app/django/db/models/query.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/db/models/query.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/db/models/query.py	Tue Oct 14 16:00:59 2008 +0000
@@ -1,28 +1,132 @@
-import warnings
+try:
+    set
+except NameError:
+    from sets import Set as set     # Python 2.3 fallback
 
-from django.conf import settings
 from django.db import connection, transaction, IntegrityError
-from django.db.models.fields import DateField, FieldDoesNotExist
-from django.db.models.query_utils import Q
+from django.db.models.fields import DateField
+from django.db.models.query_utils import Q, select_related_descend
 from django.db.models import signals, sql
-from django.dispatch import dispatcher
 from django.utils.datastructures import SortedDict
 
+
 # Used to control how many objects are worked with at once in some cases (e.g.
 # when deleting objects).
 CHUNK_SIZE = 100
 ITER_CHUNK_SIZE = CHUNK_SIZE
 
-# Pull into this namespace for backwards compatibility
+# The maximum number of items to display in a QuerySet.__repr__
+REPR_OUTPUT_SIZE = 20
+
+# Pull into this namespace for backwards compatibility.
 EmptyResultSet = sql.EmptyResultSet
 
+
+class CyclicDependency(Exception):
+    """
+    An error when dealing with a collection of objects that have a cyclic
+    dependency, i.e. when deleting multiple objects.
+    """
+    pass
+
+
+class CollectedObjects(object):
+    """
+    A container that stores keys and lists of values along with remembering the
+    parent objects for all the keys.
+
+    This is used for the database object deletion routines so that we can
+    calculate the 'leaf' objects which should be deleted first.
+    """
+
+    def __init__(self):
+        self.data = {}
+        self.children = {}
+
+    def add(self, model, pk, obj, parent_model, nullable=False):
+        """
+        Adds an item to the container.
+
+        Arguments:
+        * model - the class of the object being added.
+        * pk - the primary key.
+        * obj - the object itself.
+        * parent_model - the model of the parent object that this object was
+          reached through.
+        * nullable - should be True if this relation is nullable.
+
+        Returns True if the item already existed in the structure and
+        False otherwise.
+        """
+        d = self.data.setdefault(model, SortedDict())
+        retval = pk in d
+        d[pk] = obj
+        # Nullable relationships can be ignored -- they are nulled out before
+        # deleting, and therefore do not affect the order in which objects
+        # have to be deleted.
+        if parent_model is not None and not nullable:
+            self.children.setdefault(parent_model, []).append(model)
+        return retval
+
+    def __contains__(self, key):
+        return self.data.__contains__(key)
+
+    def __getitem__(self, key):
+        return self.data[key]
+
+    def __nonzero__(self):
+        return bool(self.data)
+
+    def iteritems(self):
+        for k in self.ordered_keys():
+            yield k, self[k]
+
+    def items(self):
+        return list(self.iteritems())
+
+    def keys(self):
+        return self.ordered_keys()
+
+    def ordered_keys(self):
+        """
+        Returns the models in the order that they should be dealt with (i.e.
+        models with no dependencies first).
+        """
+        dealt_with = SortedDict()
+        # Start with items that have no children
+        models = self.data.keys()
+        while len(dealt_with) < len(models):
+            found = False
+            for model in models:
+                if model in dealt_with:
+                    continue
+                children = self.children.setdefault(model, [])
+                if len([c for c in children if c not in dealt_with]) == 0:
+                    dealt_with[model] = None
+                    found = True
+            if not found:
+                raise CyclicDependency(
+                    "There is a cyclic dependency of items to be processed.")
+
+        return dealt_with.keys()
+
+    def unordered_keys(self):
+        """
+        Fallback for the case where is a cyclic dependency but we don't  care.
+        """
+        return self.data.keys()
+
+
 class QuerySet(object):
-    "Represents a lazy database lookup for a set of objects"
+    """
+    Represents a lazy database lookup for a set of objects.
+    """
     def __init__(self, model=None, query=None):
         self.model = model
         self.query = query or sql.Query(self.model, connection)
         self._result_cache = None
         self._iter = None
+        self._sticky_filter = False
 
     ########################
     # PYTHON MAGIC METHODS #
@@ -30,7 +134,7 @@
 
     def __getstate__(self):
         """
-        Allows the Queryset to be pickled.
+        Allows the QuerySet to be pickled.
         """
         # Force the cache to be fully populated.
         len(self)
@@ -40,12 +144,15 @@
         return obj_dict
 
     def __repr__(self):
-        return repr(list(self))
+        data = list(self[:REPR_OUTPUT_SIZE + 1])
+        if len(data) > REPR_OUTPUT_SIZE:
+            data[-1] = "...(remaining elements truncated)..."
+        return repr(data)
 
     def __len__(self):
         # Since __len__ is called quite frequently (for example, as part of
         # list(qs), we make some effort here to be as efficient as possible
-        # whilst not messing up any existing iterators against the queryset.
+        # whilst not messing up any existing iterators against the QuerySet.
         if self._result_cache is None:
             if self._iter:
                 self._result_cache = list(self._iter)
@@ -87,7 +194,9 @@
         return True
 
     def __getitem__(self, k):
-        "Retrieve an item or slice from the set of results."
+        """
+        Retrieves an item or slice from the set of results.
+        """
         if not isinstance(k, (slice, int, long)):
             raise TypeError
         assert ((not isinstance(k, slice) and (k >= 0))
@@ -132,6 +241,8 @@
 
     def __and__(self, other):
         self._merge_sanity_check(other)
+        if isinstance(other, EmptyQuerySet):
+            return other._clone()
         combined = self._clone()
         combined.query.combine(other.query, sql.AND)
         return combined
@@ -139,6 +250,8 @@
     def __or__(self, other):
         self._merge_sanity_check(other)
         combined = self._clone()
+        if isinstance(other, EmptyQuerySet):
+            return combined
         combined.query.combine(other.query, sql.OR)
         return combined
 
@@ -174,11 +287,10 @@
         Performs a SELECT COUNT() and returns the number of records as an
         integer.
 
-        If the queryset is already cached (i.e. self._result_cache is set) this
-        simply returns the length of the cached results set to avoid multiple
-        SELECT COUNT(*) calls.
+        If the QuerySet is already fully cached this simply returns the length
+        of the cached results set to avoid multiple SELECT COUNT(*) calls.
         """
-        if self._result_cache is not None:
+        if self._result_cache is not None and not self._iter:
             return len(self._result_cache)
 
         return self.query.get_count()
@@ -200,11 +312,11 @@
 
     def create(self, **kwargs):
         """
-        Create a new object with the given kwargs, saving it to the database
+        Creates a new object with the given kwargs, saving it to the database
         and returning the created object.
         """
         obj = self.model(**kwargs)
-        obj.save()
+        obj.save(force_insert=True)
         return obj
 
     def get_or_create(self, **kwargs):
@@ -223,10 +335,16 @@
                 params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
                 params.update(defaults)
                 obj = self.model(**params)
-                obj.save()
+                sid = transaction.savepoint()
+                obj.save(force_insert=True)
+                transaction.savepoint_commit(sid)
                 return obj, True
             except IntegrityError, e:
-                return self.get(**kwargs), False
+                transaction.savepoint_rollback(sid)
+                try:
+                    return self.get(**kwargs), False
+                except self.model.DoesNotExist:
+                    raise e
 
     def latest(self, field_name=None):
         """
@@ -275,7 +393,7 @@
         while 1:
             # Collect all the objects to be deleted in this chunk, and all the
             # objects that are related to the objects that are to be deleted.
-            seen_objs = SortedDict()
+            seen_objs = CollectedObjects()
             for object in del_query[:CHUNK_SIZE]:
                 object._collect_sub_objects(seen_objs)
 
@@ -292,11 +410,14 @@
         Updates all elements in the current QuerySet, setting all the given
         fields to the appropriate values.
         """
+        assert self.query.can_filter(), \
+                "Cannot update a query once a slice has been taken."
         query = self.query.clone(sql.UpdateQuery)
         query.add_update_values(kwargs)
-        query.execute_sql(None)
+        rows = query.execute_sql(None)
         transaction.commit_unless_managed()
         self._result_cache = None
+        return rows
     update.alters_data = True
 
     def _update(self, values):
@@ -306,10 +427,12 @@
         code (it requires too much poking around at model internals to be
         useful at that level).
         """
+        assert self.query.can_filter(), \
+                "Cannot update a query once a slice has been taken."
         query = self.query.clone(sql.UpdateQuery)
         query.add_update_fields(values)
-        query.execute_sql(None)
         self._result_cache = None
+        return query.execute_sql(None)
     _update.alters_data = True
 
     ##################################################
@@ -331,23 +454,19 @@
 
     def dates(self, field_name, kind, order='ASC'):
         """
-        Returns a list of datetime objects representing all available dates
-        for the given field_name, scoped to 'kind'.
+        Returns a list of datetime objects representing all available dates for
+        the given field_name, scoped to 'kind'.
         """
         assert kind in ("month", "year", "day"), \
                 "'kind' must be one of 'year', 'month' or 'day'."
         assert order in ('ASC', 'DESC'), \
                 "'order' must be either 'ASC' or 'DESC'."
-        # Let the FieldDoesNotExist exception propagate.
-        field = self.model._meta.get_field(field_name, many_to_many=False)
-        assert isinstance(field, DateField), "%r isn't a DateField." \
-                % field_name
-        return self._clone(klass=DateQuerySet, setup=True, _field=field,
-                _kind=kind, _order=order)
+        return self._clone(klass=DateQuerySet, setup=True,
+                _field_name=field_name, _kind=kind, _order=order)
 
     def none(self):
         """
-        Returns an empty queryset.
+        Returns an empty QuerySet.
         """
         return self._clone(klass=EmptyQuerySet)
 
@@ -391,6 +510,7 @@
     def complex_filter(self, filter_obj):
         """
         Returns a new QuerySet instance with filter_obj added to the filters.
+
         filter_obj can be a Q object (or anything with an add_to_query()
         method) or a dictionary of keyword lookup arguments.
 
@@ -398,14 +518,17 @@
         and usually it will be more natural to use other methods.
         """
         if isinstance(filter_obj, Q) or hasattr(filter_obj, 'add_to_query'):
-            return self._filter_or_exclude(None, filter_obj)
+            clone = self._clone()
+            clone.query.add_q(filter_obj)
+            return clone
         else:
             return self._filter_or_exclude(None, **filter_obj)
 
     def select_related(self, *fields, **kwargs):
         """
-        Returns a new QuerySet instance that will select related objects. If
-        fields are specified, they must be ForeignKey fields and only those
+        Returns a new QuerySet instance that will select related objects.
+
+        If fields are specified, they must be ForeignKey fields and only those
         related objects are included in the selection.
         """
         depth = kwargs.pop('depth', 0)
@@ -425,13 +548,15 @@
 
     def dup_select_related(self, other):
         """
-        Copies the related selection status from the queryset 'other' to the
-        current queryset.
+        Copies the related selection status from the QuerySet 'other' to the
+        current QuerySet.
         """
         self.query.select_related = other.query.select_related
 
     def order_by(self, *field_names):
-        """Returns a new QuerySet instance with the ordering changed."""
+        """
+        Returns a new QuerySet instance with the ordering changed.
+        """
         assert self.query.can_filter(), \
                 "Cannot reorder a query once a slice has been taken."
         obj = self._clone()
@@ -448,9 +573,9 @@
         return obj
 
     def extra(self, select=None, where=None, params=None, tables=None,
-            order_by=None, select_params=None):
+              order_by=None, select_params=None):
         """
-        Add extra SQL fragments to the query.
+        Adds extra SQL fragments to the query.
         """
         assert self.query.can_filter(), \
                 "Cannot change a query once a slice has been taken"
@@ -460,7 +585,7 @@
 
     def reverse(self):
         """
-        Reverses the ordering of the queryset.
+        Reverses the ordering of the QuerySet.
         """
         clone = self._clone()
         clone.query.standard_ordering = not clone.query.standard_ordering
@@ -473,7 +598,10 @@
     def _clone(self, klass=None, setup=False, **kwargs):
         if klass is None:
             klass = self.__class__
-        c = klass(model=self.model, query=self.query.clone())
+        query = self.query.clone()
+        if self._sticky_filter:
+            query.filter_is_sticky = True
+        c = klass(model=self.model, query=query)
         c.__dict__.update(kwargs)
         if setup and hasattr(c, '_setup_query'):
             c._setup_query()
@@ -491,13 +619,28 @@
             except StopIteration:
                 self._iter = None
 
+    def _next_is_sticky(self):
+        """
+        Indicates that the next filter call and the one following that should
+        be treated as a single filter. This is only important when it comes to
+        determining when to reuse tables for many-to-many filters. Required so
+        that we can filter naturally on the results of related managers.
+
+        This doesn't return a clone of the current QuerySet (it returns
+        "self"). The method is only used internally and should be immediately
+        followed by a filter() that does create a clone.
+        """
+        self._sticky_filter = True
+        return self
+
     def _merge_sanity_check(self, other):
         """
-        Checks that we are merging two comparable queryset classes.
+        Checks that we are merging two comparable QuerySet classes. By default
+        this does nothing, but see the ValuesQuerySet for an example of where
+        it's useful.
         """
-        if self.__class__ is not other.__class__:
-            raise TypeError("Cannot merge querysets of different types ('%s' and '%s'."
-                    % (self.__class__.__name__, other.__class__.__name__))
+        pass
+
 
 class ValuesQuerySet(QuerySet):
     def __init__(self, *args, **kwargs):
@@ -509,7 +652,9 @@
         # names of the model fields to select.
 
     def iterator(self):
-        self.query.trim_extra_select(self.extra_names)
+        if (not self.extra_names and
+            len(self.field_names) != len(self.model._meta.fields)):
+            self.query.trim_extra_select(self.extra_names)
         names = self.query.extra_select.keys() + self.field_names
         for row in self.query.results_iter():
             yield dict(zip(names, row))
@@ -519,7 +664,7 @@
         Constructs the field_names list that the values query will be
         retrieving.
 
-        Called by the _clone() method after initialising the rest of the
+        Called by the _clone() method after initializing the rest of the
         instance.
         """
         self.extra_names = []
@@ -560,6 +705,7 @@
             raise TypeError("Merging '%s' classes must involve the same values in each case."
                     % self.__class__.__name__)
 
+
 class ValuesListQuerySet(ValuesQuerySet):
     def iterator(self):
         self.query.trim_extra_select(self.extra_names)
@@ -568,7 +714,7 @@
                 yield row[0]
         elif not self.query.extra_select:
             for row in self.query.results_iter():
-                yield row
+                yield tuple(row)
         else:
             # When extra(select=...) is involved, the extra cols come are
             # always at the start of the row, so we need to reorder the fields
@@ -583,6 +729,7 @@
         clone.flat = self.flat
         return clone
 
+
 class DateQuerySet(QuerySet):
     def iterator(self):
         return self.query.results_iter()
@@ -591,28 +738,38 @@
         """
         Sets up any special features of the query attribute.
 
-        Called by the _clone() method after initialising the rest of the
+        Called by the _clone() method after initializing the rest of the
         instance.
         """
         self.query = self.query.clone(klass=sql.DateQuery, setup=True)
         self.query.select = []
-        self.query.add_date_select(self._field.column, self._kind, self._order)
-        if self._field.null:
-            self.query.add_filter(('%s__isnull' % self._field.name, True))
+        field = self.model._meta.get_field(self._field_name, many_to_many=False)
+        assert isinstance(field, DateField), "%r isn't a DateField." \
+                % field.name
+        self.query.add_date_select(field, self._kind, self._order)
+        if field.null:
+            self.query.add_filter(('%s__isnull' % field.name, False))
 
     def _clone(self, klass=None, setup=False, **kwargs):
         c = super(DateQuerySet, self)._clone(klass, False, **kwargs)
-        c._field = self._field
+        c._field_name = self._field_name
         c._kind = self._kind
         if setup and hasattr(c, '_setup_query'):
             c._setup_query()
         return c
 
+
 class EmptyQuerySet(QuerySet):
     def __init__(self, model=None, query=None):
         super(EmptyQuerySet, self).__init__(model, query)
         self._result_cache = []
 
+    def __and__(self, other):
+        return self._clone()
+
+    def __or__(self, other):
+        return other._clone()
+
     def count(self):
         return 0
 
@@ -629,22 +786,9 @@
         # (it raises StopIteration immediately).
         yield iter([]).next()
 
-# QOperator, QNot, QAnd and QOr are temporarily retained for backwards
-# compatibility. All the old functionality is now part of the 'Q' class.
-class QOperator(Q):
-    def __init__(self, *args, **kwargs):
-        warnings.warn('Use Q instead of QOr, QAnd or QOperation.',
-                DeprecationWarning, stacklevel=2)
-        super(QOperator, self).__init__(*args, **kwargs)
-
-QOr = QAnd = QOperator
-
-def QNot(q):
-    warnings.warn('Use ~q instead of QNot(q)', DeprecationWarning, stacklevel=2)
-    return ~q
 
 def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
-        requested=None):
+                   requested=None):
     """
     Helper function that recursively returns an object with the specified
     related attributes already populated.
@@ -655,10 +799,14 @@
 
     restricted = requested is not None
     index_end = index_start + len(klass._meta.fields)
-    obj = klass(*row[index_start:index_end])
+    fields = row[index_start:index_end]
+    if not [x for x in fields if x is not None]:
+        # If we only have a list of Nones, there was not related object.
+        obj = None
+    else:
+        obj = klass(*fields)
     for f in klass._meta.fields:
-        if (not f.rel or (not restricted and f.null) or
-                (restricted and f.name not in requested) or f.rel.parent_link):
+        if not select_related_descend(f, restricted, requested):
             continue
         if restricted:
             next = requested[f.name]
@@ -668,56 +816,73 @@
                 cur_depth+1, next)
         if cached_row:
             rel_obj, index_end = cached_row
-            setattr(obj, f.get_cache_name(), rel_obj)
+            if obj is not None:
+                setattr(obj, f.get_cache_name(), rel_obj)
     return obj, index_end
 
+
 def delete_objects(seen_objs):
     """
     Iterate through a list of seen classes, and remove any instances that are
     referred to.
     """
-    ordered_classes = seen_objs.keys()
-    ordered_classes.reverse()
-
-    for cls in ordered_classes:
-        seen_objs[cls] = seen_objs[cls].items()
-        seen_objs[cls].sort()
+    try:
+        ordered_classes = seen_objs.keys()
+    except CyclicDependency:
+        # If there is a cyclic dependency, we cannot in general delete the
+        # objects.  However, if an appropriate transaction is set up, or if the
+        # database is lax enough, it will succeed. So for now, we go ahead and
+        # try anyway.
+        ordered_classes = seen_objs.unordered_keys()
 
-        # Pre notify all instances to be deleted
-        for pk_val, instance in seen_objs[cls]:
-            dispatcher.send(signal=signals.pre_delete, sender=cls,
-                    instance=instance)
+    obj_pairs = {}
+    for cls in ordered_classes:
+        items = seen_objs[cls].items()
+        items.sort()
+        obj_pairs[cls] = items
 
-        pk_list = [pk for pk,instance in seen_objs[cls]]
+        # Pre-notify all instances to be deleted.
+        for pk_val, instance in items:
+            signals.pre_delete.send(sender=cls, instance=instance)
+
+        pk_list = [pk for pk,instance in items]
         del_query = sql.DeleteQuery(cls, connection)
         del_query.delete_batch_related(pk_list)
 
         update_query = sql.UpdateQuery(cls, connection)
-        for field in cls._meta.fields:
-            if field.rel and field.null and field.rel.to in seen_objs:
-                update_query.clear_related(field, pk_list)
+        for field, model in cls._meta.get_fields_with_model():
+            if (field.rel and field.null and field.rel.to in seen_objs and
+                    filter(lambda f: f.column == field.column,
+                    field.rel.to._meta.fields)):
+                if model:
+                    sql.UpdateQuery(model, connection).clear_related(field,
+                            pk_list)
+                else:
+                    update_query.clear_related(field, pk_list)
 
-    # Now delete the actual data
+    # Now delete the actual data.
     for cls in ordered_classes:
-        seen_objs[cls].reverse()
-        pk_list = [pk for pk,instance in seen_objs[cls]]
+        items = obj_pairs[cls]
+        items.reverse()
+
+        pk_list = [pk for pk,instance in items]
         del_query = sql.DeleteQuery(cls, connection)
         del_query.delete_batch(pk_list)
 
         # Last cleanup; set NULLs where there once was a reference to the
         # object, NULL the primary key of the found objects, and perform
         # post-notification.
-        for pk_val, instance in seen_objs[cls]:
+        for pk_val, instance in items:
             for field in cls._meta.fields:
                 if field.rel and field.null and field.rel.to in seen_objs:
                     setattr(instance, field.attname, None)
 
-            dispatcher.send(signal=signals.post_delete, sender=cls,
-                    instance=instance)
+            signals.post_delete.send(sender=cls, instance=instance)
             setattr(instance, cls._meta.pk.attname, None)
 
     transaction.commit_unless_managed()
 
+
 def insert_query(model, values, return_id=False, raw_values=False):
     """
     Inserts a new record for the given model. This provides an interface to
@@ -727,4 +892,3 @@
     query = sql.InsertQuery(model, connection)
     query.insert_values(values, raw_values)
     return query.execute_sql(return_id)
-