tests/pymox/mox.py
changeset 1000 9af147fc1f1c
equal deleted inserted replaced
999:71f15c023847 1000:9af147fc1f1c
       
     1 #!/usr/bin/python2.4
       
     2 #
       
     3 # Copyright 2008 Google Inc.
       
     4 #
       
     5 # Licensed under the Apache License, Version 2.0 (the "License");
       
     6 # you may not use this file except in compliance with the License.
       
     7 # You may obtain a copy of the License at
       
     8 #
       
     9 #      http://www.apache.org/licenses/LICENSE-2.0
       
    10 #
       
    11 # Unless required by applicable law or agreed to in writing, software
       
    12 # distributed under the License is distributed on an "AS IS" BASIS,
       
    13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
       
    14 # See the License for the specific language governing permissions and
       
    15 # limitations under the License.
       
    16 
       
    17 """Mox, an object-mocking framework for Python.
       
    18 
       
    19 Mox works in the record-replay-verify paradigm.  When you first create
       
    20 a mock object, it is in record mode.  You then programmatically set
       
    21 the expected behavior of the mock object (what methods are to be
       
    22 called on it, with what parameters, what they should return, and in
       
    23 what order).
       
    24 
       
    25 Once you have set up the expected mock behavior, you put it in replay
       
    26 mode.  Now the mock responds to method calls just as you told it to.
       
    27 If an unexpected method (or an expected method with unexpected
       
    28 parameters) is called, then an exception will be raised.
       
    29 
       
    30 Once you are done interacting with the mock, you need to verify that
       
    31 all the expected interactions occured.  (Maybe your code exited
       
    32 prematurely without calling some cleanup method!)  The verify phase
       
    33 ensures that every expected method was called; otherwise, an exception
       
    34 will be raised.
       
    35 
       
    36 Suggested usage / workflow:
       
    37 
       
    38   # Create Mox factory
       
    39   my_mox = Mox()
       
    40 
       
    41   # Create a mock data access object
       
    42   mock_dao = my_mox.CreateMock(DAOClass)
       
    43 
       
    44   # Set up expected behavior
       
    45   mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
       
    46   mock_dao.DeletePerson(person)
       
    47 
       
    48   # Put mocks in replay mode
       
    49   my_mox.ReplayAll()
       
    50 
       
    51   # Inject mock object and run test
       
    52   controller.SetDao(mock_dao)
       
    53   controller.DeletePersonById('1')
       
    54 
       
    55   # Verify all methods were called as expected
       
    56   my_mox.VerifyAll()
       
    57 """
       
    58 
       
    59 from collections import deque
       
    60 import inspect
       
    61 import re
       
    62 import types
       
    63 import unittest
       
    64 
       
    65 import stubout
       
    66 
       
    67 class Error(AssertionError):
       
    68   """Base exception for this module."""
       
    69 
       
    70   pass
       
    71 
       
    72 
       
    73 class ExpectedMethodCallsError(Error):
       
    74   """Raised when Verify() is called before all expected methods have been called
       
    75   """
       
    76 
       
    77   def __init__(self, expected_methods):
       
    78     """Init exception.
       
    79 
       
    80     Args:
       
    81       # expected_methods: A sequence of MockMethod objects that should have been
       
    82       #   called.
       
    83       expected_methods: [MockMethod]
       
    84 
       
    85     Raises:
       
    86       ValueError: if expected_methods contains no methods.
       
    87     """
       
    88 
       
    89     if not expected_methods:
       
    90       raise ValueError("There must be at least one expected method")
       
    91     Error.__init__(self)
       
    92     self._expected_methods = expected_methods
       
    93 
       
    94   def __str__(self):
       
    95     calls = "\n".join(["%3d.  %s" % (i, m)
       
    96                        for i, m in enumerate(self._expected_methods)])
       
    97     return "Verify: Expected methods never called:\n%s" % (calls,)
       
    98 
       
    99 
       
   100 class UnexpectedMethodCallError(Error):
       
   101   """Raised when an unexpected method is called.
       
   102 
       
   103   This can occur if a method is called with incorrect parameters, or out of the
       
   104   specified order.
       
   105   """
       
   106 
       
   107   def __init__(self, unexpected_method, expected):
       
   108     """Init exception.
       
   109 
       
   110     Args:
       
   111       # unexpected_method: MockMethod that was called but was not at the head of
       
   112       #   the expected_method queue.
       
   113       # expected: MockMethod or UnorderedGroup the method should have
       
   114       #   been in.
       
   115       unexpected_method: MockMethod
       
   116       expected: MockMethod or UnorderedGroup
       
   117     """
       
   118 
       
   119     Error.__init__(self)
       
   120     self._unexpected_method = unexpected_method
       
   121     self._expected = expected
       
   122 
       
   123   def __str__(self):
       
   124     return "Unexpected method call: %s.  Expecting: %s" % \
       
   125       (self._unexpected_method, self._expected)
       
   126 
       
   127 
       
   128 class UnknownMethodCallError(Error):
       
   129   """Raised if an unknown method is requested of the mock object."""
       
   130 
       
   131   def __init__(self, unknown_method_name):
       
   132     """Init exception.
       
   133 
       
   134     Args:
       
   135       # unknown_method_name: Method call that is not part of the mocked class's
       
   136       #   public interface.
       
   137       unknown_method_name: str
       
   138     """
       
   139 
       
   140     Error.__init__(self)
       
   141     self._unknown_method_name = unknown_method_name
       
   142 
       
   143   def __str__(self):
       
   144     return "Method called is not a member of the object: %s" % \
       
   145       self._unknown_method_name
       
   146 
       
   147 
       
   148 class Mox(object):
       
   149   """Mox: a factory for creating mock objects."""
       
   150 
       
   151   # A list of types that should be stubbed out with MockObjects (as
       
   152   # opposed to MockAnythings).
       
   153   _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
       
   154                       types.ObjectType, types.TypeType]
       
   155 
       
   156   def __init__(self):
       
   157     """Initialize a new Mox."""
       
   158 
       
   159     self._mock_objects = []
       
   160     self.stubs = stubout.StubOutForTesting()
       
   161 
       
   162   def CreateMock(self, class_to_mock):
       
   163     """Create a new mock object.
       
   164 
       
   165     Args:
       
   166       # class_to_mock: the class to be mocked
       
   167       class_to_mock: class
       
   168 
       
   169     Returns:
       
   170       MockObject that can be used as the class_to_mock would be.
       
   171     """
       
   172 
       
   173     new_mock = MockObject(class_to_mock)
       
   174     self._mock_objects.append(new_mock)
       
   175     return new_mock
       
   176 
       
   177   def CreateMockAnything(self):
       
   178     """Create a mock that will accept any method calls.
       
   179 
       
   180     This does not enforce an interface.
       
   181     """
       
   182 
       
   183     new_mock = MockAnything()
       
   184     self._mock_objects.append(new_mock)
       
   185     return new_mock
       
   186 
       
   187   def ReplayAll(self):
       
   188     """Set all mock objects to replay mode."""
       
   189 
       
   190     for mock_obj in self._mock_objects:
       
   191       mock_obj._Replay()
       
   192 
       
   193 
       
   194   def VerifyAll(self):
       
   195     """Call verify on all mock objects created."""
       
   196 
       
   197     for mock_obj in self._mock_objects:
       
   198       mock_obj._Verify()
       
   199 
       
   200   def ResetAll(self):
       
   201     """Call reset on all mock objects.  This does not unset stubs."""
       
   202 
       
   203     for mock_obj in self._mock_objects:
       
   204       mock_obj._Reset()
       
   205 
       
   206   def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
       
   207     """Replace a method, attribute, etc. with a Mock.
       
   208 
       
   209     This will replace a class or module with a MockObject, and everything else
       
   210     (method, function, etc) with a MockAnything.  This can be overridden to
       
   211     always use a MockAnything by setting use_mock_anything to True.
       
   212 
       
   213     Args:
       
   214       obj: A Python object (class, module, instance, callable).
       
   215       attr_name: str.  The name of the attribute to replace with a mock.
       
   216       use_mock_anything: bool. True if a MockAnything should be used regardless
       
   217         of the type of attribute.
       
   218     """
       
   219 
       
   220     attr_to_replace = getattr(obj, attr_name)
       
   221     if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
       
   222       stub = self.CreateMock(attr_to_replace)
       
   223     else:
       
   224       stub = self.CreateMockAnything()
       
   225 
       
   226     self.stubs.Set(obj, attr_name, stub)
       
   227 
       
   228   def UnsetStubs(self):
       
   229     """Restore stubs to their original state."""
       
   230 
       
   231     self.stubs.UnsetAll()
       
   232 
       
   233 def Replay(*args):
       
   234   """Put mocks into Replay mode.
       
   235 
       
   236   Args:
       
   237     # args is any number of mocks to put into replay mode.
       
   238   """
       
   239 
       
   240   for mock in args:
       
   241     mock._Replay()
       
   242 
       
   243 
       
   244 def Verify(*args):
       
   245   """Verify mocks.
       
   246 
       
   247   Args:
       
   248     # args is any number of mocks to be verified.
       
   249   """
       
   250 
       
   251   for mock in args:
       
   252     mock._Verify()
       
   253 
       
   254 
       
   255 def Reset(*args):
       
   256   """Reset mocks.
       
   257 
       
   258   Args:
       
   259     # args is any number of mocks to be reset.
       
   260   """
       
   261 
       
   262   for mock in args:
       
   263     mock._Reset()
       
   264 
       
   265 
       
   266 class MockAnything:
       
   267   """A mock that can be used to mock anything.
       
   268 
       
   269   This is helpful for mocking classes that do not provide a public interface.
       
   270   """
       
   271 
       
   272   def __init__(self):
       
   273     """ """
       
   274     self._Reset()
       
   275 
       
   276   def __getattr__(self, method_name):
       
   277     """Intercept method calls on this object.
       
   278 
       
   279      A new MockMethod is returned that is aware of the MockAnything's
       
   280      state (record or replay).  The call will be recorded or replayed
       
   281      by the MockMethod's __call__.
       
   282 
       
   283     Args:
       
   284       # method name: the name of the method being called.
       
   285       method_name: str
       
   286 
       
   287     Returns:
       
   288       A new MockMethod aware of MockAnything's state (record or replay).
       
   289     """
       
   290 
       
   291     return self._CreateMockMethod(method_name)
       
   292 
       
   293   def _CreateMockMethod(self, method_name, method_to_mock=None):
       
   294     """Create a new mock method call and return it.
       
   295 
       
   296     Args:
       
   297       # method_name: the name of the method being called.
       
   298       # method_to_mock: The actual method being mocked, used for introspection.
       
   299       method_name: str
       
   300       method_to_mock: a method object
       
   301 
       
   302     Returns:
       
   303       A new MockMethod aware of MockAnything's state (record or replay).
       
   304     """
       
   305 
       
   306     return MockMethod(method_name, self._expected_calls_queue,
       
   307                       self._replay_mode, method_to_mock=method_to_mock)
       
   308 
       
   309   def __nonzero__(self):
       
   310     """Return 1 for nonzero so the mock can be used as a conditional."""
       
   311 
       
   312     return 1
       
   313 
       
   314   def __eq__(self, rhs):
       
   315     """Provide custom logic to compare objects."""
       
   316 
       
   317     return (isinstance(rhs, MockAnything) and
       
   318             self._replay_mode == rhs._replay_mode and
       
   319             self._expected_calls_queue == rhs._expected_calls_queue)
       
   320 
       
   321   def __ne__(self, rhs):
       
   322     """Provide custom logic to compare objects."""
       
   323 
       
   324     return not self == rhs
       
   325 
       
   326   def _Replay(self):
       
   327     """Start replaying expected method calls."""
       
   328 
       
   329     self._replay_mode = True
       
   330 
       
   331   def _Verify(self):
       
   332     """Verify that all of the expected calls have been made.
       
   333 
       
   334     Raises:
       
   335       ExpectedMethodCallsError: if there are still more method calls in the
       
   336         expected queue.
       
   337     """
       
   338 
       
   339     # If the list of expected calls is not empty, raise an exception
       
   340     if self._expected_calls_queue:
       
   341       # The last MultipleTimesGroup is not popped from the queue.
       
   342       if (len(self._expected_calls_queue) == 1 and
       
   343           isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
       
   344           self._expected_calls_queue[0].IsSatisfied()):
       
   345         pass
       
   346       else:
       
   347         raise ExpectedMethodCallsError(self._expected_calls_queue)
       
   348 
       
   349   def _Reset(self):
       
   350     """Reset the state of this mock to record mode with an empty queue."""
       
   351 
       
   352     # Maintain a list of method calls we are expecting
       
   353     self._expected_calls_queue = deque()
       
   354 
       
   355     # Make sure we are in setup mode, not replay mode
       
   356     self._replay_mode = False
       
   357 
       
   358 
       
   359 class MockObject(MockAnything, object):
       
   360   """A mock object that simulates the public/protected interface of a class."""
       
   361 
       
   362   def __init__(self, class_to_mock):
       
   363     """Initialize a mock object.
       
   364 
       
   365     This determines the methods and properties of the class and stores them.
       
   366 
       
   367     Args:
       
   368       # class_to_mock: class to be mocked
       
   369       class_to_mock: class
       
   370     """
       
   371 
       
   372     # This is used to hack around the mixin/inheritance of MockAnything, which
       
   373     # is not a proper object (it can be anything. :-)
       
   374     MockAnything.__dict__['__init__'](self)
       
   375 
       
   376     # Get a list of all the public and special methods we should mock.
       
   377     self._known_methods = set()
       
   378     self._known_vars = set()
       
   379     self._class_to_mock = class_to_mock
       
   380     for method in dir(class_to_mock):
       
   381       if callable(getattr(class_to_mock, method)):
       
   382         self._known_methods.add(method)
       
   383       else:
       
   384         self._known_vars.add(method)
       
   385 
       
   386   def __getattr__(self, name):
       
   387     """Intercept attribute request on this object.
       
   388 
       
   389     If the attribute is a public class variable, it will be returned and not
       
   390     recorded as a call.
       
   391 
       
   392     If the attribute is not a variable, it is handled like a method
       
   393     call. The method name is checked against the set of mockable
       
   394     methods, and a new MockMethod is returned that is aware of the
       
   395     MockObject's state (record or replay).  The call will be recorded
       
   396     or replayed by the MockMethod's __call__.
       
   397 
       
   398     Args:
       
   399       # name: the name of the attribute being requested.
       
   400       name: str
       
   401 
       
   402     Returns:
       
   403       Either a class variable or a new MockMethod that is aware of the state
       
   404       of the mock (record or replay).
       
   405 
       
   406     Raises:
       
   407       UnknownMethodCallError if the MockObject does not mock the requested
       
   408           method.
       
   409     """
       
   410 
       
   411     if name in self._known_vars:
       
   412       return getattr(self._class_to_mock, name)
       
   413 
       
   414     if name in self._known_methods:
       
   415       return self._CreateMockMethod(
       
   416           name,
       
   417           method_to_mock=getattr(self._class_to_mock, name))
       
   418 
       
   419     raise UnknownMethodCallError(name)
       
   420 
       
   421   def __eq__(self, rhs):
       
   422     """Provide custom logic to compare objects."""
       
   423 
       
   424     return (isinstance(rhs, MockObject) and
       
   425             self._class_to_mock == rhs._class_to_mock and
       
   426             self._replay_mode == rhs._replay_mode and
       
   427             self._expected_calls_queue == rhs._expected_calls_queue)
       
   428 
       
   429   def __setitem__(self, key, value):
       
   430     """Provide custom logic for mocking classes that support item assignment.
       
   431 
       
   432     Args:
       
   433       key: Key to set the value for.
       
   434       value: Value to set.
       
   435 
       
   436     Returns:
       
   437       Expected return value in replay mode.  A MockMethod object for the
       
   438       __setitem__ method that has already been called if not in replay mode.
       
   439 
       
   440     Raises:
       
   441       TypeError if the underlying class does not support item assignment.
       
   442       UnexpectedMethodCallError if the object does not expect the call to
       
   443         __setitem__.
       
   444 
       
   445     """
       
   446     setitem = self._class_to_mock.__dict__.get('__setitem__', None)
       
   447 
       
   448     # Verify the class supports item assignment.
       
   449     if setitem is None:
       
   450       raise TypeError('object does not support item assignment')
       
   451 
       
   452     # If we are in replay mode then simply call the mock __setitem__ method.
       
   453     if self._replay_mode:
       
   454       return MockMethod('__setitem__', self._expected_calls_queue,
       
   455                         self._replay_mode)(key, value)
       
   456 
       
   457 
       
   458     # Otherwise, create a mock method __setitem__.
       
   459     return self._CreateMockMethod('__setitem__')(key, value)
       
   460 
       
   461   def __getitem__(self, key):
       
   462     """Provide custom logic for mocking classes that are subscriptable.
       
   463 
       
   464     Args:
       
   465       key: Key to return the value for.
       
   466 
       
   467     Returns:
       
   468       Expected return value in replay mode.  A MockMethod object for the
       
   469       __getitem__ method that has already been called if not in replay mode.
       
   470 
       
   471     Raises:
       
   472       TypeError if the underlying class is not subscriptable.
       
   473       UnexpectedMethodCallError if the object does not expect the call to
       
   474         __setitem__.
       
   475 
       
   476     """
       
   477     getitem = self._class_to_mock.__dict__.get('__getitem__', None)
       
   478 
       
   479     # Verify the class supports item assignment.
       
   480     if getitem is None:
       
   481       raise TypeError('unsubscriptable object')
       
   482 
       
   483     # If we are in replay mode then simply call the mock __getitem__ method.
       
   484     if self._replay_mode:
       
   485       return MockMethod('__getitem__', self._expected_calls_queue,
       
   486                         self._replay_mode)(key)
       
   487 
       
   488 
       
   489     # Otherwise, create a mock method __getitem__.
       
   490     return self._CreateMockMethod('__getitem__')(key)
       
   491 
       
   492   def __contains__(self, key):
       
   493     """Provide custom logic for mocking classes that contain items.
       
   494 
       
   495     Args:
       
   496       key: Key to look in container for.
       
   497 
       
   498     Returns:
       
   499       Expected return value in replay mode.  A MockMethod object for the
       
   500       __contains__ method that has already been called if not in replay mode.
       
   501 
       
   502     Raises:
       
   503       TypeError if the underlying class does not implement __contains__
       
   504       UnexpectedMethodCaller if the object does not expect the call to
       
   505       __contains__.
       
   506 
       
   507     """
       
   508     contains = self._class_to_mock.__dict__.get('__contains__', None)
       
   509 
       
   510     if contains is None:
       
   511       raise TypeError('unsubscriptable object')
       
   512 
       
   513     if self._replay_mode:
       
   514       return MockMethod('__contains__', self._expected_calls_queue,
       
   515                         self._replay_mode)(key)
       
   516 
       
   517     return self._CreateMockMethod('__contains__')(key)
       
   518 
       
   519   def __call__(self, *params, **named_params):
       
   520     """Provide custom logic for mocking classes that are callable."""
       
   521 
       
   522     # Verify the class we are mocking is callable
       
   523     callable = self._class_to_mock.__dict__.get('__call__', None)
       
   524     if callable is None:
       
   525       raise TypeError('Not callable')
       
   526 
       
   527     # Because the call is happening directly on this object instead of a method,
       
   528     # the call on the mock method is made right here
       
   529     mock_method = self._CreateMockMethod('__call__')
       
   530     return mock_method(*params, **named_params)
       
   531 
       
   532   @property
       
   533   def __class__(self):
       
   534     """Return the class that is being mocked."""
       
   535 
       
   536     return self._class_to_mock
       
   537 
       
   538 
       
   539 class MethodCallChecker(object):
       
   540   """Ensures that methods are called correctly."""
       
   541 
       
   542   _NEEDED, _DEFAULT, _GIVEN = range(3)
       
   543 
       
   544   def __init__(self, method):
       
   545     """Creates a checker.
       
   546 
       
   547     Args:
       
   548       # method: A method to check.
       
   549       method: function
       
   550 
       
   551     Raises:
       
   552       ValueError: method could not be inspected, so checks aren't possible.
       
   553         Some methods and functions like built-ins can't be inspected.
       
   554     """
       
   555     try:
       
   556       self._args, varargs, varkw, defaults = inspect.getargspec(method)
       
   557     except TypeError:
       
   558       raise ValueError('Could not get argument specification for %r'
       
   559                        % (method,))
       
   560     if inspect.ismethod(method):
       
   561       self._args = self._args[1:]  # Skip 'self'.
       
   562     self._method = method
       
   563 
       
   564     self._has_varargs = varargs is not None
       
   565     self._has_varkw = varkw is not None
       
   566     if defaults is None:
       
   567       self._required_args = self._args
       
   568       self._default_args = []
       
   569     else:
       
   570       self._required_args = self._args[:-len(defaults)]
       
   571       self._default_args = self._args[-len(defaults):]
       
   572 
       
   573   def _RecordArgumentGiven(self, arg_name, arg_status):
       
   574     """Mark an argument as being given.
       
   575 
       
   576     Args:
       
   577       # arg_name: The name of the argument to mark in arg_status.
       
   578       # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN.
       
   579       arg_name: string
       
   580       arg_status: dict
       
   581 
       
   582     Raises:
       
   583       AttributeError: arg_name is already marked as _GIVEN.
       
   584     """
       
   585     if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN:
       
   586       raise AttributeError('%s provided more than once' % (arg_name,))
       
   587     arg_status[arg_name] = MethodCallChecker._GIVEN
       
   588 
       
   589   def Check(self, params, named_params):
       
   590     """Ensures that the parameters used while recording a call are valid.
       
   591 
       
   592     Args:
       
   593       # params: A list of positional parameters.
       
   594       # named_params: A dict of named parameters.
       
   595       params: list
       
   596       named_params: dict
       
   597 
       
   598     Raises:
       
   599       AttributeError: the given parameters don't work with the given method.
       
   600     """
       
   601     arg_status = dict((a, MethodCallChecker._NEEDED)
       
   602                       for a in self._required_args)
       
   603     for arg in self._default_args:
       
   604       arg_status[arg] = MethodCallChecker._DEFAULT
       
   605 
       
   606     # Check that each positional param is valid.
       
   607     for i in range(len(params)):
       
   608       try:
       
   609         arg_name = self._args[i]
       
   610       except IndexError:
       
   611         if not self._has_varargs:
       
   612           raise AttributeError('%s does not take %d or more positional '
       
   613                                'arguments' % (self._method.__name__, i))
       
   614       else:
       
   615         self._RecordArgumentGiven(arg_name, arg_status)
       
   616 
       
   617     # Check each keyword argument.
       
   618     for arg_name in named_params:
       
   619       if arg_name not in arg_status and not self._has_varkw:
       
   620         raise AttributeError('%s is not expecting keyword argument %s'
       
   621                              % (self._method.__name__, arg_name))
       
   622       self._RecordArgumentGiven(arg_name, arg_status)
       
   623 
       
   624     # Ensure all the required arguments have been given.
       
   625     still_needed = [k for k, v in arg_status.iteritems()
       
   626                     if v == MethodCallChecker._NEEDED]
       
   627     if still_needed:
       
   628       raise AttributeError('No values given for arguments %s'
       
   629                            % (' '.join(sorted(still_needed))))
       
   630 
       
   631 
       
   632 class MockMethod(object):
       
   633   """Callable mock method.
       
   634 
       
   635   A MockMethod should act exactly like the method it mocks, accepting parameters
       
   636   and returning a value, or throwing an exception (as specified).  When this
       
   637   method is called, it can optionally verify whether the called method (name and
       
   638   signature) matches the expected method.
       
   639   """
       
   640 
       
   641   def __init__(self, method_name, call_queue, replay_mode, method_to_mock=None):
       
   642     """Construct a new mock method.
       
   643 
       
   644     Args:
       
   645       # method_name: the name of the method
       
   646       # call_queue: deque of calls, verify this call against the head, or add
       
   647       #     this call to the queue.
       
   648       # replay_mode: False if we are recording, True if we are verifying calls
       
   649       #     against the call queue.
       
   650       # method_to_mock: The actual method being mocked, used for introspection.
       
   651       method_name: str
       
   652       call_queue: list or deque
       
   653       replay_mode: bool
       
   654       method_to_mock: a method object
       
   655     """
       
   656 
       
   657     self._name = method_name
       
   658     self._call_queue = call_queue
       
   659     if not isinstance(call_queue, deque):
       
   660       self._call_queue = deque(self._call_queue)
       
   661     self._replay_mode = replay_mode
       
   662 
       
   663     self._params = None
       
   664     self._named_params = None
       
   665     self._return_value = None
       
   666     self._exception = None
       
   667     self._side_effects = None
       
   668 
       
   669     try:
       
   670       self._checker = MethodCallChecker(method_to_mock)
       
   671     except ValueError:
       
   672       self._checker = None
       
   673 
       
   674   def __call__(self, *params, **named_params):
       
   675     """Log parameters and return the specified return value.
       
   676 
       
   677     If the Mock(Anything/Object) associated with this call is in record mode,
       
   678     this MockMethod will be pushed onto the expected call queue.  If the mock
       
   679     is in replay mode, this will pop a MockMethod off the top of the queue and
       
   680     verify this call is equal to the expected call.
       
   681 
       
   682     Raises:
       
   683       UnexpectedMethodCall if this call is supposed to match an expected method
       
   684         call and it does not.
       
   685     """
       
   686 
       
   687     self._params = params
       
   688     self._named_params = named_params
       
   689 
       
   690     if not self._replay_mode:
       
   691       if self._checker is not None:
       
   692         self._checker.Check(params, named_params)
       
   693       self._call_queue.append(self)
       
   694       return self
       
   695 
       
   696     expected_method = self._VerifyMethodCall()
       
   697 
       
   698     if expected_method._side_effects:
       
   699       expected_method._side_effects(*params, **named_params)
       
   700 
       
   701     if expected_method._exception:
       
   702       raise expected_method._exception
       
   703 
       
   704     return expected_method._return_value
       
   705 
       
   706   def __getattr__(self, name):
       
   707     """Raise an AttributeError with a helpful message."""
       
   708 
       
   709     raise AttributeError('MockMethod has no attribute "%s". '
       
   710         'Did you remember to put your mocks in replay mode?' % name)
       
   711 
       
   712   def _PopNextMethod(self):
       
   713     """Pop the next method from our call queue."""
       
   714     try:
       
   715       return self._call_queue.popleft()
       
   716     except IndexError:
       
   717       raise UnexpectedMethodCallError(self, None)
       
   718 
       
   719   def _VerifyMethodCall(self):
       
   720     """Verify the called method is expected.
       
   721 
       
   722     This can be an ordered method, or part of an unordered set.
       
   723 
       
   724     Returns:
       
   725       The expected mock method.
       
   726 
       
   727     Raises:
       
   728       UnexpectedMethodCall if the method called was not expected.
       
   729     """
       
   730 
       
   731     expected = self._PopNextMethod()
       
   732 
       
   733     # Loop here, because we might have a MethodGroup followed by another
       
   734     # group.
       
   735     while isinstance(expected, MethodGroup):
       
   736       expected, method = expected.MethodCalled(self)
       
   737       if method is not None:
       
   738         return method
       
   739 
       
   740     # This is a mock method, so just check equality.
       
   741     if expected != self:
       
   742       raise UnexpectedMethodCallError(self, expected)
       
   743 
       
   744     return expected
       
   745 
       
   746   def __str__(self):
       
   747     params = ', '.join(
       
   748         [repr(p) for p in self._params or []] +
       
   749         ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
       
   750     desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
       
   751     return desc
       
   752 
       
   753   def __eq__(self, rhs):
       
   754     """Test whether this MockMethod is equivalent to another MockMethod.
       
   755 
       
   756     Args:
       
   757       # rhs: the right hand side of the test
       
   758       rhs: MockMethod
       
   759     """
       
   760 
       
   761     return (isinstance(rhs, MockMethod) and
       
   762             self._name == rhs._name and
       
   763             self._params == rhs._params and
       
   764             self._named_params == rhs._named_params)
       
   765 
       
   766   def __ne__(self, rhs):
       
   767     """Test whether this MockMethod is not equivalent to another MockMethod.
       
   768 
       
   769     Args:
       
   770       # rhs: the right hand side of the test
       
   771       rhs: MockMethod
       
   772     """
       
   773 
       
   774     return not self == rhs
       
   775 
       
   776   def GetPossibleGroup(self):
       
   777     """Returns a possible group from the end of the call queue or None if no
       
   778     other methods are on the stack.
       
   779     """
       
   780 
       
   781     # Remove this method from the tail of the queue so we can add it to a group.
       
   782     this_method = self._call_queue.pop()
       
   783     assert this_method == self
       
   784 
       
   785     # Determine if the tail of the queue is a group, or just a regular ordered
       
   786     # mock method.
       
   787     group = None
       
   788     try:
       
   789       group = self._call_queue[-1]
       
   790     except IndexError:
       
   791       pass
       
   792 
       
   793     return group
       
   794 
       
   795   def _CheckAndCreateNewGroup(self, group_name, group_class):
       
   796     """Checks if the last method (a possible group) is an instance of our
       
   797     group_class. Adds the current method to this group or creates a new one.
       
   798 
       
   799     Args:
       
   800 
       
   801       group_name: the name of the group.
       
   802       group_class: the class used to create instance of this new group
       
   803     """
       
   804     group = self.GetPossibleGroup()
       
   805 
       
   806     # If this is a group, and it is the correct group, add the method.
       
   807     if isinstance(group, group_class) and group.group_name() == group_name:
       
   808       group.AddMethod(self)
       
   809       return self
       
   810 
       
   811     # Create a new group and add the method.
       
   812     new_group = group_class(group_name)
       
   813     new_group.AddMethod(self)
       
   814     self._call_queue.append(new_group)
       
   815     return self
       
   816 
       
   817   def InAnyOrder(self, group_name="default"):
       
   818     """Move this method into a group of unordered calls.
       
   819 
       
   820     A group of unordered calls must be defined together, and must be executed
       
   821     in full before the next expected method can be called.  There can be
       
   822     multiple groups that are expected serially, if they are given
       
   823     different group names.  The same group name can be reused if there is a
       
   824     standard method call, or a group with a different name, spliced between
       
   825     usages.
       
   826 
       
   827     Args:
       
   828       group_name: the name of the unordered group.
       
   829 
       
   830     Returns:
       
   831       self
       
   832     """
       
   833     return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
       
   834 
       
   835   def MultipleTimes(self, group_name="default"):
       
   836     """Move this method into group of calls which may be called multiple times.
       
   837 
       
   838     A group of repeating calls must be defined together, and must be executed in
       
   839     full before the next expected mehtod can be called.
       
   840 
       
   841     Args:
       
   842       group_name: the name of the unordered group.
       
   843 
       
   844     Returns:
       
   845       self
       
   846     """
       
   847     return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
       
   848 
       
   849   def AndReturn(self, return_value):
       
   850     """Set the value to return when this method is called.
       
   851 
       
   852     Args:
       
   853       # return_value can be anything.
       
   854     """
       
   855 
       
   856     self._return_value = return_value
       
   857     return return_value
       
   858 
       
   859   def AndRaise(self, exception):
       
   860     """Set the exception to raise when this method is called.
       
   861 
       
   862     Args:
       
   863       # exception: the exception to raise when this method is called.
       
   864       exception: Exception
       
   865     """
       
   866 
       
   867     self._exception = exception
       
   868 
       
   869   def WithSideEffects(self, side_effects):
       
   870     """Set the side effects that are simulated when this method is called.
       
   871 
       
   872     Args:
       
   873       side_effects: A callable which modifies the parameters or other relevant
       
   874         state which a given test case depends on.
       
   875 
       
   876     Returns:
       
   877       Self for chaining with AndReturn and AndRaise.
       
   878     """
       
   879     self._side_effects = side_effects
       
   880     return self
       
   881 
       
   882 class Comparator:
       
   883   """Base class for all Mox comparators.
       
   884 
       
   885   A Comparator can be used as a parameter to a mocked method when the exact
       
   886   value is not known.  For example, the code you are testing might build up a
       
   887   long SQL string that is passed to your mock DAO. You're only interested that
       
   888   the IN clause contains the proper primary keys, so you can set your mock
       
   889   up as follows:
       
   890 
       
   891   mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
       
   892 
       
   893   Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
       
   894 
       
   895   A Comparator may replace one or more parameters, for example:
       
   896   # return at most 10 rows
       
   897   mock_dao.RunQuery(StrContains('SELECT'), 10)
       
   898 
       
   899   or
       
   900 
       
   901   # Return some non-deterministic number of rows
       
   902   mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
       
   903   """
       
   904 
       
   905   def equals(self, rhs):
       
   906     """Special equals method that all comparators must implement.
       
   907 
       
   908     Args:
       
   909       rhs: any python object
       
   910     """
       
   911 
       
   912     raise NotImplementedError, 'method must be implemented by a subclass.'
       
   913 
       
   914   def __eq__(self, rhs):
       
   915     return self.equals(rhs)
       
   916 
       
   917   def __ne__(self, rhs):
       
   918     return not self.equals(rhs)
       
   919 
       
   920 
       
   921 class IsA(Comparator):
       
   922   """This class wraps a basic Python type or class.  It is used to verify
       
   923   that a parameter is of the given type or class.
       
   924 
       
   925   Example:
       
   926   mock_dao.Connect(IsA(DbConnectInfo))
       
   927   """
       
   928 
       
   929   def __init__(self, class_name):
       
   930     """Initialize IsA
       
   931 
       
   932     Args:
       
   933       class_name: basic python type or a class
       
   934     """
       
   935 
       
   936     self._class_name = class_name
       
   937 
       
   938   def equals(self, rhs):
       
   939     """Check to see if the RHS is an instance of class_name.
       
   940 
       
   941     Args:
       
   942       # rhs: the right hand side of the test
       
   943       rhs: object
       
   944 
       
   945     Returns:
       
   946       bool
       
   947     """
       
   948 
       
   949     try:
       
   950       return isinstance(rhs, self._class_name)
       
   951     except TypeError:
       
   952       # Check raw types if there was a type error.  This is helpful for
       
   953       # things like cStringIO.StringIO.
       
   954       return type(rhs) == type(self._class_name)
       
   955 
       
   956   def __repr__(self):
       
   957     return str(self._class_name)
       
   958 
       
   959 class IsAlmost(Comparator):
       
   960   """Comparison class used to check whether a parameter is nearly equal
       
   961   to a given value.  Generally useful for floating point numbers.
       
   962 
       
   963   Example mock_dao.SetTimeout((IsAlmost(3.9)))
       
   964   """
       
   965 
       
   966   def __init__(self, float_value, places=7):
       
   967     """Initialize IsAlmost.
       
   968 
       
   969     Args:
       
   970       float_value: The value for making the comparison.
       
   971       places: The number of decimal places to round to.
       
   972     """
       
   973 
       
   974     self._float_value = float_value
       
   975     self._places = places
       
   976 
       
   977   def equals(self, rhs):
       
   978     """Check to see if RHS is almost equal to float_value
       
   979 
       
   980     Args:
       
   981       rhs: the value to compare to float_value
       
   982 
       
   983     Returns:
       
   984       bool
       
   985     """
       
   986 
       
   987     try:
       
   988       return round(rhs-self._float_value, self._places) == 0
       
   989     except TypeError:
       
   990       # This is probably because either float_value or rhs is not a number.
       
   991       return False
       
   992 
       
   993   def __repr__(self):
       
   994     return str(self._float_value)
       
   995 
       
   996 class StrContains(Comparator):
       
   997   """Comparison class used to check whether a substring exists in a
       
   998   string parameter.  This can be useful in mocking a database with SQL
       
   999   passed in as a string parameter, for example.
       
  1000 
       
  1001   Example:
       
  1002   mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
       
  1003   """
       
  1004 
       
  1005   def __init__(self, search_string):
       
  1006     """Initialize.
       
  1007 
       
  1008     Args:
       
  1009       # search_string: the string you are searching for
       
  1010       search_string: str
       
  1011     """
       
  1012 
       
  1013     self._search_string = search_string
       
  1014 
       
  1015   def equals(self, rhs):
       
  1016     """Check to see if the search_string is contained in the rhs string.
       
  1017 
       
  1018     Args:
       
  1019       # rhs: the right hand side of the test
       
  1020       rhs: object
       
  1021 
       
  1022     Returns:
       
  1023       bool
       
  1024     """
       
  1025 
       
  1026     try:
       
  1027       return rhs.find(self._search_string) > -1
       
  1028     except Exception:
       
  1029       return False
       
  1030 
       
  1031   def __repr__(self):
       
  1032     return '<str containing \'%s\'>' % self._search_string
       
  1033 
       
  1034 
       
  1035 class Regex(Comparator):
       
  1036   """Checks if a string matches a regular expression.
       
  1037 
       
  1038   This uses a given regular expression to determine equality.
       
  1039   """
       
  1040 
       
  1041   def __init__(self, pattern, flags=0):
       
  1042     """Initialize.
       
  1043 
       
  1044     Args:
       
  1045       # pattern is the regular expression to search for
       
  1046       pattern: str
       
  1047       # flags passed to re.compile function as the second argument
       
  1048       flags: int
       
  1049     """
       
  1050 
       
  1051     self.regex = re.compile(pattern, flags=flags)
       
  1052 
       
  1053   def equals(self, rhs):
       
  1054     """Check to see if rhs matches regular expression pattern.
       
  1055 
       
  1056     Returns:
       
  1057       bool
       
  1058     """
       
  1059 
       
  1060     return self.regex.search(rhs) is not None
       
  1061 
       
  1062   def __repr__(self):
       
  1063     s = '<regular expression \'%s\'' % self.regex.pattern
       
  1064     if self.regex.flags:
       
  1065       s += ', flags=%d' % self.regex.flags
       
  1066     s += '>'
       
  1067     return s
       
  1068 
       
  1069 
       
  1070 class In(Comparator):
       
  1071   """Checks whether an item (or key) is in a list (or dict) parameter.
       
  1072 
       
  1073   Example:
       
  1074   mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
       
  1075   """
       
  1076 
       
  1077   def __init__(self, key):
       
  1078     """Initialize.
       
  1079 
       
  1080     Args:
       
  1081       # key is any thing that could be in a list or a key in a dict
       
  1082     """
       
  1083 
       
  1084     self._key = key
       
  1085 
       
  1086   def equals(self, rhs):
       
  1087     """Check to see whether key is in rhs.
       
  1088 
       
  1089     Args:
       
  1090       rhs: dict
       
  1091 
       
  1092     Returns:
       
  1093       bool
       
  1094     """
       
  1095 
       
  1096     return self._key in rhs
       
  1097 
       
  1098   def __repr__(self):
       
  1099     return '<sequence or map containing \'%s\'>' % self._key
       
  1100 
       
  1101 
       
  1102 class Not(Comparator):
       
  1103   """Checks whether a predicates is False.
       
  1104 
       
  1105   Example:
       
  1106     mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', stevepm_user_info)))
       
  1107   """
       
  1108 
       
  1109   def __init__(self, predicate):
       
  1110     """Initialize.
       
  1111 
       
  1112     Args:
       
  1113       # predicate: a Comparator instance.
       
  1114     """
       
  1115 
       
  1116     assert isinstance(predicate, Comparator), ("predicate %r must be a"
       
  1117                                                " Comparator." % predicate)
       
  1118     self._predicate = predicate
       
  1119 
       
  1120   def equals(self, rhs):
       
  1121     """Check to see whether the predicate is False.
       
  1122 
       
  1123     Args:
       
  1124       rhs: A value that will be given in argument of the predicate.
       
  1125 
       
  1126     Returns:
       
  1127       bool
       
  1128     """
       
  1129 
       
  1130     return not self._predicate.equals(rhs)
       
  1131 
       
  1132   def __repr__(self):
       
  1133     return '<not \'%s\'>' % self._predicate
       
  1134 
       
  1135 
       
  1136 class ContainsKeyValue(Comparator):
       
  1137   """Checks whether a key/value pair is in a dict parameter.
       
  1138 
       
  1139   Example:
       
  1140   mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
       
  1141   """
       
  1142 
       
  1143   def __init__(self, key, value):
       
  1144     """Initialize.
       
  1145 
       
  1146     Args:
       
  1147       # key: a key in a dict
       
  1148       # value: the corresponding value
       
  1149     """
       
  1150 
       
  1151     self._key = key
       
  1152     self._value = value
       
  1153 
       
  1154   def equals(self, rhs):
       
  1155     """Check whether the given key/value pair is in the rhs dict.
       
  1156 
       
  1157     Returns:
       
  1158       bool
       
  1159     """
       
  1160 
       
  1161     try:
       
  1162       return rhs[self._key] == self._value
       
  1163     except Exception:
       
  1164       return False
       
  1165 
       
  1166   def __repr__(self):
       
  1167     return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
       
  1168 
       
  1169 
       
  1170 class SameElementsAs(Comparator):
       
  1171   """Checks whether iterables contain the same elements (ignoring order).
       
  1172 
       
  1173   Example:
       
  1174   mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
       
  1175   """
       
  1176 
       
  1177   def __init__(self, expected_seq):
       
  1178     """Initialize.
       
  1179 
       
  1180     Args:
       
  1181       expected_seq: a sequence
       
  1182     """
       
  1183 
       
  1184     self._expected_seq = expected_seq
       
  1185 
       
  1186   def equals(self, actual_seq):
       
  1187     """Check to see whether actual_seq has same elements as expected_seq.
       
  1188 
       
  1189     Args:
       
  1190       actual_seq: sequence
       
  1191 
       
  1192     Returns:
       
  1193       bool
       
  1194     """
       
  1195 
       
  1196     try:
       
  1197       expected = dict([(element, None) for element in self._expected_seq])
       
  1198       actual = dict([(element, None) for element in actual_seq])
       
  1199     except TypeError:
       
  1200       # Fall back to slower list-compare if any of the objects are unhashable.
       
  1201       expected = list(self._expected_seq)
       
  1202       actual = list(actual_seq)
       
  1203       expected.sort()
       
  1204       actual.sort()
       
  1205     return expected == actual
       
  1206 
       
  1207   def __repr__(self):
       
  1208     return '<sequence with same elements as \'%s\'>' % self._expected_seq
       
  1209 
       
  1210 
       
  1211 class And(Comparator):
       
  1212   """Evaluates one or more Comparators on RHS and returns an AND of the results.
       
  1213   """
       
  1214 
       
  1215   def __init__(self, *args):
       
  1216     """Initialize.
       
  1217 
       
  1218     Args:
       
  1219       *args: One or more Comparator
       
  1220     """
       
  1221 
       
  1222     self._comparators = args
       
  1223 
       
  1224   def equals(self, rhs):
       
  1225     """Checks whether all Comparators are equal to rhs.
       
  1226 
       
  1227     Args:
       
  1228       # rhs: can be anything
       
  1229 
       
  1230     Returns:
       
  1231       bool
       
  1232     """
       
  1233 
       
  1234     for comparator in self._comparators:
       
  1235       if not comparator.equals(rhs):
       
  1236         return False
       
  1237 
       
  1238     return True
       
  1239 
       
  1240   def __repr__(self):
       
  1241     return '<AND %s>' % str(self._comparators)
       
  1242 
       
  1243 
       
  1244 class Or(Comparator):
       
  1245   """Evaluates one or more Comparators on RHS and returns an OR of the results.
       
  1246   """
       
  1247 
       
  1248   def __init__(self, *args):
       
  1249     """Initialize.
       
  1250 
       
  1251     Args:
       
  1252       *args: One or more Mox comparators
       
  1253     """
       
  1254 
       
  1255     self._comparators = args
       
  1256 
       
  1257   def equals(self, rhs):
       
  1258     """Checks whether any Comparator is equal to rhs.
       
  1259 
       
  1260     Args:
       
  1261       # rhs: can be anything
       
  1262 
       
  1263     Returns:
       
  1264       bool
       
  1265     """
       
  1266 
       
  1267     for comparator in self._comparators:
       
  1268       if comparator.equals(rhs):
       
  1269         return True
       
  1270 
       
  1271     return False
       
  1272 
       
  1273   def __repr__(self):
       
  1274     return '<OR %s>' % str(self._comparators)
       
  1275 
       
  1276 
       
  1277 class Func(Comparator):
       
  1278   """Call a function that should verify the parameter passed in is correct.
       
  1279 
       
  1280   You may need the ability to perform more advanced operations on the parameter
       
  1281   in order to validate it.  You can use this to have a callable validate any
       
  1282   parameter. The callable should return either True or False.
       
  1283 
       
  1284 
       
  1285   Example:
       
  1286 
       
  1287   def myParamValidator(param):
       
  1288     # Advanced logic here
       
  1289     return True
       
  1290 
       
  1291   mock_dao.DoSomething(Func(myParamValidator), true)
       
  1292   """
       
  1293 
       
  1294   def __init__(self, func):
       
  1295     """Initialize.
       
  1296 
       
  1297     Args:
       
  1298       func: callable that takes one parameter and returns a bool
       
  1299     """
       
  1300 
       
  1301     self._func = func
       
  1302 
       
  1303   def equals(self, rhs):
       
  1304     """Test whether rhs passes the function test.
       
  1305 
       
  1306     rhs is passed into func.
       
  1307 
       
  1308     Args:
       
  1309       rhs: any python object
       
  1310 
       
  1311     Returns:
       
  1312       the result of func(rhs)
       
  1313     """
       
  1314 
       
  1315     return self._func(rhs)
       
  1316 
       
  1317   def __repr__(self):
       
  1318     return str(self._func)
       
  1319 
       
  1320 
       
  1321 class IgnoreArg(Comparator):
       
  1322   """Ignore an argument.
       
  1323 
       
  1324   This can be used when we don't care about an argument of a method call.
       
  1325 
       
  1326   Example:
       
  1327   # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
       
  1328   mymock.CastMagic(3, IgnoreArg(), 'disappear')
       
  1329   """
       
  1330 
       
  1331   def equals(self, unused_rhs):
       
  1332     """Ignores arguments and returns True.
       
  1333 
       
  1334     Args:
       
  1335       unused_rhs: any python object
       
  1336 
       
  1337     Returns:
       
  1338       always returns True
       
  1339     """
       
  1340 
       
  1341     return True
       
  1342 
       
  1343   def __repr__(self):
       
  1344     return '<IgnoreArg>'
       
  1345 
       
  1346 
       
  1347 class MethodGroup(object):
       
  1348   """Base class containing common behaviour for MethodGroups."""
       
  1349 
       
  1350   def __init__(self, group_name):
       
  1351     self._group_name = group_name
       
  1352 
       
  1353   def group_name(self):
       
  1354     return self._group_name
       
  1355 
       
  1356   def __str__(self):
       
  1357     return '<%s "%s">' % (self.__class__.__name__, self._group_name)
       
  1358 
       
  1359   def AddMethod(self, mock_method):
       
  1360     raise NotImplementedError
       
  1361 
       
  1362   def MethodCalled(self, mock_method):
       
  1363     raise NotImplementedError
       
  1364 
       
  1365   def IsSatisfied(self):
       
  1366     raise NotImplementedError
       
  1367 
       
  1368 class UnorderedGroup(MethodGroup):
       
  1369   """UnorderedGroup holds a set of method calls that may occur in any order.
       
  1370 
       
  1371   This construct is helpful for non-deterministic events, such as iterating
       
  1372   over the keys of a dict.
       
  1373   """
       
  1374 
       
  1375   def __init__(self, group_name):
       
  1376     super(UnorderedGroup, self).__init__(group_name)
       
  1377     self._methods = []
       
  1378 
       
  1379   def AddMethod(self, mock_method):
       
  1380     """Add a method to this group.
       
  1381 
       
  1382     Args:
       
  1383       mock_method: A mock method to be added to this group.
       
  1384     """
       
  1385 
       
  1386     self._methods.append(mock_method)
       
  1387 
       
  1388   def MethodCalled(self, mock_method):
       
  1389     """Remove a method call from the group.
       
  1390 
       
  1391     If the method is not in the set, an UnexpectedMethodCallError will be
       
  1392     raised.
       
  1393 
       
  1394     Args:
       
  1395       mock_method: a mock method that should be equal to a method in the group.
       
  1396 
       
  1397     Returns:
       
  1398       The mock method from the group
       
  1399 
       
  1400     Raises:
       
  1401       UnexpectedMethodCallError if the mock_method was not in the group.
       
  1402     """
       
  1403 
       
  1404     # Check to see if this method exists, and if so, remove it from the set
       
  1405     # and return it.
       
  1406     for method in self._methods:
       
  1407       if method == mock_method:
       
  1408         # Remove the called mock_method instead of the method in the group.
       
  1409         # The called method will match any comparators when equality is checked
       
  1410         # during removal.  The method in the group could pass a comparator to
       
  1411         # another comparator during the equality check.
       
  1412         self._methods.remove(mock_method)
       
  1413 
       
  1414         # If this group is not empty, put it back at the head of the queue.
       
  1415         if not self.IsSatisfied():
       
  1416           mock_method._call_queue.appendleft(self)
       
  1417 
       
  1418         return self, method
       
  1419 
       
  1420     raise UnexpectedMethodCallError(mock_method, self)
       
  1421 
       
  1422   def IsSatisfied(self):
       
  1423     """Return True if there are not any methods in this group."""
       
  1424 
       
  1425     return len(self._methods) == 0
       
  1426 
       
  1427 
       
  1428 class MultipleTimesGroup(MethodGroup):
       
  1429   """MultipleTimesGroup holds methods that may be called any number of times.
       
  1430 
       
  1431   Note: Each method must be called at least once.
       
  1432 
       
  1433   This is helpful, if you don't know or care how many times a method is called.
       
  1434   """
       
  1435 
       
  1436   def __init__(self, group_name):
       
  1437     super(MultipleTimesGroup, self).__init__(group_name)
       
  1438     self._methods = set()
       
  1439     self._methods_called = set()
       
  1440 
       
  1441   def AddMethod(self, mock_method):
       
  1442     """Add a method to this group.
       
  1443 
       
  1444     Args:
       
  1445       mock_method: A mock method to be added to this group.
       
  1446     """
       
  1447 
       
  1448     self._methods.add(mock_method)
       
  1449 
       
  1450   def MethodCalled(self, mock_method):
       
  1451     """Remove a method call from the group.
       
  1452 
       
  1453     If the method is not in the set, an UnexpectedMethodCallError will be
       
  1454     raised.
       
  1455 
       
  1456     Args:
       
  1457       mock_method: a mock method that should be equal to a method in the group.
       
  1458 
       
  1459     Returns:
       
  1460       The mock method from the group
       
  1461 
       
  1462     Raises:
       
  1463       UnexpectedMethodCallError if the mock_method was not in the group.
       
  1464     """
       
  1465 
       
  1466     # Check to see if this method exists, and if so add it to the set of
       
  1467     # called methods.
       
  1468 
       
  1469     for method in self._methods:
       
  1470       if method == mock_method:
       
  1471         self._methods_called.add(mock_method)
       
  1472         # Always put this group back on top of the queue, because we don't know
       
  1473         # when we are done.
       
  1474         mock_method._call_queue.appendleft(self)
       
  1475         return self, method
       
  1476 
       
  1477     if self.IsSatisfied():
       
  1478       next_method = mock_method._PopNextMethod();
       
  1479       return next_method, None
       
  1480     else:
       
  1481       raise UnexpectedMethodCallError(mock_method, self)
       
  1482 
       
  1483   def IsSatisfied(self):
       
  1484     """Return True if all methods in this group are called at least once."""
       
  1485     # NOTE(psycho): We can't use the simple set difference here because we want
       
  1486     # to match different parameters which are considered the same e.g. IsA(str)
       
  1487     # and some string. This solution is O(n^2) but n should be small.
       
  1488     tmp = self._methods.copy()
       
  1489     for called in self._methods_called:
       
  1490       for expected in tmp:
       
  1491         if called == expected:
       
  1492           tmp.remove(expected)
       
  1493           if not tmp:
       
  1494             return True
       
  1495           break
       
  1496     return False
       
  1497 
       
  1498 
       
  1499 class MoxMetaTestBase(type):
       
  1500   """Metaclass to add mox cleanup and verification to every test.
       
  1501 
       
  1502   As the mox unit testing class is being constructed (MoxTestBase or a
       
  1503   subclass), this metaclass will modify all test functions to call the
       
  1504   CleanUpMox method of the test class after they finish. This means that
       
  1505   unstubbing and verifying will happen for every test with no additional code,
       
  1506   and any failures will result in test failures as opposed to errors.
       
  1507   """
       
  1508 
       
  1509   def __init__(cls, name, bases, d):
       
  1510     type.__init__(cls, name, bases, d)
       
  1511 
       
  1512     # also get all the attributes from the base classes to account
       
  1513     # for a case when test class is not the immediate child of MoxTestBase
       
  1514     for base in bases:
       
  1515       for attr_name in dir(base):
       
  1516         d[attr_name] = getattr(base, attr_name)
       
  1517 
       
  1518     for func_name, func in d.items():
       
  1519       if func_name.startswith('test') and callable(func):
       
  1520         setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
       
  1521 
       
  1522   @staticmethod
       
  1523   def CleanUpTest(cls, func):
       
  1524     """Adds Mox cleanup code to any MoxTestBase method.
       
  1525 
       
  1526     Always unsets stubs after a test. Will verify all mocks for tests that
       
  1527     otherwise pass.
       
  1528 
       
  1529     Args:
       
  1530       cls: MoxTestBase or subclass; the class whose test method we are altering.
       
  1531       func: method; the method of the MoxTestBase test class we wish to alter.
       
  1532 
       
  1533     Returns:
       
  1534       The modified method.
       
  1535     """
       
  1536     def new_method(self, *args, **kwargs):
       
  1537       mox_obj = getattr(self, 'mox', None)
       
  1538       cleanup_mox = False
       
  1539       if mox_obj and isinstance(mox_obj, Mox):
       
  1540         cleanup_mox = True
       
  1541       try:
       
  1542         func(self, *args, **kwargs)
       
  1543       finally:
       
  1544         if cleanup_mox:
       
  1545           mox_obj.UnsetStubs()
       
  1546       if cleanup_mox:
       
  1547         mox_obj.VerifyAll()
       
  1548     new_method.__name__ = func.__name__
       
  1549     new_method.__doc__ = func.__doc__
       
  1550     new_method.__module__ = func.__module__
       
  1551     return new_method
       
  1552 
       
  1553 
       
  1554 class MoxTestBase(unittest.TestCase):
       
  1555   """Convenience test class to make stubbing easier.
       
  1556 
       
  1557   Sets up a "mox" attribute which is an instance of Mox - any mox tests will
       
  1558   want this. Also automatically unsets any stubs and verifies that all mock
       
  1559   methods have been called at the end of each test, eliminating boilerplate
       
  1560   code.
       
  1561   """
       
  1562 
       
  1563   __metaclass__ = MoxMetaTestBase
       
  1564 
       
  1565   def setUp(self):
       
  1566     super(MoxTestBase, self).setUp()
       
  1567     self.mox = Mox()