diff -r 5c931bd3dc1e -r a7766286a7be thirdparty/google_appengine/google/appengine/tools/bulkloader.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/thirdparty/google_appengine/google/appengine/tools/bulkloader.py Thu Feb 12 12:30:36 2009 +0000 @@ -0,0 +1,2588 @@ +#!/usr/bin/env python +# +# Copyright 2007 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Imports CSV data over HTTP. + +Usage: + %(arg0)s [flags] + + --debug Show debugging information. (Optional) + --app_id= Application ID of endpoint (Optional for + *.appspot.com) + --auth_domain= The auth domain to use for logging in and for + UserProperties. (Default: gmail.com) + --bandwidth_limit= The maximum number of bytes per second for the + aggregate transfer of data to the server. Bursts + --batch_size= Number of Entity objects to include in each post to + the URL endpoint. The more data per row/Entity, the + smaller the batch size should be. (Default 10) + --config_file= File containing Model and Loader definitions. + (Required) + --db_filename= Specific progress database to write to, or to + resume from. If not supplied, then a new database + will be started, named: + bulkloader-progress-TIMESTAMP. + The special filename "skip" may be used to simply + skip reading/writing any progress information. + --filename= Path to the CSV file to import. (Required) + --http_limit= The maximum numer of HTTP requests per second to + send to the server. (Default: 8) + --kind= Name of the Entity object kind to put in the + datastore. (Required) + --num_threads= Number of threads to use for uploading entities + (Default 10) + may exceed this, but overall transfer rate is + restricted to this rate. (Default 250000) + --rps_limit= The maximum number of records per second to + transfer to the server. (Default: 20) + --url= URL endpoint to post to for importing data. + (Required) + +The exit status will be 0 on success, non-zero on import failure. + +Works with the remote_api mix-in library for google.appengine.ext.remote_api. +Please look there for documentation about how to setup the server side. + +Example: + +%(arg0)s --url=http://app.appspot.com/remote_api --kind=Model \ + --filename=data.csv --config_file=loader_config.py + +""" + + + +import csv +import getopt +import getpass +import logging +import new +import os +import Queue +import signal +import sys +import threading +import time +import traceback +import urllib2 +import urlparse + +from google.appengine.ext import db +from google.appengine.ext.remote_api import remote_api_stub +from google.appengine.tools import appengine_rpc + +try: + import sqlite3 +except ImportError: + pass + +UPLOADER_VERSION = '1' + +DEFAULT_THREAD_COUNT = 10 + +DEFAULT_BATCH_SIZE = 10 + +DEFAULT_QUEUE_SIZE = DEFAULT_THREAD_COUNT * 10 + +_THREAD_SHOULD_EXIT = '_THREAD_SHOULD_EXIT' + +STATE_READ = 0 +STATE_SENDING = 1 +STATE_SENT = 2 +STATE_NOT_SENT = 3 + +MINIMUM_THROTTLE_SLEEP_DURATION = 0.001 + +DATA_CONSUMED_TO_HERE = 'DATA_CONSUMED_TO_HERE' + +INITIAL_BACKOFF = 1.0 + +BACKOFF_FACTOR = 2.0 + + +DEFAULT_BANDWIDTH_LIMIT = 250000 + +DEFAULT_RPS_LIMIT = 20 + +DEFAULT_REQUEST_LIMIT = 8 + +BANDWIDTH_UP = 'http-bandwidth-up' +BANDWIDTH_DOWN = 'http-bandwidth-down' +REQUESTS = 'http-requests' +HTTPS_BANDWIDTH_UP = 'https-bandwidth-up' +HTTPS_BANDWIDTH_DOWN = 'https-bandwidth-down' +HTTPS_REQUESTS = 'https-requests' +RECORDS = 'records' + + +def StateMessage(state): + """Converts a numeric state identifier to a status message.""" + return ({ + STATE_READ: 'Batch read from file.', + STATE_SENDING: 'Sending batch to server.', + STATE_SENT: 'Batch successfully sent.', + STATE_NOT_SENT: 'Error while sending batch.' + }[state]) + + +class Error(Exception): + """Base-class for exceptions in this module.""" + + +class FatalServerError(Error): + """An unrecoverable error occurred while trying to post data to the server.""" + + +class ResumeError(Error): + """Error while trying to resume a partial upload.""" + + +class ConfigurationError(Error): + """Error in configuration options.""" + + +class AuthenticationError(Error): + """Error while trying to authenticate with the server.""" + + +def GetCSVGeneratorFactory(csv_filename, batch_size, + openfile=open, create_csv_reader=csv.reader): + """Return a factory that creates a CSV-based WorkItem generator. + + Args: + csv_filename: File on disk containing CSV data. + batch_size: Maximum number of CSV rows to stash into a WorkItem. + openfile: Used for dependency injection. + create_csv_reader: Used for dependency injection. + + Returns: A callable (accepting the Progress Queue and Progress + Generators as input) which creates the WorkItem generator. + """ + + def CreateGenerator(progress_queue, progress_generator): + """Initialize a CSV generator linked to a progress generator and queue. + + Args: + progress_queue: A ProgressQueue instance to send progress information. + progress_generator: A generator of progress information or None. + + Returns: + A CSVGenerator instance. + """ + return CSVGenerator(progress_queue, + progress_generator, + csv_filename, + batch_size, + openfile, + create_csv_reader) + return CreateGenerator + + +class CSVGenerator(object): + """Reads a CSV file and generates WorkItems containing batches of records.""" + + def __init__(self, + progress_queue, + progress_generator, + csv_filename, + batch_size, + openfile, + create_csv_reader): + """Initializes a CSV generator. + + Args: + progress_queue: A queue used for tracking progress information. + progress_generator: A generator of prior progress information, or None + if there is no prior status. + csv_filename: File on disk containing CSV data. + batch_size: Maximum number of CSV rows to stash into a WorkItem. + openfile: Used for dependency injection of 'open'. + create_csv_reader: Used for dependency injection of 'csv.reader'. + """ + self.progress_queue = progress_queue + self.progress_generator = progress_generator + self.csv_filename = csv_filename + self.batch_size = batch_size + self.openfile = openfile + self.create_csv_reader = create_csv_reader + self.line_number = 1 + self.column_count = None + self.read_rows = [] + self.reader = None + self.row_count = 0 + self.sent_count = 0 + + def _AdvanceTo(self, line): + """Advance the reader to the given line. + + Args: + line: A line number to advance to. + """ + while self.line_number < line: + self.reader.next() + self.line_number += 1 + self.row_count += 1 + self.sent_count += 1 + + def _ReadRows(self, key_start, key_end): + """Attempts to read and encode rows [key_start, key_end]. + + The encoded rows are stored in self.read_rows. + + Args: + key_start: The starting line number. + key_end: The ending line number. + + Raises: + StopIteration: if the reader runs out of rows + ResumeError: if there are an inconsistent number of columns. + """ + assert self.line_number == key_start + self.read_rows = [] + while self.line_number <= key_end: + row = self.reader.next() + self.row_count += 1 + if self.column_count is None: + self.column_count = len(row) + else: + if self.column_count != len(row): + raise ResumeError('Column count mismatch, %d: %s' % + (self.column_count, str(row))) + self.read_rows.append((self.line_number, row)) + self.line_number += 1 + + def _MakeItem(self, key_start, key_end, rows, progress_key=None): + """Makes a WorkItem containing the given rows, with the given keys. + + Args: + key_start: The start key for the WorkItem. + key_end: The end key for the WorkItem. + rows: A list of the rows for the WorkItem. + progress_key: The progress key for the WorkItem + + Returns: + A WorkItem instance for the given batch. + """ + assert rows + + item = WorkItem(self.progress_queue, rows, + key_start, key_end, + progress_key=progress_key) + + return item + + def Batches(self): + """Reads the CSV data file and generates WorkItems. + + Yields: + Instances of class WorkItem + + Raises: + ResumeError: If the progress database and data file indicate a different + number of rows. + """ + csv_file = self.openfile(self.csv_filename, 'r') + csv_content = csv_file.read() + if csv_content: + has_headers = csv.Sniffer().has_header(csv_content) + else: + has_headers = False + csv_file.seek(0) + self.reader = self.create_csv_reader(csv_file, skipinitialspace=True) + if has_headers: + logging.info('The CSV file appears to have a header line, skipping.') + self.reader.next() + + exhausted = False + + self.line_number = 1 + self.column_count = None + + logging.info('Starting import; maximum %d entities per post', + self.batch_size) + + state = None + if self.progress_generator is not None: + for progress_key, state, key_start, key_end in self.progress_generator: + if key_start: + try: + self._AdvanceTo(key_start) + self._ReadRows(key_start, key_end) + yield self._MakeItem(key_start, + key_end, + self.read_rows, + progress_key=progress_key) + except StopIteration: + logging.error('Mismatch between data file and progress database') + raise ResumeError( + 'Mismatch between data file and progress database') + elif state == DATA_CONSUMED_TO_HERE: + try: + self._AdvanceTo(key_end + 1) + except StopIteration: + state = None + + if self.progress_generator is None or state == DATA_CONSUMED_TO_HERE: + while not exhausted: + key_start = self.line_number + key_end = self.line_number + self.batch_size - 1 + try: + self._ReadRows(key_start, key_end) + except StopIteration: + exhausted = True + key_end = self.line_number - 1 + if key_start <= key_end: + yield self._MakeItem(key_start, key_end, self.read_rows) + + +class ReQueue(object): + """A special thread-safe queue. + + A ReQueue allows unfinished work items to be returned with a call to + reput(). When an item is reput, task_done() should *not* be called + in addition, getting an item that has been reput does not increase + the number of outstanding tasks. + + This class shares an interface with Queue.Queue and provides the + additional Reput method. + """ + + def __init__(self, + queue_capacity, + requeue_capacity=None, + queue_factory=Queue.Queue, + get_time=time.time): + """Initialize a ReQueue instance. + + Args: + queue_capacity: The number of items that can be put in the ReQueue. + requeue_capacity: The numer of items that can be reput in the ReQueue. + queue_factory: Used for dependency injection. + get_time: Used for dependency injection. + """ + if requeue_capacity is None: + requeue_capacity = queue_capacity + + self.get_time = get_time + self.queue = queue_factory(queue_capacity) + self.requeue = queue_factory(requeue_capacity) + self.lock = threading.Lock() + self.put_cond = threading.Condition(self.lock) + self.get_cond = threading.Condition(self.lock) + + def _DoWithTimeout(self, + action, + exc, + wait_cond, + done_cond, + lock, + timeout=None, + block=True): + """Performs the given action with a timeout. + + The action must be non-blocking, and raise an instance of exc on a + recoverable failure. If the action fails with an instance of exc, + we wait on wait_cond before trying again. Failure after the + timeout is reached is propagated as an exception. Success is + signalled by notifying on done_cond and returning the result of + the action. If action raises any exception besides an instance of + exc, it is immediately propagated. + + Args: + action: A callable that performs a non-blocking action. + exc: An exception type that is thrown by the action to indicate + a recoverable error. + wait_cond: A condition variable which should be waited on when + action throws exc. + done_cond: A condition variable to signal if the action returns. + lock: The lock used by wait_cond and done_cond. + timeout: A non-negative float indicating the maximum time to wait. + block: Whether to block if the action cannot complete immediately. + + Returns: + The result of the action, if it is successful. + + Raises: + ValueError: If the timeout argument is negative. + """ + if timeout is not None and timeout < 0.0: + raise ValueError('\'timeout\' must not be a negative number') + if not block: + timeout = 0.0 + result = None + success = False + start_time = self.get_time() + lock.acquire() + try: + while not success: + try: + result = action() + success = True + except Exception, e: + if not isinstance(e, exc): + raise e + if timeout is not None: + elapsed_time = self.get_time() - start_time + timeout -= elapsed_time + if timeout <= 0.0: + raise e + wait_cond.wait(timeout) + finally: + if success: + done_cond.notify() + lock.release() + return result + + def put(self, item, block=True, timeout=None): + """Put an item into the requeue. + + Args: + item: An item to add to the requeue. + block: Whether to block if the requeue is full. + timeout: Maximum on how long to wait until the queue is non-full. + + Raises: + Queue.Full if the queue is full and the timeout expires. + """ + def PutAction(): + self.queue.put(item, block=False) + self._DoWithTimeout(PutAction, + Queue.Full, + self.get_cond, + self.put_cond, + self.lock, + timeout=timeout, + block=block) + + def reput(self, item, block=True, timeout=None): + """Re-put an item back into the requeue. + + Re-putting an item does not increase the number of outstanding + tasks, so the reput item should be uniquely associated with an + item that was previously removed from the requeue and for which + task_done has not been called. + + Args: + item: An item to add to the requeue. + block: Whether to block if the requeue is full. + timeout: Maximum on how long to wait until the queue is non-full. + + Raises: + Queue.Full is the queue is full and the timeout expires. + """ + def ReputAction(): + self.requeue.put(item, block=False) + self._DoWithTimeout(ReputAction, + Queue.Full, + self.get_cond, + self.put_cond, + self.lock, + timeout=timeout, + block=block) + + def get(self, block=True, timeout=None): + """Get an item from the requeue. + + Args: + block: Whether to block if the requeue is empty. + timeout: Maximum on how long to wait until the requeue is non-empty. + + Returns: + An item from the requeue. + + Raises: + Queue.Empty if the queue is empty and the timeout expires. + """ + def GetAction(): + try: + result = self.requeue.get(block=False) + self.requeue.task_done() + except Queue.Empty: + result = self.queue.get(block=False) + return result + return self._DoWithTimeout(GetAction, + Queue.Empty, + self.put_cond, + self.get_cond, + self.lock, + timeout=timeout, + block=block) + + def join(self): + """Blocks until all of the items in the requeue have been processed.""" + self.queue.join() + + def task_done(self): + """Indicate that a previously enqueued item has been fully processed.""" + self.queue.task_done() + + def empty(self): + """Returns true if the requeue is empty.""" + return self.queue.empty() and self.requeue.empty() + + def get_nowait(self): + """Try to get an item from the queue without blocking.""" + return self.get(block=False) + + +class ThrottleHandler(urllib2.BaseHandler): + """A urllib2 handler for http and https requests that adds to a throttle.""" + + def __init__(self, throttle): + """Initialize a ThrottleHandler. + + Args: + throttle: A Throttle instance to call for bandwidth and http/https request + throttling. + """ + self.throttle = throttle + + def AddRequest(self, throttle_name, req): + """Add to bandwidth throttle for given request. + + Args: + throttle_name: The name of the bandwidth throttle to add to. + req: The request whose size will be added to the throttle. + """ + size = 0 + for key, value in req.headers.iteritems(): + size += len('%s: %s\n' % (key, value)) + for key, value in req.unredirected_hdrs.iteritems(): + size += len('%s: %s\n' % (key, value)) + (unused_scheme, + unused_host_port, url_path, + unused_query, unused_fragment) = urlparse.urlsplit(req.get_full_url()) + size += len('%s %s HTTP/1.1\n' % (req.get_method(), url_path)) + data = req.get_data() + if data: + size += len(data) + self.throttle.AddTransfer(throttle_name, size) + + def AddResponse(self, throttle_name, res): + """Add to bandwidth throttle for given response. + + Args: + throttle_name: The name of the bandwidth throttle to add to. + res: The response whose size will be added to the throttle. + """ + content = res.read() + def ReturnContent(): + return content + res.read = ReturnContent + size = len(content) + headers = res.info() + for key, value in headers.items(): + size += len('%s: %s\n' % (key, value)) + self.throttle.AddTransfer(throttle_name, size) + + def http_request(self, req): + """Process an HTTP request. + + If the throttle is over quota, sleep first. Then add request size to + throttle before returning it to be sent. + + Args: + req: A urllib2.Request object. + + Returns: + The request passed in. + """ + self.throttle.Sleep() + self.AddRequest(BANDWIDTH_UP, req) + return req + + def https_request(self, req): + """Process an HTTPS request. + + If the throttle is over quota, sleep first. Then add request size to + throttle before returning it to be sent. + + Args: + req: A urllib2.Request object. + + Returns: + The request passed in. + """ + self.throttle.Sleep() + self.AddRequest(HTTPS_BANDWIDTH_UP, req) + return req + + def http_response(self, unused_req, res): + """Process an HTTP response. + + The size of the response is added to the bandwidth throttle and the request + throttle is incremented by one. + + Args: + unused_req: The urllib2 request for this response. + res: A urllib2 response object. + + Returns: + The response passed in. + """ + self.AddResponse(BANDWIDTH_DOWN, res) + self.throttle.AddTransfer(REQUESTS, 1) + return res + + def https_response(self, unused_req, res): + """Process an HTTPS response. + + The size of the response is added to the bandwidth throttle and the request + throttle is incremented by one. + + Args: + unused_req: The urllib2 request for this response. + res: A urllib2 response object. + + Returns: + The response passed in. + """ + self.AddResponse(HTTPS_BANDWIDTH_DOWN, res) + self.throttle.AddTransfer(HTTPS_REQUESTS, 1) + return res + + +class ThrottledHttpRpcServer(appengine_rpc.HttpRpcServer): + """Provides a simplified RPC-style interface for HTTP requests. + + This RPC server uses a Throttle to prevent exceeding quotas. + """ + + def __init__(self, throttle, request_manager, *args, **kwargs): + """Initialize a ThrottledHttpRpcServer. + + Also sets request_manager.rpc_server to the ThrottledHttpRpcServer instance. + + Args: + throttle: A Throttles instance. + request_manager: A RequestManager instance. + args: Positional arguments to pass through to + appengine_rpc.HttpRpcServer.__init__ + kwargs: Keyword arguments to pass through to + appengine_rpc.HttpRpcServer.__init__ + """ + self.throttle = throttle + appengine_rpc.HttpRpcServer.__init__(self, *args, **kwargs) + request_manager.rpc_server = self + + def _GetOpener(self): + """Returns an OpenerDirector that supports cookies and ignores redirects. + + Returns: + A urllib2.OpenerDirector object. + """ + opener = appengine_rpc.HttpRpcServer._GetOpener(self) + opener.add_handler(ThrottleHandler(self.throttle)) + + return opener + + +def ThrottledHttpRpcServerFactory(throttle, request_manager): + """Create a factory to produce ThrottledHttpRpcServer for a given throttle. + + Args: + throttle: A Throttle instance to use for the ThrottledHttpRpcServer. + request_manager: A RequestManager instance. + + Returns: + A factory to produce a ThrottledHttpRpcServer. + """ + def MakeRpcServer(*args, **kwargs): + kwargs['account_type'] = 'HOSTED_OR_GOOGLE' + kwargs['save_cookies'] = True + return ThrottledHttpRpcServer(throttle, request_manager, *args, **kwargs) + return MakeRpcServer + + +class RequestManager(object): + """A class which wraps a connection to the server.""" + + source = 'google-bulkloader-%s' % UPLOADER_VERSION + user_agent = source + + def __init__(self, + app_id, + host_port, + url_path, + kind, + throttle): + """Initialize a RequestManager object. + + Args: + app_id: String containing the application id for requests. + host_port: String containing the "host:port" pair; the port is optional. + url_path: partial URL (path) to post entity data to. + kind: Kind of the Entity records being posted. + throttle: A Throttle instance. + """ + self.app_id = app_id + self.host_port = host_port + self.host = host_port.split(':')[0] + if url_path and url_path[0] != '/': + url_path = '/' + url_path + self.url_path = url_path + self.kind = kind + self.throttle = throttle + self.credentials = None + throttled_rpc_server_factory = ThrottledHttpRpcServerFactory( + self.throttle, self) + logging.debug('Configuring remote_api. app_id = %s, url_path = %s, ' + 'servername = %s' % (app_id, url_path, host_port)) + remote_api_stub.ConfigureRemoteDatastore( + app_id, + url_path, + self.AuthFunction, + servername=host_port, + rpc_server_factory=throttled_rpc_server_factory) + self.authenticated = False + + def Authenticate(self): + """Invoke authentication if necessary.""" + self.rpc_server.Send(self.url_path, payload=None) + self.authenticated = True + + def AuthFunction(self, + raw_input_fn=raw_input, + password_input_fn=getpass.getpass): + """Prompts the user for a username and password. + + Caches the results the first time it is called and returns the + same result every subsequent time. + + Args: + raw_input_fn: Used for dependency injection. + password_input_fn: Used for dependency injection. + + Returns: + A pair of the username and password. + """ + if self.credentials is not None: + return self.credentials + print 'Please enter login credentials for %s (%s)' % ( + self.host, self.app_id) + email = raw_input_fn('Email: ') + if email: + password_prompt = 'Password for %s: ' % email + password = password_input_fn(password_prompt) + else: + password = None + self.credentials = (email, password) + return self.credentials + + def _GetHeaders(self): + """Constructs a dictionary of extra headers to send with a request.""" + headers = { + 'GAE-Uploader-Version': UPLOADER_VERSION, + 'GAE-Uploader-Kind': self.kind + } + return headers + + def EncodeContent(self, rows): + """Encodes row data to the wire format. + + Args: + rows: A list of pairs of a line number and a list of column values. + + Returns: + A list of db.Model instances. + """ + try: + loader = Loader.RegisteredLoaders()[self.kind] + except KeyError: + logging.error('No Loader defined for kind %s.' % self.kind) + raise ConfigurationError('No Loader defined for kind %s.' % self.kind) + entities = [] + for line_number, values in rows: + key = loader.GenerateKey(line_number, values) + entity = loader.CreateEntity(values, key_name=key) + entities.extend(entity) + + return entities + + def PostEntities(self, item): + """Posts Entity records to a remote endpoint over HTTP. + + Args: + item: A workitem containing the entities to post. + + Returns: + A pair of the estimated size of the request in bytes and the response + from the server as a str. + """ + entities = item.content + db.put(entities) + + +class WorkItem(object): + """Holds a unit of uploading work. + + A WorkItem represents a number of entities that need to be uploaded to + Google App Engine. These entities are encoded in the "content" field of + the WorkItem, and will be POST'd as-is to the server. + + The entities are identified by a range of numeric keys, inclusively. In + the case of a resumption of an upload, or a replay to correct errors, + these keys must be able to identify the same set of entities. + + Note that keys specify a range. The entities do not have to sequentially + fill the entire range, they must simply bound a range of valid keys. + """ + + def __init__(self, progress_queue, rows, key_start, key_end, + progress_key=None): + """Initialize the WorkItem instance. + + Args: + progress_queue: A queue used for tracking progress information. + rows: A list of pairs of a line number and a list of column values + key_start: The (numeric) starting key, inclusive. + key_end: The (numeric) ending key, inclusive. + progress_key: If this WorkItem represents state from a prior run, + then this will be the key within the progress database. + """ + self.state = STATE_READ + + self.progress_queue = progress_queue + + assert isinstance(key_start, (int, long)) + assert isinstance(key_end, (int, long)) + assert key_start <= key_end + + self.key_start = key_start + self.key_end = key_end + self.progress_key = progress_key + + self.progress_event = threading.Event() + + self.rows = rows + self.content = None + self.count = len(rows) + + def MarkAsRead(self): + """Mark this WorkItem as read/consumed from the data source.""" + + assert self.state == STATE_READ + + self._StateTransition(STATE_READ, blocking=True) + + assert self.progress_key is not None + + def MarkAsSending(self): + """Mark this WorkItem as in-process on being uploaded to the server.""" + + assert self.state == STATE_READ or self.state == STATE_NOT_SENT + assert self.progress_key is not None + + self._StateTransition(STATE_SENDING, blocking=True) + + def MarkAsSent(self): + """Mark this WorkItem as sucessfully-sent to the server.""" + + assert self.state == STATE_SENDING + assert self.progress_key is not None + + self._StateTransition(STATE_SENT, blocking=False) + + def MarkAsError(self): + """Mark this WorkItem as required manual error recovery.""" + + assert self.state == STATE_SENDING + assert self.progress_key is not None + + self._StateTransition(STATE_NOT_SENT, blocking=True) + + def _StateTransition(self, new_state, blocking=False): + """Transition the work item to a new state, storing progress information. + + Args: + new_state: The state to transition to. + blocking: Whether to block for the progress thread to acknowledge the + transition. + """ + logging.debug('[%s-%s] %s' % + (self.key_start, self.key_end, StateMessage(self.state))) + assert not self.progress_event.isSet() + + self.state = new_state + + self.progress_queue.put(self) + + if blocking: + self.progress_event.wait() + + self.progress_event.clear() + + + +def InterruptibleSleep(sleep_time): + """Puts thread to sleep, checking this threads exit_flag twice a second. + + Args: + sleep_time: Time to sleep. + """ + slept = 0.0 + epsilon = .0001 + thread = threading.currentThread() + while slept < sleep_time - epsilon: + remaining = sleep_time - slept + this_sleep_time = min(remaining, 0.5) + time.sleep(this_sleep_time) + slept += this_sleep_time + if thread.exit_flag: + return + + +class ThreadGate(object): + """Manage the number of active worker threads. + + The ThreadGate limits the number of threads that are simultaneously + uploading batches of records in order to implement adaptive rate + control. The number of simultaneous upload threads that it takes to + start causing timeout varies widely over the course of the day, so + adaptive rate control allows the uploader to do many uploads while + reducing the error rate and thus increasing the throughput. + + Initially the ThreadGate allows only one uploader thread to be active. + For each successful upload, another thread is activated and for each + failed upload, the number of active threads is reduced by one. + """ + + def __init__(self, enabled, sleep=InterruptibleSleep): + self.enabled = enabled + self.enabled_count = 1 + self.lock = threading.Lock() + self.thread_semaphore = threading.Semaphore(self.enabled_count) + self._threads = [] + self.backoff_time = 0 + self.sleep = sleep + + def Register(self, thread): + """Register a thread with the thread gate.""" + self._threads.append(thread) + + def Threads(self): + """Yields the registered threads.""" + for thread in self._threads: + yield thread + + def EnableThread(self): + """Enable one more worker thread.""" + self.lock.acquire() + try: + self.enabled_count += 1 + finally: + self.lock.release() + self.thread_semaphore.release() + + def EnableAllThreads(self): + """Enable all worker threads.""" + for unused_idx in range(len(self._threads) - self.enabled_count): + self.EnableThread() + + def StartWork(self): + """Starts a critical section in which the number of workers is limited. + + If thread throttling is enabled then this method starts a critical + section which allows self.enabled_count simultaneously operating + threads. The critical section is ended by calling self.FinishWork(). + """ + if self.enabled: + self.thread_semaphore.acquire() + if self.backoff_time > 0.0: + if not threading.currentThread().exit_flag: + logging.info('Backing off: %.1f seconds', + self.backoff_time) + self.sleep(self.backoff_time) + + def FinishWork(self): + """Ends a critical section started with self.StartWork().""" + if self.enabled: + self.thread_semaphore.release() + + def IncreaseWorkers(self): + """Informs the throttler that an item was successfully sent. + + If thread throttling is enabled, this method will cause an + additional thread to run in the critical section. + """ + if self.enabled: + if self.backoff_time > 0.0: + logging.info('Resetting backoff to 0.0') + self.backoff_time = 0.0 + do_enable = False + self.lock.acquire() + try: + if self.enabled and len(self._threads) > self.enabled_count: + do_enable = True + self.enabled_count += 1 + finally: + self.lock.release() + if do_enable: + self.thread_semaphore.release() + + def DecreaseWorkers(self): + """Informs the thread_gate that an item failed to send. + + If thread throttling is enabled, this method will cause the + throttler to allow one fewer thread in the critical section. If + there is only one thread remaining, failures will result in + exponential backoff until there is a success. + """ + if self.enabled: + do_disable = False + self.lock.acquire() + try: + if self.enabled: + if self.enabled_count > 1: + do_disable = True + self.enabled_count -= 1 + else: + if self.backoff_time == 0.0: + self.backoff_time = INITIAL_BACKOFF + else: + self.backoff_time *= BACKOFF_FACTOR + finally: + self.lock.release() + if do_disable: + self.thread_semaphore.acquire() + + +class Throttle(object): + """A base class for upload rate throttling. + + Transferring large number of records, too quickly, to an application + could trigger quota limits and cause the transfer process to halt. + In order to stay within the application's quota, we throttle the + data transfer to a specified limit (across all transfer threads). + This limit defaults to about half of the Google App Engine default + for an application, but can be manually adjusted faster/slower as + appropriate. + + This class tracks a moving average of some aspect of the transfer + rate (bandwidth, records per second, http connections per + second). It keeps two windows of counts of bytes transferred, on a + per-thread basis. One block is the "current" block, and the other is + the "prior" block. It will rotate the counts from current to prior + when ROTATE_PERIOD has passed. Thus, the current block will + represent from 0 seconds to ROTATE_PERIOD seconds of activity + (determined by: time.time() - self.last_rotate). The prior block + will always represent a full ROTATE_PERIOD. + + Sleeping is performed just before a transfer of another block, and is + based on the counts transferred *before* the next transfer. It really + does not matter how much will be transferred, but only that for all the + data transferred SO FAR that we have interspersed enough pauses to + ensure the aggregate transfer rate is within the specified limit. + + These counts are maintained on a per-thread basis, so we do not require + any interlocks around incrementing the counts. There IS an interlock on + the rotation of the counts because we do not want multiple threads to + multiply-rotate the counts. + + There are various race conditions in the computation and collection + of these counts. We do not require precise values, but simply to + keep the overall transfer within the bandwidth limits. If a given + pause is a little short, or a little long, then the aggregate delays + will be correct. + """ + + ROTATE_PERIOD = 600 + + def __init__(self, + get_time=time.time, + thread_sleep=InterruptibleSleep, + layout=None): + self.get_time = get_time + self.thread_sleep = thread_sleep + + self.start_time = get_time() + self.transferred = {} + self.prior_block = {} + self.totals = {} + self.throttles = {} + + self.last_rotate = {} + self.rotate_mutex = {} + if layout: + self.AddThrottles(layout) + + def AddThrottle(self, name, limit): + self.throttles[name] = limit + self.transferred[name] = {} + self.prior_block[name] = {} + self.totals[name] = {} + self.last_rotate[name] = self.get_time() + self.rotate_mutex[name] = threading.Lock() + + def AddThrottles(self, layout): + for key, value in layout.iteritems(): + self.AddThrottle(key, value) + + def Register(self, thread): + """Register this thread with the throttler.""" + thread_name = thread.getName() + for throttle_name in self.throttles.iterkeys(): + self.transferred[throttle_name][thread_name] = 0 + self.prior_block[throttle_name][thread_name] = 0 + self.totals[throttle_name][thread_name] = 0 + + def VerifyName(self, throttle_name): + if throttle_name not in self.throttles: + raise AssertionError('%s is not a registered throttle' % throttle_name) + + def AddTransfer(self, throttle_name, token_count): + """Add a count to the amount this thread has transferred. + + Each time a thread transfers some data, it should call this method to + note the amount sent. The counts may be rotated if sufficient time + has passed since the last rotation. + + Note: this method should only be called by the BulkLoaderThread + instances. The token count is allocated towards the + "current thread". + + Args: + throttle_name: The name of the throttle to add to. + token_count: The number to add to the throttle counter. + """ + self.VerifyName(throttle_name) + transferred = self.transferred[throttle_name] + transferred[threading.currentThread().getName()] += token_count + + if self.last_rotate[throttle_name] + self.ROTATE_PERIOD < self.get_time(): + self._RotateCounts(throttle_name) + + def Sleep(self, throttle_name=None): + """Possibly sleep in order to limit the transfer rate. + + Note that we sleep based on *prior* transfers rather than what we + may be about to transfer. The next transfer could put us under/over + and that will be rectified *after* that transfer. Net result is that + the average transfer rate will remain within bounds. Spiky behavior + or uneven rates among the threads could possibly bring the transfer + rate above the requested limit for short durations. + + Args: + throttle_name: The name of the throttle to sleep on. If None or + omitted, then sleep on all throttles. + """ + if throttle_name is None: + for throttle_name in self.throttles: + self.Sleep(throttle_name=throttle_name) + return + + self.VerifyName(throttle_name) + + thread = threading.currentThread() + + while True: + duration = self.get_time() - self.last_rotate[throttle_name] + + total = 0 + for count in self.prior_block[throttle_name].values(): + total += count + + if total: + duration += self.ROTATE_PERIOD + + for count in self.transferred[throttle_name].values(): + total += count + + sleep_time = (float(total) / self.throttles[throttle_name]) - duration + + if sleep_time < MINIMUM_THROTTLE_SLEEP_DURATION: + break + + logging.debug('[%s] Throttling on %s. Sleeping for %.1f ms ' + '(duration=%.1f ms, total=%d)', + thread.getName(), throttle_name, + sleep_time * 1000, duration * 1000, total) + self.thread_sleep(sleep_time) + if thread.exit_flag: + break + self._RotateCounts(throttle_name) + + def _RotateCounts(self, throttle_name): + """Rotate the transfer counters. + + If sufficient time has passed, then rotate the counters from active to + the prior-block of counts. + + This rotation is interlocked to ensure that multiple threads do not + over-rotate the counts. + + Args: + throttle_name: The name of the throttle to rotate. + """ + self.VerifyName(throttle_name) + self.rotate_mutex[throttle_name].acquire() + try: + next_rotate_time = self.last_rotate[throttle_name] + self.ROTATE_PERIOD + if next_rotate_time >= self.get_time(): + return + + for name, count in self.transferred[throttle_name].items(): + + + self.prior_block[throttle_name][name] = count + self.transferred[throttle_name][name] = 0 + + self.totals[throttle_name][name] += count + + self.last_rotate[throttle_name] = self.get_time() + + finally: + self.rotate_mutex[throttle_name].release() + + def TotalTransferred(self, throttle_name): + """Return the total transferred, and over what period. + + Args: + throttle_name: The name of the throttle to total. + + Returns: + A tuple of the total count and running time for the given throttle name. + """ + total = 0 + for count in self.totals[throttle_name].values(): + total += count + for count in self.transferred[throttle_name].values(): + total += count + return total, self.get_time() - self.start_time + + +class _ThreadBase(threading.Thread): + """Provide some basic features for the threads used in the uploader. + + This abstract base class is used to provide some common features: + + * Flag to ask thread to exit as soon as possible. + * Record exit/error status for the primary thread to pick up. + * Capture exceptions and record them for pickup. + * Some basic logging of thread start/stop. + * All threads are "daemon" threads. + * Friendly names for presenting to users. + + Concrete sub-classes must implement PerformWork(). + + Either self.NAME should be set or GetFriendlyName() be overridden to + return a human-friendly name for this thread. + + The run() method starts the thread and prints start/exit messages. + + self.exit_flag is intended to signal that this thread should exit + when it gets the chance. PerformWork() should check self.exit_flag + whenever it has the opportunity to exit gracefully. + """ + + def __init__(self): + threading.Thread.__init__(self) + + self.setDaemon(True) + + self.exit_flag = False + self.error = None + + def run(self): + """Perform the work of the thread.""" + logging.info('[%s] %s: started', self.getName(), self.__class__.__name__) + + try: + self.PerformWork() + except: + self.error = sys.exc_info()[1] + logging.exception('[%s] %s:', self.getName(), self.__class__.__name__) + + logging.info('[%s] %s: exiting', self.getName(), self.__class__.__name__) + + def PerformWork(self): + """Perform the thread-specific work.""" + raise NotImplementedError() + + def CheckError(self): + """If an error is present, then log it.""" + if self.error: + logging.error('Error in %s: %s', self.GetFriendlyName(), self.error) + + def GetFriendlyName(self): + """Returns a human-friendly description of the thread.""" + if hasattr(self, 'NAME'): + return self.NAME + return 'unknown thread' + + +class BulkLoaderThread(_ThreadBase): + """A thread which transmits entities to the server application. + + This thread will read WorkItem instances from the work_queue and upload + the entities to the server application. Progress information will be + pushed into the progress_queue as the work is being performed. + + If a BulkLoaderThread encounters a transient error, the entities will be + resent, if a fatal error is encoutered the BulkLoaderThread exits. + """ + + def __init__(self, + work_queue, + throttle, + thread_gate, + request_manager): + """Initialize the BulkLoaderThread instance. + + Args: + work_queue: A queue containing WorkItems for processing. + throttle: A Throttles to control upload bandwidth. + thread_gate: A ThreadGate to control number of simultaneous uploads. + request_manager: A RequestManager instance. + """ + _ThreadBase.__init__(self) + + self.work_queue = work_queue + self.throttle = throttle + self.thread_gate = thread_gate + + self.request_manager = request_manager + + def PerformWork(self): + """Perform the work of a BulkLoaderThread.""" + while not self.exit_flag: + success = False + self.thread_gate.StartWork() + try: + try: + item = self.work_queue.get(block=True, timeout=1.0) + except Queue.Empty: + continue + if item == _THREAD_SHOULD_EXIT: + break + + logging.debug('[%s] Got work item [%d-%d]', + self.getName(), item.key_start, item.key_end) + + try: + + item.MarkAsSending() + try: + if item.content is None: + item.content = self.request_manager.EncodeContent(item.rows) + try: + self.request_manager.PostEntities(item) + success = True + logging.debug( + '[%d-%d] Sent %d entities', + item.key_start, item.key_end, item.count) + self.throttle.AddTransfer(RECORDS, item.count) + except (db.InternalError, db.NotSavedError, db.Timeout), e: + logging.debug('Caught non-fatal error: %s', e) + except urllib2.HTTPError, e: + if e.code == 403 or (e.code >= 500 and e.code < 600): + logging.debug('Caught HTTP error %d', e.code) + logging.debug('%s', e.read()) + else: + raise e + + except: + self.error = sys.exc_info()[1] + logging.exception('[%s] %s: caught exception %s', self.getName(), + self.__class__.__name__, str(sys.exc_info())) + raise + + finally: + if success: + item.MarkAsSent() + self.thread_gate.IncreaseWorkers() + self.work_queue.task_done() + else: + item.MarkAsError() + self.thread_gate.DecreaseWorkers() + try: + self.work_queue.reput(item, block=False) + except Queue.Full: + logging.error('[%s] Failed to reput work item.', self.getName()) + raise Error('Failed to reput work item') + logging.info('[%d-%d] %s', + item.key_start, item.key_end, StateMessage(item.state)) + + finally: + self.thread_gate.FinishWork() + + + def GetFriendlyName(self): + """Returns a human-friendly name for this thread.""" + return 'worker [%s]' % self.getName() + + +class DataSourceThread(_ThreadBase): + """A thread which reads WorkItems and pushes them into queue. + + This thread will read/consume WorkItems from a generator (produced by + the generator factory). These WorkItems will then be pushed into the + work_queue. Note that reading will block if/when the work_queue becomes + full. Information on content consumed from the generator will be pushed + into the progress_queue. + """ + + NAME = 'data source thread' + + def __init__(self, + work_queue, + progress_queue, + workitem_generator_factory, + progress_generator_factory): + """Initialize the DataSourceThread instance. + + Args: + work_queue: A queue containing WorkItems for processing. + progress_queue: A queue used for tracking progress information. + workitem_generator_factory: A factory that creates a WorkItem generator + progress_generator_factory: A factory that creates a generator which + produces prior progress status, or None if there is no prior status + to use. + """ + _ThreadBase.__init__(self) + + self.work_queue = work_queue + self.progress_queue = progress_queue + self.workitem_generator_factory = workitem_generator_factory + self.progress_generator_factory = progress_generator_factory + self.entity_count = 0 + + def PerformWork(self): + """Performs the work of a DataSourceThread.""" + if self.progress_generator_factory: + progress_gen = self.progress_generator_factory() + else: + progress_gen = None + + content_gen = self.workitem_generator_factory(self.progress_queue, + progress_gen) + + self.sent_count = 0 + self.read_count = 0 + self.read_all = False + + for item in content_gen.Batches(): + item.MarkAsRead() + + while not self.exit_flag: + try: + self.work_queue.put(item, block=True, timeout=1.0) + self.entity_count += item.count + break + except Queue.Full: + pass + + if self.exit_flag: + break + + if not self.exit_flag: + self.read_all = True + self.read_count = content_gen.row_count + self.sent_count = content_gen.sent_count + + + +def _RunningInThread(thread): + """Return True if we are running within the specified thread.""" + return threading.currentThread().getName() == thread.getName() + + +class ProgressDatabase(object): + """Persistently record all progress information during an upload. + + This class wraps a very simple SQLite database which records each of + the relevant details from the WorkItem instances. If the uploader is + resumed, then data is replayed out of the database. + """ + + def __init__(self, db_filename, commit_periodicity=100): + """Initialize the ProgressDatabase instance. + + Args: + db_filename: The name of the SQLite database to use. + commit_periodicity: How many operations to perform between commits. + """ + self.db_filename = db_filename + + logging.info('Using progress database: %s', db_filename) + self.primary_conn = sqlite3.connect(db_filename, isolation_level=None) + self.primary_thread = threading.currentThread() + + self.progress_conn = None + self.progress_thread = None + + self.operation_count = 0 + self.commit_periodicity = commit_periodicity + + self.prior_key_end = None + + try: + self.primary_conn.execute( + """create table progress ( + id integer primary key autoincrement, + state integer not null, + key_start integer not null, + key_end integer not null + ) + """) + except sqlite3.OperationalError, e: + if 'already exists' not in e.message: + raise + + try: + self.primary_conn.execute('create index i_state on progress (state)') + except sqlite3.OperationalError, e: + if 'already exists' not in e.message: + raise + + def ThreadComplete(self): + """Finalize any operations the progress thread has performed. + + The database aggregates lots of operations into a single commit, and + this method is used to commit any pending operations as the thread + is about to shut down. + """ + if self.progress_conn: + self._MaybeCommit(force_commit=True) + + def _MaybeCommit(self, force_commit=False): + """Periodically commit changes into the SQLite database. + + Committing every operation is quite expensive, and slows down the + operation of the script. Thus, we only commit after every N operations, + as determined by the self.commit_periodicity value. Optionally, the + caller can force a commit. + + Args: + force_commit: Pass True in order for a commit to occur regardless + of the current operation count. + """ + self.operation_count += 1 + if force_commit or (self.operation_count % self.commit_periodicity) == 0: + self.progress_conn.commit() + + def _OpenProgressConnection(self): + """Possibly open a database connection for the progress tracker thread. + + If the connection is not open (for the calling thread, which is assumed + to be the progress tracker thread), then open it. We also open a couple + cursors for later use (and reuse). + """ + if self.progress_conn: + return + + assert not _RunningInThread(self.primary_thread) + + self.progress_thread = threading.currentThread() + + self.progress_conn = sqlite3.connect(self.db_filename) + + self.insert_cursor = self.progress_conn.cursor() + self.update_cursor = self.progress_conn.cursor() + + def HasUnfinishedWork(self): + """Returns True if the database has progress information. + + Note there are two basic cases for progress information: + 1) All saved records indicate a successful upload. In this case, we + need to skip everything transmitted so far and then send the rest. + 2) Some records for incomplete transfer are present. These need to be + sent again, and then we resume sending after all the successful + data. + + Returns: + True if the database has progress information, False otherwise. + + Raises: + ResumeError: If there is an error reading the progress database. + """ + assert _RunningInThread(self.primary_thread) + + cursor = self.primary_conn.cursor() + cursor.execute('select count(*) from progress') + row = cursor.fetchone() + if row is None: + raise ResumeError('Error reading progress information.') + + return row[0] != 0 + + def StoreKeys(self, key_start, key_end): + """Record a new progress record, returning a key for later updates. + + The specified progress information will be persisted into the database. + A unique key will be returned that identifies this progress state. The + key is later used to (quickly) update this record. + + For the progress resumption to proceed properly, calls to StoreKeys + MUST specify monotonically increasing key ranges. This will result in + a database whereby the ID, KEY_START, and KEY_END rows are all + increasing (rather than having ranges out of order). + + NOTE: the above precondition is NOT tested by this method (since it + would imply an additional table read or two on each invocation). + + Args: + key_start: The starting key of the WorkItem (inclusive) + key_end: The end key of the WorkItem (inclusive) + + Returns: + A string to later be used as a unique key to update this state. + """ + self._OpenProgressConnection() + + assert _RunningInThread(self.progress_thread) + assert isinstance(key_start, int) + assert isinstance(key_end, int) + assert key_start <= key_end + + if self.prior_key_end is not None: + assert key_start > self.prior_key_end + self.prior_key_end = key_end + + self.insert_cursor.execute( + 'insert into progress (state, key_start, key_end) values (?, ?, ?)', + (STATE_READ, key_start, key_end)) + + progress_key = self.insert_cursor.lastrowid + + self._MaybeCommit() + + return progress_key + + def UpdateState(self, key, new_state): + """Update a specified progress record with new information. + + Args: + key: The key for this progress record, returned from StoreKeys + new_state: The new state to associate with this progress record. + """ + self._OpenProgressConnection() + + assert _RunningInThread(self.progress_thread) + assert isinstance(new_state, int) + + self.update_cursor.execute('update progress set state=? where id=?', + (new_state, key)) + + self._MaybeCommit() + + def GetProgressStatusGenerator(self): + """Get a generator which returns progress information. + + The returned generator will yield a series of 4-tuples that specify + progress information about a prior run of the uploader. The 4-tuples + have the following values: + + progress_key: The unique key to later update this record with new + progress information. + state: The last state saved for this progress record. + key_start: The starting key of the items for uploading (inclusive). + key_end: The ending key of the items for uploading (inclusive). + + After all incompletely-transferred records are provided, then one + more 4-tuple will be generated: + + None + DATA_CONSUMED_TO_HERE: A unique string value indicating this record + is being provided. + None + key_end: An integer value specifying the last data source key that + was handled by the previous run of the uploader. + + The caller should begin uploading records which occur after key_end. + + Yields: + Progress information as tuples (progress_key, state, key_start, key_end). + """ + conn = sqlite3.connect(self.db_filename, isolation_level=None) + cursor = conn.cursor() + + cursor.execute('select max(id) from progress') + batch_id = cursor.fetchone()[0] + + cursor.execute('select key_end from progress where id = ?', (batch_id,)) + key_end = cursor.fetchone()[0] + + self.prior_key_end = key_end + + cursor.execute( + 'select id, state, key_start, key_end from progress' + ' where state != ?' + ' order by id', + (STATE_SENT,)) + + rows = cursor.fetchall() + + for row in rows: + if row is None: + break + + yield row + + yield None, DATA_CONSUMED_TO_HERE, None, key_end + + +class StubProgressDatabase(object): + """A stub implementation of ProgressDatabase which does nothing.""" + + def HasUnfinishedWork(self): + """Whether the stub database has progress information (it doesn't).""" + return False + + def StoreKeys(self, unused_key_start, unused_key_end): + """Pretend to store a key in the stub database.""" + return 'fake-key' + + def UpdateState(self, unused_key, unused_new_state): + """Pretend to update the state of a progress item.""" + pass + + def ThreadComplete(self): + """Finalize operations on the stub database (i.e. do nothing).""" + pass + + +class ProgressTrackerThread(_ThreadBase): + """A thread which records progress information for the upload process. + + The progress information is stored into the provided progress database. + This class is not responsible for replaying a prior run's progress + information out of the database. Separate mechanisms must be used to + resume a prior upload attempt. + """ + + NAME = 'progress tracking thread' + + def __init__(self, progress_queue, progress_db): + """Initialize the ProgressTrackerThread instance. + + Args: + progress_queue: A Queue used for tracking progress information. + progress_db: The database for tracking progress information; should + be an instance of ProgressDatabase. + """ + _ThreadBase.__init__(self) + + self.progress_queue = progress_queue + self.db = progress_db + self.entities_sent = 0 + + def PerformWork(self): + """Performs the work of a ProgressTrackerThread.""" + while not self.exit_flag: + try: + item = self.progress_queue.get(block=True, timeout=1.0) + except Queue.Empty: + continue + if item == _THREAD_SHOULD_EXIT: + break + + if item.state == STATE_READ and item.progress_key is None: + item.progress_key = self.db.StoreKeys(item.key_start, item.key_end) + else: + assert item.progress_key is not None + + self.db.UpdateState(item.progress_key, item.state) + if item.state == STATE_SENT: + self.entities_sent += item.count + + item.progress_event.set() + + self.progress_queue.task_done() + + self.db.ThreadComplete() + + + +def Validate(value, typ): + """Checks that value is non-empty and of the right type. + + Args: + value: any value + typ: a type or tuple of types + + Raises: + ValueError if value is None or empty. + TypeError if it's not the given type. + + """ + if not value: + raise ValueError('Value should not be empty; received %s.' % value) + elif not isinstance(value, typ): + raise TypeError('Expected a %s, but received %s (a %s).' % + (typ, value, value.__class__)) + + +class Loader(object): + """A base class for creating datastore entities from input data. + + To add a handler for bulk loading a new entity kind into your datastore, + write a subclass of this class that calls Loader.__init__ from your + class's __init__. + + If you need to run extra code to convert entities from the input + data, create new properties, or otherwise modify the entities before + they're inserted, override HandleEntity. + + See the CreateEntity method for the creation of entities from the + (parsed) input data. + """ + + __loaders = {} + __kind = None + __properties = None + + def __init__(self, kind, properties): + """Constructor. + + Populates this Loader's kind and properties map. Also registers it with + the bulk loader, so that all you need to do is instantiate your Loader, + and the bulkload handler will automatically use it. + + Args: + kind: a string containing the entity kind that this loader handles + + properties: list of (name, converter) tuples. + + This is used to automatically convert the CSV columns into + properties. The converter should be a function that takes one + argument, a string value from the CSV file, and returns a + correctly typed property value that should be inserted. The + tuples in this list should match the columns in your CSV file, + in order. + + For example: + [('name', str), + ('id_number', int), + ('email', datastore_types.Email), + ('user', users.User), + ('birthdate', lambda x: datetime.datetime.fromtimestamp(float(x))), + ('description', datastore_types.Text), + ] + """ + Validate(kind, basestring) + self.__kind = kind + + db.class_for_kind(kind) + + Validate(properties, list) + for name, fn in properties: + Validate(name, basestring) + assert callable(fn), ( + 'Conversion function %s for property %s is not callable.' % (fn, name)) + + self.__properties = properties + + @staticmethod + def RegisterLoader(loader): + + Loader.__loaders[loader.__kind] = loader + + def kind(self): + """ Return the entity kind that this Loader handes. + """ + return self.__kind + + def CreateEntity(self, values, key_name=None): + """Creates a entity from a list of property values. + + Args: + values: list/tuple of str + key_name: if provided, the name for the (single) resulting entity + + Returns: + list of db.Model + + The returned entities are populated with the property values from the + argument, converted to native types using the properties map given in + the constructor, and passed through HandleEntity. They're ready to be + inserted. + + Raises: + AssertionError if the number of values doesn't match the number + of properties in the properties map. + ValueError if any element of values is None or empty. + TypeError if values is not a list or tuple. + """ + Validate(values, (list, tuple)) + assert len(values) == len(self.__properties), ( + 'Expected %d CSV columns, found %d.' % + (len(self.__properties), len(values))) + + model_class = db.class_for_kind(self.__kind) + + properties = {'key_name': key_name} + for (name, converter), val in zip(self.__properties, values): + if converter is bool and val.lower() in ('0', 'false', 'no'): + val = False + properties[name] = converter(val) + + entity = model_class(**properties) + entities = self.HandleEntity(entity) + + if entities: + if not isinstance(entities, (list, tuple)): + entities = [entities] + + for entity in entities: + if not isinstance(entity, db.Model): + raise TypeError('Expected a db.Model, received %s (a %s).' % + (entity, entity.__class__)) + + return entities + + def GenerateKey(self, i, values): + """Generates a key_name to be used in creating the underlying object. + + The default implementation returns None. + + This method can be overridden to control the key generation for + uploaded entities. The value returned should be None (to use a + server generated numeric key), or a string which neither starts + with a digit nor has the form __*__. (See + http://code.google.com/appengine/docs/python/datastore/keysandentitygroups.html) + + If you generate your own string keys, keep in mind: + + 1. The key name for each entity must be unique. + 2. If an entity of the same kind and key already exists in the + datastore, it will be overwritten. + + Args: + i: Number corresponding to this object (assume it's run in a loop, + this is your current count. + values: list/tuple of str. + + Returns: + A string to be used as the key_name for an entity. + """ + return None + + def HandleEntity(self, entity): + """Subclasses can override this to add custom entity conversion code. + + This is called for each entity, after its properties are populated from + CSV but before it is stored. Subclasses can override this to add custom + entity handling code. + + The entity to be inserted should be returned. If multiple entities should + be inserted, return a list of entities. If no entities should be inserted, + return None or []. + + Args: + entity: db.Model + + Returns: + db.Model or list of db.Model + """ + return entity + + + @staticmethod + def RegisteredLoaders(): + """Returns a list of the Loader instances that have been created. + """ + return dict(Loader.__loaders) + + +class QueueJoinThread(threading.Thread): + """A thread that joins a queue and exits. + + Queue joins do not have a timeout. To simulate a queue join with + timeout, run this thread and join it with a timeout. + """ + + def __init__(self, queue): + """Initialize a QueueJoinThread. + + Args: + queue: The queue for this thread to join. + """ + threading.Thread.__init__(self) + assert isinstance(queue, (Queue.Queue, ReQueue)) + self.queue = queue + + def run(self): + """Perform the queue join in this thread.""" + self.queue.join() + + +def InterruptibleQueueJoin(queue, + thread_local, + thread_gate, + queue_join_thread_factory=QueueJoinThread): + """Repeatedly joins the given ReQueue or Queue.Queue with short timeout. + + Between each timeout on the join, worker threads are checked. + + Args: + queue: A Queue.Queue or ReQueue instance. + thread_local: A threading.local instance which indicates interrupts. + thread_gate: A ThreadGate instance. + queue_join_thread_factory: Used for dependency injection. + + Returns: + True unless the queue join is interrupted by SIGINT or worker death. + """ + thread = queue_join_thread_factory(queue) + thread.start() + while True: + thread.join(timeout=.5) + if not thread.isAlive(): + return True + if thread_local.shut_down: + logging.debug('Queue join interrupted') + return False + for worker_thread in thread_gate.Threads(): + if not worker_thread.isAlive(): + return False + + +def ShutdownThreads(data_source_thread, work_queue, thread_gate): + """Shuts down the worker and data source threads. + + Args: + data_source_thread: A running DataSourceThread instance. + work_queue: The work queue. + thread_gate: A ThreadGate instance with workers registered. + """ + logging.info('An error occurred. Shutting down...') + + data_source_thread.exit_flag = True + + for thread in thread_gate.Threads(): + thread.exit_flag = True + + for unused_thread in thread_gate.Threads(): + thread_gate.EnableThread() + + data_source_thread.join(timeout=3.0) + if data_source_thread.isAlive(): + logging.warn('%s hung while trying to exit', + data_source_thread.GetFriendlyName()) + + while not work_queue.empty(): + try: + unused_item = work_queue.get_nowait() + work_queue.task_done() + except Queue.Empty: + pass + + +def PerformBulkUpload(app_id, + post_url, + kind, + workitem_generator_factory, + num_threads, + throttle, + progress_db, + max_queue_size=DEFAULT_QUEUE_SIZE, + request_manager_factory=RequestManager, + bulkloaderthread_factory=BulkLoaderThread, + progresstrackerthread_factory=ProgressTrackerThread, + datasourcethread_factory=DataSourceThread, + work_queue_factory=ReQueue, + progress_queue_factory=Queue.Queue): + """Uploads data into an application using a series of HTTP POSTs. + + This function will spin up a number of threads to read entities from + the data source, pass those to a number of worker ("uploader") threads + for sending to the application, and track all of the progress in a + small database in case an error or pause/termination requires a + restart/resumption of the upload process. + + Args: + app_id: String containing application id. + post_url: URL to post the Entity data to. + kind: Kind of the Entity records being posted. + workitem_generator_factory: A factory that creates a WorkItem generator. + num_threads: How many uploader threads should be created. + throttle: A Throttle instance. + progress_db: The database to use for replaying/recording progress. + max_queue_size: Maximum size of the queues before they should block. + request_manager_factory: Used for dependency injection. + bulkloaderthread_factory: Used for dependency injection. + progresstrackerthread_factory: Used for dependency injection. + datasourcethread_factory: Used for dependency injection. + work_queue_factory: Used for dependency injection. + progress_queue_factory: Used for dependency injection. + + Raises: + AuthenticationError: If authentication is required and fails. + """ + thread_gate = ThreadGate(True) + + (unused_scheme, + host_port, url_path, + unused_query, unused_fragment) = urlparse.urlsplit(post_url) + + work_queue = work_queue_factory(max_queue_size) + progress_queue = progress_queue_factory(max_queue_size) + request_manager = request_manager_factory(app_id, + host_port, + url_path, + kind, + throttle) + + throttle.Register(threading.currentThread()) + try: + request_manager.Authenticate() + except Exception, e: + logging.exception(e) + raise AuthenticationError('Authentication failed') + if (request_manager.credentials is not None and + not request_manager.authenticated): + raise AuthenticationError('Authentication failed') + + for unused_idx in range(num_threads): + thread = bulkloaderthread_factory(work_queue, + throttle, + thread_gate, + request_manager) + throttle.Register(thread) + thread_gate.Register(thread) + + progress_thread = progresstrackerthread_factory(progress_queue, progress_db) + + if progress_db.HasUnfinishedWork(): + logging.debug('Restarting upload using progress database') + progress_generator_factory = progress_db.GetProgressStatusGenerator + else: + progress_generator_factory = None + + data_source_thread = datasourcethread_factory(work_queue, + progress_queue, + workitem_generator_factory, + progress_generator_factory) + + thread_local = threading.local() + thread_local.shut_down = False + + def Interrupt(unused_signum, unused_frame): + """Shutdown gracefully in response to a signal.""" + thread_local.shut_down = True + + signal.signal(signal.SIGINT, Interrupt) + + progress_thread.start() + data_source_thread.start() + for thread in thread_gate.Threads(): + thread.start() + + + while not thread_local.shut_down: + data_source_thread.join(timeout=0.25) + + if data_source_thread.isAlive(): + for thread in list(thread_gate.Threads()) + [progress_thread]: + if not thread.isAlive(): + logging.info('Unexpected thread death: %s', thread.getName()) + thread_local.shut_down = True + break + else: + break + + if thread_local.shut_down: + ShutdownThreads(data_source_thread, work_queue, thread_gate) + + def _Join(ob, msg): + logging.debug('Waiting for %s...', msg) + if isinstance(ob, threading.Thread): + ob.join(timeout=3.0) + if ob.isAlive(): + logging.debug('Joining %s failed', ob.GetFriendlyName()) + else: + logging.debug('... done.') + elif isinstance(ob, (Queue.Queue, ReQueue)): + if not InterruptibleQueueJoin(ob, thread_local, thread_gate): + ShutdownThreads(data_source_thread, work_queue, thread_gate) + else: + ob.join() + logging.debug('... done.') + + _Join(work_queue, 'work_queue to flush') + + for unused_thread in thread_gate.Threads(): + work_queue.put(_THREAD_SHOULD_EXIT) + + for unused_thread in thread_gate.Threads(): + thread_gate.EnableThread() + + for thread in thread_gate.Threads(): + _Join(thread, 'thread [%s] to terminate' % thread.getName()) + + thread.CheckError() + + if progress_thread.isAlive(): + _Join(progress_queue, 'progress_queue to finish') + else: + logging.warn('Progress thread exited prematurely') + + progress_queue.put(_THREAD_SHOULD_EXIT) + _Join(progress_thread, 'progress_thread to terminate') + progress_thread.CheckError() + + data_source_thread.CheckError() + + total_up, duration = throttle.TotalTransferred(BANDWIDTH_UP) + s_total_up, unused_duration = throttle.TotalTransferred(HTTPS_BANDWIDTH_UP) + total_up += s_total_up + logging.info('%d entites read, %d previously transferred', + data_source_thread.read_count, + data_source_thread.sent_count) + logging.info('%d entities (%d bytes) transferred in %.1f seconds', + progress_thread.entities_sent, total_up, duration) + if (data_source_thread.read_all and + progress_thread.entities_sent + data_source_thread.sent_count >= + data_source_thread.read_count): + logging.info('All entities successfully uploaded') + else: + logging.info('Some entities not successfully uploaded') + + +def PrintUsageExit(code): + """Prints usage information and exits with a status code. + + Args: + code: Status code to pass to sys.exit() after displaying usage information. + """ + print __doc__ % {'arg0': sys.argv[0]} + sys.stdout.flush() + sys.stderr.flush() + sys.exit(code) + + +def ParseArguments(argv): + """Parses command-line arguments. + + Prints out a help message if -h or --help is supplied. + + Args: + argv: List of command-line arguments. + + Returns: + Tuple (url, filename, cookie, batch_size, kind) containing the values from + each corresponding command-line flag. + """ + opts, unused_args = getopt.getopt( + argv[1:], + 'h', + ['debug', + 'help', + 'url=', + 'filename=', + 'batch_size=', + 'kind=', + 'num_threads=', + 'bandwidth_limit=', + 'rps_limit=', + 'http_limit=', + 'db_filename=', + 'app_id=', + 'config_file=', + 'auth_domain=', + ]) + + url = None + filename = None + batch_size = DEFAULT_BATCH_SIZE + kind = None + num_threads = DEFAULT_THREAD_COUNT + bandwidth_limit = DEFAULT_BANDWIDTH_LIMIT + rps_limit = DEFAULT_RPS_LIMIT + http_limit = DEFAULT_REQUEST_LIMIT + db_filename = None + app_id = None + config_file = None + auth_domain = 'gmail.com' + + for option, value in opts: + if option == '--debug': + logging.getLogger().setLevel(logging.DEBUG) + elif option in ('-h', '--help'): + PrintUsageExit(0) + elif option == '--url': + url = value + elif option == '--filename': + filename = value + elif option == '--batch_size': + batch_size = int(value) + elif option == '--kind': + kind = value + elif option == '--num_threads': + num_threads = int(value) + elif option == '--bandwidth_limit': + bandwidth_limit = int(value) + elif option == '--rps_limit': + rps_limit = int(value) + elif option == '--http_limit': + http_limit = int(value) + elif option == '--db_filename': + db_filename = value + elif option == '--app_id': + app_id = value + elif option == '--config_file': + config_file = value + elif option == '--auth_domain': + auth_domain = value + + return ProcessArguments(app_id=app_id, + url=url, + filename=filename, + batch_size=batch_size, + kind=kind, + num_threads=num_threads, + bandwidth_limit=bandwidth_limit, + rps_limit=rps_limit, + http_limit=http_limit, + db_filename=db_filename, + config_file=config_file, + auth_domain=auth_domain, + die_fn=lambda: PrintUsageExit(1)) + + +def ThrottleLayout(bandwidth_limit, http_limit, rps_limit): + return { + BANDWIDTH_UP: bandwidth_limit, + BANDWIDTH_DOWN: bandwidth_limit, + REQUESTS: http_limit, + HTTPS_BANDWIDTH_UP: bandwidth_limit / 5, + HTTPS_BANDWIDTH_DOWN: bandwidth_limit / 5, + HTTPS_REQUESTS: http_limit / 5, + RECORDS: rps_limit, + } + + +def LoadConfig(config_file): + """Loads a config file and registers any Loader classes present.""" + if config_file: + global_dict = dict(globals()) + execfile(config_file, global_dict) + for cls in Loader.__subclasses__(): + Loader.RegisterLoader(cls()) + + +def _MissingArgument(arg_name, die_fn): + """Print error message about missing argument and die.""" + print >>sys.stderr, '%s argument required' % arg_name + die_fn() + + +def ProcessArguments(app_id=None, + url=None, + filename=None, + batch_size=DEFAULT_BATCH_SIZE, + kind=None, + num_threads=DEFAULT_THREAD_COUNT, + bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT, + rps_limit=DEFAULT_RPS_LIMIT, + http_limit=DEFAULT_REQUEST_LIMIT, + db_filename=None, + config_file=None, + auth_domain='gmail.com', + die_fn=lambda: sys.exit(1)): + """Processes non command-line input arguments.""" + if db_filename is None: + db_filename = time.strftime('bulkloader-progress-%Y%m%d.%H%M%S.sql3') + + if batch_size <= 0: + print >>sys.stderr, 'batch_size must be 1 or larger' + die_fn() + + if url is None: + _MissingArgument('url', die_fn) + + if filename is None: + _MissingArgument('filename', die_fn) + + if kind is None: + _MissingArgument('kind', die_fn) + + if config_file is None: + _MissingArgument('config_file', die_fn) + + if app_id is None: + (unused_scheme, host_port, unused_url_path, + unused_query, unused_fragment) = urlparse.urlsplit(url) + suffix_idx = host_port.find('.appspot.com') + if suffix_idx > -1: + app_id = host_port[:suffix_idx] + elif host_port.split(':')[0].endswith('google.com'): + app_id = host_port.split('.')[0] + else: + print >>sys.stderr, 'app_id required for non appspot.com domains' + die_fn() + + return (app_id, url, filename, batch_size, kind, num_threads, + bandwidth_limit, rps_limit, http_limit, db_filename, config_file, + auth_domain) + + +def _PerformBulkload(app_id=None, + url=None, + filename=None, + batch_size=DEFAULT_BATCH_SIZE, + kind=None, + num_threads=DEFAULT_THREAD_COUNT, + bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT, + rps_limit=DEFAULT_RPS_LIMIT, + http_limit=DEFAULT_REQUEST_LIMIT, + db_filename=None, + config_file=None, + auth_domain='gmail.com'): + """Runs the bulkloader, given the options as keyword arguments. + + Args: + app_id: The application id. + url: The url of the remote_api endpoint. + filename: The name of the file containing the CSV data. + batch_size: The number of records to send per request. + kind: The kind of entity to transfer. + num_threads: The number of threads to use to transfer data. + bandwidth_limit: Maximum bytes/second to transfers. + rps_limit: Maximum records/second to transfer. + http_limit: Maximum requests/second for transfers. + db_filename: The name of the SQLite3 progress database file. + config_file: The name of the configuration file. + auth_domain: The auth domain to use for logins and UserProperty. + + Returns: + An exit code. + """ + os.environ['AUTH_DOMAIN'] = auth_domain + LoadConfig(config_file) + + throttle_layout = ThrottleLayout(bandwidth_limit, http_limit, rps_limit) + + throttle = Throttle(layout=throttle_layout) + + + workitem_generator_factory = GetCSVGeneratorFactory(filename, batch_size) + + if db_filename == 'skip': + progress_db = StubProgressDatabase() + else: + progress_db = ProgressDatabase(db_filename) + + + max_queue_size = max(DEFAULT_QUEUE_SIZE, 2 * num_threads + 5) + + PerformBulkUpload(app_id, + url, + kind, + workitem_generator_factory, + num_threads, + throttle, + progress_db, + max_queue_size=max_queue_size) + + return 0 + + +def Run(app_id=None, + url=None, + filename=None, + batch_size=DEFAULT_BATCH_SIZE, + kind=None, + num_threads=DEFAULT_THREAD_COUNT, + bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT, + rps_limit=DEFAULT_RPS_LIMIT, + http_limit=DEFAULT_REQUEST_LIMIT, + db_filename=None, + auth_domain='gmail.com', + config_file=None): + """Sets up and runs the bulkloader, given the options as keyword arguments. + + Args: + app_id: The application id. + url: The url of the remote_api endpoint. + filename: The name of the file containing the CSV data. + batch_size: The number of records to send per request. + kind: The kind of entity to transfer. + num_threads: The number of threads to use to transfer data. + bandwidth_limit: Maximum bytes/second to transfers. + rps_limit: Maximum records/second to transfer. + http_limit: Maximum requests/second for transfers. + db_filename: The name of the SQLite3 progress database file. + config_file: The name of the configuration file. + auth_domain: The auth domain to use for logins and UserProperty. + + Returns: + An exit code. + """ + logging.basicConfig( + format='%(levelname)-8s %(asctime)s %(filename)s] %(message)s') + args = ProcessArguments(app_id=app_id, + url=url, + filename=filename, + batch_size=batch_size, + kind=kind, + num_threads=num_threads, + bandwidth_limit=bandwidth_limit, + rps_limit=rps_limit, + http_limit=http_limit, + db_filename=db_filename, + config_file=config_file) + + (app_id, url, filename, batch_size, kind, num_threads, bandwidth_limit, + rps_limit, http_limit, db_filename, config_file, auth_domain) = args + + return _PerformBulkload(app_id=app_id, + url=url, + filename=filename, + batch_size=batch_size, + kind=kind, + num_threads=num_threads, + bandwidth_limit=bandwidth_limit, + rps_limit=rps_limit, + http_limit=http_limit, + db_filename=db_filename, + config_file=config_file, + auth_domain=auth_domain) + + +def main(argv): + """Runs the importer from the command line.""" + logging.basicConfig( + level=logging.INFO, + format='%(levelname)-8s %(asctime)s %(filename)s] %(message)s') + + args = ParseArguments(argv) + if None in args: + print >>sys.stderr, 'Invalid arguments' + PrintUsageExit(1) + + (app_id, url, filename, batch_size, kind, num_threads, + bandwidth_limit, rps_limit, http_limit, db_filename, config_file, + auth_domain) = args + + return _PerformBulkload(app_id=app_id, + url=url, + filename=filename, + batch_size=batch_size, + kind=kind, + num_threads=num_threads, + bandwidth_limit=bandwidth_limit, + rps_limit=rps_limit, + http_limit=http_limit, + db_filename=db_filename, + config_file=config_file, + auth_domain=auth_domain) + + +if __name__ == '__main__': + sys.exit(main(sys.argv))