Add argument validation to base.Logic
authorSverre Rabbelier <srabbelier@gmail.com>
Thu, 26 Feb 2009 16:49:45 +0000
changeset 1513 13efb916a73c
parent 1512 97c4a718d6f4
child 1514 4a233f5a4264
Add argument validation to base.Logic Patch by: Sverre Rabbelier
app/soc/logic/models/base.py
--- a/app/soc/logic/models/base.py	Thu Feb 26 16:49:06 2009 +0000
+++ b/app/soc/logic/models/base.py	Thu Feb 26 16:49:45 2009 +0000
@@ -39,6 +39,21 @@
 class Error(Exception):
   """Base class for all exceptions raised by this module.
   """
+
+  pass
+
+class InvalidArgumentError(Error):
+  """Raised when an invalid argument is passed to a method.
+
+  For example, if an argument is None, but must always be non-False.
+  """
+
+  pass
+
+class NoEntityError(InvalidArgumentError):
+  """Raised when no entity is passed to a method that requires one.
+  """
+
   pass
 
 
@@ -105,31 +120,42 @@
       value: the new value
     """
 
+    if not entity:
+      raise NoEntityError
+
+    if not entity_properties or (name not in entity_properties):
+      raise InvalidArgumentError
+
     return True
-  
+
   def _onCreate(self, entity):
     """Called when an entity has been created.
 
     Classes that override this can use it to do any post-creation operations.
     """
 
+    if not entity:
+      raise NoEntityError
+
     sidebar.flush()
-  
+
   def _onUpdate(self, entity):
     """Called when an entity has been updated.
-    
+
     Classes that override this can use it to do any post-update operations.
     """
-    
-    pass
-  
+
+    if not entity:
+      raise NoEntityError
+
   def _onDelete(self, entity):
     """Called when an entity has been deleted.
-    
+
     Classes that override this can use it to do any post-deletion operations.
     """
-    
-    pass
+
+    if not entity:
+      raise NoEntityError
 
   def getKeyNameFromFields(self, fields):
     """Returns the KeyName constructed from fields dict for this type of entity.
@@ -138,14 +164,17 @@
     <key_value1>/<key_value2>/.../<key_valueN>
     """
 
+    if not fields:
+      raise InvalidArgumentError
+
     key_field_names = self.getKeyFieldNames()
 
     # check if all key_field_names for this entity are present in fields
     if not all(field in fields.keys() for field in key_field_names):
-      raise Error("Not all the required key fields are present")
+      raise InvalidArgumentError("Not all the required key fields are present")
 
     if not all(fields.get(field) for field in key_field_names):
-      raise Error("Not all KeyValues are non-false")
+      raise InvalidArgumentError("Not all KeyValues are non-false")
 
     # construct the KeyValues in the order given by getKeyFieldNames()
     keyvalues = []
@@ -157,7 +186,8 @@
 
   def getFullModelClassName(self):
     """Returns fully-qualified model module.class name string.
-    """ 
+    """
+
     return '%s.%s' % (self._model.__module__, self._model.__name__)
 
   def getKeyValuesFromEntity(self, entity):
@@ -169,6 +199,9 @@
       entity: the entity from which to extract the key values
     """
 
+    if not entity:
+      raise NoEntityError
+
     return [entity.scope_path, entity.link_id]
 
   def getKeyValuesFromFields(self, fields):
@@ -180,6 +213,9 @@
       fields: the dict from which to extract the key values
     """
 
+    if not all( (i in fields for i in ['scope_path', 'link_id']) ):
+      raise InvalidArgumentError
+
     return [fields['scope_path'], fields['link_id']]
 
   def getKeyFieldNames(self):
@@ -215,6 +251,9 @@
       dictionary: The arguments to massage
     """
 
+    if not dictionary:
+      raise InvalidArgumentError
+
     keys = self.getKeyFieldNames()
     values = self.getKeyValuesFromFields(dictionary)
     key_fields = dicts.zip(keys, values)
@@ -228,16 +267,22 @@
       key_name: key name of entity
     """
 
+    if not key_name:
+      raise InvalidArgumentError
+
     return self._model.get_by_key_name(key_name)
 
   def getFromKeyFields(self, fields):
     """Returns the entity for the specified key names, or None if not found.
 
     Args:
-      fields: a dict containing the fields of the entity that 
+      fields: a dict containing the fields of the entity that
         uniquely identifies it
     """
 
+    if not fields:
+      raise InvalidArgumentError
+
     key_fields = self.getKeyFieldsFromFields(fields)
 
     if all(key_fields.values()):
@@ -335,6 +380,12 @@
       The original entity with any supplied properties changed.
     """
 
+    if not entity:
+      raise NoEntityError
+
+    if not entity_properties:
+      raise InvalidArgumentError
+
     def update():
       return self._unsafeUpdateEntityProperties(entity, entity_properties)
 
@@ -380,14 +431,14 @@
     """
 
     entity = self.getFromKeyName(key_name)
-    
+
     create_entity = not entity
-    
+
     if create_entity:
       # entity did not exist, so create one in a transaction
       entity = self._model.get_or_insert(key_name, **properties)
-      
-    
+
+
     # there is no way to be sure if get_or_insert() returned a new entity or
     # got an existing one due to a race, so update with properties anyway,
     # in a transaction
@@ -399,21 +450,21 @@
     else:
       # the entity has been updated call _onUpdate
       self._onUpdate(entity)
-      
+
     return entity
 
   def isDeletable(self, entity):
     """Returns whether the specified entity can be deleted.
-    
+
     Args:
       entity: an existing entity in datastore
     """
-    
+
     return True
 
   def delete(self, entity):
     """Delete existing entity from datastore.
-    
+
     Args:
       entity: an existing entity in datastore
     """