app/django/test/testcases.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/test/testcases.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/test/testcases.py	Tue Oct 14 16:00:59 2008 +0000
@@ -1,13 +1,17 @@
 import re
 import unittest
 from urlparse import urlsplit, urlunsplit
+from xml.dom.minidom import parseString, Node
 
-from django.http import QueryDict
-from django.db import transaction
+from django.conf import settings
 from django.core import mail
 from django.core.management import call_command
+from django.core.urlresolvers import clear_url_caches
+from django.db import transaction
+from django.http import QueryDict
 from django.test import _doctest as doctest
 from django.test.client import Client
+from django.utils import simplejson
 
 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
 
@@ -25,15 +29,133 @@
 
 class OutputChecker(doctest.OutputChecker):
     def check_output(self, want, got, optionflags):
-        ok = doctest.OutputChecker.check_output(self, want, got, optionflags)
+        "The entry method for doctest output checking. Defers to a sequence of child checkers"
+        checks = (self.check_output_default,
+                  self.check_output_long,
+                  self.check_output_xml,
+                  self.check_output_json)
+        for check in checks:
+            if check(want, got, optionflags):
+                return True
+        return False
+
+    def check_output_default(self, want, got, optionflags):
+        "The default comparator provided by doctest - not perfect, but good for most purposes"
+        return doctest.OutputChecker.check_output(self, want, got, optionflags)
+
+    def check_output_long(self, want, got, optionflags):
+        """Doctest does an exact string comparison of output, which means long
+        integers aren't equal to normal integers ("22L" vs. "22"). The
+        following code normalizes long integers so that they equal normal
+        integers.
+        """
+        return normalize_long_ints(want) == normalize_long_ints(got)
+
+    def check_output_xml(self, want, got, optionsflags):
+        """Tries to do a 'xml-comparision' of want and got.  Plain string
+        comparision doesn't always work because, for example, attribute
+        ordering should not be important.
+        
+        Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
+        """
+        _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
+        def norm_whitespace(v):
+            return _norm_whitespace_re.sub(' ', v)
+
+        def child_text(element):
+            return ''.join([c.data for c in element.childNodes
+                            if c.nodeType == Node.TEXT_NODE])
+
+        def children(element):
+            return [c for c in element.childNodes
+                    if c.nodeType == Node.ELEMENT_NODE]
+
+        def norm_child_text(element):
+            return norm_whitespace(child_text(element))
+
+        def attrs_dict(element):
+            return dict(element.attributes.items())
+
+        def check_element(want_element, got_element):
+            if want_element.tagName != got_element.tagName:
+                return False
+            if norm_child_text(want_element) != norm_child_text(got_element):
+                return False
+            if attrs_dict(want_element) != attrs_dict(got_element):
+                return False
+            want_children = children(want_element)
+            got_children = children(got_element)
+            if len(want_children) != len(got_children):
+                return False
+            for want, got in zip(want_children, got_children):
+                if not check_element(want, got):
+                    return False
+            return True
 
-        # Doctest does an exact string comparison of output, which means long
-        # integers aren't equal to normal integers ("22L" vs. "22"). The
-        # following code normalizes long integers so that they equal normal
-        # integers.
-        if not ok:
-            return normalize_long_ints(want) == normalize_long_ints(got)
-        return ok
+        want, got = self._strip_quotes(want, got)
+        want = want.replace('\\n','\n')
+        got = got.replace('\\n','\n')
+
+        # If the string is not a complete xml document, we may need to add a
+        # root element. This allow us to compare fragments, like "<foo/><bar/>"
+        if not want.startswith('<?xml'):
+            wrapper = '<root>%s</root>'
+            want = wrapper % want
+            got = wrapper % got
+            
+        # Parse the want and got strings, and compare the parsings.
+        try:
+            want_root = parseString(want).firstChild
+            got_root = parseString(got).firstChild
+        except:
+            return False
+        return check_element(want_root, got_root)
+
+    def check_output_json(self, want, got, optionsflags):
+        "Tries to compare want and got as if they were JSON-encoded data"
+        want, got = self._strip_quotes(want, got)
+        try:
+            want_json = simplejson.loads(want)
+            got_json = simplejson.loads(got)
+        except:
+            return False
+        return want_json == got_json
+
+    def _strip_quotes(self, want, got):
+        """
+        Strip quotes of doctests output values:
+
+        >>> o = OutputChecker()
+        >>> o._strip_quotes("'foo'")
+        "foo"
+        >>> o._strip_quotes('"foo"')
+        "foo"
+        >>> o._strip_quotes("u'foo'")
+        "foo"
+        >>> o._strip_quotes('u"foo"')
+        "foo"
+        """
+        def is_quoted_string(s):
+            s = s.strip()
+            return (len(s) >= 2
+                    and s[0] == s[-1]
+                    and s[0] in ('"', "'"))
+
+        def is_quoted_unicode(s):
+            s = s.strip()
+            return (len(s) >= 3
+                    and s[0] == 'u'
+                    and s[1] == s[-1]
+                    and s[1] in ('"', "'"))
+
+        if is_quoted_string(want) and is_quoted_string(got):
+            want = want.strip()[1:-1]
+            got = got.strip()[1:-1]
+        elif is_quoted_unicode(want) and is_quoted_unicode(got):
+            want = want.strip()[2:-1]
+            got = got.strip()[2:-1]
+        return want, got
+
 
 class DocTestRunner(doctest.DocTestRunner):
     def __init__(self, *args, **kwargs):
@@ -54,6 +176,8 @@
             * Flushing the database.
             * If the Test Case class has a 'fixtures' member, installing the 
               named fixtures.
+            * If the Test Case class has a 'urls' member, replace the
+              ROOT_URLCONF with it.
             * Clearing the mail test outbox.
         """
         call_command('flush', verbosity=0, interactive=False)
@@ -61,6 +185,10 @@
             # We have to use this slightly awkward syntax due to the fact
             # that we're using *args and **kwargs together.
             call_command('loaddata', *self.fixtures, **{'verbosity': 0})
+        if hasattr(self, 'urls'):
+            self._old_root_urlconf = settings.ROOT_URLCONF
+            settings.ROOT_URLCONF = self.urls
+            clear_url_caches()
         mail.outbox = []
 
     def __call__(self, result=None):
@@ -79,6 +207,23 @@
             result.addError(self, sys.exc_info())
             return
         super(TestCase, self).__call__(result)
+        try:
+            self._post_teardown()
+        except (KeyboardInterrupt, SystemExit):
+            raise
+        except Exception:
+            import sys
+            result.addError(self, sys.exc_info())
+            return
+
+    def _post_teardown(self):
+        """ Performs any post-test things. This includes:
+
+            * Putting back the original ROOT_URLCONF if it was changed.
+        """
+        if hasattr(self, '_old_root_urlconf'):
+            settings.ROOT_URLCONF = self._old_root_urlconf
+            clear_url_caches()
 
     def assertRedirects(self, response, expected_url, status_code=302,
                         target_status_code=200, host=None):
@@ -128,6 +273,18 @@
             self.failUnless(real_count != 0,
                             "Couldn't find '%s' in response" % text)
 
+    def assertNotContains(self, response, text, status_code=200):
+        """
+        Asserts that a response indicates that a page was retrieved
+        successfully, (i.e., the HTTP status code was as expected), and that
+        ``text`` doesn't occurs in the content of the response.
+        """
+        self.assertEqual(response.status_code, status_code,
+            "Couldn't retrieve page: Response code was %d (expected %d)'" %
+                (response.status_code, status_code))
+        self.assertEqual(response.content.count(text), 0,
+                         "Response should not contain '%s'" % text)
+
     def assertFormError(self, response, form, field, errors):
         """
         Asserts that a form used to render the response has a specific field