thirdparty/google_appengine/google/appengine/tools/bulkloader.py
author Pawel Solyga <Pawel.Solyga@gmail.com>
Thu, 12 Feb 2009 12:30:36 +0000
changeset 1278 a7766286a7be
child 2273 e4cb9c53db3e
permissions -rwxr-xr-x
Load /Users/solydzajs/Downloads/google_appengine into trunk/thirdparty/google_appengine.

#!/usr/bin/env python
#
# Copyright 2007 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Imports CSV data over HTTP.

Usage:
  %(arg0)s [flags]

    --debug                 Show debugging information. (Optional)
    --app_id=<string>       Application ID of endpoint (Optional for
                            *.appspot.com)
    --auth_domain=<domain>  The auth domain to use for logging in and for
                            UserProperties. (Default: gmail.com)
    --bandwidth_limit=<int> The maximum number of bytes per second for the
                            aggregate transfer of data to the server. Bursts
    --batch_size=<int>      Number of Entity objects to include in each post to
                            the URL endpoint. The more data per row/Entity, the
                            smaller the batch size should be. (Default 10)
    --config_file=<path>    File containing Model and Loader definitions.
                            (Required)
    --db_filename=<path>    Specific progress database to write to, or to
                            resume from. If not supplied, then a new database
                            will be started, named:
                            bulkloader-progress-TIMESTAMP.
                            The special filename "skip" may be used to simply
                            skip reading/writing any progress information.
    --filename=<path>       Path to the CSV file to import. (Required)
    --http_limit=<int>      The maximum numer of HTTP requests per second to
                            send to the server. (Default: 8)
    --kind=<string>         Name of the Entity object kind to put in the
                            datastore. (Required)
    --num_threads=<int>     Number of threads to use for uploading entities
                            (Default 10)
                            may exceed this, but overall transfer rate is
                            restricted to this rate. (Default 250000)
    --rps_limit=<int>       The maximum number of records per second to
                            transfer to the server. (Default: 20)
    --url=<string>          URL endpoint to post to for importing data.
                            (Required)

The exit status will be 0 on success, non-zero on import failure.

Works with the remote_api mix-in library for google.appengine.ext.remote_api.
Please look there for documentation about how to setup the server side.

Example:

%(arg0)s --url=http://app.appspot.com/remote_api --kind=Model \
 --filename=data.csv --config_file=loader_config.py

"""



import csv
import getopt
import getpass
import logging
import new
import os
import Queue
import signal
import sys
import threading
import time
import traceback
import urllib2
import urlparse

from google.appengine.ext import db
from google.appengine.ext.remote_api import remote_api_stub
from google.appengine.tools import appengine_rpc

try:
  import sqlite3
except ImportError:
  pass

UPLOADER_VERSION = '1'

DEFAULT_THREAD_COUNT = 10

DEFAULT_BATCH_SIZE = 10

DEFAULT_QUEUE_SIZE = DEFAULT_THREAD_COUNT * 10

_THREAD_SHOULD_EXIT = '_THREAD_SHOULD_EXIT'

STATE_READ = 0
STATE_SENDING = 1
STATE_SENT = 2
STATE_NOT_SENT = 3

MINIMUM_THROTTLE_SLEEP_DURATION = 0.001

DATA_CONSUMED_TO_HERE = 'DATA_CONSUMED_TO_HERE'

INITIAL_BACKOFF = 1.0

BACKOFF_FACTOR = 2.0


DEFAULT_BANDWIDTH_LIMIT = 250000

DEFAULT_RPS_LIMIT = 20

DEFAULT_REQUEST_LIMIT = 8

BANDWIDTH_UP = 'http-bandwidth-up'
BANDWIDTH_DOWN = 'http-bandwidth-down'
REQUESTS = 'http-requests'
HTTPS_BANDWIDTH_UP = 'https-bandwidth-up'
HTTPS_BANDWIDTH_DOWN = 'https-bandwidth-down'
HTTPS_REQUESTS = 'https-requests'
RECORDS = 'records'


def StateMessage(state):
  """Converts a numeric state identifier to a status message."""
  return ({
      STATE_READ: 'Batch read from file.',
      STATE_SENDING: 'Sending batch to server.',
      STATE_SENT: 'Batch successfully sent.',
      STATE_NOT_SENT: 'Error while sending batch.'
  }[state])


class Error(Exception):
  """Base-class for exceptions in this module."""


class FatalServerError(Error):
  """An unrecoverable error occurred while trying to post data to the server."""


class ResumeError(Error):
  """Error while trying to resume a partial upload."""


class ConfigurationError(Error):
  """Error in configuration options."""


class AuthenticationError(Error):
  """Error while trying to authenticate with the server."""


def GetCSVGeneratorFactory(csv_filename, batch_size,
                           openfile=open, create_csv_reader=csv.reader):
  """Return a factory that creates a CSV-based WorkItem generator.

  Args:
    csv_filename: File on disk containing CSV data.
    batch_size: Maximum number of CSV rows to stash into a WorkItem.
    openfile: Used for dependency injection.
    create_csv_reader: Used for dependency injection.

  Returns: A callable (accepting the Progress Queue and Progress
    Generators as input) which creates the WorkItem generator.
  """

  def CreateGenerator(progress_queue, progress_generator):
    """Initialize a CSV generator linked to a progress generator and queue.

    Args:
      progress_queue: A ProgressQueue instance to send progress information.
      progress_generator: A generator of progress information or None.

    Returns:
      A CSVGenerator instance.
    """
    return CSVGenerator(progress_queue,
                        progress_generator,
                        csv_filename,
                        batch_size,
                        openfile,
                        create_csv_reader)
  return CreateGenerator


class CSVGenerator(object):
  """Reads a CSV file and generates WorkItems containing batches of records."""

  def __init__(self,
               progress_queue,
               progress_generator,
               csv_filename,
               batch_size,
               openfile,
               create_csv_reader):
    """Initializes a CSV generator.

    Args:
      progress_queue: A queue used for tracking progress information.
      progress_generator: A generator of prior progress information, or None
        if there is no prior status.
      csv_filename: File on disk containing CSV data.
      batch_size: Maximum number of CSV rows to stash into a WorkItem.
      openfile: Used for dependency injection of 'open'.
      create_csv_reader: Used for dependency injection of 'csv.reader'.
    """
    self.progress_queue = progress_queue
    self.progress_generator = progress_generator
    self.csv_filename = csv_filename
    self.batch_size = batch_size
    self.openfile = openfile
    self.create_csv_reader = create_csv_reader
    self.line_number = 1
    self.column_count = None
    self.read_rows = []
    self.reader = None
    self.row_count = 0
    self.sent_count = 0

  def _AdvanceTo(self, line):
    """Advance the reader to the given line.

    Args:
      line: A line number to advance to.
    """
    while self.line_number < line:
      self.reader.next()
      self.line_number += 1
      self.row_count += 1
      self.sent_count += 1

  def _ReadRows(self, key_start, key_end):
    """Attempts to read and encode rows [key_start, key_end].

    The encoded rows are stored in self.read_rows.

    Args:
      key_start: The starting line number.
      key_end: The ending line number.

    Raises:
      StopIteration: if the reader runs out of rows
      ResumeError: if there are an inconsistent number of columns.
    """
    assert self.line_number == key_start
    self.read_rows = []
    while self.line_number <= key_end:
      row = self.reader.next()
      self.row_count += 1
      if self.column_count is None:
        self.column_count = len(row)
      else:
        if self.column_count != len(row):
          raise ResumeError('Column count mismatch, %d: %s' %
                            (self.column_count, str(row)))
      self.read_rows.append((self.line_number, row))
      self.line_number += 1

  def _MakeItem(self, key_start, key_end, rows, progress_key=None):
    """Makes a WorkItem containing the given rows, with the given keys.

    Args:
      key_start: The start key for the WorkItem.
      key_end: The end key for the WorkItem.
      rows: A list of the rows for the WorkItem.
      progress_key: The progress key for the WorkItem

    Returns:
      A WorkItem instance for the given batch.
    """
    assert rows

    item = WorkItem(self.progress_queue, rows,
                    key_start, key_end,
                    progress_key=progress_key)

    return item

  def Batches(self):
    """Reads the CSV data file and generates WorkItems.

    Yields:
      Instances of class WorkItem

    Raises:
      ResumeError: If the progress database and data file indicate a different
        number of rows.
    """
    csv_file = self.openfile(self.csv_filename, 'r')
    csv_content = csv_file.read()
    if csv_content:
      has_headers = csv.Sniffer().has_header(csv_content)
    else:
      has_headers = False
    csv_file.seek(0)
    self.reader = self.create_csv_reader(csv_file, skipinitialspace=True)
    if has_headers:
      logging.info('The CSV file appears to have a header line, skipping.')
      self.reader.next()

    exhausted = False

    self.line_number = 1
    self.column_count = None

    logging.info('Starting import; maximum %d entities per post',
                 self.batch_size)

    state = None
    if self.progress_generator is not None:
      for progress_key, state, key_start, key_end in self.progress_generator:
        if key_start:
          try:
            self._AdvanceTo(key_start)
            self._ReadRows(key_start, key_end)
            yield self._MakeItem(key_start,
                                 key_end,
                                 self.read_rows,
                                 progress_key=progress_key)
          except StopIteration:
            logging.error('Mismatch between data file and progress database')
            raise ResumeError(
                'Mismatch between data file and progress database')
        elif state == DATA_CONSUMED_TO_HERE:
          try:
            self._AdvanceTo(key_end + 1)
          except StopIteration:
            state = None

    if self.progress_generator is None or state == DATA_CONSUMED_TO_HERE:
      while not exhausted:
        key_start = self.line_number
        key_end = self.line_number + self.batch_size - 1
        try:
          self._ReadRows(key_start, key_end)
        except StopIteration:
          exhausted = True
          key_end = self.line_number - 1
        if key_start <= key_end:
          yield self._MakeItem(key_start, key_end, self.read_rows)


class ReQueue(object):
  """A special thread-safe queue.

  A ReQueue allows unfinished work items to be returned with a call to
  reput().  When an item is reput, task_done() should *not* be called
  in addition, getting an item that has been reput does not increase
  the number of outstanding tasks.

  This class shares an interface with Queue.Queue and provides the
  additional Reput method.
  """

  def __init__(self,
               queue_capacity,
               requeue_capacity=None,
               queue_factory=Queue.Queue,
               get_time=time.time):
    """Initialize a ReQueue instance.

    Args:
      queue_capacity: The number of items that can be put in the ReQueue.
      requeue_capacity: The numer of items that can be reput in the ReQueue.
      queue_factory: Used for dependency injection.
      get_time: Used for dependency injection.
    """
    if requeue_capacity is None:
      requeue_capacity = queue_capacity

    self.get_time = get_time
    self.queue = queue_factory(queue_capacity)
    self.requeue = queue_factory(requeue_capacity)
    self.lock = threading.Lock()
    self.put_cond = threading.Condition(self.lock)
    self.get_cond = threading.Condition(self.lock)

  def _DoWithTimeout(self,
                     action,
                     exc,
                     wait_cond,
                     done_cond,
                     lock,
                     timeout=None,
                     block=True):
    """Performs the given action with a timeout.

    The action must be non-blocking, and raise an instance of exc on a
    recoverable failure.  If the action fails with an instance of exc,
    we wait on wait_cond before trying again.  Failure after the
    timeout is reached is propagated as an exception.  Success is
    signalled by notifying on done_cond and returning the result of
    the action.  If action raises any exception besides an instance of
    exc, it is immediately propagated.

    Args:
      action: A callable that performs a non-blocking action.
      exc: An exception type that is thrown by the action to indicate
        a recoverable error.
      wait_cond: A condition variable which should be waited on when
        action throws exc.
      done_cond: A condition variable to signal if the action returns.
      lock: The lock used by wait_cond and done_cond.
      timeout: A non-negative float indicating the maximum time to wait.
      block: Whether to block if the action cannot complete immediately.

    Returns:
      The result of the action, if it is successful.

    Raises:
      ValueError: If the timeout argument is negative.
    """
    if timeout is not None and timeout < 0.0:
      raise ValueError('\'timeout\' must not be a negative  number')
    if not block:
      timeout = 0.0
    result = None
    success = False
    start_time = self.get_time()
    lock.acquire()
    try:
      while not success:
        try:
          result = action()
          success = True
        except Exception, e:
          if not isinstance(e, exc):
            raise e
          if timeout is not None:
            elapsed_time = self.get_time() - start_time
            timeout -= elapsed_time
            if timeout <= 0.0:
              raise e
          wait_cond.wait(timeout)
    finally:
      if success:
        done_cond.notify()
      lock.release()
    return result

  def put(self, item, block=True, timeout=None):
    """Put an item into the requeue.

    Args:
      item: An item to add to the requeue.
      block: Whether to block if the requeue is full.
      timeout: Maximum on how long to wait until the queue is non-full.

    Raises:
      Queue.Full if the queue is full and the timeout expires.
    """
    def PutAction():
      self.queue.put(item, block=False)
    self._DoWithTimeout(PutAction,
                        Queue.Full,
                        self.get_cond,
                        self.put_cond,
                        self.lock,
                        timeout=timeout,
                        block=block)

  def reput(self, item, block=True, timeout=None):
    """Re-put an item back into the requeue.

    Re-putting an item does not increase the number of outstanding
    tasks, so the reput item should be uniquely associated with an
    item that was previously removed from the requeue and for which
    task_done has not been called.

    Args:
      item: An item to add to the requeue.
      block: Whether to block if the requeue is full.
      timeout: Maximum on how long to wait until the queue is non-full.

    Raises:
      Queue.Full is the queue is full and the timeout expires.
    """
    def ReputAction():
      self.requeue.put(item, block=False)
    self._DoWithTimeout(ReputAction,
                        Queue.Full,
                        self.get_cond,
                        self.put_cond,
                        self.lock,
                        timeout=timeout,
                        block=block)

  def get(self, block=True, timeout=None):
    """Get an item from the requeue.

    Args:
      block: Whether to block if the requeue is empty.
      timeout: Maximum on how long to wait until the requeue is non-empty.

    Returns:
      An item from the requeue.

    Raises:
      Queue.Empty if the queue is empty and the timeout expires.
    """
    def GetAction():
      try:
        result = self.requeue.get(block=False)
        self.requeue.task_done()
      except Queue.Empty:
        result = self.queue.get(block=False)
      return result
    return self._DoWithTimeout(GetAction,
                               Queue.Empty,
                               self.put_cond,
                               self.get_cond,
                               self.lock,
                               timeout=timeout,
                               block=block)

  def join(self):
    """Blocks until all of the items in the requeue have been processed."""
    self.queue.join()

  def task_done(self):
    """Indicate that a previously enqueued item has been fully processed."""
    self.queue.task_done()

  def empty(self):
    """Returns true if the requeue is empty."""
    return self.queue.empty() and self.requeue.empty()

  def get_nowait(self):
    """Try to get an item from the queue without blocking."""
    return self.get(block=False)


class ThrottleHandler(urllib2.BaseHandler):
  """A urllib2 handler for http and https requests that adds to a throttle."""

  def __init__(self, throttle):
    """Initialize a ThrottleHandler.

    Args:
      throttle: A Throttle instance to call for bandwidth and http/https request
        throttling.
    """
    self.throttle = throttle

  def AddRequest(self, throttle_name, req):
    """Add to bandwidth throttle for given request.

    Args:
      throttle_name: The name of the bandwidth throttle to add to.
      req: The request whose size will be added to the throttle.
    """
    size = 0
    for key, value in req.headers.iteritems():
      size += len('%s: %s\n' % (key, value))
    for key, value in req.unredirected_hdrs.iteritems():
      size += len('%s: %s\n' % (key, value))
    (unused_scheme,
     unused_host_port, url_path,
     unused_query, unused_fragment) = urlparse.urlsplit(req.get_full_url())
    size += len('%s %s HTTP/1.1\n' % (req.get_method(), url_path))
    data = req.get_data()
    if data:
      size += len(data)
    self.throttle.AddTransfer(throttle_name, size)

  def AddResponse(self, throttle_name, res):
    """Add to bandwidth throttle for given response.

    Args:
      throttle_name: The name of the bandwidth throttle to add to.
      res: The response whose size will be added to the throttle.
    """
    content = res.read()
    def ReturnContent():
      return content
    res.read = ReturnContent
    size = len(content)
    headers = res.info()
    for key, value in headers.items():
      size += len('%s: %s\n' % (key, value))
    self.throttle.AddTransfer(throttle_name, size)

  def http_request(self, req):
    """Process an HTTP request.

    If the throttle is over quota, sleep first.  Then add request size to
    throttle before returning it to be sent.

    Args:
      req: A urllib2.Request object.

    Returns:
      The request passed in.
    """
    self.throttle.Sleep()
    self.AddRequest(BANDWIDTH_UP, req)
    return req

  def https_request(self, req):
    """Process an HTTPS request.

    If the throttle is over quota, sleep first.  Then add request size to
    throttle before returning it to be sent.

    Args:
      req: A urllib2.Request object.

    Returns:
      The request passed in.
    """
    self.throttle.Sleep()
    self.AddRequest(HTTPS_BANDWIDTH_UP, req)
    return req

  def http_response(self, unused_req, res):
    """Process an HTTP response.

    The size of the response is added to the bandwidth throttle and the request
    throttle is incremented by one.

    Args:
      unused_req: The urllib2 request for this response.
      res: A urllib2 response object.

    Returns:
      The response passed in.
    """
    self.AddResponse(BANDWIDTH_DOWN, res)
    self.throttle.AddTransfer(REQUESTS, 1)
    return res

  def https_response(self, unused_req, res):
    """Process an HTTPS response.

    The size of the response is added to the bandwidth throttle and the request
    throttle is incremented by one.

    Args:
      unused_req: The urllib2 request for this response.
      res: A urllib2 response object.

    Returns:
      The response passed in.
    """
    self.AddResponse(HTTPS_BANDWIDTH_DOWN, res)
    self.throttle.AddTransfer(HTTPS_REQUESTS, 1)
    return res


class ThrottledHttpRpcServer(appengine_rpc.HttpRpcServer):
  """Provides a simplified RPC-style interface for HTTP requests.

  This RPC server uses a Throttle to prevent exceeding quotas.
  """

  def __init__(self, throttle, request_manager, *args, **kwargs):
    """Initialize a ThrottledHttpRpcServer.

    Also sets request_manager.rpc_server to the ThrottledHttpRpcServer instance.

    Args:
      throttle: A Throttles instance.
      request_manager: A RequestManager instance.
      args: Positional arguments to pass through to
        appengine_rpc.HttpRpcServer.__init__
      kwargs: Keyword arguments to pass through to
        appengine_rpc.HttpRpcServer.__init__
    """
    self.throttle = throttle
    appengine_rpc.HttpRpcServer.__init__(self, *args, **kwargs)
    request_manager.rpc_server = self

  def _GetOpener(self):
    """Returns an OpenerDirector that supports cookies and ignores redirects.

    Returns:
      A urllib2.OpenerDirector object.
    """
    opener = appengine_rpc.HttpRpcServer._GetOpener(self)
    opener.add_handler(ThrottleHandler(self.throttle))

    return opener


def ThrottledHttpRpcServerFactory(throttle, request_manager):
  """Create a factory to produce ThrottledHttpRpcServer for a given throttle.

  Args:
    throttle: A Throttle instance to use for the ThrottledHttpRpcServer.
    request_manager: A RequestManager instance.

  Returns:
    A factory to produce a ThrottledHttpRpcServer.
  """
  def MakeRpcServer(*args, **kwargs):
    kwargs['account_type'] = 'HOSTED_OR_GOOGLE'
    kwargs['save_cookies'] = True
    return ThrottledHttpRpcServer(throttle, request_manager, *args, **kwargs)
  return MakeRpcServer


class RequestManager(object):
  """A class which wraps a connection to the server."""

  source = 'google-bulkloader-%s' % UPLOADER_VERSION
  user_agent = source

  def __init__(self,
               app_id,
               host_port,
               url_path,
               kind,
               throttle):
    """Initialize a RequestManager object.

    Args:
      app_id: String containing the application id for requests.
      host_port: String containing the "host:port" pair; the port is optional.
      url_path: partial URL (path) to post entity data to.
      kind: Kind of the Entity records being posted.
      throttle: A Throttle instance.
    """
    self.app_id = app_id
    self.host_port = host_port
    self.host = host_port.split(':')[0]
    if url_path and url_path[0] != '/':
      url_path = '/' + url_path
    self.url_path = url_path
    self.kind = kind
    self.throttle = throttle
    self.credentials = None
    throttled_rpc_server_factory = ThrottledHttpRpcServerFactory(
        self.throttle, self)
    logging.debug('Configuring remote_api. app_id = %s, url_path = %s, '
                  'servername = %s' % (app_id, url_path, host_port))
    remote_api_stub.ConfigureRemoteDatastore(
        app_id,
        url_path,
        self.AuthFunction,
        servername=host_port,
        rpc_server_factory=throttled_rpc_server_factory)
    self.authenticated = False

  def Authenticate(self):
    """Invoke authentication if necessary."""
    self.rpc_server.Send(self.url_path, payload=None)
    self.authenticated = True

  def AuthFunction(self,
                   raw_input_fn=raw_input,
                   password_input_fn=getpass.getpass):
    """Prompts the user for a username and password.

    Caches the results the first time it is called and returns the
    same result every subsequent time.

    Args:
      raw_input_fn: Used for dependency injection.
      password_input_fn: Used for dependency injection.

    Returns:
      A pair of the username and password.
    """
    if self.credentials is not None:
      return self.credentials
    print 'Please enter login credentials for %s (%s)' % (
        self.host, self.app_id)
    email = raw_input_fn('Email: ')
    if email:
      password_prompt = 'Password for %s: ' % email
      password = password_input_fn(password_prompt)
    else:
      password = None
    self.credentials = (email, password)
    return self.credentials

  def _GetHeaders(self):
    """Constructs a dictionary of extra headers to send with a request."""
    headers = {
        'GAE-Uploader-Version': UPLOADER_VERSION,
        'GAE-Uploader-Kind': self.kind
        }
    return headers

  def EncodeContent(self, rows):
    """Encodes row data to the wire format.

    Args:
      rows: A list of pairs of a line number and a list of column values.

    Returns:
      A list of db.Model instances.
    """
    try:
      loader = Loader.RegisteredLoaders()[self.kind]
    except KeyError:
      logging.error('No Loader defined for kind %s.' % self.kind)
      raise ConfigurationError('No Loader defined for kind %s.' % self.kind)
    entities = []
    for line_number, values in rows:
      key = loader.GenerateKey(line_number, values)
      entity = loader.CreateEntity(values, key_name=key)
      entities.extend(entity)

    return entities

  def PostEntities(self, item):
    """Posts Entity records to a remote endpoint over HTTP.

    Args:
      item: A workitem containing the entities to post.

    Returns:
      A pair of the estimated size of the request in bytes and the response
        from the server as a str.
    """
    entities = item.content
    db.put(entities)


class WorkItem(object):
  """Holds a unit of uploading work.

  A WorkItem represents a number of entities that need to be uploaded to
  Google App Engine. These entities are encoded in the "content" field of
  the WorkItem, and will be POST'd as-is to the server.

  The entities are identified by a range of numeric keys, inclusively. In
  the case of a resumption of an upload, or a replay to correct errors,
  these keys must be able to identify the same set of entities.

  Note that keys specify a range. The entities do not have to sequentially
  fill the entire range, they must simply bound a range of valid keys.
  """

  def __init__(self, progress_queue, rows, key_start, key_end,
               progress_key=None):
    """Initialize the WorkItem instance.

    Args:
      progress_queue: A queue used for tracking progress information.
      rows: A list of pairs of a line number and a list of column values
      key_start: The (numeric) starting key, inclusive.
      key_end: The (numeric) ending key, inclusive.
      progress_key: If this WorkItem represents state from a prior run,
        then this will be the key within the progress database.
    """
    self.state = STATE_READ

    self.progress_queue = progress_queue

    assert isinstance(key_start, (int, long))
    assert isinstance(key_end, (int, long))
    assert key_start <= key_end

    self.key_start = key_start
    self.key_end = key_end
    self.progress_key = progress_key

    self.progress_event = threading.Event()

    self.rows = rows
    self.content = None
    self.count = len(rows)

  def MarkAsRead(self):
    """Mark this WorkItem as read/consumed from the data source."""

    assert self.state == STATE_READ

    self._StateTransition(STATE_READ, blocking=True)

    assert self.progress_key is not None

  def MarkAsSending(self):
    """Mark this WorkItem as in-process on being uploaded to the server."""

    assert self.state == STATE_READ or self.state == STATE_NOT_SENT
    assert self.progress_key is not None

    self._StateTransition(STATE_SENDING, blocking=True)

  def MarkAsSent(self):
    """Mark this WorkItem as sucessfully-sent to the server."""

    assert self.state == STATE_SENDING
    assert self.progress_key is not None

    self._StateTransition(STATE_SENT, blocking=False)

  def MarkAsError(self):
    """Mark this WorkItem as required manual error recovery."""

    assert self.state == STATE_SENDING
    assert self.progress_key is not None

    self._StateTransition(STATE_NOT_SENT, blocking=True)

  def _StateTransition(self, new_state, blocking=False):
    """Transition the work item to a new state, storing progress information.

    Args:
      new_state: The state to transition to.
      blocking: Whether to block for the progress thread to acknowledge the
        transition.
    """
    logging.debug('[%s-%s] %s' %
                  (self.key_start, self.key_end, StateMessage(self.state)))
    assert not self.progress_event.isSet()

    self.state = new_state

    self.progress_queue.put(self)

    if blocking:
      self.progress_event.wait()

      self.progress_event.clear()



def InterruptibleSleep(sleep_time):
  """Puts thread to sleep, checking this threads exit_flag twice a second.

  Args:
    sleep_time: Time to sleep.
  """
  slept = 0.0
  epsilon = .0001
  thread = threading.currentThread()
  while slept < sleep_time - epsilon:
    remaining = sleep_time - slept
    this_sleep_time = min(remaining, 0.5)
    time.sleep(this_sleep_time)
    slept += this_sleep_time
    if thread.exit_flag:
      return


class ThreadGate(object):
  """Manage the number of active worker threads.

  The ThreadGate limits the number of threads that are simultaneously
  uploading batches of records in order to implement adaptive rate
  control.  The number of simultaneous upload threads that it takes to
  start causing timeout varies widely over the course of the day, so
  adaptive rate control allows the uploader to do many uploads while
  reducing the error rate and thus increasing the throughput.

  Initially the ThreadGate allows only one uploader thread to be active.
  For each successful upload, another thread is activated and for each
  failed upload, the number of active threads is reduced by one.
  """

  def __init__(self, enabled, sleep=InterruptibleSleep):
    self.enabled = enabled
    self.enabled_count = 1
    self.lock = threading.Lock()
    self.thread_semaphore = threading.Semaphore(self.enabled_count)
    self._threads = []
    self.backoff_time = 0
    self.sleep = sleep

  def Register(self, thread):
    """Register a thread with the thread gate."""
    self._threads.append(thread)

  def Threads(self):
    """Yields the registered threads."""
    for thread in self._threads:
      yield thread

  def EnableThread(self):
    """Enable one more worker thread."""
    self.lock.acquire()
    try:
      self.enabled_count += 1
    finally:
      self.lock.release()
    self.thread_semaphore.release()

  def EnableAllThreads(self):
    """Enable all worker threads."""
    for unused_idx in range(len(self._threads) - self.enabled_count):
      self.EnableThread()

  def StartWork(self):
    """Starts a critical section in which the number of workers is limited.

    If thread throttling is enabled then this method starts a critical
    section which allows self.enabled_count simultaneously operating
    threads. The critical section is ended by calling self.FinishWork().
    """
    if self.enabled:
      self.thread_semaphore.acquire()
      if self.backoff_time > 0.0:
        if not threading.currentThread().exit_flag:
          logging.info('Backing off: %.1f seconds',
                       self.backoff_time)
        self.sleep(self.backoff_time)

  def FinishWork(self):
    """Ends a critical section started with self.StartWork()."""
    if self.enabled:
      self.thread_semaphore.release()

  def IncreaseWorkers(self):
    """Informs the throttler that an item was successfully sent.

    If thread throttling is enabled, this method will cause an
    additional thread to run in the critical section.
    """
    if self.enabled:
      if self.backoff_time > 0.0:
        logging.info('Resetting backoff to 0.0')
        self.backoff_time = 0.0
      do_enable = False
      self.lock.acquire()
      try:
        if self.enabled and len(self._threads) > self.enabled_count:
          do_enable = True
          self.enabled_count += 1
      finally:
        self.lock.release()
      if do_enable:
        self.thread_semaphore.release()

  def DecreaseWorkers(self):
    """Informs the thread_gate that an item failed to send.

    If thread throttling is enabled, this method will cause the
    throttler to allow one fewer thread in the critical section. If
    there is only one thread remaining, failures will result in
    exponential backoff until there is a success.
    """
    if self.enabled:
      do_disable = False
      self.lock.acquire()
      try:
        if self.enabled:
          if self.enabled_count > 1:
            do_disable = True
            self.enabled_count -= 1
          else:
            if self.backoff_time == 0.0:
              self.backoff_time = INITIAL_BACKOFF
            else:
              self.backoff_time *= BACKOFF_FACTOR
      finally:
        self.lock.release()
      if do_disable:
        self.thread_semaphore.acquire()


class Throttle(object):
  """A base class for upload rate throttling.

  Transferring large number of records, too quickly, to an application
  could trigger quota limits and cause the transfer process to halt.
  In order to stay within the application's quota, we throttle the
  data transfer to a specified limit (across all transfer threads).
  This limit defaults to about half of the Google App Engine default
  for an application, but can be manually adjusted faster/slower as
  appropriate.

  This class tracks a moving average of some aspect of the transfer
  rate (bandwidth, records per second, http connections per
  second). It keeps two windows of counts of bytes transferred, on a
  per-thread basis. One block is the "current" block, and the other is
  the "prior" block. It will rotate the counts from current to prior
  when ROTATE_PERIOD has passed.  Thus, the current block will
  represent from 0 seconds to ROTATE_PERIOD seconds of activity
  (determined by: time.time() - self.last_rotate).  The prior block
  will always represent a full ROTATE_PERIOD.

  Sleeping is performed just before a transfer of another block, and is
  based on the counts transferred *before* the next transfer. It really
  does not matter how much will be transferred, but only that for all the
  data transferred SO FAR that we have interspersed enough pauses to
  ensure the aggregate transfer rate is within the specified limit.

  These counts are maintained on a per-thread basis, so we do not require
  any interlocks around incrementing the counts. There IS an interlock on
  the rotation of the counts because we do not want multiple threads to
  multiply-rotate the counts.

  There are various race conditions in the computation and collection
  of these counts. We do not require precise values, but simply to
  keep the overall transfer within the bandwidth limits. If a given
  pause is a little short, or a little long, then the aggregate delays
  will be correct.
  """

  ROTATE_PERIOD = 600

  def __init__(self,
               get_time=time.time,
               thread_sleep=InterruptibleSleep,
               layout=None):
    self.get_time = get_time
    self.thread_sleep = thread_sleep

    self.start_time = get_time()
    self.transferred = {}
    self.prior_block = {}
    self.totals = {}
    self.throttles = {}

    self.last_rotate = {}
    self.rotate_mutex = {}
    if layout:
      self.AddThrottles(layout)

  def AddThrottle(self, name, limit):
    self.throttles[name] = limit
    self.transferred[name] = {}
    self.prior_block[name] = {}
    self.totals[name] = {}
    self.last_rotate[name] = self.get_time()
    self.rotate_mutex[name] = threading.Lock()

  def AddThrottles(self, layout):
    for key, value in layout.iteritems():
      self.AddThrottle(key, value)

  def Register(self, thread):
    """Register this thread with the throttler."""
    thread_name = thread.getName()
    for throttle_name in self.throttles.iterkeys():
      self.transferred[throttle_name][thread_name] = 0
      self.prior_block[throttle_name][thread_name] = 0
      self.totals[throttle_name][thread_name] = 0

  def VerifyName(self, throttle_name):
    if throttle_name not in self.throttles:
      raise AssertionError('%s is not a registered throttle' % throttle_name)

  def AddTransfer(self, throttle_name, token_count):
    """Add a count to the amount this thread has transferred.

    Each time a thread transfers some data, it should call this method to
    note the amount sent. The counts may be rotated if sufficient time
    has passed since the last rotation.

    Note: this method should only be called by the BulkLoaderThread
    instances. The token count is allocated towards the
    "current thread".

    Args:
      throttle_name: The name of the throttle to add to.
      token_count: The number to add to the throttle counter.
    """
    self.VerifyName(throttle_name)
    transferred = self.transferred[throttle_name]
    transferred[threading.currentThread().getName()] += token_count

    if self.last_rotate[throttle_name] + self.ROTATE_PERIOD < self.get_time():
      self._RotateCounts(throttle_name)

  def Sleep(self, throttle_name=None):
    """Possibly sleep in order to limit the transfer rate.

    Note that we sleep based on *prior* transfers rather than what we
    may be about to transfer. The next transfer could put us under/over
    and that will be rectified *after* that transfer. Net result is that
    the average transfer rate will remain within bounds. Spiky behavior
    or uneven rates among the threads could possibly bring the transfer
    rate above the requested limit for short durations.

    Args:
      throttle_name: The name of the throttle to sleep on.  If None or
        omitted, then sleep on all throttles.
    """
    if throttle_name is None:
      for throttle_name in self.throttles:
        self.Sleep(throttle_name=throttle_name)
      return

    self.VerifyName(throttle_name)

    thread = threading.currentThread()

    while True:
      duration = self.get_time() - self.last_rotate[throttle_name]

      total = 0
      for count in self.prior_block[throttle_name].values():
        total += count

      if total:
        duration += self.ROTATE_PERIOD

      for count in self.transferred[throttle_name].values():
        total += count

      sleep_time = (float(total) / self.throttles[throttle_name]) - duration

      if sleep_time < MINIMUM_THROTTLE_SLEEP_DURATION:
        break

      logging.debug('[%s] Throttling on %s. Sleeping for %.1f ms '
                    '(duration=%.1f ms, total=%d)',
                    thread.getName(), throttle_name,
                    sleep_time * 1000, duration * 1000, total)
      self.thread_sleep(sleep_time)
      if thread.exit_flag:
        break
      self._RotateCounts(throttle_name)

  def _RotateCounts(self, throttle_name):
    """Rotate the transfer counters.

    If sufficient time has passed, then rotate the counters from active to
    the prior-block of counts.

    This rotation is interlocked to ensure that multiple threads do not
    over-rotate the counts.

    Args:
      throttle_name: The name of the throttle to rotate.
    """
    self.VerifyName(throttle_name)
    self.rotate_mutex[throttle_name].acquire()
    try:
      next_rotate_time = self.last_rotate[throttle_name] + self.ROTATE_PERIOD
      if next_rotate_time >= self.get_time():
        return

      for name, count in self.transferred[throttle_name].items():


        self.prior_block[throttle_name][name] = count
        self.transferred[throttle_name][name] = 0

        self.totals[throttle_name][name] += count

      self.last_rotate[throttle_name] = self.get_time()

    finally:
      self.rotate_mutex[throttle_name].release()

  def TotalTransferred(self, throttle_name):
    """Return the total transferred, and over what period.

    Args:
      throttle_name: The name of the throttle to total.

    Returns:
      A tuple of the total count and running time for the given throttle name.
    """
    total = 0
    for count in self.totals[throttle_name].values():
      total += count
    for count in self.transferred[throttle_name].values():
      total += count
    return total, self.get_time() - self.start_time


class _ThreadBase(threading.Thread):
  """Provide some basic features for the threads used in the uploader.

  This abstract base class is used to provide some common features:

  * Flag to ask thread to exit as soon as possible.
  * Record exit/error status for the primary thread to pick up.
  * Capture exceptions and record them for pickup.
  * Some basic logging of thread start/stop.
  * All threads are "daemon" threads.
  * Friendly names for presenting to users.

  Concrete sub-classes must implement PerformWork().

  Either self.NAME should be set or GetFriendlyName() be overridden to
  return a human-friendly name for this thread.

  The run() method starts the thread and prints start/exit messages.

  self.exit_flag is intended to signal that this thread should exit
  when it gets the chance.  PerformWork() should check self.exit_flag
  whenever it has the opportunity to exit gracefully.
  """

  def __init__(self):
    threading.Thread.__init__(self)

    self.setDaemon(True)

    self.exit_flag = False
    self.error = None

  def run(self):
    """Perform the work of the thread."""
    logging.info('[%s] %s: started', self.getName(), self.__class__.__name__)

    try:
      self.PerformWork()
    except:
      self.error = sys.exc_info()[1]
      logging.exception('[%s] %s:', self.getName(), self.__class__.__name__)

    logging.info('[%s] %s: exiting', self.getName(), self.__class__.__name__)

  def PerformWork(self):
    """Perform the thread-specific work."""
    raise NotImplementedError()

  def CheckError(self):
    """If an error is present, then log it."""
    if self.error:
      logging.error('Error in %s: %s', self.GetFriendlyName(), self.error)

  def GetFriendlyName(self):
    """Returns a human-friendly description of the thread."""
    if hasattr(self, 'NAME'):
      return self.NAME
    return 'unknown thread'


class BulkLoaderThread(_ThreadBase):
  """A thread which transmits entities to the server application.

  This thread will read WorkItem instances from the work_queue and upload
  the entities to the server application. Progress information will be
  pushed into the progress_queue as the work is being performed.

  If a BulkLoaderThread encounters a transient error, the entities will be
  resent, if a fatal error is encoutered the BulkLoaderThread exits.
  """

  def __init__(self,
               work_queue,
               throttle,
               thread_gate,
               request_manager):
    """Initialize the BulkLoaderThread instance.

    Args:
      work_queue: A queue containing WorkItems for processing.
      throttle: A Throttles to control upload bandwidth.
      thread_gate: A ThreadGate to control number of simultaneous uploads.
      request_manager: A RequestManager instance.
    """
    _ThreadBase.__init__(self)

    self.work_queue = work_queue
    self.throttle = throttle
    self.thread_gate = thread_gate

    self.request_manager = request_manager

  def PerformWork(self):
    """Perform the work of a BulkLoaderThread."""
    while not self.exit_flag:
      success = False
      self.thread_gate.StartWork()
      try:
        try:
          item = self.work_queue.get(block=True, timeout=1.0)
        except Queue.Empty:
          continue
        if item == _THREAD_SHOULD_EXIT:
          break

        logging.debug('[%s] Got work item [%d-%d]',
                      self.getName(), item.key_start, item.key_end)

        try:

          item.MarkAsSending()
          try:
            if item.content is None:
              item.content = self.request_manager.EncodeContent(item.rows)
            try:
              self.request_manager.PostEntities(item)
              success = True
              logging.debug(
                  '[%d-%d] Sent %d entities',
                  item.key_start, item.key_end, item.count)
              self.throttle.AddTransfer(RECORDS, item.count)
            except (db.InternalError, db.NotSavedError, db.Timeout), e:
              logging.debug('Caught non-fatal error: %s', e)
            except urllib2.HTTPError, e:
              if e.code == 403 or (e.code >= 500 and e.code < 600):
                logging.debug('Caught HTTP error %d', e.code)
                logging.debug('%s', e.read())
              else:
                raise e

          except:
            self.error = sys.exc_info()[1]
            logging.exception('[%s] %s: caught exception %s', self.getName(),
                              self.__class__.__name__, str(sys.exc_info()))
            raise

        finally:
          if success:
            item.MarkAsSent()
            self.thread_gate.IncreaseWorkers()
            self.work_queue.task_done()
          else:
            item.MarkAsError()
            self.thread_gate.DecreaseWorkers()
            try:
              self.work_queue.reput(item, block=False)
            except Queue.Full:
              logging.error('[%s] Failed to reput work item.', self.getName())
              raise Error('Failed to reput work item')
          logging.info('[%d-%d] %s',
                       item.key_start, item.key_end, StateMessage(item.state))

      finally:
        self.thread_gate.FinishWork()


  def GetFriendlyName(self):
    """Returns a human-friendly name for this thread."""
    return 'worker [%s]' % self.getName()


class DataSourceThread(_ThreadBase):
  """A thread which reads WorkItems and pushes them into queue.

  This thread will read/consume WorkItems from a generator (produced by
  the generator factory). These WorkItems will then be pushed into the
  work_queue. Note that reading will block if/when the work_queue becomes
  full. Information on content consumed from the generator will be pushed
  into the progress_queue.
  """

  NAME = 'data source thread'

  def __init__(self,
               work_queue,
               progress_queue,
               workitem_generator_factory,
               progress_generator_factory):
    """Initialize the DataSourceThread instance.

    Args:
      work_queue: A queue containing WorkItems for processing.
      progress_queue: A queue used for tracking progress information.
      workitem_generator_factory: A factory that creates a WorkItem generator
      progress_generator_factory: A factory that creates a generator which
        produces prior progress status, or None if there is no prior status
        to use.
    """
    _ThreadBase.__init__(self)

    self.work_queue = work_queue
    self.progress_queue = progress_queue
    self.workitem_generator_factory = workitem_generator_factory
    self.progress_generator_factory = progress_generator_factory
    self.entity_count = 0

  def PerformWork(self):
    """Performs the work of a DataSourceThread."""
    if self.progress_generator_factory:
      progress_gen = self.progress_generator_factory()
    else:
      progress_gen = None

    content_gen = self.workitem_generator_factory(self.progress_queue,
                                                  progress_gen)

    self.sent_count = 0
    self.read_count = 0
    self.read_all = False

    for item in content_gen.Batches():
      item.MarkAsRead()

      while not self.exit_flag:
        try:
          self.work_queue.put(item, block=True, timeout=1.0)
          self.entity_count += item.count
          break
        except Queue.Full:
          pass

      if self.exit_flag:
        break

    if not self.exit_flag:
      self.read_all = True
    self.read_count = content_gen.row_count
    self.sent_count = content_gen.sent_count



def _RunningInThread(thread):
  """Return True if we are running within the specified thread."""
  return threading.currentThread().getName() == thread.getName()


class ProgressDatabase(object):
  """Persistently record all progress information during an upload.

  This class wraps a very simple SQLite database which records each of
  the relevant details from the WorkItem instances. If the uploader is
  resumed, then data is replayed out of the database.
  """

  def __init__(self, db_filename, commit_periodicity=100):
    """Initialize the ProgressDatabase instance.

    Args:
      db_filename: The name of the SQLite database to use.
      commit_periodicity: How many operations to perform between commits.
    """
    self.db_filename = db_filename

    logging.info('Using progress database: %s', db_filename)
    self.primary_conn = sqlite3.connect(db_filename, isolation_level=None)
    self.primary_thread = threading.currentThread()

    self.progress_conn = None
    self.progress_thread = None

    self.operation_count = 0
    self.commit_periodicity = commit_periodicity

    self.prior_key_end = None

    try:
      self.primary_conn.execute(
          """create table progress (
          id integer primary key autoincrement,
          state integer not null,
          key_start integer not null,
          key_end integer not null
          )
          """)
    except sqlite3.OperationalError, e:
      if 'already exists' not in e.message:
        raise

    try:
      self.primary_conn.execute('create index i_state on progress (state)')
    except sqlite3.OperationalError, e:
      if 'already exists' not in e.message:
        raise

  def ThreadComplete(self):
    """Finalize any operations the progress thread has performed.

    The database aggregates lots of operations into a single commit, and
    this method is used to commit any pending operations as the thread
    is about to shut down.
    """
    if self.progress_conn:
      self._MaybeCommit(force_commit=True)

  def _MaybeCommit(self, force_commit=False):
    """Periodically commit changes into the SQLite database.

    Committing every operation is quite expensive, and slows down the
    operation of the script. Thus, we only commit after every N operations,
    as determined by the self.commit_periodicity value. Optionally, the
    caller can force a commit.

    Args:
      force_commit: Pass True in order for a commit to occur regardless
        of the current operation count.
    """
    self.operation_count += 1
    if force_commit or (self.operation_count % self.commit_periodicity) == 0:
      self.progress_conn.commit()

  def _OpenProgressConnection(self):
    """Possibly open a database connection for the progress tracker thread.

    If the connection is not open (for the calling thread, which is assumed
    to be the progress tracker thread), then open it. We also open a couple
    cursors for later use (and reuse).
    """
    if self.progress_conn:
      return

    assert not _RunningInThread(self.primary_thread)

    self.progress_thread = threading.currentThread()

    self.progress_conn = sqlite3.connect(self.db_filename)

    self.insert_cursor = self.progress_conn.cursor()
    self.update_cursor = self.progress_conn.cursor()

  def HasUnfinishedWork(self):
    """Returns True if the database has progress information.

    Note there are two basic cases for progress information:
    1) All saved records indicate a successful upload. In this case, we
       need to skip everything transmitted so far and then send the rest.
    2) Some records for incomplete transfer are present. These need to be
       sent again, and then we resume sending after all the successful
       data.

    Returns:
      True if the database has progress information, False otherwise.

    Raises:
      ResumeError: If there is an error reading the progress database.
    """
    assert _RunningInThread(self.primary_thread)

    cursor = self.primary_conn.cursor()
    cursor.execute('select count(*) from progress')
    row = cursor.fetchone()
    if row is None:
      raise ResumeError('Error reading progress information.')

    return row[0] != 0

  def StoreKeys(self, key_start, key_end):
    """Record a new progress record, returning a key for later updates.

    The specified progress information will be persisted into the database.
    A unique key will be returned that identifies this progress state. The
    key is later used to (quickly) update this record.

    For the progress resumption to proceed properly, calls to StoreKeys
    MUST specify monotonically increasing key ranges. This will result in
    a database whereby the ID, KEY_START, and KEY_END rows are all
    increasing (rather than having ranges out of order).

    NOTE: the above precondition is NOT tested by this method (since it
    would imply an additional table read or two on each invocation).

    Args:
      key_start: The starting key of the WorkItem (inclusive)
      key_end: The end key of the WorkItem (inclusive)

    Returns:
      A string to later be used as a unique key to update this state.
    """
    self._OpenProgressConnection()

    assert _RunningInThread(self.progress_thread)
    assert isinstance(key_start, int)
    assert isinstance(key_end, int)
    assert key_start <= key_end

    if self.prior_key_end is not None:
      assert key_start > self.prior_key_end
    self.prior_key_end = key_end

    self.insert_cursor.execute(
        'insert into progress (state, key_start, key_end) values (?, ?, ?)',
        (STATE_READ, key_start, key_end))

    progress_key = self.insert_cursor.lastrowid

    self._MaybeCommit()

    return progress_key

  def UpdateState(self, key, new_state):
    """Update a specified progress record with new information.

    Args:
      key: The key for this progress record, returned from StoreKeys
      new_state: The new state to associate with this progress record.
    """
    self._OpenProgressConnection()

    assert _RunningInThread(self.progress_thread)
    assert isinstance(new_state, int)

    self.update_cursor.execute('update progress set state=? where id=?',
                               (new_state, key))

    self._MaybeCommit()

  def GetProgressStatusGenerator(self):
    """Get a generator which returns progress information.

    The returned generator will yield a series of 4-tuples that specify
    progress information about a prior run of the uploader. The 4-tuples
    have the following values:

      progress_key: The unique key to later update this record with new
                    progress information.
      state: The last state saved for this progress record.
      key_start: The starting key of the items for uploading (inclusive).
      key_end: The ending key of the items for uploading (inclusive).

    After all incompletely-transferred records are provided, then one
    more 4-tuple will be generated:

      None
      DATA_CONSUMED_TO_HERE: A unique string value indicating this record
                             is being provided.
      None
      key_end: An integer value specifying the last data source key that
               was handled by the previous run of the uploader.

    The caller should begin uploading records which occur after key_end.

    Yields:
      Progress information as tuples (progress_key, state, key_start, key_end).
    """
    conn = sqlite3.connect(self.db_filename, isolation_level=None)
    cursor = conn.cursor()

    cursor.execute('select max(id) from progress')
    batch_id = cursor.fetchone()[0]

    cursor.execute('select key_end from progress where id = ?', (batch_id,))
    key_end = cursor.fetchone()[0]

    self.prior_key_end = key_end

    cursor.execute(
        'select id, state, key_start, key_end from progress'
        '  where state != ?'
        '  order by id',
        (STATE_SENT,))

    rows = cursor.fetchall()

    for row in rows:
      if row is None:
        break

      yield row

    yield None, DATA_CONSUMED_TO_HERE, None, key_end


class StubProgressDatabase(object):
  """A stub implementation of ProgressDatabase which does nothing."""

  def HasUnfinishedWork(self):
    """Whether the stub database has progress information (it doesn't)."""
    return False

  def StoreKeys(self, unused_key_start, unused_key_end):
    """Pretend to store a key in the stub database."""
    return 'fake-key'

  def UpdateState(self, unused_key, unused_new_state):
    """Pretend to update the state of a progress item."""
    pass

  def ThreadComplete(self):
    """Finalize operations on the stub database (i.e. do nothing)."""
    pass


class ProgressTrackerThread(_ThreadBase):
  """A thread which records progress information for the upload process.

  The progress information is stored into the provided progress database.
  This class is not responsible for replaying a prior run's progress
  information out of the database. Separate mechanisms must be used to
  resume a prior upload attempt.
  """

  NAME = 'progress tracking thread'

  def __init__(self, progress_queue, progress_db):
    """Initialize the ProgressTrackerThread instance.

    Args:
      progress_queue: A Queue used for tracking progress information.
      progress_db: The database for tracking progress information; should
        be an instance of ProgressDatabase.
    """
    _ThreadBase.__init__(self)

    self.progress_queue = progress_queue
    self.db = progress_db
    self.entities_sent = 0

  def PerformWork(self):
    """Performs the work of a ProgressTrackerThread."""
    while not self.exit_flag:
      try:
        item = self.progress_queue.get(block=True, timeout=1.0)
      except Queue.Empty:
        continue
      if item == _THREAD_SHOULD_EXIT:
        break

      if item.state == STATE_READ and item.progress_key is None:
        item.progress_key = self.db.StoreKeys(item.key_start, item.key_end)
      else:
        assert item.progress_key is not None

        self.db.UpdateState(item.progress_key, item.state)
        if item.state == STATE_SENT:
          self.entities_sent += item.count

      item.progress_event.set()

      self.progress_queue.task_done()

    self.db.ThreadComplete()



def Validate(value, typ):
  """Checks that value is non-empty and of the right type.

  Args:
    value: any value
    typ: a type or tuple of types

  Raises:
    ValueError if value is None or empty.
    TypeError if it's not the given type.

  """
  if not value:
    raise ValueError('Value should not be empty; received %s.' % value)
  elif not isinstance(value, typ):
    raise TypeError('Expected a %s, but received %s (a %s).' %
                    (typ, value, value.__class__))


class Loader(object):
  """A base class for creating datastore entities from input data.

  To add a handler for bulk loading a new entity kind into your datastore,
  write a subclass of this class that calls Loader.__init__ from your
  class's __init__.

  If you need to run extra code to convert entities from the input
  data, create new properties, or otherwise modify the entities before
  they're inserted, override HandleEntity.

  See the CreateEntity method for the creation of entities from the
  (parsed) input data.
  """

  __loaders = {}
  __kind = None
  __properties = None

  def __init__(self, kind, properties):
    """Constructor.

    Populates this Loader's kind and properties map. Also registers it with
    the bulk loader, so that all you need to do is instantiate your Loader,
    and the bulkload handler will automatically use it.

    Args:
      kind: a string containing the entity kind that this loader handles

      properties: list of (name, converter) tuples.

        This is used to automatically convert the CSV columns into
        properties.  The converter should be a function that takes one
        argument, a string value from the CSV file, and returns a
        correctly typed property value that should be inserted. The
        tuples in this list should match the columns in your CSV file,
        in order.

        For example:
          [('name', str),
           ('id_number', int),
           ('email', datastore_types.Email),
           ('user', users.User),
           ('birthdate', lambda x: datetime.datetime.fromtimestamp(float(x))),
           ('description', datastore_types.Text),
           ]
    """
    Validate(kind, basestring)
    self.__kind = kind

    db.class_for_kind(kind)

    Validate(properties, list)
    for name, fn in properties:
      Validate(name, basestring)
      assert callable(fn), (
        'Conversion function %s for property %s is not callable.' % (fn, name))

    self.__properties = properties

  @staticmethod
  def RegisterLoader(loader):

    Loader.__loaders[loader.__kind] = loader

  def kind(self):
    """ Return the entity kind that this Loader handes.
    """
    return self.__kind

  def CreateEntity(self, values, key_name=None):
    """Creates a entity from a list of property values.

    Args:
      values: list/tuple of str
      key_name: if provided, the name for the (single) resulting entity

    Returns:
      list of db.Model

      The returned entities are populated with the property values from the
      argument, converted to native types using the properties map given in
      the constructor, and passed through HandleEntity. They're ready to be
      inserted.

    Raises:
      AssertionError if the number of values doesn't match the number
        of properties in the properties map.
      ValueError if any element of values is None or empty.
      TypeError if values is not a list or tuple.
    """
    Validate(values, (list, tuple))
    assert len(values) == len(self.__properties), (
      'Expected %d CSV columns, found %d.' %
      (len(self.__properties), len(values)))

    model_class = db.class_for_kind(self.__kind)

    properties = {'key_name': key_name}
    for (name, converter), val in zip(self.__properties, values):
      if converter is bool and val.lower() in ('0', 'false', 'no'):
          val = False
      properties[name] = converter(val)

    entity = model_class(**properties)
    entities = self.HandleEntity(entity)

    if entities:
      if not isinstance(entities, (list, tuple)):
        entities = [entities]

      for entity in entities:
        if not isinstance(entity, db.Model):
          raise TypeError('Expected a db.Model, received %s (a %s).' %
                          (entity, entity.__class__))

    return entities

  def GenerateKey(self, i, values):
    """Generates a key_name to be used in creating the underlying object.

    The default implementation returns None.

    This method can be overridden to control the key generation for
    uploaded entities. The value returned should be None (to use a
    server generated numeric key), or a string which neither starts
    with a digit nor has the form __*__. (See
    http://code.google.com/appengine/docs/python/datastore/keysandentitygroups.html)

    If you generate your own string keys, keep in mind:

    1. The key name for each entity must be unique.
    2. If an entity of the same kind and key already exists in the
       datastore, it will be overwritten.

    Args:
      i: Number corresponding to this object (assume it's run in a loop,
        this is your current count.
      values: list/tuple of str.

    Returns:
      A string to be used as the key_name for an entity.
    """
    return None

  def HandleEntity(self, entity):
    """Subclasses can override this to add custom entity conversion code.

    This is called for each entity, after its properties are populated from
    CSV but before it is stored. Subclasses can override this to add custom
    entity handling code.

    The entity to be inserted should be returned. If multiple entities should
    be inserted, return a list of entities. If no entities should be inserted,
    return None or [].

    Args:
      entity: db.Model

    Returns:
      db.Model or list of db.Model
    """
    return entity


  @staticmethod
  def RegisteredLoaders():
    """Returns a list of the Loader instances that have been created.
    """
    return dict(Loader.__loaders)


class QueueJoinThread(threading.Thread):
  """A thread that joins a queue and exits.

  Queue joins do not have a timeout.  To simulate a queue join with
  timeout, run this thread and join it with a timeout.
  """

  def __init__(self, queue):
    """Initialize a QueueJoinThread.

    Args:
      queue: The queue for this thread to join.
    """
    threading.Thread.__init__(self)
    assert isinstance(queue, (Queue.Queue, ReQueue))
    self.queue = queue

  def run(self):
    """Perform the queue join in this thread."""
    self.queue.join()


def InterruptibleQueueJoin(queue,
                           thread_local,
                           thread_gate,
                           queue_join_thread_factory=QueueJoinThread):
  """Repeatedly joins the given ReQueue or Queue.Queue with short timeout.

  Between each timeout on the join, worker threads are checked.

  Args:
    queue: A Queue.Queue or ReQueue instance.
    thread_local: A threading.local instance which indicates interrupts.
    thread_gate: A ThreadGate instance.
    queue_join_thread_factory: Used for dependency injection.

  Returns:
    True unless the queue join is interrupted by SIGINT or worker death.
  """
  thread = queue_join_thread_factory(queue)
  thread.start()
  while True:
    thread.join(timeout=.5)
    if not thread.isAlive():
      return True
    if thread_local.shut_down:
      logging.debug('Queue join interrupted')
      return False
    for worker_thread in thread_gate.Threads():
      if not worker_thread.isAlive():
        return False


def ShutdownThreads(data_source_thread, work_queue, thread_gate):
  """Shuts down the worker and data source threads.

  Args:
    data_source_thread: A running DataSourceThread instance.
    work_queue: The work queue.
    thread_gate: A ThreadGate instance with workers registered.
  """
  logging.info('An error occurred. Shutting down...')

  data_source_thread.exit_flag = True

  for thread in thread_gate.Threads():
    thread.exit_flag = True

  for unused_thread in thread_gate.Threads():
    thread_gate.EnableThread()

  data_source_thread.join(timeout=3.0)
  if data_source_thread.isAlive():
    logging.warn('%s hung while trying to exit',
                 data_source_thread.GetFriendlyName())

  while not work_queue.empty():
    try:
      unused_item = work_queue.get_nowait()
      work_queue.task_done()
    except Queue.Empty:
      pass


def PerformBulkUpload(app_id,
                      post_url,
                      kind,
                      workitem_generator_factory,
                      num_threads,
                      throttle,
                      progress_db,
                      max_queue_size=DEFAULT_QUEUE_SIZE,
                      request_manager_factory=RequestManager,
                      bulkloaderthread_factory=BulkLoaderThread,
                      progresstrackerthread_factory=ProgressTrackerThread,
                      datasourcethread_factory=DataSourceThread,
                      work_queue_factory=ReQueue,
                      progress_queue_factory=Queue.Queue):
  """Uploads data into an application using a series of HTTP POSTs.

  This function will spin up a number of threads to read entities from
  the data source, pass those to a number of worker ("uploader") threads
  for sending to the application, and track all of the progress in a
  small database in case an error or pause/termination requires a
  restart/resumption of the upload process.

  Args:
    app_id: String containing application id.
    post_url: URL to post the Entity data to.
    kind: Kind of the Entity records being posted.
    workitem_generator_factory: A factory that creates a WorkItem generator.
    num_threads: How many uploader threads should be created.
    throttle: A Throttle instance.
    progress_db: The database to use for replaying/recording progress.
    max_queue_size: Maximum size of the queues before they should block.
    request_manager_factory: Used for dependency injection.
    bulkloaderthread_factory: Used for dependency injection.
    progresstrackerthread_factory: Used for dependency injection.
    datasourcethread_factory: Used for dependency injection.
    work_queue_factory: Used for dependency injection.
    progress_queue_factory: Used for dependency injection.

  Raises:
    AuthenticationError: If authentication is required and fails.
  """
  thread_gate = ThreadGate(True)

  (unused_scheme,
   host_port, url_path,
   unused_query, unused_fragment) = urlparse.urlsplit(post_url)

  work_queue = work_queue_factory(max_queue_size)
  progress_queue = progress_queue_factory(max_queue_size)
  request_manager = request_manager_factory(app_id,
                                            host_port,
                                            url_path,
                                            kind,
                                            throttle)

  throttle.Register(threading.currentThread())
  try:
    request_manager.Authenticate()
  except Exception, e:
    logging.exception(e)
    raise AuthenticationError('Authentication failed')
  if (request_manager.credentials is not None and
      not request_manager.authenticated):
    raise AuthenticationError('Authentication failed')

  for unused_idx in range(num_threads):
    thread = bulkloaderthread_factory(work_queue,
                                      throttle,
                                      thread_gate,
                                      request_manager)
    throttle.Register(thread)
    thread_gate.Register(thread)

  progress_thread = progresstrackerthread_factory(progress_queue, progress_db)

  if progress_db.HasUnfinishedWork():
    logging.debug('Restarting upload using progress database')
    progress_generator_factory = progress_db.GetProgressStatusGenerator
  else:
    progress_generator_factory = None

  data_source_thread = datasourcethread_factory(work_queue,
                                                progress_queue,
                                                workitem_generator_factory,
                                                progress_generator_factory)

  thread_local = threading.local()
  thread_local.shut_down = False

  def Interrupt(unused_signum, unused_frame):
    """Shutdown gracefully in response to a signal."""
    thread_local.shut_down = True

  signal.signal(signal.SIGINT, Interrupt)

  progress_thread.start()
  data_source_thread.start()
  for thread in thread_gate.Threads():
    thread.start()


  while not thread_local.shut_down:
    data_source_thread.join(timeout=0.25)

    if data_source_thread.isAlive():
      for thread in list(thread_gate.Threads()) + [progress_thread]:
        if not thread.isAlive():
          logging.info('Unexpected thread death: %s', thread.getName())
          thread_local.shut_down = True
          break
    else:
      break

  if thread_local.shut_down:
    ShutdownThreads(data_source_thread, work_queue, thread_gate)

  def _Join(ob, msg):
    logging.debug('Waiting for %s...', msg)
    if isinstance(ob, threading.Thread):
      ob.join(timeout=3.0)
      if ob.isAlive():
        logging.debug('Joining %s failed', ob.GetFriendlyName())
      else:
        logging.debug('... done.')
    elif isinstance(ob, (Queue.Queue, ReQueue)):
      if not InterruptibleQueueJoin(ob, thread_local, thread_gate):
        ShutdownThreads(data_source_thread, work_queue, thread_gate)
    else:
      ob.join()
      logging.debug('... done.')

  _Join(work_queue, 'work_queue to flush')

  for unused_thread in thread_gate.Threads():
    work_queue.put(_THREAD_SHOULD_EXIT)

  for unused_thread in thread_gate.Threads():
    thread_gate.EnableThread()

  for thread in thread_gate.Threads():
    _Join(thread, 'thread [%s] to terminate' % thread.getName())

    thread.CheckError()

  if progress_thread.isAlive():
    _Join(progress_queue, 'progress_queue to finish')
  else:
    logging.warn('Progress thread exited prematurely')

  progress_queue.put(_THREAD_SHOULD_EXIT)
  _Join(progress_thread, 'progress_thread to terminate')
  progress_thread.CheckError()

  data_source_thread.CheckError()

  total_up, duration = throttle.TotalTransferred(BANDWIDTH_UP)
  s_total_up, unused_duration = throttle.TotalTransferred(HTTPS_BANDWIDTH_UP)
  total_up += s_total_up
  logging.info('%d entites read, %d previously transferred',
               data_source_thread.read_count,
               data_source_thread.sent_count)
  logging.info('%d entities (%d bytes) transferred in %.1f seconds',
               progress_thread.entities_sent, total_up, duration)
  if (data_source_thread.read_all and
      progress_thread.entities_sent + data_source_thread.sent_count >=
      data_source_thread.read_count):
    logging.info('All entities successfully uploaded')
  else:
    logging.info('Some entities not successfully uploaded')


def PrintUsageExit(code):
  """Prints usage information and exits with a status code.

  Args:
    code: Status code to pass to sys.exit() after displaying usage information.
  """
  print __doc__ % {'arg0': sys.argv[0]}
  sys.stdout.flush()
  sys.stderr.flush()
  sys.exit(code)


def ParseArguments(argv):
  """Parses command-line arguments.

  Prints out a help message if -h or --help is supplied.

  Args:
    argv: List of command-line arguments.

  Returns:
    Tuple (url, filename, cookie, batch_size, kind) containing the values from
    each corresponding command-line flag.
  """
  opts, unused_args = getopt.getopt(
      argv[1:],
      'h',
      ['debug',
       'help',
       'url=',
       'filename=',
       'batch_size=',
       'kind=',
       'num_threads=',
       'bandwidth_limit=',
       'rps_limit=',
       'http_limit=',
       'db_filename=',
       'app_id=',
       'config_file=',
       'auth_domain=',
      ])

  url = None
  filename = None
  batch_size = DEFAULT_BATCH_SIZE
  kind = None
  num_threads = DEFAULT_THREAD_COUNT
  bandwidth_limit = DEFAULT_BANDWIDTH_LIMIT
  rps_limit = DEFAULT_RPS_LIMIT
  http_limit = DEFAULT_REQUEST_LIMIT
  db_filename = None
  app_id = None
  config_file = None
  auth_domain = 'gmail.com'

  for option, value in opts:
    if option == '--debug':
      logging.getLogger().setLevel(logging.DEBUG)
    elif option in ('-h', '--help'):
      PrintUsageExit(0)
    elif option == '--url':
      url = value
    elif option == '--filename':
      filename = value
    elif option == '--batch_size':
      batch_size = int(value)
    elif option == '--kind':
      kind = value
    elif option == '--num_threads':
      num_threads = int(value)
    elif option == '--bandwidth_limit':
      bandwidth_limit = int(value)
    elif option == '--rps_limit':
      rps_limit = int(value)
    elif option == '--http_limit':
      http_limit = int(value)
    elif option == '--db_filename':
      db_filename = value
    elif option == '--app_id':
      app_id = value
    elif option == '--config_file':
      config_file = value
    elif option == '--auth_domain':
      auth_domain = value

  return ProcessArguments(app_id=app_id,
                          url=url,
                          filename=filename,
                          batch_size=batch_size,
                          kind=kind,
                          num_threads=num_threads,
                          bandwidth_limit=bandwidth_limit,
                          rps_limit=rps_limit,
                          http_limit=http_limit,
                          db_filename=db_filename,
                          config_file=config_file,
                          auth_domain=auth_domain,
                          die_fn=lambda: PrintUsageExit(1))


def ThrottleLayout(bandwidth_limit, http_limit, rps_limit):
  return {
      BANDWIDTH_UP: bandwidth_limit,
      BANDWIDTH_DOWN: bandwidth_limit,
      REQUESTS: http_limit,
      HTTPS_BANDWIDTH_UP: bandwidth_limit / 5,
      HTTPS_BANDWIDTH_DOWN: bandwidth_limit / 5,
      HTTPS_REQUESTS: http_limit / 5,
      RECORDS: rps_limit,
  }


def LoadConfig(config_file):
  """Loads a config file and registers any Loader classes present."""
  if config_file:
    global_dict = dict(globals())
    execfile(config_file, global_dict)
    for cls in Loader.__subclasses__():
      Loader.RegisterLoader(cls())


def _MissingArgument(arg_name, die_fn):
  """Print error message about missing argument and die."""
  print >>sys.stderr, '%s argument required' % arg_name
  die_fn()


def ProcessArguments(app_id=None,
                     url=None,
                     filename=None,
                     batch_size=DEFAULT_BATCH_SIZE,
                     kind=None,
                     num_threads=DEFAULT_THREAD_COUNT,
                     bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT,
                     rps_limit=DEFAULT_RPS_LIMIT,
                     http_limit=DEFAULT_REQUEST_LIMIT,
                     db_filename=None,
                     config_file=None,
                     auth_domain='gmail.com',
                     die_fn=lambda: sys.exit(1)):
  """Processes non command-line input arguments."""
  if db_filename is None:
    db_filename = time.strftime('bulkloader-progress-%Y%m%d.%H%M%S.sql3')

  if batch_size <= 0:
    print >>sys.stderr, 'batch_size must be 1 or larger'
    die_fn()

  if url is None:
    _MissingArgument('url', die_fn)

  if filename is None:
    _MissingArgument('filename', die_fn)

  if kind is None:
    _MissingArgument('kind', die_fn)

  if config_file is None:
    _MissingArgument('config_file', die_fn)

  if app_id is None:
    (unused_scheme, host_port, unused_url_path,
     unused_query, unused_fragment) = urlparse.urlsplit(url)
    suffix_idx = host_port.find('.appspot.com')
    if suffix_idx > -1:
      app_id = host_port[:suffix_idx]
    elif host_port.split(':')[0].endswith('google.com'):
      app_id = host_port.split('.')[0]
    else:
      print >>sys.stderr, 'app_id required for non appspot.com domains'
      die_fn()

  return (app_id, url, filename, batch_size, kind, num_threads,
          bandwidth_limit, rps_limit, http_limit, db_filename, config_file,
          auth_domain)


def _PerformBulkload(app_id=None,
                     url=None,
                     filename=None,
                     batch_size=DEFAULT_BATCH_SIZE,
                     kind=None,
                     num_threads=DEFAULT_THREAD_COUNT,
                     bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT,
                     rps_limit=DEFAULT_RPS_LIMIT,
                     http_limit=DEFAULT_REQUEST_LIMIT,
                     db_filename=None,
                     config_file=None,
                     auth_domain='gmail.com'):
  """Runs the bulkloader, given the options as keyword arguments.

  Args:
    app_id: The application id.
    url: The url of the remote_api endpoint.
    filename: The name of the file containing the CSV data.
    batch_size: The number of records to send per request.
    kind: The kind of entity to transfer.
    num_threads: The number of threads to use to transfer data.
    bandwidth_limit: Maximum bytes/second to transfers.
    rps_limit: Maximum records/second to transfer.
    http_limit: Maximum requests/second for transfers.
    db_filename: The name of the SQLite3 progress database file.
    config_file: The name of the configuration file.
    auth_domain: The auth domain to use for logins and UserProperty.

  Returns:
    An exit code.
  """
  os.environ['AUTH_DOMAIN'] = auth_domain
  LoadConfig(config_file)

  throttle_layout = ThrottleLayout(bandwidth_limit, http_limit, rps_limit)

  throttle = Throttle(layout=throttle_layout)


  workitem_generator_factory = GetCSVGeneratorFactory(filename, batch_size)

  if db_filename == 'skip':
    progress_db = StubProgressDatabase()
  else:
    progress_db = ProgressDatabase(db_filename)


  max_queue_size = max(DEFAULT_QUEUE_SIZE, 2 * num_threads + 5)

  PerformBulkUpload(app_id,
                    url,
                    kind,
                    workitem_generator_factory,
                    num_threads,
                    throttle,
                    progress_db,
                    max_queue_size=max_queue_size)

  return 0


def Run(app_id=None,
        url=None,
        filename=None,
        batch_size=DEFAULT_BATCH_SIZE,
        kind=None,
        num_threads=DEFAULT_THREAD_COUNT,
        bandwidth_limit=DEFAULT_BANDWIDTH_LIMIT,
        rps_limit=DEFAULT_RPS_LIMIT,
        http_limit=DEFAULT_REQUEST_LIMIT,
        db_filename=None,
        auth_domain='gmail.com',
        config_file=None):
  """Sets up and runs the bulkloader, given the options as keyword arguments.

  Args:
    app_id: The application id.
    url: The url of the remote_api endpoint.
    filename: The name of the file containing the CSV data.
    batch_size: The number of records to send per request.
    kind: The kind of entity to transfer.
    num_threads: The number of threads to use to transfer data.
    bandwidth_limit: Maximum bytes/second to transfers.
    rps_limit: Maximum records/second to transfer.
    http_limit: Maximum requests/second for transfers.
    db_filename: The name of the SQLite3 progress database file.
    config_file: The name of the configuration file.
    auth_domain: The auth domain to use for logins and UserProperty.

  Returns:
    An exit code.
  """
  logging.basicConfig(
      format='%(levelname)-8s %(asctime)s %(filename)s] %(message)s')
  args = ProcessArguments(app_id=app_id,
                          url=url,
                          filename=filename,
                          batch_size=batch_size,
                          kind=kind,
                          num_threads=num_threads,
                          bandwidth_limit=bandwidth_limit,
                          rps_limit=rps_limit,
                          http_limit=http_limit,
                          db_filename=db_filename,
                          config_file=config_file)

  (app_id, url, filename, batch_size, kind, num_threads, bandwidth_limit,
   rps_limit, http_limit, db_filename, config_file, auth_domain) = args

  return _PerformBulkload(app_id=app_id,
                          url=url,
                          filename=filename,
                          batch_size=batch_size,
                          kind=kind,
                          num_threads=num_threads,
                          bandwidth_limit=bandwidth_limit,
                          rps_limit=rps_limit,
                          http_limit=http_limit,
                          db_filename=db_filename,
                          config_file=config_file,
                          auth_domain=auth_domain)


def main(argv):
  """Runs the importer from the command line."""
  logging.basicConfig(
      level=logging.INFO,
      format='%(levelname)-8s %(asctime)s %(filename)s] %(message)s')

  args = ParseArguments(argv)
  if None in args:
    print >>sys.stderr, 'Invalid arguments'
    PrintUsageExit(1)

  (app_id, url, filename, batch_size, kind, num_threads,
   bandwidth_limit, rps_limit, http_limit, db_filename, config_file,
   auth_domain) = args

  return _PerformBulkload(app_id=app_id,
                          url=url,
                          filename=filename,
                          batch_size=batch_size,
                          kind=kind,
                          num_threads=num_threads,
                          bandwidth_limit=bandwidth_limit,
                          rps_limit=rps_limit,
                          http_limit=http_limit,
                          db_filename=db_filename,
                          config_file=config_file,
                          auth_domain=auth_domain)


if __name__ == '__main__':
  sys.exit(main(sys.argv))