Factor out the setup() method in interactive
Also allow specifying a custom context dictionary in remote, which
will be used by the stats module to add helper methods.
Patch by: Sverre Rabbelier
#!/usr/bin/env python## Copyright 2007 Google Inc.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#"""Imports CSV data over HTTP.Usage: %(arg0)s [flags] --debug Show debugging information. (Optional) --app_id=<string> Application ID of endpoint (Optional for *.appspot.com) --auth_domain=<domain> The auth domain to use for logging in and for UserProperties. (Default: gmail.com) --bandwidth_limit=<int> The maximum number of bytes per second for the aggregate transfer of data to the server. Bursts --batch_size=<int> Number of Entity objects to include in each post to the URL endpoint. The more data per row/Entity, the smaller the batch size should be. (Default 10) --config_file=<path> File containing Model and Loader definitions. (Required) --db_filename=<path> Specific progress database to write to, or to resume from. If not supplied, then a new database will be started, named: bulkloader-progress-TIMESTAMP. The special filename "skip" may be used to simply skip reading/writing any progress information. --filename=<path> Path to the CSV file to import. (Required) --http_limit=<int> The maximum numer of HTTP requests per second to send to the server. (Default: 8) --kind=<string> Name of the Entity object kind to put in the datastore. (Required) --num_threads=<int> Number of threads to use for uploading entities (Default 10) may exceed this, but overall transfer rate is restricted to this rate. (Default 250000) --rps_limit=<int> The maximum number of records per second to transfer to the server. (Default: 20) --url=<string> URL endpoint to post to for importing data. (Required)The exit status will be 0 on success, non-zero on import failure.Works with the remote_api mix-in library for google.appengine.ext.remote_api.Please look there for documentation about how to setup the server side.Example:%(arg0)s --url=http://app.appspot.com/remote_api --kind=Model \ --filename=data.csv --config_file=loader_config.py"""import csvimport getoptimport getpassimport loggingimport newimport osimport Queueimport signalimport sysimport threadingimport timeimport tracebackimport urllib2import urlparsefrom google.appengine.ext import dbfrom google.appengine.ext.remote_api import remote_api_stubfrom google.appengine.tools import appengine_rpctry: import sqlite3except ImportError: passUPLOADER_VERSION = '1'DEFAULT_THREAD_COUNT = 10DEFAULT_BATCH_SIZE = 10DEFAULT_QUEUE_SIZE = DEFAULT_THREAD_COUNT * 10_THREAD_SHOULD_EXIT = '_THREAD_SHOULD_EXIT'STATE_READ = 0STATE_SENDING = 1STATE_SENT = 2STATE_NOT_SENT = 3MINIMUM_THROTTLE_SLEEP_DURATION = 0.001DATA_CONSUMED_TO_HERE = 'DATA_CONSUMED_TO_HERE'INITIAL_BACKOFF = 1.0BACKOFF_FACTOR = 2.0DEFAULT_BANDWIDTH_LIMIT = 250000DEFAULT_RPS_LIMIT = 20DEFAULT_REQUEST_LIMIT = 8BANDWIDTH_UP = 'http-bandwidth-up'BANDWIDTH_DOWN = 'http-bandwidth-down'REQUESTS = 'http-requests'HTTPS_BANDWIDTH_UP = 'https-bandwidth-up'HTTPS_BANDWIDTH_DOWN = 'https-bandwidth-down'HTTPS_REQUESTS = 'https-requests'RECORDS = 'records'def StateMessage(state): """Converts a numeric state identifier to a status message.""" return ({ STATE_READ: 'Batch read from file.', STATE_SENDING: 'Sending batch to server.', STATE_SENT: 'Batch successfully sent.', STATE_NOT_SENT: 'Error while sending batch.' }[state])class Error(Exception): """Base-class for exceptions in this module."""class FatalServerError(Error): """An unrecoverable error occurred while trying to post data to the server."""class ResumeError(Error): """Error while trying to resume a partial upload."""class ConfigurationError(Error): """Error in configuration options."""class AuthenticationError(Error): """Error while trying to authenticate with the server."""def GetCSVGeneratorFactory(csv_filename, batch_size, openfile=open, create_csv_reader=csv.reader): """Return a factory that creates a CSV-based WorkItem generator. Args: csv_filename: File on disk containing CSV data. batch_size: Maximum number of CSV rows to stash into a WorkItem. openfile: Used for dependency injection. create_csv_reader: Used for dependency injection. Returns: A callable (accepting the Progress Queue and Progress Generators as input) which creates the WorkItem generator. """ def CreateGenerator(progress_queue, progress_generator): """Initialize a CSV generator linked to a progress generator and queue. Args: progress_queue: A ProgressQueue instance to send progress information. progress_generator: A generator of progress information or None. Returns: A CSVGenerator instance. """ return CSVGenerator(progress_queue, progress_generator, csv_filename, batch_size, openfile, create_csv_reader) return CreateGeneratorclass CSVGenerator(object): """Reads a CSV file and generates WorkItems containing batches of records.""" def __init__(self, progress_queue, progress_generator, csv_filename, batch_size, openfile, create_csv_reader): """Initializes a CSV generator. Args: progress_queue: A queue used for tracking progress information. progress_generator: A generator of prior progress information, or None if there is no prior status. csv_filename: File on disk containing CSV data. batch_size: Maximum number of CSV rows to stash into a WorkItem. openfile: Used for dependency injection of 'open'. create_csv_reader: Used for dependency injection of 'csv.reader'. """ self.progress_queue = progress_queue self.progress_generator = progress_generator self.csv_filename = csv_filename self.batch_size = batch_size self.openfile = openfile self.create_csv_reader = create_csv_reader self.line_number = 1 self.column_count = None self.read_rows = [] self.reader = None self.row_count = 0 self.sent_count = 0 def _AdvanceTo(self, line): """Advance the reader to the given line. Args: line: A line number to advance to. """ while self.line_number < line: self.reader.next() self.line_number += 1 self.row_count += 1 self.sent_count += 1 def _ReadRows(self, key_start, key_end): """Attempts to read and encode rows [key_start, key_end]. The encoded rows are stored in self.read_rows. Args: key_start: The starting line number. key_end: The ending line number. Raises: StopIteration: if the reader runs out of rows ResumeError: if there are an inconsistent number of columns. """ assert self.line_number == key_start self.read_rows = [] while self.line_number <= key_end: row = self.reader.next() self.row_count += 1 if self.column_count is None: self.column_count = len(row) else: if self.column_count != len(row): raise ResumeError('Column count mismatch, %d: %s' % (self.column_count, str(row))) self.read_rows.append((self.line_number, row)) self.line_number += 1 def _MakeItem(self, key_start, key_end, rows, progress_key=None): """Makes a WorkItem containing the given rows, with the given keys. Args: key_start: The start key for the WorkItem. key_end: The end key for the WorkItem. rows: A list of the rows for the WorkItem. progress_key: The progress key for the WorkItem Returns: A WorkItem instance for the given batch. """ assert rows item = WorkItem(self.progress_queue, rows, key_start, key_end, progress_key=progress_key) return item def Batches(self): """Reads the CSV data file and generates WorkItems. Yields: Instances of class WorkItem Raises: ResumeError: If the progress database and data file indicate a different number of rows. """ csv_file = self.openfile(self.csv_filename, 'r') csv_content = csv_file.read() if csv_content: has_headers = csv.Sniffer().has_header(csv_content) else: has_headers = False csv_file.seek(0) self.reader = self.create_csv_reader(csv_file, skipinitialspace=True) if has_headers: logging.info('The CSV file appears to have a header line, skipping.') self.reader.next() exhausted = False self.line_number = 1 self.column_count = None logging.info('Starting import; maximum %d entities per post', self.batch_size) state = None if self.progress_generator is not None: for progress_key, state, key_start, key_end in self.progress_generator: if key_start: try: self._AdvanceTo(key_start) self._ReadRows(key_start, key_end) yield self._MakeItem(key_start, key_end, self.read_rows, progress_key=progress_key) except StopIteration: logging.error('Mismatch between data file and progress database') raise ResumeError( 'Mismatch between data file and progress database') elif state == DATA_CONSUMED_TO_HERE: try: self._AdvanceTo(key_end + 1) except StopIteration: state = None if self.progress_generator is None or state == DATA_CONSUMED_TO_HERE: while not exhausted: key_start = self.line_number key_end = self.line_number + self.batch_size - 1 try: self._ReadRows(key_start, key_end) except StopIteration: exhausted = True key_end = self.line_number - 1 if key_start <= key_end: yield self._MakeItem(key_start, key_end, self.read_rows)class ReQueue(object): """A special thread-safe queue. A ReQueue allows unfinished work items to be returned with a call to reput(). When an item is reput, task_done() should *not* be called in addition, getting an item that has been reput does not increase the number of outstanding tasks. This class shares an interface with Queue.Queue and provides the additional Reput method. """ def __init__(self, queue_capacity, requeue_capacity=None, queue_factory=Queue.Queue, get_time=time.time): """Initialize a ReQueue instance. Args: queue_capacity: The number of items that can be put in the ReQueue. requeue_capacity: The numer of items that can be reput in the ReQueue. queue_factory: Used for dependency injection. get_time: Used for dependency injection. """ if requeue_capacity is None: requeue_capacity = queue_capacity self.get_time = get_time self.queue = queue_factory(queue_capacity) self.requeue = queue_factory(requeue_capacity) self.lock = threading.Lock() self.put_cond = threading.Condition(self.lock) self.get_cond = threading.Condition(self.lock) def _DoWithTimeout(self, action, exc, wait_cond, done_cond, lock, timeout=None, block=True): """Performs the given action with a timeout. The action must be non-blocking, and raise an instance of exc on a recoverable failure. If the action fails with an instance of exc, we wait on wait_cond before trying again. Failure after the timeout is reached is propagated as an exception. Success is signalled by notifying on done_cond and returning the result of the action. If action raises any exception besides an instance of exc, it is immediately propagated. Args: action: A callable that performs a non-blocking action. exc: An exception type that is thrown by the action to indicate a recoverable error. wait_cond: A condition variable which should be waited on when action throws exc. done_cond: A condition variable to signal if the action returns. lock: The lock used by wait_cond and done_cond. timeout: A non-negative float indicating the maximum time to wait. block: Whether to block if the action cannot complete immediately. Returns: The result of the action, if it is successful. Raises: ValueError: If the timeout argument is negative. """ if timeout is not None and timeout < 0.0: raise ValueError('\'timeout\' must not be a negative number') if not block: timeout = 0.0 result = None success = False start_time = self.get_time() lock.acquire() try: while not success: try: result = action() success = True except Exception, e: if not isinstance(e, exc): raise e if timeout is not None: elapsed_time = self.get_time() - start_time timeout -= elapsed_time if timeout <= 0.0: raise e wait_cond.wait(timeout) finally: if success: done_cond.notify() lock.release() return result def put(self, item, block=True, timeout=None): """Put an item into the requeue. Args: item: An item to add to the requeue. block: Whether to block if the requeue is full. timeout: Maximum on how long to wait until the queue is non-full. Raises: Queue.Full if the queue is full and the timeout expires. """ def PutAction(): self.queue.put(item, block=False) self._DoWithTimeout(PutAction, Queue.Full, self.get_cond, self.put_cond, self.lock, timeout=timeout, block=block) def reput(self, item, block=True, timeout=None): """Re-put an item back into the requeue. Re-putting an item does not increase the number of outstanding tasks, so the reput item should be uniquely associated with an item that was previously removed from the requeue and for which task_done has not been called. Args: item: An item to add to the requeue. block: Whether to block if the requeue is full. timeout: Maximum on how long to wait until the queue is non-full. Raises: Queue.Full is the queue is full and the timeout expires. """ def ReputAction(): self.requeue.put(item, block=False) self._DoWithTimeout(ReputAction, Queue.Full, self.get_cond, self.put_cond, self.lock, timeout=timeout, block=block) def get(self, block=True, timeout=None): """Get an item from the requeue. Args: block: Whether to block if the requeue is empty. timeout: Maximum on how long to wait until the requeue is non-empty. Returns: An item from the requeue. Raises: Queue.Empty if the queue is empty and the timeout expires. """ def GetAction(): try: result = self.requeue.get(block=False) self.requeue.task_done() except Queue.Empty: result = self.queue.get(block=False) return result return self._DoWithTimeout(GetAction, Queue.Empty, self.put_cond, self.get_cond, self.lock, timeout=timeout, block=block) def join(self): """Blocks until all of the items in the requeue have been processed.""" self.queue.join() def task_done(self): """Indicate that a previously enqueued item has been fully processed.""" self.queue.task_done() def empty(self): """Returns true if the requeue is empty.""" return self.queue.empty() and self.requeue.empty() def get_nowait(self): """Try to get an item from the queue without blocking.""" return self.get(block=False)class ThrottleHandler(urllib2.BaseHandler): """A urllib2 handler for http and https requests that adds to a throttle.""" def __init__(self, throttle): """Initialize a ThrottleHandler. Args: throttle: A Throttle instance to call for bandwidth and http/https request throttling. """ self.throttle = throttle def AddRequest(self, throttle_name, req): """Add to bandwidth throttle for given request. Args: throttle_name: The name of the bandwidth throttle to add to. req: The request whose size will be added to the throttle. """ size = 0 for key, value in req.headers.iteritems(): size += len('%s: %s\n' % (key, value)) for key, value in req.unredirected_hdrs.iteritems(): size += len('%s: %s\n' % (key, value)) (unused_scheme, unused_host_port, url_path, unused_query, unused_fragment) = urlparse.urlsplit(req.get_full_url()) size += len('%s %s HTTP/1.1\n' % (req.get_method(), url_path)) data = req.get_data() if data: size += len(data) self.throttle.AddTransfer(throttle_name, size) def AddResponse(self, throttle_name, res): """Add to bandwidth throttle for given response. Args: throttle_name: The name of the bandwidth throttle to add to. res: The response whose size will be added to the throttle. """ content = res.read() def ReturnContent(): return content res.read = ReturnContent size = len(content) headers = res.info() for key, value in headers.items(): size += len('%s: %s\n' % (key, value)) self.throttle.AddTransfer(throttle_name, size) def http_request(self, req): """Process an HTTP request. If the throttle is over quota, sleep first. Then add request size to throttle before returning it to be sent. Args: req: A urllib2.Request object. Returns: The request passed in. """ self.throttle.Sleep() self.AddRequest(BANDWIDTH_UP, req) return req def https_request(self, req): """Process an HTTPS request. If the throttle is over quota, sleep first. Then add request size to throttle before returning it to be sent. Args: req: A urllib2.Request object. Returns: The request passed in. """ self.throttle.Sleep() self.AddRequest(HTTPS_BANDWIDTH_UP, req) return req def http_response(self, unused_req, res): """Process an HTTP response. The size of the response is added to the bandwidth throttle and the request throttle is incremented by one. Args: unused_req: The urllib2 request for this response. res: A urllib2 response object. Returns: The response passed in. """ self.AddResponse(BANDWIDTH_DOWN, res) self.throttle.AddTransfer(REQUESTS, 1) return res def https_response(self, unused_req, res): """Process an HTTPS response. The size of the response is added to the bandwidth throttle and the request throttle is incremented by one. Args: unused_req: The urllib2 request for this response. res: A urllib2 response object. Returns: The response passed in. """ self.AddResponse(HTTPS_BANDWIDTH_DOWN, res) self.throttle.AddTransfer(HTTPS_REQUESTS, 1) return resclass ThrottledHttpRpcServer(appengine_rpc.HttpRpcServer): """Provides a simplified RPC-style interface for HTTP requests. This RPC server uses a Throttle to prevent exceeding quotas. """ def __init__(self, throttle, request_manager, *args, **kwargs): """Initialize a ThrottledHttpRpcServer. Also sets request_manager.rpc_server to the ThrottledHttpRpcServer instance. Args: throttle: A Throttles instance. request_manager: A RequestManager instance. args: Positional arguments to pass through to appengine_rpc.HttpRpcServer.__init__ kwargs: Keyword arguments to pass through to appengine_rpc.HttpRpcServer.__init__ """ self.throttle = throttle appengine_rpc.HttpRpcServer.__init__(self, *args, **kwargs) request_manager.rpc_server = self def _GetOpener(self): """Returns an OpenerDirector that supports cookies and ignores redirects. Returns: A urllib2.OpenerDirector object. """ opener = appengine_rpc.HttpRpcServer._GetOpener(self) opener.add_handler(ThrottleHandler(self.throttle)) return openerdef ThrottledHttpRpcServerFactory(throttle, request_manager): """Create a factory to produce ThrottledHttpRpcServer for a given throttle. Args: throttle: A Throttle instance to use for the ThrottledHttpRpcServer. request_manager: A RequestManager instance. Returns: A factory to produce a ThrottledHttpRpcServer. """ def MakeRpcServer(*args, **kwargs): kwargs['account_type'] = 'HOSTED_OR_GOOGLE' kwargs['save_cookies'] = True return ThrottledHttpRpcServer(throttle, request_manager, *args, **kwargs) return MakeRpcServerclass RequestManager(object): """A class which wraps a connection to the server.""" source = 'google-bulkloader-%s' % UPLOADER_VERSION user_agent = source def __init__(self, app_id, host_port, url_path, kind, throttle): """Initialize a RequestManager object. Args: app_id: String containing the application id for requests. host_port: String containing the "host:port" pair; the port is optional. url_path: partial URL (path) to post entity data to. kind: Kind of the Entity records being posted. throttle: A Throttle instance. """ self.app_id = app_id self.host_port = host_port self.host = host_port.split(':')[0] if url_path and url_path[0] != '/': url_path = '/' + url_path self.url_path = url_path self.kind = kind self.throttle = throttle self.credentials = None throttled_rpc_server_factory = ThrottledHttpRpcServerFactory( self.throttle, self) logging.debug('Configuring remote_api. app_id = %s, url_path = %s, ' 'servername = %s' % (app_id, url_path, host_port)) remote_api_stub.ConfigureRemoteDatastore( app_id, url_path, self.AuthFunction, servername=host_port, rpc_server_factory=throttled_rpc_server_factory) self.authenticated = False def Authenticate(self): """Invoke authentication if necessary.""" self.rpc_server.Send(self.url_path, payload=None) self.authenticated = True def AuthFunction(self, raw_input_fn=raw_input, password_input_fn=getpass.getpass): """Prompts the user for a username and password. Caches the results the first time it is called and returns the same result every subsequent time. Args: raw_input_fn: Used for dependency injection. password_input_fn: Used for dependency injection. Returns: A pair of the username and password. """ if self.credentials is not None: return self.credentials print 'Please enter login credentials for %s (%s)' % ( self.host, self.app_id) email = raw_input_fn('Email: ') if email: password_prompt = 'Password for %s: ' % email password = password_input_fn(password_prompt) else: password = None self.credentials = (email, password) return self.credentials def _GetHeaders(self): """Constructs a dictionary of extra headers to send with a request.""" headers = { 'GAE-Uploader-Version': UPLOADER_VERSION, 'GAE-Uploader-Kind': self.kind } return headers def EncodeContent(self, rows): """Encodes row data to the wire format. Args: rows: A list of pairs of a line number and a list of column values. Returns: A list of db.Model instances. """ try: loader = Loader.RegisteredLoaders()[self.kind] except KeyError: logging.error('No Loader defined for kind %s.' % self.kind) raise ConfigurationError('No Loader defined for kind %s.' % self.kind) entities = [] for line_number, values in rows: key = loader.GenerateKey(line_number, values) entity = loader.CreateEntity(values, key_name=key) entities.extend(entity) return entities def PostEntities(self, item): """Posts Entity records to a remote endpoint over HTTP. Args: item: A workitem containing the entities to post. Returns: A pair of the estimated size of the request in bytes and the response from the server as a str. """ entities = item.content db.put(entities)class WorkItem(object): """Holds a unit of uploading work. A WorkItem represents a number of entities that need to be uploaded to Google App Engine. These entities are encoded in the "content" field of the WorkItem, and will be POST'd as-is to the server. The entities are identified by a range of numeric keys, inclusively. In the case of a resumption of an upload, or a replay to correct errors, these keys must be able to identify the same set of entities. Note that keys specify a range. The entities do not have to sequentially fill the entire range, they must simply bound a range of valid keys. """ def __init__(self, progress_queue, rows, key_start, key_end, progress_key=None): """Initialize the WorkItem instance. Args: progress_queue: A queue used for tracking progress information. rows: A list of pairs of a line number and a list of column values key_start: The (numeric) starting key, inclusive. key_end: The (numeric) ending key, inclusive. progress_key: If this WorkItem represents state from a prior run, then this will be the key within the progress database. """ self.state = STATE_READ self.progress_queue = progress_queue assert isinstance(key_start, (int, long)) assert isinstance(key_end, (int, long)) assert key_start <= key_end self.key_start = key_start self.key_end = key_end self.progress_key = progress_key self.progress_event = threading.Event() self.rows = rows self.content = None self.count = len(rows) def MarkAsRead(self): """Mark this WorkItem as read/consumed from the data source.""" assert self.state == STATE_READ self._StateTransition(STATE_READ, blocking=True) assert self.progress_key is not None def MarkAsSending(self): """Mark this WorkItem as in-process on being uploaded to the server.""" assert self.state == STATE_READ or self.state == STATE_NOT_SENT assert self.progress_key is not None self._StateTransition(STATE_SENDING, blocking=True) def MarkAsSent(self): """Mark this WorkItem as sucessfully-sent to the server.""" assert self.state == STATE_SENDING assert self.progress_key is not None self._StateTransition(STATE_SENT, blocking=False) def MarkAsError(self): """Mark this WorkItem as required manual error recovery.""" assert self.state == STATE_SENDING assert self.progress_key is not None self._StateTransition(STATE_NOT_SENT, blocking=True) def _StateTransition(self, new_state, blocking=False): """Transition the work item to a new state, storing progress information. Args: new_state: The state to transition to. blocking: Whether to block for the progress thread to acknowledge the transition. """ logging.debug('[%s-%s] %s' % (self.key_start, self.key_end, StateMessage(self.state))) assert not self.progress_event.isSet() self.state = new_state self.progress_queue.put(self) if blocking: self.progress_event.wait() self.progress_event.clear()def InterruptibleSleep(sleep_time): """Puts thread to sleep, checking this threads exit_flag twice a second. Args: sleep_time: Time to sleep. """ slept = 0.0 epsilon = .0001 thread = threading.currentThread() while slept < sleep_time - epsilon: remaining = sleep_time - slept this_sleep_time = min(remaining, 0.5) time.sleep(this_sleep_time) slept += this_sleep_time if thread.exit_flag: returnclass ThreadGate(object): """Manage the number of active worker threads. The ThreadGate limits the number of threads that are simultaneously uploading batches of records in order to implement adaptive rate control. The number of simultaneous upload threads that it takes to start causing timeout varies widely over the course of the day, so adaptive rate control allows the uploader to do many uploads while reducing the error rate and thus increasing the throughput. Initially the ThreadGate allows only one uploader thread to be active. For each successful upload, another thread is activated and for each failed upload, the number of active threads is reduced by one. """ def __init__(self, enabled, sleep=InterruptibleSleep): self.enabled = enabled self.enabled_count = 1 self.lock = threading.Lock() self.thread_semaphore = threading.Semaphore(self.enabled_count) self._threads = [] self.backoff_time = 0 self.sleep = sleep def Register(self, thread): """Register a thread with the thread gate.""" self._threads.append(thread) def Threads(self): """Yields the registered threads.""" for thread in self._threads: yield thread def EnableThread(self): """Enable one more worker thread.""" self.lock.acquire() try: self.enabled_count += 1 finally: self.lock.release() self.thread_semaphore.release() def EnableAllThreads(self): """Enable all worker threads.""" for unused_idx in range(len(self._threads) - self.enabled_count): self.EnableThread() def StartWork(self): """Starts a critical section in which the number of workers is limited. If thread throttling is enabled then this method starts a critical section which allows self.enabled_count simultaneously operating threads. The critical section is ended by calling self.FinishWork(). """ if self.enabled: self.thread_semaphore.acquire() if self.backoff_time > 0.0: if not threading.currentThread().exit_flag: logging.info('Backing off: %.1f seconds', self.backoff_time) self.sleep(self.backoff_time) def FinishWork(self): """Ends a critical section started with self.StartWork().""" if self.enabled: self.thread_semaphore.release() def IncreaseWorkers(self): """Informs the throttler that an item was successfully sent. If thread throttling is enabled, this method will cause an additional thread to run in the critical section. """ if self.enabled: if self.backoff_time > 0.0: logging.info('Resetting backoff to 0.0') self.backoff_time = 0.0 do_enable = False self.lock.acquire() try: if self.enabled and len(self._threads) > self.enabled_count: do_enable = True self.enabled_count += 1 finally: self.lock.release() if do_enable: self.thread_semaphore.release() def DecreaseWorkers(self): """Informs the thread_gate that an item failed to send. If thread throttling is enabled, this method will cause the throttler to allow one fewer thread in the critical section. If there is only one thread remaining, failures will result in exponential backoff until there is a success. """ if self.enabled: do_disable = False self.lock.acquire() try: if self.enabled: if self.enabled_count > 1: do_disable = True self.enabled_count -= 1 else: if self.backoff_time == 0.0: self.backoff_time = INITIAL_BACKOFF else: self.backoff_time *= BACKOFF_FACTOR finally: self.lock.release() if do_disable: self.thread_semaphore.acquire()class Throttle(object): """A base class for upload rate throttling. Transferring large number of records, too quickly, to an application could trigger quota limits and cause the transfer process to halt. In order to stay within the application's quota, we throttle the data transfer to a specified limit (across all transfer threads). This limit defaults to about half of the Google App Engine default for an application, but can be manually adjusted faster/slower as appropriate. This class tracks a moving average of some aspect of the transfer rate (bandwidth, records per second, http connections per second). It keeps two windows of counts of bytes transferred, on a per-thread basis. One block is the "current" block, and the other is the "prior" block. It will rotate the counts from current to prior when ROTATE_PERIOD has passed. Thus, the current block will represent from 0 seconds to ROTATE_PERIOD seconds of activity (determined by: time.time() - self.last_rotate). The prior block will always represent a full ROTATE_PERIOD. Sleeping is performed just before a transfer of another block, and is based on the counts transferred *before* the next transfer. It really does not matter how much will be transferred, but only that for all the data transferred SO FAR that we have interspersed enough pauses to ensure the aggregate transfer rate is within the specified limit. These counts are maintained on a per-thread basis, so we do not require any interlocks around incrementing the counts. There IS an interlock on the rotation of the counts because we do not want multiple threads to multiply-rotate the counts. There are various race conditions in the computation and collection of these counts. We do not require precise values, but simply to keep the overall transfer within the bandwidth limits. If a given pause is a little short, or a little long, then the aggregate delays will be correct. """ ROTATE_PERIOD = 600 def __init__(self, get_time=time.time, thread_sleep=InterruptibleSleep, layout=None): self.get_time = get_time self.thread_sleep = thread_sleep self.start_time = get_time() self.transferred = {} self.prior_block = {} self.totals = {} self.throttles = {} self.last_rotate = {} self.rotate_mutex = {} if layout: self.AddThrottles(layout) def AddThrottle(self, name, limit): self.throttles[name] = limit self.transferred[name] = {} self.prior_block[name] = {} self.totals[name] = {} self.last_rotate[name] = self.get_time() self.rotate_mutex[name] = threading.Lock() def AddThrottles(self, layout): for key, value in layout.iteritems(): self.AddThrottle(key, value) def Register(self, thread): """Register this thread with the throttler.""" thread_name = thread.getName() for throttle_name in self.throttles.iterkeys(): self.transferred[throttle_name][thread_name] = 0 self.prior_block[throttle_name][thread_name] = 0 self.totals[throttle_name][thread_name] = 0 def VerifyName(self, throttle_name): if throttle_name not in self.throttles: raise AssertionError('%s is not a registered throttle' % throttle_name) def AddTransfer(self, throttle_name, token_count): """Add a count to the amount this thread has transferred. Each time a thread transfers some data, it should call this method to note the amount sent. The counts may be rotated if sufficient time has passed since the last rotation. Note: this method should only be called by the BulkLoaderThread instances. The token count is allocated towards the "current thread". Args: throttle_name: The name of the throttle to add to. token_count: The number to add to the throttle counter. """ self.VerifyName(throttle_name) transferred = self.transferred[throttle_name] transferred[threading.currentThread().getName()] += token_count if self.last_rotate[throttle_name] + self.ROTATE_PERIOD < self.get_time(): self._RotateCounts(throttle_name) def Sleep(self, throttle_name=None): """Possibly sleep in order to limit the transfer rate. Note that we sleep based on *prior* transfers rather than what we may be about to transfer. The next transfer could put us under/over and that will be rectified *after* that transfer. Net result is that the average transfer rate will remain within bounds. Spiky behavior or uneven rates among the threads could possibly bring the transfer rate above the requested limit for short durations. Args: throttle_name: The name of the throttle to sleep on. If None or omitted, then sleep on all throttles. """ if throttle_name is None: for throttle_name in self.throttles: self.Sleep(throttle_name=throttle_name) return self.VerifyName(throttle_name) thread = threading.currentThread() while True: duration = self.get_time() - self.last_rotate[throttle_name] total = 0 for count in self.prior_block[throttle_name].values(): total += count if total: duration += self.ROTATE_PERIOD for count in self.transferred[throttle_name].values(): total += count sleep_time = (float(total) / self.throttles[throttle_name]) - duration if sleep_time < MINIMUM_THROTTLE_SLEEP_DURATION: break logging.debug('[%s] Throttling on %s. Sleeping for %.1f ms ' '(duration=%.1f ms, total=%d)', thread.getName(), throttle_name, sleep_time * 1000, duration * 1000, total) self.thread_sleep(sleep_time) if thread.exit_flag: break self._RotateCounts(throttle_name) def _RotateCounts(self, throttle_name): """Rotate the transfer counters. If sufficient time has passed, then rotate the counters from active to the prior-block of counts. This rotation is interlocked to ensure that multiple threads do not over-rotate the counts. Args: throttle_name: The name of the throttle to rotate. """ self.VerifyName(throttle_name) self.rotate_mutex[throttle_name].acquire() try: next_rotate_time = self.last_rotate[throttle_name] + self.ROTATE_PERIOD if next_rotate_time >= self.get_time(): return for name, count in self.transferred[throttle_name].items(): self.prior_block[throttle_name][name] = count self.transferred[throttle_name][name] = 0 self.totals[throttle_name][name] += count self.last_rotate[throttle_name] = self.get_time() finally: self.rotate_mutex[throttle_name].release() def TotalTransferred(self, throttle_name): """Return the total transferred, and over what period. Args: throttle_name: The name of the throttle to total. Returns: A tuple of the total count and running time for the given throttle name. """ total = 0 for count in self.totals[throttle_name].values(): total += count for count in self.transferred[throttle_name].values(): total += count return total, self.get_time() - self.start_timeclass _ThreadBase(threading.Thread): """Provide some basic features for the threads used in the uploader. This abstract base class is used to provide some common features: * Flag to ask thread to exit as soon as possible. * Record exit/error status for the primary thread to pick up. * Capture exceptions and record them for pickup. * Some basic logging of thread start/stop. * All threads are "daemon" threads. * Friendly names for presenting to users. Concrete sub-classes must implement PerformWork(). Either self.NAME should be set or GetFriendlyName() be overridden to return a human-friendly name for this thread. The run() method starts the thread and prints start/exit messages. self.exit_flag is intended to signal that this thread should exit when it gets the chance. PerformWork() should check self.exit_flag whenever it has the opportunity to exit gracefully. """ def __init__(self): threading.Thread.__init__(self) self.setDaemon(True) self.exit_flag = False self.error = None def run(self): """Perform the work of the thread.""" logging.info('[%s] %s: started', self.getName(), self.__class__.__name__) try: self.PerformWork() except: self.error = sys.exc_info()[1] logging.exception('[%s] %s:', self.getName(), self.__class__.__name__) logging.info('[%s] %s: exiting', self.getName(), self.__class__.__name__) def PerformWork(self): """Perform the thread-specific work.""" raise NotImplementedError() def CheckError(self): """If an error is present, then log it.""" if self.error: logging.error('Error in %s: %s', self.GetFriendlyName(), self.error) def GetFriendlyName(self): """Returns a human-friendly description of the thread.""" if hasattr(self, 'NAME'): return self.NAME return 'unknown thread'class BulkLoaderThread(_ThreadBase): """A thread which transmits entities to the server application. This thread will read WorkItem instances from the work_queue and upload the entities to the server application. Progress information will be pushed into the progress_queue as the work is being performed. If a BulkLoaderThread encounters a transient error, the entities will be resent, if a fatal error is encoutered the BulkLoaderThread exits. """ def __init__(self, work_queue, throttle, thread_gate, request_manager): """Initialize the BulkLoaderThread instance. Args: work_queue: A queue containing WorkItems for processing. throttle: A Throttles to control upload bandwidth. thread_gate: A ThreadGate to control number of simultaneous uploads. request_manager: A RequestManager instance. """ _ThreadBase.__init__(self) self.work_queue = work_queue self.throttle = throttle self.thread_gate = thread_gate self.request_manager = request_manager def PerformWork(self): """Perform the work of a BulkLoaderThread.""" while not self.exit_flag: success = False self.thread_gate.StartWork() try: try: item = self.work_queue.get(block=True, timeout=1.0) except Queue.Empty: continue if item == _THREAD_SHOULD_EXIT: break logging.debug('[%s] Got work item [%d-%d]', self.getName(), item.key_start, item.key_end) try: item.MarkAsSending() try: if item.content is None: item.content = self.request_manager.EncodeContent(item.rows) try: self.request_manager.PostEntities(item) success = True logging.debug( '[%d-%d] Sent %d entities', item.key_start, item.key_end, item.count) self.throttle.AddTransfer(RECORDS, item.count) except (db.InternalError, db.NotSavedError, db.Timeout), e: logging.debug('Caught non-fatal error: %s', e) except urllib2.HTTPError, e: if e.code == 403 or (e.code >= 500 and e.code < 600): logging.debug('Caught HTTP error %d', e.code) logging.debug('%s', e.read()) else: raise e except: self.error = sys.exc_info()[1] logging.exception('[%s] %s: caught exception %s', self.getName(), self.__class__.__name__, str(sys.exc_info())) raise finally: if success: item.MarkAsSent() self.thread_gate.IncreaseWorkers() self.work_queue.task_done() else: item.MarkAsError() self.thread_gate.DecreaseWorkers() try: self.work_queue.reput(item, block=False) except Queue.Full: logging.error('[%s] Failed to reput work item.', self.getName()) raise Error('Failed to reput work item') logging.info('[%d-%d] %s', item.key_start, item.key_end, StateMessage(item.state)) finally: self.thread_gate.FinishWork() def GetFriendlyName(self): """Returns a human-friendly name for this thread.""" return 'worker [%s]' % self.getName()class DataSourceThread(_ThreadBase): """A thread which reads WorkItems and pushes them into queue. This thread will read/consume WorkItems from a generator (produced by the generator factory). These WorkItems will then be pushed into the work_queue. Note that reading will block if/when the work_queue becomes full. Information on content consumed from the generator will be pushed into the progress_queue. """ NAME = 'data source thread' def __init__(self, work_queue, progress_queue, workitem_generator_factory, progress_generator_factory): """Initialize the DataSourceThread instance. Args: work_queue: A queue containing WorkItems for processing. progress_queue: A queue used for tracking progress information. workitem_generator_factory: A factory that creates a WorkItem generator progress_generator_factory: A factory that creates a generator which produces prior progress status, or None if there is no prior status to use. """ _ThreadBase.__init__(self) self.work_queue = work_queue self.progress_queue = progress_queue self.workitem_generator_factory = workitem_generator_factory self.progress_generator_factory = progress_generator_factory self.entity_count = 0 def PerformWork(self): """Performs the work of a DataSourceThread.""" if self.progress_generator_factory: progress_gen = self.progress_generator_factory() else: progress_gen = None content_gen = self.workitem_generator_factory(self.progress_queue, progress_gen) self.sent_count = 0 self.read_count = 0 self.read_all = False for item in content_gen.Batches(): item.MarkAsRead() while not self.exit_flag: try: self.work_queue.put(item, block=True, timeout=1.0) self.entity_count += item.count break except Queue.Full: pass if self.exit_flag: break if not self.exit_flag: self.read_all = True self.read_count = content_gen.row_count self.sent_count = content_gen.sent_countdef _RunningInThread(thread): """Return True if we are running within the specified thread.""" return threading.currentThread().getName() == thread.getName()class ProgressDatabase(object): """Persistently record all progress information during an upload. This class wraps a very simple SQLite database which records each of the relevant details from the WorkItem instances. If the uploader is resumed, then data is replayed out of the database. """ def __init__(self, db_filename, commit_periodicity=100): """Initialize the ProgressDatabase instance. Args: db_filename: The name of the SQLite database to use. commit_periodicity: How many operations to perform between commits. """ self.db_filename = db_filename logging.info('Using progress database: %s', db_filename) self.primary_conn = sqlite3.connect(db_filename, isolation_level=None) self.primary_thread = threading.currentThread() self.progress_conn = None self.progress_thread = None self.operation_count = 0 self.commit_periodicity = commit_periodicity self.prior_key_end = None try: self.primary_conn.execute( """create table progress ( id integer primary key autoincrement, state integer not null, key_start integer not null, key_end integer not null ) """) except sqlite3.OperationalError, e: if 'already exists' not in e.message: raise try: self.primary_conn.execute('create index i_state on progress (state)') except sqlite3.OperationalError, e: if 'already exists' not in e.message: raise def ThreadComplete(self): """Finalize any operations the progress thread has performed. The database aggregates lots of operations into a single commit, and this method is used to commit any pending operations as the thread is about to shut down. """ if self.progress_conn: self._MaybeCommit(force_commit=True) def _MaybeCommit(self, force_commit=False): """Periodically commit changes into the SQLite database. Committing every operation is quite expensive, and slows down the operation of the script. Thus, we only commit after every N operations, as determined by the self.commit_periodicity value. Optionally, the caller can force a commit. Args: force_commit: Pass True in order for a commit to occur regardless of the current operation count. """ self.operation_count += 1 if force_commit or (self.operation_count % self.commit_periodicity) == 0: self.progress_conn.commit() def _OpenProgressConnection(self): """Possibly open a database connection for the progress tracker thread. If the connection is not open (for the calling thread, which is assumed to be the progress tracker thread), then open it. We also open a couple cursors for later use (and reuse). """ if self.progress_conn: return assert not _RunningInThread(self.primary_thread) self.progress_thread = threading.currentThread() self.progress_conn = sqlite3.connect(self.db_filename) self.insert_cursor = self.progress_conn.cursor() self.update_cursor = self.progress_conn.cursor() def HasUnfinishedWork(self): """Returns True if the database has progress information. Note there are two basic cases for progress information: 1) All saved records indicate a successful upload. In this case, we need to skip everything transmitted so far and then send the rest. 2) Some records for incomplete transfer are present. These need to be sent again, and then we resume sending after all the successful data. Returns: True if the database has progress information, False otherwise. Raises: ResumeError: If there is an error reading the progress database. """ assert _RunningInThread(self.primary_thread) cursor = self.primary_conn.cursor() cursor.execute('select count(*) from progress') row = cursor.fetchone() if row is None: raise ResumeError('Error reading progress information.') return row[0] != 0 def StoreKeys(self, key_start, key_end): """Record a new progress record, returning a key for later updates. The specified progress information will be persisted into the database. A unique key will be returned that identifies this progress state. The key is later used to (quickly) update this record. For the progress resumption to proceed properly, calls to StoreKeys MUST specify monotonically increasing key ranges. This will result in a database whereby the ID, KEY_START, and KEY_END rows are all increasing (rather than having ranges out of order). NOTE: the above precondition is NOT tested by this method (since it would imply an additional table read or two on each invocation). Args: key_start: The starting key of the WorkItem (inclusive) key_end: The end key of the WorkItem (inclusive) Returns: A string to later be used as a unique key to update this state. """ self._OpenProgressConnection() assert _RunningInThread(self.progress_thread) assert isinstance(key_start, int) assert isinstance(key_end, int) assert key_start <= key_end if self.prior_key_end is not None: assert key_start > self.prior_key_end self.prior_key_end = key_end self.insert_cursor.execute( 'insert into progress (state, key_start, key_end) values (?, ?, ?)', (STATE_READ, key_start, key_end)) progress_key = self.insert_cursor.lastrowid self._MaybeCommit() return progress_key def UpdateState(self, key, new_state): """Update a specified progress record with new information. Args: key: The key for this progress record, returned from StoreKeys new_state: The new state to associate with this progress record. """ self._OpenProgressConnection() assert _RunningInThread(self.progress_thread) assert isinstance(new_state, int) self.update_cursor.execute('update progress set state=? where id=?', (new_state, key)) self._MaybeCommit() def GetProgressStatusGenerator(self): """Get a generator which returns progress information. The returned generator will yield a series of 4-tuples that specify progress information about a prior run of the uploader. The 4-tuples have the following values: progress_key: The unique key to later update this record with new progress information. state: The last state saved for this progress record. key_start: The starting key of the items for uploading (inclusive). key_end: The ending key of the items for uploading (inclusive). After all incompletely-transferred records are provided, then one more 4-tuple will be generated: None DATA_CONSUMED_TO_HERE: A unique string value indicating this record is being provided. None key_end: An integer value specifying the last data source key that was handled by the previous run of the uploader. The caller should begin uploading records which occur after key_end. Yields: Progress information as tuples (progress_key, state, key_start, key_end). """ conn = sqlite3.connect(self.db_filename, isolation_level=None) cursor = conn.cursor() cursor.execute('select max(id) from progress') batch_id = cursor.fetchone()[0] cursor.execute('select key_end from progress where id = ?', (batch_id,)) key_end = cursor.fetchone()[0] self.prior_key_end = key_end cursor.execute( 'select id, state, key_start, key_end from progress' ' where state != ?' ' order by id', (STATE_SENT,)) rows = cursor.fetchall() for row in rows: if row is None: break yield row yield None, DATA_CONSUMED_TO_HERE, None, key_endclass StubProgressDatabase(object): """A stub implementation of ProgressDatabase which does nothing.""" def HasUnfinishedWork(self): """Whether the stub database has progress information (it doesn't).""" return False def StoreKeys(self, unused_key_start, unused_key_end): """Pretend to store a key in the stub database.""" return 'fake-key' def UpdateState(self, unused_key, unused_new_state): """Pretend to update the state of a progress item.""" pass def ThreadComplete(self): """Finalize operations on the stub database (i.e. do nothing).""" passclass ProgressTrackerThread(_ThreadBase): """A thread which records progress information for the upload process. The progress information is stored into the provided progress database. This class is not responsible for replaying a prior run's progress information out of the database. Separate mechanisms must be used to resume a prior upload attempt. """ NAME = 'progress tracking thread' def __init__(self, progress_queue, progress_db): """Initialize the ProgressTrackerThread instance. Args: progress_queue: A Queue used for tracking progress information. progress_db: The database for tracking progress information; should be an instance of ProgressDatabase. """ _ThreadBase.__init__(self) self.progress_queue = progress_queue self.db = progress_db self.entities_sent = 0 def PerformWork(self): """Performs the work of a ProgressTrackerThread.""" while not self.exit_flag: try: item = self.progress_queue.get(block=True, timeout=1.0) except Queue.Empty: continue if item == _THREAD_SHOULD_EXIT: break if item.state == STATE_READ and item.progress_key is None: item.progress_key = self.db.StoreKeys(item.key_start, item.key_end) else: assert item.progress_key is not None self.db.UpdateState(item.progress_key, item.state) if item.state == STATE_SENT: self.entities_sent += item.count item.progress_event.set() self.progress_queue.task_done() self.db.ThreadComplete()def Validate(value, typ): """Checks that value is non-empty and of the right type. Args: value: any value typ: a type or tuple of types Raises: ValueError if value is None or empty. TypeError if it's not the given type. """ if not value: raise ValueError('Value should not be empty; received %s.' % value) elif not isinstance(value, typ): raise TypeError('Expected a %s, but received %s (a %s).' % (typ, value, value.__class__))class Loader(object): """A base class for creating datastore entities from input data. To add a handler for bulk loading a new entity kind into your datastore, write a subclass of this class that calls Loader.__init__ from your class's __init__. If you need to run extra code to convert entities from the input data, create new properties, or otherwise modify the entities before they're inserted, override HandleEntity. See the CreateEntity method for the creation of entities from the (parsed) input data. """ __loaders = {} __kind = None __properties = None def __init__(self, kind, properties): """Constructor. Populates this Loader's kind and properties map. Also registers it with the bulk loader, so that all you need to do is instantiate your Loader, and the bulkload handler will automatically use it. Args: kind: a string containing the entity kind that this loader handles properties: list of (name, converter) tuples. This is used to automatically convert the CSV columns into properties. The converter should be a function that takes one argument, a string value from the CSV file, and returns a correctly typed property value that should be inserted. The tuples in this list should match the columns in your CSV file, in order. For example: [('name', str), ('id_number', int), ('email', datastore_types.Email), ('user', users.User), ('birthdate', lambda x: datetime.datetime.fromtimestamp(float(x))), ('description', datastore_types.Text), ] """ Validate(kind, basestring) self.__kind = kind db.class_for_kind(kind) Validate(properties, list) for name, fn in properties: Validate(name, basestring) assert callable(fn), ( 'Conversion function %s for property %s is not callable.' % (fn, name)) self.__properties = properties @staticmethod def RegisterLoader(loader): Loader.__loaders[loader.__kind] = loader def kind(self): """ Return the entity kind that this Loader handes. """ return self.__kind def CreateEntity(self, values, key_name=None): """Creates a entity from a list of property values. Args: values: list/tuple of str key_name: if provided, the name for the (single) resulting entity Returns: list of db.Model The returned entities are populated with the property values from the argument, converted to native types using the properties map given in the constructor, and passed through HandleEntity. They're ready to be inserted. Raises: AssertionError if the number of values doesn't match the number of properties in the properties map. ValueError if any element of values is None or empty. TypeError if values is not a list or tuple. """ Validate(values, (list, tuple)) assert len(values) == len(self.__properties), ( 'Expected %d CSV columns, found %d.' % (len(self.__properties), len(values))) model_class = db.class_for_kind(self.__kind) properties = {'key_name': key_name} for (name, converter), val in zip(self.__properties, values): if converter is bool and val.lower() in ('0', 'false', 'no'): val = False properties[name] = converter(val) entity = model_class(**properties) entities = self.HandleEntity(entity) if entities: if not isinstance(entities, (list, tuple)): entities = [entities] for entity in entities: if not isinstance(entity, db.Model): raise TypeError('Expected a db.Model, received %s (a %s).' % (entity, entity.__class__)) return entities def GenerateKey(self, i, values): """Generates a key_name to be used in creating the underlying object. The default implementation returns None. This method can be overridden to control the key generation for uploaded entities. The value returned should be None (to use a server generated numeric key), or a string which neither starts with a digit nor has the form __*__. (See http://code.google.com/appengine/docs/python/datastore/keysandentitygroups.html) If you generate your own string keys, keep in mind: 1. The key name for each entity must be unique. 2. If an entity of the same kind and key already exists in the datastore, it will be overwritten. Args: i: Number corresponding to this object (assume it's run in a loop, this is your current count. values: list/tuple of str. Returns: A string to be used as the key_name for an entity. """ return None def HandleEntity(self, entity): """Subclasses can override this to add custom entity conversion code. This is called for each entity, after its properties are populated from CSV but before it is stored. Subclasses can override this to add custom entity handling code. The entity to be inserted should be returned. If multiple entities should be inserted, return a list of entities. If no entities should be inserted, return None or []. Args: entity: db.Model Returns: db.Model or list of db.Model """ return entity @staticmethod def RegisteredLoaders(): """Returns a list of the Loader instances that have been created. """ return dict(Loader.__loaders)class QueueJoinThread(threading.Thread): """A thread that joins a queue and exits. Queue joins do not have a timeout. To simulate a queue join with timeout, run this thread and join it with a timeout. """ def __init__(self, queue): """Initialize a QueueJoinThread. Args: queue: The queue for this thread to join. """ threading.Thread.__init__(self) assert isinstance(queue, (Queue.Queue, ReQueue)) self.queue = queue def run(self): """Perform the queue join in this thread.""" self.queue.join()def InterruptibleQueueJoin(queue, thread_local, thread_gate, queue_join_thread_factory=QueueJoinThread): """Repeatedly joins the given ReQueue or Queue.Queue with short timeout. Between each timeout on the join, worker threads are checked. Args: queue: A Queue.Queue or ReQueue instance. thread_local: A threading.local instance which indicates interrupts. thread_gate: A ThreadGate instance. queue_join_thread_factory: Used for dependency injection. Returns: True unless the queue join is interrupted by SIGINT or worker death. """ thread = queue_join_thread_factory(queue) thread.start() while True: thread.join(timeout=.5) if not thread.isAlive(): return True if thread_local.shut_down: logging.debug('Queue join interrupted') return False for worker_thread in thread_gate.Threads(): if not worker_thread.isAlive(): return Falsedef ShutdownThreads(data_source_thread, work_queue, thread_gate): """Shuts down the worker and data source threads. Args: data_source_thread: A running DataSourceThread instance. work_queue: The work queue. thread_gate: A ThreadGate instance with workers registered. """ logging.info('An error occurred. Shutting down...') data_source_thread.exit_flag = True for thread in thread_gate.Threads(): thread.exit_flag = True for unused_thread in thread_gate.Threads(): thread_gate.EnableThread() data_source_thread.join(timeout=3.0) if data_source_thread.isAlive(): logging.warn('%s hung while trying to exit', data_source_thread.GetFriendlyName()) while not work_queue.empty(): try: unused_item = work_queue.get_nowait() work_queue.task_done() except Queue.Empty: passdef PerformBulkUpload(app_id, post_url, kind, workitem_generator_factory, num_threads, throttle, progress_db, max_queue_size=DEFAULT_QUEUE_SIZE, request_manager_factory=RequestManager, bulkloaderthread_factory=BulkLoaderThread, progresstrackerthread_factory=ProgressTrackerThread, datasourcethread_factory=DataSourceThread, work_queue_factory=ReQueue, progress_queue_factory=Queue.Queue): """Uploads data into an application using a series of HTTP POSTs. This function will spin up a number of threads to read entities from the data source, pass those to a number of worker ("uploader") threads for sending to the application, and track all of the progress in a small database in case an error or pause/termination requires a restart/resumption of the upload process. Args: app_id: String containing application id. post_url: URL to post the Entity data to. kind: Kind of the Entity records being posted. workitem_generator_factory: A factory that creates a WorkItem generator. num_threads: How many uploader threads should be created. throttle: A Throttle instance. progress_db: The database to use for replaying/recording progress. max_queue_size: Maximum size of the queues before they should block. request_manager_factory: Used for dependency injection. bulkloaderthread_factory: Used for dependency injection. progresstrackerthread_factory: Used for dependency injection. datasourcethread_factory: Used for dependency injection. work_queue_factory: Used for dependency injection. progress_queue_factory: Used for dependency injection. Raises: AuthenticationError: If authentication is required and fails. """ thread_gate = ThreadGate(True) (unused_scheme, host_port, url_path, unused_query, unused_fragment) = urlparse.urlsplit(post_url) work_queue = work_queue_factory(max_queue_size) progress_queue = progress_queue_factory(max_queue_size) request_manager = request_manager_factory(app_id, host_port, url_path, kind, throttle) throttle.Register(threading.currentThread()) try: request_manager.Authenticate() except Exception, e: logging.exception(e) raise AuthenticationError('Authentication failed') if (request_manager.credentials is not None and not request_manager.authenticated): raise AuthenticationError('Authentication failed') for unused_idx in range(num_threads): thread = bulkloaderthread_factory(work_queue, throttle, thread_gate, request_manager) throttle.Register(thread) thread_gate.Register(thread) progress_thread = progresstrackerthread_factory(progress_queue, progress_db) if progress_db.HasUnfinishedWork(): logging.debug('Restarting upload using progress database') progress_generator_factory = progress_db.GetProgressStatusGenerator else: progress_generator_factory = None data_source_thread = datasourcethread_factory(work_queue, progress_queue, workitem_generator_factory, progress_generator_factory) thread_local = threading.local() thread_local.shut_down = False def Interrupt(unused_signum, unused_frame): """Shutdown gracefully in response to a signal.""" thread_local.shut_down = True signal.signal(signal.SIGINT, Interrupt) progress_thread.start() data_source_thread.start() for thread in thread_gate.Threads(): thread.start() while not thread_local.shut_down: data_source_thread.join(timeout=0.25) if data_source_thread.isAlive(): for thread in list(thread_gate.Threads()) + [progress_thread]: if not thread.isAlive(): logging.info('Unexpected thread death: %s', thread.getName()) thread_local.shut_down = True break else: break if thread_local.shut_down: ShutdownThreads(data_source_thread, work_queue, thread_gate) def _Join(ob, msg): logging.debug('Waiting for %s...', msg) if isinstance(ob, threading.Thread): ob.join(timeout=3.0) if ob.isAlive(): logging.debug('Joining %s failed', ob.GetFriendlyName()) else: logging.debug('... done.') elif isinstance(ob, (Queue.Queue, ReQueue)): if not InterruptibleQueueJoin(ob, thread_local, thread_gate): ShutdownThreads(data_source_thread, work_queue, thread_gate) else: ob.join() logging.debug('... done.') _Join(work_queue, 'work_queue to flush') for unused_thread in thread_gate.Threads(): work_queue.put(_THREAD_SHOULD_EXIT) for unused_thread in thread_gate.Threads(): thread_gate.EnableThread() for thread in thread_gate.Threads(): _Join(thread, 'thread [%s] to terminate' % thread.getName()) thread.CheckError() if progress_thread.isAlive(): _Join(progress_queue, 'progress_queue to finish') else: logging.warn('Progress thread exited prematurely') progress_queue.put(_THREAD_SHOULD_EXIT) _Join(progress_thread, 'progress_thread to terminate') progress_thread.CheckError() data_source_thread.CheckError() total_up, duration = throttle.TotalTransferred(BANDWIDTH_UP) s_total_up, unused_duration = throttle.TotalTransferred(HTTPS_BANDWIDTH_UP) total_up += s_total_up logging.info('%d entites read, %d previously transferred', data_source_thread.read_count, data_source_thread.sent_count) logging.info('%d entities (%d bytes) transferred in %.1f seconds', progress_thread.entities_sent, total_up, duration) if (data_source_thread.read_all and progress_thread.entities_sent + data_source_thread.sent_count >= data_source_thread.read_count): logging.info('All entities successfully uploaded') else: logging.info('Some entities not successfully uploaded')def PrintUsageExit(code): """Prints usage information and exits with a status code. Args: code: Status code to pass to sys.exit() after displaying usage information. """ print __doc__ % {'arg0': sys.argv[0]} sys.stdout.flush() sys.stderr.flush() sys.exit(code)def ParseArguments(argv): """Parses command-line arguments. Prints out a help message if -h or --help is supplied. Args: argv: List of command-line arguments. Returns: Tuple (url, filename, cookie, batch_size, kind) containing the values from each corresponding command-line flag. """ opts, unused_args = getopt.getopt( argv[1:], 'h', ['debug', 'help', 'url=', 'filename=', 'batch_size=', 'kind=', 'num_threads=', 'bandwidth_limit=', 'rps_limit=', 'http_limit=', 'db_filename=', 'app_id=', 'config_file=', 'auth_domain=', ]) url = None filename = None batch_size = DEFAULT_BATCH_SIZE kind = None num_threads = DEFAULT_THREAD_COUNT bandwidth_limit = DEFAULT_BANDWIDTH_LIMIT rps_limit = DEFAULT_RPS_LIMIT http_limit = DEFAULT_REQUEST_LIMIT db_filename = None app_id = None config_file = None auth_domain = 'gmail.com' for option, value in opts: if option == '--debug': logging.getLogger().setLevel(logging.DEBUG) elif option in ('-h', '--help'): PrintUsageExit(0) elif option == '--url': url = value elif option == '--filename': filename = value elif option == '--batch_size': batch_size = int(value) elif option == '--kind': kind = value elif option == '--num_threads': num_threads = int(value) elif option == '--bandwidth_limit': bandwidth_limit = int(value) elif option == '--rps_limit': rps_limit = int(value) elif option == '--http_limit': http_limit = int(value) elif option == '--db_filename': db_filename = value elif option == '--app_id': app_id = value elif option == '--config_file': config_file = value elif option == '--auth_domain': auth_domain = value return ProcessArguments(app_id=app_id, url=url, filename=filename, batch_size=batch_size, kind=kind, num_threads=num_threads, bandwidth_limit=bandwidth_limit, rps_limit=rps_limit, http_limit=http_limit, db_filename=db_filename, config_file=config_file, auth_domain=auth_domain, die_fn=lambda: PrintUsageExit(1))def ThrottleLayout(bandwidth_limit, http_limit, rps_limit): return { BANDWIDTH_UP: bandwidth_limit, BANDWIDTH_DOWN: bandwidth_limit, REQUESTS: http_limit, HTTPS_BANDWIDTH_UP: bandwidth_limit / 5, HTTPS_BANDWIDTH_DOWN: bandwidth_limit / 5, HTTPS_REQUESTS: http_limit / 5, RECORDS: rps_limit, }def LoadConfig(config_file): """Loads a config file and registers any Loader classes present.""" if config_file: global_dict = dict(globals()) execfile(config_file, global_dict) for cls in Loader.__subclasses__(): Loader.RegisterLoader(cls())def _MissingArgument(arg_name, die_fn): """Print error message about missing argument and die.""" print >>sys.stderr, '%s argument required' % arg_name die_fn()def ProcessArguments(app_id=None, url=None, filename=None, batch_size=DEFAULT_BATCH_SIZE, kind=None, num_threads=DEFAULT_THREAD_COUNT, bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT, rps_limit=DEFAULT_RPS_LIMIT, http_limit=DEFAULT_REQUEST_LIMIT, db_filename=None, config_file=None, auth_domain='gmail.com', die_fn=lambda: sys.exit(1)): """Processes non command-line input arguments.""" if db_filename is None: db_filename = time.strftime('bulkloader-progress-%Y%m%d.%H%M%S.sql3') if batch_size <= 0: print >>sys.stderr, 'batch_size must be 1 or larger' die_fn() if url is None: _MissingArgument('url', die_fn) if filename is None: _MissingArgument('filename', die_fn) if kind is None: _MissingArgument('kind', die_fn) if config_file is None: _MissingArgument('config_file', die_fn) if app_id is None: (unused_scheme, host_port, unused_url_path, unused_query, unused_fragment) = urlparse.urlsplit(url) suffix_idx = host_port.find('.appspot.com') if suffix_idx > -1: app_id = host_port[:suffix_idx] elif host_port.split(':')[0].endswith('google.com'): app_id = host_port.split('.')[0] else: print >>sys.stderr, 'app_id required for non appspot.com domains' die_fn() return (app_id, url, filename, batch_size, kind, num_threads, bandwidth_limit, rps_limit, http_limit, db_filename, config_file, auth_domain)def _PerformBulkload(app_id=None, url=None, filename=None, batch_size=DEFAULT_BATCH_SIZE, kind=None, num_threads=DEFAULT_THREAD_COUNT, bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT, rps_limit=DEFAULT_RPS_LIMIT, http_limit=DEFAULT_REQUEST_LIMIT, db_filename=None, config_file=None, auth_domain='gmail.com'): """Runs the bulkloader, given the options as keyword arguments. Args: app_id: The application id. url: The url of the remote_api endpoint. filename: The name of the file containing the CSV data. batch_size: The number of records to send per request. kind: The kind of entity to transfer. num_threads: The number of threads to use to transfer data. bandwidth_limit: Maximum bytes/second to transfers. rps_limit: Maximum records/second to transfer. http_limit: Maximum requests/second for transfers. db_filename: The name of the SQLite3 progress database file. config_file: The name of the configuration file. auth_domain: The auth domain to use for logins and UserProperty. Returns: An exit code. """ os.environ['AUTH_DOMAIN'] = auth_domain LoadConfig(config_file) throttle_layout = ThrottleLayout(bandwidth_limit, http_limit, rps_limit) throttle = Throttle(layout=throttle_layout) workitem_generator_factory = GetCSVGeneratorFactory(filename, batch_size) if db_filename == 'skip': progress_db = StubProgressDatabase() else: progress_db = ProgressDatabase(db_filename) max_queue_size = max(DEFAULT_QUEUE_SIZE, 2 * num_threads + 5) PerformBulkUpload(app_id, url, kind, workitem_generator_factory, num_threads, throttle, progress_db, max_queue_size=max_queue_size) return 0def Run(app_id=None, url=None, filename=None, batch_size=DEFAULT_BATCH_SIZE, kind=None, num_threads=DEFAULT_THREAD_COUNT, bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT, rps_limit=DEFAULT_RPS_LIMIT, http_limit=DEFAULT_REQUEST_LIMIT, db_filename=None, auth_domain='gmail.com', config_file=None): """Sets up and runs the bulkloader, given the options as keyword arguments. Args: app_id: The application id. url: The url of the remote_api endpoint. filename: The name of the file containing the CSV data. batch_size: The number of records to send per request. kind: The kind of entity to transfer. num_threads: The number of threads to use to transfer data. bandwidth_limit: Maximum bytes/second to transfers. rps_limit: Maximum records/second to transfer. http_limit: Maximum requests/second for transfers. db_filename: The name of the SQLite3 progress database file. config_file: The name of the configuration file. auth_domain: The auth domain to use for logins and UserProperty. Returns: An exit code. """ logging.basicConfig( format='%(levelname)-8s %(asctime)s %(filename)s] %(message)s') args = ProcessArguments(app_id=app_id, url=url, filename=filename, batch_size=batch_size, kind=kind, num_threads=num_threads, bandwidth_limit=bandwidth_limit, rps_limit=rps_limit, http_limit=http_limit, db_filename=db_filename, config_file=config_file) (app_id, url, filename, batch_size, kind, num_threads, bandwidth_limit, rps_limit, http_limit, db_filename, config_file, auth_domain) = args return _PerformBulkload(app_id=app_id, url=url, filename=filename, batch_size=batch_size, kind=kind, num_threads=num_threads, bandwidth_limit=bandwidth_limit, rps_limit=rps_limit, http_limit=http_limit, db_filename=db_filename, config_file=config_file, auth_domain=auth_domain)def main(argv): """Runs the importer from the command line.""" logging.basicConfig( level=logging.INFO, format='%(levelname)-8s %(asctime)s %(filename)s] %(message)s') args = ParseArguments(argv) if None in args: print >>sys.stderr, 'Invalid arguments' PrintUsageExit(1) (app_id, url, filename, batch_size, kind, num_threads, bandwidth_limit, rps_limit, http_limit, db_filename, config_file, auth_domain) = args return _PerformBulkload(app_id=app_id, url=url, filename=filename, batch_size=batch_size, kind=kind, num_threads=num_threads, bandwidth_limit=bandwidth_limit, rps_limit=rps_limit, http_limit=http_limit, db_filename=db_filename, config_file=config_file, auth_domain=auth_domain)if __name__ == '__main__': sys.exit(main(sys.argv))