app/soc/logic/models/base.py
author Sverre Rabbelier <srabbelier@gmail.com>
Tue, 22 Sep 2009 13:07:27 +0200
changeset 2968 7ba28890eb75
parent 2897 c0e78185444c
child 2980 cbfd8e12527a
permissions -rw-r--r--
Prevent modification of key fields Only non-id-based entities are affected as id-based entities do not have any un-modificable fields.

#!/usr/bin/python2.5
#
# Copyright 2008 the Melange authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers functions for updating different kinds of models in datastore.
"""

__authors__ = [
  '"Madhusudan C.S." <madhusudancs@gmail.com>',
  '"Todd Larsen" <tlarsen@google.com>',
  '"Sverre Rabbelier" <sverre@rabbelier.nl>',
  '"Lennard de Rijk" <ljvderijk@gmail.com>',
  '"Pawel Solyga" <pawel.solyga@gmail.com>',
  ]


import logging

from google.appengine.ext import db

from django.utils.translation import ugettext

from soc.cache import sidebar
from soc.logic import dicts
from soc.views import out_of_band


class Error(Exception):
  """Base class for all exceptions raised by this module.
  """

  pass


class InvalidArgumentError(Error):
  """Raised when an invalid argument is passed to a method.

  For example, if an argument is None, but must always be non-False.
  """

  pass


class NoEntityError(InvalidArgumentError):
  """Raised when no entity is passed to a method that requires one.
  """

  pass


class Logic(object):
  """Base logic for entity classes.

  The BaseLogic class functions specific to Entity classes by relying
  on arguments passed to __init__.
  """

  def __init__(self, model, base_model=None, scope_logic=None,
               name=None, skip_properties=None, id_based=False):
    """Defines the name, key_name and model for this entity.
    """

    self._model = model
    self._base_model = base_model
    self._scope_logic = scope_logic
    self._id_based = id_based

    if name:
      self._name = name
    else:
      self._name =  self._model.__name__

    if skip_properties:
      self._skip_properties = skip_properties
    else:
      self._skip_properties = []

  def skipField(self, name):
    """Returns whether a field with the specified name should be saved.
    """

    if name in self._skip_properties:
      return True

    if self._id_based:
      return False

    if name in self.getKeyFieldNames():
      return True

    return False

  def getModel(self):
    """Returns the model this logic class uses.
    """

    return self._model

  def getScopeLogic(self):
    """Returns the logic of the enclosing scope.
    """

    return self._scope_logic

  def getScopeDepth(self):
    """Returns the scope depth for this entity.

    Returns None if any of the parent scopes return None.
    """

    if not self._scope_logic:
      return 0

    depth = self._scope_logic.logic.getScopeDepth()
    return None if (depth is None) else (depth + 1)

  def getKeyNameFromFields(self, fields):
    """Returns the KeyName constructed from fields dict for this type of entity.

    The KeyName is in the following format:
    <key_value1>/<key_value2>/.../<key_valueN>
    """

    if not fields:
      raise InvalidArgumentError

    key_field_names = self.getKeyFieldNames()

    # check if all key_field_names for this entity are present in fields
    if not all(field in fields.keys() for field in key_field_names):
      raise InvalidArgumentError("Not all the required key fields are present")

    if not all(fields.get(field) for field in key_field_names):
      raise InvalidArgumentError("Not all KeyValues are non-false")

    # construct the KeyValues in the order given by getKeyFieldNames()
    keyvalues = []
    for key_field_name in key_field_names:
      keyvalues.append(fields[key_field_name])

    # construct the KeyName in the appropriate format
    return '/'.join(keyvalues)

  def getFullModelClassName(self):
    """Returns fully-qualified model module.class name string.
    """

    return '%s.%s' % (self._model.__module__, self._model.__name__)

  def getKeyValuesFromEntity(self, entity):
    """Extracts the key values from entity and returns them.

    The default implementation uses the scope and link_id as key values.

    Args:
      entity: the entity from which to extract the key values
    """

    if not entity:
      raise NoEntityError

    return [entity.scope_path, entity.link_id]

  def getKeyValuesFromFields(self, fields):
    """Extracts the key values from a dict and returns them.

    The default implementation uses the scope and link_id as key values.

    Args:
      fields: the dict from which to extract the key values
    """

    if ('scope_path' not in fields) or ('link_id' not in fields):
      raise InvalidArgumentError

    return [fields['scope_path'], fields['link_id']]

  def getKeyFieldNames(self):
    """Returns an array with the names of the Key Fields.

    The default implementation uses the scope and link_id as key values.
    """

    return ['scope_path', 'link_id']

  def getKeyFieldsFromFields(self, dictionary):
    """Does any required massaging and filtering of dictionary.

    The resulting dictionary contains just the key names, and has any
    required translations/modifications performed.

    Args:
      dictionary: The arguments to massage
    """

    if not dictionary:
      raise InvalidArgumentError

    keys = self.getKeyFieldNames()
    values = self.getKeyValuesFromFields(dictionary)
    key_fields = dicts.zip(keys, values)

    return key_fields

  def getFromKeyName(self, key_name, parent=None):
    """Returns entity for key_name or None if not found.

    Args:
      key_name: key name of entity
      parent: parent of the entity
    """

    if self._id_based:
      raise Error("getFromKeyName called on an id based logic")

    if not key_name:
      raise InvalidArgumentError

    return self._model.get_by_key_name(key_name, parent=parent)

  def getFromID(self, id, parent=None):
    """Returns entity for id or None if not found.

    Args:
      id: id of entity
      parent: parent of the entity
    """

    if not self._id_based:
      raise Error("getFromID called on a not id based logic")

    if not id:
      raise InvalidArgumentError

    return self._model.get_by_id(id, parent=parent)

  def getFromKeyNameOr404(self, key_name):
    """Like getFromKeyName but expects to find an entity.

    Raises:
      out_of_band.Error if no entity is found
    """

    entity = self.getFromKeyName(key_name)

    if entity:
      return entity

    msg = ugettext('There is no "%(name)s" named %(key_name)s.') % {
        'name': self._name, 'key_name': key_name}

    raise out_of_band.Error(msg, status=404)

  def getFromIDOr404(self, id):
    """Like getFromID but expects to find an entity.

    Raises:
      out_of_band.Error if no entity is found
    """

    entity = self.getFromID(id)

    if entity:
      return entity

    msg = ugettext('There is no "%(name)s" with id %(id)s.') % {
        'name': self._name, 'id': id}

    raise out_of_band.Error(msg, status=404)

  def getFromKeyFields(self, fields):
    """Returns the entity for the specified key names, or None if not found.

    Args:
      fields: a dict containing the fields of the entity that
        uniquely identifies it
    """

    if not fields:
      raise InvalidArgumentError

    key_fields = self.getKeyFieldsFromFields(fields)

    if all(key_fields.values()):
      key_name = self.getKeyNameFromFields(key_fields)
      entity = self.getFromKeyName(key_name)
    else:
      entity = None

    return entity

  def getFromKeyFieldsOr404(self, fields):
    """Like getFromKeyFields but expects to find an entity.

    Raises:
      out_of_band.Error if no entity is found
    """

    entity = self.getFromKeyFields(fields)

    if entity:
      return entity

    key_fields = self.getKeyFieldsFromFields(fields)
    format_text = ugettext('"%(key)s" is "%(value)s"')

    msg_pairs = [format_text % {'key': key, 'value': value}
      for key, value in key_fields.iteritems()]

    joined_pairs = ' and '.join(msg_pairs)

    msg = ugettext(
      'There is no "%(name)s" where %(pairs)s.') % {
        'name': self._name, 'pairs': joined_pairs}

    raise out_of_band.Error(msg, status=404)

  def prefetchField(self, field, data):
    """Prefetches all fields in data from the datastore in one fetch.

    Args:
      field: the field that should be fetched
      data: the data for which the prefetch should be done
    """

    prop = getattr(self._model, field, None)

    if not prop:
      logging.exception("Model %s does not have attribute %s" %
                        (self._model, field))
      return

    if not isinstance(prop, db.ReferenceProperty):
      logging.exception("Property %s of %s is not a ReferenceProperty but a %s" %
                        (field, self._model.kind(), prop.__class__.__name__))
      return

    keys = [prop.get_value_for_datastore(i) for i in data]

    prefetched_entities = db.get(keys)
    prefetched_dict = dict((i.key(), i) for i in prefetched_entities)

    for i in data:
      value = prefetched_dict[prop.get_value_for_datastore(i)]
      setattr(i, field, value)

  def getForFields(self, filter=None, unique=False, limit=1000, offset=0,
                   ancestors=None, order=None, prefetch=None):
    """Returns all entities that have the specified properties.

    Args:
      filter: a dict for the properties that the entities should have
      unique: if set, only the first item from the resultset will be returned
      limit: the amount of entities to fetch at most
      offset: the position to start at
      ancestors: list of ancestors properties to set for this query
      order: a list with the sort order
      prefetch: the fields of the data that should be pre-fetched,
          has no effect if unique is True
    """

    if unique:
      limit = 1

    if not prefetch:
      prefetch = []

    query = self.getQueryForFields(filter=filter, 
                                   ancestors=ancestors, order=order)

    try:
      result = query.fetch(limit, offset)
    except db.NeedIndexError, exception:
      result = []
      logging.exception("%s, model: %s filter: %s, ancestors: %s, order: %s" % 
                        (exception, self._model, filter, ancestors, order))
      # TODO: send email

    if unique:
      return result[0] if result else None

    for field in prefetch:
      self.prefetchField(field, result)

    return result

  def getQueryForFields(self, filter=None, ancestors=None, order=None):
    """Returns a query with the specified properties.

    Args:
      filter: a dict for the properties that the entities should have
      ancestors: list with ancestor properties to set for this query
      order: a list with the sort order

    Returns:
      - Query object instantiated with the given properties
    """

    from google.appengine.api import memcache
    from soc.logic import system

    if not filter:
      filter = {}

    if not ancestors:
      ancestors = []

    if not order:
      order = []

    orderset = set([i.strip('-') for i in order])
    if len(orderset) != len(order):
      raise InvalidArgumentError

    query = db.Query(self._model)

    for key, value in filter.iteritems():
      if isinstance(value, list) and len(value) == 1:
        value = value[0]
      if isinstance(value, list):
        op = '%s IN' % key
        query.filter(op, value)
      else:
        query.filter(key, value)

    for ancestor in ancestors:
      query.ancestor(ancestor)

    for key in order:
      query.order(key)

    return query

  def updateEntityProperties(self, entity, entity_properties, silent=False,
                             store=True):
    """Update existing entity using supplied properties.

    Args:
      entity: a model entity
      entity_properties: keyword arguments that correspond to entity
        properties and their values
      silent: iff True does not call _onUpdate method
      store: iff True updated entity is actually stored in the data model

    Returns:
      The original entity with any supplied properties changed.
    """

    if not entity:
      raise NoEntityError

    if not entity_properties:
      raise InvalidArgumentError

    properties = self._model.properties()

    for name, prop in properties.iteritems():
      # if the property is not updateable or is not updated, skip it
      if self.skipField(name) or (name not in entity_properties):
        continue

      if self._updateField(entity, entity_properties, name):
        value = entity_properties[name]
        prop.__set__(entity, value)

    if store:
      entity.put()

    # call the _onUpdate method
    if not silent:
      self._onUpdate(entity)

    return entity

  def updateOrCreateFromKeyName(self, properties, key_name, silent=False):
    """Update existing entity, or create new one with supplied properties.

    Args:
      properties: dict with entity properties and their values
      key_name: the key_name of the entity that uniquely identifies it
      silent: if True, do not run the _onCreate hook

    Returns:
      the entity corresponding to the key_name, with any supplied
      properties changed, or a new entity now associated with the
      supplied key_name and properties.
    """

    entity = self.getFromKeyName(key_name)

    create_entity = not entity

    if create_entity:
      for property_name in properties:
        self._createField(properties, property_name)

      # entity did not exist, so create one in a transaction
      entity = self._model.get_or_insert(key_name, **properties)

      if not silent:
        # a new entity has been created call _onCreate
        self._onCreate(entity)
    else:
      # If someone else already created the entity (due to a race), we
      # should not update the propties (as they 'won' the race).
      entity = self.updateEntityProperties(entity, properties, silent=silent)

    return entity

  def updateOrCreateFromFields(self, properties, silent=False):
    """Update existing entity or creates a new entity with the supplied 
    properties containing the key fields.

    Args:
      properties: dict with entity properties and their values
      silent: if True, do not run the _onCreate hook
    """

    for property_name in properties:
      self._createField(properties, property_name)

    if self._id_based:
      entity = self._model(**properties)
      entity.put()

      if not silent:
        # call the _onCreate hook
        self._onCreate(entity)
    else:
      key_name = self.getKeyNameFromFields(properties)
      entity = self.updateOrCreateFromKeyName(properties, key_name,
                                              silent=silent)

    return entity

  def isDeletable(self, entity):
    """Returns whether the specified entity can be deleted.

    Args:
      entity: an existing entity in datastore
    """

    return True

  def delete(self, entity):
    """Delete existing entity from datastore.

    Args:
      entity: an existing entity in datastore
    """

    entity.delete()
    # entity has been deleted call _onDelete
    self._onDelete(entity)

  def getAll(self, query):
    """Retrieves all entities for the specified query.
    """

    chunk = 999
    offset = 0
    result = []
    more = True

    while(more):
      data = query.fetch(chunk+1, offset)

      more = len(data) > chunk

      if more:
        del data[chunk]

      result.extend(data)
      offset = offset + chunk

    return result
  # pylint: disable-msg=C0103
  def entityIterator(self, queryGen, batch_size = 100):
    """Iterator that yields an entity in batches.

    Args:
      queryGen: should return a Query object
      batchSize: how many entities to retrieve in one datastore call

    Retrieved from http://tinyurl.com/d887ll (AppEngine cookbook).
    """

     # AppEngine will not fetch more than 1000 results
    batch_size = min(batch_size, 1000)

    done = False
    count = 0
    key = None

    while not done:
      query = queryGen()
      if key:
        query.filter("__key__ > ", key)
      results = query.fetch(batch_size)
      for result in results:
        count += 1
        yield result
      if batch_size > len(results):
        done = True
      else:
        key = results[-1].key()

  def _createField(self, entity_properties, name):
    """Hook called when a field is created.

    To be exact, this method is called for each field (that has a value
    specified) on an entity that is being created.

    Base classes should override if any special actions need to be
    taken when a field is created.

    Args:
      entity_properties: keyword arguments that correspond to entity
        properties and their values
      name: the name of the field to be created
    """

    if not entity_properties or (name not in entity_properties):
      raise InvalidArgumentError

  def _updateField(self, entity, entity_properties, name):
    """Hook called when a field is updated.

    Base classes should override if any special actions need to be
    taken when a field is updated. The field is not updated if the
    method does not return a True value.

    Args:
      entity: the unaltered entity
      entity_properties: keyword arguments that correspond to entity
        properties and their values
      name: the name of the field to be changed
    """

    if not entity:
      raise NoEntityError

    if not entity_properties or (name not in entity_properties):
      raise InvalidArgumentError

    return True

  def _onCreate(self, entity):
    """Called when an entity has been created.

    Classes that override this can use it to do any post-creation operations.
    """

    if not entity:
      raise NoEntityError

    sidebar.flush()

  def _onUpdate(self, entity):
    """Called when an entity has been updated.

    Classes that override this can use it to do any post-update operations.
    """

    if not entity:
      raise NoEntityError

  def _onDelete(self, entity):
    """Called when an entity has been deleted.

    Classes that override this can use it to do any post-deletion operations.
    """

    if not entity:
      raise NoEntityError