app/django/db/models/base.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/db/models/base.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/db/models/base.py	Tue Oct 14 16:00:59 2008 +0000
@@ -3,45 +3,39 @@
 import sys
 import os
 from itertools import izip
+try:
+    set
+except NameError:
+    from sets import Set as set     # Python 2.3 fallback.
 
-import django.db.models.manipulators    # Imported to register signal handler.
-import django.db.models.manager         # Ditto.
-from django.core import validators
+import django.db.models.manager     # Imported to register signal handler.
 from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned, FieldError
-from django.db.models.fields import AutoField, ImageField, FieldDoesNotExist
+from django.db.models.fields import AutoField
 from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneField
-from django.db.models.query import delete_objects, Q
-from django.db.models.options import Options, AdminOptions
-from django.db import connection, transaction
+from django.db.models.query import delete_objects, Q, CollectedObjects
+from django.db.models.options import Options
+from django.db import connection, transaction, DatabaseError
 from django.db.models import signals
 from django.db.models.loading import register_models, get_model
-from django.dispatch import dispatcher
-from django.utils.datastructures import SortedDict
 from django.utils.functional import curry
 from django.utils.encoding import smart_str, force_unicode, smart_unicode
 from django.conf import settings
 
-try:
-    set
-except NameError:
-    from sets import Set as set     # Python 2.3 fallback
 
 class ModelBase(type):
-    "Metaclass for all models"
+    """
+    Metaclass for all models.
+    """
     def __new__(cls, name, bases, attrs):
-        # If this isn't a subclass of Model, don't do anything special.
-        try:
-            parents = [b for b in bases if issubclass(b, Model)]
-        except NameError:
-            # 'Model' isn't defined yet, meaning we're looking at Django's own
-            # Model class, defined below.
-            parents = []
+        super_new = super(ModelBase, cls).__new__
+        parents = [b for b in bases if isinstance(b, ModelBase)]
         if not parents:
-            return super(ModelBase, cls).__new__(cls, name, bases, attrs)
+            # If this isn't a subclass of Model, don't do anything special.
+            return super_new(cls, name, bases, attrs)
 
         # Create the class.
         module = attrs.pop('__module__')
-        new_class = type.__new__(cls, name, bases, {'__module__': module})
+        new_class = super_new(cls, name, bases, {'__module__': module})
         attr_meta = attrs.pop('Meta', None)
         abstract = getattr(attr_meta, 'abstract', False)
         if not attr_meta:
@@ -50,7 +44,15 @@
             meta = attr_meta
         base_meta = getattr(new_class, '_meta', None)
 
-        new_class.add_to_class('_meta', Options(meta))
+        if getattr(meta, 'app_label', None) is None:
+            # Figure out the app_label by looking one level up.
+            # For 'django.contrib.sites.models', this would be 'sites'.
+            model_module = sys.modules[new_class.__module__]
+            kwargs = {"app_label": model_module.__name__.split('.')[-2]}
+        else:
+            kwargs = {}
+
+        new_class.add_to_class('_meta', Options(meta, **kwargs))
         if not abstract:
             new_class.add_to_class('DoesNotExist',
                     subclass_exception('DoesNotExist', ObjectDoesNotExist, module))
@@ -65,17 +67,8 @@
                 if not hasattr(meta, 'get_latest_by'):
                     new_class._meta.get_latest_by = base_meta.get_latest_by
 
-        old_default_mgr = None
         if getattr(new_class, '_default_manager', None):
-            # We have a parent who set the default manager.
-            if new_class._default_manager.model._meta.abstract:
-                old_default_mgr = new_class._default_manager
             new_class._default_manager = None
-        if getattr(new_class._meta, 'app_label', None) is None:
-            # Figure out the app_label by looking one level up.
-            # For 'django.contrib.sites.models', this would be 'sites'.
-            model_module = sys.modules[new_class.__module__]
-            new_class._meta.app_label = model_module.__name__.split('.')[-2]
 
         # Bail out early if we have already created this class.
         m = get_model(new_class._meta.app_label, name, False)
@@ -94,7 +87,15 @@
                 # Things without _meta aren't functional models, so they're
                 # uninteresting parents.
                 continue
+
+            # All the fields of any type declared on this model
+            new_fields = new_class._meta.local_fields + \
+                         new_class._meta.local_many_to_many + \
+                         new_class._meta.virtual_fields
+            field_names = set([f.name for f in new_fields])
+
             if not base._meta.abstract:
+                # Concrete classes...
                 if base in o2o_map:
                     field = o2o_map[base]
                     field.primary_key = True
@@ -105,15 +106,42 @@
                             auto_created=True, parent_link=True)
                     new_class.add_to_class(attr_name, field)
                 new_class._meta.parents[base] = field
+
             else:
-                # The abstract base class case.
-                names = set([f.name for f in new_class._meta.local_fields + new_class._meta.many_to_many])
-                for field in base._meta.local_fields + base._meta.local_many_to_many:
-                    if field.name in names:
-                        raise FieldError('Local field %r in class %r clashes with field of similar name from abstract base class %r'
-                                % (field.name, name, base.__name__))
+                # .. and abstract ones.
+
+                # Check for clashes between locally declared fields and those
+                # on the ABC.
+                parent_fields = base._meta.local_fields + base._meta.local_many_to_many
+                for field in parent_fields:
+                    if field.name in field_names:
+                        raise FieldError('Local field %r in class %r clashes '\
+                                         'with field of similar name from '\
+                                         'abstract base class %r' % \
+                                            (field.name, name, base.__name__))
                     new_class.add_to_class(field.name, copy.deepcopy(field))
 
+                # Pass any non-abstract parent classes onto child.
+                new_class._meta.parents.update(base._meta.parents)
+
+            # Inherit managers from the abstract base classes.
+            base_managers = base._meta.abstract_managers
+            base_managers.sort()
+            for _, mgr_name, manager in base_managers:
+                val = getattr(new_class, mgr_name, None)
+                if not val or val is manager:
+                    new_manager = manager._copy_to_model(new_class)
+                    new_class.add_to_class(mgr_name, new_manager)
+
+            # Inherit virtual fields (like GenericForeignKey) from the parent class
+            for field in base._meta.virtual_fields:
+                if base._meta.abstract and field.name in field_names:
+                    raise FieldError('Local field %r in class %r clashes '\
+                                     'with field of similar name from '\
+                                     'abstract base class %r' % \
+                                        (field.name, name, base.__name__))
+                new_class.add_to_class(field.name, copy.deepcopy(field))
+
         if abstract:
             # Abstract base models can't be instantiated and don't appear in
             # the list of models for an app. We do the final setup for them a
@@ -122,8 +150,6 @@
             new_class.Meta = attr_meta
             return new_class
 
-        if old_default_mgr and not new_class._default_manager:
-            new_class._default_manager = old_default_mgr._copy_to_model(new_class)
         new_class._prepare()
         register_models(new_class._meta.app_label, new_class)
 
@@ -134,16 +160,15 @@
         return get_model(new_class._meta.app_label, name, False)
 
     def add_to_class(cls, name, value):
-        if name == 'Admin':
-            assert type(value) == types.ClassType, "%r attribute of %s model must be a class, not a %s object" % (name, cls.__name__, type(value))
-            value = AdminOptions(**dict([(k, v) for k, v in value.__dict__.items() if not k.startswith('_')]))
         if hasattr(value, 'contribute_to_class'):
             value.contribute_to_class(cls, name)
         else:
             setattr(cls, name, value)
 
     def _prepare(cls):
-        # Creates some methods once self._meta has been populated.
+        """
+        Creates some methods once self._meta has been populated.
+        """
         opts = cls._meta
         opts._prepare(cls)
 
@@ -160,13 +185,14 @@
         if hasattr(cls, 'get_absolute_url'):
             cls.get_absolute_url = curry(get_absolute_url, opts, cls.get_absolute_url)
 
-        dispatcher.send(signal=signals.class_prepared, sender=cls)
+        signals.class_prepared.send(sender=cls)
+
 
 class Model(object):
     __metaclass__ = ModelBase
 
     def __init__(self, *args, **kwargs):
-        dispatcher.send(signal=signals.pre_init, sender=self.__class__, args=args, kwargs=kwargs)
+        signals.pre_init.send(sender=self.__class__, args=args, kwargs=kwargs)
 
         # There is a rather weird disparity here; if kwargs, it's set, then args
         # overrides it. It should be one or the other; don't duplicate the work
@@ -198,6 +224,7 @@
         # keywords, or default.
 
         for field in fields_iter:
+            rel_obj = None
             if kwargs:
                 if isinstance(field.rel, ManyToOneRel):
                     try:
@@ -214,17 +241,18 @@
                         # pass in "None" for related objects if it's allowed.
                         if rel_obj is None and field.null:
                             val = None
-                        else:
-                            try:
-                                val = getattr(rel_obj, field.rel.get_related_field().attname)
-                            except AttributeError:
-                                raise TypeError("Invalid value: %r should be a %s instance, not a %s" %
-                                    (field.name, field.rel.to, type(rel_obj)))
                 else:
                     val = kwargs.pop(field.attname, field.get_default())
             else:
                 val = field.get_default()
-            setattr(self, field.attname, val)
+            # If we got passed a related instance, set it using the field.name
+            # instead of field.attname (e.g. "user" instead of "user_id") so
+            # that the object gets properly cached (and type checked) by the
+            # RelatedObjectDescriptor.
+            if rel_obj:
+                setattr(self, field.name, rel_obj)
+            else:
+                setattr(self, field.attname, val)
 
         if kwargs:
             for prop in kwargs.keys():
@@ -235,7 +263,7 @@
                     pass
             if kwargs:
                 raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0]
-        dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self)
+        signals.post_init.send(sender=self.__class__, instance=self)
 
     def __repr__(self):
         return smart_str(u'<%s: %s>' % (self.__class__.__name__, unicode(self)))
@@ -264,56 +292,78 @@
 
     pk = property(_get_pk_val, _set_pk_val)
 
-    def save(self):
+    def save(self, force_insert=False, force_update=False):
         """
-        Save the current instance. Override this in a subclass if you want to
+        Saves the current instance. Override this in a subclass if you want to
         control the saving process.
+
+        The 'force_insert' and 'force_update' parameters can be used to insist
+        that the "save" must be an SQL insert or update (or equivalent for
+        non-SQL backends), respectively. Normally, they should not be set.
         """
-        self.save_base()
+        if force_insert and force_update:
+            raise ValueError("Cannot force both insert and updating in "
+                    "model saving.")
+        self.save_base(force_insert=force_insert, force_update=force_update)
 
     save.alters_data = True
 
-    def save_base(self, raw=False, cls=None):
+    def save_base(self, raw=False, cls=None, force_insert=False,
+            force_update=False):
         """
         Does the heavy-lifting involved in saving. Subclasses shouldn't need to
         override this method. It's separate from save() in order to hide the
         need for overrides of save() to pass around internal-only parameters
         ('raw' and 'cls').
         """
+        assert not (force_insert and force_update)
         if not cls:
             cls = self.__class__
             meta = self._meta
             signal = True
-            dispatcher.send(signal=signals.pre_save, sender=self.__class__,
-                    instance=self, raw=raw)
+            signals.pre_save.send(sender=self.__class__, instance=self, raw=raw)
         else:
             meta = cls._meta
             signal = False
 
-        for parent, field in meta.parents.items():
-            self.save_base(raw, parent)
-            setattr(self, field.attname, self._get_pk_val(parent._meta))
+        # If we are in a raw save, save the object exactly as presented.
+        # That means that we don't try to be smart about saving attributes
+        # that might have come from the parent class - we just save the
+        # attributes we have been given to the class we have been given.
+        if not raw:
+            for parent, field in meta.parents.items():
+                # At this point, parent's primary key field may be unknown
+                # (for example, from administration form which doesn't fill
+                # this field). If so, fill it.
+                if getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None:
+                    setattr(self, parent._meta.pk.attname, getattr(self, field.attname))
+
+                self.save_base(raw, parent)
+                setattr(self, field.attname, self._get_pk_val(parent._meta))
 
         non_pks = [f for f in meta.local_fields if not f.primary_key]
 
         # First, try an UPDATE. If that doesn't update anything, do an INSERT.
         pk_val = self._get_pk_val(meta)
-        # Note: the comparison with '' is required for compatibility with
-        # oldforms-style model creation.
-        pk_set = pk_val is not None and smart_unicode(pk_val) != u''
+        pk_set = pk_val is not None
         record_exists = True
         manager = cls._default_manager
         if pk_set:
             # Determine whether a record with the primary key already exists.
-            if manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by():
+            if (force_update or (not force_insert and
+                    manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())):
                 # It does already exist, so do an UPDATE.
-                if non_pks:
+                if force_update or non_pks:
                     values = [(f, None, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
-                    manager.filter(pk=pk_val)._update(values)
+                    rows = manager.filter(pk=pk_val)._update(values)
+                    if force_update and not rows:
+                        raise DatabaseError("Forced update did not affect any rows.")
             else:
                 record_exists = False
         if not pk_set or not record_exists:
             if not pk_set:
+                if force_update:
+                    raise ValueError("Cannot force an update in save() with no primary key.")
                 values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields if not isinstance(f, AutoField)]
             else:
                 values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields]
@@ -336,44 +386,23 @@
         transaction.commit_unless_managed()
 
         if signal:
-            dispatcher.send(signal=signals.post_save, sender=self.__class__,
-                    instance=self, created=(not record_exists), raw=raw)
+            signals.post_save.send(sender=self.__class__, instance=self,
+                created=(not record_exists), raw=raw)
 
     save_base.alters_data = True
 
-    def validate(self):
-        """
-        First coerces all fields on this instance to their proper Python types.
-        Then runs validation on every field. Returns a dictionary of
-        field_name -> error_list.
+    def _collect_sub_objects(self, seen_objs, parent=None, nullable=False):
         """
-        error_dict = {}
-        invalid_python = {}
-        for f in self._meta.fields:
-            try:
-                setattr(self, f.attname, f.to_python(getattr(self, f.attname, f.get_default())))
-            except validators.ValidationError, e:
-                error_dict[f.name] = e.messages
-                invalid_python[f.name] = 1
-        for f in self._meta.fields:
-            if f.name in invalid_python:
-                continue
-            errors = f.validate_full(getattr(self, f.attname, f.get_default()), self.__dict__)
-            if errors:
-                error_dict[f.name] = errors
-        return error_dict
+        Recursively populates seen_objs with all objects related to this
+        object.
 
-    def _collect_sub_objects(self, seen_objs):
-        """
-        Recursively populates seen_objs with all objects related to this object.
-        When done, seen_objs will be in the format:
-            {model_class: {pk_val: obj, pk_val: obj, ...},
-             model_class: {pk_val: obj, pk_val: obj, ...}, ...}
+        When done, seen_objs.items() will be in the format:
+            [(model_class, {pk_val: obj, pk_val: obj, ...}),
+             (model_class, {pk_val: obj, pk_val: obj, ...}), ...]
         """
         pk_val = self._get_pk_val()
-        if pk_val in seen_objs.setdefault(self.__class__, {}):
+        if seen_objs.add(self.__class__, pk_val, self, parent, nullable):
             return
-        seen_objs.setdefault(self.__class__, {})[pk_val] = self
 
         for related in self._meta.get_all_related_objects():
             rel_opts_name = related.get_accessor_name()
@@ -383,26 +412,41 @@
                 except ObjectDoesNotExist:
                     pass
                 else:
-                    sub_obj._collect_sub_objects(seen_objs)
+                    sub_obj._collect_sub_objects(seen_objs, self.__class__, related.field.null)
             else:
                 for sub_obj in getattr(self, rel_opts_name).all():
-                    sub_obj._collect_sub_objects(seen_objs)
+                    sub_obj._collect_sub_objects(seen_objs, self.__class__, related.field.null)
+
+        # Handle any ancestors (for the model-inheritance case). We do this by
+        # traversing to the most remote parent classes -- those with no parents
+        # themselves -- and then adding those instances to the collection. That
+        # will include all the child instances down to "self".
+        parent_stack = self._meta.parents.values()
+        while parent_stack:
+            link = parent_stack.pop()
+            parent_obj = getattr(self, link.name)
+            if parent_obj._meta.parents:
+                parent_stack.extend(parent_obj._meta.parents.values())
+                continue
+            # At this point, parent_obj is base class (no ancestor models). So
+            # delete it and all its descendents.
+            parent_obj._collect_sub_objects(seen_objs)
 
     def delete(self):
         assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname)
 
-        # Find all the objects than need to be deleted
-        seen_objs = SortedDict()
+        # Find all the objects than need to be deleted.
+        seen_objs = CollectedObjects()
         self._collect_sub_objects(seen_objs)
 
-        # Actually delete the objects
+        # Actually delete the objects.
         delete_objects(seen_objs)
 
     delete.alters_data = True
 
     def _get_FIELD_display(self, field):
         value = getattr(self, field.attname)
-        return force_unicode(dict(field.choices).get(value, value), strings_only=True)
+        return force_unicode(dict(field.flatchoices).get(value, value), strings_only=True)
 
     def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
         op = is_next and 'gt' or 'lt'
@@ -433,74 +477,7 @@
             setattr(self, cachename, obj)
         return getattr(self, cachename)
 
-    def _get_FIELD_filename(self, field):
-        if getattr(self, field.attname): # value is not blank
-            return os.path.join(settings.MEDIA_ROOT, getattr(self, field.attname))
-        return ''
 
-    def _get_FIELD_url(self, field):
-        if getattr(self, field.attname): # value is not blank
-            import urlparse
-            return urlparse.urljoin(settings.MEDIA_URL, getattr(self, field.attname)).replace('\\', '/')
-        return ''
-
-    def _get_FIELD_size(self, field):
-        return os.path.getsize(self._get_FIELD_filename(field))
-
-    def _save_FIELD_file(self, field, filename, raw_contents, save=True):
-        directory = field.get_directory_name()
-        try: # Create the date-based directory if it doesn't exist.
-            os.makedirs(os.path.join(settings.MEDIA_ROOT, directory))
-        except OSError: # Directory probably already exists.
-            pass
-        filename = field.get_filename(filename)
-
-        # If the filename already exists, keep adding an underscore to the name of
-        # the file until the filename doesn't exist.
-        while os.path.exists(os.path.join(settings.MEDIA_ROOT, filename)):
-            try:
-                dot_index = filename.rindex('.')
-            except ValueError: # filename has no dot
-                filename += '_'
-            else:
-                filename = filename[:dot_index] + '_' + filename[dot_index:]
-
-        # Write the file to disk.
-        setattr(self, field.attname, filename)
-
-        full_filename = self._get_FIELD_filename(field)
-        fp = open(full_filename, 'wb')
-        fp.write(raw_contents)
-        fp.close()
-
-        # Save the width and/or height, if applicable.
-        if isinstance(field, ImageField) and (field.width_field or field.height_field):
-            from django.utils.images import get_image_dimensions
-            width, height = get_image_dimensions(full_filename)
-            if field.width_field:
-                setattr(self, field.width_field, width)
-            if field.height_field:
-                setattr(self, field.height_field, height)
-
-        # Save the object because it has changed unless save is False
-        if save:
-            self.save()
-
-    _save_FIELD_file.alters_data = True
-
-    def _get_FIELD_width(self, field):
-        return self._get_image_dimensions(field)[0]
-
-    def _get_FIELD_height(self, field):
-        return self._get_image_dimensions(field)[1]
-
-    def _get_image_dimensions(self, field):
-        cachename = "__%s_dimensions_cache" % field.name
-        if not hasattr(self, cachename):
-            from django.utils.images import get_image_dimensions
-            filename = self._get_FIELD_filename(field)
-            setattr(self, cachename, get_image_dimensions(filename))
-        return getattr(self, cachename)
 
 ############################################
 # HELPER FUNCTIONS (CURRIED MODEL METHODS) #
@@ -517,6 +494,7 @@
         ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i)
     transaction.commit_unless_managed()
 
+
 def method_get_order(ordered_obj, self):
     rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.rel.field_name)
     order_name = ordered_obj._meta.order_with_respect_to.name
@@ -524,6 +502,7 @@
     return [r[pk_name] for r in
             ordered_obj.objects.filter(**{order_name: rel_val}).values(pk_name)]
 
+
 ##############################################
 # HELPER FUNCTIONS (CURRIED MODEL FUNCTIONS) #
 ##############################################
@@ -531,6 +510,7 @@
 def get_absolute_url(opts, func, self, *args, **kwargs):
     return settings.ABSOLUTE_URL_OVERRIDES.get('%s.%s' % (opts.app_label, opts.module_name), func)(self, *args, **kwargs)
 
+
 ########
 # MISC #
 ########
@@ -542,8 +522,6 @@
     # Prior to Python 2.5, Exception was an old-style class
     def subclass_exception(name, parent, unused):
         return types.ClassType(name, (parent,), {})
-
 else:
     def subclass_exception(name, parent, module):
         return type(name, (parent,), {'__module__': module})
-