app/django/db/transaction.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/db/transaction.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/db/transaction.py	Tue Oct 14 16:00:59 2008 +0000
@@ -19,7 +19,7 @@
 try:
     from functools import wraps
 except ImportError:
-    from django.utils.functional import wraps  # Python 2.3, 2.4 fallback. 
+    from django.utils.functional import wraps  # Python 2.3, 2.4 fallback.
 from django.db import connection
 from django.conf import settings
 
@@ -30,9 +30,10 @@
     """
     pass
 
-# The state is a dictionary of lists. The key to the dict is the current
+# The states are dictionaries of lists. The key to the dict is the current
 # thread and the list is handled as a stack of values.
 state = {}
+savepoint_state = {}
 
 # The dirty flag is set by *_unless_managed functions to denote that the
 # code under transaction management has changed things to require a
@@ -104,6 +105,12 @@
         dirty[thread_ident] = False
     else:
         raise TransactionManagementError("This code isn't under transaction management")
+    clean_savepoints()
+
+def clean_savepoints():
+    thread_ident = thread.get_ident()
+    if thread_ident in savepoint_state:
+        del savepoint_state[thread_ident]
 
 def is_managed():
     """
@@ -138,6 +145,7 @@
     """
     if not is_managed():
         connection._commit()
+        clean_savepoints()
     else:
         set_dirty()
 
@@ -164,6 +172,38 @@
     connection._rollback()
     set_clean()
 
+def savepoint():
+    """
+    Creates a savepoint (if supported and required by the backend) inside the
+    current transaction. Returns an identifier for the savepoint that will be
+    used for the subsequent rollback or commit.
+    """
+    thread_ident = thread.get_ident()
+    if thread_ident in savepoint_state:
+        savepoint_state[thread_ident].append(None)
+    else:
+        savepoint_state[thread_ident] = [None]
+    tid = str(thread_ident).replace('-', '')
+    sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident]))
+    connection._savepoint(sid)
+    return sid
+
+def savepoint_rollback(sid):
+    """
+    Rolls back the most recent savepoint (if one exists). Does nothing if
+    savepoints are not supported.
+    """
+    if thread.get_ident() in savepoint_state:
+        connection._savepoint_rollback(sid)
+
+def savepoint_commit(sid):
+    """
+    Commits the most recent savepoint (if one exists). Does nothing if
+    savepoints are not supported.
+    """
+    if thread.get_ident() in savepoint_state:
+        connection._savepoint_commit(sid)
+
 ##############
 # DECORATORS #
 ##############
@@ -196,7 +236,8 @@
             managed(True)
             try:
                 res = func(*args, **kw)
-            except Exception, e:
+            except:
+                # All exceptions must be handled here (even string ones).
                 if is_dirty():
                     rollback()
                 raise