thirdparty/google_appengine/google/appengine/tools/bulkloader.py
changeset 3031 7678f72140e6
parent 2864 2e0b0af889be
--- a/thirdparty/google_appengine/google/appengine/tools/bulkloader.py	Fri Oct 23 11:17:07 2009 -0700
+++ b/thirdparty/google_appengine/google/appengine/tools/bulkloader.py	Fri Oct 23 13:54:11 2009 -0500
@@ -395,10 +395,6 @@
       self.row_count += 1
       if self.column_count is None:
         self.column_count = len(row)
-      else:
-        if self.column_count != len(row):
-          raise ResumeError('Column count mismatch, %d: %s' %
-                            (self.column_count, str(row)))
       self.read_rows.append((self.line_number, row))
       self.line_number += 1
 
@@ -1186,6 +1182,20 @@
     self.auth_called = True
     return (email, password)
 
+  def IncrementId(self, ancestor_path, kind, high_id):
+    """Increment the unique id counter associated with ancestor_path and kind.
+
+    Args:
+      ancestor_path: A list encoding the path of a key.
+      kind: The string name of a kind.
+      high_id: The int value to which to increment the unique id counter.
+    """
+    model_key = datastore.Key.from_path(*(ancestor_path + [kind, 1]))
+    start, end = datastore.AllocateIds(model_key, 1)
+    if end < high_id:
+      start, end = datastore.AllocateIds(model_key, high_id - end)
+    assert end >= high_id
+
   def EncodeContent(self, rows, loader=None):
     """Encodes row data to the wire format.
 
@@ -2368,6 +2378,14 @@
     """
     Loader.__loaders[loader.kind] = loader
 
+  def get_high_ids(self):
+    """Returns dict {ancestor_path : {kind : id}} with high id values.
+
+    The returned dictionary is used to increment the id counters
+    associated with each ancestor_path and kind to be at least id.
+    """
+    return {}
+
   def alias_old_names(self):
     """Aliases method names so that Loaders defined with old names work."""
     aliases = (
@@ -2546,21 +2564,54 @@
     cursor = db_conn.cursor()
     cursor.execute('select id, value from result')
     for entity_id, value in cursor:
-      self.queue.put([entity_id, value], block=True)
+      self.queue.put(value, block=True)
     self.queue.put(RestoreThread._ENTITIES_DONE, block=True)
 
 
 class RestoreLoader(Loader):
   """A Loader which imports protobuffers from a file."""
 
-  def __init__(self, kind):
+  def __init__(self, kind, app_id):
     self.kind = kind
+    self.app_id = app_id
 
   def initialize(self, filename, loader_opts):
     CheckFile(filename)
     self.queue = Queue.Queue(1000)
     restore_thread = RestoreThread(self.queue, filename)
     restore_thread.start()
+    self.high_id_table = self._find_high_id(self.generate_records(filename))
+    restore_thread = RestoreThread(self.queue, filename)
+    restore_thread.start()
+
+  def get_high_ids(self):
+    return dict(self.high_id_table)
+
+  def _find_high_id(self, record_generator):
+    """Find the highest numeric id used for each ancestor-path, kind pair.
+
+    Args:
+      record_generator: A generator of entity_encoding strings.
+
+    Returns:
+      A map from ancestor-path to maps from kind to id. {path : {kind : id}}
+    """
+    high_id = {}
+    for values in record_generator:
+      entity = self.create_entity(values)
+      key = entity.key()
+      if not key.id():
+        continue
+      kind = key.kind()
+      ancestor_path = []
+      if key.parent():
+        ancestor_path = key.parent().to_path()
+      if tuple(ancestor_path) not in high_id:
+        high_id[tuple(ancestor_path)] = {}
+      kind_map = high_id[tuple(ancestor_path)]
+      if kind not in kind_map or kind_map[kind] < key.id():
+        kind_map[kind] = key.id()
+    return high_id
 
   def generate_records(self, filename):
     while True:
@@ -2570,10 +2621,33 @@
       yield record
 
   def create_entity(self, values, key_name=None, parent=None):
-    key = StrKey(unicode(values[0], 'utf-8'))
-    entity_proto = entity_pb.EntityProto(contents=str(values[1]))
-    entity_proto.mutable_key().CopyFrom(key._Key__reference)
-    return datastore.Entity._FromPb(entity_proto)
+    entity_proto = entity_pb.EntityProto(contents=str(values))
+    fixed_entity_proto = self._translate_entity_proto(entity_proto)
+    return datastore.Entity._FromPb(fixed_entity_proto)
+
+  def rewrite_reference_proto(self, reference_proto):
+    """Transform the Reference protobuffer which underlies keys and references.
+
+    Args:
+      reference_proto: A Onestore Reference proto
+    """
+    reference_proto.set_app(self.app_id)
+
+  def _translate_entity_proto(self, entity_proto):
+    """Transform the ReferenceProperties of the given entity to fix app_id."""
+    entity_key = entity_proto.mutable_key()
+    entity_key.set_app(self.app_id)
+    for prop in entity_proto.property_list():
+      prop_value = prop.mutable_value()
+      if prop_value.has_referencevalue():
+        self.rewrite_reference_proto(prop_value.mutable_referencevalue())
+
+    for prop in entity_proto.raw_property_list():
+      prop_value = prop.mutable_value()
+      if prop_value.has_referencevalue():
+        self.rewrite_reference_proto(prop_value.mutable_referencevalue())
+
+    return entity_proto
 
 
 class Exporter(object):
@@ -2662,7 +2736,7 @@
     for name, fn, default in self.__properties:
       try:
         encoding.append(fn(entity[name]))
-      except AttributeError:
+      except KeyError:
         if default is None:
           raise MissingPropertyError(name)
         else:
@@ -2954,6 +3028,10 @@
      unused_query, unused_fragment) = urlparse.urlsplit(self.post_url)
     self.secure = (scheme == 'https')
 
+  def RunPostAuthentication(self):
+    """Method that gets called after authentication."""
+    pass
+
   def Run(self):
     """Perform the work of the BulkTransporterApp.
 
@@ -2971,29 +3049,31 @@
     threading.currentThread().exit_flag = False
 
     progress_queue = self.progress_queue_factory(self.max_queue_size)
-    request_manager = self.request_manager_factory(self.app_id,
-                                                   self.host_port,
-                                                   self.url_path,
-                                                   self.kind,
-                                                   self.throttle,
-                                                   self.batch_size,
-                                                   self.secure,
-                                                   self.email,
-                                                   self.passin,
-                                                   self.dry_run)
+    self.request_manager = self.request_manager_factory(self.app_id,
+                                                        self.host_port,
+                                                        self.url_path,
+                                                        self.kind,
+                                                        self.throttle,
+                                                        self.batch_size,
+                                                        self.secure,
+                                                        self.email,
+                                                        self.passin,
+                                                        self.dry_run)
     try:
-      request_manager.Authenticate()
+      self.request_manager.Authenticate()
     except Exception, e:
       self.error = True
       if not isinstance(e, urllib2.HTTPError) or (
           e.code != 302 and e.code != 401):
         logger.exception('Exception during authentication')
       raise AuthenticationError()
-    if (request_manager.auth_called and
-        not request_manager.authenticated):
+    if (self.request_manager.auth_called and
+        not self.request_manager.authenticated):
       self.error = True
       raise AuthenticationError('Authentication failed')
 
+    self.RunPostAuthentication()
+
     for thread in thread_pool.Threads():
       self.throttle.Register(thread)
 
@@ -3007,7 +3087,7 @@
       progress_generator_factory = None
 
     self.data_source_thread = (
-        self.datasourcethread_factory(request_manager,
+        self.datasourcethread_factory(self.request_manager,
                                       thread_pool,
                                       progress_queue,
                                       self.input_generator_factory,
@@ -3092,6 +3172,13 @@
   def __init__(self, *args, **kwargs):
     BulkTransporterApp.__init__(self, *args, **kwargs)
 
+  def RunPostAuthentication(self):
+    loader = Loader.RegisteredLoader(self.kind)
+    high_id_table = loader.get_high_ids()
+    for ancestor_path, kind_map in high_id_table.iteritems():
+      for kind, high_id in kind_map.iteritems():
+        self.request_manager.IncrementId(list(ancestor_path), kind, high_id)
+
   def ReportStatus(self):
     """Display a message reporting the final status of the transfer."""
     total_up, duration = self.throttle.TotalTransferred(
@@ -3625,7 +3712,7 @@
   if dump:
     Exporter.RegisterExporter(DumpExporter(kind, result_db_filename))
   elif restore:
-    Loader.RegisterLoader(RestoreLoader(kind))
+    Loader.RegisterLoader(RestoreLoader(kind, app_id))
   else:
     LoadConfig(config_file)