thirdparty/google_appengine/google/appengine/ext/bulkload/__init__.py
changeset 297 35211afcd563
parent 149 f2e327a7c5de
child 686 df109be0567c
--- a/thirdparty/google_appengine/google/appengine/ext/bulkload/__init__.py	Fri Oct 10 06:56:56 2008 +0000
+++ b/thirdparty/google_appengine/google/appengine/ext/bulkload/__init__.py	Fri Oct 10 13:14:24 2008 +0000
@@ -28,6 +28,7 @@
   'Person',
   [('name', str),
    ('email', datastore_types.Email),
+   ('cool', bool), # ('0', 'False', 'No', '')=False, otherwise bool(value)
    ('birthdate', lambda x: datetime.datetime.fromtimestamp(float(x))),
   ])
 
@@ -108,7 +109,7 @@
 import traceback
 import types
 import struct
-
+import zlib
 
 import google
 import wsgiref.handlers
@@ -227,6 +228,8 @@
 
     entity = datastore.Entity(self.__kind, name=key_name)
     for (name, converter), val in zip(self.__properties, values):
+      if converter is bool and val.lower() in ('0', 'false', 'no'):
+          val = False
       entity[name] = converter(val)
 
     entities = self.HandleEntity(entity)
@@ -341,9 +344,51 @@
     page += '</body></html>'
     return page
 
+  def IterRows(self, reader):
+    """ Yields a tuple of a line number and row for each row of the CSV data.
+
+    Args:
+      reader: a csv reader for the input data.
+    """
+    line_num = 1
+    for columns in reader:
+      yield (line_num, columns)
+      line_num += 1
+
+  def LoadEntities(self, iter, loader, key_format=None):
+    """Generates entities and loads them into the datastore.  Returns
+    a tuple of HTTP code and string reply.
+
+    Args:
+      iter: an iterator yielding pairs of a line number and row contents.
+      key_format: a format string to convert a line number into an
+        entity id. If None, then entity ID's are automatically generated.
+      """
+    entities = []
+    output = []
+    for line_num, columns in iter:
+      key_name = None
+      if key_format is not None:
+        key_name = key_format % line_num
+      if columns:
+        try:
+          output.append('\nLoading from line %d...' % line_num)
+          new_entities = loader.CreateEntity(columns, key_name=key_name)
+          if new_entities:
+            entities.extend(new_entities)
+          output.append('done.')
+        except:
+          stacktrace = traceback.format_exc()
+          output.append('error:\n%s' % stacktrace)
+          return (httplib.BAD_REQUEST, ''.join(output))
+
+    for entity in entities:
+      datastore.Put(entity)
+
+    return (httplib.OK, ''.join(output))
 
   def Load(self, kind, data):
-    """ Parses CSV data, uses a Loader to convert to entities, and stores them.
+    """Parses CSV data, uses a Loader to convert to entities, and stores them.
 
     On error, fails fast. Returns a "bad request" HTTP response code and
     includes the traceback in the output.
@@ -375,28 +420,34 @@
     except AttributeError:
       pass
 
-    entities = []
+    return self.LoadEntities(self.IterRows(reader), loader)
+
+  def IterRowsV1(self, data):
+    """Yields a tuple of columns for each row in the uploaded data.
+
+    Args:
+      data: a string containing the unzipped v1 format data to load.
+
+    """
+    column_count, = struct.unpack_from('!i', data)
+    offset = 4
+
+    lengths_format = '!%di' % (column_count,)
 
-    line_num = 1
-    for columns in reader:
-      if columns:
-        try:
-          output.append('\nLoading from line %d...' % line_num)
-          new_entities = loader.CreateEntity(columns)
-          if new_entities:
-            entities.extend(new_entities)
-          output.append('done.')
-        except:
-          stacktrace = traceback.format_exc()
-          output.append('error:\n%s' % stacktrace)
-          return (httplib.BAD_REQUEST, ''.join(output))
+    while offset < len(data):
+      id_num = struct.unpack_from('!i', data, offset=offset)
+      offset += 4
+
+      value_lengths = struct.unpack_from(lengths_format, data, offset=offset)
+      offset += 4 * column_count
 
-      line_num += 1
+      columns = struct.unpack_from(''.join('%ds' % length
+                                           for length in value_lengths), data,
+                                   offset=offset)
+      offset += sum(value_lengths)
 
-    for entity in entities:
-      datastore.Put(entity)
+      yield (id_num, columns)
 
-    return (httplib.OK, ''.join(output))
 
   def LoadV1(self, kind, data):
     """Parses version-1 format data, converts to entities, and stores them.
@@ -421,46 +472,19 @@
       loader = Loader.RegisteredLoaders()[kind]
     except KeyError:
       output.append('Error: no Loader defined for kind %s.' % kind)
-      return httplib.BAD_REQUEST, ''.join(output)
-
-    entities = []
-
-    column_count, = struct.unpack_from('!i', data)
-
-    offset = 4
-
-    lengths_format = '!%di' % (column_count,)
-
-    while offset < len(data):
-      id_num = struct.unpack_from('!i', data, offset=offset)
-      offset += 4
-
-      key_name = 'i%010d' % id_num
-
-      value_lengths = struct.unpack_from(lengths_format, data, offset=offset)
-      offset += 4 * column_count
+      return (httplib.BAD_REQUEST, ''.join(output))
 
-      columns = struct.unpack_from(''.join('%ds' % length
-                                           for length in value_lengths), data,
-                                   offset=offset)
-      offset += sum(value_lengths)
+    try:
+      data = zlib.decompress(data)
+    except:
+      stacktrace = traceback.format_exc()
+      output.append('Error: Could not decompress data\n%s' % stacktrace)
+      return (httplib.BAD_REQUEST, ''.join(output))
 
-      try:
-        output.append('Loading key_name=%s... ' % key_name)
-        new_entities = loader.CreateEntity(columns, key_name=key_name)
-        if new_entities:
-          entities.extend(new_entities)
-        output.append('done.\n')
-      except:
-        stacktrace = traceback.format_exc()
-        output.append('error:\n%s' % stacktrace)
-        return httplib.BAD_REQUEST, ''.join(output)
-
-    for entity in entities:
-      datastore.Put(entity)
-
-    return httplib.OK, ''.join(output)
-
+    key_format = 'i%010d'
+    return self.LoadEntities(self.IterRowsV1(data),
+                             loader,
+                             key_format=key_format)
 
 def main(*loaders):
   """Starts bulk upload.