From 0982f77ec70bfa69cd88767204fac55e677e4bc2 Mon Sep 17 00:00:00 2001 From: Aaron Carlino Date: Fri, 3 Feb 2017 09:03:19 +1300 Subject: [PATCH] Feature/aggregate data filters (#6553) --- src/ORM/Filters/ComparisonFilter.php | 16 +-- src/ORM/Filters/ExactMatchFilter.php | 13 ++- src/ORM/Filters/PartialMatchFilter.php | 27 ++++- src/ORM/Filters/SearchFilter.php | 90 +++++++++++++++- tests/php/ORM/DataListTest.php | 141 +++++++++++++++++++++++++ 5 files changed, 274 insertions(+), 13 deletions(-) diff --git a/src/ORM/Filters/ComparisonFilter.php b/src/ORM/Filters/ComparisonFilter.php index 2b4c08bbe..77a25a2e6 100755 --- a/src/ORM/Filters/ComparisonFilter.php +++ b/src/ORM/Filters/ComparisonFilter.php @@ -44,9 +44,11 @@ abstract class ComparisonFilter extends SearchFilter $this->model = $query->applyRelation($this->relation); $predicate = sprintf("%s %s ?", $this->getDbName(), $this->getOperator()); - return $query->where(array( - $predicate => $this->getDbFormattedValue() - )); + $clause = [$predicate => $this->getDbFormattedValue()]; + + return $this->aggregate ? + $this->applyAggregate($query, $clause) : + $query->where($clause); } /** @@ -61,9 +63,11 @@ abstract class ComparisonFilter extends SearchFilter $this->model = $query->applyRelation($this->relation); $predicate = sprintf("%s %s ?", $this->getDbName(), $this->getInverseOperator()); - return $query->where(array( - $predicate => $this->getDbFormattedValue() - )); + $clause = [$predicate => $this->getDbFormattedValue()]; + + return $this->aggregate ? + $this->applyAggregate($query, $clause) : + $query->where($clause); } public function isEmpty() diff --git a/src/ORM/Filters/ExactMatchFilter.php b/src/ORM/Filters/ExactMatchFilter.php index c375a2adb..277d9896a 100644 --- a/src/ORM/Filters/ExactMatchFilter.php +++ b/src/ORM/Filters/ExactMatchFilter.php @@ -75,7 +75,12 @@ class ExactMatchFilter extends SearchFilter $nullClause = DB::get_conn()->nullCheckClause($field, true); $where .= " OR {$nullClause}"; } - return $query->where(array($where => $value)); + + $clause = [$where => $value]; + + return $this->aggregate ? + $this->applyAggregate($query, $clause) : + $query->where($clause); } /** @@ -189,7 +194,11 @@ class ExactMatchFilter extends SearchFilter } } - return $query->where(array($predicate => $values)); + $clause = [$predicate => $values]; + + return $this->aggregate ? + $this->applyAggregate($query, $clause) : + $query->where($clause); } public function isEmpty() diff --git a/src/ORM/Filters/PartialMatchFilter.php b/src/ORM/Filters/PartialMatchFilter.php index 9e5d9f6e5..a158dd1bf 100644 --- a/src/ORM/Filters/PartialMatchFilter.php +++ b/src/ORM/Filters/PartialMatchFilter.php @@ -28,6 +28,24 @@ class PartialMatchFilter extends SearchFilter return "%$value%"; } + /** + * Apply filter criteria to a SQL query. + * + * @param DataQuery $query + * @return DataQuery + */ + public function apply(DataQuery $query) + { + if ($this->aggregate) { + throw new InvalidArgumentException(sprintf( + 'Aggregate functions can only be used with comparison filters. See %s', + $this->fullName + )); + } + + return parent::apply($query); + } + protected function applyOne(DataQuery $query) { $this->model = $query->applyRelation($this->relation); @@ -39,9 +57,12 @@ class PartialMatchFilter extends SearchFilter $this->getCaseSensitive(), true ); - return $query->where(array( - $comparisonClause => $this->getMatchPattern($this->getValue()) - )); + + $clause = [$comparisonClause => $this->getMatchPattern($this->getValue())]; + + return $this->aggregate ? + $this->applyAggregate($query, $clause) : + $query->where($clause); } protected function applyMany(DataQuery $query) diff --git a/src/ORM/Filters/SearchFilter.php b/src/ORM/Filters/SearchFilter.php index 099226ef5..7e1f34baf 100644 --- a/src/ORM/Filters/SearchFilter.php +++ b/src/ORM/Filters/SearchFilter.php @@ -58,6 +58,17 @@ abstract class SearchFilter extends Object */ protected $relation; + /** + * An array of data about an aggregate column being used + * ex: + * [ + * 'function' => 'COUNT', + * 'column' => 'ID' + * ] + * @var array + */ + protected $aggregate; + /** * @param string $fullName Determines the name of the field, as well as the searched database * column. Can contain a relation name in dot notation, which will automatically join @@ -73,6 +84,7 @@ abstract class SearchFilter extends Object // sets $this->name and $this->relation $this->addRelation($fullName); + $this->addAggregate($fullName); $this->value = $value; $this->setModifiers($modifiers); } @@ -94,6 +106,33 @@ abstract class SearchFilter extends Object } } + /** + * Parses the name for any aggregate functions and stores them in the $aggregate array + * + * @param string $name + */ + protected function addAggregate($name) + { + if (!$this->relation) { + return; + } + + if (!preg_match('/([A-Za-z]+)\(\s*(?:([A-Za-z_*][A-Za-z0-9_]*))?\s*\)$/', $name, $matches)) { + if (stristr($name, '(') !== false) { + throw new InvalidArgumentException(sprintf( + 'Malformed aggregate filter %s', + $name + )); + } + return; + } + + $this->aggregate = [ + 'function' => strtoupper($matches[1]), + 'column' => isset($matches[2]) ? $matches[2] : null + ]; + } + /** * Set the root model class to be selected by this * search query. @@ -217,14 +256,40 @@ abstract class SearchFilter extends Object } // Ensure that we're dealing with a DataObject. - if (!is_subclass_of($this->model, 'SilverStripe\\ORM\\DataObject')) { + if (!is_subclass_of($this->model, DataObject::class)) { throw new InvalidArgumentException( "Model supplied to " . get_class($this) . " should be an instance of DataObject." ); } + $schema = DataObject::getSchema(); + + if ($this->aggregate) { + $column = $this->aggregate['column']; + $function = $this->aggregate['function']; + + $table = $column ? + $schema->tableForField($this->model, $column) : + $schema->baseDataTable($this->model); + + if (!$table) { + throw new InvalidArgumentException(sprintf( + 'Invalid column %s for aggregate function %s on %s', + $column, + $function, + $this->model + )); + } + return sprintf( + '%s("%s".%s)', + $function, + $table, + $column ? "\"$column\"" : '"ID"' + ); + } + // Find table this field belongs to - $table = DataObject::getSchema()->tableForField($this->model, $this->name); + $table = $schema->tableForField($this->model, $this->name); if (!$table) { // fallback to the provided name in the event of a joined column // name (as the candidate class doesn't check joined records) @@ -244,12 +309,33 @@ abstract class SearchFilter extends Object { // SRM: This code finds the table where the field named $this->name lives // Todo: move to somewhere more appropriate, such as DataMapper, the magical class-to-be? + + if ($this->aggregate) { + return intval($this->value); + } + /** @var DBField $dbField */ $dbField = singleton($this->model)->dbObject($this->name); $dbField->setValue($this->value); return $dbField->RAW(); } + /** + * Given an escaped HAVING clause, add it along with the appropriate GROUP BY clause + * @param DataQuery $query + * @param string $having + * @return DataQuery + */ + public function applyAggregate(DataQuery $query, $having) + { + $schema = DataObject::getSchema(); + $baseTable = $schema->baseDataTable($query->dataClass()); + + return $query + ->having($having) + ->groupby("\"{$baseTable}\".\"ID\""); + } + /** * Apply filter criteria to a SQL query. * diff --git a/tests/php/ORM/DataListTest.php b/tests/php/ORM/DataListTest.php index f426d159a..e827a8194 100755 --- a/tests/php/ORM/DataListTest.php +++ b/tests/php/ORM/DataListTest.php @@ -6,6 +6,7 @@ use SilverStripe\Core\Convert; use SilverStripe\ORM\DataList; use SilverStripe\ORM\DB; use SilverStripe\ORM\Filterable; +use SilverStripe\ORM\Filters\ExactMatchFilter; use SilverStripe\Dev\SapphireTest; use SilverStripe\ORM\Tests\DataObjectTest\EquipmentCompany; use SilverStripe\ORM\Tests\DataObjectTest\Fan; @@ -15,6 +16,7 @@ use SilverStripe\ORM\Tests\DataObjectTest\SubTeam; use SilverStripe\ORM\Tests\DataObjectTest\Team; use SilverStripe\ORM\Tests\DataObjectTest\TeamComment; use SilverStripe\ORM\Tests\DataObjectTest\ValidatedObject; +use SilverStripe\ORM\Tests\DataObjectTest\Staff; class DataListTest extends SapphireTest { @@ -30,6 +32,7 @@ class DataListTest extends SapphireTest ); } + public function testFilterDataObjectByCreatedDate() { // create an object to test with @@ -1220,6 +1223,144 @@ class DataListTest extends SapphireTest $this->assertSQLNotContains('"DataObjectTest_Fan"."Email" IS NOT NULL', $items9->sql()); } + public function testAggregateDBName() + { + $filter = new ExactMatchFilter( + 'Comments.Count()' + ); + $filter->setModel(new DataObjectTest\Team()); + $this->assertEquals('COUNT("DataObjectTest_Team"."ID")', $filter->getDBName()); + + foreach (['Comments.Max(ID)', 'Comments.Max( ID )', 'Comments.Max( ID)'] as $name) { + $filter = new ExactMatchFilter($name); + $filter->setModel(new DataObjectTest\Team()); + $this->assertEquals('MAX("DataObjectTest_Team"."ID")', $filter->getDBName()); + } + } + + public function testAggregateFilterExceptions() + { + $ex = null; + try { + $filter = new ExactMatchFilter('Comments.Max( This will not parse! )'); + } catch (\Exception $e) { + $ex = $e; + } + $this->assertInstanceOf(\InvalidArgumentException::class, $ex); + $this->assertRegExp('/Malformed/', $ex->getMessage()); + + + $filter = new ExactMatchFilter('Comments.Max(NonExistentColumn)'); + $filter->setModel(new DataObjectTest\Team()); + $ex = null; + try { + $name = $filter->getDBName(); + } catch (\Exception $e) { + $ex = $e; + } + $this->assertInstanceOf(\InvalidArgumentException::class, $ex); + $this->assertRegExp('/Invalid column/', $ex->getMessage()); + } + + public function testAggregateFilters() + { + $teams = Team::get()->filter('Comments.Count()', 2); + + $team1 = $this->objFromFixture(Team::class, 'team1'); + $team2 = $this->objFromFixture(Team::class, 'team2'); + $team3 = $this->objFromFixture(Team::class, 'team3'); + $team4 = $this->objFromFixture(SubTeam::class, 'subteam1'); + $team5 = $this->objFromFixture(SubTeam::class, 'subteam2_with_player_relation'); + $team6 = $this->objFromFixture(SubTeam::class, 'subteam3_with_empty_fields'); + + $company1 = $this->objFromFixture(EquipmentCompany::class, 'equipmentcompany1'); + $company2 = $this->objFromFixture(EquipmentCompany::class, 'equipmentcompany2'); + + $company1->CurrentStaff()->add(Staff::create(['Salary' => 3])->write()); + $company1->CurrentStaff()->add(Staff::create(['Salary' => 5])->write()); + $company2->CurrentStaff()->add(Staff::create(['Salary' => 4])->write()); + + $this->assertCount(1, $teams); + $this->assertEquals($team1->ID, $teams->first()->ID); + + $teams = Team::get()->filter('Comments.Count()', [1,2]); + + $this->assertCount(2, $teams); + foreach ([$team1, $team2] as $expectedTeam) { + $this->assertContains($expectedTeam->ID, $teams->column('ID')); + } + + $teams = Team::get()->filter('Comments.Count():GreaterThan', 1); + + $this->assertCount(1, $teams); + $this->assertContains( + $this->objFromFixture(Team::class, 'team1')->ID, + $teams->column('ID') + ); + + $teams = Team::get()->filter('Comments.Count():LessThan', 2); + + $this->assertCount(5, $teams); + foreach ([$team2, $team3, $team4, $team5, $team6] as $expectedTeam) { + $this->assertContains($expectedTeam->ID, $teams->column('ID')); + } + + $teams = Team::get()->filter('Comments.Count():GreaterThanOrEqual', 1); + + $this->assertCount(2, $teams); + foreach ([$team1, $team2] as $expectedTeam) { + $this->assertContains($expectedTeam->ID, $teams->column('ID')); + } + + $teams = Team::get()->filter('Comments.Count():LessThanOrEqual', 1); + + $this->assertCount(5, $teams); + foreach ([$team2, $team3, $team4, $team5, $team6] as $expectedTeam) { + $this->assertContains($expectedTeam->ID, $teams->column('ID')); + } + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Max(Salary)', 5); + $this->assertCount(1, $companies); + $this->assertEquals($company1->ID, $companies->first()->ID); + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Min(Salary)', 3); + $this->assertCount(1, $companies); + $this->assertEquals($company1->ID, $companies->first()->ID); + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Max(Salary):GreaterThan', 3); + $this->assertCount(2, $companies); + foreach ([$company1, $company2] as $expectedTeam) { + $this->assertContains($expectedTeam->ID, $companies->column('ID')); + } + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Sum(Salary)', 8); + $this->assertCount(1, $companies); + $this->assertEquals($company1->ID, $companies->first()->ID); + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Sum(Salary):LessThan', 7); + $this->assertCount(1, $companies); + $this->assertEquals($company2->ID, $companies->first()->ID); + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Sum(Salary):GreaterThan', 100); + $this->assertCount(0, $companies); + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Sum(Salary):GreaterThan', 7); + $this->assertCount(1, $companies); + $this->assertEquals($company1->ID, $companies->first()->ID); + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Avg(Salary)', 4); + $this->assertCount(2, $companies); + foreach ([$company1, $company2] as $expectedTeam) { + $this->assertContains($expectedTeam->ID, $companies->column('ID')); + } + + $companies = EquipmentCompany::get()->filter('CurrentStaff.Avg(Salary):LessThan', 10); + $this->assertCount(2, $companies); + foreach ([$company1, $company2] as $expectedTeam) { + $this->assertContains($expectedTeam->ID, $companies->column('ID')); + } + } + /** * $list = $list->filterByCallback(function($item, $list) { return $item->Age == 21; }) */