diff --git a/src/ORM/DataObject.php b/src/ORM/DataObject.php index 8a8069bf1..d5e61b9d9 100644 --- a/src/ORM/DataObject.php +++ b/src/ORM/DataObject.php @@ -3039,29 +3039,33 @@ class DataObject extends ViewableData implements DataObjectInterface, i18nEntity $limit = null, $containerClass = DataList::class ) { - + // Validate arguments if ($callerClass == null) { $callerClass = get_called_class(); - if ($callerClass == self::class) { - throw new \InvalidArgumentException('Call ::get() instead of DataObject::get()'); + if ($callerClass === self::class) { + throw new InvalidArgumentException('Call ::get() instead of DataObject::get()'); } - - if ($filter || $sort || $join || $limit || ($containerClass != DataList::class)) { - throw new \InvalidArgumentException('If calling ::get() then you shouldn\'t pass any other' + if ($filter || $sort || $join || $limit || ($containerClass !== DataList::class)) { + throw new InvalidArgumentException('If calling ::get() then you shouldn\'t pass any other' . ' arguments'); } - - return DataList::create(get_called_class()); + } elseif ($callerClass === self::class) { + throw new InvalidArgumentException('DataObject::get() cannot query non-subclass DataObject directly'); } - if ($join) { - throw new \InvalidArgumentException( + throw new InvalidArgumentException( 'The $join argument has been removed. Use leftJoin($table, $joinClause) instead.' ); } - $result = DataList::create($callerClass)->where($filter)->sort($sort); - + // Build and decorate with args + $result = DataList::create($callerClass); + if ($filter) { + $result = $result->where($filter); + } + if ($sort) { + $result = $result->sort($sort); + } if ($limit && strpos($limit, ',') !== false) { $limitArguments = explode(',', $limit); $result = $result->limit($limitArguments[1], $limitArguments[0]); @@ -3173,23 +3177,32 @@ class DataObject extends ViewableData implements DataObjectInterface, i18nEntity } /** - * Return the given element, searching by ID + * Return the given element, searching by ID. * - * @param string $callerClass The class of the object to be returned - * @param int $id The id of the element + * This can be called either via `DataObject::get_by_id(MyClass::class, $id)` + * or `MyClass::get_by_id($id)` + * + * @param string|int $classOrID The class of the object to be returned, or id if called on target class + * @param int|bool $idOrCache The id of the element, or cache if called on target class * @param boolean $cache See {@link get_one()} * - * @return DataObject The element + * @return static The element */ - public static function get_by_id($callerClass, $id, $cache = true) + public static function get_by_id($classOrID, $idOrCache = null, $cache = true) { - if (!is_numeric($id)) { - user_error("DataObject::get_by_id passed a non-numeric ID #$id", E_USER_WARNING); + // Shift arguments if passing id in first or second argument + list ($class, $id, $cached) = is_numeric($classOrID) + ? [get_called_class(), $classOrID, isset($idOrCache) ? $idOrCache : $cache] + : [$classOrID, $idOrCache, $cache]; + + // Validate class + if ($class === self::class) { + throw new InvalidArgumentException('DataObject::get_by_id() cannot query non-subclass DataObject directly'); } // Pass to get_one - $column = static::getSchema()->sqlColumnForField($callerClass, 'ID'); - return DataObject::get_one($callerClass, array($column => $id), $cache); + $column = static::getSchema()->sqlColumnForField($class, 'ID'); + return DataObject::get_one($class, [$column => $id], $cached); } /** diff --git a/tests/php/ORM/DataObjectTest.php b/tests/php/ORM/DataObjectTest.php index 606075f5b..7a9422abf 100644 --- a/tests/php/ORM/DataObjectTest.php +++ b/tests/php/ORM/DataObjectTest.php @@ -324,6 +324,20 @@ class DataObjectTest extends SapphireTest $this->assertEquals('Phil', $comment->Name); } + public function testGetByIDCallerClass() + { + $captain1ID = $this->idFromFixture(DataObjectTest\Player::class, 'captain1'); + $captain1 = DataObjectTest\Player::get_by_id($captain1ID); + $this->assertInstanceOf(DataObjectTest\Player::class, $captain1); + $this->assertEquals('Captain', $captain1->FirstName); + + $captain2ID = $this->idFromFixture(DataObjectTest\Player::class, 'captain2'); + // make sure we can call from any class but get the one passed as an argument + $captain2 = DataObjectTest\TeamComment::get_by_id(DataObjectTest\Player::class, $captain2ID); + $this->assertInstanceOf(DataObjectTest\Player::class, $captain2); + $this->assertEquals('Captain 2', $captain2->FirstName); + } + public function testGetCaseInsensitive() { // Test get_one() with bad case on the classname