diff -r 27971a13089f -r 2e0b0af889be thirdparty/google_appengine/google/appengine/tools/bulkloader.py --- a/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Sat Sep 05 14:04:24 2009 +0200 +++ b/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Sun Sep 06 23:31:53 2009 +0200 @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - """Imports data over HTTP. Usage: @@ -33,7 +32,7 @@ the URL endpoint. The more data per row/Entity, the smaller the batch size should be. (Default 10) --config_file= File containing Model and Loader definitions. - (Required) + (Required unless --dump or --restore are used) --db_filename= Specific progress database to write to, or to resume from. If not supplied, then a new database will be started, named: @@ -41,6 +40,8 @@ The special filename "skip" may be used to simply skip reading/writing any progress information. --download Export entities to a file. + --dry_run Do not execute any remote_api calls. + --dump Use zero-configuration dump format. --email= The username to use. Will prompt if omitted. --exporter_opts= A string to pass to the Exporter.initialize method. @@ -54,9 +55,12 @@ --log_file= File to write bulkloader logs. If not supplied then a new log file will be created, named: bulkloader-log-TIMESTAMP. + --map Map an action across datastore entities. + --mapper_opts= A string to pass to the Mapper.Initialize method. --num_threads= Number of threads to use for uploading entities (Default 10) --passin Read the login password from stdin. + --restore Restore from zero-configuration dump format. --result_db_filename= Result database to write to for downloads. --rps_limit= The maximum number of records per second to @@ -78,7 +82,6 @@ -import cPickle import csv import errno import getopt @@ -88,20 +91,31 @@ import os import Queue import re +import shutil import signal import StringIO import sys import threading import time +import traceback import urllib2 import urlparse +from google.appengine.datastore import entity_pb + +from google.appengine.api import apiproxy_stub_map +from google.appengine.api import datastore from google.appengine.api import datastore_errors +from google.appengine.datastore import datastore_pb from google.appengine.ext import db +from google.appengine.ext import key_range as key_range_module from google.appengine.ext.db import polymodel from google.appengine.ext.remote_api import remote_api_stub +from google.appengine.ext.remote_api import throttle as remote_api_throttle from google.appengine.runtime import apiproxy_errors +from google.appengine.tools import adaptive_thread_pool from google.appengine.tools import appengine_rpc +from google.appengine.tools.requeue import ReQueue try: import sqlite3 @@ -110,10 +124,14 @@ logger = logging.getLogger('google.appengine.tools.bulkloader') +KeyRange = key_range_module.KeyRange + DEFAULT_THREAD_COUNT = 10 DEFAULT_BATCH_SIZE = 10 +DEFAULT_DOWNLOAD_BATCH_SIZE = 100 + DEFAULT_QUEUE_SIZE = DEFAULT_THREAD_COUNT * 10 _THREAD_SHOULD_EXIT = '_THREAD_SHOULD_EXIT' @@ -125,9 +143,7 @@ STATE_GETTING = 1 STATE_GOT = 2 -STATE_NOT_GOT = 3 - -MINIMUM_THROTTLE_SLEEP_DURATION = 0.001 +STATE_ERROR = 3 DATA_CONSUMED_TO_HERE = 'DATA_CONSUMED_TO_HERE' @@ -142,16 +158,8 @@ DEFAULT_REQUEST_LIMIT = 8 -BANDWIDTH_UP = 'http-bandwidth-up' -BANDWIDTH_DOWN = 'http-bandwidth-down' -REQUESTS = 'http-requests' -HTTPS_BANDWIDTH_UP = 'https-bandwidth-up' -HTTPS_BANDWIDTH_DOWN = 'https-bandwidth-down' -HTTPS_REQUESTS = 'https-requests' -RECORDS = 'records' - -MAXIMUM_INCREASE_DURATION = 8.0 -MAXIMUM_HOLD_DURATION = 10.0 +MAXIMUM_INCREASE_DURATION = 5.0 +MAXIMUM_HOLD_DURATION = 12.0 def ImportStateMessage(state): @@ -170,7 +178,17 @@ STATE_READ: 'Batch read from file.', STATE_GETTING: 'Fetching batch from server', STATE_GOT: 'Batch successfully fetched.', - STATE_NOT_GOT: 'Error while fetching batch' + STATE_ERROR: 'Error while fetching batch' + }[state]) + + +def MapStateMessage(state): + """Converts a numeric state identifier to a status message.""" + return ({ + STATE_READ: 'Batch read from file.', + STATE_GETTING: 'Querying for batch from server', + STATE_GOT: 'Batch successfully fetched.', + STATE_ERROR: 'Error while fetching or mapping.' }[state]) @@ -180,7 +198,7 @@ STATE_READ: 'READ', STATE_GETTING: 'GETTING', STATE_GOT: 'GOT', - STATE_NOT_GOT: 'NOT_GOT' + STATE_ERROR: 'NOT_GOT' }[state]) @@ -190,7 +208,7 @@ STATE_READ: 'READ', STATE_GETTING: 'SENDING', STATE_GOT: 'SENT', - STATE_NOT_GOT: 'NOT_SENT' + STATE_NOT_SENT: 'NOT_SENT' }[state]) @@ -234,16 +252,35 @@ """A filename passed in by the user refers to a non-writable output file.""" -class KeyRangeError(Error): - """Error while trying to generate a KeyRange.""" - - class BadStateError(Error): """A work item in an unexpected state was encountered.""" +class KeyRangeError(Error): + """An error during construction of a KeyRangeItem.""" + + +class FieldSizeLimitError(Error): + """The csv module tried to read a field larger than the size limit.""" + + def __init__(self, limit): + self.message = """ +A field in your CSV input file has exceeded the current limit of %d. + +You can raise this limit by adding the following lines to your config file: + +import csv +csv.field_size_limit(new_limit) + +where new_limit is number larger than the size in bytes of the largest +field in your CSV. +""" % limit + Error.__init__(self, self.message) + + class NameClashError(Error): """A name clash occurred while trying to alias old method names.""" + def __init__(self, old_name, new_name, klass): Error.__init__(self, old_name, new_name, klass) self.old_name = old_name @@ -253,48 +290,51 @@ def GetCSVGeneratorFactory(kind, csv_filename, batch_size, csv_has_header, openfile=open, create_csv_reader=csv.reader): - """Return a factory that creates a CSV-based WorkItem generator. + """Return a factory that creates a CSV-based UploadWorkItem generator. Args: kind: The kind of the entities being uploaded. csv_filename: File on disk containing CSV data. - batch_size: Maximum number of CSV rows to stash into a WorkItem. + batch_size: Maximum number of CSV rows to stash into an UploadWorkItem. csv_has_header: Whether to skip the first row of the CSV. openfile: Used for dependency injection. create_csv_reader: Used for dependency injection. Returns: A callable (accepting the Progress Queue and Progress Generators - as input) which creates the WorkItem generator. + as input) which creates the UploadWorkItem generator. """ loader = Loader.RegisteredLoader(kind) loader._Loader__openfile = openfile loader._Loader__create_csv_reader = create_csv_reader record_generator = loader.generate_records(csv_filename) - def CreateGenerator(progress_queue, progress_generator): - """Initialize a WorkItem generator linked to a progress generator and queue. + def CreateGenerator(request_manager, progress_queue, progress_generator): + """Initialize a UploadWorkItem generator. Args: + request_manager: A RequestManager instance. progress_queue: A ProgressQueue instance to send progress information. progress_generator: A generator of progress information or None. Returns: - A WorkItemGenerator instance. + An UploadWorkItemGenerator instance. """ - return WorkItemGenerator(progress_queue, - progress_generator, - record_generator, - csv_has_header, - batch_size) + return UploadWorkItemGenerator(request_manager, + progress_queue, + progress_generator, + record_generator, + csv_has_header, + batch_size) return CreateGenerator -class WorkItemGenerator(object): - """Reads rows from a row generator and generates WorkItems of batches.""" +class UploadWorkItemGenerator(object): + """Reads rows from a row generator and generates UploadWorkItems.""" def __init__(self, + request_manager, progress_queue, progress_generator, record_generator, @@ -303,12 +343,15 @@ """Initialize a WorkItemGenerator. Args: + request_manager: A RequestManager instance with which to associate + WorkItems. progress_queue: A progress queue with which to associate WorkItems. progress_generator: A generator of progress information. record_generator: A generator of data records. skip_first: Whether to skip the first data record. batch_size: The number of data records per WorkItem. """ + self.request_manager = request_manager self.progress_queue = progress_queue self.progress_generator = progress_generator self.reader = record_generator @@ -360,30 +403,29 @@ self.line_number += 1 def _MakeItem(self, key_start, key_end, rows, progress_key=None): - """Makes a WorkItem containing the given rows, with the given keys. + """Makes a UploadWorkItem containing the given rows, with the given keys. Args: - key_start: The start key for the WorkItem. - key_end: The end key for the WorkItem. - rows: A list of the rows for the WorkItem. - progress_key: The progress key for the WorkItem + key_start: The start key for the UploadWorkItem. + key_end: The end key for the UploadWorkItem. + rows: A list of the rows for the UploadWorkItem. + progress_key: The progress key for the UploadWorkItem Returns: - A WorkItem instance for the given batch. + An UploadWorkItem instance for the given batch. """ assert rows - item = WorkItem(self.progress_queue, rows, - key_start, key_end, - progress_key=progress_key) + item = UploadWorkItem(self.request_manager, self.progress_queue, rows, + key_start, key_end, progress_key=progress_key) return item def Batches(self): - """Reads from the record_generator and generates WorkItems. + """Reads from the record_generator and generates UploadWorkItems. Yields: - Instances of class WorkItem + Instances of class UploadWorkItem Raises: ResumeError: If the progress database and data file indicate a different @@ -468,37 +510,50 @@ """ csv_file = self.openfile(self.csv_filename, 'rb') reader = self.create_csv_reader(csv_file, skipinitialspace=True) - return reader - - -class KeyRangeGenerator(object): + try: + for record in reader: + yield record + except csv.Error, e: + if e.args and e.args[0].startswith('field larger than field limit'): + limit = e.args[1] + raise FieldSizeLimitError(limit) + else: + raise + + +class KeyRangeItemGenerator(object): """Generates ranges of keys to download. Reads progress information from the progress database and creates - KeyRange objects corresponding to incompletely downloaded parts of an + KeyRangeItem objects corresponding to incompletely downloaded parts of an export. """ - def __init__(self, kind, progress_queue, progress_generator): - """Initialize the KeyRangeGenerator. + def __init__(self, request_manager, kind, progress_queue, progress_generator, + key_range_item_factory): + """Initialize the KeyRangeItemGenerator. Args: + request_manager: A RequestManager instance. kind: The kind of entities being transferred. progress_queue: A queue used for tracking progress information. progress_generator: A generator of prior progress information, or None if there is no prior status. + key_range_item_factory: A factory to produce KeyRangeItems. """ + self.request_manager = request_manager self.kind = kind self.row_count = 0 self.xfer_count = 0 self.progress_queue = progress_queue self.progress_generator = progress_generator + self.key_range_item_factory = key_range_item_factory def Batches(self): """Iterate through saved progress information. Yields: - KeyRange instances corresponding to undownloaded key ranges. + KeyRangeItem instances corresponding to undownloaded key ranges. """ if self.progress_generator is not None: for progress_key, state, key_start, key_end in self.progress_generator: @@ -506,397 +561,27 @@ key_start = ParseKey(key_start) key_end = ParseKey(key_end) - result = KeyRange(self.progress_queue, - self.kind, - key_start=key_start, - key_end=key_end, - progress_key=progress_key, - direction=KeyRange.ASC, - state=STATE_READ) + key_range = KeyRange(key_start=key_start, + key_end=key_end) + + result = self.key_range_item_factory(self.request_manager, + self.progress_queue, + self.kind, + key_range, + progress_key=progress_key, + state=STATE_READ) yield result else: - - yield KeyRange( - self.progress_queue, self.kind, - key_start=None, - key_end=None, - direction=KeyRange.DESC) - - -class ReQueue(object): - """A special thread-safe queue. - - A ReQueue allows unfinished work items to be returned with a call to - reput(). When an item is reput, task_done() should *not* be called - in addition, getting an item that has been reput does not increase - the number of outstanding tasks. - - This class shares an interface with Queue.Queue and provides the - additional reput method. - """ - - def __init__(self, - queue_capacity, - requeue_capacity=None, - queue_factory=Queue.Queue, - get_time=time.time): - """Initialize a ReQueue instance. - - Args: - queue_capacity: The number of items that can be put in the ReQueue. - requeue_capacity: The numer of items that can be reput in the ReQueue. - queue_factory: Used for dependency injection. - get_time: Used for dependency injection. - """ - if requeue_capacity is None: - requeue_capacity = queue_capacity - - self.get_time = get_time - self.queue = queue_factory(queue_capacity) - self.requeue = queue_factory(requeue_capacity) - self.lock = threading.Lock() - self.put_cond = threading.Condition(self.lock) - self.get_cond = threading.Condition(self.lock) - - def _DoWithTimeout(self, - action, - exc, - wait_cond, - done_cond, - lock, - timeout=None, - block=True): - """Performs the given action with a timeout. - - The action must be non-blocking, and raise an instance of exc on a - recoverable failure. If the action fails with an instance of exc, - we wait on wait_cond before trying again. Failure after the - timeout is reached is propagated as an exception. Success is - signalled by notifying on done_cond and returning the result of - the action. If action raises any exception besides an instance of - exc, it is immediately propagated. - - Args: - action: A callable that performs a non-blocking action. - exc: An exception type that is thrown by the action to indicate - a recoverable error. - wait_cond: A condition variable which should be waited on when - action throws exc. - done_cond: A condition variable to signal if the action returns. - lock: The lock used by wait_cond and done_cond. - timeout: A non-negative float indicating the maximum time to wait. - block: Whether to block if the action cannot complete immediately. - - Returns: - The result of the action, if it is successful. - - Raises: - ValueError: If the timeout argument is negative. - """ - if timeout is not None and timeout < 0.0: - raise ValueError('\'timeout\' must not be a negative number') - if not block: - timeout = 0.0 - result = None - success = False - start_time = self.get_time() - lock.acquire() - try: - while not success: - try: - result = action() - success = True - except Exception, e: - if not isinstance(e, exc): - raise e - if timeout is not None: - elapsed_time = self.get_time() - start_time - timeout -= elapsed_time - if timeout <= 0.0: - raise e - wait_cond.wait(timeout) - finally: - if success: - done_cond.notify() - lock.release() - return result - - def put(self, item, block=True, timeout=None): - """Put an item into the requeue. - - Args: - item: An item to add to the requeue. - block: Whether to block if the requeue is full. - timeout: Maximum on how long to wait until the queue is non-full. - - Raises: - Queue.Full if the queue is full and the timeout expires. - """ - def PutAction(): - self.queue.put(item, block=False) - self._DoWithTimeout(PutAction, - Queue.Full, - self.get_cond, - self.put_cond, - self.lock, - timeout=timeout, - block=block) - - def reput(self, item, block=True, timeout=None): - """Re-put an item back into the requeue. - - Re-putting an item does not increase the number of outstanding - tasks, so the reput item should be uniquely associated with an - item that was previously removed from the requeue and for which - TaskDone has not been called. - - Args: - item: An item to add to the requeue. - block: Whether to block if the requeue is full. - timeout: Maximum on how long to wait until the queue is non-full. - - Raises: - Queue.Full is the queue is full and the timeout expires. - """ - def ReputAction(): - self.requeue.put(item, block=False) - self._DoWithTimeout(ReputAction, - Queue.Full, - self.get_cond, - self.put_cond, - self.lock, - timeout=timeout, - block=block) - - def get(self, block=True, timeout=None): - """Get an item from the requeue. - - Args: - block: Whether to block if the requeue is empty. - timeout: Maximum on how long to wait until the requeue is non-empty. - - Returns: - An item from the requeue. - - Raises: - Queue.Empty if the queue is empty and the timeout expires. - """ - def GetAction(): - try: - result = self.requeue.get(block=False) - self.requeue.task_done() - except Queue.Empty: - result = self.queue.get(block=False) - return result - return self._DoWithTimeout(GetAction, - Queue.Empty, - self.put_cond, - self.get_cond, - self.lock, - timeout=timeout, - block=block) - - def join(self): - """Blocks until all of the items in the requeue have been processed.""" - self.queue.join() - - def task_done(self): - """Indicate that a previously enqueued item has been fully processed.""" - self.queue.task_done() - - def empty(self): - """Returns true if the requeue is empty.""" - return self.queue.empty() and self.requeue.empty() - - def get_nowait(self): - """Try to get an item from the queue without blocking.""" - return self.get(block=False) - - def qsize(self): - return self.queue.qsize() + self.requeue.qsize() - - -class ThrottleHandler(urllib2.BaseHandler): - """A urllib2 handler for http and https requests that adds to a throttle.""" - - def __init__(self, throttle): - """Initialize a ThrottleHandler. - - Args: - throttle: A Throttle instance to call for bandwidth and http/https request - throttling. - """ - self.throttle = throttle - - def AddRequest(self, throttle_name, req): - """Add to bandwidth throttle for given request. - - Args: - throttle_name: The name of the bandwidth throttle to add to. - req: The request whose size will be added to the throttle. - """ - size = 0 - for key, value in req.headers.iteritems(): - size += len('%s: %s\n' % (key, value)) - for key, value in req.unredirected_hdrs.iteritems(): - size += len('%s: %s\n' % (key, value)) - (unused_scheme, - unused_host_port, url_path, - unused_query, unused_fragment) = urlparse.urlsplit(req.get_full_url()) - size += len('%s %s HTTP/1.1\n' % (req.get_method(), url_path)) - data = req.get_data() - if data: - size += len(data) - self.throttle.AddTransfer(throttle_name, size) - - def AddResponse(self, throttle_name, res): - """Add to bandwidth throttle for given response. - - Args: - throttle_name: The name of the bandwidth throttle to add to. - res: The response whose size will be added to the throttle. - """ - content = res.read() - def ReturnContent(): - return content - res.read = ReturnContent - size = len(content) - headers = res.info() - for key, value in headers.items(): - size += len('%s: %s\n' % (key, value)) - self.throttle.AddTransfer(throttle_name, size) - - def http_request(self, req): - """Process an HTTP request. - - If the throttle is over quota, sleep first. Then add request size to - throttle before returning it to be sent. - - Args: - req: A urllib2.Request object. - - Returns: - The request passed in. - """ - self.throttle.Sleep() - self.AddRequest(BANDWIDTH_UP, req) - return req - - def https_request(self, req): - """Process an HTTPS request. - - If the throttle is over quota, sleep first. Then add request size to - throttle before returning it to be sent. - - Args: - req: A urllib2.Request object. - - Returns: - The request passed in. - """ - self.throttle.Sleep() - self.AddRequest(HTTPS_BANDWIDTH_UP, req) - return req - - def http_response(self, unused_req, res): - """Process an HTTP response. - - The size of the response is added to the bandwidth throttle and the request - throttle is incremented by one. - - Args: - unused_req: The urllib2 request for this response. - res: A urllib2 response object. - - Returns: - The response passed in. - """ - self.AddResponse(BANDWIDTH_DOWN, res) - self.throttle.AddTransfer(REQUESTS, 1) - return res - - def https_response(self, unused_req, res): - """Process an HTTPS response. - - The size of the response is added to the bandwidth throttle and the request - throttle is incremented by one. - - Args: - unused_req: The urllib2 request for this response. - res: A urllib2 response object. - - Returns: - The response passed in. - """ - self.AddResponse(HTTPS_BANDWIDTH_DOWN, res) - self.throttle.AddTransfer(HTTPS_REQUESTS, 1) - return res - - -class ThrottledHttpRpcServer(appengine_rpc.HttpRpcServer): - """Provides a simplified RPC-style interface for HTTP requests. - - This RPC server uses a Throttle to prevent exceeding quotas. - """ - - def __init__(self, throttle, request_manager, *args, **kwargs): - """Initialize a ThrottledHttpRpcServer. - - Also sets request_manager.rpc_server to the ThrottledHttpRpcServer instance. - - Args: - throttle: A Throttles instance. - request_manager: A RequestManager instance. - args: Positional arguments to pass through to - appengine_rpc.HttpRpcServer.__init__ - kwargs: Keyword arguments to pass through to - appengine_rpc.HttpRpcServer.__init__ - """ - self.throttle = throttle - appengine_rpc.HttpRpcServer.__init__(self, *args, **kwargs) - request_manager.rpc_server = self - - def _GetOpener(self): - """Returns an OpenerDirector that supports cookies and ignores redirects. - - Returns: - A urllib2.OpenerDirector object. - """ - opener = appengine_rpc.HttpRpcServer._GetOpener(self) - opener.add_handler(ThrottleHandler(self.throttle)) - - return opener - - -def ThrottledHttpRpcServerFactory(throttle, request_manager): - """Create a factory to produce ThrottledHttpRpcServer for a given throttle. - - Args: - throttle: A Throttle instance to use for the ThrottledHttpRpcServer. - request_manager: A RequestManager instance. - - Returns: - A factory to produce a ThrottledHttpRpcServer. - """ - - def MakeRpcServer(*args, **kwargs): - """Factory to produce a ThrottledHttpRpcServer. - - Args: - args: Positional args to pass to ThrottledHttpRpcServer. - kwargs: Keyword args to pass to ThrottledHttpRpcServer. - - Returns: - A ThrottledHttpRpcServer instance. - """ - kwargs['account_type'] = 'HOSTED_OR_GOOGLE' - kwargs['save_cookies'] = True - return ThrottledHttpRpcServer(throttle, request_manager, *args, **kwargs) - return MakeRpcServer - - -class ExportResult(object): - """Holds the decoded content for the result of an export requests.""" + key_range = KeyRange() + + yield self.key_range_item_factory(self.request_manager, + self.progress_queue, + self.kind, + key_range) + + +class DownloadResult(object): + """Holds the result of an entity download.""" def __init__(self, continued, direction, keys, entities): self.continued = continued @@ -905,21 +590,31 @@ self.entities = entities self.count = len(keys) assert self.count == len(entities) - assert direction in (KeyRange.ASC, KeyRange.DESC) + assert direction in (key_range_module.KeyRange.ASC, + key_range_module.KeyRange.DESC) if self.count > 0: - if direction == KeyRange.ASC: + if direction == key_range_module.KeyRange.ASC: self.key_start = keys[0] self.key_end = keys[-1] else: self.key_start = keys[-1] self.key_end = keys[0] + def Entities(self): + """Returns the list of entities for this result in key order.""" + if self.direction == key_range_module.KeyRange.ASC: + return list(self.entities) + else: + result = list(self.entities) + result.reverse() + return result + def __str__(self): return 'continued = %s\n%s' % ( - str(self.continued), '\n'.join(self.entities)) - - -class _WorkItem(object): + str(self.continued), '\n'.join(str(self.entities))) + + +class _WorkItem(adaptive_thread_pool.WorkItem): """Holds a description of a unit of upload or download work.""" def __init__(self, progress_queue, key_start, key_end, state_namer, @@ -928,20 +623,101 @@ Args: progress_queue: A queue used for tracking progress information. - key_start: The starting key, inclusive. - key_end: The ending key, inclusive. + key_start: The start key of the work item. + key_end: The end key of the work item. state_namer: Function to describe work item states. state: The initial state of the work item. progress_key: If this WorkItem represents state from a prior run, then this will be the key within the progress database. """ + adaptive_thread_pool.WorkItem.__init__(self, + '[%s-%s]' % (key_start, key_end)) self.progress_queue = progress_queue - self.key_start = key_start - self.key_end = key_end self.state_namer = state_namer self.state = state self.progress_key = progress_key self.progress_event = threading.Event() + self.key_start = key_start + self.key_end = key_end + self.error = None + self.traceback = None + + def _TransferItem(self, thread_pool): + raise NotImplementedError() + + def SetError(self): + """Sets the error and traceback information for this thread. + + This must be called from an exception handler. + """ + if not self.error: + exc_info = sys.exc_info() + self.error = exc_info[1] + self.traceback = exc_info[2] + + def PerformWork(self, thread_pool): + """Perform the work of this work item and report the results. + + Args: + thread_pool: An AdaptiveThreadPool instance. + + Returns: + A tuple (status, instruction) of the work status and an instruction + for the ThreadGate. + """ + status = adaptive_thread_pool.WorkItem.FAILURE + instruction = adaptive_thread_pool.ThreadGate.DECREASE + + try: + self.MarkAsTransferring() + + try: + transfer_time = self._TransferItem(thread_pool) + if transfer_time is None: + status = adaptive_thread_pool.WorkItem.RETRY + instruction = adaptive_thread_pool.ThreadGate.HOLD + else: + logger.debug('[%s] %s Transferred %d entities in %0.1f seconds', + threading.currentThread().getName(), self, self.count, + transfer_time) + sys.stdout.write('.') + sys.stdout.flush() + status = adaptive_thread_pool.WorkItem.SUCCESS + if transfer_time <= MAXIMUM_INCREASE_DURATION: + instruction = adaptive_thread_pool.ThreadGate.INCREASE + elif transfer_time <= MAXIMUM_HOLD_DURATION: + instruction = adaptive_thread_pool.ThreadGate.HOLD + except (db.InternalError, db.NotSavedError, db.Timeout, + db.TransactionFailedError, + apiproxy_errors.OverQuotaError, + apiproxy_errors.DeadlineExceededError, + apiproxy_errors.ApplicationError), e: + status = adaptive_thread_pool.WorkItem.RETRY + logger.exception('Retrying on non-fatal datastore error: %s', e) + except urllib2.HTTPError, e: + http_status = e.code + if http_status == 403 or (http_status >= 500 and http_status < 600): + status = adaptive_thread_pool.WorkItem.RETRY + logger.exception('Retrying on non-fatal HTTP error: %d %s', + http_status, e.msg) + else: + self.SetError() + status = adaptive_thread_pool.WorkItem.FAILURE + except urllib2.URLError, e: + if IsURLErrorFatal(e): + self.SetError() + status = adaptive_thread_pool.WorkItem.FAILURE + else: + status = adaptive_thread_pool.WorkItem.RETRY + logger.exception('Retrying on non-fatal URL error: %s', e.reason) + + finally: + if status == adaptive_thread_pool.WorkItem.SUCCESS: + self.MarkAsTransferred() + else: + self.MarkAsError() + + return (status, instruction) def _AssertInState(self, *states): """Raises an Error if the state of this range is not in states.""" @@ -963,7 +739,7 @@ def MarkAsTransferring(self): """Mark this _WorkItem as transferring, updating the progress database.""" - self._AssertInState(STATE_READ, STATE_NOT_GOT) + self._AssertInState(STATE_READ, STATE_ERROR) self._AssertProgressKey() self._StateTransition(STATE_GETTING, blocking=True) @@ -975,7 +751,7 @@ """Mark this _WorkItem as failed, updating the progress database.""" self._AssertInState(STATE_GETTING) self._AssertProgressKey() - self._StateTransition(STATE_NOT_GOT, blocking=True) + self._StateTransition(STATE_ERROR, blocking=True) def _StateTransition(self, new_state, blocking=False): """Transition the work item to a new state, storing progress information. @@ -998,12 +774,12 @@ -class WorkItem(_WorkItem): +class UploadWorkItem(_WorkItem): """Holds a unit of uploading work. - A WorkItem represents a number of entities that need to be uploaded to + A UploadWorkItem represents a number of entities that need to be uploaded to Google App Engine. These entities are encoded in the "content" field of - the WorkItem, and will be POST'd as-is to the server. + the UploadWorkItem, and will be POST'd as-is to the server. The entities are identified by a range of numeric keys, inclusively. In the case of a resumption of an upload, or a replay to correct errors, @@ -1013,16 +789,17 @@ fill the entire range, they must simply bound a range of valid keys. """ - def __init__(self, progress_queue, rows, key_start, key_end, + def __init__(self, request_manager, progress_queue, rows, key_start, key_end, progress_key=None): - """Initialize the WorkItem instance. + """Initialize the UploadWorkItem instance. Args: + request_manager: A RequestManager instance. progress_queue: A queue used for tracking progress information. rows: A list of pairs of a line number and a list of column values key_start: The (numeric) starting key, inclusive. key_end: The (numeric) ending key, inclusive. - progress_key: If this WorkItem represents state from a prior run, + progress_key: If this UploadWorkItem represents state from a prior run, then this will be the key within the progress database. """ _WorkItem.__init__(self, progress_queue, key_start, key_end, @@ -1033,6 +810,7 @@ assert isinstance(key_end, (int, long)) assert key_start <= key_end + self.request_manager = request_manager self.rows = rows self.content = None self.count = len(rows) @@ -1040,8 +818,24 @@ def __str__(self): return '[%s-%s]' % (self.key_start, self.key_end) + def _TransferItem(self, thread_pool, get_time=time.time): + """Transfers the entities associated with an item. + + Args: + thread_pool: An AdaptiveThreadPool instance. + get_time: Used for dependency injection. + """ + t = get_time() + if not self.content: + self.content = self.request_manager.EncodeContent(self.rows) + try: + self.request_manager.PostEntities(self.content) + except: + raise + return get_time() - t + def MarkAsTransferred(self): - """Mark this WorkItem as sucessfully-sent to the server.""" + """Mark this UploadWorkItem as sucessfully-sent to the server.""" self._AssertInState(STATE_SENDING) self._AssertProgressKey() @@ -1068,45 +862,31 @@ implementation_class = db.class_for_kind(kind_or_class_key) return implementation_class -class EmptyQuery(db.Query): - def get(self): - return None - - def fetch(self, limit=1000, offset=0): - return [] - - def count(self, limit=1000): - return 0 - def KeyLEQ(key1, key2): """Compare two keys for less-than-or-equal-to. - All keys with numeric ids come before all keys with names. + All keys with numeric ids come before all keys with names. None represents + an unbounded end-point so it is both greater and less than any other key. Args: - key1: An int or db.Key instance. - key2: An int or db.Key instance. + key1: An int or datastore.Key instance. + key2: An int or datastore.Key instance. Returns: True if key1 <= key2 """ - if isinstance(key1, int) and isinstance(key2, int): - return key1 <= key2 if key1 is None or key2 is None: return True - if key1.id() and not key2.id(): - return True - return key1.id_or_name() <= key2.id_or_name() - - -class KeyRange(_WorkItem): - """Represents an item of download work. - - A KeyRange object represents a key range (key_start, key_end) and a - scan direction (KeyRange.DESC or KeyRange.ASC). The KeyRange object - has an associated state: STATE_READ, STATE_GETTING, STATE_GOT, and - STATE_ERROR. + return key1 <= key2 + + +class KeyRangeItem(_WorkItem): + """Represents an item of work that scans over a key range. + + A KeyRangeItem object represents holds a KeyRange + and has an associated state: STATE_READ, STATE_GETTING, STATE_GOT, + and STATE_ERROR. - STATE_READ indicates the range ready to be downloaded by a worker thread. - STATE_GETTING indicates the range is currently being downloaded. @@ -1114,280 +894,143 @@ - STATE_ERROR indicates that an error occurred during the last download attempt - KeyRanges not in the STATE_GOT state are stored in the progress database. - When a piece of KeyRange work is downloaded, the download may cover only - a portion of the range. In this case, the old KeyRange is removed from + KeyRangeItems not in the STATE_GOT state are stored in the progress database. + When a piece of KeyRangeItem work is downloaded, the download may cover only + a portion of the range. In this case, the old KeyRangeItem is removed from the progress database and ranges covering the undownloaded range are generated and stored as STATE_READ in the export progress database. """ - DESC = 0 - ASC = 1 - - MAX_KEY_LEN = 500 - def __init__(self, + request_manager, progress_queue, kind, - direction, - key_start=None, - key_end=None, - include_start=True, - include_end=True, + key_range, progress_key=None, state=STATE_READ): - """Initialize a KeyRange object. + """Initialize a KeyRangeItem object. Args: + request_manager: A RequestManager instance. progress_queue: A queue used for tracking progress information. kind: The kind of entities for this range. - direction: The direction of the query for this range. - key_start: The starting key for this range. - key_end: The ending key for this range. - include_start: Whether the start key should be included in the range. - include_end: Whether the end key should be included in the range. + key_range: A KeyRange instance for this work item. progress_key: The key for this range within the progress database. state: The initial state of this range. - - Raises: - KeyRangeError: if key_start is None. """ - assert direction in (KeyRange.ASC, KeyRange.DESC) - _WorkItem.__init__(self, progress_queue, key_start, key_end, - ExportStateName, state=state, progress_key=progress_key) + _WorkItem.__init__(self, progress_queue, key_range.key_start, + key_range.key_end, ExportStateName, state=state, + progress_key=progress_key) + self.request_manager = request_manager self.kind = kind - self.direction = direction - self.export_result = None + self.key_range = key_range + self.download_result = None self.count = 0 - self.include_start = include_start - self.include_end = include_end - self.SPLIT_KEY = db.Key.from_path(self.kind, unichr(0)) + self.key_start = key_range.key_start + self.key_end = key_range.key_end def __str__(self): - return '[%s-%s]' % (PrettyKey(self.key_start), PrettyKey(self.key_end)) + return str(self.key_range) def __repr__(self): return self.__str__() def MarkAsTransferred(self): - """Mark this KeyRange as transferred, updating the progress database.""" + """Mark this KeyRangeItem as transferred, updating the progress database.""" pass - def Process(self, export_result, num_threads, batch_size, work_queue): - """Mark this KeyRange as success, updating the progress database. - - Process will split this KeyRange based on the content of export_result and - adds the unfinished ranges to the work queue. + def Process(self, download_result, thread_pool, batch_size, + new_state=STATE_GOT): + """Mark this KeyRangeItem as success, updating the progress database. + + Process will split this KeyRangeItem based on the content of + download_result and adds the unfinished ranges to the work queue. Args: - export_result: An ExportResult instance. - num_threads: The number of threads for parallel transfers. + download_result: A DownloadResult instance. + thread_pool: An AdaptiveThreadPool instance. batch_size: The number of entities to transfer per request. - work_queue: The work queue to add unfinished ranges to. - - Returns: - A list of KeyRanges representing undownloaded datastore key ranges. + new_state: The state to transition the completed range to. """ self._AssertInState(STATE_GETTING) self._AssertProgressKey() - self.export_result = export_result - self.count = len(export_result.keys) - if export_result.continued: - self._FinishedRange()._StateTransition(STATE_GOT, blocking=True) - self._AddUnfinishedRanges(num_threads, batch_size, work_queue) + self.download_result = download_result + self.count = len(download_result.keys) + if download_result.continued: + self._FinishedRange()._StateTransition(new_state, blocking=True) + self._AddUnfinishedRanges(thread_pool, batch_size) else: - self._StateTransition(STATE_GOT, blocking=True) + self._StateTransition(new_state, blocking=True) def _FinishedRange(self): - """Returns the range completed by the export_result. - - Returns: - A KeyRange representing a completed range. - """ - assert self.export_result is not None - - if self.direction == KeyRange.ASC: - key_start = self.key_start - if self.export_result.continued: - key_end = self.export_result.key_end - else: - key_end = self.key_end - else: - key_end = self.key_end - if self.export_result.continued: - key_start = self.export_result.key_start - else: - key_start = self.key_start - - result = KeyRange(self.progress_queue, - self.kind, - key_start=key_start, - key_end=key_end, - direction=self.direction) - - result.progress_key = self.progress_key - result.export_result = self.export_result - result.state = self.state - result.count = self.count - return result - - def FilterQuery(self, query): - """Add query filter to restrict to this key range. - - Args: - query: A db.Query instance. - """ - if self.key_start == self.key_end and not ( - self.include_start or self.include_end): - return EmptyQuery() - if self.include_start: - start_comparator = '>=' - else: - start_comparator = '>' - if self.include_end: - end_comparator = '<=' - else: - end_comparator = '<' - if self.key_start and self.key_end: - query.filter('__key__ %s' % start_comparator, self.key_start) - query.filter('__key__ %s' % end_comparator, self.key_end) - elif self.key_start: - query.filter('__key__ %s' % start_comparator, self.key_start) - elif self.key_end: - query.filter('__key__ %s' % end_comparator, self.key_end) - - return query - - def MakeParallelQuery(self): - """Construct a query for this key range, for parallel downloading. - - Returns: - A db.Query instance. - - Raises: - KeyRangeError: if self.direction is not one of - KeyRange.ASC, KeyRange.DESC - """ - if self.direction == KeyRange.ASC: - direction = '' - elif self.direction == KeyRange.DESC: - direction = '-' - else: - raise KeyRangeError('KeyRange direction unexpected: %s', self.direction) - query = db.Query(GetImplementationClass(self.kind)) - query.order('%s__key__' % direction) - - return self.FilterQuery(query) - - def MakeSerialQuery(self): - """Construct a query for this key range without descending __key__ scan. + """Returns the range completed by the download_result. Returns: - A db.Query instance. + A KeyRangeItem representing a completed range. """ - query = db.Query(GetImplementationClass(self.kind)) - query.order('__key__') - - return self.FilterQuery(query) - - def _BisectStringRange(self, start, end): - if start == end: - return (start, start, end) - start += '\0' - end += '\0' - midpoint = [] - expected_max = 127 - for i in xrange(min(len(start), len(end))): - if start[i] == end[i]: - midpoint.append(start[i]) + assert self.download_result is not None + + if self.key_range.direction == key_range_module.KeyRange.ASC: + key_start = self.key_range.key_start + if self.download_result.continued: + key_end = self.download_result.key_end else: - ord_sum = ord(start[i]) + ord(end[i]) - midpoint.append(unichr(ord_sum / 2)) - if ord_sum % 2: - if len(start) > i + 1: - ord_start = ord(start[i+1]) - else: - ord_start = 0 - if ord_start < expected_max: - ord_split = (expected_max + ord_start) / 2 - else: - ord_split = (0xFFFF + ord_start) / 2 - midpoint.append(unichr(ord_split)) - break - return (start[:-1], ''.join(midpoint), end[:-1]) - - def SplitRange(self, key_start, include_start, key_end, include_end, - export_result, num_threads, batch_size, work_queue): - """Split the key range [key_start, key_end] into a list of ranges.""" - if export_result.direction == KeyRange.ASC: - key_start = export_result.key_end - include_start = False + key_end = self.key_range.key_end else: - key_end = export_result.key_start - include_end = False - key_pairs = [] - if not key_start: - key_pairs.append((key_start, include_start, key_end, include_end, - KeyRange.ASC)) - elif not key_end: - key_pairs.append((key_start, include_start, key_end, include_end, - KeyRange.DESC)) - elif work_queue.qsize() > 2 * num_threads: - key_pairs.append((key_start, include_start, key_end, include_end, - KeyRange.ASC)) - elif key_start.id() and key_end.id(): - if key_end.id() - key_start.id() > batch_size: - key_half = db.Key.from_path(self.kind, - (key_start.id() + key_end.id()) / 2) - key_pairs.append((key_start, include_start, - key_half, True, - KeyRange.DESC)) - key_pairs.append((key_half, False, - key_end, include_end, - KeyRange.ASC)) + key_end = self.key_range.key_end + if self.download_result.continued: + key_start = self.download_result.key_start else: - key_pairs.append((key_start, include_start, key_end, include_end, - KeyRange.ASC)) - elif key_start.name() and key_end.name(): - (start, middle, end) = self._BisectStringRange(key_start.name(), - key_end.name()) - key_pairs.append((key_start, include_start, - db.Key.from_path(self.kind, middle), True, - KeyRange.DESC)) - key_pairs.append((db.Key.from_path(self.kind, middle), False, - key_end, include_end, - KeyRange.ASC)) + key_start = self.key_range.key_start + + key_range = KeyRange(key_start=key_start, + key_end=key_end, + direction=self.key_range.direction) + + result = self.__class__(self.request_manager, + self.progress_queue, + self.kind, + key_range, + progress_key=self.progress_key, + state=self.state) + + result.download_result = self.download_result + result.count = self.count + return result + + def _SplitAndAddRanges(self, thread_pool, batch_size): + """Split the key range [key_start, key_end] into a list of ranges.""" + if self.download_result.direction == key_range_module.KeyRange.ASC: + key_range = KeyRange( + key_start=self.download_result.key_end, + key_end=self.key_range.key_end, + include_start=False) else: - assert key_start.id() and key_end.name() - key_pairs.append((key_start, include_start, - self.SPLIT_KEY, False, - KeyRange.DESC)) - key_pairs.append((self.SPLIT_KEY, True, - key_end, include_end, - KeyRange.ASC)) - - ranges = [KeyRange(self.progress_queue, - self.kind, - key_start=start, - include_start=include_start, - key_end=end, - include_end=include_end, - direction=direction) - for (start, include_start, end, include_end, direction) - in key_pairs] + key_range = KeyRange( + key_start=self.key_range.key_start, + key_end=self.download_result.key_start, + include_end=False) + + if thread_pool.QueuedItemCount() > 2 * thread_pool.num_threads(): + ranges = [key_range] + else: + ranges = key_range.split_range(batch_size=batch_size) for key_range in ranges: - key_range.MarkAsRead() - work_queue.put(key_range, block=True) - - def _AddUnfinishedRanges(self, num_threads, batch_size, work_queue): - """Adds incomplete KeyRanges to the work_queue. + key_range_item = self.__class__(self.request_manager, + self.progress_queue, + self.kind, + key_range) + key_range_item.MarkAsRead() + thread_pool.SubmitItem(key_range_item, block=True) + + def _AddUnfinishedRanges(self, thread_pool, batch_size): + """Adds incomplete KeyRanges to the thread_pool. Args: - num_threads: The number of threads for parallel transfers. + thread_pool: An AdaptiveThreadPool instance. batch_size: The number of entities to transfer per request. - work_queue: The work queue to add unfinished ranges to. Returns: A list of KeyRanges representing incomplete datastore key ranges. @@ -1395,15 +1038,43 @@ Raises: KeyRangeError: if this key range has already been completely transferred. """ - assert self.export_result is not None - if self.export_result.continued: - self.SplitRange(self.key_start, self.include_start, self.key_end, - self.include_end, self.export_result, - num_threads, batch_size, work_queue) + assert self.download_result is not None + if self.download_result.continued: + self._SplitAndAddRanges(thread_pool, batch_size) else: raise KeyRangeError('No unfinished part of key range.') +class DownloadItem(KeyRangeItem): + """A KeyRangeItem for downloading key ranges.""" + + def _TransferItem(self, thread_pool, get_time=time.time): + """Transfers the entities associated with an item.""" + t = get_time() + download_result = self.request_manager.GetEntities(self) + transfer_time = get_time() - t + self.Process(download_result, thread_pool, + self.request_manager.batch_size) + return transfer_time + + +class MapperItem(KeyRangeItem): + """A KeyRangeItem for mapping over key ranges.""" + + def _TransferItem(self, thread_pool, get_time=time.time): + t = get_time() + download_result = self.request_manager.GetEntities(self) + transfer_time = get_time() - t + mapper = self.request_manager.GetMapper() + try: + mapper.batch_apply(download_result.Entities()) + except MapperRetry: + return None + self.Process(download_result, thread_pool, + self.request_manager.batch_size) + return transfer_time + + class RequestManager(object): """A class which wraps a connection to the server.""" @@ -1416,7 +1087,8 @@ batch_size, secure, email, - passin): + passin, + dry_run=False): """Initialize a RequestManager object. Args: @@ -1445,23 +1117,39 @@ self.parallel_download = True self.email = email self.passin = passin - throttled_rpc_server_factory = ThrottledHttpRpcServerFactory( - self.throttle, self) + self.mapper = None + self.dry_run = dry_run + + if self.dry_run: + logger.info('Running in dry run mode, skipping remote_api setup') + return + logger.debug('Configuring remote_api. url_path = %s, ' 'servername = %s' % (url_path, host_port)) + + def CookieHttpRpcServer(*args, **kwargs): + kwargs['save_cookies'] = True + kwargs['account_type'] = 'HOSTED_OR_GOOGLE' + return appengine_rpc.HttpRpcServer(*args, **kwargs) + remote_api_stub.ConfigureRemoteDatastore( app_id, url_path, self.AuthFunction, servername=host_port, - rpc_server_factory=throttled_rpc_server_factory, + rpc_server_factory=CookieHttpRpcServer, secure=self.secure) + remote_api_throttle.ThrottleRemoteDatastore(self.throttle) logger.debug('Bulkloader using app_id: %s', os.environ['APPLICATION_ID']) def Authenticate(self): """Invoke authentication if necessary.""" - logger.info('Connecting to %s', self.url_path) - self.rpc_server.Send(self.url_path, payload=None) + logger.info('Connecting to %s%s', self.host_port, self.url_path) + if self.dry_run: + self.authenticated = True + return + + remote_api_stub.MaybeInvokeAuthentication() self.authenticated = True def AuthFunction(self, @@ -1506,7 +1194,7 @@ loader: Used for dependency injection. Returns: - A list of db.Model instances. + A list of datastore.Entity instances. Raises: ConfigurationError: if no loader is defined for self.kind @@ -1520,77 +1208,112 @@ entities = [] for line_number, values in rows: key = loader.generate_key(line_number, values) - if isinstance(key, db.Key): + if isinstance(key, datastore.Key): parent = key.parent() key = key.name() else: parent = None entity = loader.create_entity(values, key_name=key, parent=parent) + + def ToEntity(entity): + if isinstance(entity, db.Model): + return entity._populate_entity() + else: + return entity + if isinstance(entity, list): - entities.extend(entity) + entities.extend(map(ToEntity, entity)) elif entity: - entities.append(entity) + entities.append(ToEntity(entity)) return entities - def PostEntities(self, item): + def PostEntities(self, entities): """Posts Entity records to a remote endpoint over HTTP. Args: - item: A workitem containing the entities to post. - - Returns: - A pair of the estimated size of the request in bytes and the response - from the server as a str. + entities: A list of datastore entities. """ - entities = item.content - db.put(entities) - - def GetEntities(self, key_range): + if self.dry_run: + return + datastore.Put(entities) + + def _QueryForPbs(self, query): + """Perform the given query and return a list of entity_pb's.""" + try: + query_pb = query._ToPb(limit=self.batch_size) + result_pb = datastore_pb.QueryResult() + apiproxy_stub_map.MakeSyncCall('datastore_v3', 'RunQuery', query_pb, + result_pb) + next_pb = datastore_pb.NextRequest() + next_pb.set_count(self.batch_size) + next_pb.mutable_cursor().CopyFrom(result_pb.cursor()) + result_pb = datastore_pb.QueryResult() + apiproxy_stub_map.MakeSyncCall('datastore_v3', 'Next', next_pb, result_pb) + return result_pb.result_list() + except apiproxy_errors.ApplicationError, e: + raise datastore._ToDatastoreError(e) + + def GetEntities(self, key_range_item, key_factory=datastore.Key): """Gets Entity records from a remote endpoint over HTTP. Args: - key_range: Range of keys to get. + key_range_item: Range of keys to get. + key_factory: Used for dependency injection. Returns: - An ExportResult instance. + A DownloadResult instance. Raises: ConfigurationError: if no Exporter is defined for self.kind """ - try: - Exporter.RegisteredExporter(self.kind) - except KeyError: - raise ConfigurationError('No Exporter defined for kind %s.' % self.kind) - keys = [] entities = [] if self.parallel_download: - query = key_range.MakeParallelQuery() + query = key_range_item.key_range.make_directed_datastore_query(self.kind) try: - results = query.fetch(self.batch_size) + results = self._QueryForPbs(query) except datastore_errors.NeedIndexError: logger.info('%s: No descending index on __key__, ' 'performing serial download', self.kind) self.parallel_download = False if not self.parallel_download: - key_range.direction = KeyRange.ASC - query = key_range.MakeSerialQuery() - results = query.fetch(self.batch_size) + key_range_item.key_range.direction = key_range_module.KeyRange.ASC + query = key_range_item.key_range.make_ascending_datastore_query(self.kind) + results = self._QueryForPbs(query) size = len(results) - for model in results: - key = model.key() - entities.append(cPickle.dumps(model)) + for entity in results: + key = key_factory() + key._Key__reference = entity.key() + entities.append(entity) keys.append(key) continued = (size == self.batch_size) - key_range.count = size - - return ExportResult(continued, key_range.direction, keys, entities) + key_range_item.count = size + + return DownloadResult(continued, key_range_item.key_range.direction, + keys, entities) + + def GetMapper(self): + """Returns a mapper for the registered kind. + + Returns: + A Mapper instance. + + Raises: + ConfigurationError: if no Mapper is defined for self.kind + """ + if not self.mapper: + try: + self.mapper = Mapper.RegisteredMapper(self.kind) + except KeyError: + logger.error('No Mapper defined for kind %s.' % self.kind) + raise ConfigurationError('No Mapper defined for kind %s.' % self.kind) + return self.mapper def InterruptibleSleep(sleep_time): @@ -1611,357 +1334,6 @@ return -class ThreadGate(object): - """Manage the number of active worker threads. - - The ThreadGate limits the number of threads that are simultaneously - uploading batches of records in order to implement adaptive rate - control. The number of simultaneous upload threads that it takes to - start causing timeout varies widely over the course of the day, so - adaptive rate control allows the uploader to do many uploads while - reducing the error rate and thus increasing the throughput. - - Initially the ThreadGate allows only one uploader thread to be active. - For each successful upload, another thread is activated and for each - failed upload, the number of active threads is reduced by one. - """ - - def __init__(self, enabled, - threshhold1=MAXIMUM_INCREASE_DURATION, - threshhold2=MAXIMUM_HOLD_DURATION, - sleep=InterruptibleSleep): - """Constructor for ThreadGate instances. - - Args: - enabled: Whether the thread gate is enabled - threshhold1: Maximum duration (in seconds) for a transfer to increase - the number of active threads. - threshhold2: Maximum duration (in seconds) for a transfer to not decrease - the number of active threads. - """ - self.enabled = enabled - self.enabled_count = 1 - self.lock = threading.Lock() - self.thread_semaphore = threading.Semaphore(self.enabled_count) - self._threads = [] - self.backoff_time = 0 - self.sleep = sleep - self.threshhold1 = threshhold1 - self.threshhold2 = threshhold2 - - def Register(self, thread): - """Register a thread with the thread gate.""" - self._threads.append(thread) - - def Threads(self): - """Yields the registered threads.""" - for thread in self._threads: - yield thread - - def EnableThread(self): - """Enable one more worker thread.""" - self.lock.acquire() - try: - self.enabled_count += 1 - finally: - self.lock.release() - self.thread_semaphore.release() - - def EnableAllThreads(self): - """Enable all worker threads.""" - for unused_idx in xrange(len(self._threads) - self.enabled_count): - self.EnableThread() - - def StartWork(self): - """Starts a critical section in which the number of workers is limited. - - If thread throttling is enabled then this method starts a critical - section which allows self.enabled_count simultaneously operating - threads. The critical section is ended by calling self.FinishWork(). - """ - if self.enabled: - self.thread_semaphore.acquire() - if self.backoff_time > 0.0: - if not threading.currentThread().exit_flag: - logger.info('Backing off: %.1f seconds', - self.backoff_time) - self.sleep(self.backoff_time) - - def FinishWork(self): - """Ends a critical section started with self.StartWork().""" - if self.enabled: - self.thread_semaphore.release() - - def TransferSuccess(self, duration): - """Informs the throttler that an item was successfully sent. - - If thread throttling is enabled and the duration is low enough, this - method will cause an additional thread to run in the critical section. - - Args: - duration: The duration of the transfer in seconds. - """ - if duration > self.threshhold2: - logger.debug('Transfer took %s, decreasing workers.', duration) - self.DecreaseWorkers(backoff=False) - return - elif duration > self.threshhold1: - logger.debug('Transfer took %s, not increasing workers.', duration) - return - elif self.enabled: - if self.backoff_time > 0.0: - logger.info('Resetting backoff to 0.0') - self.backoff_time = 0.0 - do_enable = False - self.lock.acquire() - try: - if self.enabled and len(self._threads) > self.enabled_count: - do_enable = True - self.enabled_count += 1 - finally: - self.lock.release() - if do_enable: - logger.debug('Increasing active thread count to %d', - self.enabled_count) - self.thread_semaphore.release() - - def DecreaseWorkers(self, backoff=True): - """Informs the thread_gate that an item failed to send. - - If thread throttling is enabled, this method will cause the - throttler to allow one fewer thread in the critical section. If - there is only one thread remaining, failures will result in - exponential backoff until there is a success. - - Args: - backoff: Whether to increase exponential backoff if there is only - one thread enabled. - """ - if self.enabled: - do_disable = False - self.lock.acquire() - try: - if self.enabled: - if self.enabled_count > 1: - do_disable = True - self.enabled_count -= 1 - elif backoff: - if self.backoff_time == 0.0: - self.backoff_time = INITIAL_BACKOFF - else: - self.backoff_time *= BACKOFF_FACTOR - finally: - self.lock.release() - if do_disable: - logger.debug('Decreasing the number of active threads to %d', - self.enabled_count) - self.thread_semaphore.acquire() - - -class Throttle(object): - """A base class for upload rate throttling. - - Transferring large number of records, too quickly, to an application - could trigger quota limits and cause the transfer process to halt. - In order to stay within the application's quota, we throttle the - data transfer to a specified limit (across all transfer threads). - This limit defaults to about half of the Google App Engine default - for an application, but can be manually adjusted faster/slower as - appropriate. - - This class tracks a moving average of some aspect of the transfer - rate (bandwidth, records per second, http connections per - second). It keeps two windows of counts of bytes transferred, on a - per-thread basis. One block is the "current" block, and the other is - the "prior" block. It will rotate the counts from current to prior - when ROTATE_PERIOD has passed. Thus, the current block will - represent from 0 seconds to ROTATE_PERIOD seconds of activity - (determined by: time.time() - self.last_rotate). The prior block - will always represent a full ROTATE_PERIOD. - - Sleeping is performed just before a transfer of another block, and is - based on the counts transferred *before* the next transfer. It really - does not matter how much will be transferred, but only that for all the - data transferred SO FAR that we have interspersed enough pauses to - ensure the aggregate transfer rate is within the specified limit. - - These counts are maintained on a per-thread basis, so we do not require - any interlocks around incrementing the counts. There IS an interlock on - the rotation of the counts because we do not want multiple threads to - multiply-rotate the counts. - - There are various race conditions in the computation and collection - of these counts. We do not require precise values, but simply to - keep the overall transfer within the bandwidth limits. If a given - pause is a little short, or a little long, then the aggregate delays - will be correct. - """ - - ROTATE_PERIOD = 600 - - def __init__(self, - get_time=time.time, - thread_sleep=InterruptibleSleep, - layout=None): - self.get_time = get_time - self.thread_sleep = thread_sleep - - self.start_time = get_time() - self.transferred = {} - self.prior_block = {} - self.totals = {} - self.throttles = {} - - self.last_rotate = {} - self.rotate_mutex = {} - if layout: - self.AddThrottles(layout) - - def AddThrottle(self, name, limit): - self.throttles[name] = limit - self.transferred[name] = {} - self.prior_block[name] = {} - self.totals[name] = {} - self.last_rotate[name] = self.get_time() - self.rotate_mutex[name] = threading.Lock() - - def AddThrottles(self, layout): - for key, value in layout.iteritems(): - self.AddThrottle(key, value) - - def Register(self, thread): - """Register this thread with the throttler.""" - thread_name = thread.getName() - for throttle_name in self.throttles.iterkeys(): - self.transferred[throttle_name][thread_name] = 0 - self.prior_block[throttle_name][thread_name] = 0 - self.totals[throttle_name][thread_name] = 0 - - def VerifyName(self, throttle_name): - if throttle_name not in self.throttles: - raise AssertionError('%s is not a registered throttle' % throttle_name) - - def AddTransfer(self, throttle_name, token_count): - """Add a count to the amount this thread has transferred. - - Each time a thread transfers some data, it should call this method to - note the amount sent. The counts may be rotated if sufficient time - has passed since the last rotation. - - Note: this method should only be called by the BulkLoaderThread - instances. The token count is allocated towards the - "current thread". - - Args: - throttle_name: The name of the throttle to add to. - token_count: The number to add to the throttle counter. - """ - self.VerifyName(throttle_name) - transferred = self.transferred[throttle_name] - transferred[threading.currentThread().getName()] += token_count - - if self.last_rotate[throttle_name] + self.ROTATE_PERIOD < self.get_time(): - self._RotateCounts(throttle_name) - - def Sleep(self, throttle_name=None): - """Possibly sleep in order to limit the transfer rate. - - Note that we sleep based on *prior* transfers rather than what we - may be about to transfer. The next transfer could put us under/over - and that will be rectified *after* that transfer. Net result is that - the average transfer rate will remain within bounds. Spiky behavior - or uneven rates among the threads could possibly bring the transfer - rate above the requested limit for short durations. - - Args: - throttle_name: The name of the throttle to sleep on. If None or - omitted, then sleep on all throttles. - """ - if throttle_name is None: - for throttle_name in self.throttles: - self.Sleep(throttle_name=throttle_name) - return - - self.VerifyName(throttle_name) - - thread = threading.currentThread() - - while True: - duration = self.get_time() - self.last_rotate[throttle_name] - - total = 0 - for count in self.prior_block[throttle_name].values(): - total += count - - if total: - duration += self.ROTATE_PERIOD - - for count in self.transferred[throttle_name].values(): - total += count - - sleep_time = (float(total) / self.throttles[throttle_name]) - duration - - if sleep_time < MINIMUM_THROTTLE_SLEEP_DURATION: - break - - logger.debug('[%s] Throttling on %s. Sleeping for %.1f ms ' - '(duration=%.1f ms, total=%d)', - thread.getName(), throttle_name, - sleep_time * 1000, duration * 1000, total) - self.thread_sleep(sleep_time) - if thread.exit_flag: - break - self._RotateCounts(throttle_name) - - def _RotateCounts(self, throttle_name): - """Rotate the transfer counters. - - If sufficient time has passed, then rotate the counters from active to - the prior-block of counts. - - This rotation is interlocked to ensure that multiple threads do not - over-rotate the counts. - - Args: - throttle_name: The name of the throttle to rotate. - """ - self.VerifyName(throttle_name) - self.rotate_mutex[throttle_name].acquire() - try: - next_rotate_time = self.last_rotate[throttle_name] + self.ROTATE_PERIOD - if next_rotate_time >= self.get_time(): - return - - for name, count in self.transferred[throttle_name].items(): - - - self.prior_block[throttle_name][name] = count - self.transferred[throttle_name][name] = 0 - - self.totals[throttle_name][name] += count - - self.last_rotate[throttle_name] = self.get_time() - - finally: - self.rotate_mutex[throttle_name].release() - - def TotalTransferred(self, throttle_name): - """Return the total transferred, and over what period. - - Args: - throttle_name: The name of the throttle to total. - - Returns: - A tuple of the total count and running time for the given throttle name. - """ - total = 0 - for count in self.totals[throttle_name].values(): - total += count - for count in self.transferred[throttle_name].values(): - total += count - return total, self.get_time() - self.start_time - - class _ThreadBase(threading.Thread): """Provide some basic features for the threads used in the uploader. @@ -1993,18 +1365,29 @@ self.exit_flag = False self.error = None + self.traceback = None def run(self): """Perform the work of the thread.""" - logger.info('[%s] %s: started', self.getName(), self.__class__.__name__) + logger.debug('[%s] %s: started', self.getName(), self.__class__.__name__) try: self.PerformWork() except: - self.error = sys.exc_info()[1] + self.SetError() logger.exception('[%s] %s:', self.getName(), self.__class__.__name__) - logger.info('[%s] %s: exiting', self.getName(), self.__class__.__name__) + logger.debug('[%s] %s: exiting', self.getName(), self.__class__.__name__) + + def SetError(self): + """Sets the error and traceback information for this thread. + + This must be called from an exception handler. + """ + if not self.error: + exc_info = sys.exc_info() + self.error = exc_info[1] + self.traceback = exc_info[2] def PerformWork(self): """Perform the thread-specific work.""" @@ -2014,6 +1397,10 @@ """If an error is present, then log it.""" if self.error: logger.error('Error in %s: %s', self.GetFriendlyName(), self.error) + if self.traceback: + logger.debug(''.join(traceback.format_exception(self.error.__class__, + self.error, + self.traceback))) def GetFriendlyName(self): """Returns a human-friendly description of the thread.""" @@ -2044,292 +1431,12 @@ return error.reason[0] not in non_fatal_error_codes -def PrettyKey(key): - """Returns a nice string representation of the given key.""" - if key is None: - return None - elif isinstance(key, db.Key): - return repr(key.id_or_name()) - return str(key) - - -class _BulkWorkerThread(_ThreadBase): - """A base class for worker threads. - - This thread will read WorkItem instances from the work_queue and upload - the entities to the server application. Progress information will be - pushed into the progress_queue as the work is being performed. - - If a _BulkWorkerThread encounters a transient error, the entities will be - resent, if a fatal error is encoutered the BulkWorkerThread exits. - - Subclasses must provide implementations for PreProcessItem, TransferItem, - and ProcessResponse. - """ - - def __init__(self, - work_queue, - throttle, - thread_gate, - request_manager, - num_threads, - batch_size, - state_message, - get_time): - """Initialize the BulkLoaderThread instance. - - Args: - work_queue: A queue containing WorkItems for processing. - throttle: A Throttles to control upload bandwidth. - thread_gate: A ThreadGate to control number of simultaneous uploads. - request_manager: A RequestManager instance. - num_threads: The number of threads for parallel transfers. - batch_size: The number of entities to transfer per request. - state_message: Used for dependency injection. - get_time: Used for dependency injection. - """ - _ThreadBase.__init__(self) - - self.work_queue = work_queue - self.throttle = throttle - self.thread_gate = thread_gate - self.request_manager = request_manager - self.num_threads = num_threads - self.batch_size = batch_size - self.state_message = state_message - self.get_time = get_time - - def PreProcessItem(self, item): - """Performs pre transfer processing on a work item.""" - raise NotImplementedError() - - def TransferItem(self, item): - """Transfers the entities associated with an item. - - Args: - item: An item of upload (WorkItem) or download (KeyRange) work. - - Returns: - A tuple of (estimated transfer size, response) - """ - raise NotImplementedError() - - def ProcessResponse(self, item, result): - """Processes the response from the server application.""" - raise NotImplementedError() - - def PerformWork(self): - """Perform the work of a _BulkWorkerThread.""" - while not self.exit_flag: - transferred = False - self.thread_gate.StartWork() - try: - try: - item = self.work_queue.get(block=True, timeout=1.0) - except Queue.Empty: - continue - if item == _THREAD_SHOULD_EXIT: - break - - logger.debug('[%s] Got work item %s', self.getName(), item) - - try: - - item.MarkAsTransferring() - self.PreProcessItem(item) - response = None - try: - try: - t = self.get_time() - response = self.TransferItem(item) - status = 200 - transferred = True - transfer_time = self.get_time() - t - logger.debug('[%s] %s Transferred %d entities in %0.1f seconds', - self.getName(), item, item.count, transfer_time) - self.throttle.AddTransfer(RECORDS, item.count) - except (db.InternalError, db.NotSavedError, db.Timeout, - apiproxy_errors.OverQuotaError, - apiproxy_errors.DeadlineExceededError), e: - logger.exception('Caught non-fatal datastore error: %s', e) - except urllib2.HTTPError, e: - status = e.code - if status == 403 or (status >= 500 and status < 600): - logger.exception('Caught non-fatal HTTP error: %d %s', - status, e.msg) - else: - raise e - except urllib2.URLError, e: - if IsURLErrorFatal(e): - raise e - else: - logger.exception('Caught non-fatal URL error: %s', e.reason) - - self.ProcessResponse(item, response) - - except: - self.error = sys.exc_info()[1] - logger.exception('[%s] %s: caught exception %s', self.getName(), - self.__class__.__name__, str(sys.exc_info())) - raise - - finally: - if transferred: - item.MarkAsTransferred() - self.work_queue.task_done() - self.thread_gate.TransferSuccess(transfer_time) - else: - item.MarkAsError() - try: - self.work_queue.reput(item, block=False) - except Queue.Full: - logger.error('[%s] Failed to reput work item.', self.getName()) - raise Error('Failed to reput work item') - self.thread_gate.DecreaseWorkers() - logger.info('%s %s', - item, - self.state_message(item.state)) - - finally: - self.thread_gate.FinishWork() - - - def GetFriendlyName(self): - """Returns a human-friendly name for this thread.""" - return 'worker [%s]' % self.getName() - - -class BulkLoaderThread(_BulkWorkerThread): - """A thread which transmits entities to the server application. - - This thread will read WorkItem instances from the work_queue and upload - the entities to the server application. Progress information will be - pushed into the progress_queue as the work is being performed. - - If a BulkLoaderThread encounters a transient error, the entities will be - resent, if a fatal error is encoutered the BulkLoaderThread exits. - """ - - def __init__(self, - work_queue, - throttle, - thread_gate, - request_manager, - num_threads, - batch_size, - get_time=time.time): - """Initialize the BulkLoaderThread instance. - - Args: - work_queue: A queue containing WorkItems for processing. - throttle: A Throttles to control upload bandwidth. - thread_gate: A ThreadGate to control number of simultaneous uploads. - request_manager: A RequestManager instance. - num_threads: The number of threads for parallel transfers. - batch_size: The number of entities to transfer per request. - get_time: Used for dependency injection. - """ - _BulkWorkerThread.__init__(self, - work_queue, - throttle, - thread_gate, - request_manager, - num_threads, - batch_size, - ImportStateMessage, - get_time) - - def PreProcessItem(self, item): - """Performs pre transfer processing on a work item.""" - if item and not item.content: - item.content = self.request_manager.EncodeContent(item.rows) - - def TransferItem(self, item): - """Transfers the entities associated with an item. - - Args: - item: An item of upload (WorkItem) work. - - Returns: - A tuple of (estimated transfer size, response) - """ - return self.request_manager.PostEntities(item) - - def ProcessResponse(self, item, response): - """Processes the response from the server application.""" - pass - - -class BulkExporterThread(_BulkWorkerThread): - """A thread which recieved entities to the server application. - - This thread will read KeyRange instances from the work_queue and export - the entities from the server application. Progress information will be - pushed into the progress_queue as the work is being performed. - - If a BulkExporterThread encounters an error when trying to post data, - the thread will exit and cause the application to terminate. - """ - - def __init__(self, - work_queue, - throttle, - thread_gate, - request_manager, - num_threads, - batch_size, - get_time=time.time): - - """Initialize the BulkExporterThread instance. - - Args: - work_queue: A queue containing KeyRanges for processing. - throttle: A Throttles to control upload bandwidth. - thread_gate: A ThreadGate to control number of simultaneous uploads. - request_manager: A RequestManager instance. - num_threads: The number of threads for parallel transfers. - batch_size: The number of entities to transfer per request. - get_time: Used for dependency injection. - """ - _BulkWorkerThread.__init__(self, - work_queue, - throttle, - thread_gate, - request_manager, - num_threads, - batch_size, - ExportStateMessage, - get_time) - - def PreProcessItem(self, unused_item): - """Performs pre transfer processing on a work item.""" - pass - - def TransferItem(self, item): - """Transfers the entities associated with an item. - - Args: - item: An item of download (KeyRange) work. - - Returns: - A tuple of (estimated transfer size, response) - """ - return self.request_manager.GetEntities(item) - - def ProcessResponse(self, item, export_result): - """Processes the response from the server application.""" - if export_result: - item.Process(export_result, self.num_threads, self.batch_size, - self.work_queue) - item.state = STATE_GOT - - class DataSourceThread(_ThreadBase): """A thread which reads WorkItems and pushes them into queue. This thread will read/consume WorkItems from a generator (produced by the generator factory). These WorkItems will then be pushed into the - work_queue. Note that reading will block if/when the work_queue becomes + thread_pool. Note that reading will block if/when the thread_pool becomes full. Information on content consumed from the generator will be pushed into the progress_queue. """ @@ -2337,14 +1444,16 @@ NAME = 'data source thread' def __init__(self, - work_queue, + request_manager, + thread_pool, progress_queue, workitem_generator_factory, progress_generator_factory): """Initialize the DataSourceThread instance. Args: - work_queue: A queue containing WorkItems for processing. + request_manager: A RequestManager instance. + thread_pool: An AdaptiveThreadPool instance. progress_queue: A queue used for tracking progress information. workitem_generator_factory: A factory that creates a WorkItem generator progress_generator_factory: A factory that creates a generator which @@ -2353,7 +1462,8 @@ """ _ThreadBase.__init__(self) - self.work_queue = work_queue + self.request_manager = request_manager + self.thread_pool = thread_pool self.progress_queue = progress_queue self.workitem_generator_factory = workitem_generator_factory self.progress_generator_factory = progress_generator_factory @@ -2366,7 +1476,8 @@ else: progress_gen = None - content_gen = self.workitem_generator_factory(self.progress_queue, + content_gen = self.workitem_generator_factory(self.request_manager, + self.progress_queue, progress_gen) self.xfer_count = 0 @@ -2378,7 +1489,7 @@ while not self.exit_flag: try: - self.work_queue.put(item, block=True, timeout=1.0) + self.thread_pool.SubmitItem(item, block=True, timeout=1.0) self.entity_count += item.count break except Queue.Full: @@ -2526,6 +1637,70 @@ self.update_cursor = self.secondary_conn.cursor() +zero_matcher = re.compile(r'\x00') + +zero_one_matcher = re.compile(r'\x00\x01') + + +def KeyStr(key): + """Returns a string to represent a key, preserving ordering. + + Unlike datastore.Key.__str__(), we have the property: + + key1 < key2 ==> KeyStr(key1) < KeyStr(key2) + + The key string is constructed from the key path as follows: + (1) Strings are prepended with ':' and numeric id's are padded to + 20 digits. + (2) Any null characters (u'\0') present are replaced with u'\0\1' + (3) The sequence u'\0\0' is used to separate each component of the path. + + (1) assures that names and ids compare properly, while (2) and (3) enforce + the part-by-part comparison of pieces of the path. + + Args: + key: A datastore.Key instance. + + Returns: + A string representation of the key, which preserves ordering. + """ + assert isinstance(key, datastore.Key) + path = key.to_path() + + out_path = [] + for part in path: + if isinstance(part, (int, long)): + part = '%020d' % part + else: + part = ':%s' % part + + out_path.append(zero_matcher.sub(u'\0\1', part)) + + out_str = u'\0\0'.join(out_path) + + return out_str + + +def StrKey(key_str): + """The inverse of the KeyStr function. + + Args: + key_str: A string in the range of KeyStr. + + Returns: + A datastore.Key instance k, such that KeyStr(k) == key_str. + """ + parts = key_str.split(u'\0\0') + for i in xrange(len(parts)): + if parts[i][0] == ':': + part = parts[i][1:] + part = zero_one_matcher.sub(u'\0', part) + parts[i] = part + else: + parts[i] = int(parts[i]) + return datastore.Key.from_path(*parts) + + class ResultDatabase(_Database): """Persistently record all the entities downloaded during an export. @@ -2544,7 +1719,7 @@ """ self.complete = False create_table = ('create table result (\n' - 'id TEXT primary key,\n' + 'id BLOB primary key,\n' 'value BLOB not null)') _Database.__init__(self, @@ -2560,34 +1735,37 @@ self.existing_count = 0 self.count = self.existing_count - def _StoreEntity(self, entity_id, value): + def _StoreEntity(self, entity_id, entity): """Store an entity in the result database. Args: - entity_id: A db.Key for the entity. - value: A string of the contents of the entity. + entity_id: A datastore.Key for the entity. + entity: The entity to store. Returns: True if this entities is not already present in the result database. """ assert _RunningInThread(self.secondary_thread) - assert isinstance(entity_id, db.Key) - - entity_id = entity_id.id_or_name() + assert isinstance(entity_id, datastore.Key), ( + 'expected a datastore.Key, got a %s' % entity_id.__class__.__name__) + + key_str = buffer(KeyStr(entity_id).encode('utf-8')) self.insert_cursor.execute( - 'select count(*) from result where id = ?', (unicode(entity_id),)) + 'select count(*) from result where id = ?', (key_str,)) + already_present = self.insert_cursor.fetchone()[0] result = True if already_present: result = False self.insert_cursor.execute('delete from result where id = ?', - (unicode(entity_id),)) + (key_str,)) else: self.count += 1 + value = entity.Encode() self.insert_cursor.execute( 'insert into result (id, value) values (?, ?)', - (unicode(entity_id), buffer(value))) + (key_str, buffer(value))) return result def StoreEntities(self, keys, entities): @@ -2603,9 +1781,9 @@ self._OpenSecondaryConnection() t = time.time() count = 0 - for entity_id, value in zip(keys, - entities): - if self._StoreEntity(entity_id, value): + for entity_id, entity in zip(keys, + entities): + if self._StoreEntity(entity_id, entity): count += 1 logger.debug('%s insert: delta=%.3f', self.db_filename, @@ -2627,7 +1805,8 @@ 'select id, value from result order by id') for unused_entity_id, entity in cursor: - yield cPickle.loads(str(entity)) + entity_proto = entity_pb.EntityProto(contents=entity) + yield datastore.Entity._FromPb(entity_proto) class _ProgressDatabase(_Database): @@ -2723,9 +1902,16 @@ self._OpenSecondaryConnection() assert _RunningInThread(self.secondary_thread) - assert not key_start or isinstance(key_start, self.py_type) - assert not key_end or isinstance(key_end, self.py_type), '%s is a %s' % ( - key_end, key_end.__class__) + assert (not key_start) or isinstance(key_start, self.py_type), ( + '%s is a %s, %s expected %s' % (key_start, + key_start.__class__, + self.__class__.__name__, + self.py_type)) + assert (not key_end) or isinstance(key_end, self.py_type), ( + '%s is a %s, %s expected %s' % (key_end, + key_end.__class__, + self.__class__.__name__, + self.py_type)) assert KeyLEQ(key_start, key_end), '%s not less than %s' % ( repr(key_start), repr(key_end)) @@ -2843,7 +2029,7 @@ _ProgressDatabase.__init__(self, db_filename, 'TEXT', - db.Key, + datastore.Key, signature, commit_periodicity=1) @@ -3011,34 +2197,72 @@ exporter.output_entities(self.result_db.AllEntities()) def UpdateProgress(self, item): - """Update the state of the given KeyRange. + """Update the state of the given KeyRangeItem. Args: item: A KeyRange instance. """ if item.state == STATE_GOT: - count = self.result_db.StoreEntities(item.export_result.keys, - item.export_result.entities) + count = self.result_db.StoreEntities(item.download_result.keys, + item.download_result.entities) self.db.DeleteKey(item.progress_key) self.entities_transferred += count else: self.db.UpdateState(item.progress_key, item.state) +class MapperProgressThread(_ProgressThreadBase): + """A thread to record progress information for maps over the datastore.""" + + def __init__(self, kind, progress_queue, progress_db): + """Initialize the MapperProgressThread instance. + + Args: + kind: The kind of entities being stored in the database. + progress_queue: A Queue used for tracking progress information. + progress_db: The database for tracking progress information; should + be an instance of ProgressDatabase. + """ + _ProgressThreadBase.__init__(self, progress_queue, progress_db) + + self.kind = kind + self.mapper = Mapper.RegisteredMapper(self.kind) + + def EntitiesTransferred(self): + """Return the total number of unique entities transferred.""" + return self.entities_transferred + + def WorkFinished(self): + """Perform actions after map is complete.""" + pass + + def UpdateProgress(self, item): + """Update the state of the given KeyRangeItem. + + Args: + item: A KeyRange instance. + """ + if item.state == STATE_GOT: + self.entities_transferred += item.count + self.db.DeleteKey(item.progress_key) + else: + self.db.UpdateState(item.progress_key, item.state) + + def ParseKey(key_string): - """Turn a key stored in the database into a db.Key or None. + """Turn a key stored in the database into a Key or None. Args: - key_string: The string representation of a db.Key. + key_string: The string representation of a Key. Returns: - A db.Key instance or None + A datastore.Key instance or None """ if not key_string: return None if key_string == 'None': return None - return db.Key(encoded=key_string) + return datastore.Key(encoded=key_string) def Validate(value, typ): @@ -3097,9 +2321,7 @@ def __init__(self, kind, properties): """Constructor. - Populates this Loader's kind and properties map. Also registers it with - the bulk loader, so that all you need to do is instantiate your Loader, - and the bulkload handler will automatically use it. + Populates this Loader's kind and properties map. Args: kind: a string containing the entity kind that this loader handles @@ -3139,7 +2361,11 @@ @staticmethod def RegisterLoader(loader): - + """Register loader and the Loader instance for its kind. + + Args: + loader: A Loader instance. + """ Loader.__loaders[loader.kind] = loader def alias_old_names(self): @@ -3166,7 +2392,7 @@ Args: values: list/tuple of str key_name: if provided, the name for the (single) resulting entity - parent: A db.Key instance for the parent, or None + parent: A datastore.Key instance for the parent, or None Returns: list of db.Model @@ -3222,7 +2448,7 @@ server generated numeric key), or a string which neither starts with a digit nor has the form __*__ (see http://code.google.com/appengine/docs/python/datastore/keysandentitygroups.html), - or a db.Key instance. + or a datastore.Key instance. If you generate your own string keys, keep in mind: @@ -3305,6 +2531,51 @@ return Loader.__loaders[kind] +class RestoreThread(_ThreadBase): + """A thread to read saved entity_pbs from sqlite3.""" + NAME = 'RestoreThread' + _ENTITIES_DONE = 'Entities Done' + + def __init__(self, queue, filename): + _ThreadBase.__init__(self) + self.queue = queue + self.filename = filename + + def PerformWork(self): + db_conn = sqlite3.connect(self.filename) + cursor = db_conn.cursor() + cursor.execute('select id, value from result') + for entity_id, value in cursor: + self.queue.put([entity_id, value], block=True) + self.queue.put(RestoreThread._ENTITIES_DONE, block=True) + + +class RestoreLoader(Loader): + """A Loader which imports protobuffers from a file.""" + + def __init__(self, kind): + self.kind = kind + + def initialize(self, filename, loader_opts): + CheckFile(filename) + self.queue = Queue.Queue(1000) + restore_thread = RestoreThread(self.queue, filename) + restore_thread.start() + + def generate_records(self, filename): + while True: + record = self.queue.get(block=True) + if id(record) == id(RestoreThread._ENTITIES_DONE): + break + yield record + + def create_entity(self, values, key_name=None, parent=None): + key = StrKey(unicode(values[0], 'utf-8')) + entity_proto = entity_pb.EntityProto(contents=str(values[1])) + entity_proto.mutable_key().CopyFrom(key._Key__reference) + return datastore.Entity._FromPb(entity_proto) + + class Exporter(object): """A base class for serializing datastore entities. @@ -3326,9 +2597,7 @@ def __init__(self, kind, properties): """Constructor. - Populates this Exporters's kind and properties map. Also registers - it so that all you need to do is instantiate your Exporter, and - the bulkload handler will automatically use it. + Populates this Exporters's kind and properties map. Args: kind: a string containing the entity kind that this exporter handles @@ -3370,7 +2639,11 @@ @staticmethod def RegisterExporter(exporter): - + """Register exporter and the Exporter instance for its kind. + + Args: + exporter: A Exporter instance. + """ Exporter.__exporters[exporter.kind] = exporter def __ExtractProperties(self, entity): @@ -3388,7 +2661,7 @@ encoding = [] for name, fn, default in self.__properties: try: - encoding.append(fn(getattr(entity, name))) + encoding.append(fn(entity[name])) except AttributeError: if default is None: raise MissingPropertyError(name) @@ -3468,6 +2741,87 @@ return Exporter.__exporters[kind] +class DumpExporter(Exporter): + """An exporter which dumps protobuffers to a file.""" + + def __init__(self, kind, result_db_filename): + self.kind = kind + self.result_db_filename = result_db_filename + + def output_entities(self, entity_generator): + shutil.copyfile(self.result_db_filename, self.output_filename) + + +class MapperRetry(Error): + """An exception that indicates a non-fatal error during mapping.""" + + +class Mapper(object): + """A base class for serializing datastore entities. + + To add a handler for exporting an entity kind from your datastore, + write a subclass of this class that calls Mapper.__init__ from your + class's __init__. + + You need to implement to batch_apply or apply method on your subclass + for the map to do anything. + """ + + __mappers = {} + kind = None + + def __init__(self, kind): + """Constructor. + + Populates this Mappers's kind. + + Args: + kind: a string containing the entity kind that this mapper handles + """ + Validate(kind, basestring) + self.kind = kind + + GetImplementationClass(kind) + + @staticmethod + def RegisterMapper(mapper): + """Register mapper and the Mapper instance for its kind. + + Args: + mapper: A Mapper instance. + """ + Mapper.__mappers[mapper.kind] = mapper + + def initialize(self, mapper_opts): + """Performs initialization. + + Args: + mapper_opts: The string given as the --mapper_opts flag argument. + """ + pass + + def finalize(self): + """Performs finalization actions after the download completes.""" + pass + + def apply(self, entity): + print 'Default map function doing nothing to %s' % entity + + def batch_apply(self, entities): + for entity in entities: + self.apply(entity) + + @staticmethod + def RegisteredMappers(): + """Returns a dictionary of the mapper instances that have been created.""" + return dict(Mapper.__mappers) + + @staticmethod + def RegisteredMapper(kind): + """Returns an mapper instance for the given kind if it exists.""" + return Mapper.__mappers[kind] + + class QueueJoinThread(threading.Thread): """A thread that joins a queue and exits. @@ -3492,7 +2846,7 @@ def InterruptibleQueueJoin(queue, thread_local, - thread_gate, + thread_pool, queue_join_thread_factory=QueueJoinThread, check_workers=True): """Repeatedly joins the given ReQueue or Queue.Queue with short timeout. @@ -3502,7 +2856,7 @@ Args: queue: A Queue.Queue or ReQueue instance. thread_local: A threading.local instance which indicates interrupts. - thread_gate: A ThreadGate instance. + thread_pool: An AdaptiveThreadPool instance. queue_join_thread_factory: Used for dependency injection. check_workers: Whether to interrupt the join on worker death. @@ -3519,41 +2873,29 @@ logger.debug('Queue join interrupted') return False if check_workers: - for worker_thread in thread_gate.Threads(): + for worker_thread in thread_pool.Threads(): if not worker_thread.isAlive(): return False -def ShutdownThreads(data_source_thread, work_queue, thread_gate): +def ShutdownThreads(data_source_thread, thread_pool): """Shuts down the worker and data source threads. Args: data_source_thread: A running DataSourceThread instance. - work_queue: The work queue. - thread_gate: A ThreadGate instance with workers registered. + thread_pool: An AdaptiveThreadPool instance with workers registered. """ logger.info('An error occurred. Shutting down...') data_source_thread.exit_flag = True - for thread in thread_gate.Threads(): - thread.exit_flag = True - - for unused_thread in thread_gate.Threads(): - thread_gate.EnableThread() + thread_pool.Shutdown() data_source_thread.join(timeout=3.0) if data_source_thread.isAlive(): logger.warn('%s hung while trying to exit', data_source_thread.GetFriendlyName()) - while not work_queue.empty(): - try: - unused_item = work_queue.get_nowait() - work_queue.task_done() - except Queue.Empty: - pass - class BulkTransporterApp(object): """Class to wrap bulk transport application functionality.""" @@ -3563,13 +2905,12 @@ input_generator_factory, throttle, progress_db, - workerthread_factory, progresstrackerthread_factory, max_queue_size=DEFAULT_QUEUE_SIZE, request_manager_factory=RequestManager, datasourcethread_factory=DataSourceThread, - work_queue_factory=ReQueue, - progress_queue_factory=Queue.Queue): + progress_queue_factory=Queue.Queue, + thread_pool_factory=adaptive_thread_pool.AdaptiveThreadPool): """Instantiate a BulkTransporterApp. Uploads or downloads data to or from application using HTTP requests. @@ -3584,13 +2925,12 @@ input_generator_factory: A factory that creates a WorkItem generator. throttle: A Throttle instance. progress_db: The database to use for replaying/recording progress. - workerthread_factory: A factory for worker threads. progresstrackerthread_factory: Used for dependency injection. max_queue_size: Maximum size of the queues before they should block. request_manager_factory: Used for dependency injection. datasourcethread_factory: Used for dependency injection. - work_queue_factory: Used for dependency injection. progress_queue_factory: Used for dependency injection. + thread_pool_factory: Used for dependency injection. """ self.app_id = arg_dict['app_id'] self.post_url = arg_dict['url'] @@ -3600,15 +2940,15 @@ self.num_threads = arg_dict['num_threads'] self.email = arg_dict['email'] self.passin = arg_dict['passin'] + self.dry_run = arg_dict['dry_run'] self.throttle = throttle self.progress_db = progress_db - self.workerthread_factory = workerthread_factory self.progresstrackerthread_factory = progresstrackerthread_factory self.max_queue_size = max_queue_size self.request_manager_factory = request_manager_factory self.datasourcethread_factory = datasourcethread_factory - self.work_queue_factory = work_queue_factory self.progress_queue_factory = progress_queue_factory + self.thread_pool_factory = thread_pool_factory (scheme, self.host_port, self.url_path, unused_query, unused_fragment) = urlparse.urlsplit(self.post_url) @@ -3623,13 +2963,13 @@ Returns: Error code suitable for sys.exit, e.g. 0 on success, 1 on failure. """ - thread_gate = ThreadGate(True) + self.error = False + thread_pool = self.thread_pool_factory( + self.num_threads, queue_size=self.max_queue_size) self.throttle.Register(threading.currentThread()) threading.currentThread().exit_flag = False - work_queue = self.work_queue_factory(self.max_queue_size) - progress_queue = self.progress_queue_factory(self.max_queue_size) request_manager = self.request_manager_factory(self.app_id, self.host_port, @@ -3639,27 +2979,23 @@ self.batch_size, self.secure, self.email, - self.passin) + self.passin, + self.dry_run) try: request_manager.Authenticate() except Exception, e: + self.error = True if not isinstance(e, urllib2.HTTPError) or ( e.code != 302 and e.code != 401): logger.exception('Exception during authentication') raise AuthenticationError() if (request_manager.auth_called and not request_manager.authenticated): + self.error = True raise AuthenticationError('Authentication failed') - for unused_idx in xrange(self.num_threads): - thread = self.workerthread_factory(work_queue, - self.throttle, - thread_gate, - request_manager, - self.num_threads, - self.batch_size) + for thread in thread_pool.Threads(): self.throttle.Register(thread) - thread_gate.Register(thread) self.progress_thread = self.progresstrackerthread_factory( progress_queue, self.progress_db) @@ -3671,7 +3007,8 @@ progress_generator_factory = None self.data_source_thread = ( - self.datasourcethread_factory(work_queue, + self.datasourcethread_factory(request_manager, + thread_pool, progress_queue, self.input_generator_factory, progress_generator_factory)) @@ -3682,60 +3019,54 @@ def Interrupt(unused_signum, unused_frame): """Shutdown gracefully in response to a signal.""" thread_local.shut_down = True + self.error = True signal.signal(signal.SIGINT, Interrupt) self.progress_thread.start() self.data_source_thread.start() - for thread in thread_gate.Threads(): - thread.start() while not thread_local.shut_down: self.data_source_thread.join(timeout=0.25) if self.data_source_thread.isAlive(): - for thread in list(thread_gate.Threads()) + [self.progress_thread]: + for thread in list(thread_pool.Threads()) + [self.progress_thread]: if not thread.isAlive(): logger.info('Unexpected thread death: %s', thread.getName()) thread_local.shut_down = True + self.error = True break else: break - if thread_local.shut_down: - ShutdownThreads(self.data_source_thread, work_queue, thread_gate) - def _Join(ob, msg): logger.debug('Waiting for %s...', msg) if isinstance(ob, threading.Thread): ob.join(timeout=3.0) if ob.isAlive(): - logger.debug('Joining %s failed', ob.GetFriendlyName()) + logger.debug('Joining %s failed', ob) else: logger.debug('... done.') elif isinstance(ob, (Queue.Queue, ReQueue)): - if not InterruptibleQueueJoin(ob, thread_local, thread_gate): - ShutdownThreads(self.data_source_thread, work_queue, thread_gate) + if not InterruptibleQueueJoin(ob, thread_local, thread_pool): + ShutdownThreads(self.data_source_thread, thread_pool) else: ob.join() logger.debug('... done.') - _Join(work_queue, 'work_queue to flush') - - for unused_thread in thread_gate.Threads(): - work_queue.put(_THREAD_SHOULD_EXIT) - - for unused_thread in thread_gate.Threads(): - thread_gate.EnableThread() - - for thread in thread_gate.Threads(): - _Join(thread, 'thread [%s] to terminate' % thread.getName()) - - thread.CheckError() + if self.data_source_thread.error or thread_local.shut_down: + ShutdownThreads(self.data_source_thread, thread_pool) + else: + _Join(thread_pool.requeue, 'worker threads to finish') + + thread_pool.Shutdown() + thread_pool.JoinThreads() + thread_pool.CheckErrors() + print '' if self.progress_thread.isAlive(): - InterruptibleQueueJoin(progress_queue, thread_local, thread_gate, + InterruptibleQueueJoin(progress_queue, thread_local, thread_pool, check_workers=False) else: logger.warn('Progress thread exited prematurely') @@ -3763,9 +3094,10 @@ def ReportStatus(self): """Display a message reporting the final status of the transfer.""" - total_up, duration = self.throttle.TotalTransferred(BANDWIDTH_UP) + total_up, duration = self.throttle.TotalTransferred( + remote_api_throttle.BANDWIDTH_UP) s_total_up, unused_duration = self.throttle.TotalTransferred( - HTTPS_BANDWIDTH_UP) + remote_api_throttle.HTTPS_BANDWIDTH_UP) total_up += s_total_up total = total_up logger.info('%d entites total, %d previously transferred', @@ -3793,18 +3125,49 @@ def ReportStatus(self): """Display a message reporting the final status of the transfer.""" - total_down, duration = self.throttle.TotalTransferred(BANDWIDTH_DOWN) + total_down, duration = self.throttle.TotalTransferred( + remote_api_throttle.BANDWIDTH_DOWN) s_total_down, unused_duration = self.throttle.TotalTransferred( - HTTPS_BANDWIDTH_DOWN) + remote_api_throttle.HTTPS_BANDWIDTH_DOWN) total_down += s_total_down total = total_down existing_count = self.progress_thread.existing_count xfer_count = self.progress_thread.EntitiesTransferred() logger.info('Have %d entities, %d previously transferred', - xfer_count + existing_count, existing_count) + xfer_count, existing_count) logger.info('%d entities (%d bytes) transferred in %.1f seconds', xfer_count, total, duration) - return 0 + if self.error: + return 1 + else: + return 0 + + +class BulkMapperApp(BulkTransporterApp): + """Class to encapsulate bulk map functionality.""" + + def __init__(self, *args, **kwargs): + BulkTransporterApp.__init__(self, *args, **kwargs) + + def ReportStatus(self): + """Display a message reporting the final status of the transfer.""" + total_down, duration = self.throttle.TotalTransferred( + remote_api_throttle.BANDWIDTH_DOWN) + s_total_down, unused_duration = self.throttle.TotalTransferred( + remote_api_throttle.HTTPS_BANDWIDTH_DOWN) + total_down += s_total_down + total = total_down + xfer_count = self.progress_thread.EntitiesTransferred() + logger.info('The following may be inaccurate if any mapper tasks ' + 'encountered errors and had to be retried.') + logger.info('Applied mapper to %s entities.', + xfer_count) + logger.info('%s entities (%s bytes) transferred in %.1f seconds', + xfer_count, total, duration) + if self.error: + return 1 + else: + return 0 def PrintUsageExit(code): @@ -3843,18 +3206,24 @@ 'loader_opts=', 'exporter_opts=', 'log_file=', + 'mapper_opts=', 'email=', 'passin', + 'map', + 'dry_run', + 'dump', + 'restore', ] -def ParseArguments(argv): +def ParseArguments(argv, die_fn=lambda: PrintUsageExit(1)): """Parses command-line arguments. Prints out a help message if -h or --help is supplied. Args: argv: List of command-line arguments. + die_fn: Function to invoke to end the program. Returns: A dictionary containing the value of command-line options. @@ -3867,11 +3236,11 @@ arg_dict = {} arg_dict['url'] = REQUIRED_OPTION - arg_dict['filename'] = REQUIRED_OPTION - arg_dict['config_file'] = REQUIRED_OPTION - arg_dict['kind'] = REQUIRED_OPTION - - arg_dict['batch_size'] = DEFAULT_BATCH_SIZE + arg_dict['filename'] = None + arg_dict['config_file'] = None + arg_dict['kind'] = None + + arg_dict['batch_size'] = None arg_dict['num_threads'] = DEFAULT_THREAD_COUNT arg_dict['bandwidth_limit'] = DEFAULT_BANDWIDTH_LIMIT arg_dict['rps_limit'] = DEFAULT_RPS_LIMIT @@ -3889,6 +3258,11 @@ arg_dict['log_file'] = None arg_dict['email'] = None arg_dict['passin'] = False + arg_dict['mapper_opts'] = None + arg_dict['map'] = False + arg_dict['dry_run'] = False + arg_dict['dump'] = False + arg_dict['restore'] = False def ExpandFilename(filename): """Expand shell variables and ~usernames in filename.""" @@ -3938,26 +3312,39 @@ elif option == '--exporter_opts': arg_dict['exporter_opts'] = value elif option == '--log_file': - arg_dict['log_file'] = value + arg_dict['log_file'] = ExpandFilename(value) elif option == '--email': arg_dict['email'] = value elif option == '--passin': arg_dict['passin'] = True - - return ProcessArguments(arg_dict, die_fn=lambda: PrintUsageExit(1)) + elif option == '--map': + arg_dict['map'] = True + elif option == '--mapper_opts': + arg_dict['mapper_opts'] = value + elif option == '--dry_run': + arg_dict['dry_run'] = True + elif option == '--dump': + arg_dict['dump'] = True + elif option == '--restore': + arg_dict['restore'] = True + + return ProcessArguments(arg_dict, die_fn=die_fn) def ThrottleLayout(bandwidth_limit, http_limit, rps_limit): """Return a dictionary indicating the throttle options.""" - return { - BANDWIDTH_UP: bandwidth_limit, - BANDWIDTH_DOWN: bandwidth_limit, - REQUESTS: http_limit, - HTTPS_BANDWIDTH_UP: bandwidth_limit / 5, - HTTPS_BANDWIDTH_DOWN: bandwidth_limit / 5, - HTTPS_REQUESTS: http_limit / 5, - RECORDS: rps_limit, - } + bulkloader_limits = dict(remote_api_throttle.NO_LIMITS) + bulkloader_limits.update({ + remote_api_throttle.BANDWIDTH_UP: bandwidth_limit, + remote_api_throttle.BANDWIDTH_DOWN: bandwidth_limit, + remote_api_throttle.REQUESTS: http_limit, + remote_api_throttle.HTTPS_BANDWIDTH_UP: bandwidth_limit, + remote_api_throttle.HTTPS_BANDWIDTH_DOWN: bandwidth_limit, + remote_api_throttle.HTTPS_REQUESTS: http_limit, + remote_api_throttle.ENTITIES_FETCHED: rps_limit, + remote_api_throttle.ENTITIES_MODIFIED: rps_limit, + }) + return bulkloader_limits def CheckOutputFile(filename): @@ -3969,12 +3356,13 @@ Raises: FileExistsError: if the given filename is not found FileNotWritableError: if the given filename is not readable. - """ - if os.path.exists(filename): + """ + full_path = os.path.abspath(filename) + if os.path.exists(full_path): raise FileExistsError('%s: output file exists' % filename) - elif not os.access(os.path.dirname(filename), os.W_OK): + elif not os.access(os.path.dirname(full_path), os.W_OK): raise FileNotWritableError( - '%s: not writable' % os.path.dirname(filename)) + '%s: not writable' % os.path.dirname(full_path)) def LoadConfig(config_file_name, exit_fn=sys.exit): @@ -3999,6 +3387,11 @@ if hasattr(bulkloader_config, 'exporters'): for cls in bulkloader_config.exporters: Exporter.RegisterExporter(cls()) + + if hasattr(bulkloader_config, 'mappers'): + for cls in bulkloader_config.mappers: + Mapper.RegisterMapper(cls()) + except NameError, e: m = re.search(r"[^']*'([^']*)'.*", str(e)) if m.groups() and m.group(1) == 'Loader': @@ -4058,9 +3451,12 @@ url=None, kind=None, db_filename=None, + perform_map=None, download=None, has_header=None, - result_db_filename=None): + result_db_filename=None, + dump=None, + restore=None): """Returns a string that identifies the important options for the database.""" if download: result_db_line = 'result_db: %s' % result_db_filename @@ -4071,10 +3467,14 @@ url: %s kind: %s download: %s + map: %s + dump: %s + restore: %s progress_db: %s has_header: %s %s - """ % (app_id, url, kind, download, db_filename, has_header, result_db_line) + """ % (app_id, url, kind, download, perform_map, dump, restore, db_filename, + has_header, result_db_line) def ProcessArguments(arg_dict, @@ -4090,6 +3490,8 @@ """ app_id = GetArgument(arg_dict, 'app_id', die_fn) url = GetArgument(arg_dict, 'url', die_fn) + dump = GetArgument(arg_dict, 'dump', die_fn) + restore = GetArgument(arg_dict, 'restore', die_fn) filename = GetArgument(arg_dict, 'filename', die_fn) batch_size = GetArgument(arg_dict, 'batch_size', die_fn) kind = GetArgument(arg_dict, 'kind', die_fn) @@ -4098,21 +3500,18 @@ result_db_filename = GetArgument(arg_dict, 'result_db_filename', die_fn) download = GetArgument(arg_dict, 'download', die_fn) log_file = GetArgument(arg_dict, 'log_file', die_fn) - - unused_passin = GetArgument(arg_dict, 'passin', die_fn) - unused_email = GetArgument(arg_dict, 'email', die_fn) - unused_debug = GetArgument(arg_dict, 'debug', die_fn) - unused_num_threads = GetArgument(arg_dict, 'num_threads', die_fn) - unused_bandwidth_limit = GetArgument(arg_dict, 'bandwidth_limit', die_fn) - unused_rps_limit = GetArgument(arg_dict, 'rps_limit', die_fn) - unused_http_limit = GetArgument(arg_dict, 'http_limit', die_fn) - unused_auth_domain = GetArgument(arg_dict, 'auth_domain', die_fn) - unused_has_headers = GetArgument(arg_dict, 'has_header', die_fn) - unused_loader_opts = GetArgument(arg_dict, 'loader_opts', die_fn) - unused_exporter_opts = GetArgument(arg_dict, 'exporter_opts', die_fn) + perform_map = GetArgument(arg_dict, 'map', die_fn) errors = [] + if batch_size is None: + if download or perform_map: + arg_dict['batch_size'] = DEFAULT_DOWNLOAD_BATCH_SIZE + else: + arg_dict['batch_size'] = DEFAULT_BATCH_SIZE + elif batch_size <= 0: + errors.append('batch_size must be at least 1') + if db_filename is None: arg_dict['db_filename'] = time.strftime( 'bulkloader-progress-%Y%m%d.%H%M%S.sql3') @@ -4124,37 +3523,35 @@ if log_file is None: arg_dict['log_file'] = time.strftime('bulkloader-log-%Y%m%d.%H%M%S') - if batch_size <= 0: - errors.append('batch_size must be at least 1') - required = '%s argument required' + if config_file is None and not dump and not restore: + errors.append('One of --config_file, --dump, or --restore is required') + if url is REQUIRED_OPTION: errors.append(required % 'url') - if filename is REQUIRED_OPTION: + if not filename and not perform_map: errors.append(required % 'filename') - if kind is REQUIRED_OPTION: - errors.append(required % 'kind') - - if config_file is REQUIRED_OPTION: - errors.append(required % 'config_file') - - if download: - if result_db_filename is REQUIRED_OPTION: - errors.append(required % 'result_db_filename') + if kind is None: + if download or map: + errors.append('kind argument required for this operation') + elif not dump and not restore: + errors.append( + 'kind argument required unless --dump or --restore is specified') if not app_id: - (unused_scheme, host_port, unused_url_path, - unused_query, unused_fragment) = urlparse.urlsplit(url) - suffix_idx = host_port.find('.appspot.com') - if suffix_idx > -1: - arg_dict['app_id'] = host_port[:suffix_idx] - elif host_port.split(':')[0].endswith('google.com'): - arg_dict['app_id'] = host_port.split('.')[0] - else: - errors.append('app_id argument required for non appspot.com domains') + if url and url is not REQUIRED_OPTION: + (unused_scheme, host_port, unused_url_path, + unused_query, unused_fragment) = urlparse.urlsplit(url) + suffix_idx = host_port.find('.appspot.com') + if suffix_idx > -1: + arg_dict['app_id'] = host_port[:suffix_idx] + elif host_port.split(':')[0].endswith('google.com'): + arg_dict['app_id'] = host_port.split('.')[0] + else: + errors.append('app_id argument required for non appspot.com domains') if errors: print >>sys.stderr, '\n'.join(errors) @@ -4203,50 +3600,68 @@ result_db_filename = arg_dict['result_db_filename'] loader_opts = arg_dict['loader_opts'] exporter_opts = arg_dict['exporter_opts'] + mapper_opts = arg_dict['mapper_opts'] email = arg_dict['email'] passin = arg_dict['passin'] + perform_map = arg_dict['map'] + dump = arg_dict['dump'] + restore = arg_dict['restore'] os.environ['AUTH_DOMAIN'] = auth_domain kind = ParseKind(kind) - check_file(config_file) - if not download: + if not dump and not restore: + check_file(config_file) + + if download and perform_map: + logger.error('--download and --map are mutually exclusive.') + + if download or dump: + check_output_file(filename) + elif not perform_map: check_file(filename) + + if dump: + Exporter.RegisterExporter(DumpExporter(kind, result_db_filename)) + elif restore: + Loader.RegisterLoader(RestoreLoader(kind)) else: - check_output_file(filename) - - LoadConfig(config_file) + LoadConfig(config_file) os.environ['APPLICATION_ID'] = app_id throttle_layout = ThrottleLayout(bandwidth_limit, http_limit, rps_limit) - - throttle = Throttle(layout=throttle_layout) + logger.info('Throttling transfers:') + logger.info('Bandwidth: %s bytes/second', bandwidth_limit) + logger.info('HTTP connections: %s/second', http_limit) + logger.info('Entities inserted/fetched/modified: %s/second', rps_limit) + + throttle = remote_api_throttle.Throttle(layout=throttle_layout) signature = _MakeSignature(app_id=app_id, url=url, kind=kind, db_filename=db_filename, download=download, + perform_map=perform_map, has_header=has_header, - result_db_filename=result_db_filename) + result_db_filename=result_db_filename, + dump=dump, + restore=restore) max_queue_size = max(DEFAULT_QUEUE_SIZE, 3 * num_threads + 5) if db_filename == 'skip': progress_db = StubProgressDatabase() - elif not download: + elif not download and not perform_map and not dump: progress_db = ProgressDatabase(db_filename, signature) else: progress_db = ExportProgressDatabase(db_filename, signature) - if download: - result_db = ResultDatabase(result_db_filename, signature) - return_code = 1 - if not download: + if not download and not perform_map and not dump: loader = Loader.RegisteredLoader(kind) try: loader.initialize(filename, loader_opts) @@ -4257,12 +3672,10 @@ workitem_generator_factory, throttle, progress_db, - BulkLoaderThread, ProgressTrackerThread, max_queue_size, RequestManager, DataSourceThread, - ReQueue, Queue.Queue) try: return_code = app.Run() @@ -4270,29 +3683,31 @@ logger.info('Authentication Failed') finally: loader.finalize() - else: + elif not perform_map: + result_db = ResultDatabase(result_db_filename, signature) exporter = Exporter.RegisteredExporter(kind) try: exporter.initialize(filename, exporter_opts) - def KeyRangeGeneratorFactory(progress_queue, progress_gen): - return KeyRangeGenerator(kind, progress_queue, progress_gen) + def KeyRangeGeneratorFactory(request_manager, progress_queue, + progress_gen): + return KeyRangeItemGenerator(request_manager, kind, progress_queue, + progress_gen, DownloadItem) def ExportProgressThreadFactory(progress_queue, progress_db): return ExportProgressThread(kind, progress_queue, progress_db, result_db) + app = BulkDownloaderApp(arg_dict, KeyRangeGeneratorFactory, throttle, progress_db, - BulkExporterThread, ExportProgressThreadFactory, 0, RequestManager, DataSourceThread, - ReQueue, Queue.Queue) try: return_code = app.Run() @@ -4300,6 +3715,35 @@ logger.info('Authentication Failed') finally: exporter.finalize() + elif not download: + mapper = Mapper.RegisteredMapper(kind) + try: + mapper.initialize(mapper_opts) + def KeyRangeGeneratorFactory(request_manager, progress_queue, + progress_gen): + return KeyRangeItemGenerator(request_manager, kind, progress_queue, + progress_gen, MapperItem) + + def MapperProgressThreadFactory(progress_queue, progress_db): + return MapperProgressThread(kind, + progress_queue, + progress_db) + + app = BulkMapperApp(arg_dict, + KeyRangeGeneratorFactory, + throttle, + progress_db, + MapperProgressThreadFactory, + 0, + RequestManager, + DataSourceThread, + Queue.Queue) + try: + return_code = app.Run() + except AuthenticationError: + logger.info('Authentication Failed') + finally: + mapper.finalize() return return_code @@ -4335,8 +3779,17 @@ logger.info('Logging to %s', log_file) + remote_api_throttle.logger.setLevel(level) + remote_api_throttle.logger.addHandler(file_handler) + remote_api_throttle.logger.addHandler(console) + appengine_rpc.logger.setLevel(logging.WARN) + adaptive_thread_pool.logger.setLevel(logging.DEBUG) + adaptive_thread_pool.logger.addHandler(console) + adaptive_thread_pool.logger.addHandler(file_handler) + adaptive_thread_pool.logger.propagate = False + def Run(arg_dict): """Sets up and runs the bulkloader, given the options as keyword arguments.