--- a/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Sat Sep 05 14:04:24 2009 +0200
+++ b/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Sun Sep 06 23:31:53 2009 +0200
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
"""Imports data over HTTP.
Usage:
@@ -33,7 +32,7 @@
the URL endpoint. The more data per row/Entity, the
smaller the batch size should be. (Default 10)
--config_file=<path> File containing Model and Loader definitions.
- (Required)
+ (Required unless --dump or --restore are used)
--db_filename=<path> Specific progress database to write to, or to
resume from. If not supplied, then a new database
will be started, named:
@@ -41,6 +40,8 @@
The special filename "skip" may be used to simply
skip reading/writing any progress information.
--download Export entities to a file.
+ --dry_run Do not execute any remote_api calls.
+ --dump Use zero-configuration dump format.
--email=<string> The username to use. Will prompt if omitted.
--exporter_opts=<string>
A string to pass to the Exporter.initialize method.
@@ -54,9 +55,12 @@
--log_file=<path> File to write bulkloader logs. If not supplied
then a new log file will be created, named:
bulkloader-log-TIMESTAMP.
+ --map Map an action across datastore entities.
+ --mapper_opts=<string> A string to pass to the Mapper.Initialize method.
--num_threads=<int> Number of threads to use for uploading entities
(Default 10)
--passin Read the login password from stdin.
+ --restore Restore from zero-configuration dump format.
--result_db_filename=<path>
Result database to write to for downloads.
--rps_limit=<int> The maximum number of records per second to
@@ -78,7 +82,6 @@
-import cPickle
import csv
import errno
import getopt
@@ -88,20 +91,31 @@
import os
import Queue
import re
+import shutil
import signal
import StringIO
import sys
import threading
import time
+import traceback
import urllib2
import urlparse
+from google.appengine.datastore import entity_pb
+
+from google.appengine.api import apiproxy_stub_map
+from google.appengine.api import datastore
from google.appengine.api import datastore_errors
+from google.appengine.datastore import datastore_pb
from google.appengine.ext import db
+from google.appengine.ext import key_range as key_range_module
from google.appengine.ext.db import polymodel
from google.appengine.ext.remote_api import remote_api_stub
+from google.appengine.ext.remote_api import throttle as remote_api_throttle
from google.appengine.runtime import apiproxy_errors
+from google.appengine.tools import adaptive_thread_pool
from google.appengine.tools import appengine_rpc
+from google.appengine.tools.requeue import ReQueue
try:
import sqlite3
@@ -110,10 +124,14 @@
logger = logging.getLogger('google.appengine.tools.bulkloader')
+KeyRange = key_range_module.KeyRange
+
DEFAULT_THREAD_COUNT = 10
DEFAULT_BATCH_SIZE = 10
+DEFAULT_DOWNLOAD_BATCH_SIZE = 100
+
DEFAULT_QUEUE_SIZE = DEFAULT_THREAD_COUNT * 10
_THREAD_SHOULD_EXIT = '_THREAD_SHOULD_EXIT'
@@ -125,9 +143,7 @@
STATE_GETTING = 1
STATE_GOT = 2
-STATE_NOT_GOT = 3
-
-MINIMUM_THROTTLE_SLEEP_DURATION = 0.001
+STATE_ERROR = 3
DATA_CONSUMED_TO_HERE = 'DATA_CONSUMED_TO_HERE'
@@ -142,16 +158,8 @@
DEFAULT_REQUEST_LIMIT = 8
-BANDWIDTH_UP = 'http-bandwidth-up'
-BANDWIDTH_DOWN = 'http-bandwidth-down'
-REQUESTS = 'http-requests'
-HTTPS_BANDWIDTH_UP = 'https-bandwidth-up'
-HTTPS_BANDWIDTH_DOWN = 'https-bandwidth-down'
-HTTPS_REQUESTS = 'https-requests'
-RECORDS = 'records'
-
-MAXIMUM_INCREASE_DURATION = 8.0
-MAXIMUM_HOLD_DURATION = 10.0
+MAXIMUM_INCREASE_DURATION = 5.0
+MAXIMUM_HOLD_DURATION = 12.0
def ImportStateMessage(state):
@@ -170,7 +178,17 @@
STATE_READ: 'Batch read from file.',
STATE_GETTING: 'Fetching batch from server',
STATE_GOT: 'Batch successfully fetched.',
- STATE_NOT_GOT: 'Error while fetching batch'
+ STATE_ERROR: 'Error while fetching batch'
+ }[state])
+
+
+def MapStateMessage(state):
+ """Converts a numeric state identifier to a status message."""
+ return ({
+ STATE_READ: 'Batch read from file.',
+ STATE_GETTING: 'Querying for batch from server',
+ STATE_GOT: 'Batch successfully fetched.',
+ STATE_ERROR: 'Error while fetching or mapping.'
}[state])
@@ -180,7 +198,7 @@
STATE_READ: 'READ',
STATE_GETTING: 'GETTING',
STATE_GOT: 'GOT',
- STATE_NOT_GOT: 'NOT_GOT'
+ STATE_ERROR: 'NOT_GOT'
}[state])
@@ -190,7 +208,7 @@
STATE_READ: 'READ',
STATE_GETTING: 'SENDING',
STATE_GOT: 'SENT',
- STATE_NOT_GOT: 'NOT_SENT'
+ STATE_NOT_SENT: 'NOT_SENT'
}[state])
@@ -234,16 +252,35 @@
"""A filename passed in by the user refers to a non-writable output file."""
-class KeyRangeError(Error):
- """Error while trying to generate a KeyRange."""
-
-
class BadStateError(Error):
"""A work item in an unexpected state was encountered."""
+class KeyRangeError(Error):
+ """An error during construction of a KeyRangeItem."""
+
+
+class FieldSizeLimitError(Error):
+ """The csv module tried to read a field larger than the size limit."""
+
+ def __init__(self, limit):
+ self.message = """
+A field in your CSV input file has exceeded the current limit of %d.
+
+You can raise this limit by adding the following lines to your config file:
+
+import csv
+csv.field_size_limit(new_limit)
+
+where new_limit is number larger than the size in bytes of the largest
+field in your CSV.
+""" % limit
+ Error.__init__(self, self.message)
+
+
class NameClashError(Error):
"""A name clash occurred while trying to alias old method names."""
+
def __init__(self, old_name, new_name, klass):
Error.__init__(self, old_name, new_name, klass)
self.old_name = old_name
@@ -253,48 +290,51 @@
def GetCSVGeneratorFactory(kind, csv_filename, batch_size, csv_has_header,
openfile=open, create_csv_reader=csv.reader):
- """Return a factory that creates a CSV-based WorkItem generator.
+ """Return a factory that creates a CSV-based UploadWorkItem generator.
Args:
kind: The kind of the entities being uploaded.
csv_filename: File on disk containing CSV data.
- batch_size: Maximum number of CSV rows to stash into a WorkItem.
+ batch_size: Maximum number of CSV rows to stash into an UploadWorkItem.
csv_has_header: Whether to skip the first row of the CSV.
openfile: Used for dependency injection.
create_csv_reader: Used for dependency injection.
Returns:
A callable (accepting the Progress Queue and Progress Generators
- as input) which creates the WorkItem generator.
+ as input) which creates the UploadWorkItem generator.
"""
loader = Loader.RegisteredLoader(kind)
loader._Loader__openfile = openfile
loader._Loader__create_csv_reader = create_csv_reader
record_generator = loader.generate_records(csv_filename)
- def CreateGenerator(progress_queue, progress_generator):
- """Initialize a WorkItem generator linked to a progress generator and queue.
+ def CreateGenerator(request_manager, progress_queue, progress_generator):
+ """Initialize a UploadWorkItem generator.
Args:
+ request_manager: A RequestManager instance.
progress_queue: A ProgressQueue instance to send progress information.
progress_generator: A generator of progress information or None.
Returns:
- A WorkItemGenerator instance.
+ An UploadWorkItemGenerator instance.
"""
- return WorkItemGenerator(progress_queue,
- progress_generator,
- record_generator,
- csv_has_header,
- batch_size)
+ return UploadWorkItemGenerator(request_manager,
+ progress_queue,
+ progress_generator,
+ record_generator,
+ csv_has_header,
+ batch_size)
return CreateGenerator
-class WorkItemGenerator(object):
- """Reads rows from a row generator and generates WorkItems of batches."""
+class UploadWorkItemGenerator(object):
+ """Reads rows from a row generator and generates UploadWorkItems."""
def __init__(self,
+ request_manager,
progress_queue,
progress_generator,
record_generator,
@@ -303,12 +343,15 @@
"""Initialize a WorkItemGenerator.
Args:
+ request_manager: A RequestManager instance with which to associate
+ WorkItems.
progress_queue: A progress queue with which to associate WorkItems.
progress_generator: A generator of progress information.
record_generator: A generator of data records.
skip_first: Whether to skip the first data record.
batch_size: The number of data records per WorkItem.
"""
+ self.request_manager = request_manager
self.progress_queue = progress_queue
self.progress_generator = progress_generator
self.reader = record_generator
@@ -360,30 +403,29 @@
self.line_number += 1
def _MakeItem(self, key_start, key_end, rows, progress_key=None):
- """Makes a WorkItem containing the given rows, with the given keys.
+ """Makes a UploadWorkItem containing the given rows, with the given keys.
Args:
- key_start: The start key for the WorkItem.
- key_end: The end key for the WorkItem.
- rows: A list of the rows for the WorkItem.
- progress_key: The progress key for the WorkItem
+ key_start: The start key for the UploadWorkItem.
+ key_end: The end key for the UploadWorkItem.
+ rows: A list of the rows for the UploadWorkItem.
+ progress_key: The progress key for the UploadWorkItem
Returns:
- A WorkItem instance for the given batch.
+ An UploadWorkItem instance for the given batch.
"""
assert rows
- item = WorkItem(self.progress_queue, rows,
- key_start, key_end,
- progress_key=progress_key)
+ item = UploadWorkItem(self.request_manager, self.progress_queue, rows,
+ key_start, key_end, progress_key=progress_key)
return item
def Batches(self):
- """Reads from the record_generator and generates WorkItems.
+ """Reads from the record_generator and generates UploadWorkItems.
Yields:
- Instances of class WorkItem
+ Instances of class UploadWorkItem
Raises:
ResumeError: If the progress database and data file indicate a different
@@ -468,37 +510,50 @@
"""
csv_file = self.openfile(self.csv_filename, 'rb')
reader = self.create_csv_reader(csv_file, skipinitialspace=True)
- return reader
-
-
-class KeyRangeGenerator(object):
+ try:
+ for record in reader:
+ yield record
+ except csv.Error, e:
+ if e.args and e.args[0].startswith('field larger than field limit'):
+ limit = e.args[1]
+ raise FieldSizeLimitError(limit)
+ else:
+ raise
+
+
+class KeyRangeItemGenerator(object):
"""Generates ranges of keys to download.
Reads progress information from the progress database and creates
- KeyRange objects corresponding to incompletely downloaded parts of an
+ KeyRangeItem objects corresponding to incompletely downloaded parts of an
export.
"""
- def __init__(self, kind, progress_queue, progress_generator):
- """Initialize the KeyRangeGenerator.
+ def __init__(self, request_manager, kind, progress_queue, progress_generator,
+ key_range_item_factory):
+ """Initialize the KeyRangeItemGenerator.
Args:
+ request_manager: A RequestManager instance.
kind: The kind of entities being transferred.
progress_queue: A queue used for tracking progress information.
progress_generator: A generator of prior progress information, or None
if there is no prior status.
+ key_range_item_factory: A factory to produce KeyRangeItems.
"""
+ self.request_manager = request_manager
self.kind = kind
self.row_count = 0
self.xfer_count = 0
self.progress_queue = progress_queue
self.progress_generator = progress_generator
+ self.key_range_item_factory = key_range_item_factory
def Batches(self):
"""Iterate through saved progress information.
Yields:
- KeyRange instances corresponding to undownloaded key ranges.
+ KeyRangeItem instances corresponding to undownloaded key ranges.
"""
if self.progress_generator is not None:
for progress_key, state, key_start, key_end in self.progress_generator:
@@ -506,397 +561,27 @@
key_start = ParseKey(key_start)
key_end = ParseKey(key_end)
- result = KeyRange(self.progress_queue,
- self.kind,
- key_start=key_start,
- key_end=key_end,
- progress_key=progress_key,
- direction=KeyRange.ASC,
- state=STATE_READ)
+ key_range = KeyRange(key_start=key_start,
+ key_end=key_end)
+
+ result = self.key_range_item_factory(self.request_manager,
+ self.progress_queue,
+ self.kind,
+ key_range,
+ progress_key=progress_key,
+ state=STATE_READ)
yield result
else:
-
- yield KeyRange(
- self.progress_queue, self.kind,
- key_start=None,
- key_end=None,
- direction=KeyRange.DESC)
-
-
-class ReQueue(object):
- """A special thread-safe queue.
-
- A ReQueue allows unfinished work items to be returned with a call to
- reput(). When an item is reput, task_done() should *not* be called
- in addition, getting an item that has been reput does not increase
- the number of outstanding tasks.
-
- This class shares an interface with Queue.Queue and provides the
- additional reput method.
- """
-
- def __init__(self,
- queue_capacity,
- requeue_capacity=None,
- queue_factory=Queue.Queue,
- get_time=time.time):
- """Initialize a ReQueue instance.
-
- Args:
- queue_capacity: The number of items that can be put in the ReQueue.
- requeue_capacity: The numer of items that can be reput in the ReQueue.
- queue_factory: Used for dependency injection.
- get_time: Used for dependency injection.
- """
- if requeue_capacity is None:
- requeue_capacity = queue_capacity
-
- self.get_time = get_time
- self.queue = queue_factory(queue_capacity)
- self.requeue = queue_factory(requeue_capacity)
- self.lock = threading.Lock()
- self.put_cond = threading.Condition(self.lock)
- self.get_cond = threading.Condition(self.lock)
-
- def _DoWithTimeout(self,
- action,
- exc,
- wait_cond,
- done_cond,
- lock,
- timeout=None,
- block=True):
- """Performs the given action with a timeout.
-
- The action must be non-blocking, and raise an instance of exc on a
- recoverable failure. If the action fails with an instance of exc,
- we wait on wait_cond before trying again. Failure after the
- timeout is reached is propagated as an exception. Success is
- signalled by notifying on done_cond and returning the result of
- the action. If action raises any exception besides an instance of
- exc, it is immediately propagated.
-
- Args:
- action: A callable that performs a non-blocking action.
- exc: An exception type that is thrown by the action to indicate
- a recoverable error.
- wait_cond: A condition variable which should be waited on when
- action throws exc.
- done_cond: A condition variable to signal if the action returns.
- lock: The lock used by wait_cond and done_cond.
- timeout: A non-negative float indicating the maximum time to wait.
- block: Whether to block if the action cannot complete immediately.
-
- Returns:
- The result of the action, if it is successful.
-
- Raises:
- ValueError: If the timeout argument is negative.
- """
- if timeout is not None and timeout < 0.0:
- raise ValueError('\'timeout\' must not be a negative number')
- if not block:
- timeout = 0.0
- result = None
- success = False
- start_time = self.get_time()
- lock.acquire()
- try:
- while not success:
- try:
- result = action()
- success = True
- except Exception, e:
- if not isinstance(e, exc):
- raise e
- if timeout is not None:
- elapsed_time = self.get_time() - start_time
- timeout -= elapsed_time
- if timeout <= 0.0:
- raise e
- wait_cond.wait(timeout)
- finally:
- if success:
- done_cond.notify()
- lock.release()
- return result
-
- def put(self, item, block=True, timeout=None):
- """Put an item into the requeue.
-
- Args:
- item: An item to add to the requeue.
- block: Whether to block if the requeue is full.
- timeout: Maximum on how long to wait until the queue is non-full.
-
- Raises:
- Queue.Full if the queue is full and the timeout expires.
- """
- def PutAction():
- self.queue.put(item, block=False)
- self._DoWithTimeout(PutAction,
- Queue.Full,
- self.get_cond,
- self.put_cond,
- self.lock,
- timeout=timeout,
- block=block)
-
- def reput(self, item, block=True, timeout=None):
- """Re-put an item back into the requeue.
-
- Re-putting an item does not increase the number of outstanding
- tasks, so the reput item should be uniquely associated with an
- item that was previously removed from the requeue and for which
- TaskDone has not been called.
-
- Args:
- item: An item to add to the requeue.
- block: Whether to block if the requeue is full.
- timeout: Maximum on how long to wait until the queue is non-full.
-
- Raises:
- Queue.Full is the queue is full and the timeout expires.
- """
- def ReputAction():
- self.requeue.put(item, block=False)
- self._DoWithTimeout(ReputAction,
- Queue.Full,
- self.get_cond,
- self.put_cond,
- self.lock,
- timeout=timeout,
- block=block)
-
- def get(self, block=True, timeout=None):
- """Get an item from the requeue.
-
- Args:
- block: Whether to block if the requeue is empty.
- timeout: Maximum on how long to wait until the requeue is non-empty.
-
- Returns:
- An item from the requeue.
-
- Raises:
- Queue.Empty if the queue is empty and the timeout expires.
- """
- def GetAction():
- try:
- result = self.requeue.get(block=False)
- self.requeue.task_done()
- except Queue.Empty:
- result = self.queue.get(block=False)
- return result
- return self._DoWithTimeout(GetAction,
- Queue.Empty,
- self.put_cond,
- self.get_cond,
- self.lock,
- timeout=timeout,
- block=block)
-
- def join(self):
- """Blocks until all of the items in the requeue have been processed."""
- self.queue.join()
-
- def task_done(self):
- """Indicate that a previously enqueued item has been fully processed."""
- self.queue.task_done()
-
- def empty(self):
- """Returns true if the requeue is empty."""
- return self.queue.empty() and self.requeue.empty()
-
- def get_nowait(self):
- """Try to get an item from the queue without blocking."""
- return self.get(block=False)
-
- def qsize(self):
- return self.queue.qsize() + self.requeue.qsize()
-
-
-class ThrottleHandler(urllib2.BaseHandler):
- """A urllib2 handler for http and https requests that adds to a throttle."""
-
- def __init__(self, throttle):
- """Initialize a ThrottleHandler.
-
- Args:
- throttle: A Throttle instance to call for bandwidth and http/https request
- throttling.
- """
- self.throttle = throttle
-
- def AddRequest(self, throttle_name, req):
- """Add to bandwidth throttle for given request.
-
- Args:
- throttle_name: The name of the bandwidth throttle to add to.
- req: The request whose size will be added to the throttle.
- """
- size = 0
- for key, value in req.headers.iteritems():
- size += len('%s: %s\n' % (key, value))
- for key, value in req.unredirected_hdrs.iteritems():
- size += len('%s: %s\n' % (key, value))
- (unused_scheme,
- unused_host_port, url_path,
- unused_query, unused_fragment) = urlparse.urlsplit(req.get_full_url())
- size += len('%s %s HTTP/1.1\n' % (req.get_method(), url_path))
- data = req.get_data()
- if data:
- size += len(data)
- self.throttle.AddTransfer(throttle_name, size)
-
- def AddResponse(self, throttle_name, res):
- """Add to bandwidth throttle for given response.
-
- Args:
- throttle_name: The name of the bandwidth throttle to add to.
- res: The response whose size will be added to the throttle.
- """
- content = res.read()
- def ReturnContent():
- return content
- res.read = ReturnContent
- size = len(content)
- headers = res.info()
- for key, value in headers.items():
- size += len('%s: %s\n' % (key, value))
- self.throttle.AddTransfer(throttle_name, size)
-
- def http_request(self, req):
- """Process an HTTP request.
-
- If the throttle is over quota, sleep first. Then add request size to
- throttle before returning it to be sent.
-
- Args:
- req: A urllib2.Request object.
-
- Returns:
- The request passed in.
- """
- self.throttle.Sleep()
- self.AddRequest(BANDWIDTH_UP, req)
- return req
-
- def https_request(self, req):
- """Process an HTTPS request.
-
- If the throttle is over quota, sleep first. Then add request size to
- throttle before returning it to be sent.
-
- Args:
- req: A urllib2.Request object.
-
- Returns:
- The request passed in.
- """
- self.throttle.Sleep()
- self.AddRequest(HTTPS_BANDWIDTH_UP, req)
- return req
-
- def http_response(self, unused_req, res):
- """Process an HTTP response.
-
- The size of the response is added to the bandwidth throttle and the request
- throttle is incremented by one.
-
- Args:
- unused_req: The urllib2 request for this response.
- res: A urllib2 response object.
-
- Returns:
- The response passed in.
- """
- self.AddResponse(BANDWIDTH_DOWN, res)
- self.throttle.AddTransfer(REQUESTS, 1)
- return res
-
- def https_response(self, unused_req, res):
- """Process an HTTPS response.
-
- The size of the response is added to the bandwidth throttle and the request
- throttle is incremented by one.
-
- Args:
- unused_req: The urllib2 request for this response.
- res: A urllib2 response object.
-
- Returns:
- The response passed in.
- """
- self.AddResponse(HTTPS_BANDWIDTH_DOWN, res)
- self.throttle.AddTransfer(HTTPS_REQUESTS, 1)
- return res
-
-
-class ThrottledHttpRpcServer(appengine_rpc.HttpRpcServer):
- """Provides a simplified RPC-style interface for HTTP requests.
-
- This RPC server uses a Throttle to prevent exceeding quotas.
- """
-
- def __init__(self, throttle, request_manager, *args, **kwargs):
- """Initialize a ThrottledHttpRpcServer.
-
- Also sets request_manager.rpc_server to the ThrottledHttpRpcServer instance.
-
- Args:
- throttle: A Throttles instance.
- request_manager: A RequestManager instance.
- args: Positional arguments to pass through to
- appengine_rpc.HttpRpcServer.__init__
- kwargs: Keyword arguments to pass through to
- appengine_rpc.HttpRpcServer.__init__
- """
- self.throttle = throttle
- appengine_rpc.HttpRpcServer.__init__(self, *args, **kwargs)
- request_manager.rpc_server = self
-
- def _GetOpener(self):
- """Returns an OpenerDirector that supports cookies and ignores redirects.
-
- Returns:
- A urllib2.OpenerDirector object.
- """
- opener = appengine_rpc.HttpRpcServer._GetOpener(self)
- opener.add_handler(ThrottleHandler(self.throttle))
-
- return opener
-
-
-def ThrottledHttpRpcServerFactory(throttle, request_manager):
- """Create a factory to produce ThrottledHttpRpcServer for a given throttle.
-
- Args:
- throttle: A Throttle instance to use for the ThrottledHttpRpcServer.
- request_manager: A RequestManager instance.
-
- Returns:
- A factory to produce a ThrottledHttpRpcServer.
- """
-
- def MakeRpcServer(*args, **kwargs):
- """Factory to produce a ThrottledHttpRpcServer.
-
- Args:
- args: Positional args to pass to ThrottledHttpRpcServer.
- kwargs: Keyword args to pass to ThrottledHttpRpcServer.
-
- Returns:
- A ThrottledHttpRpcServer instance.
- """
- kwargs['account_type'] = 'HOSTED_OR_GOOGLE'
- kwargs['save_cookies'] = True
- return ThrottledHttpRpcServer(throttle, request_manager, *args, **kwargs)
- return MakeRpcServer
-
-
-class ExportResult(object):
- """Holds the decoded content for the result of an export requests."""
+ key_range = KeyRange()
+
+ yield self.key_range_item_factory(self.request_manager,
+ self.progress_queue,
+ self.kind,
+ key_range)
+
+
+class DownloadResult(object):
+ """Holds the result of an entity download."""
def __init__(self, continued, direction, keys, entities):
self.continued = continued
@@ -905,21 +590,31 @@
self.entities = entities
self.count = len(keys)
assert self.count == len(entities)
- assert direction in (KeyRange.ASC, KeyRange.DESC)
+ assert direction in (key_range_module.KeyRange.ASC,
+ key_range_module.KeyRange.DESC)
if self.count > 0:
- if direction == KeyRange.ASC:
+ if direction == key_range_module.KeyRange.ASC:
self.key_start = keys[0]
self.key_end = keys[-1]
else:
self.key_start = keys[-1]
self.key_end = keys[0]
+ def Entities(self):
+ """Returns the list of entities for this result in key order."""
+ if self.direction == key_range_module.KeyRange.ASC:
+ return list(self.entities)
+ else:
+ result = list(self.entities)
+ result.reverse()
+ return result
+
def __str__(self):
return 'continued = %s\n%s' % (
- str(self.continued), '\n'.join(self.entities))
-
-
-class _WorkItem(object):
+ str(self.continued), '\n'.join(str(self.entities)))
+
+
+class _WorkItem(adaptive_thread_pool.WorkItem):
"""Holds a description of a unit of upload or download work."""
def __init__(self, progress_queue, key_start, key_end, state_namer,
@@ -928,20 +623,101 @@
Args:
progress_queue: A queue used for tracking progress information.
- key_start: The starting key, inclusive.
- key_end: The ending key, inclusive.
+ key_start: The start key of the work item.
+ key_end: The end key of the work item.
state_namer: Function to describe work item states.
state: The initial state of the work item.
progress_key: If this WorkItem represents state from a prior run,
then this will be the key within the progress database.
"""
+ adaptive_thread_pool.WorkItem.__init__(self,
+ '[%s-%s]' % (key_start, key_end))
self.progress_queue = progress_queue
- self.key_start = key_start
- self.key_end = key_end
self.state_namer = state_namer
self.state = state
self.progress_key = progress_key
self.progress_event = threading.Event()
+ self.key_start = key_start
+ self.key_end = key_end
+ self.error = None
+ self.traceback = None
+
+ def _TransferItem(self, thread_pool):
+ raise NotImplementedError()
+
+ def SetError(self):
+ """Sets the error and traceback information for this thread.
+
+ This must be called from an exception handler.
+ """
+ if not self.error:
+ exc_info = sys.exc_info()
+ self.error = exc_info[1]
+ self.traceback = exc_info[2]
+
+ def PerformWork(self, thread_pool):
+ """Perform the work of this work item and report the results.
+
+ Args:
+ thread_pool: An AdaptiveThreadPool instance.
+
+ Returns:
+ A tuple (status, instruction) of the work status and an instruction
+ for the ThreadGate.
+ """
+ status = adaptive_thread_pool.WorkItem.FAILURE
+ instruction = adaptive_thread_pool.ThreadGate.DECREASE
+
+ try:
+ self.MarkAsTransferring()
+
+ try:
+ transfer_time = self._TransferItem(thread_pool)
+ if transfer_time is None:
+ status = adaptive_thread_pool.WorkItem.RETRY
+ instruction = adaptive_thread_pool.ThreadGate.HOLD
+ else:
+ logger.debug('[%s] %s Transferred %d entities in %0.1f seconds',
+ threading.currentThread().getName(), self, self.count,
+ transfer_time)
+ sys.stdout.write('.')
+ sys.stdout.flush()
+ status = adaptive_thread_pool.WorkItem.SUCCESS
+ if transfer_time <= MAXIMUM_INCREASE_DURATION:
+ instruction = adaptive_thread_pool.ThreadGate.INCREASE
+ elif transfer_time <= MAXIMUM_HOLD_DURATION:
+ instruction = adaptive_thread_pool.ThreadGate.HOLD
+ except (db.InternalError, db.NotSavedError, db.Timeout,
+ db.TransactionFailedError,
+ apiproxy_errors.OverQuotaError,
+ apiproxy_errors.DeadlineExceededError,
+ apiproxy_errors.ApplicationError), e:
+ status = adaptive_thread_pool.WorkItem.RETRY
+ logger.exception('Retrying on non-fatal datastore error: %s', e)
+ except urllib2.HTTPError, e:
+ http_status = e.code
+ if http_status == 403 or (http_status >= 500 and http_status < 600):
+ status = adaptive_thread_pool.WorkItem.RETRY
+ logger.exception('Retrying on non-fatal HTTP error: %d %s',
+ http_status, e.msg)
+ else:
+ self.SetError()
+ status = adaptive_thread_pool.WorkItem.FAILURE
+ except urllib2.URLError, e:
+ if IsURLErrorFatal(e):
+ self.SetError()
+ status = adaptive_thread_pool.WorkItem.FAILURE
+ else:
+ status = adaptive_thread_pool.WorkItem.RETRY
+ logger.exception('Retrying on non-fatal URL error: %s', e.reason)
+
+ finally:
+ if status == adaptive_thread_pool.WorkItem.SUCCESS:
+ self.MarkAsTransferred()
+ else:
+ self.MarkAsError()
+
+ return (status, instruction)
def _AssertInState(self, *states):
"""Raises an Error if the state of this range is not in states."""
@@ -963,7 +739,7 @@
def MarkAsTransferring(self):
"""Mark this _WorkItem as transferring, updating the progress database."""
- self._AssertInState(STATE_READ, STATE_NOT_GOT)
+ self._AssertInState(STATE_READ, STATE_ERROR)
self._AssertProgressKey()
self._StateTransition(STATE_GETTING, blocking=True)
@@ -975,7 +751,7 @@
"""Mark this _WorkItem as failed, updating the progress database."""
self._AssertInState(STATE_GETTING)
self._AssertProgressKey()
- self._StateTransition(STATE_NOT_GOT, blocking=True)
+ self._StateTransition(STATE_ERROR, blocking=True)
def _StateTransition(self, new_state, blocking=False):
"""Transition the work item to a new state, storing progress information.
@@ -998,12 +774,12 @@
-class WorkItem(_WorkItem):
+class UploadWorkItem(_WorkItem):
"""Holds a unit of uploading work.
- A WorkItem represents a number of entities that need to be uploaded to
+ A UploadWorkItem represents a number of entities that need to be uploaded to
Google App Engine. These entities are encoded in the "content" field of
- the WorkItem, and will be POST'd as-is to the server.
+ the UploadWorkItem, and will be POST'd as-is to the server.
The entities are identified by a range of numeric keys, inclusively. In
the case of a resumption of an upload, or a replay to correct errors,
@@ -1013,16 +789,17 @@
fill the entire range, they must simply bound a range of valid keys.
"""
- def __init__(self, progress_queue, rows, key_start, key_end,
+ def __init__(self, request_manager, progress_queue, rows, key_start, key_end,
progress_key=None):
- """Initialize the WorkItem instance.
+ """Initialize the UploadWorkItem instance.
Args:
+ request_manager: A RequestManager instance.
progress_queue: A queue used for tracking progress information.
rows: A list of pairs of a line number and a list of column values
key_start: The (numeric) starting key, inclusive.
key_end: The (numeric) ending key, inclusive.
- progress_key: If this WorkItem represents state from a prior run,
+ progress_key: If this UploadWorkItem represents state from a prior run,
then this will be the key within the progress database.
"""
_WorkItem.__init__(self, progress_queue, key_start, key_end,
@@ -1033,6 +810,7 @@
assert isinstance(key_end, (int, long))
assert key_start <= key_end
+ self.request_manager = request_manager
self.rows = rows
self.content = None
self.count = len(rows)
@@ -1040,8 +818,24 @@
def __str__(self):
return '[%s-%s]' % (self.key_start, self.key_end)
+ def _TransferItem(self, thread_pool, get_time=time.time):
+ """Transfers the entities associated with an item.
+
+ Args:
+ thread_pool: An AdaptiveThreadPool instance.
+ get_time: Used for dependency injection.
+ """
+ t = get_time()
+ if not self.content:
+ self.content = self.request_manager.EncodeContent(self.rows)
+ try:
+ self.request_manager.PostEntities(self.content)
+ except:
+ raise
+ return get_time() - t
+
def MarkAsTransferred(self):
- """Mark this WorkItem as sucessfully-sent to the server."""
+ """Mark this UploadWorkItem as sucessfully-sent to the server."""
self._AssertInState(STATE_SENDING)
self._AssertProgressKey()
@@ -1068,45 +862,31 @@
implementation_class = db.class_for_kind(kind_or_class_key)
return implementation_class
-class EmptyQuery(db.Query):
- def get(self):
- return None
-
- def fetch(self, limit=1000, offset=0):
- return []
-
- def count(self, limit=1000):
- return 0
-
def KeyLEQ(key1, key2):
"""Compare two keys for less-than-or-equal-to.
- All keys with numeric ids come before all keys with names.
+ All keys with numeric ids come before all keys with names. None represents
+ an unbounded end-point so it is both greater and less than any other key.
Args:
- key1: An int or db.Key instance.
- key2: An int or db.Key instance.
+ key1: An int or datastore.Key instance.
+ key2: An int or datastore.Key instance.
Returns:
True if key1 <= key2
"""
- if isinstance(key1, int) and isinstance(key2, int):
- return key1 <= key2
if key1 is None or key2 is None:
return True
- if key1.id() and not key2.id():
- return True
- return key1.id_or_name() <= key2.id_or_name()
-
-
-class KeyRange(_WorkItem):
- """Represents an item of download work.
-
- A KeyRange object represents a key range (key_start, key_end) and a
- scan direction (KeyRange.DESC or KeyRange.ASC). The KeyRange object
- has an associated state: STATE_READ, STATE_GETTING, STATE_GOT, and
- STATE_ERROR.
+ return key1 <= key2
+
+
+class KeyRangeItem(_WorkItem):
+ """Represents an item of work that scans over a key range.
+
+ A KeyRangeItem object represents holds a KeyRange
+ and has an associated state: STATE_READ, STATE_GETTING, STATE_GOT,
+ and STATE_ERROR.
- STATE_READ indicates the range ready to be downloaded by a worker thread.
- STATE_GETTING indicates the range is currently being downloaded.
@@ -1114,280 +894,143 @@
- STATE_ERROR indicates that an error occurred during the last download
attempt
- KeyRanges not in the STATE_GOT state are stored in the progress database.
- When a piece of KeyRange work is downloaded, the download may cover only
- a portion of the range. In this case, the old KeyRange is removed from
+ KeyRangeItems not in the STATE_GOT state are stored in the progress database.
+ When a piece of KeyRangeItem work is downloaded, the download may cover only
+ a portion of the range. In this case, the old KeyRangeItem is removed from
the progress database and ranges covering the undownloaded range are
generated and stored as STATE_READ in the export progress database.
"""
- DESC = 0
- ASC = 1
-
- MAX_KEY_LEN = 500
-
def __init__(self,
+ request_manager,
progress_queue,
kind,
- direction,
- key_start=None,
- key_end=None,
- include_start=True,
- include_end=True,
+ key_range,
progress_key=None,
state=STATE_READ):
- """Initialize a KeyRange object.
+ """Initialize a KeyRangeItem object.
Args:
+ request_manager: A RequestManager instance.
progress_queue: A queue used for tracking progress information.
kind: The kind of entities for this range.
- direction: The direction of the query for this range.
- key_start: The starting key for this range.
- key_end: The ending key for this range.
- include_start: Whether the start key should be included in the range.
- include_end: Whether the end key should be included in the range.
+ key_range: A KeyRange instance for this work item.
progress_key: The key for this range within the progress database.
state: The initial state of this range.
-
- Raises:
- KeyRangeError: if key_start is None.
"""
- assert direction in (KeyRange.ASC, KeyRange.DESC)
- _WorkItem.__init__(self, progress_queue, key_start, key_end,
- ExportStateName, state=state, progress_key=progress_key)
+ _WorkItem.__init__(self, progress_queue, key_range.key_start,
+ key_range.key_end, ExportStateName, state=state,
+ progress_key=progress_key)
+ self.request_manager = request_manager
self.kind = kind
- self.direction = direction
- self.export_result = None
+ self.key_range = key_range
+ self.download_result = None
self.count = 0
- self.include_start = include_start
- self.include_end = include_end
- self.SPLIT_KEY = db.Key.from_path(self.kind, unichr(0))
+ self.key_start = key_range.key_start
+ self.key_end = key_range.key_end
def __str__(self):
- return '[%s-%s]' % (PrettyKey(self.key_start), PrettyKey(self.key_end))
+ return str(self.key_range)
def __repr__(self):
return self.__str__()
def MarkAsTransferred(self):
- """Mark this KeyRange as transferred, updating the progress database."""
+ """Mark this KeyRangeItem as transferred, updating the progress database."""
pass
- def Process(self, export_result, num_threads, batch_size, work_queue):
- """Mark this KeyRange as success, updating the progress database.
-
- Process will split this KeyRange based on the content of export_result and
- adds the unfinished ranges to the work queue.
+ def Process(self, download_result, thread_pool, batch_size,
+ new_state=STATE_GOT):
+ """Mark this KeyRangeItem as success, updating the progress database.
+
+ Process will split this KeyRangeItem based on the content of
+ download_result and adds the unfinished ranges to the work queue.
Args:
- export_result: An ExportResult instance.
- num_threads: The number of threads for parallel transfers.
+ download_result: A DownloadResult instance.
+ thread_pool: An AdaptiveThreadPool instance.
batch_size: The number of entities to transfer per request.
- work_queue: The work queue to add unfinished ranges to.
-
- Returns:
- A list of KeyRanges representing undownloaded datastore key ranges.
+ new_state: The state to transition the completed range to.
"""
self._AssertInState(STATE_GETTING)
self._AssertProgressKey()
- self.export_result = export_result
- self.count = len(export_result.keys)
- if export_result.continued:
- self._FinishedRange()._StateTransition(STATE_GOT, blocking=True)
- self._AddUnfinishedRanges(num_threads, batch_size, work_queue)
+ self.download_result = download_result
+ self.count = len(download_result.keys)
+ if download_result.continued:
+ self._FinishedRange()._StateTransition(new_state, blocking=True)
+ self._AddUnfinishedRanges(thread_pool, batch_size)
else:
- self._StateTransition(STATE_GOT, blocking=True)
+ self._StateTransition(new_state, blocking=True)
def _FinishedRange(self):
- """Returns the range completed by the export_result.
-
- Returns:
- A KeyRange representing a completed range.
- """
- assert self.export_result is not None
-
- if self.direction == KeyRange.ASC:
- key_start = self.key_start
- if self.export_result.continued:
- key_end = self.export_result.key_end
- else:
- key_end = self.key_end
- else:
- key_end = self.key_end
- if self.export_result.continued:
- key_start = self.export_result.key_start
- else:
- key_start = self.key_start
-
- result = KeyRange(self.progress_queue,
- self.kind,
- key_start=key_start,
- key_end=key_end,
- direction=self.direction)
-
- result.progress_key = self.progress_key
- result.export_result = self.export_result
- result.state = self.state
- result.count = self.count
- return result
-
- def FilterQuery(self, query):
- """Add query filter to restrict to this key range.
-
- Args:
- query: A db.Query instance.
- """
- if self.key_start == self.key_end and not (
- self.include_start or self.include_end):
- return EmptyQuery()
- if self.include_start:
- start_comparator = '>='
- else:
- start_comparator = '>'
- if self.include_end:
- end_comparator = '<='
- else:
- end_comparator = '<'
- if self.key_start and self.key_end:
- query.filter('__key__ %s' % start_comparator, self.key_start)
- query.filter('__key__ %s' % end_comparator, self.key_end)
- elif self.key_start:
- query.filter('__key__ %s' % start_comparator, self.key_start)
- elif self.key_end:
- query.filter('__key__ %s' % end_comparator, self.key_end)
-
- return query
-
- def MakeParallelQuery(self):
- """Construct a query for this key range, for parallel downloading.
-
- Returns:
- A db.Query instance.
-
- Raises:
- KeyRangeError: if self.direction is not one of
- KeyRange.ASC, KeyRange.DESC
- """
- if self.direction == KeyRange.ASC:
- direction = ''
- elif self.direction == KeyRange.DESC:
- direction = '-'
- else:
- raise KeyRangeError('KeyRange direction unexpected: %s', self.direction)
- query = db.Query(GetImplementationClass(self.kind))
- query.order('%s__key__' % direction)
-
- return self.FilterQuery(query)
-
- def MakeSerialQuery(self):
- """Construct a query for this key range without descending __key__ scan.
+ """Returns the range completed by the download_result.
Returns:
- A db.Query instance.
+ A KeyRangeItem representing a completed range.
"""
- query = db.Query(GetImplementationClass(self.kind))
- query.order('__key__')
-
- return self.FilterQuery(query)
-
- def _BisectStringRange(self, start, end):
- if start == end:
- return (start, start, end)
- start += '\0'
- end += '\0'
- midpoint = []
- expected_max = 127
- for i in xrange(min(len(start), len(end))):
- if start[i] == end[i]:
- midpoint.append(start[i])
+ assert self.download_result is not None
+
+ if self.key_range.direction == key_range_module.KeyRange.ASC:
+ key_start = self.key_range.key_start
+ if self.download_result.continued:
+ key_end = self.download_result.key_end
else:
- ord_sum = ord(start[i]) + ord(end[i])
- midpoint.append(unichr(ord_sum / 2))
- if ord_sum % 2:
- if len(start) > i + 1:
- ord_start = ord(start[i+1])
- else:
- ord_start = 0
- if ord_start < expected_max:
- ord_split = (expected_max + ord_start) / 2
- else:
- ord_split = (0xFFFF + ord_start) / 2
- midpoint.append(unichr(ord_split))
- break
- return (start[:-1], ''.join(midpoint), end[:-1])
-
- def SplitRange(self, key_start, include_start, key_end, include_end,
- export_result, num_threads, batch_size, work_queue):
- """Split the key range [key_start, key_end] into a list of ranges."""
- if export_result.direction == KeyRange.ASC:
- key_start = export_result.key_end
- include_start = False
+ key_end = self.key_range.key_end
else:
- key_end = export_result.key_start
- include_end = False
- key_pairs = []
- if not key_start:
- key_pairs.append((key_start, include_start, key_end, include_end,
- KeyRange.ASC))
- elif not key_end:
- key_pairs.append((key_start, include_start, key_end, include_end,
- KeyRange.DESC))
- elif work_queue.qsize() > 2 * num_threads:
- key_pairs.append((key_start, include_start, key_end, include_end,
- KeyRange.ASC))
- elif key_start.id() and key_end.id():
- if key_end.id() - key_start.id() > batch_size:
- key_half = db.Key.from_path(self.kind,
- (key_start.id() + key_end.id()) / 2)
- key_pairs.append((key_start, include_start,
- key_half, True,
- KeyRange.DESC))
- key_pairs.append((key_half, False,
- key_end, include_end,
- KeyRange.ASC))
+ key_end = self.key_range.key_end
+ if self.download_result.continued:
+ key_start = self.download_result.key_start
else:
- key_pairs.append((key_start, include_start, key_end, include_end,
- KeyRange.ASC))
- elif key_start.name() and key_end.name():
- (start, middle, end) = self._BisectStringRange(key_start.name(),
- key_end.name())
- key_pairs.append((key_start, include_start,
- db.Key.from_path(self.kind, middle), True,
- KeyRange.DESC))
- key_pairs.append((db.Key.from_path(self.kind, middle), False,
- key_end, include_end,
- KeyRange.ASC))
+ key_start = self.key_range.key_start
+
+ key_range = KeyRange(key_start=key_start,
+ key_end=key_end,
+ direction=self.key_range.direction)
+
+ result = self.__class__(self.request_manager,
+ self.progress_queue,
+ self.kind,
+ key_range,
+ progress_key=self.progress_key,
+ state=self.state)
+
+ result.download_result = self.download_result
+ result.count = self.count
+ return result
+
+ def _SplitAndAddRanges(self, thread_pool, batch_size):
+ """Split the key range [key_start, key_end] into a list of ranges."""
+ if self.download_result.direction == key_range_module.KeyRange.ASC:
+ key_range = KeyRange(
+ key_start=self.download_result.key_end,
+ key_end=self.key_range.key_end,
+ include_start=False)
else:
- assert key_start.id() and key_end.name()
- key_pairs.append((key_start, include_start,
- self.SPLIT_KEY, False,
- KeyRange.DESC))
- key_pairs.append((self.SPLIT_KEY, True,
- key_end, include_end,
- KeyRange.ASC))
-
- ranges = [KeyRange(self.progress_queue,
- self.kind,
- key_start=start,
- include_start=include_start,
- key_end=end,
- include_end=include_end,
- direction=direction)
- for (start, include_start, end, include_end, direction)
- in key_pairs]
+ key_range = KeyRange(
+ key_start=self.key_range.key_start,
+ key_end=self.download_result.key_start,
+ include_end=False)
+
+ if thread_pool.QueuedItemCount() > 2 * thread_pool.num_threads():
+ ranges = [key_range]
+ else:
+ ranges = key_range.split_range(batch_size=batch_size)
for key_range in ranges:
- key_range.MarkAsRead()
- work_queue.put(key_range, block=True)
-
- def _AddUnfinishedRanges(self, num_threads, batch_size, work_queue):
- """Adds incomplete KeyRanges to the work_queue.
+ key_range_item = self.__class__(self.request_manager,
+ self.progress_queue,
+ self.kind,
+ key_range)
+ key_range_item.MarkAsRead()
+ thread_pool.SubmitItem(key_range_item, block=True)
+
+ def _AddUnfinishedRanges(self, thread_pool, batch_size):
+ """Adds incomplete KeyRanges to the thread_pool.
Args:
- num_threads: The number of threads for parallel transfers.
+ thread_pool: An AdaptiveThreadPool instance.
batch_size: The number of entities to transfer per request.
- work_queue: The work queue to add unfinished ranges to.
Returns:
A list of KeyRanges representing incomplete datastore key ranges.
@@ -1395,15 +1038,43 @@
Raises:
KeyRangeError: if this key range has already been completely transferred.
"""
- assert self.export_result is not None
- if self.export_result.continued:
- self.SplitRange(self.key_start, self.include_start, self.key_end,
- self.include_end, self.export_result,
- num_threads, batch_size, work_queue)
+ assert self.download_result is not None
+ if self.download_result.continued:
+ self._SplitAndAddRanges(thread_pool, batch_size)
else:
raise KeyRangeError('No unfinished part of key range.')
+class DownloadItem(KeyRangeItem):
+ """A KeyRangeItem for downloading key ranges."""
+
+ def _TransferItem(self, thread_pool, get_time=time.time):
+ """Transfers the entities associated with an item."""
+ t = get_time()
+ download_result = self.request_manager.GetEntities(self)
+ transfer_time = get_time() - t
+ self.Process(download_result, thread_pool,
+ self.request_manager.batch_size)
+ return transfer_time
+
+
+class MapperItem(KeyRangeItem):
+ """A KeyRangeItem for mapping over key ranges."""
+
+ def _TransferItem(self, thread_pool, get_time=time.time):
+ t = get_time()
+ download_result = self.request_manager.GetEntities(self)
+ transfer_time = get_time() - t
+ mapper = self.request_manager.GetMapper()
+ try:
+ mapper.batch_apply(download_result.Entities())
+ except MapperRetry:
+ return None
+ self.Process(download_result, thread_pool,
+ self.request_manager.batch_size)
+ return transfer_time
+
+
class RequestManager(object):
"""A class which wraps a connection to the server."""
@@ -1416,7 +1087,8 @@
batch_size,
secure,
email,
- passin):
+ passin,
+ dry_run=False):
"""Initialize a RequestManager object.
Args:
@@ -1445,23 +1117,39 @@
self.parallel_download = True
self.email = email
self.passin = passin
- throttled_rpc_server_factory = ThrottledHttpRpcServerFactory(
- self.throttle, self)
+ self.mapper = None
+ self.dry_run = dry_run
+
+ if self.dry_run:
+ logger.info('Running in dry run mode, skipping remote_api setup')
+ return
+
logger.debug('Configuring remote_api. url_path = %s, '
'servername = %s' % (url_path, host_port))
+
+ def CookieHttpRpcServer(*args, **kwargs):
+ kwargs['save_cookies'] = True
+ kwargs['account_type'] = 'HOSTED_OR_GOOGLE'
+ return appengine_rpc.HttpRpcServer(*args, **kwargs)
+
remote_api_stub.ConfigureRemoteDatastore(
app_id,
url_path,
self.AuthFunction,
servername=host_port,
- rpc_server_factory=throttled_rpc_server_factory,
+ rpc_server_factory=CookieHttpRpcServer,
secure=self.secure)
+ remote_api_throttle.ThrottleRemoteDatastore(self.throttle)
logger.debug('Bulkloader using app_id: %s', os.environ['APPLICATION_ID'])
def Authenticate(self):
"""Invoke authentication if necessary."""
- logger.info('Connecting to %s', self.url_path)
- self.rpc_server.Send(self.url_path, payload=None)
+ logger.info('Connecting to %s%s', self.host_port, self.url_path)
+ if self.dry_run:
+ self.authenticated = True
+ return
+
+ remote_api_stub.MaybeInvokeAuthentication()
self.authenticated = True
def AuthFunction(self,
@@ -1506,7 +1194,7 @@
loader: Used for dependency injection.
Returns:
- A list of db.Model instances.
+ A list of datastore.Entity instances.
Raises:
ConfigurationError: if no loader is defined for self.kind
@@ -1520,77 +1208,112 @@
entities = []
for line_number, values in rows:
key = loader.generate_key(line_number, values)
- if isinstance(key, db.Key):
+ if isinstance(key, datastore.Key):
parent = key.parent()
key = key.name()
else:
parent = None
entity = loader.create_entity(values, key_name=key, parent=parent)
+
+ def ToEntity(entity):
+ if isinstance(entity, db.Model):
+ return entity._populate_entity()
+ else:
+ return entity
+
if isinstance(entity, list):
- entities.extend(entity)
+ entities.extend(map(ToEntity, entity))
elif entity:
- entities.append(entity)
+ entities.append(ToEntity(entity))
return entities
- def PostEntities(self, item):
+ def PostEntities(self, entities):
"""Posts Entity records to a remote endpoint over HTTP.
Args:
- item: A workitem containing the entities to post.
-
- Returns:
- A pair of the estimated size of the request in bytes and the response
- from the server as a str.
+ entities: A list of datastore entities.
"""
- entities = item.content
- db.put(entities)
-
- def GetEntities(self, key_range):
+ if self.dry_run:
+ return
+ datastore.Put(entities)
+
+ def _QueryForPbs(self, query):
+ """Perform the given query and return a list of entity_pb's."""
+ try:
+ query_pb = query._ToPb(limit=self.batch_size)
+ result_pb = datastore_pb.QueryResult()
+ apiproxy_stub_map.MakeSyncCall('datastore_v3', 'RunQuery', query_pb,
+ result_pb)
+ next_pb = datastore_pb.NextRequest()
+ next_pb.set_count(self.batch_size)
+ next_pb.mutable_cursor().CopyFrom(result_pb.cursor())
+ result_pb = datastore_pb.QueryResult()
+ apiproxy_stub_map.MakeSyncCall('datastore_v3', 'Next', next_pb, result_pb)
+ return result_pb.result_list()
+ except apiproxy_errors.ApplicationError, e:
+ raise datastore._ToDatastoreError(e)
+
+ def GetEntities(self, key_range_item, key_factory=datastore.Key):
"""Gets Entity records from a remote endpoint over HTTP.
Args:
- key_range: Range of keys to get.
+ key_range_item: Range of keys to get.
+ key_factory: Used for dependency injection.
Returns:
- An ExportResult instance.
+ A DownloadResult instance.
Raises:
ConfigurationError: if no Exporter is defined for self.kind
"""
- try:
- Exporter.RegisteredExporter(self.kind)
- except KeyError:
- raise ConfigurationError('No Exporter defined for kind %s.' % self.kind)
-
keys = []
entities = []
if self.parallel_download:
- query = key_range.MakeParallelQuery()
+ query = key_range_item.key_range.make_directed_datastore_query(self.kind)
try:
- results = query.fetch(self.batch_size)
+ results = self._QueryForPbs(query)
except datastore_errors.NeedIndexError:
logger.info('%s: No descending index on __key__, '
'performing serial download', self.kind)
self.parallel_download = False
if not self.parallel_download:
- key_range.direction = KeyRange.ASC
- query = key_range.MakeSerialQuery()
- results = query.fetch(self.batch_size)
+ key_range_item.key_range.direction = key_range_module.KeyRange.ASC
+ query = key_range_item.key_range.make_ascending_datastore_query(self.kind)
+ results = self._QueryForPbs(query)
size = len(results)
- for model in results:
- key = model.key()
- entities.append(cPickle.dumps(model))
+ for entity in results:
+ key = key_factory()
+ key._Key__reference = entity.key()
+ entities.append(entity)
keys.append(key)
continued = (size == self.batch_size)
- key_range.count = size
-
- return ExportResult(continued, key_range.direction, keys, entities)
+ key_range_item.count = size
+
+ return DownloadResult(continued, key_range_item.key_range.direction,
+ keys, entities)
+
+ def GetMapper(self):
+ """Returns a mapper for the registered kind.
+
+ Returns:
+ A Mapper instance.
+
+ Raises:
+ ConfigurationError: if no Mapper is defined for self.kind
+ """
+ if not self.mapper:
+ try:
+ self.mapper = Mapper.RegisteredMapper(self.kind)
+ except KeyError:
+ logger.error('No Mapper defined for kind %s.' % self.kind)
+ raise ConfigurationError('No Mapper defined for kind %s.' % self.kind)
+ return self.mapper
def InterruptibleSleep(sleep_time):
@@ -1611,357 +1334,6 @@
return
-class ThreadGate(object):
- """Manage the number of active worker threads.
-
- The ThreadGate limits the number of threads that are simultaneously
- uploading batches of records in order to implement adaptive rate
- control. The number of simultaneous upload threads that it takes to
- start causing timeout varies widely over the course of the day, so
- adaptive rate control allows the uploader to do many uploads while
- reducing the error rate and thus increasing the throughput.
-
- Initially the ThreadGate allows only one uploader thread to be active.
- For each successful upload, another thread is activated and for each
- failed upload, the number of active threads is reduced by one.
- """
-
- def __init__(self, enabled,
- threshhold1=MAXIMUM_INCREASE_DURATION,
- threshhold2=MAXIMUM_HOLD_DURATION,
- sleep=InterruptibleSleep):
- """Constructor for ThreadGate instances.
-
- Args:
- enabled: Whether the thread gate is enabled
- threshhold1: Maximum duration (in seconds) for a transfer to increase
- the number of active threads.
- threshhold2: Maximum duration (in seconds) for a transfer to not decrease
- the number of active threads.
- """
- self.enabled = enabled
- self.enabled_count = 1
- self.lock = threading.Lock()
- self.thread_semaphore = threading.Semaphore(self.enabled_count)
- self._threads = []
- self.backoff_time = 0
- self.sleep = sleep
- self.threshhold1 = threshhold1
- self.threshhold2 = threshhold2
-
- def Register(self, thread):
- """Register a thread with the thread gate."""
- self._threads.append(thread)
-
- def Threads(self):
- """Yields the registered threads."""
- for thread in self._threads:
- yield thread
-
- def EnableThread(self):
- """Enable one more worker thread."""
- self.lock.acquire()
- try:
- self.enabled_count += 1
- finally:
- self.lock.release()
- self.thread_semaphore.release()
-
- def EnableAllThreads(self):
- """Enable all worker threads."""
- for unused_idx in xrange(len(self._threads) - self.enabled_count):
- self.EnableThread()
-
- def StartWork(self):
- """Starts a critical section in which the number of workers is limited.
-
- If thread throttling is enabled then this method starts a critical
- section which allows self.enabled_count simultaneously operating
- threads. The critical section is ended by calling self.FinishWork().
- """
- if self.enabled:
- self.thread_semaphore.acquire()
- if self.backoff_time > 0.0:
- if not threading.currentThread().exit_flag:
- logger.info('Backing off: %.1f seconds',
- self.backoff_time)
- self.sleep(self.backoff_time)
-
- def FinishWork(self):
- """Ends a critical section started with self.StartWork()."""
- if self.enabled:
- self.thread_semaphore.release()
-
- def TransferSuccess(self, duration):
- """Informs the throttler that an item was successfully sent.
-
- If thread throttling is enabled and the duration is low enough, this
- method will cause an additional thread to run in the critical section.
-
- Args:
- duration: The duration of the transfer in seconds.
- """
- if duration > self.threshhold2:
- logger.debug('Transfer took %s, decreasing workers.', duration)
- self.DecreaseWorkers(backoff=False)
- return
- elif duration > self.threshhold1:
- logger.debug('Transfer took %s, not increasing workers.', duration)
- return
- elif self.enabled:
- if self.backoff_time > 0.0:
- logger.info('Resetting backoff to 0.0')
- self.backoff_time = 0.0
- do_enable = False
- self.lock.acquire()
- try:
- if self.enabled and len(self._threads) > self.enabled_count:
- do_enable = True
- self.enabled_count += 1
- finally:
- self.lock.release()
- if do_enable:
- logger.debug('Increasing active thread count to %d',
- self.enabled_count)
- self.thread_semaphore.release()
-
- def DecreaseWorkers(self, backoff=True):
- """Informs the thread_gate that an item failed to send.
-
- If thread throttling is enabled, this method will cause the
- throttler to allow one fewer thread in the critical section. If
- there is only one thread remaining, failures will result in
- exponential backoff until there is a success.
-
- Args:
- backoff: Whether to increase exponential backoff if there is only
- one thread enabled.
- """
- if self.enabled:
- do_disable = False
- self.lock.acquire()
- try:
- if self.enabled:
- if self.enabled_count > 1:
- do_disable = True
- self.enabled_count -= 1
- elif backoff:
- if self.backoff_time == 0.0:
- self.backoff_time = INITIAL_BACKOFF
- else:
- self.backoff_time *= BACKOFF_FACTOR
- finally:
- self.lock.release()
- if do_disable:
- logger.debug('Decreasing the number of active threads to %d',
- self.enabled_count)
- self.thread_semaphore.acquire()
-
-
-class Throttle(object):
- """A base class for upload rate throttling.
-
- Transferring large number of records, too quickly, to an application
- could trigger quota limits and cause the transfer process to halt.
- In order to stay within the application's quota, we throttle the
- data transfer to a specified limit (across all transfer threads).
- This limit defaults to about half of the Google App Engine default
- for an application, but can be manually adjusted faster/slower as
- appropriate.
-
- This class tracks a moving average of some aspect of the transfer
- rate (bandwidth, records per second, http connections per
- second). It keeps two windows of counts of bytes transferred, on a
- per-thread basis. One block is the "current" block, and the other is
- the "prior" block. It will rotate the counts from current to prior
- when ROTATE_PERIOD has passed. Thus, the current block will
- represent from 0 seconds to ROTATE_PERIOD seconds of activity
- (determined by: time.time() - self.last_rotate). The prior block
- will always represent a full ROTATE_PERIOD.
-
- Sleeping is performed just before a transfer of another block, and is
- based on the counts transferred *before* the next transfer. It really
- does not matter how much will be transferred, but only that for all the
- data transferred SO FAR that we have interspersed enough pauses to
- ensure the aggregate transfer rate is within the specified limit.
-
- These counts are maintained on a per-thread basis, so we do not require
- any interlocks around incrementing the counts. There IS an interlock on
- the rotation of the counts because we do not want multiple threads to
- multiply-rotate the counts.
-
- There are various race conditions in the computation and collection
- of these counts. We do not require precise values, but simply to
- keep the overall transfer within the bandwidth limits. If a given
- pause is a little short, or a little long, then the aggregate delays
- will be correct.
- """
-
- ROTATE_PERIOD = 600
-
- def __init__(self,
- get_time=time.time,
- thread_sleep=InterruptibleSleep,
- layout=None):
- self.get_time = get_time
- self.thread_sleep = thread_sleep
-
- self.start_time = get_time()
- self.transferred = {}
- self.prior_block = {}
- self.totals = {}
- self.throttles = {}
-
- self.last_rotate = {}
- self.rotate_mutex = {}
- if layout:
- self.AddThrottles(layout)
-
- def AddThrottle(self, name, limit):
- self.throttles[name] = limit
- self.transferred[name] = {}
- self.prior_block[name] = {}
- self.totals[name] = {}
- self.last_rotate[name] = self.get_time()
- self.rotate_mutex[name] = threading.Lock()
-
- def AddThrottles(self, layout):
- for key, value in layout.iteritems():
- self.AddThrottle(key, value)
-
- def Register(self, thread):
- """Register this thread with the throttler."""
- thread_name = thread.getName()
- for throttle_name in self.throttles.iterkeys():
- self.transferred[throttle_name][thread_name] = 0
- self.prior_block[throttle_name][thread_name] = 0
- self.totals[throttle_name][thread_name] = 0
-
- def VerifyName(self, throttle_name):
- if throttle_name not in self.throttles:
- raise AssertionError('%s is not a registered throttle' % throttle_name)
-
- def AddTransfer(self, throttle_name, token_count):
- """Add a count to the amount this thread has transferred.
-
- Each time a thread transfers some data, it should call this method to
- note the amount sent. The counts may be rotated if sufficient time
- has passed since the last rotation.
-
- Note: this method should only be called by the BulkLoaderThread
- instances. The token count is allocated towards the
- "current thread".
-
- Args:
- throttle_name: The name of the throttle to add to.
- token_count: The number to add to the throttle counter.
- """
- self.VerifyName(throttle_name)
- transferred = self.transferred[throttle_name]
- transferred[threading.currentThread().getName()] += token_count
-
- if self.last_rotate[throttle_name] + self.ROTATE_PERIOD < self.get_time():
- self._RotateCounts(throttle_name)
-
- def Sleep(self, throttle_name=None):
- """Possibly sleep in order to limit the transfer rate.
-
- Note that we sleep based on *prior* transfers rather than what we
- may be about to transfer. The next transfer could put us under/over
- and that will be rectified *after* that transfer. Net result is that
- the average transfer rate will remain within bounds. Spiky behavior
- or uneven rates among the threads could possibly bring the transfer
- rate above the requested limit for short durations.
-
- Args:
- throttle_name: The name of the throttle to sleep on. If None or
- omitted, then sleep on all throttles.
- """
- if throttle_name is None:
- for throttle_name in self.throttles:
- self.Sleep(throttle_name=throttle_name)
- return
-
- self.VerifyName(throttle_name)
-
- thread = threading.currentThread()
-
- while True:
- duration = self.get_time() - self.last_rotate[throttle_name]
-
- total = 0
- for count in self.prior_block[throttle_name].values():
- total += count
-
- if total:
- duration += self.ROTATE_PERIOD
-
- for count in self.transferred[throttle_name].values():
- total += count
-
- sleep_time = (float(total) / self.throttles[throttle_name]) - duration
-
- if sleep_time < MINIMUM_THROTTLE_SLEEP_DURATION:
- break
-
- logger.debug('[%s] Throttling on %s. Sleeping for %.1f ms '
- '(duration=%.1f ms, total=%d)',
- thread.getName(), throttle_name,
- sleep_time * 1000, duration * 1000, total)
- self.thread_sleep(sleep_time)
- if thread.exit_flag:
- break
- self._RotateCounts(throttle_name)
-
- def _RotateCounts(self, throttle_name):
- """Rotate the transfer counters.
-
- If sufficient time has passed, then rotate the counters from active to
- the prior-block of counts.
-
- This rotation is interlocked to ensure that multiple threads do not
- over-rotate the counts.
-
- Args:
- throttle_name: The name of the throttle to rotate.
- """
- self.VerifyName(throttle_name)
- self.rotate_mutex[throttle_name].acquire()
- try:
- next_rotate_time = self.last_rotate[throttle_name] + self.ROTATE_PERIOD
- if next_rotate_time >= self.get_time():
- return
-
- for name, count in self.transferred[throttle_name].items():
-
-
- self.prior_block[throttle_name][name] = count
- self.transferred[throttle_name][name] = 0
-
- self.totals[throttle_name][name] += count
-
- self.last_rotate[throttle_name] = self.get_time()
-
- finally:
- self.rotate_mutex[throttle_name].release()
-
- def TotalTransferred(self, throttle_name):
- """Return the total transferred, and over what period.
-
- Args:
- throttle_name: The name of the throttle to total.
-
- Returns:
- A tuple of the total count and running time for the given throttle name.
- """
- total = 0
- for count in self.totals[throttle_name].values():
- total += count
- for count in self.transferred[throttle_name].values():
- total += count
- return total, self.get_time() - self.start_time
-
-
class _ThreadBase(threading.Thread):
"""Provide some basic features for the threads used in the uploader.
@@ -1993,18 +1365,29 @@
self.exit_flag = False
self.error = None
+ self.traceback = None
def run(self):
"""Perform the work of the thread."""
- logger.info('[%s] %s: started', self.getName(), self.__class__.__name__)
+ logger.debug('[%s] %s: started', self.getName(), self.__class__.__name__)
try:
self.PerformWork()
except:
- self.error = sys.exc_info()[1]
+ self.SetError()
logger.exception('[%s] %s:', self.getName(), self.__class__.__name__)
- logger.info('[%s] %s: exiting', self.getName(), self.__class__.__name__)
+ logger.debug('[%s] %s: exiting', self.getName(), self.__class__.__name__)
+
+ def SetError(self):
+ """Sets the error and traceback information for this thread.
+
+ This must be called from an exception handler.
+ """
+ if not self.error:
+ exc_info = sys.exc_info()
+ self.error = exc_info[1]
+ self.traceback = exc_info[2]
def PerformWork(self):
"""Perform the thread-specific work."""
@@ -2014,6 +1397,10 @@
"""If an error is present, then log it."""
if self.error:
logger.error('Error in %s: %s', self.GetFriendlyName(), self.error)
+ if self.traceback:
+ logger.debug(''.join(traceback.format_exception(self.error.__class__,
+ self.error,
+ self.traceback)))
def GetFriendlyName(self):
"""Returns a human-friendly description of the thread."""
@@ -2044,292 +1431,12 @@
return error.reason[0] not in non_fatal_error_codes
-def PrettyKey(key):
- """Returns a nice string representation of the given key."""
- if key is None:
- return None
- elif isinstance(key, db.Key):
- return repr(key.id_or_name())
- return str(key)
-
-
-class _BulkWorkerThread(_ThreadBase):
- """A base class for worker threads.
-
- This thread will read WorkItem instances from the work_queue and upload
- the entities to the server application. Progress information will be
- pushed into the progress_queue as the work is being performed.
-
- If a _BulkWorkerThread encounters a transient error, the entities will be
- resent, if a fatal error is encoutered the BulkWorkerThread exits.
-
- Subclasses must provide implementations for PreProcessItem, TransferItem,
- and ProcessResponse.
- """
-
- def __init__(self,
- work_queue,
- throttle,
- thread_gate,
- request_manager,
- num_threads,
- batch_size,
- state_message,
- get_time):
- """Initialize the BulkLoaderThread instance.
-
- Args:
- work_queue: A queue containing WorkItems for processing.
- throttle: A Throttles to control upload bandwidth.
- thread_gate: A ThreadGate to control number of simultaneous uploads.
- request_manager: A RequestManager instance.
- num_threads: The number of threads for parallel transfers.
- batch_size: The number of entities to transfer per request.
- state_message: Used for dependency injection.
- get_time: Used for dependency injection.
- """
- _ThreadBase.__init__(self)
-
- self.work_queue = work_queue
- self.throttle = throttle
- self.thread_gate = thread_gate
- self.request_manager = request_manager
- self.num_threads = num_threads
- self.batch_size = batch_size
- self.state_message = state_message
- self.get_time = get_time
-
- def PreProcessItem(self, item):
- """Performs pre transfer processing on a work item."""
- raise NotImplementedError()
-
- def TransferItem(self, item):
- """Transfers the entities associated with an item.
-
- Args:
- item: An item of upload (WorkItem) or download (KeyRange) work.
-
- Returns:
- A tuple of (estimated transfer size, response)
- """
- raise NotImplementedError()
-
- def ProcessResponse(self, item, result):
- """Processes the response from the server application."""
- raise NotImplementedError()
-
- def PerformWork(self):
- """Perform the work of a _BulkWorkerThread."""
- while not self.exit_flag:
- transferred = False
- self.thread_gate.StartWork()
- try:
- try:
- item = self.work_queue.get(block=True, timeout=1.0)
- except Queue.Empty:
- continue
- if item == _THREAD_SHOULD_EXIT:
- break
-
- logger.debug('[%s] Got work item %s', self.getName(), item)
-
- try:
-
- item.MarkAsTransferring()
- self.PreProcessItem(item)
- response = None
- try:
- try:
- t = self.get_time()
- response = self.TransferItem(item)
- status = 200
- transferred = True
- transfer_time = self.get_time() - t
- logger.debug('[%s] %s Transferred %d entities in %0.1f seconds',
- self.getName(), item, item.count, transfer_time)
- self.throttle.AddTransfer(RECORDS, item.count)
- except (db.InternalError, db.NotSavedError, db.Timeout,
- apiproxy_errors.OverQuotaError,
- apiproxy_errors.DeadlineExceededError), e:
- logger.exception('Caught non-fatal datastore error: %s', e)
- except urllib2.HTTPError, e:
- status = e.code
- if status == 403 or (status >= 500 and status < 600):
- logger.exception('Caught non-fatal HTTP error: %d %s',
- status, e.msg)
- else:
- raise e
- except urllib2.URLError, e:
- if IsURLErrorFatal(e):
- raise e
- else:
- logger.exception('Caught non-fatal URL error: %s', e.reason)
-
- self.ProcessResponse(item, response)
-
- except:
- self.error = sys.exc_info()[1]
- logger.exception('[%s] %s: caught exception %s', self.getName(),
- self.__class__.__name__, str(sys.exc_info()))
- raise
-
- finally:
- if transferred:
- item.MarkAsTransferred()
- self.work_queue.task_done()
- self.thread_gate.TransferSuccess(transfer_time)
- else:
- item.MarkAsError()
- try:
- self.work_queue.reput(item, block=False)
- except Queue.Full:
- logger.error('[%s] Failed to reput work item.', self.getName())
- raise Error('Failed to reput work item')
- self.thread_gate.DecreaseWorkers()
- logger.info('%s %s',
- item,
- self.state_message(item.state))
-
- finally:
- self.thread_gate.FinishWork()
-
-
- def GetFriendlyName(self):
- """Returns a human-friendly name for this thread."""
- return 'worker [%s]' % self.getName()
-
-
-class BulkLoaderThread(_BulkWorkerThread):
- """A thread which transmits entities to the server application.
-
- This thread will read WorkItem instances from the work_queue and upload
- the entities to the server application. Progress information will be
- pushed into the progress_queue as the work is being performed.
-
- If a BulkLoaderThread encounters a transient error, the entities will be
- resent, if a fatal error is encoutered the BulkLoaderThread exits.
- """
-
- def __init__(self,
- work_queue,
- throttle,
- thread_gate,
- request_manager,
- num_threads,
- batch_size,
- get_time=time.time):
- """Initialize the BulkLoaderThread instance.
-
- Args:
- work_queue: A queue containing WorkItems for processing.
- throttle: A Throttles to control upload bandwidth.
- thread_gate: A ThreadGate to control number of simultaneous uploads.
- request_manager: A RequestManager instance.
- num_threads: The number of threads for parallel transfers.
- batch_size: The number of entities to transfer per request.
- get_time: Used for dependency injection.
- """
- _BulkWorkerThread.__init__(self,
- work_queue,
- throttle,
- thread_gate,
- request_manager,
- num_threads,
- batch_size,
- ImportStateMessage,
- get_time)
-
- def PreProcessItem(self, item):
- """Performs pre transfer processing on a work item."""
- if item and not item.content:
- item.content = self.request_manager.EncodeContent(item.rows)
-
- def TransferItem(self, item):
- """Transfers the entities associated with an item.
-
- Args:
- item: An item of upload (WorkItem) work.
-
- Returns:
- A tuple of (estimated transfer size, response)
- """
- return self.request_manager.PostEntities(item)
-
- def ProcessResponse(self, item, response):
- """Processes the response from the server application."""
- pass
-
-
-class BulkExporterThread(_BulkWorkerThread):
- """A thread which recieved entities to the server application.
-
- This thread will read KeyRange instances from the work_queue and export
- the entities from the server application. Progress information will be
- pushed into the progress_queue as the work is being performed.
-
- If a BulkExporterThread encounters an error when trying to post data,
- the thread will exit and cause the application to terminate.
- """
-
- def __init__(self,
- work_queue,
- throttle,
- thread_gate,
- request_manager,
- num_threads,
- batch_size,
- get_time=time.time):
-
- """Initialize the BulkExporterThread instance.
-
- Args:
- work_queue: A queue containing KeyRanges for processing.
- throttle: A Throttles to control upload bandwidth.
- thread_gate: A ThreadGate to control number of simultaneous uploads.
- request_manager: A RequestManager instance.
- num_threads: The number of threads for parallel transfers.
- batch_size: The number of entities to transfer per request.
- get_time: Used for dependency injection.
- """
- _BulkWorkerThread.__init__(self,
- work_queue,
- throttle,
- thread_gate,
- request_manager,
- num_threads,
- batch_size,
- ExportStateMessage,
- get_time)
-
- def PreProcessItem(self, unused_item):
- """Performs pre transfer processing on a work item."""
- pass
-
- def TransferItem(self, item):
- """Transfers the entities associated with an item.
-
- Args:
- item: An item of download (KeyRange) work.
-
- Returns:
- A tuple of (estimated transfer size, response)
- """
- return self.request_manager.GetEntities(item)
-
- def ProcessResponse(self, item, export_result):
- """Processes the response from the server application."""
- if export_result:
- item.Process(export_result, self.num_threads, self.batch_size,
- self.work_queue)
- item.state = STATE_GOT
-
-
class DataSourceThread(_ThreadBase):
"""A thread which reads WorkItems and pushes them into queue.
This thread will read/consume WorkItems from a generator (produced by
the generator factory). These WorkItems will then be pushed into the
- work_queue. Note that reading will block if/when the work_queue becomes
+ thread_pool. Note that reading will block if/when the thread_pool becomes
full. Information on content consumed from the generator will be pushed
into the progress_queue.
"""
@@ -2337,14 +1444,16 @@
NAME = 'data source thread'
def __init__(self,
- work_queue,
+ request_manager,
+ thread_pool,
progress_queue,
workitem_generator_factory,
progress_generator_factory):
"""Initialize the DataSourceThread instance.
Args:
- work_queue: A queue containing WorkItems for processing.
+ request_manager: A RequestManager instance.
+ thread_pool: An AdaptiveThreadPool instance.
progress_queue: A queue used for tracking progress information.
workitem_generator_factory: A factory that creates a WorkItem generator
progress_generator_factory: A factory that creates a generator which
@@ -2353,7 +1462,8 @@
"""
_ThreadBase.__init__(self)
- self.work_queue = work_queue
+ self.request_manager = request_manager
+ self.thread_pool = thread_pool
self.progress_queue = progress_queue
self.workitem_generator_factory = workitem_generator_factory
self.progress_generator_factory = progress_generator_factory
@@ -2366,7 +1476,8 @@
else:
progress_gen = None
- content_gen = self.workitem_generator_factory(self.progress_queue,
+ content_gen = self.workitem_generator_factory(self.request_manager,
+ self.progress_queue,
progress_gen)
self.xfer_count = 0
@@ -2378,7 +1489,7 @@
while not self.exit_flag:
try:
- self.work_queue.put(item, block=True, timeout=1.0)
+ self.thread_pool.SubmitItem(item, block=True, timeout=1.0)
self.entity_count += item.count
break
except Queue.Full:
@@ -2526,6 +1637,70 @@
self.update_cursor = self.secondary_conn.cursor()
+zero_matcher = re.compile(r'\x00')
+
+zero_one_matcher = re.compile(r'\x00\x01')
+
+
+def KeyStr(key):
+ """Returns a string to represent a key, preserving ordering.
+
+ Unlike datastore.Key.__str__(), we have the property:
+
+ key1 < key2 ==> KeyStr(key1) < KeyStr(key2)
+
+ The key string is constructed from the key path as follows:
+ (1) Strings are prepended with ':' and numeric id's are padded to
+ 20 digits.
+ (2) Any null characters (u'\0') present are replaced with u'\0\1'
+ (3) The sequence u'\0\0' is used to separate each component of the path.
+
+ (1) assures that names and ids compare properly, while (2) and (3) enforce
+ the part-by-part comparison of pieces of the path.
+
+ Args:
+ key: A datastore.Key instance.
+
+ Returns:
+ A string representation of the key, which preserves ordering.
+ """
+ assert isinstance(key, datastore.Key)
+ path = key.to_path()
+
+ out_path = []
+ for part in path:
+ if isinstance(part, (int, long)):
+ part = '%020d' % part
+ else:
+ part = ':%s' % part
+
+ out_path.append(zero_matcher.sub(u'\0\1', part))
+
+ out_str = u'\0\0'.join(out_path)
+
+ return out_str
+
+
+def StrKey(key_str):
+ """The inverse of the KeyStr function.
+
+ Args:
+ key_str: A string in the range of KeyStr.
+
+ Returns:
+ A datastore.Key instance k, such that KeyStr(k) == key_str.
+ """
+ parts = key_str.split(u'\0\0')
+ for i in xrange(len(parts)):
+ if parts[i][0] == ':':
+ part = parts[i][1:]
+ part = zero_one_matcher.sub(u'\0', part)
+ parts[i] = part
+ else:
+ parts[i] = int(parts[i])
+ return datastore.Key.from_path(*parts)
+
+
class ResultDatabase(_Database):
"""Persistently record all the entities downloaded during an export.
@@ -2544,7 +1719,7 @@
"""
self.complete = False
create_table = ('create table result (\n'
- 'id TEXT primary key,\n'
+ 'id BLOB primary key,\n'
'value BLOB not null)')
_Database.__init__(self,
@@ -2560,34 +1735,37 @@
self.existing_count = 0
self.count = self.existing_count
- def _StoreEntity(self, entity_id, value):
+ def _StoreEntity(self, entity_id, entity):
"""Store an entity in the result database.
Args:
- entity_id: A db.Key for the entity.
- value: A string of the contents of the entity.
+ entity_id: A datastore.Key for the entity.
+ entity: The entity to store.
Returns:
True if this entities is not already present in the result database.
"""
assert _RunningInThread(self.secondary_thread)
- assert isinstance(entity_id, db.Key)
-
- entity_id = entity_id.id_or_name()
+ assert isinstance(entity_id, datastore.Key), (
+ 'expected a datastore.Key, got a %s' % entity_id.__class__.__name__)
+
+ key_str = buffer(KeyStr(entity_id).encode('utf-8'))
self.insert_cursor.execute(
- 'select count(*) from result where id = ?', (unicode(entity_id),))
+ 'select count(*) from result where id = ?', (key_str,))
+
already_present = self.insert_cursor.fetchone()[0]
result = True
if already_present:
result = False
self.insert_cursor.execute('delete from result where id = ?',
- (unicode(entity_id),))
+ (key_str,))
else:
self.count += 1
+ value = entity.Encode()
self.insert_cursor.execute(
'insert into result (id, value) values (?, ?)',
- (unicode(entity_id), buffer(value)))
+ (key_str, buffer(value)))
return result
def StoreEntities(self, keys, entities):
@@ -2603,9 +1781,9 @@
self._OpenSecondaryConnection()
t = time.time()
count = 0
- for entity_id, value in zip(keys,
- entities):
- if self._StoreEntity(entity_id, value):
+ for entity_id, entity in zip(keys,
+ entities):
+ if self._StoreEntity(entity_id, entity):
count += 1
logger.debug('%s insert: delta=%.3f',
self.db_filename,
@@ -2627,7 +1805,8 @@
'select id, value from result order by id')
for unused_entity_id, entity in cursor:
- yield cPickle.loads(str(entity))
+ entity_proto = entity_pb.EntityProto(contents=entity)
+ yield datastore.Entity._FromPb(entity_proto)
class _ProgressDatabase(_Database):
@@ -2723,9 +1902,16 @@
self._OpenSecondaryConnection()
assert _RunningInThread(self.secondary_thread)
- assert not key_start or isinstance(key_start, self.py_type)
- assert not key_end or isinstance(key_end, self.py_type), '%s is a %s' % (
- key_end, key_end.__class__)
+ assert (not key_start) or isinstance(key_start, self.py_type), (
+ '%s is a %s, %s expected %s' % (key_start,
+ key_start.__class__,
+ self.__class__.__name__,
+ self.py_type))
+ assert (not key_end) or isinstance(key_end, self.py_type), (
+ '%s is a %s, %s expected %s' % (key_end,
+ key_end.__class__,
+ self.__class__.__name__,
+ self.py_type))
assert KeyLEQ(key_start, key_end), '%s not less than %s' % (
repr(key_start), repr(key_end))
@@ -2843,7 +2029,7 @@
_ProgressDatabase.__init__(self,
db_filename,
'TEXT',
- db.Key,
+ datastore.Key,
signature,
commit_periodicity=1)
@@ -3011,34 +2197,72 @@
exporter.output_entities(self.result_db.AllEntities())
def UpdateProgress(self, item):
- """Update the state of the given KeyRange.
+ """Update the state of the given KeyRangeItem.
Args:
item: A KeyRange instance.
"""
if item.state == STATE_GOT:
- count = self.result_db.StoreEntities(item.export_result.keys,
- item.export_result.entities)
+ count = self.result_db.StoreEntities(item.download_result.keys,
+ item.download_result.entities)
self.db.DeleteKey(item.progress_key)
self.entities_transferred += count
else:
self.db.UpdateState(item.progress_key, item.state)
+class MapperProgressThread(_ProgressThreadBase):
+ """A thread to record progress information for maps over the datastore."""
+
+ def __init__(self, kind, progress_queue, progress_db):
+ """Initialize the MapperProgressThread instance.
+
+ Args:
+ kind: The kind of entities being stored in the database.
+ progress_queue: A Queue used for tracking progress information.
+ progress_db: The database for tracking progress information; should
+ be an instance of ProgressDatabase.
+ """
+ _ProgressThreadBase.__init__(self, progress_queue, progress_db)
+
+ self.kind = kind
+ self.mapper = Mapper.RegisteredMapper(self.kind)
+
+ def EntitiesTransferred(self):
+ """Return the total number of unique entities transferred."""
+ return self.entities_transferred
+
+ def WorkFinished(self):
+ """Perform actions after map is complete."""
+ pass
+
+ def UpdateProgress(self, item):
+ """Update the state of the given KeyRangeItem.
+
+ Args:
+ item: A KeyRange instance.
+ """
+ if item.state == STATE_GOT:
+ self.entities_transferred += item.count
+ self.db.DeleteKey(item.progress_key)
+ else:
+ self.db.UpdateState(item.progress_key, item.state)
+
+
def ParseKey(key_string):
- """Turn a key stored in the database into a db.Key or None.
+ """Turn a key stored in the database into a Key or None.
Args:
- key_string: The string representation of a db.Key.
+ key_string: The string representation of a Key.
Returns:
- A db.Key instance or None
+ A datastore.Key instance or None
"""
if not key_string:
return None
if key_string == 'None':
return None
- return db.Key(encoded=key_string)
+ return datastore.Key(encoded=key_string)
def Validate(value, typ):
@@ -3097,9 +2321,7 @@
def __init__(self, kind, properties):
"""Constructor.
- Populates this Loader's kind and properties map. Also registers it with
- the bulk loader, so that all you need to do is instantiate your Loader,
- and the bulkload handler will automatically use it.
+ Populates this Loader's kind and properties map.
Args:
kind: a string containing the entity kind that this loader handles
@@ -3139,7 +2361,11 @@
@staticmethod
def RegisterLoader(loader):
-
+ """Register loader and the Loader instance for its kind.
+
+ Args:
+ loader: A Loader instance.
+ """
Loader.__loaders[loader.kind] = loader
def alias_old_names(self):
@@ -3166,7 +2392,7 @@
Args:
values: list/tuple of str
key_name: if provided, the name for the (single) resulting entity
- parent: A db.Key instance for the parent, or None
+ parent: A datastore.Key instance for the parent, or None
Returns:
list of db.Model
@@ -3222,7 +2448,7 @@
server generated numeric key), or a string which neither starts
with a digit nor has the form __*__ (see
http://code.google.com/appengine/docs/python/datastore/keysandentitygroups.html),
- or a db.Key instance.
+ or a datastore.Key instance.
If you generate your own string keys, keep in mind:
@@ -3305,6 +2531,51 @@
return Loader.__loaders[kind]
+class RestoreThread(_ThreadBase):
+ """A thread to read saved entity_pbs from sqlite3."""
+ NAME = 'RestoreThread'
+ _ENTITIES_DONE = 'Entities Done'
+
+ def __init__(self, queue, filename):
+ _ThreadBase.__init__(self)
+ self.queue = queue
+ self.filename = filename
+
+ def PerformWork(self):
+ db_conn = sqlite3.connect(self.filename)
+ cursor = db_conn.cursor()
+ cursor.execute('select id, value from result')
+ for entity_id, value in cursor:
+ self.queue.put([entity_id, value], block=True)
+ self.queue.put(RestoreThread._ENTITIES_DONE, block=True)
+
+
+class RestoreLoader(Loader):
+ """A Loader which imports protobuffers from a file."""
+
+ def __init__(self, kind):
+ self.kind = kind
+
+ def initialize(self, filename, loader_opts):
+ CheckFile(filename)
+ self.queue = Queue.Queue(1000)
+ restore_thread = RestoreThread(self.queue, filename)
+ restore_thread.start()
+
+ def generate_records(self, filename):
+ while True:
+ record = self.queue.get(block=True)
+ if id(record) == id(RestoreThread._ENTITIES_DONE):
+ break
+ yield record
+
+ def create_entity(self, values, key_name=None, parent=None):
+ key = StrKey(unicode(values[0], 'utf-8'))
+ entity_proto = entity_pb.EntityProto(contents=str(values[1]))
+ entity_proto.mutable_key().CopyFrom(key._Key__reference)
+ return datastore.Entity._FromPb(entity_proto)
+
+
class Exporter(object):
"""A base class for serializing datastore entities.
@@ -3326,9 +2597,7 @@
def __init__(self, kind, properties):
"""Constructor.
- Populates this Exporters's kind and properties map. Also registers
- it so that all you need to do is instantiate your Exporter, and
- the bulkload handler will automatically use it.
+ Populates this Exporters's kind and properties map.
Args:
kind: a string containing the entity kind that this exporter handles
@@ -3370,7 +2639,11 @@
@staticmethod
def RegisterExporter(exporter):
-
+ """Register exporter and the Exporter instance for its kind.
+
+ Args:
+ exporter: A Exporter instance.
+ """
Exporter.__exporters[exporter.kind] = exporter
def __ExtractProperties(self, entity):
@@ -3388,7 +2661,7 @@
encoding = []
for name, fn, default in self.__properties:
try:
- encoding.append(fn(getattr(entity, name)))
+ encoding.append(fn(entity[name]))
except AttributeError:
if default is None:
raise MissingPropertyError(name)
@@ -3468,6 +2741,87 @@
return Exporter.__exporters[kind]
+class DumpExporter(Exporter):
+ """An exporter which dumps protobuffers to a file."""
+
+ def __init__(self, kind, result_db_filename):
+ self.kind = kind
+ self.result_db_filename = result_db_filename
+
+ def output_entities(self, entity_generator):
+ shutil.copyfile(self.result_db_filename, self.output_filename)
+
+
+class MapperRetry(Error):
+ """An exception that indicates a non-fatal error during mapping."""
+
+
+class Mapper(object):
+ """A base class for serializing datastore entities.
+
+ To add a handler for exporting an entity kind from your datastore,
+ write a subclass of this class that calls Mapper.__init__ from your
+ class's __init__.
+
+ You need to implement to batch_apply or apply method on your subclass
+ for the map to do anything.
+ """
+
+ __mappers = {}
+ kind = None
+
+ def __init__(self, kind):
+ """Constructor.
+
+ Populates this Mappers's kind.
+
+ Args:
+ kind: a string containing the entity kind that this mapper handles
+ """
+ Validate(kind, basestring)
+ self.kind = kind
+
+ GetImplementationClass(kind)
+
+ @staticmethod
+ def RegisterMapper(mapper):
+ """Register mapper and the Mapper instance for its kind.
+
+ Args:
+ mapper: A Mapper instance.
+ """
+ Mapper.__mappers[mapper.kind] = mapper
+
+ def initialize(self, mapper_opts):
+ """Performs initialization.
+
+ Args:
+ mapper_opts: The string given as the --mapper_opts flag argument.
+ """
+ pass
+
+ def finalize(self):
+ """Performs finalization actions after the download completes."""
+ pass
+
+ def apply(self, entity):
+ print 'Default map function doing nothing to %s' % entity
+
+ def batch_apply(self, entities):
+ for entity in entities:
+ self.apply(entity)
+
+ @staticmethod
+ def RegisteredMappers():
+ """Returns a dictionary of the mapper instances that have been created."""
+ return dict(Mapper.__mappers)
+
+ @staticmethod
+ def RegisteredMapper(kind):
+ """Returns an mapper instance for the given kind if it exists."""
+ return Mapper.__mappers[kind]
+
+
class QueueJoinThread(threading.Thread):
"""A thread that joins a queue and exits.
@@ -3492,7 +2846,7 @@
def InterruptibleQueueJoin(queue,
thread_local,
- thread_gate,
+ thread_pool,
queue_join_thread_factory=QueueJoinThread,
check_workers=True):
"""Repeatedly joins the given ReQueue or Queue.Queue with short timeout.
@@ -3502,7 +2856,7 @@
Args:
queue: A Queue.Queue or ReQueue instance.
thread_local: A threading.local instance which indicates interrupts.
- thread_gate: A ThreadGate instance.
+ thread_pool: An AdaptiveThreadPool instance.
queue_join_thread_factory: Used for dependency injection.
check_workers: Whether to interrupt the join on worker death.
@@ -3519,41 +2873,29 @@
logger.debug('Queue join interrupted')
return False
if check_workers:
- for worker_thread in thread_gate.Threads():
+ for worker_thread in thread_pool.Threads():
if not worker_thread.isAlive():
return False
-def ShutdownThreads(data_source_thread, work_queue, thread_gate):
+def ShutdownThreads(data_source_thread, thread_pool):
"""Shuts down the worker and data source threads.
Args:
data_source_thread: A running DataSourceThread instance.
- work_queue: The work queue.
- thread_gate: A ThreadGate instance with workers registered.
+ thread_pool: An AdaptiveThreadPool instance with workers registered.
"""
logger.info('An error occurred. Shutting down...')
data_source_thread.exit_flag = True
- for thread in thread_gate.Threads():
- thread.exit_flag = True
-
- for unused_thread in thread_gate.Threads():
- thread_gate.EnableThread()
+ thread_pool.Shutdown()
data_source_thread.join(timeout=3.0)
if data_source_thread.isAlive():
logger.warn('%s hung while trying to exit',
data_source_thread.GetFriendlyName())
- while not work_queue.empty():
- try:
- unused_item = work_queue.get_nowait()
- work_queue.task_done()
- except Queue.Empty:
- pass
-
class BulkTransporterApp(object):
"""Class to wrap bulk transport application functionality."""
@@ -3563,13 +2905,12 @@
input_generator_factory,
throttle,
progress_db,
- workerthread_factory,
progresstrackerthread_factory,
max_queue_size=DEFAULT_QUEUE_SIZE,
request_manager_factory=RequestManager,
datasourcethread_factory=DataSourceThread,
- work_queue_factory=ReQueue,
- progress_queue_factory=Queue.Queue):
+ progress_queue_factory=Queue.Queue,
+ thread_pool_factory=adaptive_thread_pool.AdaptiveThreadPool):
"""Instantiate a BulkTransporterApp.
Uploads or downloads data to or from application using HTTP requests.
@@ -3584,13 +2925,12 @@
input_generator_factory: A factory that creates a WorkItem generator.
throttle: A Throttle instance.
progress_db: The database to use for replaying/recording progress.
- workerthread_factory: A factory for worker threads.
progresstrackerthread_factory: Used for dependency injection.
max_queue_size: Maximum size of the queues before they should block.
request_manager_factory: Used for dependency injection.
datasourcethread_factory: Used for dependency injection.
- work_queue_factory: Used for dependency injection.
progress_queue_factory: Used for dependency injection.
+ thread_pool_factory: Used for dependency injection.
"""
self.app_id = arg_dict['app_id']
self.post_url = arg_dict['url']
@@ -3600,15 +2940,15 @@
self.num_threads = arg_dict['num_threads']
self.email = arg_dict['email']
self.passin = arg_dict['passin']
+ self.dry_run = arg_dict['dry_run']
self.throttle = throttle
self.progress_db = progress_db
- self.workerthread_factory = workerthread_factory
self.progresstrackerthread_factory = progresstrackerthread_factory
self.max_queue_size = max_queue_size
self.request_manager_factory = request_manager_factory
self.datasourcethread_factory = datasourcethread_factory
- self.work_queue_factory = work_queue_factory
self.progress_queue_factory = progress_queue_factory
+ self.thread_pool_factory = thread_pool_factory
(scheme,
self.host_port, self.url_path,
unused_query, unused_fragment) = urlparse.urlsplit(self.post_url)
@@ -3623,13 +2963,13 @@
Returns:
Error code suitable for sys.exit, e.g. 0 on success, 1 on failure.
"""
- thread_gate = ThreadGate(True)
+ self.error = False
+ thread_pool = self.thread_pool_factory(
+ self.num_threads, queue_size=self.max_queue_size)
self.throttle.Register(threading.currentThread())
threading.currentThread().exit_flag = False
- work_queue = self.work_queue_factory(self.max_queue_size)
-
progress_queue = self.progress_queue_factory(self.max_queue_size)
request_manager = self.request_manager_factory(self.app_id,
self.host_port,
@@ -3639,27 +2979,23 @@
self.batch_size,
self.secure,
self.email,
- self.passin)
+ self.passin,
+ self.dry_run)
try:
request_manager.Authenticate()
except Exception, e:
+ self.error = True
if not isinstance(e, urllib2.HTTPError) or (
e.code != 302 and e.code != 401):
logger.exception('Exception during authentication')
raise AuthenticationError()
if (request_manager.auth_called and
not request_manager.authenticated):
+ self.error = True
raise AuthenticationError('Authentication failed')
- for unused_idx in xrange(self.num_threads):
- thread = self.workerthread_factory(work_queue,
- self.throttle,
- thread_gate,
- request_manager,
- self.num_threads,
- self.batch_size)
+ for thread in thread_pool.Threads():
self.throttle.Register(thread)
- thread_gate.Register(thread)
self.progress_thread = self.progresstrackerthread_factory(
progress_queue, self.progress_db)
@@ -3671,7 +3007,8 @@
progress_generator_factory = None
self.data_source_thread = (
- self.datasourcethread_factory(work_queue,
+ self.datasourcethread_factory(request_manager,
+ thread_pool,
progress_queue,
self.input_generator_factory,
progress_generator_factory))
@@ -3682,60 +3019,54 @@
def Interrupt(unused_signum, unused_frame):
"""Shutdown gracefully in response to a signal."""
thread_local.shut_down = True
+ self.error = True
signal.signal(signal.SIGINT, Interrupt)
self.progress_thread.start()
self.data_source_thread.start()
- for thread in thread_gate.Threads():
- thread.start()
while not thread_local.shut_down:
self.data_source_thread.join(timeout=0.25)
if self.data_source_thread.isAlive():
- for thread in list(thread_gate.Threads()) + [self.progress_thread]:
+ for thread in list(thread_pool.Threads()) + [self.progress_thread]:
if not thread.isAlive():
logger.info('Unexpected thread death: %s', thread.getName())
thread_local.shut_down = True
+ self.error = True
break
else:
break
- if thread_local.shut_down:
- ShutdownThreads(self.data_source_thread, work_queue, thread_gate)
-
def _Join(ob, msg):
logger.debug('Waiting for %s...', msg)
if isinstance(ob, threading.Thread):
ob.join(timeout=3.0)
if ob.isAlive():
- logger.debug('Joining %s failed', ob.GetFriendlyName())
+ logger.debug('Joining %s failed', ob)
else:
logger.debug('... done.')
elif isinstance(ob, (Queue.Queue, ReQueue)):
- if not InterruptibleQueueJoin(ob, thread_local, thread_gate):
- ShutdownThreads(self.data_source_thread, work_queue, thread_gate)
+ if not InterruptibleQueueJoin(ob, thread_local, thread_pool):
+ ShutdownThreads(self.data_source_thread, thread_pool)
else:
ob.join()
logger.debug('... done.')
- _Join(work_queue, 'work_queue to flush')
-
- for unused_thread in thread_gate.Threads():
- work_queue.put(_THREAD_SHOULD_EXIT)
-
- for unused_thread in thread_gate.Threads():
- thread_gate.EnableThread()
-
- for thread in thread_gate.Threads():
- _Join(thread, 'thread [%s] to terminate' % thread.getName())
-
- thread.CheckError()
+ if self.data_source_thread.error or thread_local.shut_down:
+ ShutdownThreads(self.data_source_thread, thread_pool)
+ else:
+ _Join(thread_pool.requeue, 'worker threads to finish')
+
+ thread_pool.Shutdown()
+ thread_pool.JoinThreads()
+ thread_pool.CheckErrors()
+ print ''
if self.progress_thread.isAlive():
- InterruptibleQueueJoin(progress_queue, thread_local, thread_gate,
+ InterruptibleQueueJoin(progress_queue, thread_local, thread_pool,
check_workers=False)
else:
logger.warn('Progress thread exited prematurely')
@@ -3763,9 +3094,10 @@
def ReportStatus(self):
"""Display a message reporting the final status of the transfer."""
- total_up, duration = self.throttle.TotalTransferred(BANDWIDTH_UP)
+ total_up, duration = self.throttle.TotalTransferred(
+ remote_api_throttle.BANDWIDTH_UP)
s_total_up, unused_duration = self.throttle.TotalTransferred(
- HTTPS_BANDWIDTH_UP)
+ remote_api_throttle.HTTPS_BANDWIDTH_UP)
total_up += s_total_up
total = total_up
logger.info('%d entites total, %d previously transferred',
@@ -3793,18 +3125,49 @@
def ReportStatus(self):
"""Display a message reporting the final status of the transfer."""
- total_down, duration = self.throttle.TotalTransferred(BANDWIDTH_DOWN)
+ total_down, duration = self.throttle.TotalTransferred(
+ remote_api_throttle.BANDWIDTH_DOWN)
s_total_down, unused_duration = self.throttle.TotalTransferred(
- HTTPS_BANDWIDTH_DOWN)
+ remote_api_throttle.HTTPS_BANDWIDTH_DOWN)
total_down += s_total_down
total = total_down
existing_count = self.progress_thread.existing_count
xfer_count = self.progress_thread.EntitiesTransferred()
logger.info('Have %d entities, %d previously transferred',
- xfer_count + existing_count, existing_count)
+ xfer_count, existing_count)
logger.info('%d entities (%d bytes) transferred in %.1f seconds',
xfer_count, total, duration)
- return 0
+ if self.error:
+ return 1
+ else:
+ return 0
+
+
+class BulkMapperApp(BulkTransporterApp):
+ """Class to encapsulate bulk map functionality."""
+
+ def __init__(self, *args, **kwargs):
+ BulkTransporterApp.__init__(self, *args, **kwargs)
+
+ def ReportStatus(self):
+ """Display a message reporting the final status of the transfer."""
+ total_down, duration = self.throttle.TotalTransferred(
+ remote_api_throttle.BANDWIDTH_DOWN)
+ s_total_down, unused_duration = self.throttle.TotalTransferred(
+ remote_api_throttle.HTTPS_BANDWIDTH_DOWN)
+ total_down += s_total_down
+ total = total_down
+ xfer_count = self.progress_thread.EntitiesTransferred()
+ logger.info('The following may be inaccurate if any mapper tasks '
+ 'encountered errors and had to be retried.')
+ logger.info('Applied mapper to %s entities.',
+ xfer_count)
+ logger.info('%s entities (%s bytes) transferred in %.1f seconds',
+ xfer_count, total, duration)
+ if self.error:
+ return 1
+ else:
+ return 0
def PrintUsageExit(code):
@@ -3843,18 +3206,24 @@
'loader_opts=',
'exporter_opts=',
'log_file=',
+ 'mapper_opts=',
'email=',
'passin',
+ 'map',
+ 'dry_run',
+ 'dump',
+ 'restore',
]
-def ParseArguments(argv):
+def ParseArguments(argv, die_fn=lambda: PrintUsageExit(1)):
"""Parses command-line arguments.
Prints out a help message if -h or --help is supplied.
Args:
argv: List of command-line arguments.
+ die_fn: Function to invoke to end the program.
Returns:
A dictionary containing the value of command-line options.
@@ -3867,11 +3236,11 @@
arg_dict = {}
arg_dict['url'] = REQUIRED_OPTION
- arg_dict['filename'] = REQUIRED_OPTION
- arg_dict['config_file'] = REQUIRED_OPTION
- arg_dict['kind'] = REQUIRED_OPTION
-
- arg_dict['batch_size'] = DEFAULT_BATCH_SIZE
+ arg_dict['filename'] = None
+ arg_dict['config_file'] = None
+ arg_dict['kind'] = None
+
+ arg_dict['batch_size'] = None
arg_dict['num_threads'] = DEFAULT_THREAD_COUNT
arg_dict['bandwidth_limit'] = DEFAULT_BANDWIDTH_LIMIT
arg_dict['rps_limit'] = DEFAULT_RPS_LIMIT
@@ -3889,6 +3258,11 @@
arg_dict['log_file'] = None
arg_dict['email'] = None
arg_dict['passin'] = False
+ arg_dict['mapper_opts'] = None
+ arg_dict['map'] = False
+ arg_dict['dry_run'] = False
+ arg_dict['dump'] = False
+ arg_dict['restore'] = False
def ExpandFilename(filename):
"""Expand shell variables and ~usernames in filename."""
@@ -3938,26 +3312,39 @@
elif option == '--exporter_opts':
arg_dict['exporter_opts'] = value
elif option == '--log_file':
- arg_dict['log_file'] = value
+ arg_dict['log_file'] = ExpandFilename(value)
elif option == '--email':
arg_dict['email'] = value
elif option == '--passin':
arg_dict['passin'] = True
-
- return ProcessArguments(arg_dict, die_fn=lambda: PrintUsageExit(1))
+ elif option == '--map':
+ arg_dict['map'] = True
+ elif option == '--mapper_opts':
+ arg_dict['mapper_opts'] = value
+ elif option == '--dry_run':
+ arg_dict['dry_run'] = True
+ elif option == '--dump':
+ arg_dict['dump'] = True
+ elif option == '--restore':
+ arg_dict['restore'] = True
+
+ return ProcessArguments(arg_dict, die_fn=die_fn)
def ThrottleLayout(bandwidth_limit, http_limit, rps_limit):
"""Return a dictionary indicating the throttle options."""
- return {
- BANDWIDTH_UP: bandwidth_limit,
- BANDWIDTH_DOWN: bandwidth_limit,
- REQUESTS: http_limit,
- HTTPS_BANDWIDTH_UP: bandwidth_limit / 5,
- HTTPS_BANDWIDTH_DOWN: bandwidth_limit / 5,
- HTTPS_REQUESTS: http_limit / 5,
- RECORDS: rps_limit,
- }
+ bulkloader_limits = dict(remote_api_throttle.NO_LIMITS)
+ bulkloader_limits.update({
+ remote_api_throttle.BANDWIDTH_UP: bandwidth_limit,
+ remote_api_throttle.BANDWIDTH_DOWN: bandwidth_limit,
+ remote_api_throttle.REQUESTS: http_limit,
+ remote_api_throttle.HTTPS_BANDWIDTH_UP: bandwidth_limit,
+ remote_api_throttle.HTTPS_BANDWIDTH_DOWN: bandwidth_limit,
+ remote_api_throttle.HTTPS_REQUESTS: http_limit,
+ remote_api_throttle.ENTITIES_FETCHED: rps_limit,
+ remote_api_throttle.ENTITIES_MODIFIED: rps_limit,
+ })
+ return bulkloader_limits
def CheckOutputFile(filename):
@@ -3969,12 +3356,13 @@
Raises:
FileExistsError: if the given filename is not found
FileNotWritableError: if the given filename is not readable.
- """
- if os.path.exists(filename):
+ """
+ full_path = os.path.abspath(filename)
+ if os.path.exists(full_path):
raise FileExistsError('%s: output file exists' % filename)
- elif not os.access(os.path.dirname(filename), os.W_OK):
+ elif not os.access(os.path.dirname(full_path), os.W_OK):
raise FileNotWritableError(
- '%s: not writable' % os.path.dirname(filename))
+ '%s: not writable' % os.path.dirname(full_path))
def LoadConfig(config_file_name, exit_fn=sys.exit):
@@ -3999,6 +3387,11 @@
if hasattr(bulkloader_config, 'exporters'):
for cls in bulkloader_config.exporters:
Exporter.RegisterExporter(cls())
+
+ if hasattr(bulkloader_config, 'mappers'):
+ for cls in bulkloader_config.mappers:
+ Mapper.RegisterMapper(cls())
+
except NameError, e:
m = re.search(r"[^']*'([^']*)'.*", str(e))
if m.groups() and m.group(1) == 'Loader':
@@ -4058,9 +3451,12 @@
url=None,
kind=None,
db_filename=None,
+ perform_map=None,
download=None,
has_header=None,
- result_db_filename=None):
+ result_db_filename=None,
+ dump=None,
+ restore=None):
"""Returns a string that identifies the important options for the database."""
if download:
result_db_line = 'result_db: %s' % result_db_filename
@@ -4071,10 +3467,14 @@
url: %s
kind: %s
download: %s
+ map: %s
+ dump: %s
+ restore: %s
progress_db: %s
has_header: %s
%s
- """ % (app_id, url, kind, download, db_filename, has_header, result_db_line)
+ """ % (app_id, url, kind, download, perform_map, dump, restore, db_filename,
+ has_header, result_db_line)
def ProcessArguments(arg_dict,
@@ -4090,6 +3490,8 @@
"""
app_id = GetArgument(arg_dict, 'app_id', die_fn)
url = GetArgument(arg_dict, 'url', die_fn)
+ dump = GetArgument(arg_dict, 'dump', die_fn)
+ restore = GetArgument(arg_dict, 'restore', die_fn)
filename = GetArgument(arg_dict, 'filename', die_fn)
batch_size = GetArgument(arg_dict, 'batch_size', die_fn)
kind = GetArgument(arg_dict, 'kind', die_fn)
@@ -4098,21 +3500,18 @@
result_db_filename = GetArgument(arg_dict, 'result_db_filename', die_fn)
download = GetArgument(arg_dict, 'download', die_fn)
log_file = GetArgument(arg_dict, 'log_file', die_fn)
-
- unused_passin = GetArgument(arg_dict, 'passin', die_fn)
- unused_email = GetArgument(arg_dict, 'email', die_fn)
- unused_debug = GetArgument(arg_dict, 'debug', die_fn)
- unused_num_threads = GetArgument(arg_dict, 'num_threads', die_fn)
- unused_bandwidth_limit = GetArgument(arg_dict, 'bandwidth_limit', die_fn)
- unused_rps_limit = GetArgument(arg_dict, 'rps_limit', die_fn)
- unused_http_limit = GetArgument(arg_dict, 'http_limit', die_fn)
- unused_auth_domain = GetArgument(arg_dict, 'auth_domain', die_fn)
- unused_has_headers = GetArgument(arg_dict, 'has_header', die_fn)
- unused_loader_opts = GetArgument(arg_dict, 'loader_opts', die_fn)
- unused_exporter_opts = GetArgument(arg_dict, 'exporter_opts', die_fn)
+ perform_map = GetArgument(arg_dict, 'map', die_fn)
errors = []
+ if batch_size is None:
+ if download or perform_map:
+ arg_dict['batch_size'] = DEFAULT_DOWNLOAD_BATCH_SIZE
+ else:
+ arg_dict['batch_size'] = DEFAULT_BATCH_SIZE
+ elif batch_size <= 0:
+ errors.append('batch_size must be at least 1')
+
if db_filename is None:
arg_dict['db_filename'] = time.strftime(
'bulkloader-progress-%Y%m%d.%H%M%S.sql3')
@@ -4124,37 +3523,35 @@
if log_file is None:
arg_dict['log_file'] = time.strftime('bulkloader-log-%Y%m%d.%H%M%S')
- if batch_size <= 0:
- errors.append('batch_size must be at least 1')
-
required = '%s argument required'
+ if config_file is None and not dump and not restore:
+ errors.append('One of --config_file, --dump, or --restore is required')
+
if url is REQUIRED_OPTION:
errors.append(required % 'url')
- if filename is REQUIRED_OPTION:
+ if not filename and not perform_map:
errors.append(required % 'filename')
- if kind is REQUIRED_OPTION:
- errors.append(required % 'kind')
-
- if config_file is REQUIRED_OPTION:
- errors.append(required % 'config_file')
-
- if download:
- if result_db_filename is REQUIRED_OPTION:
- errors.append(required % 'result_db_filename')
+ if kind is None:
+ if download or map:
+ errors.append('kind argument required for this operation')
+ elif not dump and not restore:
+ errors.append(
+ 'kind argument required unless --dump or --restore is specified')
if not app_id:
- (unused_scheme, host_port, unused_url_path,
- unused_query, unused_fragment) = urlparse.urlsplit(url)
- suffix_idx = host_port.find('.appspot.com')
- if suffix_idx > -1:
- arg_dict['app_id'] = host_port[:suffix_idx]
- elif host_port.split(':')[0].endswith('google.com'):
- arg_dict['app_id'] = host_port.split('.')[0]
- else:
- errors.append('app_id argument required for non appspot.com domains')
+ if url and url is not REQUIRED_OPTION:
+ (unused_scheme, host_port, unused_url_path,
+ unused_query, unused_fragment) = urlparse.urlsplit(url)
+ suffix_idx = host_port.find('.appspot.com')
+ if suffix_idx > -1:
+ arg_dict['app_id'] = host_port[:suffix_idx]
+ elif host_port.split(':')[0].endswith('google.com'):
+ arg_dict['app_id'] = host_port.split('.')[0]
+ else:
+ errors.append('app_id argument required for non appspot.com domains')
if errors:
print >>sys.stderr, '\n'.join(errors)
@@ -4203,50 +3600,68 @@
result_db_filename = arg_dict['result_db_filename']
loader_opts = arg_dict['loader_opts']
exporter_opts = arg_dict['exporter_opts']
+ mapper_opts = arg_dict['mapper_opts']
email = arg_dict['email']
passin = arg_dict['passin']
+ perform_map = arg_dict['map']
+ dump = arg_dict['dump']
+ restore = arg_dict['restore']
os.environ['AUTH_DOMAIN'] = auth_domain
kind = ParseKind(kind)
- check_file(config_file)
- if not download:
+ if not dump and not restore:
+ check_file(config_file)
+
+ if download and perform_map:
+ logger.error('--download and --map are mutually exclusive.')
+
+ if download or dump:
+ check_output_file(filename)
+ elif not perform_map:
check_file(filename)
+
+ if dump:
+ Exporter.RegisterExporter(DumpExporter(kind, result_db_filename))
+ elif restore:
+ Loader.RegisterLoader(RestoreLoader(kind))
else:
- check_output_file(filename)
-
- LoadConfig(config_file)
+ LoadConfig(config_file)
os.environ['APPLICATION_ID'] = app_id
throttle_layout = ThrottleLayout(bandwidth_limit, http_limit, rps_limit)
-
- throttle = Throttle(layout=throttle_layout)
+ logger.info('Throttling transfers:')
+ logger.info('Bandwidth: %s bytes/second', bandwidth_limit)
+ logger.info('HTTP connections: %s/second', http_limit)
+ logger.info('Entities inserted/fetched/modified: %s/second', rps_limit)
+
+ throttle = remote_api_throttle.Throttle(layout=throttle_layout)
signature = _MakeSignature(app_id=app_id,
url=url,
kind=kind,
db_filename=db_filename,
download=download,
+ perform_map=perform_map,
has_header=has_header,
- result_db_filename=result_db_filename)
+ result_db_filename=result_db_filename,
+ dump=dump,
+ restore=restore)
max_queue_size = max(DEFAULT_QUEUE_SIZE, 3 * num_threads + 5)
if db_filename == 'skip':
progress_db = StubProgressDatabase()
- elif not download:
+ elif not download and not perform_map and not dump:
progress_db = ProgressDatabase(db_filename, signature)
else:
progress_db = ExportProgressDatabase(db_filename, signature)
- if download:
- result_db = ResultDatabase(result_db_filename, signature)
-
return_code = 1
- if not download:
+ if not download and not perform_map and not dump:
loader = Loader.RegisteredLoader(kind)
try:
loader.initialize(filename, loader_opts)
@@ -4257,12 +3672,10 @@
workitem_generator_factory,
throttle,
progress_db,
- BulkLoaderThread,
ProgressTrackerThread,
max_queue_size,
RequestManager,
DataSourceThread,
- ReQueue,
Queue.Queue)
try:
return_code = app.Run()
@@ -4270,29 +3683,31 @@
logger.info('Authentication Failed')
finally:
loader.finalize()
- else:
+ elif not perform_map:
+ result_db = ResultDatabase(result_db_filename, signature)
exporter = Exporter.RegisteredExporter(kind)
try:
exporter.initialize(filename, exporter_opts)
- def KeyRangeGeneratorFactory(progress_queue, progress_gen):
- return KeyRangeGenerator(kind, progress_queue, progress_gen)
+ def KeyRangeGeneratorFactory(request_manager, progress_queue,
+ progress_gen):
+ return KeyRangeItemGenerator(request_manager, kind, progress_queue,
+ progress_gen, DownloadItem)
def ExportProgressThreadFactory(progress_queue, progress_db):
return ExportProgressThread(kind,
progress_queue,
progress_db,
result_db)
+
app = BulkDownloaderApp(arg_dict,
KeyRangeGeneratorFactory,
throttle,
progress_db,
- BulkExporterThread,
ExportProgressThreadFactory,
0,
RequestManager,
DataSourceThread,
- ReQueue,
Queue.Queue)
try:
return_code = app.Run()
@@ -4300,6 +3715,35 @@
logger.info('Authentication Failed')
finally:
exporter.finalize()
+ elif not download:
+ mapper = Mapper.RegisteredMapper(kind)
+ try:
+ mapper.initialize(mapper_opts)
+ def KeyRangeGeneratorFactory(request_manager, progress_queue,
+ progress_gen):
+ return KeyRangeItemGenerator(request_manager, kind, progress_queue,
+ progress_gen, MapperItem)
+
+ def MapperProgressThreadFactory(progress_queue, progress_db):
+ return MapperProgressThread(kind,
+ progress_queue,
+ progress_db)
+
+ app = BulkMapperApp(arg_dict,
+ KeyRangeGeneratorFactory,
+ throttle,
+ progress_db,
+ MapperProgressThreadFactory,
+ 0,
+ RequestManager,
+ DataSourceThread,
+ Queue.Queue)
+ try:
+ return_code = app.Run()
+ except AuthenticationError:
+ logger.info('Authentication Failed')
+ finally:
+ mapper.finalize()
return return_code
@@ -4335,8 +3779,17 @@
logger.info('Logging to %s', log_file)
+ remote_api_throttle.logger.setLevel(level)
+ remote_api_throttle.logger.addHandler(file_handler)
+ remote_api_throttle.logger.addHandler(console)
+
appengine_rpc.logger.setLevel(logging.WARN)
+ adaptive_thread_pool.logger.setLevel(logging.DEBUG)
+ adaptive_thread_pool.logger.addHandler(console)
+ adaptive_thread_pool.logger.addHandler(file_handler)
+ adaptive_thread_pool.logger.propagate = False
+
def Run(arg_dict):
"""Sets up and runs the bulkloader, given the options as keyword arguments.