Convert getForFields back to the Query API and add tests
authorSverre Rabbelier <srabbelier@gmail.com>
Tue, 03 Mar 2009 23:27:39 +0000 (2009-03-03)
changeset 1614 797f5ae462e7
parent 1613 59e5cc89e509
child 1615 81f26c9809dc
Convert getForFields back to the Query API and add tests This is possible now that it supports the 'IN' operator. Added tests to make sure there are no regressions. Patch by: Sverre Rabbelier
app/soc/logic/models/base.py
tests/app/soc/logic/models/test_base.py
--- a/app/soc/logic/models/base.py	Tue Mar 03 23:11:37 2009 +0000
+++ b/app/soc/logic/models/base.py	Tue Mar 03 23:27:39 2009 +0000
@@ -268,7 +268,8 @@
 
     raise out_of_band.Error(msg, status=404)
 
-  def getForFields(self, filter=None, unique=False, limit=1000, offset=0):
+  def getForFields(self, filter=None, unique=False,
+                   limit=1000, offset=0, order=None):
     """Returns all entities that have the specified properties.
 
     Args:
@@ -276,38 +277,31 @@
       unique: if set, only the first item from the resultset will be returned
       limit: the amount of entities to fetch at most
       offset: the position to start at
+      order: a list with the sort order
     """
 
+    if not filter:
+      filter = {}
     if unique:
       limit = 1
+    if not order:
+      order = []
 
-    if filter:
-      format_eq = '%(key)s = :%(num)d'
-      format_in = '%(key)s IN (%(values)s)'
+    orderset = set([i.strip('-') for i in order])
+    if len(orderset) != len(order):
+      raise InvalidArgumentError
 
-      n = 1
-      conditionals = []
-      args = []
+    q = db.Query(self._model)
 
-      for key, value in filter.iteritems():
-        if isinstance(value, list):
-          count = len(value)
-          args.extend(value)
-          values = ', '.join([':%d' % i for i in range(n, n + count)])
-          sub = format_in % {'key': key, 'values': values}
-          n = n + count
-        else:
-          sub = format_eq % {'key': key, 'num': n}
-          args.append(value)
-          n = n + 1
-        conditionals.append(sub)
+    for key, value in filter.iteritems():
+      if isinstance(value, list):
+        op = '%s IN' % key
+        q.filter(op, value)
+      else:
+        q.filter(key, value)
 
-      joined_pairs = ' AND '.join(conditionals)
-      condition = 'WHERE ' + joined_pairs
-
-      q = self._model.gql(condition, *args)
-    else:
-      q = self._model.all()
+    for key in order:
+      q.order(key)
 
     result = q.fetch(limit, offset)
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/app/soc/logic/models/test_base.py	Tue Mar 03 23:27:39 2009 +0000
@@ -0,0 +1,152 @@
+#!/usr/bin/python2.5
+#
+# Copyright 2009 the Melange authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+__authors__ = [
+  '"Sverre Rabbelier" <sverre@rabbelier.nl>',
+  ]
+
+
+import unittest
+
+from google.appengine.api import users
+from google.appengine.ext import db
+
+from soc.logic.models import base
+
+
+class TestModel(db.Model):
+  """Simpel test model.
+  """
+
+  value = db.IntegerProperty()
+
+
+class TestModelLogic(base.Logic):
+  """Simple test logic.
+  """
+
+  def __init__(self):
+    super(TestModelLogic, self).__init__(TestModel)
+
+
+class UserTest(unittest.TestCase):
+  """Tests related to user logic.
+  """
+
+  def setUp(self):
+    """Set up required for the slot allocation tests.
+    """
+
+    entities = []
+
+    for i in range(5):
+      entity = TestModel(key_name='test_%d' % i, value=i)
+      entity.put()
+      entities.append(entity)
+
+    self.logic = TestModelLogic()
+    self.entities = entities
+
+  def testGetForFields(self):
+    """Test that all entries were retrieved.
+    """
+
+    expected = set(range(5))
+    actual = set([i.value for i in self.logic.getForFields()])
+    self.assertEqual(expected, actual)
+
+  def testGetForFieldsFiltered(self):
+    """Test that only the entry that matches the filter is retrieved.
+    """
+
+    fields = {'value': 1}
+
+    expected = [1]
+    actual = [i.value for i in self.logic.getForFields(fields)]
+
+    self.assertEqual(expected, actual)
+
+  def testGetForFieldsWithOperator(self):
+    """Test that all entries matching the filter are retrieved.
+    """
+
+    fields = {'value <': 3}
+
+    expected = set(range(3))
+    actual = set([i.value for i in self.logic.getForFields(fields)])
+
+    self.assertEqual(expected, actual)
+
+  def testGetForFieldsNonMatching(self):
+    """Test that unique returns None instead of a list.
+    """
+
+    fields = {'value': 1337}
+
+    expected = []
+    actual = self.logic.getForFields(fields)
+    self.assertEqual(expected, actual)
+
+  def testGetForFieldsUnique(self):
+    """Test that unique returns an entry instead of a list.
+    """
+
+    fields = {'value': 1}
+
+    actual = self.logic.getForFields(fields, unique=True)
+    self.assertTrue(isinstance(actual, TestModel))
+
+  def testGetForFieldsUniqueEmpty(self):
+    """Test that unique returns None instead of a list.
+    """
+
+    fields = {'value': 1337}
+
+    expected = None
+    actual = self.logic.getForFields(fields, unique=True)
+    self.assertEqual(expected, actual)
+
+  def testGetForFieldsMultiFilter(self):
+    """Test that all entries matching an 'IN' filter are returned.
+    """
+
+    fields = {'value': [1, 2]}
+
+    expected = 2
+    actual = len(self.logic.getForFields(fields))
+    self.assertEqual(expected, actual)
+
+  def testGetFieldsOrdened(self):
+    """Test that fields can be ordened.
+    """
+
+    order = ['value']
+
+    expected = range(5)
+    actual = [i.value for i in self.logic.getForFields(order=order)]
+    self.assertEqual(expected, actual)
+
+  def testGetFieldsReverseOrdened(self):
+    """Test that fields can be ordened in reverse.
+    """
+
+    order = ['-value']
+
+    expected = range(5)
+    expected.reverse()
+    actual = [i.value for i in self.logic.getForFields(order=order)]
+    self.assertEqual(expected, actual)