--- a/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Fri Oct 23 11:17:07 2009 -0700
+++ b/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Fri Oct 23 13:54:11 2009 -0500
@@ -395,10 +395,6 @@
self.row_count += 1
if self.column_count is None:
self.column_count = len(row)
- else:
- if self.column_count != len(row):
- raise ResumeError('Column count mismatch, %d: %s' %
- (self.column_count, str(row)))
self.read_rows.append((self.line_number, row))
self.line_number += 1
@@ -1186,6 +1182,20 @@
self.auth_called = True
return (email, password)
+ def IncrementId(self, ancestor_path, kind, high_id):
+ """Increment the unique id counter associated with ancestor_path and kind.
+
+ Args:
+ ancestor_path: A list encoding the path of a key.
+ kind: The string name of a kind.
+ high_id: The int value to which to increment the unique id counter.
+ """
+ model_key = datastore.Key.from_path(*(ancestor_path + [kind, 1]))
+ start, end = datastore.AllocateIds(model_key, 1)
+ if end < high_id:
+ start, end = datastore.AllocateIds(model_key, high_id - end)
+ assert end >= high_id
+
def EncodeContent(self, rows, loader=None):
"""Encodes row data to the wire format.
@@ -2368,6 +2378,14 @@
"""
Loader.__loaders[loader.kind] = loader
+ def get_high_ids(self):
+ """Returns dict {ancestor_path : {kind : id}} with high id values.
+
+ The returned dictionary is used to increment the id counters
+ associated with each ancestor_path and kind to be at least id.
+ """
+ return {}
+
def alias_old_names(self):
"""Aliases method names so that Loaders defined with old names work."""
aliases = (
@@ -2546,21 +2564,54 @@
cursor = db_conn.cursor()
cursor.execute('select id, value from result')
for entity_id, value in cursor:
- self.queue.put([entity_id, value], block=True)
+ self.queue.put(value, block=True)
self.queue.put(RestoreThread._ENTITIES_DONE, block=True)
class RestoreLoader(Loader):
"""A Loader which imports protobuffers from a file."""
- def __init__(self, kind):
+ def __init__(self, kind, app_id):
self.kind = kind
+ self.app_id = app_id
def initialize(self, filename, loader_opts):
CheckFile(filename)
self.queue = Queue.Queue(1000)
restore_thread = RestoreThread(self.queue, filename)
restore_thread.start()
+ self.high_id_table = self._find_high_id(self.generate_records(filename))
+ restore_thread = RestoreThread(self.queue, filename)
+ restore_thread.start()
+
+ def get_high_ids(self):
+ return dict(self.high_id_table)
+
+ def _find_high_id(self, record_generator):
+ """Find the highest numeric id used for each ancestor-path, kind pair.
+
+ Args:
+ record_generator: A generator of entity_encoding strings.
+
+ Returns:
+ A map from ancestor-path to maps from kind to id. {path : {kind : id}}
+ """
+ high_id = {}
+ for values in record_generator:
+ entity = self.create_entity(values)
+ key = entity.key()
+ if not key.id():
+ continue
+ kind = key.kind()
+ ancestor_path = []
+ if key.parent():
+ ancestor_path = key.parent().to_path()
+ if tuple(ancestor_path) not in high_id:
+ high_id[tuple(ancestor_path)] = {}
+ kind_map = high_id[tuple(ancestor_path)]
+ if kind not in kind_map or kind_map[kind] < key.id():
+ kind_map[kind] = key.id()
+ return high_id
def generate_records(self, filename):
while True:
@@ -2570,10 +2621,33 @@
yield record
def create_entity(self, values, key_name=None, parent=None):
- key = StrKey(unicode(values[0], 'utf-8'))
- entity_proto = entity_pb.EntityProto(contents=str(values[1]))
- entity_proto.mutable_key().CopyFrom(key._Key__reference)
- return datastore.Entity._FromPb(entity_proto)
+ entity_proto = entity_pb.EntityProto(contents=str(values))
+ fixed_entity_proto = self._translate_entity_proto(entity_proto)
+ return datastore.Entity._FromPb(fixed_entity_proto)
+
+ def rewrite_reference_proto(self, reference_proto):
+ """Transform the Reference protobuffer which underlies keys and references.
+
+ Args:
+ reference_proto: A Onestore Reference proto
+ """
+ reference_proto.set_app(self.app_id)
+
+ def _translate_entity_proto(self, entity_proto):
+ """Transform the ReferenceProperties of the given entity to fix app_id."""
+ entity_key = entity_proto.mutable_key()
+ entity_key.set_app(self.app_id)
+ for prop in entity_proto.property_list():
+ prop_value = prop.mutable_value()
+ if prop_value.has_referencevalue():
+ self.rewrite_reference_proto(prop_value.mutable_referencevalue())
+
+ for prop in entity_proto.raw_property_list():
+ prop_value = prop.mutable_value()
+ if prop_value.has_referencevalue():
+ self.rewrite_reference_proto(prop_value.mutable_referencevalue())
+
+ return entity_proto
class Exporter(object):
@@ -2662,7 +2736,7 @@
for name, fn, default in self.__properties:
try:
encoding.append(fn(entity[name]))
- except AttributeError:
+ except KeyError:
if default is None:
raise MissingPropertyError(name)
else:
@@ -2954,6 +3028,10 @@
unused_query, unused_fragment) = urlparse.urlsplit(self.post_url)
self.secure = (scheme == 'https')
+ def RunPostAuthentication(self):
+ """Method that gets called after authentication."""
+ pass
+
def Run(self):
"""Perform the work of the BulkTransporterApp.
@@ -2971,29 +3049,31 @@
threading.currentThread().exit_flag = False
progress_queue = self.progress_queue_factory(self.max_queue_size)
- request_manager = self.request_manager_factory(self.app_id,
- self.host_port,
- self.url_path,
- self.kind,
- self.throttle,
- self.batch_size,
- self.secure,
- self.email,
- self.passin,
- self.dry_run)
+ self.request_manager = self.request_manager_factory(self.app_id,
+ self.host_port,
+ self.url_path,
+ self.kind,
+ self.throttle,
+ self.batch_size,
+ self.secure,
+ self.email,
+ self.passin,
+ self.dry_run)
try:
- request_manager.Authenticate()
+ self.request_manager.Authenticate()
except Exception, e:
self.error = True
if not isinstance(e, urllib2.HTTPError) or (
e.code != 302 and e.code != 401):
logger.exception('Exception during authentication')
raise AuthenticationError()
- if (request_manager.auth_called and
- not request_manager.authenticated):
+ if (self.request_manager.auth_called and
+ not self.request_manager.authenticated):
self.error = True
raise AuthenticationError('Authentication failed')
+ self.RunPostAuthentication()
+
for thread in thread_pool.Threads():
self.throttle.Register(thread)
@@ -3007,7 +3087,7 @@
progress_generator_factory = None
self.data_source_thread = (
- self.datasourcethread_factory(request_manager,
+ self.datasourcethread_factory(self.request_manager,
thread_pool,
progress_queue,
self.input_generator_factory,
@@ -3092,6 +3172,13 @@
def __init__(self, *args, **kwargs):
BulkTransporterApp.__init__(self, *args, **kwargs)
+ def RunPostAuthentication(self):
+ loader = Loader.RegisteredLoader(self.kind)
+ high_id_table = loader.get_high_ids()
+ for ancestor_path, kind_map in high_id_table.iteritems():
+ for kind, high_id in kind_map.iteritems():
+ self.request_manager.IncrementId(list(ancestor_path), kind, high_id)
+
def ReportStatus(self):
"""Display a message reporting the final status of the transfer."""
total_up, duration = self.throttle.TotalTransferred(
@@ -3625,7 +3712,7 @@
if dump:
Exporter.RegisterExporter(DumpExporter(kind, result_db_filename))
elif restore:
- Loader.RegisterLoader(RestoreLoader(kind))
+ Loader.RegisterLoader(RestoreLoader(kind, app_id))
else:
LoadConfig(config_file)