app/django/db/models/sql/where.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/db/models/sql/where.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/db/models/sql/where.py	Tue Oct 14 16:00:59 2008 +0000
@@ -18,15 +18,55 @@
     Used to represent the SQL where-clause.
 
     The class is tied to the Query class that created it (in order to create
-    the corret SQL).
+    the correct SQL).
 
     The children in this tree are usually either Q-like objects or lists of
-    [table_alias, field_name, field_class, lookup_type, value]. However, a
-    child could also be any class with as_sql() and relabel_aliases() methods.
+    [table_alias, field_name, db_type, lookup_type, value_annotation,
+    params]. However, a child could also be any class with as_sql() and
+    relabel_aliases() methods.
     """
     default = AND
 
-    def as_sql(self, node=None, qn=None):
+    def add(self, data, connector):
+        """
+        Add a node to the where-tree. If the data is a list or tuple, it is
+        expected to be of the form (alias, col_name, field_obj, lookup_type,
+        value), which is then slightly munged before being stored (to avoid
+        storing any reference to field objects). Otherwise, the 'data' is
+        stored unchanged and can be anything with an 'as_sql()' method.
+        """
+        # Because of circular imports, we need to import this here.
+        from django.db.models.base import ObjectDoesNotExist
+
+        if not isinstance(data, (list, tuple)):
+            super(WhereNode, self).add(data, connector)
+            return
+
+        alias, col, field, lookup_type, value = data
+        try:
+            if field:
+                params = field.get_db_prep_lookup(lookup_type, value)
+                db_type = field.db_type()
+            else:
+                # This is possible when we add a comparison to NULL sometimes
+                # (we don't really need to waste time looking up the associated
+                # field object).
+                params = Field().get_db_prep_lookup(lookup_type, value)
+                db_type = None
+        except ObjectDoesNotExist:
+            # This can happen when trying to insert a reference to a null pk.
+            # We break out of the normal path and indicate there's nothing to
+            # match.
+            super(WhereNode, self).add(NothingNode(), connector)
+            return
+        if isinstance(value, datetime.datetime):
+            annotation = datetime.datetime
+        else:
+            annotation = bool(value)
+        super(WhereNode, self).add((alias, col, db_type, lookup_type,
+                annotation, params), connector)
+
+    def as_sql(self, qn=None):
         """
         Returns the SQL version of the where clause and the value to be
         substituted in. Returns None, None if this node is empty.
@@ -35,82 +75,73 @@
         (generally not needed except by the internal implementation for
         recursion).
         """
-        if node is None:
-            node = self
         if not qn:
             qn = connection.ops.quote_name
-        if not node.children:
+        if not self.children:
             return None, []
         result = []
         result_params = []
         empty = True
-        for child in node.children:
+        for child in self.children:
             try:
                 if hasattr(child, 'as_sql'):
                     sql, params = child.as_sql(qn=qn)
-                    format = '(%s)'
-                elif isinstance(child, tree.Node):
-                    sql, params = self.as_sql(child, qn)
-                    if child.negated:
-                        format = 'NOT (%s)'
-                    elif len(child.children) == 1:
-                        format = '%s'
-                    else:
-                        format = '(%s)'
                 else:
+                    # A leaf node in the tree.
                     sql, params = self.make_atom(child, qn)
-                    format = '%s'
             except EmptyResultSet:
-                if node.connector == AND and not node.negated:
+                if self.connector == AND and not self.negated:
                     # We can bail out early in this particular case (only).
                     raise
-                elif node.negated:
+                elif self.negated:
                     empty = False
                 continue
             except FullResultSet:
                 if self.connector == OR:
-                    if node.negated:
+                    if self.negated:
                         empty = True
                         break
                     # We match everything. No need for any constraints.
                     return '', []
-                if node.negated:
+                if self.negated:
                     empty = True
                 continue
             empty = False
             if sql:
-                result.append(format % sql)
+                result.append(sql)
                 result_params.extend(params)
         if empty:
             raise EmptyResultSet
-        conn = ' %s ' % node.connector
-        return conn.join(result), result_params
+
+        conn = ' %s ' % self.connector
+        sql_string = conn.join(result)
+        if sql_string:
+            if self.negated:
+                sql_string = 'NOT (%s)' % sql_string
+            elif len(self.children) != 1:
+                sql_string = '(%s)' % sql_string
+        return sql_string, result_params
 
     def make_atom(self, child, qn):
         """
-        Turn a tuple (table_alias, field_name, field_class, lookup_type, value)
-        into valid SQL.
+        Turn a tuple (table_alias, column_name, db_type, lookup_type,
+        value_annot, params) into valid SQL.
 
         Returns the string for the SQL fragment and the parameters to use for
         it.
         """
-        table_alias, name, field, lookup_type, value = child
+        table_alias, name, db_type, lookup_type, value_annot, params = child
         if table_alias:
             lhs = '%s.%s' % (qn(table_alias), qn(name))
         else:
             lhs = qn(name)
-        db_type = field and field.db_type() or None
         field_sql = connection.ops.field_cast_sql(db_type) % lhs
 
-        if isinstance(value, datetime.datetime):
+        if value_annot is datetime.datetime:
             cast_sql = connection.ops.datetime_cast_sql()
         else:
             cast_sql = '%s'
 
-        if field:
-            params = field.get_db_prep_lookup(lookup_type, value)
-        else:
-            params = Field().get_db_prep_lookup(lookup_type, value)
         if isinstance(params, QueryWrapper):
             extra, params = params.data
         else:
@@ -123,11 +154,11 @@
                     connection.operators[lookup_type] % cast_sql), params)
 
         if lookup_type == 'in':
-            if not value:
+            if not value_annot:
                 raise EmptyResultSet
             if extra:
                 return ('%s IN %s' % (field_sql, extra), params)
-            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))),
+            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))),
                     params)
         elif lookup_type in ('range', 'year'):
             return ('%s BETWEEN %%s and %%s' % field_sql, params)
@@ -135,8 +166,8 @@
             return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type,
                     field_sql), params)
         elif lookup_type == 'isnull':
-            return ('%s IS %sNULL' % (field_sql, (not value and 'NOT ' or '')),
-                    params)
+            return ('%s IS %sNULL' % (field_sql,
+                (not value_annot and 'NOT ' or '')), ())
         elif lookup_type == 'search':
             return (connection.ops.fulltext_search_sql(field_sql), params)
         elif lookup_type in ('regex', 'iregex'):
@@ -169,3 +200,14 @@
 
     def relabel_aliases(self, change_map, node=None):
         return
+
+class NothingNode(object):
+    """
+    A node that matches nothing.
+    """
+    def as_sql(self, qn=None):
+        raise EmptyResultSet
+
+    def relabel_aliases(self, change_map, node=None):
+        return
+