diff --git a/model/ManyManyList.php b/model/ManyManyList.php index f464c62ca..9680b1212 100644 --- a/model/ManyManyList.php +++ b/model/ManyManyList.php @@ -78,17 +78,29 @@ class ManyManyList extends RelationList { if(!$this->foreignID) { throw new Exception("ManyManyList::add() can't be called until a foreign ID is set", E_USER_WARNING); } - - // Delete old entries, to prevent duplication - $this->removeById($itemID); - // Insert new entry/entries + if($filter = $this->foreignIDFilter()) { + $query = new SQLQuery("*", array("\"$this->joinTable\"")); + $query->setWhere($filter); + $hasExisting = ($query->count() > 0); + } else { + $hasExisting = false; + } + + // Insert or update foreach((array)$this->foreignID as $foreignID) { $manipulation = array(); - $manipulation[$this->joinTable]['command'] = 'insert'; + if($hasExisting) { + $manipulation[$this->joinTable]['command'] = 'update'; + $manipulation[$this->joinTable]['where'] = "\"$this->joinTable\".\"$this->foreignKey\" = " . + "'" . Convert::raw2sql($foreignID) . "'" . + " AND \"$this->localKey\" = {$itemID}"; + } else { + $manipulation[$this->joinTable]['command'] = 'insert'; + } if($extraFields) foreach($extraFields as $k => $v) { - $manipulation[$this->joinTable]['fields'][$k] = "'" . Convert::raw2sql($v) . "'"; + $manipulation[$this->joinTable]['fields'][$k] = (is_null($v)) ? 'NULL' : "'" . Convert::raw2sql($v) . "'"; } $manipulation[$this->joinTable]['fields'][$this->localKey] = $itemID; @@ -150,30 +162,27 @@ class ManyManyList extends RelationList { * @todo Add tests for this / refactor it / something * * @param string $componentName The name of the component - * @param int $childID The ID of the child for the relationship + * @param int $itemID The ID of the child for the relationship * @return array Map of fieldName => fieldValue */ - public function getExtraData($componentName, $childID) { - $ownerObj = $this->ownerObj; - $parentField = $this->ownerClass . 'ID'; - $childField = ($this->childClass == $this->ownerClass) ? 'ChildID' : ($this->childClass . 'ID'); + function getExtraData($componentName, $itemID) { $result = array(); - if(!isset($componentName)) { - user_error('ComponentSet::getExtraData() passed a NULL component name', E_USER_ERROR); - } - - if(!is_numeric($childID)) { + if(!is_numeric($itemID)) { user_error('ComponentSet::getExtraData() passed a non-numeric child ID', E_USER_ERROR); } // @todo Optimize into a single query instead of one per extra field if($this->extraFields) { foreach($this->extraFields as $fieldName => $dbFieldSpec) { - $query = DB::query("SELECT \"$fieldName\" FROM \"$this->tableName\" " - . "WHERE \"$parentField\" = {$this->ownerObj->ID} AND \"$childField\" = {$childID}"); - $value = $query->value(); - $result[$fieldName] = $value; + $query = new SQLQuery($fieldName, array("\"$this->joinTable\"")); + if($filter = $this->foreignIDFilter()) { + $query->setWhere($filter); + } else { + user_error("Can't call ManyManyList::getExtraData() until a foreign ID is set", E_USER_WARNING); + } + $query->addWhere("\"$this->localKey\" = {$itemID}"); + $result[$fieldName] = $query->execute()->value(); } } diff --git a/tests/model/ManyManyListTest.php b/tests/model/ManyManyListTest.php index 4d04b72c3..1979e0577 100644 --- a/tests/model/ManyManyListTest.php +++ b/tests/model/ManyManyListTest.php @@ -98,6 +98,40 @@ class ManyManyListTest extends SapphireTest { $newPlayer->Teams()->sort('Title')->column('ID') ); } + + public function testAddingExistingDoesntRemoveExtraFields() { + $player = new DataObjectTest_Player(); + $player->write(); + $team1 = $this->objFromFixture('DataObjectTest_Team', 'team1'); + + $team1->Players()->add($player, array('Position' => 'Captain')); + $this->assertEquals( + array('Position' => 'Captain'), + $team1->Players()->getExtraData('Teams', $player->ID), + 'Writes extrafields' + ); + + $team1->Players()->add($player); + $this->assertEquals( + array('Position' => 'Captain'), + $team1->Players()->getExtraData('Teams', $player->ID), + 'Retains extrafields on subsequent adds with NULL fields' + ); + + $team1->Players()->add($player, array('Position' => 'Defense')); + $this->assertEquals( + array('Position' => 'Defense'), + $team1->Players()->getExtraData('Teams', $player->ID), + 'Updates extrafields on subsequent adds with fields' + ); + + $team1->Players()->add($player, array('Position' => null)); + $this->assertEquals( + array('Position' => null), + $team1->Players()->getExtraData('Teams', $player->ID), + 'Allows clearing of extrafields on subsequent adds' + ); + } public function testSubtractOnAManyManyList() { $allList = ManyManyList::create('DataObjectTest_Player', 'DataObjectTest_Team_Players',