--- a/app/django/test/testcases.py Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/test/testcases.py Tue Oct 14 16:00:59 2008 +0000
@@ -1,13 +1,17 @@
import re
import unittest
from urlparse import urlsplit, urlunsplit
+from xml.dom.minidom import parseString, Node
-from django.http import QueryDict
-from django.db import transaction
+from django.conf import settings
from django.core import mail
from django.core.management import call_command
+from django.core.urlresolvers import clear_url_caches
+from django.db import transaction
+from django.http import QueryDict
from django.test import _doctest as doctest
from django.test.client import Client
+from django.utils import simplejson
normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
@@ -25,15 +29,133 @@
class OutputChecker(doctest.OutputChecker):
def check_output(self, want, got, optionflags):
- ok = doctest.OutputChecker.check_output(self, want, got, optionflags)
+ "The entry method for doctest output checking. Defers to a sequence of child checkers"
+ checks = (self.check_output_default,
+ self.check_output_long,
+ self.check_output_xml,
+ self.check_output_json)
+ for check in checks:
+ if check(want, got, optionflags):
+ return True
+ return False
+
+ def check_output_default(self, want, got, optionflags):
+ "The default comparator provided by doctest - not perfect, but good for most purposes"
+ return doctest.OutputChecker.check_output(self, want, got, optionflags)
+
+ def check_output_long(self, want, got, optionflags):
+ """Doctest does an exact string comparison of output, which means long
+ integers aren't equal to normal integers ("22L" vs. "22"). The
+ following code normalizes long integers so that they equal normal
+ integers.
+ """
+ return normalize_long_ints(want) == normalize_long_ints(got)
+
+ def check_output_xml(self, want, got, optionsflags):
+ """Tries to do a 'xml-comparision' of want and got. Plain string
+ comparision doesn't always work because, for example, attribute
+ ordering should not be important.
+
+ Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
+ """
+ _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
+ def norm_whitespace(v):
+ return _norm_whitespace_re.sub(' ', v)
+
+ def child_text(element):
+ return ''.join([c.data for c in element.childNodes
+ if c.nodeType == Node.TEXT_NODE])
+
+ def children(element):
+ return [c for c in element.childNodes
+ if c.nodeType == Node.ELEMENT_NODE]
+
+ def norm_child_text(element):
+ return norm_whitespace(child_text(element))
+
+ def attrs_dict(element):
+ return dict(element.attributes.items())
+
+ def check_element(want_element, got_element):
+ if want_element.tagName != got_element.tagName:
+ return False
+ if norm_child_text(want_element) != norm_child_text(got_element):
+ return False
+ if attrs_dict(want_element) != attrs_dict(got_element):
+ return False
+ want_children = children(want_element)
+ got_children = children(got_element)
+ if len(want_children) != len(got_children):
+ return False
+ for want, got in zip(want_children, got_children):
+ if not check_element(want, got):
+ return False
+ return True
- # Doctest does an exact string comparison of output, which means long
- # integers aren't equal to normal integers ("22L" vs. "22"). The
- # following code normalizes long integers so that they equal normal
- # integers.
- if not ok:
- return normalize_long_ints(want) == normalize_long_ints(got)
- return ok
+ want, got = self._strip_quotes(want, got)
+ want = want.replace('\\n','\n')
+ got = got.replace('\\n','\n')
+
+ # If the string is not a complete xml document, we may need to add a
+ # root element. This allow us to compare fragments, like "<foo/><bar/>"
+ if not want.startswith('<?xml'):
+ wrapper = '<root>%s</root>'
+ want = wrapper % want
+ got = wrapper % got
+
+ # Parse the want and got strings, and compare the parsings.
+ try:
+ want_root = parseString(want).firstChild
+ got_root = parseString(got).firstChild
+ except:
+ return False
+ return check_element(want_root, got_root)
+
+ def check_output_json(self, want, got, optionsflags):
+ "Tries to compare want and got as if they were JSON-encoded data"
+ want, got = self._strip_quotes(want, got)
+ try:
+ want_json = simplejson.loads(want)
+ got_json = simplejson.loads(got)
+ except:
+ return False
+ return want_json == got_json
+
+ def _strip_quotes(self, want, got):
+ """
+ Strip quotes of doctests output values:
+
+ >>> o = OutputChecker()
+ >>> o._strip_quotes("'foo'")
+ "foo"
+ >>> o._strip_quotes('"foo"')
+ "foo"
+ >>> o._strip_quotes("u'foo'")
+ "foo"
+ >>> o._strip_quotes('u"foo"')
+ "foo"
+ """
+ def is_quoted_string(s):
+ s = s.strip()
+ return (len(s) >= 2
+ and s[0] == s[-1]
+ and s[0] in ('"', "'"))
+
+ def is_quoted_unicode(s):
+ s = s.strip()
+ return (len(s) >= 3
+ and s[0] == 'u'
+ and s[1] == s[-1]
+ and s[1] in ('"', "'"))
+
+ if is_quoted_string(want) and is_quoted_string(got):
+ want = want.strip()[1:-1]
+ got = got.strip()[1:-1]
+ elif is_quoted_unicode(want) and is_quoted_unicode(got):
+ want = want.strip()[2:-1]
+ got = got.strip()[2:-1]
+ return want, got
+
class DocTestRunner(doctest.DocTestRunner):
def __init__(self, *args, **kwargs):
@@ -54,6 +176,8 @@
* Flushing the database.
* If the Test Case class has a 'fixtures' member, installing the
named fixtures.
+ * If the Test Case class has a 'urls' member, replace the
+ ROOT_URLCONF with it.
* Clearing the mail test outbox.
"""
call_command('flush', verbosity=0, interactive=False)
@@ -61,6 +185,10 @@
# We have to use this slightly awkward syntax due to the fact
# that we're using *args and **kwargs together.
call_command('loaddata', *self.fixtures, **{'verbosity': 0})
+ if hasattr(self, 'urls'):
+ self._old_root_urlconf = settings.ROOT_URLCONF
+ settings.ROOT_URLCONF = self.urls
+ clear_url_caches()
mail.outbox = []
def __call__(self, result=None):
@@ -79,6 +207,23 @@
result.addError(self, sys.exc_info())
return
super(TestCase, self).__call__(result)
+ try:
+ self._post_teardown()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception:
+ import sys
+ result.addError(self, sys.exc_info())
+ return
+
+ def _post_teardown(self):
+ """ Performs any post-test things. This includes:
+
+ * Putting back the original ROOT_URLCONF if it was changed.
+ """
+ if hasattr(self, '_old_root_urlconf'):
+ settings.ROOT_URLCONF = self._old_root_urlconf
+ clear_url_caches()
def assertRedirects(self, response, expected_url, status_code=302,
target_status_code=200, host=None):
@@ -128,6 +273,18 @@
self.failUnless(real_count != 0,
"Couldn't find '%s' in response" % text)
+ def assertNotContains(self, response, text, status_code=200):
+ """
+ Asserts that a response indicates that a page was retrieved
+ successfully, (i.e., the HTTP status code was as expected), and that
+ ``text`` doesn't occurs in the content of the response.
+ """
+ self.assertEqual(response.status_code, status_code,
+ "Couldn't retrieve page: Response code was %d (expected %d)'" %
+ (response.status_code, status_code))
+ self.assertEqual(response.content.count(text), 0,
+ "Response should not contain '%s'" % text)
+
def assertFormError(self, response, form, field, errors):
"""
Asserts that a form used to render the response has a specific field