|
1 import re |
|
2 import unittest |
|
3 from urlparse import urlsplit, urlunsplit |
|
4 |
|
5 from django.http import QueryDict |
|
6 from django.db import transaction |
|
7 from django.core import mail |
|
8 from django.core.management import call_command |
|
9 from django.test import _doctest as doctest |
|
10 from django.test.client import Client |
|
11 |
|
12 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) |
|
13 |
|
14 def to_list(value): |
|
15 """ |
|
16 Puts value into a list if it's not already one. |
|
17 Returns an empty list if value is None. |
|
18 """ |
|
19 if value is None: |
|
20 value = [] |
|
21 elif not isinstance(value, list): |
|
22 value = [value] |
|
23 return value |
|
24 |
|
25 |
|
26 class OutputChecker(doctest.OutputChecker): |
|
27 def check_output(self, want, got, optionflags): |
|
28 ok = doctest.OutputChecker.check_output(self, want, got, optionflags) |
|
29 |
|
30 # Doctest does an exact string comparison of output, which means long |
|
31 # integers aren't equal to normal integers ("22L" vs. "22"). The |
|
32 # following code normalizes long integers so that they equal normal |
|
33 # integers. |
|
34 if not ok: |
|
35 return normalize_long_ints(want) == normalize_long_ints(got) |
|
36 return ok |
|
37 |
|
38 class DocTestRunner(doctest.DocTestRunner): |
|
39 def __init__(self, *args, **kwargs): |
|
40 doctest.DocTestRunner.__init__(self, *args, **kwargs) |
|
41 self.optionflags = doctest.ELLIPSIS |
|
42 |
|
43 def report_unexpected_exception(self, out, test, example, exc_info): |
|
44 doctest.DocTestRunner.report_unexpected_exception(self, out, test, |
|
45 example, exc_info) |
|
46 # Rollback, in case of database errors. Otherwise they'd have |
|
47 # side effects on other tests. |
|
48 transaction.rollback_unless_managed() |
|
49 |
|
50 class TestCase(unittest.TestCase): |
|
51 def _pre_setup(self): |
|
52 """Performs any pre-test setup. This includes: |
|
53 |
|
54 * Flushing the database. |
|
55 * If the Test Case class has a 'fixtures' member, installing the |
|
56 named fixtures. |
|
57 * Clearing the mail test outbox. |
|
58 """ |
|
59 call_command('flush', verbosity=0, interactive=False) |
|
60 if hasattr(self, 'fixtures'): |
|
61 # We have to use this slightly awkward syntax due to the fact |
|
62 # that we're using *args and **kwargs together. |
|
63 call_command('loaddata', *self.fixtures, **{'verbosity': 0}) |
|
64 mail.outbox = [] |
|
65 |
|
66 def __call__(self, result=None): |
|
67 """ |
|
68 Wrapper around default __call__ method to perform common Django test |
|
69 set up. This means that user-defined Test Cases aren't required to |
|
70 include a call to super().setUp(). |
|
71 """ |
|
72 self.client = Client() |
|
73 try: |
|
74 self._pre_setup() |
|
75 except (KeyboardInterrupt, SystemExit): |
|
76 raise |
|
77 except Exception: |
|
78 import sys |
|
79 result.addError(self, sys.exc_info()) |
|
80 return |
|
81 super(TestCase, self).__call__(result) |
|
82 |
|
83 def assertRedirects(self, response, expected_url, status_code=302, |
|
84 target_status_code=200, host=None): |
|
85 """Asserts that a response redirected to a specific URL, and that the |
|
86 redirect URL can be loaded. |
|
87 |
|
88 Note that assertRedirects won't work for external links since it uses |
|
89 TestClient to do a request. |
|
90 """ |
|
91 self.assertEqual(response.status_code, status_code, |
|
92 ("Response didn't redirect as expected: Response code was %d" |
|
93 " (expected %d)" % (response.status_code, status_code))) |
|
94 url = response['Location'] |
|
95 scheme, netloc, path, query, fragment = urlsplit(url) |
|
96 e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url) |
|
97 if not (e_scheme or e_netloc): |
|
98 expected_url = urlunsplit(('http', host or 'testserver', e_path, |
|
99 e_query, e_fragment)) |
|
100 self.assertEqual(url, expected_url, |
|
101 "Response redirected to '%s', expected '%s'" % (url, expected_url)) |
|
102 |
|
103 # Get the redirection page, using the same client that was used |
|
104 # to obtain the original response. |
|
105 redirect_response = response.client.get(path, QueryDict(query)) |
|
106 self.assertEqual(redirect_response.status_code, target_status_code, |
|
107 ("Couldn't retrieve redirection page '%s': response code was %d" |
|
108 " (expected %d)") % |
|
109 (path, redirect_response.status_code, target_status_code)) |
|
110 |
|
111 def assertContains(self, response, text, count=None, status_code=200): |
|
112 """ |
|
113 Asserts that a response indicates that a page was retrieved |
|
114 successfully, (i.e., the HTTP status code was as expected), and that |
|
115 ``text`` occurs ``count`` times in the content of the response. |
|
116 If ``count`` is None, the count doesn't matter - the assertion is true |
|
117 if the text occurs at least once in the response. |
|
118 """ |
|
119 self.assertEqual(response.status_code, status_code, |
|
120 "Couldn't retrieve page: Response code was %d (expected %d)'" % |
|
121 (response.status_code, status_code)) |
|
122 real_count = response.content.count(text) |
|
123 if count is not None: |
|
124 self.assertEqual(real_count, count, |
|
125 "Found %d instances of '%s' in response (expected %d)" % |
|
126 (real_count, text, count)) |
|
127 else: |
|
128 self.failUnless(real_count != 0, |
|
129 "Couldn't find '%s' in response" % text) |
|
130 |
|
131 def assertFormError(self, response, form, field, errors): |
|
132 """ |
|
133 Asserts that a form used to render the response has a specific field |
|
134 error. |
|
135 """ |
|
136 # Put context(s) into a list to simplify processing. |
|
137 contexts = to_list(response.context) |
|
138 if not contexts: |
|
139 self.fail('Response did not use any contexts to render the' |
|
140 ' response') |
|
141 |
|
142 # Put error(s) into a list to simplify processing. |
|
143 errors = to_list(errors) |
|
144 |
|
145 # Search all contexts for the error. |
|
146 found_form = False |
|
147 for i,context in enumerate(contexts): |
|
148 if form not in context: |
|
149 continue |
|
150 found_form = True |
|
151 for err in errors: |
|
152 if field: |
|
153 if field in context[form].errors: |
|
154 field_errors = context[form].errors[field] |
|
155 self.failUnless(err in field_errors, |
|
156 "The field '%s' on form '%s' in" |
|
157 " context %d does not contain the" |
|
158 " error '%s' (actual errors: %s)" % |
|
159 (field, form, i, err, |
|
160 repr(field_errors))) |
|
161 elif field in context[form].fields: |
|
162 self.fail("The field '%s' on form '%s' in context %d" |
|
163 " contains no errors" % (field, form, i)) |
|
164 else: |
|
165 self.fail("The form '%s' in context %d does not" |
|
166 " contain the field '%s'" % |
|
167 (form, i, field)) |
|
168 else: |
|
169 non_field_errors = context[form].non_field_errors() |
|
170 self.failUnless(err in non_field_errors, |
|
171 "The form '%s' in context %d does not contain the" |
|
172 " non-field error '%s' (actual errors: %s)" % |
|
173 (form, i, err, non_field_errors)) |
|
174 if not found_form: |
|
175 self.fail("The form '%s' was not used to render the response" % |
|
176 form) |
|
177 |
|
178 def assertTemplateUsed(self, response, template_name): |
|
179 """ |
|
180 Asserts that the template with the provided name was used in rendering |
|
181 the response. |
|
182 """ |
|
183 template_names = [t.name for t in to_list(response.template)] |
|
184 if not template_names: |
|
185 self.fail('No templates used to render the response') |
|
186 self.failUnless(template_name in template_names, |
|
187 (u"Template '%s' was not a template used to render the response." |
|
188 u" Actual template(s) used: %s") % (template_name, |
|
189 u', '.join(template_names))) |
|
190 |
|
191 def assertTemplateNotUsed(self, response, template_name): |
|
192 """ |
|
193 Asserts that the template with the provided name was NOT used in |
|
194 rendering the response. |
|
195 """ |
|
196 template_names = [t.name for t in to_list(response.template)] |
|
197 self.failIf(template_name in template_names, |
|
198 (u"Template '%s' was used unexpectedly in rendering the" |
|
199 u" response") % template_name) |