app/django/test/client.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/test/client.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/test/client.py	Tue Oct 14 16:00:59 2008 +0000
@@ -1,12 +1,16 @@
 import urllib
 import sys
-from cStringIO import StringIO
+import os
+try:
+    from cStringIO import StringIO
+except ImportError:
+    from StringIO import StringIO
+
 from django.conf import settings
 from django.contrib.auth import authenticate, login
 from django.core.handlers.base import BaseHandler
 from django.core.handlers.wsgi import WSGIRequest
 from django.core.signals import got_request_exception
-from django.dispatch import dispatcher
 from django.http import SimpleCookie, HttpRequest
 from django.template import TemplateDoesNotExist
 from django.test import signals
@@ -18,6 +22,27 @@
 BOUNDARY = 'BoUnDaRyStRiNg'
 MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
 
+
+class FakePayload(object):
+    """
+    A wrapper around StringIO that restricts what can be read since data from
+    the network can't be seeked and cannot be read outside of its content
+    length. This makes sure that views can't do anything under the test client
+    that wouldn't work in Real Life.
+    """
+    def __init__(self, content):
+        self.__content = StringIO(content)
+        self.__len = len(content)
+
+    def read(self, num_bytes=None):
+        if num_bytes is None:
+            num_bytes = self.__len or 1
+        assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
+        content = self.__content.read(num_bytes)
+        self.__len -= num_bytes
+        return content
+
+
 class ClientHandler(BaseHandler):
     """
     A HTTP Handler that can be used for testing purposes.
@@ -33,29 +58,30 @@
         if self._request_middleware is None:
             self.load_middleware()
 
-        dispatcher.send(signal=signals.request_started)
+        signals.request_started.send(sender=self.__class__)
         try:
             request = WSGIRequest(environ)
             response = self.get_response(request)
 
-            # Apply response middleware
+            # Apply response middleware.
             for middleware_method in self._response_middleware:
                 response = middleware_method(request, response)
             response = self.apply_response_fixes(request, response)
         finally:
-            dispatcher.send(signal=signals.request_finished)
+            signals.request_finished.send(sender=self.__class__)
 
         return response
 
-def store_rendered_templates(store, signal, sender, template, context):
-    "A utility function for storing templates and contexts that are rendered"
+def store_rendered_templates(store, signal, sender, template, context, **kwargs):
+    """
+    Stores templates and contexts that are rendered.
+    """
     store.setdefault('template',[]).append(template)
     store.setdefault('context',[]).append(context)
 
 def encode_multipart(boundary, data):
     """
-    A simple method for encoding multipart POST data from a dictionary of
-    form values.
+    Encodes multipart POST data from a dictionary of form values.
 
     The key will be used as the form data name; the value will be transmitted
     as content. If the value is a file, the contents of the file will be sent
@@ -63,31 +89,34 @@
     """
     lines = []
     to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET)
+
+    # Not by any means perfect, but good enough for our purposes.
+    is_file = lambda thing: hasattr(thing, "read") and callable(thing.read)
+
+    # Each bit of the multipart form data could be either a form value or a
+    # file, or a *list* of form values and/or files. Remember that HTTP field
+    # names can be duplicated!
     for (key, value) in data.items():
-        if isinstance(value, file):
-            lines.extend([
-                '--' + boundary,
-                'Content-Disposition: form-data; name="%s"; filename="%s"' % (to_str(key), to_str(value.name)),
-                'Content-Type: application/octet-stream',
-                '',
-                value.read()
-            ])
-        else:
-            if not isinstance(value, basestring) and is_iterable(value):
-                for item in value:
+        if is_file(value):
+            lines.extend(encode_file(boundary, key, value))
+        elif not isinstance(value, basestring) and is_iterable(value):
+            for item in value:
+                if is_file(item):
+                    lines.extend(encode_file(boundary, key, item))
+                else:
                     lines.extend([
                         '--' + boundary,
                         'Content-Disposition: form-data; name="%s"' % to_str(key),
                         '',
                         to_str(item)
                     ])
-            else:
-                lines.extend([
-                    '--' + boundary,
-                    'Content-Disposition: form-data; name="%s"' % to_str(key),
-                    '',
-                    to_str(value)
-                ])
+        else:
+            lines.extend([
+                '--' + boundary,
+                'Content-Disposition: form-data; name="%s"' % to_str(key),
+                '',
+                to_str(value)
+            ])
 
     lines.extend([
         '--' + boundary + '--',
@@ -95,7 +124,18 @@
     ])
     return '\r\n'.join(lines)
 
-class Client:
+def encode_file(boundary, key, file):
+    to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET)
+    return [
+        '--' + boundary,
+        'Content-Disposition: form-data; name="%s"; filename="%s"' \
+            % (to_str(key), to_str(os.path.basename(file.name))),
+        'Content-Type: application/octet-stream',
+        '',
+        file.read()
+    ]
+
+class Client(object):
     """
     A class that can act as a client for testing purposes.
 
@@ -119,15 +159,16 @@
         self.cookies = SimpleCookie()
         self.exc_info = None
 
-    def store_exc_info(self, *args, **kwargs):
+    def store_exc_info(self, **kwargs):
         """
-        Utility method that can be used to store exceptions when they are
-        generated by a view.
+        Stores exceptions when they are generated by a view.
         """
         self.exc_info = sys.exc_info()
 
     def _session(self):
-        "Obtain the current session variables"
+        """
+        Obtains the current session variables.
+        """
         if 'django.contrib.sessions' in settings.INSTALLED_APPS:
             engine = __import__(settings.SESSION_ENGINE, {}, {}, [''])
             cookie = self.cookies.get(settings.SESSION_COOKIE_NAME, None)
@@ -143,28 +184,27 @@
         Assumes defaults for the query environment, which can be overridden
         using the arguments to the request.
         """
-
         environ = {
             'HTTP_COOKIE':      self.cookies,
             'PATH_INFO':         '/',
             'QUERY_STRING':      '',
             'REQUEST_METHOD':    'GET',
-            'SCRIPT_NAME':       None,
+            'SCRIPT_NAME':       '',
             'SERVER_NAME':       'testserver',
-            'SERVER_PORT':       80,
+            'SERVER_PORT':       '80',
             'SERVER_PROTOCOL':   'HTTP/1.1',
         }
         environ.update(self.defaults)
         environ.update(request)
 
-        # Curry a data dictionary into an instance of
-        # the template renderer callback function
+        # Curry a data dictionary into an instance of the template renderer
+        # callback function.
         data = {}
         on_template_render = curry(store_rendered_templates, data)
-        dispatcher.connect(on_template_render, signal=signals.template_rendered)
+        signals.template_rendered.connect(on_template_render)
 
-        # Capture exceptions created by the handler
-        dispatcher.connect(self.store_exc_info, signal=got_request_exception)
+        # Capture exceptions created by the handler.
+        got_request_exception.connect(self.store_exc_info)
 
         try:
             response = self.handler(environ)
@@ -178,17 +218,22 @@
             if e.args != ('500.html',):
                 raise
 
-        # Look for a signalled exception and reraise it
+        # Look for a signalled exception, clear the current context
+        # exception data, then re-raise the signalled exception.
+        # Also make sure that the signalled exception is cleared from
+        # the local cache!
         if self.exc_info:
-            raise self.exc_info[1], None, self.exc_info[2]
+            exc_info = self.exc_info
+            self.exc_info = None
+            raise exc_info[1], None, exc_info[2]
 
-        # Save the client and request that stimulated the response
+        # Save the client and request that stimulated the response.
         response.client = self
         response.request = request
 
-        # Add any rendered template detail to the response
+        # Add any rendered template detail to the response.
         # If there was only one template rendered (the most likely case),
-        # flatten the list to a single element
+        # flatten the list to a single element.
         for detail in ('template', 'context'):
             if data.get(detail):
                 if len(data[detail]) == 1:
@@ -198,14 +243,16 @@
             else:
                 setattr(response, detail, None)
 
-        # Update persistent cookie data
+        # Update persistent cookie data.
         if response.cookies:
             self.cookies.update(response.cookies)
 
         return response
 
     def get(self, path, data={}, **extra):
-        "Request a response from the server using GET."
+        """
+        Requests a response from the server using GET.
+        """
         r = {
             'CONTENT_LENGTH':  None,
             'CONTENT_TYPE':    'text/html; charset=utf-8',
@@ -218,8 +265,9 @@
         return self.request(**r)
 
     def post(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
-        "Request a response from the server using POST."
-
+        """
+        Requests a response from the server using POST.
+        """
         if content_type is MULTIPART_CONTENT:
             post_data = encode_multipart(BOUNDARY, data)
         else:
@@ -230,37 +278,109 @@
             'CONTENT_TYPE':   content_type,
             'PATH_INFO':      urllib.unquote(path),
             'REQUEST_METHOD': 'POST',
-            'wsgi.input':     StringIO(post_data),
+            'wsgi.input':     FakePayload(post_data),
+        }
+        r.update(extra)
+
+        return self.request(**r)
+
+    def head(self, path, data={}, **extra):
+        """
+        Request a response from the server using HEAD.
+        """
+        r = {
+            'CONTENT_LENGTH':  None,
+            'CONTENT_TYPE':    'text/html; charset=utf-8',
+            'PATH_INFO':       urllib.unquote(path),
+            'QUERY_STRING':    urlencode(data, doseq=True),
+            'REQUEST_METHOD': 'HEAD',
+        }
+        r.update(extra)
+
+        return self.request(**r)
+
+    def options(self, path, data={}, **extra):
+        """
+        Request a response from the server using OPTIONS.
+        """
+        r = {
+            'CONTENT_LENGTH':  None,
+            'CONTENT_TYPE':    None,
+            'PATH_INFO':       urllib.unquote(path),
+            'QUERY_STRING':    urlencode(data, doseq=True),
+            'REQUEST_METHOD': 'OPTIONS',
         }
         r.update(extra)
 
         return self.request(**r)
 
+    def put(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
+        """
+        Send a resource to the server using PUT.
+        """
+        if content_type is MULTIPART_CONTENT:
+            post_data = encode_multipart(BOUNDARY, data)
+        else:
+            post_data = data
+        r = {
+            'CONTENT_LENGTH': len(post_data),
+            'CONTENT_TYPE':   content_type,
+            'PATH_INFO':      urllib.unquote(path),
+            'REQUEST_METHOD': 'PUT',
+            'wsgi.input':     FakePayload(post_data),
+        }
+        r.update(extra)
+
+        return self.request(**r)
+
+    def delete(self, path, data={}, **extra):
+        """
+        Send a DELETE request to the server.
+        """
+        r = {
+            'CONTENT_LENGTH':  None,
+            'CONTENT_TYPE':    None,
+            'PATH_INFO':       urllib.unquote(path),
+            'REQUEST_METHOD': 'DELETE',
+            }
+        r.update(extra)
+
+        return self.request(**r)
+
     def login(self, **credentials):
-        """Set the Client to appear as if it has sucessfully logged into a site.
+        """
+        Sets the Client to appear as if it has successfully logged into a site.
 
         Returns True if login is possible; False if the provided credentials
         are incorrect, or the user is inactive, or if the sessions framework is
         not available.
         """
         user = authenticate(**credentials)
-        if user and user.is_active and 'django.contrib.sessions' in settings.INSTALLED_APPS:
+        if user and user.is_active \
+                and 'django.contrib.sessions' in settings.INSTALLED_APPS:
             engine = __import__(settings.SESSION_ENGINE, {}, {}, [''])
 
-            # Create a fake request to store login details
+            # Create a fake request to store login details.
             request = HttpRequest()
-            request.session = engine.SessionStore()
+            if self.session:
+                request.session = self.session
+            else:
+                request.session = engine.SessionStore()
             login(request, user)
 
-            # Set the cookie to represent the session
-            self.cookies[settings.SESSION_COOKIE_NAME] = request.session.session_key
-            self.cookies[settings.SESSION_COOKIE_NAME]['max-age'] = None
-            self.cookies[settings.SESSION_COOKIE_NAME]['path'] = '/'
-            self.cookies[settings.SESSION_COOKIE_NAME]['domain'] = settings.SESSION_COOKIE_DOMAIN
-            self.cookies[settings.SESSION_COOKIE_NAME]['secure'] = settings.SESSION_COOKIE_SECURE or None
-            self.cookies[settings.SESSION_COOKIE_NAME]['expires'] = None
+            # Set the cookie to represent the session.
+            session_cookie = settings.SESSION_COOKIE_NAME
+            self.cookies[session_cookie] = request.session.session_key
+            cookie_data = {
+                'max-age': None,
+                'path': '/',
+                'domain': settings.SESSION_COOKIE_DOMAIN,
+                'secure': settings.SESSION_COOKIE_SECURE or None,
+                'expires': None,
+            }
+            self.cookies[session_cookie].update(cookie_data)
 
-            # Save the session values
+            # Save the session values.
             request.session.save()
 
             return True
@@ -268,7 +388,8 @@
             return False
 
     def logout(self):
-        """Removes the authenticated user's cookies.
+        """
+        Removes the authenticated user's cookies.
 
         Causes the authenticated user to be logged out.
         """