|
1 import re |
|
2 |
|
3 from django.db.backends import BaseDatabaseOperations |
|
4 |
|
5 server_version_re = re.compile(r'PostgreSQL (\d{1,2})\.(\d{1,2})\.?(\d{1,2})?') |
|
6 |
|
7 # This DatabaseOperations class lives in here instead of base.py because it's |
|
8 # used by both the 'postgresql' and 'postgresql_psycopg2' backends. |
|
9 |
|
10 class DatabaseOperations(BaseDatabaseOperations): |
|
11 def __init__(self): |
|
12 self._postgres_version = None |
|
13 |
|
14 def _get_postgres_version(self): |
|
15 if self._postgres_version is None: |
|
16 from django.db import connection |
|
17 cursor = connection.cursor() |
|
18 cursor.execute("SELECT version()") |
|
19 version_string = cursor.fetchone()[0] |
|
20 m = server_version_re.match(version_string) |
|
21 if not m: |
|
22 raise Exception('Unable to determine PostgreSQL version from version() function string: %r' % version_string) |
|
23 self._postgres_version = [int(val) for val in m.groups() if val] |
|
24 return self._postgres_version |
|
25 postgres_version = property(_get_postgres_version) |
|
26 |
|
27 def date_extract_sql(self, lookup_type, field_name): |
|
28 # http://www.postgresql.org/docs/8.0/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT |
|
29 return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) |
|
30 |
|
31 def date_trunc_sql(self, lookup_type, field_name): |
|
32 # http://www.postgresql.org/docs/8.0/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC |
|
33 return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) |
|
34 |
|
35 def deferrable_sql(self): |
|
36 return " DEFERRABLE INITIALLY DEFERRED" |
|
37 |
|
38 def field_cast_sql(self, db_type): |
|
39 if db_type == 'inet': |
|
40 return 'HOST(%s)' |
|
41 return '%s' |
|
42 |
|
43 def last_insert_id(self, cursor, table_name, pk_name): |
|
44 cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (table_name, pk_name)) |
|
45 return cursor.fetchone()[0] |
|
46 |
|
47 def no_limit_value(self): |
|
48 return None |
|
49 |
|
50 def quote_name(self, name): |
|
51 if name.startswith('"') and name.endswith('"'): |
|
52 return name # Quoting once is enough. |
|
53 return '"%s"' % name |
|
54 |
|
55 def sql_flush(self, style, tables, sequences): |
|
56 if tables: |
|
57 if self.postgres_version[0] >= 8 and self.postgres_version[1] >= 1: |
|
58 # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* |
|
59 # in order to be able to truncate tables referenced by a foreign |
|
60 # key in any other table. The result is a single SQL TRUNCATE |
|
61 # statement. |
|
62 sql = ['%s %s;' % \ |
|
63 (style.SQL_KEYWORD('TRUNCATE'), |
|
64 style.SQL_FIELD(', '.join([self.quote_name(table) for table in tables])) |
|
65 )] |
|
66 else: |
|
67 # Older versions of Postgres can't do TRUNCATE in a single call, so |
|
68 # they must use a simple delete. |
|
69 sql = ['%s %s %s;' % \ |
|
70 (style.SQL_KEYWORD('DELETE'), |
|
71 style.SQL_KEYWORD('FROM'), |
|
72 style.SQL_FIELD(self.quote_name(table)) |
|
73 ) for table in tables] |
|
74 |
|
75 # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements |
|
76 # to reset sequence indices |
|
77 for sequence_info in sequences: |
|
78 table_name = sequence_info['table'] |
|
79 column_name = sequence_info['column'] |
|
80 if column_name and len(column_name) > 0: |
|
81 sequence_name = '%s_%s_seq' % (table_name, column_name) |
|
82 else: |
|
83 sequence_name = '%s_id_seq' % table_name |
|
84 sql.append("%s setval('%s', 1, false);" % \ |
|
85 (style.SQL_KEYWORD('SELECT'), |
|
86 style.SQL_FIELD(self.quote_name(sequence_name))) |
|
87 ) |
|
88 return sql |
|
89 else: |
|
90 return [] |
|
91 |
|
92 def sequence_reset_sql(self, style, model_list): |
|
93 from django.db import models |
|
94 output = [] |
|
95 qn = self.quote_name |
|
96 for model in model_list: |
|
97 # Use `coalesce` to set the sequence for each model to the max pk value if there are records, |
|
98 # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true |
|
99 # if there are records (as the max pk value is already in use), otherwise set it to false. |
|
100 for f in model._meta.fields: |
|
101 if isinstance(f, models.AutoField): |
|
102 output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ |
|
103 (style.SQL_KEYWORD('SELECT'), |
|
104 style.SQL_FIELD(qn('%s_%s_seq' % (model._meta.db_table, f.column))), |
|
105 style.SQL_FIELD(qn(f.column)), |
|
106 style.SQL_FIELD(qn(f.column)), |
|
107 style.SQL_KEYWORD('IS NOT'), |
|
108 style.SQL_KEYWORD('FROM'), |
|
109 style.SQL_TABLE(qn(model._meta.db_table)))) |
|
110 break # Only one AutoField is allowed per model, so don't bother continuing. |
|
111 for f in model._meta.many_to_many: |
|
112 output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ |
|
113 (style.SQL_KEYWORD('SELECT'), |
|
114 style.SQL_FIELD(qn('%s_id_seq' % f.m2m_db_table())), |
|
115 style.SQL_FIELD(qn('id')), |
|
116 style.SQL_FIELD(qn('id')), |
|
117 style.SQL_KEYWORD('IS NOT'), |
|
118 style.SQL_KEYWORD('FROM'), |
|
119 style.SQL_TABLE(qn(f.m2m_db_table())))) |
|
120 return output |