app/django/db/models/sql/subqueries.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/db/models/sql/subqueries.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/db/models/sql/subqueries.py	Tue Oct 14 16:00:59 2008 +0000
@@ -2,10 +2,9 @@
 Query subclasses which provide extra functionality beyond simple data retrieval.
 """
 
-from django.contrib.contenttypes import generic
 from django.core.exceptions import FieldError
 from django.db.models.sql.constants import *
-from django.db.models.sql.datastructures import RawValue, Date
+from django.db.models.sql.datastructures import Date
 from django.db.models.sql.query import Query
 from django.db.models.sql.where import AND
 
@@ -43,6 +42,7 @@
         More than one physical query may be executed if there are a
         lot of values in pk_list.
         """
+        from django.contrib.contenttypes import generic
         cls = self.model
         for related in cls._meta.get_all_related_many_to_many_objects():
             if not isinstance(related.field, generic.GenericRelation):
@@ -106,12 +106,20 @@
 
     def clone(self, klass=None, **kwargs):
         return super(UpdateQuery, self).clone(klass,
-                related_updates=self.related_updates.copy, **kwargs)
+                related_updates=self.related_updates.copy(), **kwargs)
 
     def execute_sql(self, result_type=None):
-        super(UpdateQuery, self).execute_sql(result_type)
+        """
+        Execute the specified update. Returns the number of rows affected by
+        the primary update query (there could be other updates on related
+        tables, but their rowcounts are not returned).
+        """
+        cursor = super(UpdateQuery, self).execute_sql(result_type)
+        rows = cursor.rowcount
+        del cursor
         for query in self.get_related_updates():
             query.execute_sql(result_type)
+        return rows
 
     def as_sql(self):
         """
@@ -285,7 +293,8 @@
     def clone(self, klass=None, **kwargs):
         extras = {'columns': self.columns[:], 'values': self.values[:],
                 'params': self.params}
-        return super(InsertQuery, self).clone(klass, extras)
+        extras.update(kwargs)
+        return super(InsertQuery, self).clone(klass, **extras)
 
     def as_sql(self):
         # We don't need quote_name_unless_alias() here, since these are all
@@ -335,6 +344,23 @@
     date field. This requires some special handling when converting the results
     back to Python objects, so we put it in a separate class.
     """
+    def __getstate__(self):
+        """
+        Special DateQuery-specific pickle handling.
+        """
+        for elt in self.select:
+            if isinstance(elt, Date):
+                # Eliminate a method reference that can't be pickled. The
+                # __setstate__ method restores this.
+                elt.date_sql_func = None
+        return super(DateQuery, self).__getstate__()
+
+    def __setstate__(self, obj_dict):
+        super(DateQuery, self).__setstate__(obj_dict)
+        for elt in self.select:
+            if isinstance(elt, Date):
+                self.date_sql_func = self.connection.ops.date_trunc_sql
+
     def results_iter(self):
         """
         Returns an iterator over the results from executing this query.
@@ -352,21 +378,24 @@
             for row in rows:
                 date = row[offset]
                 if resolve_columns:
-                    date = self.resolve_columns([date], fields)[0]
+                    date = self.resolve_columns(row, fields)[offset]
                 elif needs_string_cast:
                     date = typecast_timestamp(str(date))
                 yield date
 
-    def add_date_select(self, column, lookup_type, order='ASC'):
+    def add_date_select(self, field, lookup_type, order='ASC'):
         """
         Converts the query into a date extraction query.
         """
-        alias = self.join((None, self.model._meta.db_table, None, None))
-        select = Date((alias, column), lookup_type,
+        result = self.setup_joins([field.name], self.get_meta(),
+                self.get_initial_alias(), False)
+        alias = result[3][-1]
+        select = Date((alias, field.column), lookup_type,
                 self.connection.ops.date_trunc_sql)
         self.select = [select]
         self.select_fields = [None]
         self.select_related = False # See #7097.
+        self.extra_select = {}
         self.distinct = True
         self.order_by = order == 'ASC' and [1] or [-1]
 
@@ -382,4 +411,3 @@
 
     def get_ordering(self):
         return ()
-