app/django/db/backends/oracle/base.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/db/backends/oracle/base.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/db/backends/oracle/base.py	Tue Oct 14 16:00:59 2008 +0000
@@ -5,11 +5,8 @@
 """
 
 import os
-
-from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util
-from django.db.backends.oracle import query
-from django.utils.datastructures import SortedDict
-from django.utils.encoding import smart_str, force_unicode
+import datetime
+import time
 
 # Oracle takes client-side character set encoding from the environment.
 os.environ['NLS_LANG'] = '.UTF8'
@@ -19,18 +16,23 @@
     from django.core.exceptions import ImproperlyConfigured
     raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
 
-DatabaseError = Database.Error
+from django.db.backends import *
+from django.db.backends.oracle import query
+from django.db.backends.oracle.client import DatabaseClient
+from django.db.backends.oracle.creation import DatabaseCreation
+from django.db.backends.oracle.introspection import DatabaseIntrospection
+from django.utils.encoding import smart_str, force_unicode
+
+DatabaseError = Database.DatabaseError
 IntegrityError = Database.IntegrityError
 
+
 class DatabaseFeatures(BaseDatabaseFeatures):
-    allows_group_by_ordinal = False
-    allows_unique_and_pk = False        # Suppress UNIQUE/PK for Oracle (ORA-02259)
     empty_fetchmany_value = ()
     needs_datetime_string_cast = False
-    needs_upper_for_iops = True
-    supports_tablespaces = True
-    uses_case_insensitive_names = True
     uses_custom_query_class = True
+    interprets_empty_strings_as_nulls = True
+
 
 class DatabaseOperations(BaseDatabaseOperations):
     def autoinc_sql(self, table, column):
@@ -40,7 +42,17 @@
         tr_name = get_trigger_name(table)
         tbl_name = self.quote_name(table)
         col_name = self.quote_name(column)
-        sequence_sql = 'CREATE SEQUENCE %s;' % sq_name
+        sequence_sql = """
+            DECLARE
+                i INTEGER;
+            BEGIN
+                SELECT COUNT(*) INTO i FROM USER_CATALOG
+                    WHERE TABLE_NAME = '%(sq_name)s' AND TABLE_TYPE = 'SEQUENCE';
+                IF i = 0 THEN
+                    EXECUTE IMMEDIATE 'CREATE SEQUENCE %(sq_name)s';
+                END IF;
+            END;
+            /""" % locals()
         trigger_sql = """
             CREATE OR REPLACE TRIGGER %(tr_name)s
             BEFORE INSERT ON %(tbl_name)s
@@ -86,11 +98,6 @@
         cursor.execute('SELECT %s_sq.currval FROM dual' % sq_name)
         return cursor.fetchone()[0]
 
-    def limit_offset_sql(self, limit, offset=None):
-        # Limits and offset are too complicated to be handled here.
-        # Instead, they are handled in django/db/backends/oracle/query.py.
-        return ""
-
     def lookup_cast(self, lookup_type):
         if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
             return "UPPER(%s)"
@@ -99,6 +106,9 @@
     def max_name_length(self):
         return 30
 
+    def prep_for_iexact_query(self, x):
+        return x
+
     def query_class(self, DefaultQueryClass):
         return query.query_class(DefaultQueryClass, Database)
 
@@ -145,11 +155,11 @@
             # Since we've just deleted all the rows, running our sequence
             # ALTER code will reset the sequence to 0.
             for sequence_info in sequences:
-                table_name = sequence_info['table']
-                seq_name = get_sequence_name(table_name)
+                sequence_name = get_sequence_name(sequence_info['table'])
+                table_name = self.quote_name(sequence_info['table'])
                 column_name = self.quote_name(sequence_info['column'] or 'id')
-                query = _get_sequence_reset_sql() % {'sequence': seq_name,
-                                                     'table': self.quote_name(table_name),
+                query = _get_sequence_reset_sql() % {'sequence': sequence_name,
+                                                     'table': table_name,
                                                      'column': column_name}
                 sql.append(query)
             return sql
@@ -161,19 +171,22 @@
         output = []
         query = _get_sequence_reset_sql()
         for model in model_list:
-            for f in model._meta.fields:
+            for f in model._meta.local_fields:
                 if isinstance(f, models.AutoField):
+                    table_name = self.quote_name(model._meta.db_table)
                     sequence_name = get_sequence_name(model._meta.db_table)
-                    column_name = self.quote_name(f.db_column or f.name)
+                    column_name = self.quote_name(f.column)
                     output.append(query % {'sequence': sequence_name,
-                                           'table': model._meta.db_table,
+                                           'table': table_name,
                                            'column': column_name})
                     break # Only one AutoField is allowed per model, so don't bother continuing.
             for f in model._meta.many_to_many:
+                table_name = self.quote_name(f.m2m_db_table())
                 sequence_name = get_sequence_name(f.m2m_db_table())
+                column_name = self.quote_name('id')
                 output.append(query % {'sequence': sequence_name,
-                                       'table': f.m2m_db_table(),
-                                       'column': self.quote_name('id')})
+                                       'table': table_name,
+                                       'column': column_name})
         return output
 
     def start_transaction_sql(self):
@@ -182,9 +195,22 @@
     def tablespace_sql(self, tablespace, inline=False):
         return "%sTABLESPACE %s" % ((inline and "USING INDEX " or ""), self.quote_name(tablespace))
 
+    def value_to_db_time(self, value):
+        if value is None:
+            return None
+        if isinstance(value, basestring):
+            return datetime.datetime(*(time.strptime(value, '%H:%M:%S')[:6]))
+        return datetime.datetime(1900, 1, 1, value.hour, value.minute,
+                                 value.second, value.microsecond)
+
+    def year_lookup_bounds_for_date_field(self, value):
+        first = '%s-01-01'
+        second = '%s-12-31'
+        return [first % value, second % value]
+
+
 class DatabaseWrapper(BaseDatabaseWrapper):
-    features = DatabaseFeatures()
-    ops = DatabaseOperations()
+
     operators = {
         'exact': '= %s',
         'iexact': '= UPPER(%s)',
@@ -201,6 +227,16 @@
     }
     oracle_version = None
 
+    def __init__(self, *args, **kwargs):
+        super(DatabaseWrapper, self).__init__(*args, **kwargs)
+
+        self.features = DatabaseFeatures()
+        self.ops = DatabaseOperations()
+        self.client = DatabaseClient()
+        self.creation = DatabaseCreation(self)
+        self.introspection = DatabaseIntrospection(self)
+        self.validation = BaseDatabaseValidation()
+
     def _valid_connection(self):
         return self.connection is not None
 
@@ -244,6 +280,28 @@
         cursor.arraysize = 100
         return cursor
 
+
+class OracleParam(object):
+    """
+    Wrapper object for formatting parameters for Oracle. If the string
+    representation of the value is large enough (greater than 4000 characters)
+    the input size needs to be set as NCLOB. Alternatively, if the parameter has
+    an `input_size` attribute, then the value of the `input_size` attribute will
+    be used instead. Otherwise, no input size will be set for the parameter when
+    executing the query.
+    """
+    def __init__(self, param, charset, strings_only=False):
+        self.smart_str = smart_str(param, charset, strings_only)
+        if hasattr(param, 'input_size'):
+            # If parameter has `input_size` attribute, use that.
+            self.input_size = param.input_size
+        elif isinstance(param, basestring) and len(param) > 4000:
+            # Mark any string parameter greater than 4000 characters as an NCLOB.
+            self.input_size = Database.NCLOB
+        else:
+            self.input_size = None
+
+
 class FormatStylePlaceholderCursor(Database.Cursor):
     """
     Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
@@ -258,15 +316,13 @@
     def _format_params(self, params):
         if isinstance(params, dict):
             result = {}
-            charset = self.charset
             for key, value in params.items():
-                result[smart_str(key, charset)] = smart_str(value, charset)
+                result[smart_str(key, self.charset)] = OracleParam(param, self.charset)
             return result
         else:
-            return tuple([smart_str(p, self.charset, True) for p in params])
+            return tuple([OracleParam(p, self.charset, True) for p in params])
 
     def _guess_input_sizes(self, params_list):
-        # Mark any string parameter greater than 4000 characters as an NCLOB.
         if isinstance(params_list[0], dict):
             sizes = {}
             iterators = [params.iteritems() for params in params_list]
@@ -275,13 +331,18 @@
             iterators = [enumerate(params) for params in params_list]
         for iterator in iterators:
             for key, value in iterator:
-                if isinstance(value, basestring) and len(value) > 4000:
-                    sizes[key] = Database.NCLOB
+                if value.input_size: sizes[key] = value.input_size
         if isinstance(sizes, dict):
             self.setinputsizes(**sizes)
         else:
             self.setinputsizes(*sizes)
 
+    def _param_generator(self, params):
+        if isinstance(params, dict):
+            return dict([(k, p.smart_str) for k, p in params.iteritems()])
+        else:
+            return [p.smart_str for p in params]
+
     def execute(self, query, params=None):
         if params is None:
             params = []
@@ -296,7 +357,13 @@
             query = query[:-1]
         query = smart_str(query, self.charset) % tuple(args)
         self._guess_input_sizes([params])
-        return Database.Cursor.execute(self, query, params)
+        try:
+            return Database.Cursor.execute(self, query, self._param_generator(params))
+        except DatabaseError, e:
+            # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
+            if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
+                e = IntegrityError(e.args[0])
+            raise e
 
     def executemany(self, query, params=None):
         try:
@@ -311,9 +378,15 @@
         if query.endswith(';') or query.endswith('/'):
             query = query[:-1]
         query = smart_str(query, self.charset) % tuple(args)
-        new_param_list = [self._format_params(i) for i in params]
-        self._guess_input_sizes(new_param_list)
-        return Database.Cursor.executemany(self, query, new_param_list)
+        formatted = [self._format_params(i) for i in params]
+        self._guess_input_sizes(formatted)
+        try:
+            return Database.Cursor.executemany(self, query, [self._param_generator(p) for p in formatted])
+        except DatabaseError, e:
+            # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
+            if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
+                e = IntegrityError(e.args[0])
+            raise e
 
     def fetchone(self):
         row = Database.Cursor.fetchone(self)