--- a/app/soc/logic/models/base.py Thu Feb 26 16:49:06 2009 +0000
+++ b/app/soc/logic/models/base.py Thu Feb 26 16:49:45 2009 +0000
@@ -39,6 +39,21 @@
class Error(Exception):
"""Base class for all exceptions raised by this module.
"""
+
+ pass
+
+class InvalidArgumentError(Error):
+ """Raised when an invalid argument is passed to a method.
+
+ For example, if an argument is None, but must always be non-False.
+ """
+
+ pass
+
+class NoEntityError(InvalidArgumentError):
+ """Raised when no entity is passed to a method that requires one.
+ """
+
pass
@@ -105,31 +120,42 @@
value: the new value
"""
+ if not entity:
+ raise NoEntityError
+
+ if not entity_properties or (name not in entity_properties):
+ raise InvalidArgumentError
+
return True
-
+
def _onCreate(self, entity):
"""Called when an entity has been created.
Classes that override this can use it to do any post-creation operations.
"""
+ if not entity:
+ raise NoEntityError
+
sidebar.flush()
-
+
def _onUpdate(self, entity):
"""Called when an entity has been updated.
-
+
Classes that override this can use it to do any post-update operations.
"""
-
- pass
-
+
+ if not entity:
+ raise NoEntityError
+
def _onDelete(self, entity):
"""Called when an entity has been deleted.
-
+
Classes that override this can use it to do any post-deletion operations.
"""
-
- pass
+
+ if not entity:
+ raise NoEntityError
def getKeyNameFromFields(self, fields):
"""Returns the KeyName constructed from fields dict for this type of entity.
@@ -138,14 +164,17 @@
<key_value1>/<key_value2>/.../<key_valueN>
"""
+ if not fields:
+ raise InvalidArgumentError
+
key_field_names = self.getKeyFieldNames()
# check if all key_field_names for this entity are present in fields
if not all(field in fields.keys() for field in key_field_names):
- raise Error("Not all the required key fields are present")
+ raise InvalidArgumentError("Not all the required key fields are present")
if not all(fields.get(field) for field in key_field_names):
- raise Error("Not all KeyValues are non-false")
+ raise InvalidArgumentError("Not all KeyValues are non-false")
# construct the KeyValues in the order given by getKeyFieldNames()
keyvalues = []
@@ -157,7 +186,8 @@
def getFullModelClassName(self):
"""Returns fully-qualified model module.class name string.
- """
+ """
+
return '%s.%s' % (self._model.__module__, self._model.__name__)
def getKeyValuesFromEntity(self, entity):
@@ -169,6 +199,9 @@
entity: the entity from which to extract the key values
"""
+ if not entity:
+ raise NoEntityError
+
return [entity.scope_path, entity.link_id]
def getKeyValuesFromFields(self, fields):
@@ -180,6 +213,9 @@
fields: the dict from which to extract the key values
"""
+ if not all( (i in fields for i in ['scope_path', 'link_id']) ):
+ raise InvalidArgumentError
+
return [fields['scope_path'], fields['link_id']]
def getKeyFieldNames(self):
@@ -215,6 +251,9 @@
dictionary: The arguments to massage
"""
+ if not dictionary:
+ raise InvalidArgumentError
+
keys = self.getKeyFieldNames()
values = self.getKeyValuesFromFields(dictionary)
key_fields = dicts.zip(keys, values)
@@ -228,16 +267,22 @@
key_name: key name of entity
"""
+ if not key_name:
+ raise InvalidArgumentError
+
return self._model.get_by_key_name(key_name)
def getFromKeyFields(self, fields):
"""Returns the entity for the specified key names, or None if not found.
Args:
- fields: a dict containing the fields of the entity that
+ fields: a dict containing the fields of the entity that
uniquely identifies it
"""
+ if not fields:
+ raise InvalidArgumentError
+
key_fields = self.getKeyFieldsFromFields(fields)
if all(key_fields.values()):
@@ -335,6 +380,12 @@
The original entity with any supplied properties changed.
"""
+ if not entity:
+ raise NoEntityError
+
+ if not entity_properties:
+ raise InvalidArgumentError
+
def update():
return self._unsafeUpdateEntityProperties(entity, entity_properties)
@@ -380,14 +431,14 @@
"""
entity = self.getFromKeyName(key_name)
-
+
create_entity = not entity
-
+
if create_entity:
# entity did not exist, so create one in a transaction
entity = self._model.get_or_insert(key_name, **properties)
-
-
+
+
# there is no way to be sure if get_or_insert() returned a new entity or
# got an existing one due to a race, so update with properties anyway,
# in a transaction
@@ -399,21 +450,21 @@
else:
# the entity has been updated call _onUpdate
self._onUpdate(entity)
-
+
return entity
def isDeletable(self, entity):
"""Returns whether the specified entity can be deleted.
-
+
Args:
entity: an existing entity in datastore
"""
-
+
return True
def delete(self, entity):
"""Delete existing entity from datastore.
-
+
Args:
entity: an existing entity in datastore
"""