--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Thu Feb 12 12:30:36 2009 +0000
@@ -0,0 +1,2588 @@
+#!/usr/bin/env python
+#
+# Copyright 2007 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Imports CSV data over HTTP.
+
+Usage:
+ %(arg0)s [flags]
+
+ --debug Show debugging information. (Optional)
+ --app_id=<string> Application ID of endpoint (Optional for
+ *.appspot.com)
+ --auth_domain=<domain> The auth domain to use for logging in and for
+ UserProperties. (Default: gmail.com)
+ --bandwidth_limit=<int> The maximum number of bytes per second for the
+ aggregate transfer of data to the server. Bursts
+ --batch_size=<int> Number of Entity objects to include in each post to
+ the URL endpoint. The more data per row/Entity, the
+ smaller the batch size should be. (Default 10)
+ --config_file=<path> File containing Model and Loader definitions.
+ (Required)
+ --db_filename=<path> Specific progress database to write to, or to
+ resume from. If not supplied, then a new database
+ will be started, named:
+ bulkloader-progress-TIMESTAMP.
+ The special filename "skip" may be used to simply
+ skip reading/writing any progress information.
+ --filename=<path> Path to the CSV file to import. (Required)
+ --http_limit=<int> The maximum numer of HTTP requests per second to
+ send to the server. (Default: 8)
+ --kind=<string> Name of the Entity object kind to put in the
+ datastore. (Required)
+ --num_threads=<int> Number of threads to use for uploading entities
+ (Default 10)
+ may exceed this, but overall transfer rate is
+ restricted to this rate. (Default 250000)
+ --rps_limit=<int> The maximum number of records per second to
+ transfer to the server. (Default: 20)
+ --url=<string> URL endpoint to post to for importing data.
+ (Required)
+
+The exit status will be 0 on success, non-zero on import failure.
+
+Works with the remote_api mix-in library for google.appengine.ext.remote_api.
+Please look there for documentation about how to setup the server side.
+
+Example:
+
+%(arg0)s --url=http://app.appspot.com/remote_api --kind=Model \
+ --filename=data.csv --config_file=loader_config.py
+
+"""
+
+
+
+import csv
+import getopt
+import getpass
+import logging
+import new
+import os
+import Queue
+import signal
+import sys
+import threading
+import time
+import traceback
+import urllib2
+import urlparse
+
+from google.appengine.ext import db
+from google.appengine.ext.remote_api import remote_api_stub
+from google.appengine.tools import appengine_rpc
+
+try:
+ import sqlite3
+except ImportError:
+ pass
+
+UPLOADER_VERSION = '1'
+
+DEFAULT_THREAD_COUNT = 10
+
+DEFAULT_BATCH_SIZE = 10
+
+DEFAULT_QUEUE_SIZE = DEFAULT_THREAD_COUNT * 10
+
+_THREAD_SHOULD_EXIT = '_THREAD_SHOULD_EXIT'
+
+STATE_READ = 0
+STATE_SENDING = 1
+STATE_SENT = 2
+STATE_NOT_SENT = 3
+
+MINIMUM_THROTTLE_SLEEP_DURATION = 0.001
+
+DATA_CONSUMED_TO_HERE = 'DATA_CONSUMED_TO_HERE'
+
+INITIAL_BACKOFF = 1.0
+
+BACKOFF_FACTOR = 2.0
+
+
+DEFAULT_BANDWIDTH_LIMIT = 250000
+
+DEFAULT_RPS_LIMIT = 20
+
+DEFAULT_REQUEST_LIMIT = 8
+
+BANDWIDTH_UP = 'http-bandwidth-up'
+BANDWIDTH_DOWN = 'http-bandwidth-down'
+REQUESTS = 'http-requests'
+HTTPS_BANDWIDTH_UP = 'https-bandwidth-up'
+HTTPS_BANDWIDTH_DOWN = 'https-bandwidth-down'
+HTTPS_REQUESTS = 'https-requests'
+RECORDS = 'records'
+
+
+def StateMessage(state):
+ """Converts a numeric state identifier to a status message."""
+ return ({
+ STATE_READ: 'Batch read from file.',
+ STATE_SENDING: 'Sending batch to server.',
+ STATE_SENT: 'Batch successfully sent.',
+ STATE_NOT_SENT: 'Error while sending batch.'
+ }[state])
+
+
+class Error(Exception):
+ """Base-class for exceptions in this module."""
+
+
+class FatalServerError(Error):
+ """An unrecoverable error occurred while trying to post data to the server."""
+
+
+class ResumeError(Error):
+ """Error while trying to resume a partial upload."""
+
+
+class ConfigurationError(Error):
+ """Error in configuration options."""
+
+
+class AuthenticationError(Error):
+ """Error while trying to authenticate with the server."""
+
+
+def GetCSVGeneratorFactory(csv_filename, batch_size,
+ openfile=open, create_csv_reader=csv.reader):
+ """Return a factory that creates a CSV-based WorkItem generator.
+
+ Args:
+ csv_filename: File on disk containing CSV data.
+ batch_size: Maximum number of CSV rows to stash into a WorkItem.
+ openfile: Used for dependency injection.
+ create_csv_reader: Used for dependency injection.
+
+ Returns: A callable (accepting the Progress Queue and Progress
+ Generators as input) which creates the WorkItem generator.
+ """
+
+ def CreateGenerator(progress_queue, progress_generator):
+ """Initialize a CSV generator linked to a progress generator and queue.
+
+ Args:
+ progress_queue: A ProgressQueue instance to send progress information.
+ progress_generator: A generator of progress information or None.
+
+ Returns:
+ A CSVGenerator instance.
+ """
+ return CSVGenerator(progress_queue,
+ progress_generator,
+ csv_filename,
+ batch_size,
+ openfile,
+ create_csv_reader)
+ return CreateGenerator
+
+
+class CSVGenerator(object):
+ """Reads a CSV file and generates WorkItems containing batches of records."""
+
+ def __init__(self,
+ progress_queue,
+ progress_generator,
+ csv_filename,
+ batch_size,
+ openfile,
+ create_csv_reader):
+ """Initializes a CSV generator.
+
+ Args:
+ progress_queue: A queue used for tracking progress information.
+ progress_generator: A generator of prior progress information, or None
+ if there is no prior status.
+ csv_filename: File on disk containing CSV data.
+ batch_size: Maximum number of CSV rows to stash into a WorkItem.
+ openfile: Used for dependency injection of 'open'.
+ create_csv_reader: Used for dependency injection of 'csv.reader'.
+ """
+ self.progress_queue = progress_queue
+ self.progress_generator = progress_generator
+ self.csv_filename = csv_filename
+ self.batch_size = batch_size
+ self.openfile = openfile
+ self.create_csv_reader = create_csv_reader
+ self.line_number = 1
+ self.column_count = None
+ self.read_rows = []
+ self.reader = None
+ self.row_count = 0
+ self.sent_count = 0
+
+ def _AdvanceTo(self, line):
+ """Advance the reader to the given line.
+
+ Args:
+ line: A line number to advance to.
+ """
+ while self.line_number < line:
+ self.reader.next()
+ self.line_number += 1
+ self.row_count += 1
+ self.sent_count += 1
+
+ def _ReadRows(self, key_start, key_end):
+ """Attempts to read and encode rows [key_start, key_end].
+
+ The encoded rows are stored in self.read_rows.
+
+ Args:
+ key_start: The starting line number.
+ key_end: The ending line number.
+
+ Raises:
+ StopIteration: if the reader runs out of rows
+ ResumeError: if there are an inconsistent number of columns.
+ """
+ assert self.line_number == key_start
+ self.read_rows = []
+ while self.line_number <= key_end:
+ row = self.reader.next()
+ self.row_count += 1
+ if self.column_count is None:
+ self.column_count = len(row)
+ else:
+ if self.column_count != len(row):
+ raise ResumeError('Column count mismatch, %d: %s' %
+ (self.column_count, str(row)))
+ self.read_rows.append((self.line_number, row))
+ self.line_number += 1
+
+ def _MakeItem(self, key_start, key_end, rows, progress_key=None):
+ """Makes a WorkItem containing the given rows, with the given keys.
+
+ Args:
+ key_start: The start key for the WorkItem.
+ key_end: The end key for the WorkItem.
+ rows: A list of the rows for the WorkItem.
+ progress_key: The progress key for the WorkItem
+
+ Returns:
+ A WorkItem instance for the given batch.
+ """
+ assert rows
+
+ item = WorkItem(self.progress_queue, rows,
+ key_start, key_end,
+ progress_key=progress_key)
+
+ return item
+
+ def Batches(self):
+ """Reads the CSV data file and generates WorkItems.
+
+ Yields:
+ Instances of class WorkItem
+
+ Raises:
+ ResumeError: If the progress database and data file indicate a different
+ number of rows.
+ """
+ csv_file = self.openfile(self.csv_filename, 'r')
+ csv_content = csv_file.read()
+ if csv_content:
+ has_headers = csv.Sniffer().has_header(csv_content)
+ else:
+ has_headers = False
+ csv_file.seek(0)
+ self.reader = self.create_csv_reader(csv_file, skipinitialspace=True)
+ if has_headers:
+ logging.info('The CSV file appears to have a header line, skipping.')
+ self.reader.next()
+
+ exhausted = False
+
+ self.line_number = 1
+ self.column_count = None
+
+ logging.info('Starting import; maximum %d entities per post',
+ self.batch_size)
+
+ state = None
+ if self.progress_generator is not None:
+ for progress_key, state, key_start, key_end in self.progress_generator:
+ if key_start:
+ try:
+ self._AdvanceTo(key_start)
+ self._ReadRows(key_start, key_end)
+ yield self._MakeItem(key_start,
+ key_end,
+ self.read_rows,
+ progress_key=progress_key)
+ except StopIteration:
+ logging.error('Mismatch between data file and progress database')
+ raise ResumeError(
+ 'Mismatch between data file and progress database')
+ elif state == DATA_CONSUMED_TO_HERE:
+ try:
+ self._AdvanceTo(key_end + 1)
+ except StopIteration:
+ state = None
+
+ if self.progress_generator is None or state == DATA_CONSUMED_TO_HERE:
+ while not exhausted:
+ key_start = self.line_number
+ key_end = self.line_number + self.batch_size - 1
+ try:
+ self._ReadRows(key_start, key_end)
+ except StopIteration:
+ exhausted = True
+ key_end = self.line_number - 1
+ if key_start <= key_end:
+ yield self._MakeItem(key_start, key_end, self.read_rows)
+
+
+class ReQueue(object):
+ """A special thread-safe queue.
+
+ A ReQueue allows unfinished work items to be returned with a call to
+ reput(). When an item is reput, task_done() should *not* be called
+ in addition, getting an item that has been reput does not increase
+ the number of outstanding tasks.
+
+ This class shares an interface with Queue.Queue and provides the
+ additional Reput method.
+ """
+
+ def __init__(self,
+ queue_capacity,
+ requeue_capacity=None,
+ queue_factory=Queue.Queue,
+ get_time=time.time):
+ """Initialize a ReQueue instance.
+
+ Args:
+ queue_capacity: The number of items that can be put in the ReQueue.
+ requeue_capacity: The numer of items that can be reput in the ReQueue.
+ queue_factory: Used for dependency injection.
+ get_time: Used for dependency injection.
+ """
+ if requeue_capacity is None:
+ requeue_capacity = queue_capacity
+
+ self.get_time = get_time
+ self.queue = queue_factory(queue_capacity)
+ self.requeue = queue_factory(requeue_capacity)
+ self.lock = threading.Lock()
+ self.put_cond = threading.Condition(self.lock)
+ self.get_cond = threading.Condition(self.lock)
+
+ def _DoWithTimeout(self,
+ action,
+ exc,
+ wait_cond,
+ done_cond,
+ lock,
+ timeout=None,
+ block=True):
+ """Performs the given action with a timeout.
+
+ The action must be non-blocking, and raise an instance of exc on a
+ recoverable failure. If the action fails with an instance of exc,
+ we wait on wait_cond before trying again. Failure after the
+ timeout is reached is propagated as an exception. Success is
+ signalled by notifying on done_cond and returning the result of
+ the action. If action raises any exception besides an instance of
+ exc, it is immediately propagated.
+
+ Args:
+ action: A callable that performs a non-blocking action.
+ exc: An exception type that is thrown by the action to indicate
+ a recoverable error.
+ wait_cond: A condition variable which should be waited on when
+ action throws exc.
+ done_cond: A condition variable to signal if the action returns.
+ lock: The lock used by wait_cond and done_cond.
+ timeout: A non-negative float indicating the maximum time to wait.
+ block: Whether to block if the action cannot complete immediately.
+
+ Returns:
+ The result of the action, if it is successful.
+
+ Raises:
+ ValueError: If the timeout argument is negative.
+ """
+ if timeout is not None and timeout < 0.0:
+ raise ValueError('\'timeout\' must not be a negative number')
+ if not block:
+ timeout = 0.0
+ result = None
+ success = False
+ start_time = self.get_time()
+ lock.acquire()
+ try:
+ while not success:
+ try:
+ result = action()
+ success = True
+ except Exception, e:
+ if not isinstance(e, exc):
+ raise e
+ if timeout is not None:
+ elapsed_time = self.get_time() - start_time
+ timeout -= elapsed_time
+ if timeout <= 0.0:
+ raise e
+ wait_cond.wait(timeout)
+ finally:
+ if success:
+ done_cond.notify()
+ lock.release()
+ return result
+
+ def put(self, item, block=True, timeout=None):
+ """Put an item into the requeue.
+
+ Args:
+ item: An item to add to the requeue.
+ block: Whether to block if the requeue is full.
+ timeout: Maximum on how long to wait until the queue is non-full.
+
+ Raises:
+ Queue.Full if the queue is full and the timeout expires.
+ """
+ def PutAction():
+ self.queue.put(item, block=False)
+ self._DoWithTimeout(PutAction,
+ Queue.Full,
+ self.get_cond,
+ self.put_cond,
+ self.lock,
+ timeout=timeout,
+ block=block)
+
+ def reput(self, item, block=True, timeout=None):
+ """Re-put an item back into the requeue.
+
+ Re-putting an item does not increase the number of outstanding
+ tasks, so the reput item should be uniquely associated with an
+ item that was previously removed from the requeue and for which
+ task_done has not been called.
+
+ Args:
+ item: An item to add to the requeue.
+ block: Whether to block if the requeue is full.
+ timeout: Maximum on how long to wait until the queue is non-full.
+
+ Raises:
+ Queue.Full is the queue is full and the timeout expires.
+ """
+ def ReputAction():
+ self.requeue.put(item, block=False)
+ self._DoWithTimeout(ReputAction,
+ Queue.Full,
+ self.get_cond,
+ self.put_cond,
+ self.lock,
+ timeout=timeout,
+ block=block)
+
+ def get(self, block=True, timeout=None):
+ """Get an item from the requeue.
+
+ Args:
+ block: Whether to block if the requeue is empty.
+ timeout: Maximum on how long to wait until the requeue is non-empty.
+
+ Returns:
+ An item from the requeue.
+
+ Raises:
+ Queue.Empty if the queue is empty and the timeout expires.
+ """
+ def GetAction():
+ try:
+ result = self.requeue.get(block=False)
+ self.requeue.task_done()
+ except Queue.Empty:
+ result = self.queue.get(block=False)
+ return result
+ return self._DoWithTimeout(GetAction,
+ Queue.Empty,
+ self.put_cond,
+ self.get_cond,
+ self.lock,
+ timeout=timeout,
+ block=block)
+
+ def join(self):
+ """Blocks until all of the items in the requeue have been processed."""
+ self.queue.join()
+
+ def task_done(self):
+ """Indicate that a previously enqueued item has been fully processed."""
+ self.queue.task_done()
+
+ def empty(self):
+ """Returns true if the requeue is empty."""
+ return self.queue.empty() and self.requeue.empty()
+
+ def get_nowait(self):
+ """Try to get an item from the queue without blocking."""
+ return self.get(block=False)
+
+
+class ThrottleHandler(urllib2.BaseHandler):
+ """A urllib2 handler for http and https requests that adds to a throttle."""
+
+ def __init__(self, throttle):
+ """Initialize a ThrottleHandler.
+
+ Args:
+ throttle: A Throttle instance to call for bandwidth and http/https request
+ throttling.
+ """
+ self.throttle = throttle
+
+ def AddRequest(self, throttle_name, req):
+ """Add to bandwidth throttle for given request.
+
+ Args:
+ throttle_name: The name of the bandwidth throttle to add to.
+ req: The request whose size will be added to the throttle.
+ """
+ size = 0
+ for key, value in req.headers.iteritems():
+ size += len('%s: %s\n' % (key, value))
+ for key, value in req.unredirected_hdrs.iteritems():
+ size += len('%s: %s\n' % (key, value))
+ (unused_scheme,
+ unused_host_port, url_path,
+ unused_query, unused_fragment) = urlparse.urlsplit(req.get_full_url())
+ size += len('%s %s HTTP/1.1\n' % (req.get_method(), url_path))
+ data = req.get_data()
+ if data:
+ size += len(data)
+ self.throttle.AddTransfer(throttle_name, size)
+
+ def AddResponse(self, throttle_name, res):
+ """Add to bandwidth throttle for given response.
+
+ Args:
+ throttle_name: The name of the bandwidth throttle to add to.
+ res: The response whose size will be added to the throttle.
+ """
+ content = res.read()
+ def ReturnContent():
+ return content
+ res.read = ReturnContent
+ size = len(content)
+ headers = res.info()
+ for key, value in headers.items():
+ size += len('%s: %s\n' % (key, value))
+ self.throttle.AddTransfer(throttle_name, size)
+
+ def http_request(self, req):
+ """Process an HTTP request.
+
+ If the throttle is over quota, sleep first. Then add request size to
+ throttle before returning it to be sent.
+
+ Args:
+ req: A urllib2.Request object.
+
+ Returns:
+ The request passed in.
+ """
+ self.throttle.Sleep()
+ self.AddRequest(BANDWIDTH_UP, req)
+ return req
+
+ def https_request(self, req):
+ """Process an HTTPS request.
+
+ If the throttle is over quota, sleep first. Then add request size to
+ throttle before returning it to be sent.
+
+ Args:
+ req: A urllib2.Request object.
+
+ Returns:
+ The request passed in.
+ """
+ self.throttle.Sleep()
+ self.AddRequest(HTTPS_BANDWIDTH_UP, req)
+ return req
+
+ def http_response(self, unused_req, res):
+ """Process an HTTP response.
+
+ The size of the response is added to the bandwidth throttle and the request
+ throttle is incremented by one.
+
+ Args:
+ unused_req: The urllib2 request for this response.
+ res: A urllib2 response object.
+
+ Returns:
+ The response passed in.
+ """
+ self.AddResponse(BANDWIDTH_DOWN, res)
+ self.throttle.AddTransfer(REQUESTS, 1)
+ return res
+
+ def https_response(self, unused_req, res):
+ """Process an HTTPS response.
+
+ The size of the response is added to the bandwidth throttle and the request
+ throttle is incremented by one.
+
+ Args:
+ unused_req: The urllib2 request for this response.
+ res: A urllib2 response object.
+
+ Returns:
+ The response passed in.
+ """
+ self.AddResponse(HTTPS_BANDWIDTH_DOWN, res)
+ self.throttle.AddTransfer(HTTPS_REQUESTS, 1)
+ return res
+
+
+class ThrottledHttpRpcServer(appengine_rpc.HttpRpcServer):
+ """Provides a simplified RPC-style interface for HTTP requests.
+
+ This RPC server uses a Throttle to prevent exceeding quotas.
+ """
+
+ def __init__(self, throttle, request_manager, *args, **kwargs):
+ """Initialize a ThrottledHttpRpcServer.
+
+ Also sets request_manager.rpc_server to the ThrottledHttpRpcServer instance.
+
+ Args:
+ throttle: A Throttles instance.
+ request_manager: A RequestManager instance.
+ args: Positional arguments to pass through to
+ appengine_rpc.HttpRpcServer.__init__
+ kwargs: Keyword arguments to pass through to
+ appengine_rpc.HttpRpcServer.__init__
+ """
+ self.throttle = throttle
+ appengine_rpc.HttpRpcServer.__init__(self, *args, **kwargs)
+ request_manager.rpc_server = self
+
+ def _GetOpener(self):
+ """Returns an OpenerDirector that supports cookies and ignores redirects.
+
+ Returns:
+ A urllib2.OpenerDirector object.
+ """
+ opener = appengine_rpc.HttpRpcServer._GetOpener(self)
+ opener.add_handler(ThrottleHandler(self.throttle))
+
+ return opener
+
+
+def ThrottledHttpRpcServerFactory(throttle, request_manager):
+ """Create a factory to produce ThrottledHttpRpcServer for a given throttle.
+
+ Args:
+ throttle: A Throttle instance to use for the ThrottledHttpRpcServer.
+ request_manager: A RequestManager instance.
+
+ Returns:
+ A factory to produce a ThrottledHttpRpcServer.
+ """
+ def MakeRpcServer(*args, **kwargs):
+ kwargs['account_type'] = 'HOSTED_OR_GOOGLE'
+ kwargs['save_cookies'] = True
+ return ThrottledHttpRpcServer(throttle, request_manager, *args, **kwargs)
+ return MakeRpcServer
+
+
+class RequestManager(object):
+ """A class which wraps a connection to the server."""
+
+ source = 'google-bulkloader-%s' % UPLOADER_VERSION
+ user_agent = source
+
+ def __init__(self,
+ app_id,
+ host_port,
+ url_path,
+ kind,
+ throttle):
+ """Initialize a RequestManager object.
+
+ Args:
+ app_id: String containing the application id for requests.
+ host_port: String containing the "host:port" pair; the port is optional.
+ url_path: partial URL (path) to post entity data to.
+ kind: Kind of the Entity records being posted.
+ throttle: A Throttle instance.
+ """
+ self.app_id = app_id
+ self.host_port = host_port
+ self.host = host_port.split(':')[0]
+ if url_path and url_path[0] != '/':
+ url_path = '/' + url_path
+ self.url_path = url_path
+ self.kind = kind
+ self.throttle = throttle
+ self.credentials = None
+ throttled_rpc_server_factory = ThrottledHttpRpcServerFactory(
+ self.throttle, self)
+ logging.debug('Configuring remote_api. app_id = %s, url_path = %s, '
+ 'servername = %s' % (app_id, url_path, host_port))
+ remote_api_stub.ConfigureRemoteDatastore(
+ app_id,
+ url_path,
+ self.AuthFunction,
+ servername=host_port,
+ rpc_server_factory=throttled_rpc_server_factory)
+ self.authenticated = False
+
+ def Authenticate(self):
+ """Invoke authentication if necessary."""
+ self.rpc_server.Send(self.url_path, payload=None)
+ self.authenticated = True
+
+ def AuthFunction(self,
+ raw_input_fn=raw_input,
+ password_input_fn=getpass.getpass):
+ """Prompts the user for a username and password.
+
+ Caches the results the first time it is called and returns the
+ same result every subsequent time.
+
+ Args:
+ raw_input_fn: Used for dependency injection.
+ password_input_fn: Used for dependency injection.
+
+ Returns:
+ A pair of the username and password.
+ """
+ if self.credentials is not None:
+ return self.credentials
+ print 'Please enter login credentials for %s (%s)' % (
+ self.host, self.app_id)
+ email = raw_input_fn('Email: ')
+ if email:
+ password_prompt = 'Password for %s: ' % email
+ password = password_input_fn(password_prompt)
+ else:
+ password = None
+ self.credentials = (email, password)
+ return self.credentials
+
+ def _GetHeaders(self):
+ """Constructs a dictionary of extra headers to send with a request."""
+ headers = {
+ 'GAE-Uploader-Version': UPLOADER_VERSION,
+ 'GAE-Uploader-Kind': self.kind
+ }
+ return headers
+
+ def EncodeContent(self, rows):
+ """Encodes row data to the wire format.
+
+ Args:
+ rows: A list of pairs of a line number and a list of column values.
+
+ Returns:
+ A list of db.Model instances.
+ """
+ try:
+ loader = Loader.RegisteredLoaders()[self.kind]
+ except KeyError:
+ logging.error('No Loader defined for kind %s.' % self.kind)
+ raise ConfigurationError('No Loader defined for kind %s.' % self.kind)
+ entities = []
+ for line_number, values in rows:
+ key = loader.GenerateKey(line_number, values)
+ entity = loader.CreateEntity(values, key_name=key)
+ entities.extend(entity)
+
+ return entities
+
+ def PostEntities(self, item):
+ """Posts Entity records to a remote endpoint over HTTP.
+
+ Args:
+ item: A workitem containing the entities to post.
+
+ Returns:
+ A pair of the estimated size of the request in bytes and the response
+ from the server as a str.
+ """
+ entities = item.content
+ db.put(entities)
+
+
+class WorkItem(object):
+ """Holds a unit of uploading work.
+
+ A WorkItem represents a number of entities that need to be uploaded to
+ Google App Engine. These entities are encoded in the "content" field of
+ the WorkItem, and will be POST'd as-is to the server.
+
+ The entities are identified by a range of numeric keys, inclusively. In
+ the case of a resumption of an upload, or a replay to correct errors,
+ these keys must be able to identify the same set of entities.
+
+ Note that keys specify a range. The entities do not have to sequentially
+ fill the entire range, they must simply bound a range of valid keys.
+ """
+
+ def __init__(self, progress_queue, rows, key_start, key_end,
+ progress_key=None):
+ """Initialize the WorkItem instance.
+
+ Args:
+ progress_queue: A queue used for tracking progress information.
+ rows: A list of pairs of a line number and a list of column values
+ key_start: The (numeric) starting key, inclusive.
+ key_end: The (numeric) ending key, inclusive.
+ progress_key: If this WorkItem represents state from a prior run,
+ then this will be the key within the progress database.
+ """
+ self.state = STATE_READ
+
+ self.progress_queue = progress_queue
+
+ assert isinstance(key_start, (int, long))
+ assert isinstance(key_end, (int, long))
+ assert key_start <= key_end
+
+ self.key_start = key_start
+ self.key_end = key_end
+ self.progress_key = progress_key
+
+ self.progress_event = threading.Event()
+
+ self.rows = rows
+ self.content = None
+ self.count = len(rows)
+
+ def MarkAsRead(self):
+ """Mark this WorkItem as read/consumed from the data source."""
+
+ assert self.state == STATE_READ
+
+ self._StateTransition(STATE_READ, blocking=True)
+
+ assert self.progress_key is not None
+
+ def MarkAsSending(self):
+ """Mark this WorkItem as in-process on being uploaded to the server."""
+
+ assert self.state == STATE_READ or self.state == STATE_NOT_SENT
+ assert self.progress_key is not None
+
+ self._StateTransition(STATE_SENDING, blocking=True)
+
+ def MarkAsSent(self):
+ """Mark this WorkItem as sucessfully-sent to the server."""
+
+ assert self.state == STATE_SENDING
+ assert self.progress_key is not None
+
+ self._StateTransition(STATE_SENT, blocking=False)
+
+ def MarkAsError(self):
+ """Mark this WorkItem as required manual error recovery."""
+
+ assert self.state == STATE_SENDING
+ assert self.progress_key is not None
+
+ self._StateTransition(STATE_NOT_SENT, blocking=True)
+
+ def _StateTransition(self, new_state, blocking=False):
+ """Transition the work item to a new state, storing progress information.
+
+ Args:
+ new_state: The state to transition to.
+ blocking: Whether to block for the progress thread to acknowledge the
+ transition.
+ """
+ logging.debug('[%s-%s] %s' %
+ (self.key_start, self.key_end, StateMessage(self.state)))
+ assert not self.progress_event.isSet()
+
+ self.state = new_state
+
+ self.progress_queue.put(self)
+
+ if blocking:
+ self.progress_event.wait()
+
+ self.progress_event.clear()
+
+
+
+def InterruptibleSleep(sleep_time):
+ """Puts thread to sleep, checking this threads exit_flag twice a second.
+
+ Args:
+ sleep_time: Time to sleep.
+ """
+ slept = 0.0
+ epsilon = .0001
+ thread = threading.currentThread()
+ while slept < sleep_time - epsilon:
+ remaining = sleep_time - slept
+ this_sleep_time = min(remaining, 0.5)
+ time.sleep(this_sleep_time)
+ slept += this_sleep_time
+ if thread.exit_flag:
+ return
+
+
+class ThreadGate(object):
+ """Manage the number of active worker threads.
+
+ The ThreadGate limits the number of threads that are simultaneously
+ uploading batches of records in order to implement adaptive rate
+ control. The number of simultaneous upload threads that it takes to
+ start causing timeout varies widely over the course of the day, so
+ adaptive rate control allows the uploader to do many uploads while
+ reducing the error rate and thus increasing the throughput.
+
+ Initially the ThreadGate allows only one uploader thread to be active.
+ For each successful upload, another thread is activated and for each
+ failed upload, the number of active threads is reduced by one.
+ """
+
+ def __init__(self, enabled, sleep=InterruptibleSleep):
+ self.enabled = enabled
+ self.enabled_count = 1
+ self.lock = threading.Lock()
+ self.thread_semaphore = threading.Semaphore(self.enabled_count)
+ self._threads = []
+ self.backoff_time = 0
+ self.sleep = sleep
+
+ def Register(self, thread):
+ """Register a thread with the thread gate."""
+ self._threads.append(thread)
+
+ def Threads(self):
+ """Yields the registered threads."""
+ for thread in self._threads:
+ yield thread
+
+ def EnableThread(self):
+ """Enable one more worker thread."""
+ self.lock.acquire()
+ try:
+ self.enabled_count += 1
+ finally:
+ self.lock.release()
+ self.thread_semaphore.release()
+
+ def EnableAllThreads(self):
+ """Enable all worker threads."""
+ for unused_idx in range(len(self._threads) - self.enabled_count):
+ self.EnableThread()
+
+ def StartWork(self):
+ """Starts a critical section in which the number of workers is limited.
+
+ If thread throttling is enabled then this method starts a critical
+ section which allows self.enabled_count simultaneously operating
+ threads. The critical section is ended by calling self.FinishWork().
+ """
+ if self.enabled:
+ self.thread_semaphore.acquire()
+ if self.backoff_time > 0.0:
+ if not threading.currentThread().exit_flag:
+ logging.info('Backing off: %.1f seconds',
+ self.backoff_time)
+ self.sleep(self.backoff_time)
+
+ def FinishWork(self):
+ """Ends a critical section started with self.StartWork()."""
+ if self.enabled:
+ self.thread_semaphore.release()
+
+ def IncreaseWorkers(self):
+ """Informs the throttler that an item was successfully sent.
+
+ If thread throttling is enabled, this method will cause an
+ additional thread to run in the critical section.
+ """
+ if self.enabled:
+ if self.backoff_time > 0.0:
+ logging.info('Resetting backoff to 0.0')
+ self.backoff_time = 0.0
+ do_enable = False
+ self.lock.acquire()
+ try:
+ if self.enabled and len(self._threads) > self.enabled_count:
+ do_enable = True
+ self.enabled_count += 1
+ finally:
+ self.lock.release()
+ if do_enable:
+ self.thread_semaphore.release()
+
+ def DecreaseWorkers(self):
+ """Informs the thread_gate that an item failed to send.
+
+ If thread throttling is enabled, this method will cause the
+ throttler to allow one fewer thread in the critical section. If
+ there is only one thread remaining, failures will result in
+ exponential backoff until there is a success.
+ """
+ if self.enabled:
+ do_disable = False
+ self.lock.acquire()
+ try:
+ if self.enabled:
+ if self.enabled_count > 1:
+ do_disable = True
+ self.enabled_count -= 1
+ else:
+ if self.backoff_time == 0.0:
+ self.backoff_time = INITIAL_BACKOFF
+ else:
+ self.backoff_time *= BACKOFF_FACTOR
+ finally:
+ self.lock.release()
+ if do_disable:
+ self.thread_semaphore.acquire()
+
+
+class Throttle(object):
+ """A base class for upload rate throttling.
+
+ Transferring large number of records, too quickly, to an application
+ could trigger quota limits and cause the transfer process to halt.
+ In order to stay within the application's quota, we throttle the
+ data transfer to a specified limit (across all transfer threads).
+ This limit defaults to about half of the Google App Engine default
+ for an application, but can be manually adjusted faster/slower as
+ appropriate.
+
+ This class tracks a moving average of some aspect of the transfer
+ rate (bandwidth, records per second, http connections per
+ second). It keeps two windows of counts of bytes transferred, on a
+ per-thread basis. One block is the "current" block, and the other is
+ the "prior" block. It will rotate the counts from current to prior
+ when ROTATE_PERIOD has passed. Thus, the current block will
+ represent from 0 seconds to ROTATE_PERIOD seconds of activity
+ (determined by: time.time() - self.last_rotate). The prior block
+ will always represent a full ROTATE_PERIOD.
+
+ Sleeping is performed just before a transfer of another block, and is
+ based on the counts transferred *before* the next transfer. It really
+ does not matter how much will be transferred, but only that for all the
+ data transferred SO FAR that we have interspersed enough pauses to
+ ensure the aggregate transfer rate is within the specified limit.
+
+ These counts are maintained on a per-thread basis, so we do not require
+ any interlocks around incrementing the counts. There IS an interlock on
+ the rotation of the counts because we do not want multiple threads to
+ multiply-rotate the counts.
+
+ There are various race conditions in the computation and collection
+ of these counts. We do not require precise values, but simply to
+ keep the overall transfer within the bandwidth limits. If a given
+ pause is a little short, or a little long, then the aggregate delays
+ will be correct.
+ """
+
+ ROTATE_PERIOD = 600
+
+ def __init__(self,
+ get_time=time.time,
+ thread_sleep=InterruptibleSleep,
+ layout=None):
+ self.get_time = get_time
+ self.thread_sleep = thread_sleep
+
+ self.start_time = get_time()
+ self.transferred = {}
+ self.prior_block = {}
+ self.totals = {}
+ self.throttles = {}
+
+ self.last_rotate = {}
+ self.rotate_mutex = {}
+ if layout:
+ self.AddThrottles(layout)
+
+ def AddThrottle(self, name, limit):
+ self.throttles[name] = limit
+ self.transferred[name] = {}
+ self.prior_block[name] = {}
+ self.totals[name] = {}
+ self.last_rotate[name] = self.get_time()
+ self.rotate_mutex[name] = threading.Lock()
+
+ def AddThrottles(self, layout):
+ for key, value in layout.iteritems():
+ self.AddThrottle(key, value)
+
+ def Register(self, thread):
+ """Register this thread with the throttler."""
+ thread_name = thread.getName()
+ for throttle_name in self.throttles.iterkeys():
+ self.transferred[throttle_name][thread_name] = 0
+ self.prior_block[throttle_name][thread_name] = 0
+ self.totals[throttle_name][thread_name] = 0
+
+ def VerifyName(self, throttle_name):
+ if throttle_name not in self.throttles:
+ raise AssertionError('%s is not a registered throttle' % throttle_name)
+
+ def AddTransfer(self, throttle_name, token_count):
+ """Add a count to the amount this thread has transferred.
+
+ Each time a thread transfers some data, it should call this method to
+ note the amount sent. The counts may be rotated if sufficient time
+ has passed since the last rotation.
+
+ Note: this method should only be called by the BulkLoaderThread
+ instances. The token count is allocated towards the
+ "current thread".
+
+ Args:
+ throttle_name: The name of the throttle to add to.
+ token_count: The number to add to the throttle counter.
+ """
+ self.VerifyName(throttle_name)
+ transferred = self.transferred[throttle_name]
+ transferred[threading.currentThread().getName()] += token_count
+
+ if self.last_rotate[throttle_name] + self.ROTATE_PERIOD < self.get_time():
+ self._RotateCounts(throttle_name)
+
+ def Sleep(self, throttle_name=None):
+ """Possibly sleep in order to limit the transfer rate.
+
+ Note that we sleep based on *prior* transfers rather than what we
+ may be about to transfer. The next transfer could put us under/over
+ and that will be rectified *after* that transfer. Net result is that
+ the average transfer rate will remain within bounds. Spiky behavior
+ or uneven rates among the threads could possibly bring the transfer
+ rate above the requested limit for short durations.
+
+ Args:
+ throttle_name: The name of the throttle to sleep on. If None or
+ omitted, then sleep on all throttles.
+ """
+ if throttle_name is None:
+ for throttle_name in self.throttles:
+ self.Sleep(throttle_name=throttle_name)
+ return
+
+ self.VerifyName(throttle_name)
+
+ thread = threading.currentThread()
+
+ while True:
+ duration = self.get_time() - self.last_rotate[throttle_name]
+
+ total = 0
+ for count in self.prior_block[throttle_name].values():
+ total += count
+
+ if total:
+ duration += self.ROTATE_PERIOD
+
+ for count in self.transferred[throttle_name].values():
+ total += count
+
+ sleep_time = (float(total) / self.throttles[throttle_name]) - duration
+
+ if sleep_time < MINIMUM_THROTTLE_SLEEP_DURATION:
+ break
+
+ logging.debug('[%s] Throttling on %s. Sleeping for %.1f ms '
+ '(duration=%.1f ms, total=%d)',
+ thread.getName(), throttle_name,
+ sleep_time * 1000, duration * 1000, total)
+ self.thread_sleep(sleep_time)
+ if thread.exit_flag:
+ break
+ self._RotateCounts(throttle_name)
+
+ def _RotateCounts(self, throttle_name):
+ """Rotate the transfer counters.
+
+ If sufficient time has passed, then rotate the counters from active to
+ the prior-block of counts.
+
+ This rotation is interlocked to ensure that multiple threads do not
+ over-rotate the counts.
+
+ Args:
+ throttle_name: The name of the throttle to rotate.
+ """
+ self.VerifyName(throttle_name)
+ self.rotate_mutex[throttle_name].acquire()
+ try:
+ next_rotate_time = self.last_rotate[throttle_name] + self.ROTATE_PERIOD
+ if next_rotate_time >= self.get_time():
+ return
+
+ for name, count in self.transferred[throttle_name].items():
+
+
+ self.prior_block[throttle_name][name] = count
+ self.transferred[throttle_name][name] = 0
+
+ self.totals[throttle_name][name] += count
+
+ self.last_rotate[throttle_name] = self.get_time()
+
+ finally:
+ self.rotate_mutex[throttle_name].release()
+
+ def TotalTransferred(self, throttle_name):
+ """Return the total transferred, and over what period.
+
+ Args:
+ throttle_name: The name of the throttle to total.
+
+ Returns:
+ A tuple of the total count and running time for the given throttle name.
+ """
+ total = 0
+ for count in self.totals[throttle_name].values():
+ total += count
+ for count in self.transferred[throttle_name].values():
+ total += count
+ return total, self.get_time() - self.start_time
+
+
+class _ThreadBase(threading.Thread):
+ """Provide some basic features for the threads used in the uploader.
+
+ This abstract base class is used to provide some common features:
+
+ * Flag to ask thread to exit as soon as possible.
+ * Record exit/error status for the primary thread to pick up.
+ * Capture exceptions and record them for pickup.
+ * Some basic logging of thread start/stop.
+ * All threads are "daemon" threads.
+ * Friendly names for presenting to users.
+
+ Concrete sub-classes must implement PerformWork().
+
+ Either self.NAME should be set or GetFriendlyName() be overridden to
+ return a human-friendly name for this thread.
+
+ The run() method starts the thread and prints start/exit messages.
+
+ self.exit_flag is intended to signal that this thread should exit
+ when it gets the chance. PerformWork() should check self.exit_flag
+ whenever it has the opportunity to exit gracefully.
+ """
+
+ def __init__(self):
+ threading.Thread.__init__(self)
+
+ self.setDaemon(True)
+
+ self.exit_flag = False
+ self.error = None
+
+ def run(self):
+ """Perform the work of the thread."""
+ logging.info('[%s] %s: started', self.getName(), self.__class__.__name__)
+
+ try:
+ self.PerformWork()
+ except:
+ self.error = sys.exc_info()[1]
+ logging.exception('[%s] %s:', self.getName(), self.__class__.__name__)
+
+ logging.info('[%s] %s: exiting', self.getName(), self.__class__.__name__)
+
+ def PerformWork(self):
+ """Perform the thread-specific work."""
+ raise NotImplementedError()
+
+ def CheckError(self):
+ """If an error is present, then log it."""
+ if self.error:
+ logging.error('Error in %s: %s', self.GetFriendlyName(), self.error)
+
+ def GetFriendlyName(self):
+ """Returns a human-friendly description of the thread."""
+ if hasattr(self, 'NAME'):
+ return self.NAME
+ return 'unknown thread'
+
+
+class BulkLoaderThread(_ThreadBase):
+ """A thread which transmits entities to the server application.
+
+ This thread will read WorkItem instances from the work_queue and upload
+ the entities to the server application. Progress information will be
+ pushed into the progress_queue as the work is being performed.
+
+ If a BulkLoaderThread encounters a transient error, the entities will be
+ resent, if a fatal error is encoutered the BulkLoaderThread exits.
+ """
+
+ def __init__(self,
+ work_queue,
+ throttle,
+ thread_gate,
+ request_manager):
+ """Initialize the BulkLoaderThread instance.
+
+ Args:
+ work_queue: A queue containing WorkItems for processing.
+ throttle: A Throttles to control upload bandwidth.
+ thread_gate: A ThreadGate to control number of simultaneous uploads.
+ request_manager: A RequestManager instance.
+ """
+ _ThreadBase.__init__(self)
+
+ self.work_queue = work_queue
+ self.throttle = throttle
+ self.thread_gate = thread_gate
+
+ self.request_manager = request_manager
+
+ def PerformWork(self):
+ """Perform the work of a BulkLoaderThread."""
+ while not self.exit_flag:
+ success = False
+ self.thread_gate.StartWork()
+ try:
+ try:
+ item = self.work_queue.get(block=True, timeout=1.0)
+ except Queue.Empty:
+ continue
+ if item == _THREAD_SHOULD_EXIT:
+ break
+
+ logging.debug('[%s] Got work item [%d-%d]',
+ self.getName(), item.key_start, item.key_end)
+
+ try:
+
+ item.MarkAsSending()
+ try:
+ if item.content is None:
+ item.content = self.request_manager.EncodeContent(item.rows)
+ try:
+ self.request_manager.PostEntities(item)
+ success = True
+ logging.debug(
+ '[%d-%d] Sent %d entities',
+ item.key_start, item.key_end, item.count)
+ self.throttle.AddTransfer(RECORDS, item.count)
+ except (db.InternalError, db.NotSavedError, db.Timeout), e:
+ logging.debug('Caught non-fatal error: %s', e)
+ except urllib2.HTTPError, e:
+ if e.code == 403 or (e.code >= 500 and e.code < 600):
+ logging.debug('Caught HTTP error %d', e.code)
+ logging.debug('%s', e.read())
+ else:
+ raise e
+
+ except:
+ self.error = sys.exc_info()[1]
+ logging.exception('[%s] %s: caught exception %s', self.getName(),
+ self.__class__.__name__, str(sys.exc_info()))
+ raise
+
+ finally:
+ if success:
+ item.MarkAsSent()
+ self.thread_gate.IncreaseWorkers()
+ self.work_queue.task_done()
+ else:
+ item.MarkAsError()
+ self.thread_gate.DecreaseWorkers()
+ try:
+ self.work_queue.reput(item, block=False)
+ except Queue.Full:
+ logging.error('[%s] Failed to reput work item.', self.getName())
+ raise Error('Failed to reput work item')
+ logging.info('[%d-%d] %s',
+ item.key_start, item.key_end, StateMessage(item.state))
+
+ finally:
+ self.thread_gate.FinishWork()
+
+
+ def GetFriendlyName(self):
+ """Returns a human-friendly name for this thread."""
+ return 'worker [%s]' % self.getName()
+
+
+class DataSourceThread(_ThreadBase):
+ """A thread which reads WorkItems and pushes them into queue.
+
+ This thread will read/consume WorkItems from a generator (produced by
+ the generator factory). These WorkItems will then be pushed into the
+ work_queue. Note that reading will block if/when the work_queue becomes
+ full. Information on content consumed from the generator will be pushed
+ into the progress_queue.
+ """
+
+ NAME = 'data source thread'
+
+ def __init__(self,
+ work_queue,
+ progress_queue,
+ workitem_generator_factory,
+ progress_generator_factory):
+ """Initialize the DataSourceThread instance.
+
+ Args:
+ work_queue: A queue containing WorkItems for processing.
+ progress_queue: A queue used for tracking progress information.
+ workitem_generator_factory: A factory that creates a WorkItem generator
+ progress_generator_factory: A factory that creates a generator which
+ produces prior progress status, or None if there is no prior status
+ to use.
+ """
+ _ThreadBase.__init__(self)
+
+ self.work_queue = work_queue
+ self.progress_queue = progress_queue
+ self.workitem_generator_factory = workitem_generator_factory
+ self.progress_generator_factory = progress_generator_factory
+ self.entity_count = 0
+
+ def PerformWork(self):
+ """Performs the work of a DataSourceThread."""
+ if self.progress_generator_factory:
+ progress_gen = self.progress_generator_factory()
+ else:
+ progress_gen = None
+
+ content_gen = self.workitem_generator_factory(self.progress_queue,
+ progress_gen)
+
+ self.sent_count = 0
+ self.read_count = 0
+ self.read_all = False
+
+ for item in content_gen.Batches():
+ item.MarkAsRead()
+
+ while not self.exit_flag:
+ try:
+ self.work_queue.put(item, block=True, timeout=1.0)
+ self.entity_count += item.count
+ break
+ except Queue.Full:
+ pass
+
+ if self.exit_flag:
+ break
+
+ if not self.exit_flag:
+ self.read_all = True
+ self.read_count = content_gen.row_count
+ self.sent_count = content_gen.sent_count
+
+
+
+def _RunningInThread(thread):
+ """Return True if we are running within the specified thread."""
+ return threading.currentThread().getName() == thread.getName()
+
+
+class ProgressDatabase(object):
+ """Persistently record all progress information during an upload.
+
+ This class wraps a very simple SQLite database which records each of
+ the relevant details from the WorkItem instances. If the uploader is
+ resumed, then data is replayed out of the database.
+ """
+
+ def __init__(self, db_filename, commit_periodicity=100):
+ """Initialize the ProgressDatabase instance.
+
+ Args:
+ db_filename: The name of the SQLite database to use.
+ commit_periodicity: How many operations to perform between commits.
+ """
+ self.db_filename = db_filename
+
+ logging.info('Using progress database: %s', db_filename)
+ self.primary_conn = sqlite3.connect(db_filename, isolation_level=None)
+ self.primary_thread = threading.currentThread()
+
+ self.progress_conn = None
+ self.progress_thread = None
+
+ self.operation_count = 0
+ self.commit_periodicity = commit_periodicity
+
+ self.prior_key_end = None
+
+ try:
+ self.primary_conn.execute(
+ """create table progress (
+ id integer primary key autoincrement,
+ state integer not null,
+ key_start integer not null,
+ key_end integer not null
+ )
+ """)
+ except sqlite3.OperationalError, e:
+ if 'already exists' not in e.message:
+ raise
+
+ try:
+ self.primary_conn.execute('create index i_state on progress (state)')
+ except sqlite3.OperationalError, e:
+ if 'already exists' not in e.message:
+ raise
+
+ def ThreadComplete(self):
+ """Finalize any operations the progress thread has performed.
+
+ The database aggregates lots of operations into a single commit, and
+ this method is used to commit any pending operations as the thread
+ is about to shut down.
+ """
+ if self.progress_conn:
+ self._MaybeCommit(force_commit=True)
+
+ def _MaybeCommit(self, force_commit=False):
+ """Periodically commit changes into the SQLite database.
+
+ Committing every operation is quite expensive, and slows down the
+ operation of the script. Thus, we only commit after every N operations,
+ as determined by the self.commit_periodicity value. Optionally, the
+ caller can force a commit.
+
+ Args:
+ force_commit: Pass True in order for a commit to occur regardless
+ of the current operation count.
+ """
+ self.operation_count += 1
+ if force_commit or (self.operation_count % self.commit_periodicity) == 0:
+ self.progress_conn.commit()
+
+ def _OpenProgressConnection(self):
+ """Possibly open a database connection for the progress tracker thread.
+
+ If the connection is not open (for the calling thread, which is assumed
+ to be the progress tracker thread), then open it. We also open a couple
+ cursors for later use (and reuse).
+ """
+ if self.progress_conn:
+ return
+
+ assert not _RunningInThread(self.primary_thread)
+
+ self.progress_thread = threading.currentThread()
+
+ self.progress_conn = sqlite3.connect(self.db_filename)
+
+ self.insert_cursor = self.progress_conn.cursor()
+ self.update_cursor = self.progress_conn.cursor()
+
+ def HasUnfinishedWork(self):
+ """Returns True if the database has progress information.
+
+ Note there are two basic cases for progress information:
+ 1) All saved records indicate a successful upload. In this case, we
+ need to skip everything transmitted so far and then send the rest.
+ 2) Some records for incomplete transfer are present. These need to be
+ sent again, and then we resume sending after all the successful
+ data.
+
+ Returns:
+ True if the database has progress information, False otherwise.
+
+ Raises:
+ ResumeError: If there is an error reading the progress database.
+ """
+ assert _RunningInThread(self.primary_thread)
+
+ cursor = self.primary_conn.cursor()
+ cursor.execute('select count(*) from progress')
+ row = cursor.fetchone()
+ if row is None:
+ raise ResumeError('Error reading progress information.')
+
+ return row[0] != 0
+
+ def StoreKeys(self, key_start, key_end):
+ """Record a new progress record, returning a key for later updates.
+
+ The specified progress information will be persisted into the database.
+ A unique key will be returned that identifies this progress state. The
+ key is later used to (quickly) update this record.
+
+ For the progress resumption to proceed properly, calls to StoreKeys
+ MUST specify monotonically increasing key ranges. This will result in
+ a database whereby the ID, KEY_START, and KEY_END rows are all
+ increasing (rather than having ranges out of order).
+
+ NOTE: the above precondition is NOT tested by this method (since it
+ would imply an additional table read or two on each invocation).
+
+ Args:
+ key_start: The starting key of the WorkItem (inclusive)
+ key_end: The end key of the WorkItem (inclusive)
+
+ Returns:
+ A string to later be used as a unique key to update this state.
+ """
+ self._OpenProgressConnection()
+
+ assert _RunningInThread(self.progress_thread)
+ assert isinstance(key_start, int)
+ assert isinstance(key_end, int)
+ assert key_start <= key_end
+
+ if self.prior_key_end is not None:
+ assert key_start > self.prior_key_end
+ self.prior_key_end = key_end
+
+ self.insert_cursor.execute(
+ 'insert into progress (state, key_start, key_end) values (?, ?, ?)',
+ (STATE_READ, key_start, key_end))
+
+ progress_key = self.insert_cursor.lastrowid
+
+ self._MaybeCommit()
+
+ return progress_key
+
+ def UpdateState(self, key, new_state):
+ """Update a specified progress record with new information.
+
+ Args:
+ key: The key for this progress record, returned from StoreKeys
+ new_state: The new state to associate with this progress record.
+ """
+ self._OpenProgressConnection()
+
+ assert _RunningInThread(self.progress_thread)
+ assert isinstance(new_state, int)
+
+ self.update_cursor.execute('update progress set state=? where id=?',
+ (new_state, key))
+
+ self._MaybeCommit()
+
+ def GetProgressStatusGenerator(self):
+ """Get a generator which returns progress information.
+
+ The returned generator will yield a series of 4-tuples that specify
+ progress information about a prior run of the uploader. The 4-tuples
+ have the following values:
+
+ progress_key: The unique key to later update this record with new
+ progress information.
+ state: The last state saved for this progress record.
+ key_start: The starting key of the items for uploading (inclusive).
+ key_end: The ending key of the items for uploading (inclusive).
+
+ After all incompletely-transferred records are provided, then one
+ more 4-tuple will be generated:
+
+ None
+ DATA_CONSUMED_TO_HERE: A unique string value indicating this record
+ is being provided.
+ None
+ key_end: An integer value specifying the last data source key that
+ was handled by the previous run of the uploader.
+
+ The caller should begin uploading records which occur after key_end.
+
+ Yields:
+ Progress information as tuples (progress_key, state, key_start, key_end).
+ """
+ conn = sqlite3.connect(self.db_filename, isolation_level=None)
+ cursor = conn.cursor()
+
+ cursor.execute('select max(id) from progress')
+ batch_id = cursor.fetchone()[0]
+
+ cursor.execute('select key_end from progress where id = ?', (batch_id,))
+ key_end = cursor.fetchone()[0]
+
+ self.prior_key_end = key_end
+
+ cursor.execute(
+ 'select id, state, key_start, key_end from progress'
+ ' where state != ?'
+ ' order by id',
+ (STATE_SENT,))
+
+ rows = cursor.fetchall()
+
+ for row in rows:
+ if row is None:
+ break
+
+ yield row
+
+ yield None, DATA_CONSUMED_TO_HERE, None, key_end
+
+
+class StubProgressDatabase(object):
+ """A stub implementation of ProgressDatabase which does nothing."""
+
+ def HasUnfinishedWork(self):
+ """Whether the stub database has progress information (it doesn't)."""
+ return False
+
+ def StoreKeys(self, unused_key_start, unused_key_end):
+ """Pretend to store a key in the stub database."""
+ return 'fake-key'
+
+ def UpdateState(self, unused_key, unused_new_state):
+ """Pretend to update the state of a progress item."""
+ pass
+
+ def ThreadComplete(self):
+ """Finalize operations on the stub database (i.e. do nothing)."""
+ pass
+
+
+class ProgressTrackerThread(_ThreadBase):
+ """A thread which records progress information for the upload process.
+
+ The progress information is stored into the provided progress database.
+ This class is not responsible for replaying a prior run's progress
+ information out of the database. Separate mechanisms must be used to
+ resume a prior upload attempt.
+ """
+
+ NAME = 'progress tracking thread'
+
+ def __init__(self, progress_queue, progress_db):
+ """Initialize the ProgressTrackerThread instance.
+
+ Args:
+ progress_queue: A Queue used for tracking progress information.
+ progress_db: The database for tracking progress information; should
+ be an instance of ProgressDatabase.
+ """
+ _ThreadBase.__init__(self)
+
+ self.progress_queue = progress_queue
+ self.db = progress_db
+ self.entities_sent = 0
+
+ def PerformWork(self):
+ """Performs the work of a ProgressTrackerThread."""
+ while not self.exit_flag:
+ try:
+ item = self.progress_queue.get(block=True, timeout=1.0)
+ except Queue.Empty:
+ continue
+ if item == _THREAD_SHOULD_EXIT:
+ break
+
+ if item.state == STATE_READ and item.progress_key is None:
+ item.progress_key = self.db.StoreKeys(item.key_start, item.key_end)
+ else:
+ assert item.progress_key is not None
+
+ self.db.UpdateState(item.progress_key, item.state)
+ if item.state == STATE_SENT:
+ self.entities_sent += item.count
+
+ item.progress_event.set()
+
+ self.progress_queue.task_done()
+
+ self.db.ThreadComplete()
+
+
+
+def Validate(value, typ):
+ """Checks that value is non-empty and of the right type.
+
+ Args:
+ value: any value
+ typ: a type or tuple of types
+
+ Raises:
+ ValueError if value is None or empty.
+ TypeError if it's not the given type.
+
+ """
+ if not value:
+ raise ValueError('Value should not be empty; received %s.' % value)
+ elif not isinstance(value, typ):
+ raise TypeError('Expected a %s, but received %s (a %s).' %
+ (typ, value, value.__class__))
+
+
+class Loader(object):
+ """A base class for creating datastore entities from input data.
+
+ To add a handler for bulk loading a new entity kind into your datastore,
+ write a subclass of this class that calls Loader.__init__ from your
+ class's __init__.
+
+ If you need to run extra code to convert entities from the input
+ data, create new properties, or otherwise modify the entities before
+ they're inserted, override HandleEntity.
+
+ See the CreateEntity method for the creation of entities from the
+ (parsed) input data.
+ """
+
+ __loaders = {}
+ __kind = None
+ __properties = None
+
+ def __init__(self, kind, properties):
+ """Constructor.
+
+ Populates this Loader's kind and properties map. Also registers it with
+ the bulk loader, so that all you need to do is instantiate your Loader,
+ and the bulkload handler will automatically use it.
+
+ Args:
+ kind: a string containing the entity kind that this loader handles
+
+ properties: list of (name, converter) tuples.
+
+ This is used to automatically convert the CSV columns into
+ properties. The converter should be a function that takes one
+ argument, a string value from the CSV file, and returns a
+ correctly typed property value that should be inserted. The
+ tuples in this list should match the columns in your CSV file,
+ in order.
+
+ For example:
+ [('name', str),
+ ('id_number', int),
+ ('email', datastore_types.Email),
+ ('user', users.User),
+ ('birthdate', lambda x: datetime.datetime.fromtimestamp(float(x))),
+ ('description', datastore_types.Text),
+ ]
+ """
+ Validate(kind, basestring)
+ self.__kind = kind
+
+ db.class_for_kind(kind)
+
+ Validate(properties, list)
+ for name, fn in properties:
+ Validate(name, basestring)
+ assert callable(fn), (
+ 'Conversion function %s for property %s is not callable.' % (fn, name))
+
+ self.__properties = properties
+
+ @staticmethod
+ def RegisterLoader(loader):
+
+ Loader.__loaders[loader.__kind] = loader
+
+ def kind(self):
+ """ Return the entity kind that this Loader handes.
+ """
+ return self.__kind
+
+ def CreateEntity(self, values, key_name=None):
+ """Creates a entity from a list of property values.
+
+ Args:
+ values: list/tuple of str
+ key_name: if provided, the name for the (single) resulting entity
+
+ Returns:
+ list of db.Model
+
+ The returned entities are populated with the property values from the
+ argument, converted to native types using the properties map given in
+ the constructor, and passed through HandleEntity. They're ready to be
+ inserted.
+
+ Raises:
+ AssertionError if the number of values doesn't match the number
+ of properties in the properties map.
+ ValueError if any element of values is None or empty.
+ TypeError if values is not a list or tuple.
+ """
+ Validate(values, (list, tuple))
+ assert len(values) == len(self.__properties), (
+ 'Expected %d CSV columns, found %d.' %
+ (len(self.__properties), len(values)))
+
+ model_class = db.class_for_kind(self.__kind)
+
+ properties = {'key_name': key_name}
+ for (name, converter), val in zip(self.__properties, values):
+ if converter is bool and val.lower() in ('0', 'false', 'no'):
+ val = False
+ properties[name] = converter(val)
+
+ entity = model_class(**properties)
+ entities = self.HandleEntity(entity)
+
+ if entities:
+ if not isinstance(entities, (list, tuple)):
+ entities = [entities]
+
+ for entity in entities:
+ if not isinstance(entity, db.Model):
+ raise TypeError('Expected a db.Model, received %s (a %s).' %
+ (entity, entity.__class__))
+
+ return entities
+
+ def GenerateKey(self, i, values):
+ """Generates a key_name to be used in creating the underlying object.
+
+ The default implementation returns None.
+
+ This method can be overridden to control the key generation for
+ uploaded entities. The value returned should be None (to use a
+ server generated numeric key), or a string which neither starts
+ with a digit nor has the form __*__. (See
+ http://code.google.com/appengine/docs/python/datastore/keysandentitygroups.html)
+
+ If you generate your own string keys, keep in mind:
+
+ 1. The key name for each entity must be unique.
+ 2. If an entity of the same kind and key already exists in the
+ datastore, it will be overwritten.
+
+ Args:
+ i: Number corresponding to this object (assume it's run in a loop,
+ this is your current count.
+ values: list/tuple of str.
+
+ Returns:
+ A string to be used as the key_name for an entity.
+ """
+ return None
+
+ def HandleEntity(self, entity):
+ """Subclasses can override this to add custom entity conversion code.
+
+ This is called for each entity, after its properties are populated from
+ CSV but before it is stored. Subclasses can override this to add custom
+ entity handling code.
+
+ The entity to be inserted should be returned. If multiple entities should
+ be inserted, return a list of entities. If no entities should be inserted,
+ return None or [].
+
+ Args:
+ entity: db.Model
+
+ Returns:
+ db.Model or list of db.Model
+ """
+ return entity
+
+
+ @staticmethod
+ def RegisteredLoaders():
+ """Returns a list of the Loader instances that have been created.
+ """
+ return dict(Loader.__loaders)
+
+
+class QueueJoinThread(threading.Thread):
+ """A thread that joins a queue and exits.
+
+ Queue joins do not have a timeout. To simulate a queue join with
+ timeout, run this thread and join it with a timeout.
+ """
+
+ def __init__(self, queue):
+ """Initialize a QueueJoinThread.
+
+ Args:
+ queue: The queue for this thread to join.
+ """
+ threading.Thread.__init__(self)
+ assert isinstance(queue, (Queue.Queue, ReQueue))
+ self.queue = queue
+
+ def run(self):
+ """Perform the queue join in this thread."""
+ self.queue.join()
+
+
+def InterruptibleQueueJoin(queue,
+ thread_local,
+ thread_gate,
+ queue_join_thread_factory=QueueJoinThread):
+ """Repeatedly joins the given ReQueue or Queue.Queue with short timeout.
+
+ Between each timeout on the join, worker threads are checked.
+
+ Args:
+ queue: A Queue.Queue or ReQueue instance.
+ thread_local: A threading.local instance which indicates interrupts.
+ thread_gate: A ThreadGate instance.
+ queue_join_thread_factory: Used for dependency injection.
+
+ Returns:
+ True unless the queue join is interrupted by SIGINT or worker death.
+ """
+ thread = queue_join_thread_factory(queue)
+ thread.start()
+ while True:
+ thread.join(timeout=.5)
+ if not thread.isAlive():
+ return True
+ if thread_local.shut_down:
+ logging.debug('Queue join interrupted')
+ return False
+ for worker_thread in thread_gate.Threads():
+ if not worker_thread.isAlive():
+ return False
+
+
+def ShutdownThreads(data_source_thread, work_queue, thread_gate):
+ """Shuts down the worker and data source threads.
+
+ Args:
+ data_source_thread: A running DataSourceThread instance.
+ work_queue: The work queue.
+ thread_gate: A ThreadGate instance with workers registered.
+ """
+ logging.info('An error occurred. Shutting down...')
+
+ data_source_thread.exit_flag = True
+
+ for thread in thread_gate.Threads():
+ thread.exit_flag = True
+
+ for unused_thread in thread_gate.Threads():
+ thread_gate.EnableThread()
+
+ data_source_thread.join(timeout=3.0)
+ if data_source_thread.isAlive():
+ logging.warn('%s hung while trying to exit',
+ data_source_thread.GetFriendlyName())
+
+ while not work_queue.empty():
+ try:
+ unused_item = work_queue.get_nowait()
+ work_queue.task_done()
+ except Queue.Empty:
+ pass
+
+
+def PerformBulkUpload(app_id,
+ post_url,
+ kind,
+ workitem_generator_factory,
+ num_threads,
+ throttle,
+ progress_db,
+ max_queue_size=DEFAULT_QUEUE_SIZE,
+ request_manager_factory=RequestManager,
+ bulkloaderthread_factory=BulkLoaderThread,
+ progresstrackerthread_factory=ProgressTrackerThread,
+ datasourcethread_factory=DataSourceThread,
+ work_queue_factory=ReQueue,
+ progress_queue_factory=Queue.Queue):
+ """Uploads data into an application using a series of HTTP POSTs.
+
+ This function will spin up a number of threads to read entities from
+ the data source, pass those to a number of worker ("uploader") threads
+ for sending to the application, and track all of the progress in a
+ small database in case an error or pause/termination requires a
+ restart/resumption of the upload process.
+
+ Args:
+ app_id: String containing application id.
+ post_url: URL to post the Entity data to.
+ kind: Kind of the Entity records being posted.
+ workitem_generator_factory: A factory that creates a WorkItem generator.
+ num_threads: How many uploader threads should be created.
+ throttle: A Throttle instance.
+ progress_db: The database to use for replaying/recording progress.
+ max_queue_size: Maximum size of the queues before they should block.
+ request_manager_factory: Used for dependency injection.
+ bulkloaderthread_factory: Used for dependency injection.
+ progresstrackerthread_factory: Used for dependency injection.
+ datasourcethread_factory: Used for dependency injection.
+ work_queue_factory: Used for dependency injection.
+ progress_queue_factory: Used for dependency injection.
+
+ Raises:
+ AuthenticationError: If authentication is required and fails.
+ """
+ thread_gate = ThreadGate(True)
+
+ (unused_scheme,
+ host_port, url_path,
+ unused_query, unused_fragment) = urlparse.urlsplit(post_url)
+
+ work_queue = work_queue_factory(max_queue_size)
+ progress_queue = progress_queue_factory(max_queue_size)
+ request_manager = request_manager_factory(app_id,
+ host_port,
+ url_path,
+ kind,
+ throttle)
+
+ throttle.Register(threading.currentThread())
+ try:
+ request_manager.Authenticate()
+ except Exception, e:
+ logging.exception(e)
+ raise AuthenticationError('Authentication failed')
+ if (request_manager.credentials is not None and
+ not request_manager.authenticated):
+ raise AuthenticationError('Authentication failed')
+
+ for unused_idx in range(num_threads):
+ thread = bulkloaderthread_factory(work_queue,
+ throttle,
+ thread_gate,
+ request_manager)
+ throttle.Register(thread)
+ thread_gate.Register(thread)
+
+ progress_thread = progresstrackerthread_factory(progress_queue, progress_db)
+
+ if progress_db.HasUnfinishedWork():
+ logging.debug('Restarting upload using progress database')
+ progress_generator_factory = progress_db.GetProgressStatusGenerator
+ else:
+ progress_generator_factory = None
+
+ data_source_thread = datasourcethread_factory(work_queue,
+ progress_queue,
+ workitem_generator_factory,
+ progress_generator_factory)
+
+ thread_local = threading.local()
+ thread_local.shut_down = False
+
+ def Interrupt(unused_signum, unused_frame):
+ """Shutdown gracefully in response to a signal."""
+ thread_local.shut_down = True
+
+ signal.signal(signal.SIGINT, Interrupt)
+
+ progress_thread.start()
+ data_source_thread.start()
+ for thread in thread_gate.Threads():
+ thread.start()
+
+
+ while not thread_local.shut_down:
+ data_source_thread.join(timeout=0.25)
+
+ if data_source_thread.isAlive():
+ for thread in list(thread_gate.Threads()) + [progress_thread]:
+ if not thread.isAlive():
+ logging.info('Unexpected thread death: %s', thread.getName())
+ thread_local.shut_down = True
+ break
+ else:
+ break
+
+ if thread_local.shut_down:
+ ShutdownThreads(data_source_thread, work_queue, thread_gate)
+
+ def _Join(ob, msg):
+ logging.debug('Waiting for %s...', msg)
+ if isinstance(ob, threading.Thread):
+ ob.join(timeout=3.0)
+ if ob.isAlive():
+ logging.debug('Joining %s failed', ob.GetFriendlyName())
+ else:
+ logging.debug('... done.')
+ elif isinstance(ob, (Queue.Queue, ReQueue)):
+ if not InterruptibleQueueJoin(ob, thread_local, thread_gate):
+ ShutdownThreads(data_source_thread, work_queue, thread_gate)
+ else:
+ ob.join()
+ logging.debug('... done.')
+
+ _Join(work_queue, 'work_queue to flush')
+
+ for unused_thread in thread_gate.Threads():
+ work_queue.put(_THREAD_SHOULD_EXIT)
+
+ for unused_thread in thread_gate.Threads():
+ thread_gate.EnableThread()
+
+ for thread in thread_gate.Threads():
+ _Join(thread, 'thread [%s] to terminate' % thread.getName())
+
+ thread.CheckError()
+
+ if progress_thread.isAlive():
+ _Join(progress_queue, 'progress_queue to finish')
+ else:
+ logging.warn('Progress thread exited prematurely')
+
+ progress_queue.put(_THREAD_SHOULD_EXIT)
+ _Join(progress_thread, 'progress_thread to terminate')
+ progress_thread.CheckError()
+
+ data_source_thread.CheckError()
+
+ total_up, duration = throttle.TotalTransferred(BANDWIDTH_UP)
+ s_total_up, unused_duration = throttle.TotalTransferred(HTTPS_BANDWIDTH_UP)
+ total_up += s_total_up
+ logging.info('%d entites read, %d previously transferred',
+ data_source_thread.read_count,
+ data_source_thread.sent_count)
+ logging.info('%d entities (%d bytes) transferred in %.1f seconds',
+ progress_thread.entities_sent, total_up, duration)
+ if (data_source_thread.read_all and
+ progress_thread.entities_sent + data_source_thread.sent_count >=
+ data_source_thread.read_count):
+ logging.info('All entities successfully uploaded')
+ else:
+ logging.info('Some entities not successfully uploaded')
+
+
+def PrintUsageExit(code):
+ """Prints usage information and exits with a status code.
+
+ Args:
+ code: Status code to pass to sys.exit() after displaying usage information.
+ """
+ print __doc__ % {'arg0': sys.argv[0]}
+ sys.stdout.flush()
+ sys.stderr.flush()
+ sys.exit(code)
+
+
+def ParseArguments(argv):
+ """Parses command-line arguments.
+
+ Prints out a help message if -h or --help is supplied.
+
+ Args:
+ argv: List of command-line arguments.
+
+ Returns:
+ Tuple (url, filename, cookie, batch_size, kind) containing the values from
+ each corresponding command-line flag.
+ """
+ opts, unused_args = getopt.getopt(
+ argv[1:],
+ 'h',
+ ['debug',
+ 'help',
+ 'url=',
+ 'filename=',
+ 'batch_size=',
+ 'kind=',
+ 'num_threads=',
+ 'bandwidth_limit=',
+ 'rps_limit=',
+ 'http_limit=',
+ 'db_filename=',
+ 'app_id=',
+ 'config_file=',
+ 'auth_domain=',
+ ])
+
+ url = None
+ filename = None
+ batch_size = DEFAULT_BATCH_SIZE
+ kind = None
+ num_threads = DEFAULT_THREAD_COUNT
+ bandwidth_limit = DEFAULT_BANDWIDTH_LIMIT
+ rps_limit = DEFAULT_RPS_LIMIT
+ http_limit = DEFAULT_REQUEST_LIMIT
+ db_filename = None
+ app_id = None
+ config_file = None
+ auth_domain = 'gmail.com'
+
+ for option, value in opts:
+ if option == '--debug':
+ logging.getLogger().setLevel(logging.DEBUG)
+ elif option in ('-h', '--help'):
+ PrintUsageExit(0)
+ elif option == '--url':
+ url = value
+ elif option == '--filename':
+ filename = value
+ elif option == '--batch_size':
+ batch_size = int(value)
+ elif option == '--kind':
+ kind = value
+ elif option == '--num_threads':
+ num_threads = int(value)
+ elif option == '--bandwidth_limit':
+ bandwidth_limit = int(value)
+ elif option == '--rps_limit':
+ rps_limit = int(value)
+ elif option == '--http_limit':
+ http_limit = int(value)
+ elif option == '--db_filename':
+ db_filename = value
+ elif option == '--app_id':
+ app_id = value
+ elif option == '--config_file':
+ config_file = value
+ elif option == '--auth_domain':
+ auth_domain = value
+
+ return ProcessArguments(app_id=app_id,
+ url=url,
+ filename=filename,
+ batch_size=batch_size,
+ kind=kind,
+ num_threads=num_threads,
+ bandwidth_limit=bandwidth_limit,
+ rps_limit=rps_limit,
+ http_limit=http_limit,
+ db_filename=db_filename,
+ config_file=config_file,
+ auth_domain=auth_domain,
+ die_fn=lambda: PrintUsageExit(1))
+
+
+def ThrottleLayout(bandwidth_limit, http_limit, rps_limit):
+ return {
+ BANDWIDTH_UP: bandwidth_limit,
+ BANDWIDTH_DOWN: bandwidth_limit,
+ REQUESTS: http_limit,
+ HTTPS_BANDWIDTH_UP: bandwidth_limit / 5,
+ HTTPS_BANDWIDTH_DOWN: bandwidth_limit / 5,
+ HTTPS_REQUESTS: http_limit / 5,
+ RECORDS: rps_limit,
+ }
+
+
+def LoadConfig(config_file):
+ """Loads a config file and registers any Loader classes present."""
+ if config_file:
+ global_dict = dict(globals())
+ execfile(config_file, global_dict)
+ for cls in Loader.__subclasses__():
+ Loader.RegisterLoader(cls())
+
+
+def _MissingArgument(arg_name, die_fn):
+ """Print error message about missing argument and die."""
+ print >>sys.stderr, '%s argument required' % arg_name
+ die_fn()
+
+
+def ProcessArguments(app_id=None,
+ url=None,
+ filename=None,
+ batch_size=DEFAULT_BATCH_SIZE,
+ kind=None,
+ num_threads=DEFAULT_THREAD_COUNT,
+ bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT,
+ rps_limit=DEFAULT_RPS_LIMIT,
+ http_limit=DEFAULT_REQUEST_LIMIT,
+ db_filename=None,
+ config_file=None,
+ auth_domain='gmail.com',
+ die_fn=lambda: sys.exit(1)):
+ """Processes non command-line input arguments."""
+ if db_filename is None:
+ db_filename = time.strftime('bulkloader-progress-%Y%m%d.%H%M%S.sql3')
+
+ if batch_size <= 0:
+ print >>sys.stderr, 'batch_size must be 1 or larger'
+ die_fn()
+
+ if url is None:
+ _MissingArgument('url', die_fn)
+
+ if filename is None:
+ _MissingArgument('filename', die_fn)
+
+ if kind is None:
+ _MissingArgument('kind', die_fn)
+
+ if config_file is None:
+ _MissingArgument('config_file', die_fn)
+
+ if app_id is None:
+ (unused_scheme, host_port, unused_url_path,
+ unused_query, unused_fragment) = urlparse.urlsplit(url)
+ suffix_idx = host_port.find('.appspot.com')
+ if suffix_idx > -1:
+ app_id = host_port[:suffix_idx]
+ elif host_port.split(':')[0].endswith('google.com'):
+ app_id = host_port.split('.')[0]
+ else:
+ print >>sys.stderr, 'app_id required for non appspot.com domains'
+ die_fn()
+
+ return (app_id, url, filename, batch_size, kind, num_threads,
+ bandwidth_limit, rps_limit, http_limit, db_filename, config_file,
+ auth_domain)
+
+
+def _PerformBulkload(app_id=None,
+ url=None,
+ filename=None,
+ batch_size=DEFAULT_BATCH_SIZE,
+ kind=None,
+ num_threads=DEFAULT_THREAD_COUNT,
+ bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT,
+ rps_limit=DEFAULT_RPS_LIMIT,
+ http_limit=DEFAULT_REQUEST_LIMIT,
+ db_filename=None,
+ config_file=None,
+ auth_domain='gmail.com'):
+ """Runs the bulkloader, given the options as keyword arguments.
+
+ Args:
+ app_id: The application id.
+ url: The url of the remote_api endpoint.
+ filename: The name of the file containing the CSV data.
+ batch_size: The number of records to send per request.
+ kind: The kind of entity to transfer.
+ num_threads: The number of threads to use to transfer data.
+ bandwidth_limit: Maximum bytes/second to transfers.
+ rps_limit: Maximum records/second to transfer.
+ http_limit: Maximum requests/second for transfers.
+ db_filename: The name of the SQLite3 progress database file.
+ config_file: The name of the configuration file.
+ auth_domain: The auth domain to use for logins and UserProperty.
+
+ Returns:
+ An exit code.
+ """
+ os.environ['AUTH_DOMAIN'] = auth_domain
+ LoadConfig(config_file)
+
+ throttle_layout = ThrottleLayout(bandwidth_limit, http_limit, rps_limit)
+
+ throttle = Throttle(layout=throttle_layout)
+
+
+ workitem_generator_factory = GetCSVGeneratorFactory(filename, batch_size)
+
+ if db_filename == 'skip':
+ progress_db = StubProgressDatabase()
+ else:
+ progress_db = ProgressDatabase(db_filename)
+
+
+ max_queue_size = max(DEFAULT_QUEUE_SIZE, 2 * num_threads + 5)
+
+ PerformBulkUpload(app_id,
+ url,
+ kind,
+ workitem_generator_factory,
+ num_threads,
+ throttle,
+ progress_db,
+ max_queue_size=max_queue_size)
+
+ return 0
+
+
+def Run(app_id=None,
+ url=None,
+ filename=None,
+ batch_size=DEFAULT_BATCH_SIZE,
+ kind=None,
+ num_threads=DEFAULT_THREAD_COUNT,
+ bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT,
+ rps_limit=DEFAULT_RPS_LIMIT,
+ http_limit=DEFAULT_REQUEST_LIMIT,
+ db_filename=None,
+ auth_domain='gmail.com',
+ config_file=None):
+ """Sets up and runs the bulkloader, given the options as keyword arguments.
+
+ Args:
+ app_id: The application id.
+ url: The url of the remote_api endpoint.
+ filename: The name of the file containing the CSV data.
+ batch_size: The number of records to send per request.
+ kind: The kind of entity to transfer.
+ num_threads: The number of threads to use to transfer data.
+ bandwidth_limit: Maximum bytes/second to transfers.
+ rps_limit: Maximum records/second to transfer.
+ http_limit: Maximum requests/second for transfers.
+ db_filename: The name of the SQLite3 progress database file.
+ config_file: The name of the configuration file.
+ auth_domain: The auth domain to use for logins and UserProperty.
+
+ Returns:
+ An exit code.
+ """
+ logging.basicConfig(
+ format='%(levelname)-8s %(asctime)s %(filename)s] %(message)s')
+ args = ProcessArguments(app_id=app_id,
+ url=url,
+ filename=filename,
+ batch_size=batch_size,
+ kind=kind,
+ num_threads=num_threads,
+ bandwidth_limit=bandwidth_limit,
+ rps_limit=rps_limit,
+ http_limit=http_limit,
+ db_filename=db_filename,
+ config_file=config_file)
+
+ (app_id, url, filename, batch_size, kind, num_threads, bandwidth_limit,
+ rps_limit, http_limit, db_filename, config_file, auth_domain) = args
+
+ return _PerformBulkload(app_id=app_id,
+ url=url,
+ filename=filename,
+ batch_size=batch_size,
+ kind=kind,
+ num_threads=num_threads,
+ bandwidth_limit=bandwidth_limit,
+ rps_limit=rps_limit,
+ http_limit=http_limit,
+ db_filename=db_filename,
+ config_file=config_file,
+ auth_domain=auth_domain)
+
+
+def main(argv):
+ """Runs the importer from the command line."""
+ logging.basicConfig(
+ level=logging.INFO,
+ format='%(levelname)-8s %(asctime)s %(filename)s] %(message)s')
+
+ args = ParseArguments(argv)
+ if None in args:
+ print >>sys.stderr, 'Invalid arguments'
+ PrintUsageExit(1)
+
+ (app_id, url, filename, batch_size, kind, num_threads,
+ bandwidth_limit, rps_limit, http_limit, db_filename, config_file,
+ auth_domain) = args
+
+ return _PerformBulkload(app_id=app_id,
+ url=url,
+ filename=filename,
+ batch_size=batch_size,
+ kind=kind,
+ num_threads=num_threads,
+ bandwidth_limit=bandwidth_limit,
+ rps_limit=rps_limit,
+ http_limit=http_limit,
+ db_filename=db_filename,
+ config_file=config_file,
+ auth_domain=auth_domain)
+
+
+if __name__ == '__main__':
+ sys.exit(main(sys.argv))