|
1 #!/usr/bin/python2.4 |
|
2 # |
|
3 # Copyright 2008 Google Inc. |
|
4 # |
|
5 # Licensed under the Apache License, Version 2.0 (the "License"); |
|
6 # you may not use this file except in compliance with the License. |
|
7 # You may obtain a copy of the License at |
|
8 # |
|
9 # http://www.apache.org/licenses/LICENSE-2.0 |
|
10 # |
|
11 # Unless required by applicable law or agreed to in writing, software |
|
12 # distributed under the License is distributed on an "AS IS" BASIS, |
|
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
14 # See the License for the specific language governing permissions and |
|
15 # limitations under the License. |
|
16 |
|
17 """Mox, an object-mocking framework for Python. |
|
18 |
|
19 Mox works in the record-replay-verify paradigm. When you first create |
|
20 a mock object, it is in record mode. You then programmatically set |
|
21 the expected behavior of the mock object (what methods are to be |
|
22 called on it, with what parameters, what they should return, and in |
|
23 what order). |
|
24 |
|
25 Once you have set up the expected mock behavior, you put it in replay |
|
26 mode. Now the mock responds to method calls just as you told it to. |
|
27 If an unexpected method (or an expected method with unexpected |
|
28 parameters) is called, then an exception will be raised. |
|
29 |
|
30 Once you are done interacting with the mock, you need to verify that |
|
31 all the expected interactions occured. (Maybe your code exited |
|
32 prematurely without calling some cleanup method!) The verify phase |
|
33 ensures that every expected method was called; otherwise, an exception |
|
34 will be raised. |
|
35 |
|
36 Suggested usage / workflow: |
|
37 |
|
38 # Create Mox factory |
|
39 my_mox = Mox() |
|
40 |
|
41 # Create a mock data access object |
|
42 mock_dao = my_mox.CreateMock(DAOClass) |
|
43 |
|
44 # Set up expected behavior |
|
45 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) |
|
46 mock_dao.DeletePerson(person) |
|
47 |
|
48 # Put mocks in replay mode |
|
49 my_mox.ReplayAll() |
|
50 |
|
51 # Inject mock object and run test |
|
52 controller.SetDao(mock_dao) |
|
53 controller.DeletePersonById('1') |
|
54 |
|
55 # Verify all methods were called as expected |
|
56 my_mox.VerifyAll() |
|
57 """ |
|
58 |
|
59 from collections import deque |
|
60 import inspect |
|
61 import re |
|
62 import types |
|
63 import unittest |
|
64 |
|
65 import stubout |
|
66 |
|
67 class Error(AssertionError): |
|
68 """Base exception for this module.""" |
|
69 |
|
70 pass |
|
71 |
|
72 |
|
73 class ExpectedMethodCallsError(Error): |
|
74 """Raised when Verify() is called before all expected methods have been called |
|
75 """ |
|
76 |
|
77 def __init__(self, expected_methods): |
|
78 """Init exception. |
|
79 |
|
80 Args: |
|
81 # expected_methods: A sequence of MockMethod objects that should have been |
|
82 # called. |
|
83 expected_methods: [MockMethod] |
|
84 |
|
85 Raises: |
|
86 ValueError: if expected_methods contains no methods. |
|
87 """ |
|
88 |
|
89 if not expected_methods: |
|
90 raise ValueError("There must be at least one expected method") |
|
91 Error.__init__(self) |
|
92 self._expected_methods = expected_methods |
|
93 |
|
94 def __str__(self): |
|
95 calls = "\n".join(["%3d. %s" % (i, m) |
|
96 for i, m in enumerate(self._expected_methods)]) |
|
97 return "Verify: Expected methods never called:\n%s" % (calls,) |
|
98 |
|
99 |
|
100 class UnexpectedMethodCallError(Error): |
|
101 """Raised when an unexpected method is called. |
|
102 |
|
103 This can occur if a method is called with incorrect parameters, or out of the |
|
104 specified order. |
|
105 """ |
|
106 |
|
107 def __init__(self, unexpected_method, expected): |
|
108 """Init exception. |
|
109 |
|
110 Args: |
|
111 # unexpected_method: MockMethod that was called but was not at the head of |
|
112 # the expected_method queue. |
|
113 # expected: MockMethod or UnorderedGroup the method should have |
|
114 # been in. |
|
115 unexpected_method: MockMethod |
|
116 expected: MockMethod or UnorderedGroup |
|
117 """ |
|
118 |
|
119 Error.__init__(self) |
|
120 self._unexpected_method = unexpected_method |
|
121 self._expected = expected |
|
122 |
|
123 def __str__(self): |
|
124 return "Unexpected method call: %s. Expecting: %s" % \ |
|
125 (self._unexpected_method, self._expected) |
|
126 |
|
127 |
|
128 class UnknownMethodCallError(Error): |
|
129 """Raised if an unknown method is requested of the mock object.""" |
|
130 |
|
131 def __init__(self, unknown_method_name): |
|
132 """Init exception. |
|
133 |
|
134 Args: |
|
135 # unknown_method_name: Method call that is not part of the mocked class's |
|
136 # public interface. |
|
137 unknown_method_name: str |
|
138 """ |
|
139 |
|
140 Error.__init__(self) |
|
141 self._unknown_method_name = unknown_method_name |
|
142 |
|
143 def __str__(self): |
|
144 return "Method called is not a member of the object: %s" % \ |
|
145 self._unknown_method_name |
|
146 |
|
147 |
|
148 class Mox(object): |
|
149 """Mox: a factory for creating mock objects.""" |
|
150 |
|
151 # A list of types that should be stubbed out with MockObjects (as |
|
152 # opposed to MockAnythings). |
|
153 _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType, |
|
154 types.ObjectType, types.TypeType] |
|
155 |
|
156 def __init__(self): |
|
157 """Initialize a new Mox.""" |
|
158 |
|
159 self._mock_objects = [] |
|
160 self.stubs = stubout.StubOutForTesting() |
|
161 |
|
162 def CreateMock(self, class_to_mock): |
|
163 """Create a new mock object. |
|
164 |
|
165 Args: |
|
166 # class_to_mock: the class to be mocked |
|
167 class_to_mock: class |
|
168 |
|
169 Returns: |
|
170 MockObject that can be used as the class_to_mock would be. |
|
171 """ |
|
172 |
|
173 new_mock = MockObject(class_to_mock) |
|
174 self._mock_objects.append(new_mock) |
|
175 return new_mock |
|
176 |
|
177 def CreateMockAnything(self): |
|
178 """Create a mock that will accept any method calls. |
|
179 |
|
180 This does not enforce an interface. |
|
181 """ |
|
182 |
|
183 new_mock = MockAnything() |
|
184 self._mock_objects.append(new_mock) |
|
185 return new_mock |
|
186 |
|
187 def ReplayAll(self): |
|
188 """Set all mock objects to replay mode.""" |
|
189 |
|
190 for mock_obj in self._mock_objects: |
|
191 mock_obj._Replay() |
|
192 |
|
193 |
|
194 def VerifyAll(self): |
|
195 """Call verify on all mock objects created.""" |
|
196 |
|
197 for mock_obj in self._mock_objects: |
|
198 mock_obj._Verify() |
|
199 |
|
200 def ResetAll(self): |
|
201 """Call reset on all mock objects. This does not unset stubs.""" |
|
202 |
|
203 for mock_obj in self._mock_objects: |
|
204 mock_obj._Reset() |
|
205 |
|
206 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): |
|
207 """Replace a method, attribute, etc. with a Mock. |
|
208 |
|
209 This will replace a class or module with a MockObject, and everything else |
|
210 (method, function, etc) with a MockAnything. This can be overridden to |
|
211 always use a MockAnything by setting use_mock_anything to True. |
|
212 |
|
213 Args: |
|
214 obj: A Python object (class, module, instance, callable). |
|
215 attr_name: str. The name of the attribute to replace with a mock. |
|
216 use_mock_anything: bool. True if a MockAnything should be used regardless |
|
217 of the type of attribute. |
|
218 """ |
|
219 |
|
220 attr_to_replace = getattr(obj, attr_name) |
|
221 if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: |
|
222 stub = self.CreateMock(attr_to_replace) |
|
223 else: |
|
224 stub = self.CreateMockAnything() |
|
225 |
|
226 self.stubs.Set(obj, attr_name, stub) |
|
227 |
|
228 def UnsetStubs(self): |
|
229 """Restore stubs to their original state.""" |
|
230 |
|
231 self.stubs.UnsetAll() |
|
232 |
|
233 def Replay(*args): |
|
234 """Put mocks into Replay mode. |
|
235 |
|
236 Args: |
|
237 # args is any number of mocks to put into replay mode. |
|
238 """ |
|
239 |
|
240 for mock in args: |
|
241 mock._Replay() |
|
242 |
|
243 |
|
244 def Verify(*args): |
|
245 """Verify mocks. |
|
246 |
|
247 Args: |
|
248 # args is any number of mocks to be verified. |
|
249 """ |
|
250 |
|
251 for mock in args: |
|
252 mock._Verify() |
|
253 |
|
254 |
|
255 def Reset(*args): |
|
256 """Reset mocks. |
|
257 |
|
258 Args: |
|
259 # args is any number of mocks to be reset. |
|
260 """ |
|
261 |
|
262 for mock in args: |
|
263 mock._Reset() |
|
264 |
|
265 |
|
266 class MockAnything: |
|
267 """A mock that can be used to mock anything. |
|
268 |
|
269 This is helpful for mocking classes that do not provide a public interface. |
|
270 """ |
|
271 |
|
272 def __init__(self): |
|
273 """ """ |
|
274 self._Reset() |
|
275 |
|
276 def __getattr__(self, method_name): |
|
277 """Intercept method calls on this object. |
|
278 |
|
279 A new MockMethod is returned that is aware of the MockAnything's |
|
280 state (record or replay). The call will be recorded or replayed |
|
281 by the MockMethod's __call__. |
|
282 |
|
283 Args: |
|
284 # method name: the name of the method being called. |
|
285 method_name: str |
|
286 |
|
287 Returns: |
|
288 A new MockMethod aware of MockAnything's state (record or replay). |
|
289 """ |
|
290 |
|
291 return self._CreateMockMethod(method_name) |
|
292 |
|
293 def _CreateMockMethod(self, method_name, method_to_mock=None): |
|
294 """Create a new mock method call and return it. |
|
295 |
|
296 Args: |
|
297 # method_name: the name of the method being called. |
|
298 # method_to_mock: The actual method being mocked, used for introspection. |
|
299 method_name: str |
|
300 method_to_mock: a method object |
|
301 |
|
302 Returns: |
|
303 A new MockMethod aware of MockAnything's state (record or replay). |
|
304 """ |
|
305 |
|
306 return MockMethod(method_name, self._expected_calls_queue, |
|
307 self._replay_mode, method_to_mock=method_to_mock) |
|
308 |
|
309 def __nonzero__(self): |
|
310 """Return 1 for nonzero so the mock can be used as a conditional.""" |
|
311 |
|
312 return 1 |
|
313 |
|
314 def __eq__(self, rhs): |
|
315 """Provide custom logic to compare objects.""" |
|
316 |
|
317 return (isinstance(rhs, MockAnything) and |
|
318 self._replay_mode == rhs._replay_mode and |
|
319 self._expected_calls_queue == rhs._expected_calls_queue) |
|
320 |
|
321 def __ne__(self, rhs): |
|
322 """Provide custom logic to compare objects.""" |
|
323 |
|
324 return not self == rhs |
|
325 |
|
326 def _Replay(self): |
|
327 """Start replaying expected method calls.""" |
|
328 |
|
329 self._replay_mode = True |
|
330 |
|
331 def _Verify(self): |
|
332 """Verify that all of the expected calls have been made. |
|
333 |
|
334 Raises: |
|
335 ExpectedMethodCallsError: if there are still more method calls in the |
|
336 expected queue. |
|
337 """ |
|
338 |
|
339 # If the list of expected calls is not empty, raise an exception |
|
340 if self._expected_calls_queue: |
|
341 # The last MultipleTimesGroup is not popped from the queue. |
|
342 if (len(self._expected_calls_queue) == 1 and |
|
343 isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and |
|
344 self._expected_calls_queue[0].IsSatisfied()): |
|
345 pass |
|
346 else: |
|
347 raise ExpectedMethodCallsError(self._expected_calls_queue) |
|
348 |
|
349 def _Reset(self): |
|
350 """Reset the state of this mock to record mode with an empty queue.""" |
|
351 |
|
352 # Maintain a list of method calls we are expecting |
|
353 self._expected_calls_queue = deque() |
|
354 |
|
355 # Make sure we are in setup mode, not replay mode |
|
356 self._replay_mode = False |
|
357 |
|
358 |
|
359 class MockObject(MockAnything, object): |
|
360 """A mock object that simulates the public/protected interface of a class.""" |
|
361 |
|
362 def __init__(self, class_to_mock): |
|
363 """Initialize a mock object. |
|
364 |
|
365 This determines the methods and properties of the class and stores them. |
|
366 |
|
367 Args: |
|
368 # class_to_mock: class to be mocked |
|
369 class_to_mock: class |
|
370 """ |
|
371 |
|
372 # This is used to hack around the mixin/inheritance of MockAnything, which |
|
373 # is not a proper object (it can be anything. :-) |
|
374 MockAnything.__dict__['__init__'](self) |
|
375 |
|
376 # Get a list of all the public and special methods we should mock. |
|
377 self._known_methods = set() |
|
378 self._known_vars = set() |
|
379 self._class_to_mock = class_to_mock |
|
380 for method in dir(class_to_mock): |
|
381 if callable(getattr(class_to_mock, method)): |
|
382 self._known_methods.add(method) |
|
383 else: |
|
384 self._known_vars.add(method) |
|
385 |
|
386 def __getattr__(self, name): |
|
387 """Intercept attribute request on this object. |
|
388 |
|
389 If the attribute is a public class variable, it will be returned and not |
|
390 recorded as a call. |
|
391 |
|
392 If the attribute is not a variable, it is handled like a method |
|
393 call. The method name is checked against the set of mockable |
|
394 methods, and a new MockMethod is returned that is aware of the |
|
395 MockObject's state (record or replay). The call will be recorded |
|
396 or replayed by the MockMethod's __call__. |
|
397 |
|
398 Args: |
|
399 # name: the name of the attribute being requested. |
|
400 name: str |
|
401 |
|
402 Returns: |
|
403 Either a class variable or a new MockMethod that is aware of the state |
|
404 of the mock (record or replay). |
|
405 |
|
406 Raises: |
|
407 UnknownMethodCallError if the MockObject does not mock the requested |
|
408 method. |
|
409 """ |
|
410 |
|
411 if name in self._known_vars: |
|
412 return getattr(self._class_to_mock, name) |
|
413 |
|
414 if name in self._known_methods: |
|
415 return self._CreateMockMethod( |
|
416 name, |
|
417 method_to_mock=getattr(self._class_to_mock, name)) |
|
418 |
|
419 raise UnknownMethodCallError(name) |
|
420 |
|
421 def __eq__(self, rhs): |
|
422 """Provide custom logic to compare objects.""" |
|
423 |
|
424 return (isinstance(rhs, MockObject) and |
|
425 self._class_to_mock == rhs._class_to_mock and |
|
426 self._replay_mode == rhs._replay_mode and |
|
427 self._expected_calls_queue == rhs._expected_calls_queue) |
|
428 |
|
429 def __setitem__(self, key, value): |
|
430 """Provide custom logic for mocking classes that support item assignment. |
|
431 |
|
432 Args: |
|
433 key: Key to set the value for. |
|
434 value: Value to set. |
|
435 |
|
436 Returns: |
|
437 Expected return value in replay mode. A MockMethod object for the |
|
438 __setitem__ method that has already been called if not in replay mode. |
|
439 |
|
440 Raises: |
|
441 TypeError if the underlying class does not support item assignment. |
|
442 UnexpectedMethodCallError if the object does not expect the call to |
|
443 __setitem__. |
|
444 |
|
445 """ |
|
446 setitem = self._class_to_mock.__dict__.get('__setitem__', None) |
|
447 |
|
448 # Verify the class supports item assignment. |
|
449 if setitem is None: |
|
450 raise TypeError('object does not support item assignment') |
|
451 |
|
452 # If we are in replay mode then simply call the mock __setitem__ method. |
|
453 if self._replay_mode: |
|
454 return MockMethod('__setitem__', self._expected_calls_queue, |
|
455 self._replay_mode)(key, value) |
|
456 |
|
457 |
|
458 # Otherwise, create a mock method __setitem__. |
|
459 return self._CreateMockMethod('__setitem__')(key, value) |
|
460 |
|
461 def __getitem__(self, key): |
|
462 """Provide custom logic for mocking classes that are subscriptable. |
|
463 |
|
464 Args: |
|
465 key: Key to return the value for. |
|
466 |
|
467 Returns: |
|
468 Expected return value in replay mode. A MockMethod object for the |
|
469 __getitem__ method that has already been called if not in replay mode. |
|
470 |
|
471 Raises: |
|
472 TypeError if the underlying class is not subscriptable. |
|
473 UnexpectedMethodCallError if the object does not expect the call to |
|
474 __setitem__. |
|
475 |
|
476 """ |
|
477 getitem = self._class_to_mock.__dict__.get('__getitem__', None) |
|
478 |
|
479 # Verify the class supports item assignment. |
|
480 if getitem is None: |
|
481 raise TypeError('unsubscriptable object') |
|
482 |
|
483 # If we are in replay mode then simply call the mock __getitem__ method. |
|
484 if self._replay_mode: |
|
485 return MockMethod('__getitem__', self._expected_calls_queue, |
|
486 self._replay_mode)(key) |
|
487 |
|
488 |
|
489 # Otherwise, create a mock method __getitem__. |
|
490 return self._CreateMockMethod('__getitem__')(key) |
|
491 |
|
492 def __contains__(self, key): |
|
493 """Provide custom logic for mocking classes that contain items. |
|
494 |
|
495 Args: |
|
496 key: Key to look in container for. |
|
497 |
|
498 Returns: |
|
499 Expected return value in replay mode. A MockMethod object for the |
|
500 __contains__ method that has already been called if not in replay mode. |
|
501 |
|
502 Raises: |
|
503 TypeError if the underlying class does not implement __contains__ |
|
504 UnexpectedMethodCaller if the object does not expect the call to |
|
505 __contains__. |
|
506 |
|
507 """ |
|
508 contains = self._class_to_mock.__dict__.get('__contains__', None) |
|
509 |
|
510 if contains is None: |
|
511 raise TypeError('unsubscriptable object') |
|
512 |
|
513 if self._replay_mode: |
|
514 return MockMethod('__contains__', self._expected_calls_queue, |
|
515 self._replay_mode)(key) |
|
516 |
|
517 return self._CreateMockMethod('__contains__')(key) |
|
518 |
|
519 def __call__(self, *params, **named_params): |
|
520 """Provide custom logic for mocking classes that are callable.""" |
|
521 |
|
522 # Verify the class we are mocking is callable |
|
523 callable = self._class_to_mock.__dict__.get('__call__', None) |
|
524 if callable is None: |
|
525 raise TypeError('Not callable') |
|
526 |
|
527 # Because the call is happening directly on this object instead of a method, |
|
528 # the call on the mock method is made right here |
|
529 mock_method = self._CreateMockMethod('__call__') |
|
530 return mock_method(*params, **named_params) |
|
531 |
|
532 @property |
|
533 def __class__(self): |
|
534 """Return the class that is being mocked.""" |
|
535 |
|
536 return self._class_to_mock |
|
537 |
|
538 |
|
539 class MethodCallChecker(object): |
|
540 """Ensures that methods are called correctly.""" |
|
541 |
|
542 _NEEDED, _DEFAULT, _GIVEN = range(3) |
|
543 |
|
544 def __init__(self, method): |
|
545 """Creates a checker. |
|
546 |
|
547 Args: |
|
548 # method: A method to check. |
|
549 method: function |
|
550 |
|
551 Raises: |
|
552 ValueError: method could not be inspected, so checks aren't possible. |
|
553 Some methods and functions like built-ins can't be inspected. |
|
554 """ |
|
555 try: |
|
556 self._args, varargs, varkw, defaults = inspect.getargspec(method) |
|
557 except TypeError: |
|
558 raise ValueError('Could not get argument specification for %r' |
|
559 % (method,)) |
|
560 if inspect.ismethod(method): |
|
561 self._args = self._args[1:] # Skip 'self'. |
|
562 self._method = method |
|
563 |
|
564 self._has_varargs = varargs is not None |
|
565 self._has_varkw = varkw is not None |
|
566 if defaults is None: |
|
567 self._required_args = self._args |
|
568 self._default_args = [] |
|
569 else: |
|
570 self._required_args = self._args[:-len(defaults)] |
|
571 self._default_args = self._args[-len(defaults):] |
|
572 |
|
573 def _RecordArgumentGiven(self, arg_name, arg_status): |
|
574 """Mark an argument as being given. |
|
575 |
|
576 Args: |
|
577 # arg_name: The name of the argument to mark in arg_status. |
|
578 # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN. |
|
579 arg_name: string |
|
580 arg_status: dict |
|
581 |
|
582 Raises: |
|
583 AttributeError: arg_name is already marked as _GIVEN. |
|
584 """ |
|
585 if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN: |
|
586 raise AttributeError('%s provided more than once' % (arg_name,)) |
|
587 arg_status[arg_name] = MethodCallChecker._GIVEN |
|
588 |
|
589 def Check(self, params, named_params): |
|
590 """Ensures that the parameters used while recording a call are valid. |
|
591 |
|
592 Args: |
|
593 # params: A list of positional parameters. |
|
594 # named_params: A dict of named parameters. |
|
595 params: list |
|
596 named_params: dict |
|
597 |
|
598 Raises: |
|
599 AttributeError: the given parameters don't work with the given method. |
|
600 """ |
|
601 arg_status = dict((a, MethodCallChecker._NEEDED) |
|
602 for a in self._required_args) |
|
603 for arg in self._default_args: |
|
604 arg_status[arg] = MethodCallChecker._DEFAULT |
|
605 |
|
606 # Check that each positional param is valid. |
|
607 for i in range(len(params)): |
|
608 try: |
|
609 arg_name = self._args[i] |
|
610 except IndexError: |
|
611 if not self._has_varargs: |
|
612 raise AttributeError('%s does not take %d or more positional ' |
|
613 'arguments' % (self._method.__name__, i)) |
|
614 else: |
|
615 self._RecordArgumentGiven(arg_name, arg_status) |
|
616 |
|
617 # Check each keyword argument. |
|
618 for arg_name in named_params: |
|
619 if arg_name not in arg_status and not self._has_varkw: |
|
620 raise AttributeError('%s is not expecting keyword argument %s' |
|
621 % (self._method.__name__, arg_name)) |
|
622 self._RecordArgumentGiven(arg_name, arg_status) |
|
623 |
|
624 # Ensure all the required arguments have been given. |
|
625 still_needed = [k for k, v in arg_status.iteritems() |
|
626 if v == MethodCallChecker._NEEDED] |
|
627 if still_needed: |
|
628 raise AttributeError('No values given for arguments %s' |
|
629 % (' '.join(sorted(still_needed)))) |
|
630 |
|
631 |
|
632 class MockMethod(object): |
|
633 """Callable mock method. |
|
634 |
|
635 A MockMethod should act exactly like the method it mocks, accepting parameters |
|
636 and returning a value, or throwing an exception (as specified). When this |
|
637 method is called, it can optionally verify whether the called method (name and |
|
638 signature) matches the expected method. |
|
639 """ |
|
640 |
|
641 def __init__(self, method_name, call_queue, replay_mode, method_to_mock=None): |
|
642 """Construct a new mock method. |
|
643 |
|
644 Args: |
|
645 # method_name: the name of the method |
|
646 # call_queue: deque of calls, verify this call against the head, or add |
|
647 # this call to the queue. |
|
648 # replay_mode: False if we are recording, True if we are verifying calls |
|
649 # against the call queue. |
|
650 # method_to_mock: The actual method being mocked, used for introspection. |
|
651 method_name: str |
|
652 call_queue: list or deque |
|
653 replay_mode: bool |
|
654 method_to_mock: a method object |
|
655 """ |
|
656 |
|
657 self._name = method_name |
|
658 self._call_queue = call_queue |
|
659 if not isinstance(call_queue, deque): |
|
660 self._call_queue = deque(self._call_queue) |
|
661 self._replay_mode = replay_mode |
|
662 |
|
663 self._params = None |
|
664 self._named_params = None |
|
665 self._return_value = None |
|
666 self._exception = None |
|
667 self._side_effects = None |
|
668 |
|
669 try: |
|
670 self._checker = MethodCallChecker(method_to_mock) |
|
671 except ValueError: |
|
672 self._checker = None |
|
673 |
|
674 def __call__(self, *params, **named_params): |
|
675 """Log parameters and return the specified return value. |
|
676 |
|
677 If the Mock(Anything/Object) associated with this call is in record mode, |
|
678 this MockMethod will be pushed onto the expected call queue. If the mock |
|
679 is in replay mode, this will pop a MockMethod off the top of the queue and |
|
680 verify this call is equal to the expected call. |
|
681 |
|
682 Raises: |
|
683 UnexpectedMethodCall if this call is supposed to match an expected method |
|
684 call and it does not. |
|
685 """ |
|
686 |
|
687 self._params = params |
|
688 self._named_params = named_params |
|
689 |
|
690 if not self._replay_mode: |
|
691 if self._checker is not None: |
|
692 self._checker.Check(params, named_params) |
|
693 self._call_queue.append(self) |
|
694 return self |
|
695 |
|
696 expected_method = self._VerifyMethodCall() |
|
697 |
|
698 if expected_method._side_effects: |
|
699 expected_method._side_effects(*params, **named_params) |
|
700 |
|
701 if expected_method._exception: |
|
702 raise expected_method._exception |
|
703 |
|
704 return expected_method._return_value |
|
705 |
|
706 def __getattr__(self, name): |
|
707 """Raise an AttributeError with a helpful message.""" |
|
708 |
|
709 raise AttributeError('MockMethod has no attribute "%s". ' |
|
710 'Did you remember to put your mocks in replay mode?' % name) |
|
711 |
|
712 def _PopNextMethod(self): |
|
713 """Pop the next method from our call queue.""" |
|
714 try: |
|
715 return self._call_queue.popleft() |
|
716 except IndexError: |
|
717 raise UnexpectedMethodCallError(self, None) |
|
718 |
|
719 def _VerifyMethodCall(self): |
|
720 """Verify the called method is expected. |
|
721 |
|
722 This can be an ordered method, or part of an unordered set. |
|
723 |
|
724 Returns: |
|
725 The expected mock method. |
|
726 |
|
727 Raises: |
|
728 UnexpectedMethodCall if the method called was not expected. |
|
729 """ |
|
730 |
|
731 expected = self._PopNextMethod() |
|
732 |
|
733 # Loop here, because we might have a MethodGroup followed by another |
|
734 # group. |
|
735 while isinstance(expected, MethodGroup): |
|
736 expected, method = expected.MethodCalled(self) |
|
737 if method is not None: |
|
738 return method |
|
739 |
|
740 # This is a mock method, so just check equality. |
|
741 if expected != self: |
|
742 raise UnexpectedMethodCallError(self, expected) |
|
743 |
|
744 return expected |
|
745 |
|
746 def __str__(self): |
|
747 params = ', '.join( |
|
748 [repr(p) for p in self._params or []] + |
|
749 ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) |
|
750 desc = "%s(%s) -> %r" % (self._name, params, self._return_value) |
|
751 return desc |
|
752 |
|
753 def __eq__(self, rhs): |
|
754 """Test whether this MockMethod is equivalent to another MockMethod. |
|
755 |
|
756 Args: |
|
757 # rhs: the right hand side of the test |
|
758 rhs: MockMethod |
|
759 """ |
|
760 |
|
761 return (isinstance(rhs, MockMethod) and |
|
762 self._name == rhs._name and |
|
763 self._params == rhs._params and |
|
764 self._named_params == rhs._named_params) |
|
765 |
|
766 def __ne__(self, rhs): |
|
767 """Test whether this MockMethod is not equivalent to another MockMethod. |
|
768 |
|
769 Args: |
|
770 # rhs: the right hand side of the test |
|
771 rhs: MockMethod |
|
772 """ |
|
773 |
|
774 return not self == rhs |
|
775 |
|
776 def GetPossibleGroup(self): |
|
777 """Returns a possible group from the end of the call queue or None if no |
|
778 other methods are on the stack. |
|
779 """ |
|
780 |
|
781 # Remove this method from the tail of the queue so we can add it to a group. |
|
782 this_method = self._call_queue.pop() |
|
783 assert this_method == self |
|
784 |
|
785 # Determine if the tail of the queue is a group, or just a regular ordered |
|
786 # mock method. |
|
787 group = None |
|
788 try: |
|
789 group = self._call_queue[-1] |
|
790 except IndexError: |
|
791 pass |
|
792 |
|
793 return group |
|
794 |
|
795 def _CheckAndCreateNewGroup(self, group_name, group_class): |
|
796 """Checks if the last method (a possible group) is an instance of our |
|
797 group_class. Adds the current method to this group or creates a new one. |
|
798 |
|
799 Args: |
|
800 |
|
801 group_name: the name of the group. |
|
802 group_class: the class used to create instance of this new group |
|
803 """ |
|
804 group = self.GetPossibleGroup() |
|
805 |
|
806 # If this is a group, and it is the correct group, add the method. |
|
807 if isinstance(group, group_class) and group.group_name() == group_name: |
|
808 group.AddMethod(self) |
|
809 return self |
|
810 |
|
811 # Create a new group and add the method. |
|
812 new_group = group_class(group_name) |
|
813 new_group.AddMethod(self) |
|
814 self._call_queue.append(new_group) |
|
815 return self |
|
816 |
|
817 def InAnyOrder(self, group_name="default"): |
|
818 """Move this method into a group of unordered calls. |
|
819 |
|
820 A group of unordered calls must be defined together, and must be executed |
|
821 in full before the next expected method can be called. There can be |
|
822 multiple groups that are expected serially, if they are given |
|
823 different group names. The same group name can be reused if there is a |
|
824 standard method call, or a group with a different name, spliced between |
|
825 usages. |
|
826 |
|
827 Args: |
|
828 group_name: the name of the unordered group. |
|
829 |
|
830 Returns: |
|
831 self |
|
832 """ |
|
833 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) |
|
834 |
|
835 def MultipleTimes(self, group_name="default"): |
|
836 """Move this method into group of calls which may be called multiple times. |
|
837 |
|
838 A group of repeating calls must be defined together, and must be executed in |
|
839 full before the next expected mehtod can be called. |
|
840 |
|
841 Args: |
|
842 group_name: the name of the unordered group. |
|
843 |
|
844 Returns: |
|
845 self |
|
846 """ |
|
847 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) |
|
848 |
|
849 def AndReturn(self, return_value): |
|
850 """Set the value to return when this method is called. |
|
851 |
|
852 Args: |
|
853 # return_value can be anything. |
|
854 """ |
|
855 |
|
856 self._return_value = return_value |
|
857 return return_value |
|
858 |
|
859 def AndRaise(self, exception): |
|
860 """Set the exception to raise when this method is called. |
|
861 |
|
862 Args: |
|
863 # exception: the exception to raise when this method is called. |
|
864 exception: Exception |
|
865 """ |
|
866 |
|
867 self._exception = exception |
|
868 |
|
869 def WithSideEffects(self, side_effects): |
|
870 """Set the side effects that are simulated when this method is called. |
|
871 |
|
872 Args: |
|
873 side_effects: A callable which modifies the parameters or other relevant |
|
874 state which a given test case depends on. |
|
875 |
|
876 Returns: |
|
877 Self for chaining with AndReturn and AndRaise. |
|
878 """ |
|
879 self._side_effects = side_effects |
|
880 return self |
|
881 |
|
882 class Comparator: |
|
883 """Base class for all Mox comparators. |
|
884 |
|
885 A Comparator can be used as a parameter to a mocked method when the exact |
|
886 value is not known. For example, the code you are testing might build up a |
|
887 long SQL string that is passed to your mock DAO. You're only interested that |
|
888 the IN clause contains the proper primary keys, so you can set your mock |
|
889 up as follows: |
|
890 |
|
891 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) |
|
892 |
|
893 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. |
|
894 |
|
895 A Comparator may replace one or more parameters, for example: |
|
896 # return at most 10 rows |
|
897 mock_dao.RunQuery(StrContains('SELECT'), 10) |
|
898 |
|
899 or |
|
900 |
|
901 # Return some non-deterministic number of rows |
|
902 mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) |
|
903 """ |
|
904 |
|
905 def equals(self, rhs): |
|
906 """Special equals method that all comparators must implement. |
|
907 |
|
908 Args: |
|
909 rhs: any python object |
|
910 """ |
|
911 |
|
912 raise NotImplementedError, 'method must be implemented by a subclass.' |
|
913 |
|
914 def __eq__(self, rhs): |
|
915 return self.equals(rhs) |
|
916 |
|
917 def __ne__(self, rhs): |
|
918 return not self.equals(rhs) |
|
919 |
|
920 |
|
921 class IsA(Comparator): |
|
922 """This class wraps a basic Python type or class. It is used to verify |
|
923 that a parameter is of the given type or class. |
|
924 |
|
925 Example: |
|
926 mock_dao.Connect(IsA(DbConnectInfo)) |
|
927 """ |
|
928 |
|
929 def __init__(self, class_name): |
|
930 """Initialize IsA |
|
931 |
|
932 Args: |
|
933 class_name: basic python type or a class |
|
934 """ |
|
935 |
|
936 self._class_name = class_name |
|
937 |
|
938 def equals(self, rhs): |
|
939 """Check to see if the RHS is an instance of class_name. |
|
940 |
|
941 Args: |
|
942 # rhs: the right hand side of the test |
|
943 rhs: object |
|
944 |
|
945 Returns: |
|
946 bool |
|
947 """ |
|
948 |
|
949 try: |
|
950 return isinstance(rhs, self._class_name) |
|
951 except TypeError: |
|
952 # Check raw types if there was a type error. This is helpful for |
|
953 # things like cStringIO.StringIO. |
|
954 return type(rhs) == type(self._class_name) |
|
955 |
|
956 def __repr__(self): |
|
957 return str(self._class_name) |
|
958 |
|
959 class IsAlmost(Comparator): |
|
960 """Comparison class used to check whether a parameter is nearly equal |
|
961 to a given value. Generally useful for floating point numbers. |
|
962 |
|
963 Example mock_dao.SetTimeout((IsAlmost(3.9))) |
|
964 """ |
|
965 |
|
966 def __init__(self, float_value, places=7): |
|
967 """Initialize IsAlmost. |
|
968 |
|
969 Args: |
|
970 float_value: The value for making the comparison. |
|
971 places: The number of decimal places to round to. |
|
972 """ |
|
973 |
|
974 self._float_value = float_value |
|
975 self._places = places |
|
976 |
|
977 def equals(self, rhs): |
|
978 """Check to see if RHS is almost equal to float_value |
|
979 |
|
980 Args: |
|
981 rhs: the value to compare to float_value |
|
982 |
|
983 Returns: |
|
984 bool |
|
985 """ |
|
986 |
|
987 try: |
|
988 return round(rhs-self._float_value, self._places) == 0 |
|
989 except TypeError: |
|
990 # This is probably because either float_value or rhs is not a number. |
|
991 return False |
|
992 |
|
993 def __repr__(self): |
|
994 return str(self._float_value) |
|
995 |
|
996 class StrContains(Comparator): |
|
997 """Comparison class used to check whether a substring exists in a |
|
998 string parameter. This can be useful in mocking a database with SQL |
|
999 passed in as a string parameter, for example. |
|
1000 |
|
1001 Example: |
|
1002 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) |
|
1003 """ |
|
1004 |
|
1005 def __init__(self, search_string): |
|
1006 """Initialize. |
|
1007 |
|
1008 Args: |
|
1009 # search_string: the string you are searching for |
|
1010 search_string: str |
|
1011 """ |
|
1012 |
|
1013 self._search_string = search_string |
|
1014 |
|
1015 def equals(self, rhs): |
|
1016 """Check to see if the search_string is contained in the rhs string. |
|
1017 |
|
1018 Args: |
|
1019 # rhs: the right hand side of the test |
|
1020 rhs: object |
|
1021 |
|
1022 Returns: |
|
1023 bool |
|
1024 """ |
|
1025 |
|
1026 try: |
|
1027 return rhs.find(self._search_string) > -1 |
|
1028 except Exception: |
|
1029 return False |
|
1030 |
|
1031 def __repr__(self): |
|
1032 return '<str containing \'%s\'>' % self._search_string |
|
1033 |
|
1034 |
|
1035 class Regex(Comparator): |
|
1036 """Checks if a string matches a regular expression. |
|
1037 |
|
1038 This uses a given regular expression to determine equality. |
|
1039 """ |
|
1040 |
|
1041 def __init__(self, pattern, flags=0): |
|
1042 """Initialize. |
|
1043 |
|
1044 Args: |
|
1045 # pattern is the regular expression to search for |
|
1046 pattern: str |
|
1047 # flags passed to re.compile function as the second argument |
|
1048 flags: int |
|
1049 """ |
|
1050 |
|
1051 self.regex = re.compile(pattern, flags=flags) |
|
1052 |
|
1053 def equals(self, rhs): |
|
1054 """Check to see if rhs matches regular expression pattern. |
|
1055 |
|
1056 Returns: |
|
1057 bool |
|
1058 """ |
|
1059 |
|
1060 return self.regex.search(rhs) is not None |
|
1061 |
|
1062 def __repr__(self): |
|
1063 s = '<regular expression \'%s\'' % self.regex.pattern |
|
1064 if self.regex.flags: |
|
1065 s += ', flags=%d' % self.regex.flags |
|
1066 s += '>' |
|
1067 return s |
|
1068 |
|
1069 |
|
1070 class In(Comparator): |
|
1071 """Checks whether an item (or key) is in a list (or dict) parameter. |
|
1072 |
|
1073 Example: |
|
1074 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) |
|
1075 """ |
|
1076 |
|
1077 def __init__(self, key): |
|
1078 """Initialize. |
|
1079 |
|
1080 Args: |
|
1081 # key is any thing that could be in a list or a key in a dict |
|
1082 """ |
|
1083 |
|
1084 self._key = key |
|
1085 |
|
1086 def equals(self, rhs): |
|
1087 """Check to see whether key is in rhs. |
|
1088 |
|
1089 Args: |
|
1090 rhs: dict |
|
1091 |
|
1092 Returns: |
|
1093 bool |
|
1094 """ |
|
1095 |
|
1096 return self._key in rhs |
|
1097 |
|
1098 def __repr__(self): |
|
1099 return '<sequence or map containing \'%s\'>' % self._key |
|
1100 |
|
1101 |
|
1102 class Not(Comparator): |
|
1103 """Checks whether a predicates is False. |
|
1104 |
|
1105 Example: |
|
1106 mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', stevepm_user_info))) |
|
1107 """ |
|
1108 |
|
1109 def __init__(self, predicate): |
|
1110 """Initialize. |
|
1111 |
|
1112 Args: |
|
1113 # predicate: a Comparator instance. |
|
1114 """ |
|
1115 |
|
1116 assert isinstance(predicate, Comparator), ("predicate %r must be a" |
|
1117 " Comparator." % predicate) |
|
1118 self._predicate = predicate |
|
1119 |
|
1120 def equals(self, rhs): |
|
1121 """Check to see whether the predicate is False. |
|
1122 |
|
1123 Args: |
|
1124 rhs: A value that will be given in argument of the predicate. |
|
1125 |
|
1126 Returns: |
|
1127 bool |
|
1128 """ |
|
1129 |
|
1130 return not self._predicate.equals(rhs) |
|
1131 |
|
1132 def __repr__(self): |
|
1133 return '<not \'%s\'>' % self._predicate |
|
1134 |
|
1135 |
|
1136 class ContainsKeyValue(Comparator): |
|
1137 """Checks whether a key/value pair is in a dict parameter. |
|
1138 |
|
1139 Example: |
|
1140 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) |
|
1141 """ |
|
1142 |
|
1143 def __init__(self, key, value): |
|
1144 """Initialize. |
|
1145 |
|
1146 Args: |
|
1147 # key: a key in a dict |
|
1148 # value: the corresponding value |
|
1149 """ |
|
1150 |
|
1151 self._key = key |
|
1152 self._value = value |
|
1153 |
|
1154 def equals(self, rhs): |
|
1155 """Check whether the given key/value pair is in the rhs dict. |
|
1156 |
|
1157 Returns: |
|
1158 bool |
|
1159 """ |
|
1160 |
|
1161 try: |
|
1162 return rhs[self._key] == self._value |
|
1163 except Exception: |
|
1164 return False |
|
1165 |
|
1166 def __repr__(self): |
|
1167 return '<map containing the entry \'%s: %s\'>' % (self._key, self._value) |
|
1168 |
|
1169 |
|
1170 class SameElementsAs(Comparator): |
|
1171 """Checks whether iterables contain the same elements (ignoring order). |
|
1172 |
|
1173 Example: |
|
1174 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) |
|
1175 """ |
|
1176 |
|
1177 def __init__(self, expected_seq): |
|
1178 """Initialize. |
|
1179 |
|
1180 Args: |
|
1181 expected_seq: a sequence |
|
1182 """ |
|
1183 |
|
1184 self._expected_seq = expected_seq |
|
1185 |
|
1186 def equals(self, actual_seq): |
|
1187 """Check to see whether actual_seq has same elements as expected_seq. |
|
1188 |
|
1189 Args: |
|
1190 actual_seq: sequence |
|
1191 |
|
1192 Returns: |
|
1193 bool |
|
1194 """ |
|
1195 |
|
1196 try: |
|
1197 expected = dict([(element, None) for element in self._expected_seq]) |
|
1198 actual = dict([(element, None) for element in actual_seq]) |
|
1199 except TypeError: |
|
1200 # Fall back to slower list-compare if any of the objects are unhashable. |
|
1201 expected = list(self._expected_seq) |
|
1202 actual = list(actual_seq) |
|
1203 expected.sort() |
|
1204 actual.sort() |
|
1205 return expected == actual |
|
1206 |
|
1207 def __repr__(self): |
|
1208 return '<sequence with same elements as \'%s\'>' % self._expected_seq |
|
1209 |
|
1210 |
|
1211 class And(Comparator): |
|
1212 """Evaluates one or more Comparators on RHS and returns an AND of the results. |
|
1213 """ |
|
1214 |
|
1215 def __init__(self, *args): |
|
1216 """Initialize. |
|
1217 |
|
1218 Args: |
|
1219 *args: One or more Comparator |
|
1220 """ |
|
1221 |
|
1222 self._comparators = args |
|
1223 |
|
1224 def equals(self, rhs): |
|
1225 """Checks whether all Comparators are equal to rhs. |
|
1226 |
|
1227 Args: |
|
1228 # rhs: can be anything |
|
1229 |
|
1230 Returns: |
|
1231 bool |
|
1232 """ |
|
1233 |
|
1234 for comparator in self._comparators: |
|
1235 if not comparator.equals(rhs): |
|
1236 return False |
|
1237 |
|
1238 return True |
|
1239 |
|
1240 def __repr__(self): |
|
1241 return '<AND %s>' % str(self._comparators) |
|
1242 |
|
1243 |
|
1244 class Or(Comparator): |
|
1245 """Evaluates one or more Comparators on RHS and returns an OR of the results. |
|
1246 """ |
|
1247 |
|
1248 def __init__(self, *args): |
|
1249 """Initialize. |
|
1250 |
|
1251 Args: |
|
1252 *args: One or more Mox comparators |
|
1253 """ |
|
1254 |
|
1255 self._comparators = args |
|
1256 |
|
1257 def equals(self, rhs): |
|
1258 """Checks whether any Comparator is equal to rhs. |
|
1259 |
|
1260 Args: |
|
1261 # rhs: can be anything |
|
1262 |
|
1263 Returns: |
|
1264 bool |
|
1265 """ |
|
1266 |
|
1267 for comparator in self._comparators: |
|
1268 if comparator.equals(rhs): |
|
1269 return True |
|
1270 |
|
1271 return False |
|
1272 |
|
1273 def __repr__(self): |
|
1274 return '<OR %s>' % str(self._comparators) |
|
1275 |
|
1276 |
|
1277 class Func(Comparator): |
|
1278 """Call a function that should verify the parameter passed in is correct. |
|
1279 |
|
1280 You may need the ability to perform more advanced operations on the parameter |
|
1281 in order to validate it. You can use this to have a callable validate any |
|
1282 parameter. The callable should return either True or False. |
|
1283 |
|
1284 |
|
1285 Example: |
|
1286 |
|
1287 def myParamValidator(param): |
|
1288 # Advanced logic here |
|
1289 return True |
|
1290 |
|
1291 mock_dao.DoSomething(Func(myParamValidator), true) |
|
1292 """ |
|
1293 |
|
1294 def __init__(self, func): |
|
1295 """Initialize. |
|
1296 |
|
1297 Args: |
|
1298 func: callable that takes one parameter and returns a bool |
|
1299 """ |
|
1300 |
|
1301 self._func = func |
|
1302 |
|
1303 def equals(self, rhs): |
|
1304 """Test whether rhs passes the function test. |
|
1305 |
|
1306 rhs is passed into func. |
|
1307 |
|
1308 Args: |
|
1309 rhs: any python object |
|
1310 |
|
1311 Returns: |
|
1312 the result of func(rhs) |
|
1313 """ |
|
1314 |
|
1315 return self._func(rhs) |
|
1316 |
|
1317 def __repr__(self): |
|
1318 return str(self._func) |
|
1319 |
|
1320 |
|
1321 class IgnoreArg(Comparator): |
|
1322 """Ignore an argument. |
|
1323 |
|
1324 This can be used when we don't care about an argument of a method call. |
|
1325 |
|
1326 Example: |
|
1327 # Check if CastMagic is called with 3 as first arg and 'disappear' as third. |
|
1328 mymock.CastMagic(3, IgnoreArg(), 'disappear') |
|
1329 """ |
|
1330 |
|
1331 def equals(self, unused_rhs): |
|
1332 """Ignores arguments and returns True. |
|
1333 |
|
1334 Args: |
|
1335 unused_rhs: any python object |
|
1336 |
|
1337 Returns: |
|
1338 always returns True |
|
1339 """ |
|
1340 |
|
1341 return True |
|
1342 |
|
1343 def __repr__(self): |
|
1344 return '<IgnoreArg>' |
|
1345 |
|
1346 |
|
1347 class MethodGroup(object): |
|
1348 """Base class containing common behaviour for MethodGroups.""" |
|
1349 |
|
1350 def __init__(self, group_name): |
|
1351 self._group_name = group_name |
|
1352 |
|
1353 def group_name(self): |
|
1354 return self._group_name |
|
1355 |
|
1356 def __str__(self): |
|
1357 return '<%s "%s">' % (self.__class__.__name__, self._group_name) |
|
1358 |
|
1359 def AddMethod(self, mock_method): |
|
1360 raise NotImplementedError |
|
1361 |
|
1362 def MethodCalled(self, mock_method): |
|
1363 raise NotImplementedError |
|
1364 |
|
1365 def IsSatisfied(self): |
|
1366 raise NotImplementedError |
|
1367 |
|
1368 class UnorderedGroup(MethodGroup): |
|
1369 """UnorderedGroup holds a set of method calls that may occur in any order. |
|
1370 |
|
1371 This construct is helpful for non-deterministic events, such as iterating |
|
1372 over the keys of a dict. |
|
1373 """ |
|
1374 |
|
1375 def __init__(self, group_name): |
|
1376 super(UnorderedGroup, self).__init__(group_name) |
|
1377 self._methods = [] |
|
1378 |
|
1379 def AddMethod(self, mock_method): |
|
1380 """Add a method to this group. |
|
1381 |
|
1382 Args: |
|
1383 mock_method: A mock method to be added to this group. |
|
1384 """ |
|
1385 |
|
1386 self._methods.append(mock_method) |
|
1387 |
|
1388 def MethodCalled(self, mock_method): |
|
1389 """Remove a method call from the group. |
|
1390 |
|
1391 If the method is not in the set, an UnexpectedMethodCallError will be |
|
1392 raised. |
|
1393 |
|
1394 Args: |
|
1395 mock_method: a mock method that should be equal to a method in the group. |
|
1396 |
|
1397 Returns: |
|
1398 The mock method from the group |
|
1399 |
|
1400 Raises: |
|
1401 UnexpectedMethodCallError if the mock_method was not in the group. |
|
1402 """ |
|
1403 |
|
1404 # Check to see if this method exists, and if so, remove it from the set |
|
1405 # and return it. |
|
1406 for method in self._methods: |
|
1407 if method == mock_method: |
|
1408 # Remove the called mock_method instead of the method in the group. |
|
1409 # The called method will match any comparators when equality is checked |
|
1410 # during removal. The method in the group could pass a comparator to |
|
1411 # another comparator during the equality check. |
|
1412 self._methods.remove(mock_method) |
|
1413 |
|
1414 # If this group is not empty, put it back at the head of the queue. |
|
1415 if not self.IsSatisfied(): |
|
1416 mock_method._call_queue.appendleft(self) |
|
1417 |
|
1418 return self, method |
|
1419 |
|
1420 raise UnexpectedMethodCallError(mock_method, self) |
|
1421 |
|
1422 def IsSatisfied(self): |
|
1423 """Return True if there are not any methods in this group.""" |
|
1424 |
|
1425 return len(self._methods) == 0 |
|
1426 |
|
1427 |
|
1428 class MultipleTimesGroup(MethodGroup): |
|
1429 """MultipleTimesGroup holds methods that may be called any number of times. |
|
1430 |
|
1431 Note: Each method must be called at least once. |
|
1432 |
|
1433 This is helpful, if you don't know or care how many times a method is called. |
|
1434 """ |
|
1435 |
|
1436 def __init__(self, group_name): |
|
1437 super(MultipleTimesGroup, self).__init__(group_name) |
|
1438 self._methods = set() |
|
1439 self._methods_called = set() |
|
1440 |
|
1441 def AddMethod(self, mock_method): |
|
1442 """Add a method to this group. |
|
1443 |
|
1444 Args: |
|
1445 mock_method: A mock method to be added to this group. |
|
1446 """ |
|
1447 |
|
1448 self._methods.add(mock_method) |
|
1449 |
|
1450 def MethodCalled(self, mock_method): |
|
1451 """Remove a method call from the group. |
|
1452 |
|
1453 If the method is not in the set, an UnexpectedMethodCallError will be |
|
1454 raised. |
|
1455 |
|
1456 Args: |
|
1457 mock_method: a mock method that should be equal to a method in the group. |
|
1458 |
|
1459 Returns: |
|
1460 The mock method from the group |
|
1461 |
|
1462 Raises: |
|
1463 UnexpectedMethodCallError if the mock_method was not in the group. |
|
1464 """ |
|
1465 |
|
1466 # Check to see if this method exists, and if so add it to the set of |
|
1467 # called methods. |
|
1468 |
|
1469 for method in self._methods: |
|
1470 if method == mock_method: |
|
1471 self._methods_called.add(mock_method) |
|
1472 # Always put this group back on top of the queue, because we don't know |
|
1473 # when we are done. |
|
1474 mock_method._call_queue.appendleft(self) |
|
1475 return self, method |
|
1476 |
|
1477 if self.IsSatisfied(): |
|
1478 next_method = mock_method._PopNextMethod(); |
|
1479 return next_method, None |
|
1480 else: |
|
1481 raise UnexpectedMethodCallError(mock_method, self) |
|
1482 |
|
1483 def IsSatisfied(self): |
|
1484 """Return True if all methods in this group are called at least once.""" |
|
1485 # NOTE(psycho): We can't use the simple set difference here because we want |
|
1486 # to match different parameters which are considered the same e.g. IsA(str) |
|
1487 # and some string. This solution is O(n^2) but n should be small. |
|
1488 tmp = self._methods.copy() |
|
1489 for called in self._methods_called: |
|
1490 for expected in tmp: |
|
1491 if called == expected: |
|
1492 tmp.remove(expected) |
|
1493 if not tmp: |
|
1494 return True |
|
1495 break |
|
1496 return False |
|
1497 |
|
1498 |
|
1499 class MoxMetaTestBase(type): |
|
1500 """Metaclass to add mox cleanup and verification to every test. |
|
1501 |
|
1502 As the mox unit testing class is being constructed (MoxTestBase or a |
|
1503 subclass), this metaclass will modify all test functions to call the |
|
1504 CleanUpMox method of the test class after they finish. This means that |
|
1505 unstubbing and verifying will happen for every test with no additional code, |
|
1506 and any failures will result in test failures as opposed to errors. |
|
1507 """ |
|
1508 |
|
1509 def __init__(cls, name, bases, d): |
|
1510 type.__init__(cls, name, bases, d) |
|
1511 |
|
1512 # also get all the attributes from the base classes to account |
|
1513 # for a case when test class is not the immediate child of MoxTestBase |
|
1514 for base in bases: |
|
1515 for attr_name in dir(base): |
|
1516 d[attr_name] = getattr(base, attr_name) |
|
1517 |
|
1518 for func_name, func in d.items(): |
|
1519 if func_name.startswith('test') and callable(func): |
|
1520 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) |
|
1521 |
|
1522 @staticmethod |
|
1523 def CleanUpTest(cls, func): |
|
1524 """Adds Mox cleanup code to any MoxTestBase method. |
|
1525 |
|
1526 Always unsets stubs after a test. Will verify all mocks for tests that |
|
1527 otherwise pass. |
|
1528 |
|
1529 Args: |
|
1530 cls: MoxTestBase or subclass; the class whose test method we are altering. |
|
1531 func: method; the method of the MoxTestBase test class we wish to alter. |
|
1532 |
|
1533 Returns: |
|
1534 The modified method. |
|
1535 """ |
|
1536 def new_method(self, *args, **kwargs): |
|
1537 mox_obj = getattr(self, 'mox', None) |
|
1538 cleanup_mox = False |
|
1539 if mox_obj and isinstance(mox_obj, Mox): |
|
1540 cleanup_mox = True |
|
1541 try: |
|
1542 func(self, *args, **kwargs) |
|
1543 finally: |
|
1544 if cleanup_mox: |
|
1545 mox_obj.UnsetStubs() |
|
1546 if cleanup_mox: |
|
1547 mox_obj.VerifyAll() |
|
1548 new_method.__name__ = func.__name__ |
|
1549 new_method.__doc__ = func.__doc__ |
|
1550 new_method.__module__ = func.__module__ |
|
1551 return new_method |
|
1552 |
|
1553 |
|
1554 class MoxTestBase(unittest.TestCase): |
|
1555 """Convenience test class to make stubbing easier. |
|
1556 |
|
1557 Sets up a "mox" attribute which is an instance of Mox - any mox tests will |
|
1558 want this. Also automatically unsets any stubs and verifies that all mock |
|
1559 methods have been called at the end of each test, eliminating boilerplate |
|
1560 code. |
|
1561 """ |
|
1562 |
|
1563 __metaclass__ = MoxMetaTestBase |
|
1564 |
|
1565 def setUp(self): |
|
1566 super(MoxTestBase, self).setUp() |
|
1567 self.mox = Mox() |