app/django/db/backends/oracle/base.py
changeset 54 03e267d67478
child 323 ff1a9aa48cfd
equal deleted inserted replaced
53:57b4279d8c4e 54:03e267d67478
       
     1 """
       
     2 Oracle database backend for Django.
       
     3 
       
     4 Requires cx_Oracle: http://www.python.net/crew/atuining/cx_Oracle/
       
     5 """
       
     6 
       
     7 import os
       
     8 
       
     9 from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util
       
    10 from django.db.backends.oracle import query
       
    11 from django.utils.datastructures import SortedDict
       
    12 from django.utils.encoding import smart_str, force_unicode
       
    13 
       
    14 # Oracle takes client-side character set encoding from the environment.
       
    15 os.environ['NLS_LANG'] = '.UTF8'
       
    16 try:
       
    17     import cx_Oracle as Database
       
    18 except ImportError, e:
       
    19     from django.core.exceptions import ImproperlyConfigured
       
    20     raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
       
    21 
       
    22 DatabaseError = Database.Error
       
    23 IntegrityError = Database.IntegrityError
       
    24 
       
    25 class DatabaseFeatures(BaseDatabaseFeatures):
       
    26     allows_group_by_ordinal = False
       
    27     allows_unique_and_pk = False        # Suppress UNIQUE/PK for Oracle (ORA-02259)
       
    28     empty_fetchmany_value = ()
       
    29     needs_datetime_string_cast = False
       
    30     needs_upper_for_iops = True
       
    31     supports_tablespaces = True
       
    32     uses_case_insensitive_names = True
       
    33     uses_custom_query_class = True
       
    34 
       
    35 class DatabaseOperations(BaseDatabaseOperations):
       
    36     def autoinc_sql(self, table, column):
       
    37         # To simulate auto-incrementing primary keys in Oracle, we have to
       
    38         # create a sequence and a trigger.
       
    39         sq_name = get_sequence_name(table)
       
    40         tr_name = get_trigger_name(table)
       
    41         tbl_name = self.quote_name(table)
       
    42         col_name = self.quote_name(column)
       
    43         sequence_sql = 'CREATE SEQUENCE %s;' % sq_name
       
    44         trigger_sql = """
       
    45             CREATE OR REPLACE TRIGGER %(tr_name)s
       
    46             BEFORE INSERT ON %(tbl_name)s
       
    47             FOR EACH ROW
       
    48             WHEN (new.%(col_name)s IS NULL)
       
    49                 BEGIN
       
    50                     SELECT %(sq_name)s.nextval
       
    51                     INTO :new.%(col_name)s FROM dual;
       
    52                 END;
       
    53                 /""" % locals()
       
    54         return sequence_sql, trigger_sql
       
    55 
       
    56     def date_extract_sql(self, lookup_type, field_name):
       
    57         # http://download-east.oracle.com/docs/cd/B10501_01/server.920/a96540/functions42a.htm#1017163
       
    58         return "EXTRACT(%s FROM %s)" % (lookup_type, field_name)
       
    59 
       
    60     def date_trunc_sql(self, lookup_type, field_name):
       
    61         # Oracle uses TRUNC() for both dates and numbers.
       
    62         # http://download-east.oracle.com/docs/cd/B10501_01/server.920/a96540/functions155a.htm#SQLRF06151
       
    63         if lookup_type == 'day':
       
    64             sql = 'TRUNC(%s)' % field_name
       
    65         else:
       
    66             sql = "TRUNC(%s, '%s')" % (field_name, lookup_type)
       
    67         return sql
       
    68 
       
    69     def datetime_cast_sql(self):
       
    70         return "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')"
       
    71 
       
    72     def deferrable_sql(self):
       
    73         return " DEFERRABLE INITIALLY DEFERRED"
       
    74 
       
    75     def drop_sequence_sql(self, table):
       
    76         return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))
       
    77 
       
    78     def field_cast_sql(self, db_type):
       
    79         if db_type and db_type.endswith('LOB'):
       
    80             return "DBMS_LOB.SUBSTR(%s)"
       
    81         else:
       
    82             return "%s"
       
    83 
       
    84     def last_insert_id(self, cursor, table_name, pk_name):
       
    85         sq_name = util.truncate_name(table_name, self.max_name_length() - 3)
       
    86         cursor.execute('SELECT %s_sq.currval FROM dual' % sq_name)
       
    87         return cursor.fetchone()[0]
       
    88 
       
    89     def limit_offset_sql(self, limit, offset=None):
       
    90         # Limits and offset are too complicated to be handled here.
       
    91         # Instead, they are handled in django/db/backends/oracle/query.py.
       
    92         return ""
       
    93 
       
    94     def lookup_cast(self, lookup_type):
       
    95         if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
       
    96             return "UPPER(%s)"
       
    97         return "%s"
       
    98 
       
    99     def max_name_length(self):
       
   100         return 30
       
   101 
       
   102     def query_class(self, DefaultQueryClass):
       
   103         return query.query_class(DefaultQueryClass, Database)
       
   104 
       
   105     def quote_name(self, name):
       
   106         # SQL92 requires delimited (quoted) names to be case-sensitive.  When
       
   107         # not quoted, Oracle has case-insensitive behavior for identifiers, but
       
   108         # always defaults to uppercase.
       
   109         # We simplify things by making Oracle identifiers always uppercase.
       
   110         if not name.startswith('"') and not name.endswith('"'):
       
   111             name = '"%s"' % util.truncate_name(name.upper(), self.max_name_length())
       
   112         return name.upper()
       
   113 
       
   114     def random_function_sql(self):
       
   115         return "DBMS_RANDOM.RANDOM"
       
   116 
       
   117     def regex_lookup_9(self, lookup_type):
       
   118         raise NotImplementedError("Regexes are not supported in Oracle before version 10g.")
       
   119 
       
   120     def regex_lookup_10(self, lookup_type):
       
   121         if lookup_type == 'regex':
       
   122             match_option = "'c'"
       
   123         else:
       
   124             match_option = "'i'"
       
   125         return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
       
   126 
       
   127     def regex_lookup(self, lookup_type):
       
   128         # If regex_lookup is called before it's been initialized, then create
       
   129         # a cursor to initialize it and recur.
       
   130         from django.db import connection
       
   131         connection.cursor()
       
   132         return connection.ops.regex_lookup(lookup_type)
       
   133 
       
   134     def sql_flush(self, style, tables, sequences):
       
   135         # Return a list of 'TRUNCATE x;', 'TRUNCATE y;',
       
   136         # 'TRUNCATE z;'... style SQL statements
       
   137         if tables:
       
   138             # Oracle does support TRUNCATE, but it seems to get us into
       
   139             # FK referential trouble, whereas DELETE FROM table works.
       
   140             sql = ['%s %s %s;' % \
       
   141                     (style.SQL_KEYWORD('DELETE'),
       
   142                      style.SQL_KEYWORD('FROM'),
       
   143                      style.SQL_FIELD(self.quote_name(table))
       
   144                      ) for table in tables]
       
   145             # Since we've just deleted all the rows, running our sequence
       
   146             # ALTER code will reset the sequence to 0.
       
   147             for sequence_info in sequences:
       
   148                 table_name = sequence_info['table']
       
   149                 seq_name = get_sequence_name(table_name)
       
   150                 column_name = self.quote_name(sequence_info['column'] or 'id')
       
   151                 query = _get_sequence_reset_sql() % {'sequence': seq_name,
       
   152                                                      'table': self.quote_name(table_name),
       
   153                                                      'column': column_name}
       
   154                 sql.append(query)
       
   155             return sql
       
   156         else:
       
   157             return []
       
   158 
       
   159     def sequence_reset_sql(self, style, model_list):
       
   160         from django.db import models
       
   161         output = []
       
   162         query = _get_sequence_reset_sql()
       
   163         for model in model_list:
       
   164             for f in model._meta.fields:
       
   165                 if isinstance(f, models.AutoField):
       
   166                     sequence_name = get_sequence_name(model._meta.db_table)
       
   167                     column_name = self.quote_name(f.db_column or f.name)
       
   168                     output.append(query % {'sequence': sequence_name,
       
   169                                            'table': model._meta.db_table,
       
   170                                            'column': column_name})
       
   171                     break # Only one AutoField is allowed per model, so don't bother continuing.
       
   172             for f in model._meta.many_to_many:
       
   173                 sequence_name = get_sequence_name(f.m2m_db_table())
       
   174                 output.append(query % {'sequence': sequence_name,
       
   175                                        'table': f.m2m_db_table(),
       
   176                                        'column': self.quote_name('id')})
       
   177         return output
       
   178 
       
   179     def start_transaction_sql(self):
       
   180         return ''
       
   181 
       
   182     def tablespace_sql(self, tablespace, inline=False):
       
   183         return "%sTABLESPACE %s" % ((inline and "USING INDEX " or ""), self.quote_name(tablespace))
       
   184 
       
   185 class DatabaseWrapper(BaseDatabaseWrapper):
       
   186     features = DatabaseFeatures()
       
   187     ops = DatabaseOperations()
       
   188     operators = {
       
   189         'exact': '= %s',
       
   190         'iexact': '= UPPER(%s)',
       
   191         'contains': "LIKEC %s ESCAPE '\\'",
       
   192         'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
       
   193         'gt': '> %s',
       
   194         'gte': '>= %s',
       
   195         'lt': '< %s',
       
   196         'lte': '<= %s',
       
   197         'startswith': "LIKEC %s ESCAPE '\\'",
       
   198         'endswith': "LIKEC %s ESCAPE '\\'",
       
   199         'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
       
   200         'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
       
   201     }
       
   202     oracle_version = None
       
   203 
       
   204     def _valid_connection(self):
       
   205         return self.connection is not None
       
   206 
       
   207     def _cursor(self, settings):
       
   208         cursor = None
       
   209         if not self._valid_connection():
       
   210             if len(settings.DATABASE_HOST.strip()) == 0:
       
   211                 settings.DATABASE_HOST = 'localhost'
       
   212             if len(settings.DATABASE_PORT.strip()) != 0:
       
   213                 dsn = Database.makedsn(settings.DATABASE_HOST, int(settings.DATABASE_PORT), settings.DATABASE_NAME)
       
   214                 self.connection = Database.connect(settings.DATABASE_USER, settings.DATABASE_PASSWORD, dsn, **self.options)
       
   215             else:
       
   216                 conn_string = "%s/%s@%s" % (settings.DATABASE_USER, settings.DATABASE_PASSWORD, settings.DATABASE_NAME)
       
   217                 self.connection = Database.connect(conn_string, **self.options)
       
   218             cursor = FormatStylePlaceholderCursor(self.connection)
       
   219             # Set oracle date to ansi date format.  This only needs to execute
       
   220             # once when we create a new connection.
       
   221             cursor.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD' "
       
   222                            "NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'")
       
   223             try:
       
   224                 self.oracle_version = int(self.connection.version.split('.')[0])
       
   225                 # There's no way for the DatabaseOperations class to know the
       
   226                 # currently active Oracle version, so we do some setups here.
       
   227                 # TODO: Multi-db support will need a better solution (a way to
       
   228                 # communicate the current version).
       
   229                 if self.oracle_version <= 9:
       
   230                     self.ops.regex_lookup = self.ops.regex_lookup_9
       
   231                 else:
       
   232                     self.ops.regex_lookup = self.ops.regex_lookup_10
       
   233             except ValueError:
       
   234                 pass
       
   235             try:
       
   236                 self.connection.stmtcachesize = 20
       
   237             except:
       
   238                 # Django docs specify cx_Oracle version 4.3.1 or higher, but
       
   239                 # stmtcachesize is available only in 4.3.2 and up.
       
   240                 pass
       
   241         if not cursor:
       
   242             cursor = FormatStylePlaceholderCursor(self.connection)
       
   243         # Default arraysize of 1 is highly sub-optimal.
       
   244         cursor.arraysize = 100
       
   245         return cursor
       
   246 
       
   247 class FormatStylePlaceholderCursor(Database.Cursor):
       
   248     """
       
   249     Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
       
   250     style. This fixes it -- but note that if you want to use a literal "%s" in
       
   251     a query, you'll need to use "%%s".
       
   252 
       
   253     We also do automatic conversion between Unicode on the Python side and
       
   254     UTF-8 -- for talking to Oracle -- in here.
       
   255     """
       
   256     charset = 'utf-8'
       
   257 
       
   258     def _format_params(self, params):
       
   259         if isinstance(params, dict):
       
   260             result = {}
       
   261             charset = self.charset
       
   262             for key, value in params.items():
       
   263                 result[smart_str(key, charset)] = smart_str(value, charset)
       
   264             return result
       
   265         else:
       
   266             return tuple([smart_str(p, self.charset, True) for p in params])
       
   267 
       
   268     def _guess_input_sizes(self, params_list):
       
   269         # Mark any string parameter greater than 4000 characters as an NCLOB.
       
   270         if isinstance(params_list[0], dict):
       
   271             sizes = {}
       
   272             iterators = [params.iteritems() for params in params_list]
       
   273         else:
       
   274             sizes = [None] * len(params_list[0])
       
   275             iterators = [enumerate(params) for params in params_list]
       
   276         for iterator in iterators:
       
   277             for key, value in iterator:
       
   278                 if isinstance(value, basestring) and len(value) > 4000:
       
   279                     sizes[key] = Database.NCLOB
       
   280         if isinstance(sizes, dict):
       
   281             self.setinputsizes(**sizes)
       
   282         else:
       
   283             self.setinputsizes(*sizes)
       
   284 
       
   285     def execute(self, query, params=None):
       
   286         if params is None:
       
   287             params = []
       
   288         else:
       
   289             params = self._format_params(params)
       
   290         args = [(':arg%d' % i) for i in range(len(params))]
       
   291         # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it
       
   292         # it does want a trailing ';' but not a trailing '/'.  However, these
       
   293         # characters must be included in the original query in case the query
       
   294         # is being passed to SQL*Plus.
       
   295         if query.endswith(';') or query.endswith('/'):
       
   296             query = query[:-1]
       
   297         query = smart_str(query, self.charset) % tuple(args)
       
   298         self._guess_input_sizes([params])
       
   299         return Database.Cursor.execute(self, query, params)
       
   300 
       
   301     def executemany(self, query, params=None):
       
   302         try:
       
   303           args = [(':arg%d' % i) for i in range(len(params[0]))]
       
   304         except (IndexError, TypeError):
       
   305           # No params given, nothing to do
       
   306           return None
       
   307         # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it
       
   308         # it does want a trailing ';' but not a trailing '/'.  However, these
       
   309         # characters must be included in the original query in case the query
       
   310         # is being passed to SQL*Plus.
       
   311         if query.endswith(';') or query.endswith('/'):
       
   312             query = query[:-1]
       
   313         query = smart_str(query, self.charset) % tuple(args)
       
   314         new_param_list = [self._format_params(i) for i in params]
       
   315         self._guess_input_sizes(new_param_list)
       
   316         return Database.Cursor.executemany(self, query, new_param_list)
       
   317 
       
   318     def fetchone(self):
       
   319         row = Database.Cursor.fetchone(self)
       
   320         if row is None:
       
   321             return row
       
   322         return tuple([to_unicode(e) for e in row])
       
   323 
       
   324     def fetchmany(self, size=None):
       
   325         if size is None:
       
   326             size = self.arraysize
       
   327         return tuple([tuple([to_unicode(e) for e in r]) for r in Database.Cursor.fetchmany(self, size)])
       
   328 
       
   329     def fetchall(self):
       
   330         return tuple([tuple([to_unicode(e) for e in r]) for r in Database.Cursor.fetchall(self)])
       
   331 
       
   332 def to_unicode(s):
       
   333     """
       
   334     Convert strings to Unicode objects (and return all other data types
       
   335     unchanged).
       
   336     """
       
   337     if isinstance(s, basestring):
       
   338         return force_unicode(s)
       
   339     return s
       
   340 
       
   341 def _get_sequence_reset_sql():
       
   342     # TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
       
   343     return """
       
   344         DECLARE
       
   345             startvalue integer;
       
   346             cval integer;
       
   347         BEGIN
       
   348             LOCK TABLE %(table)s IN SHARE MODE;
       
   349             SELECT NVL(MAX(%(column)s), 0) INTO startvalue FROM %(table)s;
       
   350             SELECT %(sequence)s.nextval INTO cval FROM dual;
       
   351             cval := startvalue - cval;
       
   352             IF cval != 0 THEN
       
   353                 EXECUTE IMMEDIATE 'ALTER SEQUENCE %(sequence)s MINVALUE 0 INCREMENT BY '||cval;
       
   354                 SELECT %(sequence)s.nextval INTO cval FROM dual;
       
   355                 EXECUTE IMMEDIATE 'ALTER SEQUENCE %(sequence)s INCREMENT BY 1';
       
   356             END IF;
       
   357             COMMIT;
       
   358         END;
       
   359         /"""
       
   360 
       
   361 def get_sequence_name(table):
       
   362     name_length = DatabaseOperations().max_name_length() - 3
       
   363     return '%s_SQ' % util.truncate_name(table, name_length).upper()
       
   364 
       
   365 def get_trigger_name(table):
       
   366     name_length = DatabaseOperations().max_name_length() - 3
       
   367     return '%s_TR' % util.truncate_name(table, name_length).upper()