thirdparty/google_appengine/google/appengine/api/datastore.py
changeset 1278 a7766286a7be
parent 828 f5fd65cc3bf3
child 2309 be1b94099f2d
--- a/thirdparty/google_appengine/google/appengine/api/datastore.py	Thu Feb 12 10:24:37 2009 +0000
+++ b/thirdparty/google_appengine/google/appengine/api/datastore.py	Thu Feb 12 12:30:36 2009 +0000
@@ -31,6 +31,8 @@
 
 
 
+import heapq
+import itertools
 import logging
 import re
 import string
@@ -47,7 +49,9 @@
 from google.appengine.runtime import apiproxy_errors
 from google.appengine.datastore import entity_pb
 
-TRANSACTION_RETRIES = 3
+MAX_ALLOWABLE_QUERIES = 30
+
+DEFAULT_TRANSACTION_RETRIES = 3
 
 _MAX_INDEXED_PROPERTIES = 5000
 
@@ -488,7 +492,7 @@
       if isinstance(sample, list):
         sample = values[0]
 
-      if isinstance(sample, (datastore_types.Blob, datastore_types.Text)):
+      if isinstance(sample, datastore_types._RAW_PROPERTY_TYPES):
         pb.raw_property_list().extend(properties)
       else:
         pb.property_list().extend(properties)
@@ -1072,12 +1076,9 @@
       values = list(values)
     elif not isinstance(values, list):
       values = [values]
-    if isinstance(values[0], datastore_types.Blob):
+    if isinstance(values[0], datastore_types._RAW_PROPERTY_TYPES):
       raise datastore_errors.BadValueError(
-        'Filtering on Blob properties is not supported.')
-    if isinstance(values[0], datastore_types.Text):
-      raise datastore_errors.BadValueError(
-        'Filtering on Text properties is not supported.')
+        'Filtering on %s properties is not supported.' % typename(values[0]))
 
     if operator in self.INEQUALITY_OPERATORS:
       if self.__inequality_prop and property != self.__inequality_prop:
@@ -1165,6 +1166,306 @@
     return pb
 
 
+class MultiQuery(Query):
+  """Class representing a query which requires multiple datastore queries.
+
+  This class is actually a subclass of datastore.Query as it is intended to act
+  like a normal Query object (supporting the same interface).
+  """
+
+  def __init__(self, bound_queries, orderings):
+    if len(bound_queries) > MAX_ALLOWABLE_QUERIES:
+      raise datastore_errors.BadArgumentError(
+          'Cannot satisfy query -- too many subqueries (max: %d, got %d).'
+          ' Probable cause: too many IN/!= filters in query.' %
+          (MAX_ALLOWABLE_QUERIES, len(bound_queries)))
+    self.__bound_queries = bound_queries
+    self.__orderings = orderings
+
+  def __str__(self):
+    res = 'MultiQuery: '
+    for query in self.__bound_queries:
+      res = '%s %s' % (res, str(query))
+    return res
+
+  def Get(self, limit, offset=0):
+    """Get results of the query with a limit on the number of results.
+
+    Args:
+      limit: maximum number of values to return.
+      offset: offset requested -- if nonzero, this will override the offset in
+              the original query
+
+    Returns:
+      A list of entities with at most "limit" entries (less if the query
+      completes before reading limit values).
+    """
+    count = 1
+    result = []
+
+    iterator = self.Run()
+
+    try:
+      for i in xrange(offset):
+        val = iterator.next()
+    except StopIteration:
+      pass
+
+    try:
+      while count <= limit:
+        val = iterator.next()
+        result.append(val)
+        count += 1
+    except StopIteration:
+      pass
+    return result
+
+  class SortOrderEntity(object):
+    """Allow entity comparisons using provided orderings.
+
+    The iterator passed to the constructor is eventually consumed via
+    calls to GetNext(), which generate new SortOrderEntity s with the
+    same orderings.
+    """
+
+    def __init__(self, entity_iterator, orderings):
+      """Ctor.
+
+      Args:
+        entity_iterator: an iterator of entities which will be wrapped.
+        orderings: an iterable of (identifier, order) pairs. order
+          should be either Query.ASCENDING or Query.DESCENDING.
+      """
+      self.__entity_iterator = entity_iterator
+      self.__entity = None
+      self.__min_max_value_cache = {}
+      try:
+        self.__entity = entity_iterator.next()
+      except StopIteration:
+        pass
+      else:
+        self.__orderings = orderings
+
+    def __str__(self):
+      return str(self.__entity)
+
+    def GetEntity(self):
+      """Gets the wrapped entity."""
+      return self.__entity
+
+    def GetNext(self):
+      """Wrap and return the next entity.
+
+      The entity is retrieved from the iterator given at construction time.
+      """
+      return MultiQuery.SortOrderEntity(self.__entity_iterator,
+                                        self.__orderings)
+
+    def CmpProperties(self, that):
+      """Compare two entities and return their relative order.
+
+      Compares self to that based on the current sort orderings and the
+      key orders between them. Returns negative, 0, or positive depending on
+      whether self is less, equal to, or greater than that. This
+      comparison returns as if all values were to be placed in ascending order
+      (highest value last).  Only uses the sort orderings to compare (ignores
+       keys).
+
+      Args:
+        that: SortOrderEntity
+
+      Returns:
+        Negative if self < that
+        Zero if self == that
+        Positive if self > that
+      """
+      if not self.__entity:
+        return cmp(self.__entity, that.__entity)
+
+      for (identifier, order) in self.__orderings:
+        value1 = self.__GetValueForId(self, identifier, order)
+        value2 = self.__GetValueForId(that, identifier, order)
+
+        result = cmp(value1, value2)
+        if order == Query.DESCENDING:
+          result = -result
+        if result:
+          return result
+      return 0
+
+    def __GetValueForId(self, sort_order_entity, identifier, sort_order):
+      value = sort_order_entity.__entity[identifier]
+      entity_key = sort_order_entity.__entity.key()
+      if (entity_key, identifier) in self.__min_max_value_cache:
+        value = self.__min_max_value_cache[(entity_key, identifier)]
+      elif isinstance(value, list):
+        if sort_order == Query.DESCENDING:
+          value = min(value)
+        else:
+          value = max(value)
+        self.__min_max_value_cache[(entity_key, identifier)] = value
+
+      return value
+
+    def __cmp__(self, that):
+      """Compare self to that w.r.t. values defined in the sort order.
+
+      Compare an entity with another, using sort-order first, then the key
+      order to break ties. This can be used in a heap to have faster min-value
+      lookup.
+
+      Args:
+        that: other entity to compare to
+      Returns:
+        negative: if self is less than that in sort order
+        zero: if self is equal to that in sort order
+        positive: if self is greater than that in sort order
+      """
+      property_compare = self.CmpProperties(that)
+      if property_compare:
+        return property_compare
+      else:
+        return cmp(self.__entity.key(), that.__entity.key())
+
+  def Run(self):
+    """Return an iterable output with all results in order."""
+    results = []
+    count = 1
+    log_level = logging.DEBUG - 1
+    for bound_query in self.__bound_queries:
+      logging.log(log_level, 'Running query #%i' % count)
+      results.append(bound_query.Run())
+      count += 1
+
+    def IterateResults(results):
+      """Iterator function to return all results in sorted order.
+
+      Iterate over the array of results, yielding the next element, in
+      sorted order. This function is destructive (results will be empty
+      when the operation is complete).
+
+      Args:
+        results: list of result iterators to merge and iterate through
+
+      Yields:
+        The next result in sorted order.
+      """
+      result_heap = []
+      for result in results:
+        heap_value = MultiQuery.SortOrderEntity(result, self.__orderings)
+        if heap_value.GetEntity():
+          heapq.heappush(result_heap, heap_value)
+
+      used_keys = set()
+
+      while result_heap:
+        top_result = heapq.heappop(result_heap)
+
+        results_to_push = []
+        if top_result.GetEntity().key() not in used_keys:
+          yield top_result.GetEntity()
+        else:
+          pass
+
+        used_keys.add(top_result.GetEntity().key())
+
+        results_to_push = []
+        while result_heap:
+          next = heapq.heappop(result_heap)
+          if cmp(top_result, next):
+            results_to_push.append(next)
+            break
+          else:
+            results_to_push.append(next.GetNext())
+        results_to_push.append(top_result.GetNext())
+
+        for popped_result in results_to_push:
+          if popped_result.GetEntity():
+            heapq.heappush(result_heap, popped_result)
+
+    return IterateResults(results)
+
+  def Count(self, limit=None):
+    """Return the number of matched entities for this query.
+
+    Will return the de-duplicated count of results.  Will call the more
+    efficient Get() function if a limit is given.
+
+    Args:
+      limit: maximum number of entries to count (for any result > limit, return
+      limit).
+    Returns:
+      count of the number of entries returned.
+    """
+    if limit is None:
+      count = 0
+      for i in self.Run():
+        count += 1
+      return count
+    else:
+      return len(self.Get(limit))
+
+  def __setitem__(self, query_filter, value):
+    """Add a new filter by setting it on all subqueries.
+
+    If any of the setting operations raise an exception, the ones
+    that succeeded are undone and the exception is propagated
+    upward.
+
+    Args:
+      query_filter: a string of the form "property operand".
+      value: the value that the given property is compared against.
+    """
+    saved_items = []
+    for index, query in enumerate(self.__bound_queries):
+      saved_items.append(query.get(query_filter, None))
+      try:
+        query[query_filter] = value
+      except:
+        for q, old_value in itertools.izip(self.__bound_queries[:index],
+                                           saved_items):
+          if old_value is not None:
+            q[query_filter] = old_value
+          else:
+            del q[query_filter]
+        raise
+
+  def __delitem__(self, query_filter):
+    """Delete a filter by deleting it from all subqueries.
+
+    If a KeyError is raised during the attempt, it is ignored, unless
+    every subquery raised a KeyError. If any other exception is
+    raised, any deletes will be rolled back.
+
+    Args:
+      query_filter: the filter to delete.
+
+    Raises:
+      KeyError: No subquery had an entry containing query_filter.
+    """
+    subquery_count = len(self.__bound_queries)
+    keyerror_count = 0
+    saved_items = []
+    for index, query in enumerate(self.__bound_queries):
+      try:
+        saved_items.append(query.get(query_filter, None))
+        del query[query_filter]
+      except KeyError:
+        keyerror_count += 1
+      except:
+        for q, old_value in itertools.izip(self.__bound_queries[:index],
+                                           saved_items):
+          if old_value is not None:
+            q[query_filter] = old_value
+        raise
+
+    if keyerror_count == subquery_count:
+      raise KeyError(query_filter)
+
+  def __iter__(self):
+    return iter(self.__bound_queries)
+
+
 class Iterator(object):
   """An iterator over the results of a datastore query.
 
@@ -1331,8 +1632,29 @@
     self.modified_keys.update(keys)
 
 
+def RunInTransaction(function, *args, **kwargs):
+  """Runs a function inside a datastore transaction.
 
-def RunInTransaction(function, *args, **kwargs):
+     Runs the user-provided function inside transaction, retries default
+     number of times.
+
+    Args:
+    # a function to be run inside the transaction
+    function: callable
+    # positional arguments to pass to the function
+    args: variable number of any type
+
+  Returns:
+    the function's return value, if any
+
+  Raises:
+    TransactionFailedError, if the transaction could not be committed.
+  """
+  return RunInTransactionCustomRetries(
+      DEFAULT_TRANSACTION_RETRIES, function, *args, **kwargs)
+
+
+def RunInTransactionCustomRetries(retries, function, *args, **kwargs):
   """Runs a function inside a datastore transaction.
 
   Runs the user-provided function inside a full-featured, ACID datastore
@@ -1387,6 +1709,8 @@
   Nested transactions are not supported.
 
   Args:
+    # number of retries
+    retries: integer
     # a function to be run inside the transaction
     function: callable
     # positional arguments to pass to the function
@@ -1403,6 +1727,10 @@
     raise datastore_errors.BadRequestError(
       'Nested transactions are not supported.')
 
+  if retries < 0:
+    raise datastore_errors.BadRequestError(
+      'Number of retries should be non-negative number.')
+
   tx_key = None
 
   try:
@@ -1410,7 +1738,7 @@
     tx = _Transaction()
     _txes[tx_key] = tx
 
-    for i in range(0, TRANSACTION_RETRIES + 1):
+    for i in range(0, retries + 1):
       tx.modified_keys.clear()
 
       try:
@@ -1436,7 +1764,7 @@
 
       if tx.handle:
         try:
-          resp = api_base_pb.VoidProto()
+          resp = datastore_pb.CommitResponse()
           apiproxy_stub_map.MakeSyncCall('datastore_v3', 'Commit',
                                          tx.handle, resp)
         except apiproxy_errors.ApplicationError, err:
@@ -1544,7 +1872,7 @@
   """Walks the stack to find a RunInTransaction() call.
 
   Returns:
-    # this is the RunInTransaction() frame record, if found
+    # this is the RunInTransactionCustomRetries() frame record, if found
     frame record or None
   """
   frame = sys._getframe()
@@ -1553,7 +1881,7 @@
   frame = frame.f_back.f_back
   while frame:
     if (frame.f_code.co_filename == filename and
-        frame.f_code.co_name == 'RunInTransaction'):
+        frame.f_code.co_name == 'RunInTransactionCustomRetries'):
       return frame
     frame = frame.f_back