app/django/test/testcases.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
equal deleted inserted replaced
322:6641e941ef1e 323:ff1a9aa48cfd
     1 import re
     1 import re
     2 import unittest
     2 import unittest
     3 from urlparse import urlsplit, urlunsplit
     3 from urlparse import urlsplit, urlunsplit
     4 
     4 from xml.dom.minidom import parseString, Node
     5 from django.http import QueryDict
     5 
     6 from django.db import transaction
     6 from django.conf import settings
     7 from django.core import mail
     7 from django.core import mail
     8 from django.core.management import call_command
     8 from django.core.management import call_command
       
     9 from django.core.urlresolvers import clear_url_caches
       
    10 from django.db import transaction
       
    11 from django.http import QueryDict
     9 from django.test import _doctest as doctest
    12 from django.test import _doctest as doctest
    10 from django.test.client import Client
    13 from django.test.client import Client
       
    14 from django.utils import simplejson
    11 
    15 
    12 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
    16 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
    13 
    17 
    14 def to_list(value):
    18 def to_list(value):
    15     """
    19     """
    23     return value
    27     return value
    24 
    28 
    25 
    29 
    26 class OutputChecker(doctest.OutputChecker):
    30 class OutputChecker(doctest.OutputChecker):
    27     def check_output(self, want, got, optionflags):
    31     def check_output(self, want, got, optionflags):
    28         ok = doctest.OutputChecker.check_output(self, want, got, optionflags)
    32         "The entry method for doctest output checking. Defers to a sequence of child checkers"
    29 
    33         checks = (self.check_output_default,
    30         # Doctest does an exact string comparison of output, which means long
    34                   self.check_output_long,
    31         # integers aren't equal to normal integers ("22L" vs. "22"). The
    35                   self.check_output_xml,
    32         # following code normalizes long integers so that they equal normal
    36                   self.check_output_json)
    33         # integers.
    37         for check in checks:
    34         if not ok:
    38             if check(want, got, optionflags):
    35             return normalize_long_ints(want) == normalize_long_ints(got)
    39                 return True
    36         return ok
    40         return False
       
    41 
       
    42     def check_output_default(self, want, got, optionflags):
       
    43         "The default comparator provided by doctest - not perfect, but good for most purposes"
       
    44         return doctest.OutputChecker.check_output(self, want, got, optionflags)
       
    45 
       
    46     def check_output_long(self, want, got, optionflags):
       
    47         """Doctest does an exact string comparison of output, which means long
       
    48         integers aren't equal to normal integers ("22L" vs. "22"). The
       
    49         following code normalizes long integers so that they equal normal
       
    50         integers.
       
    51         """
       
    52         return normalize_long_ints(want) == normalize_long_ints(got)
       
    53 
       
    54     def check_output_xml(self, want, got, optionsflags):
       
    55         """Tries to do a 'xml-comparision' of want and got.  Plain string
       
    56         comparision doesn't always work because, for example, attribute
       
    57         ordering should not be important.
       
    58         
       
    59         Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
       
    60         """
       
    61         _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
       
    62         def norm_whitespace(v):
       
    63             return _norm_whitespace_re.sub(' ', v)
       
    64 
       
    65         def child_text(element):
       
    66             return ''.join([c.data for c in element.childNodes
       
    67                             if c.nodeType == Node.TEXT_NODE])
       
    68 
       
    69         def children(element):
       
    70             return [c for c in element.childNodes
       
    71                     if c.nodeType == Node.ELEMENT_NODE]
       
    72 
       
    73         def norm_child_text(element):
       
    74             return norm_whitespace(child_text(element))
       
    75 
       
    76         def attrs_dict(element):
       
    77             return dict(element.attributes.items())
       
    78 
       
    79         def check_element(want_element, got_element):
       
    80             if want_element.tagName != got_element.tagName:
       
    81                 return False
       
    82             if norm_child_text(want_element) != norm_child_text(got_element):
       
    83                 return False
       
    84             if attrs_dict(want_element) != attrs_dict(got_element):
       
    85                 return False
       
    86             want_children = children(want_element)
       
    87             got_children = children(got_element)
       
    88             if len(want_children) != len(got_children):
       
    89                 return False
       
    90             for want, got in zip(want_children, got_children):
       
    91                 if not check_element(want, got):
       
    92                     return False
       
    93             return True
       
    94 
       
    95         want, got = self._strip_quotes(want, got)
       
    96         want = want.replace('\\n','\n')
       
    97         got = got.replace('\\n','\n')
       
    98 
       
    99         # If the string is not a complete xml document, we may need to add a
       
   100         # root element. This allow us to compare fragments, like "<foo/><bar/>"
       
   101         if not want.startswith('<?xml'):
       
   102             wrapper = '<root>%s</root>'
       
   103             want = wrapper % want
       
   104             got = wrapper % got
       
   105             
       
   106         # Parse the want and got strings, and compare the parsings.
       
   107         try:
       
   108             want_root = parseString(want).firstChild
       
   109             got_root = parseString(got).firstChild
       
   110         except:
       
   111             return False
       
   112         return check_element(want_root, got_root)
       
   113 
       
   114     def check_output_json(self, want, got, optionsflags):
       
   115         "Tries to compare want and got as if they were JSON-encoded data"
       
   116         want, got = self._strip_quotes(want, got)
       
   117         try:
       
   118             want_json = simplejson.loads(want)
       
   119             got_json = simplejson.loads(got)
       
   120         except:
       
   121             return False
       
   122         return want_json == got_json
       
   123 
       
   124     def _strip_quotes(self, want, got):
       
   125         """
       
   126         Strip quotes of doctests output values:
       
   127 
       
   128         >>> o = OutputChecker()
       
   129         >>> o._strip_quotes("'foo'")
       
   130         "foo"
       
   131         >>> o._strip_quotes('"foo"')
       
   132         "foo"
       
   133         >>> o._strip_quotes("u'foo'")
       
   134         "foo"
       
   135         >>> o._strip_quotes('u"foo"')
       
   136         "foo"
       
   137         """
       
   138         def is_quoted_string(s):
       
   139             s = s.strip()
       
   140             return (len(s) >= 2
       
   141                     and s[0] == s[-1]
       
   142                     and s[0] in ('"', "'"))
       
   143 
       
   144         def is_quoted_unicode(s):
       
   145             s = s.strip()
       
   146             return (len(s) >= 3
       
   147                     and s[0] == 'u'
       
   148                     and s[1] == s[-1]
       
   149                     and s[1] in ('"', "'"))
       
   150 
       
   151         if is_quoted_string(want) and is_quoted_string(got):
       
   152             want = want.strip()[1:-1]
       
   153             got = got.strip()[1:-1]
       
   154         elif is_quoted_unicode(want) and is_quoted_unicode(got):
       
   155             want = want.strip()[2:-1]
       
   156             got = got.strip()[2:-1]
       
   157         return want, got
       
   158 
    37 
   159 
    38 class DocTestRunner(doctest.DocTestRunner):
   160 class DocTestRunner(doctest.DocTestRunner):
    39     def __init__(self, *args, **kwargs):
   161     def __init__(self, *args, **kwargs):
    40         doctest.DocTestRunner.__init__(self, *args, **kwargs)
   162         doctest.DocTestRunner.__init__(self, *args, **kwargs)
    41         self.optionflags = doctest.ELLIPSIS
   163         self.optionflags = doctest.ELLIPSIS
    52         """Performs any pre-test setup. This includes:
   174         """Performs any pre-test setup. This includes:
    53 
   175 
    54             * Flushing the database.
   176             * Flushing the database.
    55             * If the Test Case class has a 'fixtures' member, installing the 
   177             * If the Test Case class has a 'fixtures' member, installing the 
    56               named fixtures.
   178               named fixtures.
       
   179             * If the Test Case class has a 'urls' member, replace the
       
   180               ROOT_URLCONF with it.
    57             * Clearing the mail test outbox.
   181             * Clearing the mail test outbox.
    58         """
   182         """
    59         call_command('flush', verbosity=0, interactive=False)
   183         call_command('flush', verbosity=0, interactive=False)
    60         if hasattr(self, 'fixtures'):
   184         if hasattr(self, 'fixtures'):
    61             # We have to use this slightly awkward syntax due to the fact
   185             # We have to use this slightly awkward syntax due to the fact
    62             # that we're using *args and **kwargs together.
   186             # that we're using *args and **kwargs together.
    63             call_command('loaddata', *self.fixtures, **{'verbosity': 0})
   187             call_command('loaddata', *self.fixtures, **{'verbosity': 0})
       
   188         if hasattr(self, 'urls'):
       
   189             self._old_root_urlconf = settings.ROOT_URLCONF
       
   190             settings.ROOT_URLCONF = self.urls
       
   191             clear_url_caches()
    64         mail.outbox = []
   192         mail.outbox = []
    65 
   193 
    66     def __call__(self, result=None):
   194     def __call__(self, result=None):
    67         """
   195         """
    68         Wrapper around default __call__ method to perform common Django test
   196         Wrapper around default __call__ method to perform common Django test
    77         except Exception:
   205         except Exception:
    78             import sys
   206             import sys
    79             result.addError(self, sys.exc_info())
   207             result.addError(self, sys.exc_info())
    80             return
   208             return
    81         super(TestCase, self).__call__(result)
   209         super(TestCase, self).__call__(result)
       
   210         try:
       
   211             self._post_teardown()
       
   212         except (KeyboardInterrupt, SystemExit):
       
   213             raise
       
   214         except Exception:
       
   215             import sys
       
   216             result.addError(self, sys.exc_info())
       
   217             return
       
   218 
       
   219     def _post_teardown(self):
       
   220         """ Performs any post-test things. This includes:
       
   221 
       
   222             * Putting back the original ROOT_URLCONF if it was changed.
       
   223         """
       
   224         if hasattr(self, '_old_root_urlconf'):
       
   225             settings.ROOT_URLCONF = self._old_root_urlconf
       
   226             clear_url_caches()
    82 
   227 
    83     def assertRedirects(self, response, expected_url, status_code=302,
   228     def assertRedirects(self, response, expected_url, status_code=302,
    84                         target_status_code=200, host=None):
   229                         target_status_code=200, host=None):
    85         """Asserts that a response redirected to a specific URL, and that the
   230         """Asserts that a response redirected to a specific URL, and that the
    86         redirect URL can be loaded.
   231         redirect URL can be loaded.
   125                 "Found %d instances of '%s' in response (expected %d)" %
   270                 "Found %d instances of '%s' in response (expected %d)" %
   126                     (real_count, text, count))
   271                     (real_count, text, count))
   127         else:
   272         else:
   128             self.failUnless(real_count != 0,
   273             self.failUnless(real_count != 0,
   129                             "Couldn't find '%s' in response" % text)
   274                             "Couldn't find '%s' in response" % text)
       
   275 
       
   276     def assertNotContains(self, response, text, status_code=200):
       
   277         """
       
   278         Asserts that a response indicates that a page was retrieved
       
   279         successfully, (i.e., the HTTP status code was as expected), and that
       
   280         ``text`` doesn't occurs in the content of the response.
       
   281         """
       
   282         self.assertEqual(response.status_code, status_code,
       
   283             "Couldn't retrieve page: Response code was %d (expected %d)'" %
       
   284                 (response.status_code, status_code))
       
   285         self.assertEqual(response.content.count(text), 0,
       
   286                          "Response should not contain '%s'" % text)
   130 
   287 
   131     def assertFormError(self, response, form, field, errors):
   288     def assertFormError(self, response, form, field, errors):
   132         """
   289         """
   133         Asserts that a form used to render the response has a specific field
   290         Asserts that a form used to render the response has a specific field
   134         error.
   291         error.