diff -r 09cae668b536 -r 7678f72140e6 thirdparty/google_appengine/google/appengine/tools/bulkloader.py --- a/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Fri Oct 23 11:17:07 2009 -0700 +++ b/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Fri Oct 23 13:54:11 2009 -0500 @@ -395,10 +395,6 @@ self.row_count += 1 if self.column_count is None: self.column_count = len(row) - else: - if self.column_count != len(row): - raise ResumeError('Column count mismatch, %d: %s' % - (self.column_count, str(row))) self.read_rows.append((self.line_number, row)) self.line_number += 1 @@ -1186,6 +1182,20 @@ self.auth_called = True return (email, password) + def IncrementId(self, ancestor_path, kind, high_id): + """Increment the unique id counter associated with ancestor_path and kind. + + Args: + ancestor_path: A list encoding the path of a key. + kind: The string name of a kind. + high_id: The int value to which to increment the unique id counter. + """ + model_key = datastore.Key.from_path(*(ancestor_path + [kind, 1])) + start, end = datastore.AllocateIds(model_key, 1) + if end < high_id: + start, end = datastore.AllocateIds(model_key, high_id - end) + assert end >= high_id + def EncodeContent(self, rows, loader=None): """Encodes row data to the wire format. @@ -2368,6 +2378,14 @@ """ Loader.__loaders[loader.kind] = loader + def get_high_ids(self): + """Returns dict {ancestor_path : {kind : id}} with high id values. + + The returned dictionary is used to increment the id counters + associated with each ancestor_path and kind to be at least id. + """ + return {} + def alias_old_names(self): """Aliases method names so that Loaders defined with old names work.""" aliases = ( @@ -2546,21 +2564,54 @@ cursor = db_conn.cursor() cursor.execute('select id, value from result') for entity_id, value in cursor: - self.queue.put([entity_id, value], block=True) + self.queue.put(value, block=True) self.queue.put(RestoreThread._ENTITIES_DONE, block=True) class RestoreLoader(Loader): """A Loader which imports protobuffers from a file.""" - def __init__(self, kind): + def __init__(self, kind, app_id): self.kind = kind + self.app_id = app_id def initialize(self, filename, loader_opts): CheckFile(filename) self.queue = Queue.Queue(1000) restore_thread = RestoreThread(self.queue, filename) restore_thread.start() + self.high_id_table = self._find_high_id(self.generate_records(filename)) + restore_thread = RestoreThread(self.queue, filename) + restore_thread.start() + + def get_high_ids(self): + return dict(self.high_id_table) + + def _find_high_id(self, record_generator): + """Find the highest numeric id used for each ancestor-path, kind pair. + + Args: + record_generator: A generator of entity_encoding strings. + + Returns: + A map from ancestor-path to maps from kind to id. {path : {kind : id}} + """ + high_id = {} + for values in record_generator: + entity = self.create_entity(values) + key = entity.key() + if not key.id(): + continue + kind = key.kind() + ancestor_path = [] + if key.parent(): + ancestor_path = key.parent().to_path() + if tuple(ancestor_path) not in high_id: + high_id[tuple(ancestor_path)] = {} + kind_map = high_id[tuple(ancestor_path)] + if kind not in kind_map or kind_map[kind] < key.id(): + kind_map[kind] = key.id() + return high_id def generate_records(self, filename): while True: @@ -2570,10 +2621,33 @@ yield record def create_entity(self, values, key_name=None, parent=None): - key = StrKey(unicode(values[0], 'utf-8')) - entity_proto = entity_pb.EntityProto(contents=str(values[1])) - entity_proto.mutable_key().CopyFrom(key._Key__reference) - return datastore.Entity._FromPb(entity_proto) + entity_proto = entity_pb.EntityProto(contents=str(values)) + fixed_entity_proto = self._translate_entity_proto(entity_proto) + return datastore.Entity._FromPb(fixed_entity_proto) + + def rewrite_reference_proto(self, reference_proto): + """Transform the Reference protobuffer which underlies keys and references. + + Args: + reference_proto: A Onestore Reference proto + """ + reference_proto.set_app(self.app_id) + + def _translate_entity_proto(self, entity_proto): + """Transform the ReferenceProperties of the given entity to fix app_id.""" + entity_key = entity_proto.mutable_key() + entity_key.set_app(self.app_id) + for prop in entity_proto.property_list(): + prop_value = prop.mutable_value() + if prop_value.has_referencevalue(): + self.rewrite_reference_proto(prop_value.mutable_referencevalue()) + + for prop in entity_proto.raw_property_list(): + prop_value = prop.mutable_value() + if prop_value.has_referencevalue(): + self.rewrite_reference_proto(prop_value.mutable_referencevalue()) + + return entity_proto class Exporter(object): @@ -2662,7 +2736,7 @@ for name, fn, default in self.__properties: try: encoding.append(fn(entity[name])) - except AttributeError: + except KeyError: if default is None: raise MissingPropertyError(name) else: @@ -2954,6 +3028,10 @@ unused_query, unused_fragment) = urlparse.urlsplit(self.post_url) self.secure = (scheme == 'https') + def RunPostAuthentication(self): + """Method that gets called after authentication.""" + pass + def Run(self): """Perform the work of the BulkTransporterApp. @@ -2971,29 +3049,31 @@ threading.currentThread().exit_flag = False progress_queue = self.progress_queue_factory(self.max_queue_size) - request_manager = self.request_manager_factory(self.app_id, - self.host_port, - self.url_path, - self.kind, - self.throttle, - self.batch_size, - self.secure, - self.email, - self.passin, - self.dry_run) + self.request_manager = self.request_manager_factory(self.app_id, + self.host_port, + self.url_path, + self.kind, + self.throttle, + self.batch_size, + self.secure, + self.email, + self.passin, + self.dry_run) try: - request_manager.Authenticate() + self.request_manager.Authenticate() except Exception, e: self.error = True if not isinstance(e, urllib2.HTTPError) or ( e.code != 302 and e.code != 401): logger.exception('Exception during authentication') raise AuthenticationError() - if (request_manager.auth_called and - not request_manager.authenticated): + if (self.request_manager.auth_called and + not self.request_manager.authenticated): self.error = True raise AuthenticationError('Authentication failed') + self.RunPostAuthentication() + for thread in thread_pool.Threads(): self.throttle.Register(thread) @@ -3007,7 +3087,7 @@ progress_generator_factory = None self.data_source_thread = ( - self.datasourcethread_factory(request_manager, + self.datasourcethread_factory(self.request_manager, thread_pool, progress_queue, self.input_generator_factory, @@ -3092,6 +3172,13 @@ def __init__(self, *args, **kwargs): BulkTransporterApp.__init__(self, *args, **kwargs) + def RunPostAuthentication(self): + loader = Loader.RegisteredLoader(self.kind) + high_id_table = loader.get_high_ids() + for ancestor_path, kind_map in high_id_table.iteritems(): + for kind, high_id in kind_map.iteritems(): + self.request_manager.IncrementId(list(ancestor_path), kind, high_id) + def ReportStatus(self): """Display a message reporting the final status of the transfer.""" total_up, duration = self.throttle.TotalTransferred( @@ -3625,7 +3712,7 @@ if dump: Exporter.RegisterExporter(DumpExporter(kind, result_db_filename)) elif restore: - Loader.RegisterLoader(RestoreLoader(kind)) + Loader.RegisterLoader(RestoreLoader(kind, app_id)) else: LoadConfig(config_file)