--- a/app/django/db/transaction.py Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/db/transaction.py Tue Oct 14 16:00:59 2008 +0000
@@ -19,7 +19,7 @@
try:
from functools import wraps
except ImportError:
- from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
+ from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
from django.db import connection
from django.conf import settings
@@ -30,9 +30,10 @@
"""
pass
-# The state is a dictionary of lists. The key to the dict is the current
+# The states are dictionaries of lists. The key to the dict is the current
# thread and the list is handled as a stack of values.
state = {}
+savepoint_state = {}
# The dirty flag is set by *_unless_managed functions to denote that the
# code under transaction management has changed things to require a
@@ -104,6 +105,12 @@
dirty[thread_ident] = False
else:
raise TransactionManagementError("This code isn't under transaction management")
+ clean_savepoints()
+
+def clean_savepoints():
+ thread_ident = thread.get_ident()
+ if thread_ident in savepoint_state:
+ del savepoint_state[thread_ident]
def is_managed():
"""
@@ -138,6 +145,7 @@
"""
if not is_managed():
connection._commit()
+ clean_savepoints()
else:
set_dirty()
@@ -164,6 +172,38 @@
connection._rollback()
set_clean()
+def savepoint():
+ """
+ Creates a savepoint (if supported and required by the backend) inside the
+ current transaction. Returns an identifier for the savepoint that will be
+ used for the subsequent rollback or commit.
+ """
+ thread_ident = thread.get_ident()
+ if thread_ident in savepoint_state:
+ savepoint_state[thread_ident].append(None)
+ else:
+ savepoint_state[thread_ident] = [None]
+ tid = str(thread_ident).replace('-', '')
+ sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident]))
+ connection._savepoint(sid)
+ return sid
+
+def savepoint_rollback(sid):
+ """
+ Rolls back the most recent savepoint (if one exists). Does nothing if
+ savepoints are not supported.
+ """
+ if thread.get_ident() in savepoint_state:
+ connection._savepoint_rollback(sid)
+
+def savepoint_commit(sid):
+ """
+ Commits the most recent savepoint (if one exists). Does nothing if
+ savepoints are not supported.
+ """
+ if thread.get_ident() in savepoint_state:
+ connection._savepoint_commit(sid)
+
##############
# DECORATORS #
##############
@@ -196,7 +236,8 @@
managed(True)
try:
res = func(*args, **kw)
- except Exception, e:
+ except:
+ # All exceptions must be handled here (even string ones).
if is_dirty():
rollback()
raise