diff -r 97c4a718d6f4 -r 13efb916a73c app/soc/logic/models/base.py --- a/app/soc/logic/models/base.py Thu Feb 26 16:49:06 2009 +0000 +++ b/app/soc/logic/models/base.py Thu Feb 26 16:49:45 2009 +0000 @@ -39,6 +39,21 @@ class Error(Exception): """Base class for all exceptions raised by this module. """ + + pass + +class InvalidArgumentError(Error): + """Raised when an invalid argument is passed to a method. + + For example, if an argument is None, but must always be non-False. + """ + + pass + +class NoEntityError(InvalidArgumentError): + """Raised when no entity is passed to a method that requires one. + """ + pass @@ -105,31 +120,42 @@ value: the new value """ + if not entity: + raise NoEntityError + + if not entity_properties or (name not in entity_properties): + raise InvalidArgumentError + return True - + def _onCreate(self, entity): """Called when an entity has been created. Classes that override this can use it to do any post-creation operations. """ + if not entity: + raise NoEntityError + sidebar.flush() - + def _onUpdate(self, entity): """Called when an entity has been updated. - + Classes that override this can use it to do any post-update operations. """ - - pass - + + if not entity: + raise NoEntityError + def _onDelete(self, entity): """Called when an entity has been deleted. - + Classes that override this can use it to do any post-deletion operations. """ - - pass + + if not entity: + raise NoEntityError def getKeyNameFromFields(self, fields): """Returns the KeyName constructed from fields dict for this type of entity. @@ -138,14 +164,17 @@ //.../ """ + if not fields: + raise InvalidArgumentError + key_field_names = self.getKeyFieldNames() # check if all key_field_names for this entity are present in fields if not all(field in fields.keys() for field in key_field_names): - raise Error("Not all the required key fields are present") + raise InvalidArgumentError("Not all the required key fields are present") if not all(fields.get(field) for field in key_field_names): - raise Error("Not all KeyValues are non-false") + raise InvalidArgumentError("Not all KeyValues are non-false") # construct the KeyValues in the order given by getKeyFieldNames() keyvalues = [] @@ -157,7 +186,8 @@ def getFullModelClassName(self): """Returns fully-qualified model module.class name string. - """ + """ + return '%s.%s' % (self._model.__module__, self._model.__name__) def getKeyValuesFromEntity(self, entity): @@ -169,6 +199,9 @@ entity: the entity from which to extract the key values """ + if not entity: + raise NoEntityError + return [entity.scope_path, entity.link_id] def getKeyValuesFromFields(self, fields): @@ -180,6 +213,9 @@ fields: the dict from which to extract the key values """ + if not all( (i in fields for i in ['scope_path', 'link_id']) ): + raise InvalidArgumentError + return [fields['scope_path'], fields['link_id']] def getKeyFieldNames(self): @@ -215,6 +251,9 @@ dictionary: The arguments to massage """ + if not dictionary: + raise InvalidArgumentError + keys = self.getKeyFieldNames() values = self.getKeyValuesFromFields(dictionary) key_fields = dicts.zip(keys, values) @@ -228,16 +267,22 @@ key_name: key name of entity """ + if not key_name: + raise InvalidArgumentError + return self._model.get_by_key_name(key_name) def getFromKeyFields(self, fields): """Returns the entity for the specified key names, or None if not found. Args: - fields: a dict containing the fields of the entity that + fields: a dict containing the fields of the entity that uniquely identifies it """ + if not fields: + raise InvalidArgumentError + key_fields = self.getKeyFieldsFromFields(fields) if all(key_fields.values()): @@ -335,6 +380,12 @@ The original entity with any supplied properties changed. """ + if not entity: + raise NoEntityError + + if not entity_properties: + raise InvalidArgumentError + def update(): return self._unsafeUpdateEntityProperties(entity, entity_properties) @@ -380,14 +431,14 @@ """ entity = self.getFromKeyName(key_name) - + create_entity = not entity - + if create_entity: # entity did not exist, so create one in a transaction entity = self._model.get_or_insert(key_name, **properties) - - + + # there is no way to be sure if get_or_insert() returned a new entity or # got an existing one due to a race, so update with properties anyway, # in a transaction @@ -399,21 +450,21 @@ else: # the entity has been updated call _onUpdate self._onUpdate(entity) - + return entity def isDeletable(self, entity): """Returns whether the specified entity can be deleted. - + Args: entity: an existing entity in datastore """ - + return True def delete(self, entity): """Delete existing entity from datastore. - + Args: entity: an existing entity in datastore """