--- a/app/django/test/client.py Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/test/client.py Tue Oct 14 16:00:59 2008 +0000
@@ -1,12 +1,16 @@
import urllib
import sys
-from cStringIO import StringIO
+import os
+try:
+ from cStringIO import StringIO
+except ImportError:
+ from StringIO import StringIO
+
from django.conf import settings
from django.contrib.auth import authenticate, login
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.core.signals import got_request_exception
-from django.dispatch import dispatcher
from django.http import SimpleCookie, HttpRequest
from django.template import TemplateDoesNotExist
from django.test import signals
@@ -18,6 +22,27 @@
BOUNDARY = 'BoUnDaRyStRiNg'
MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
+
+class FakePayload(object):
+ """
+ A wrapper around StringIO that restricts what can be read since data from
+ the network can't be seeked and cannot be read outside of its content
+ length. This makes sure that views can't do anything under the test client
+ that wouldn't work in Real Life.
+ """
+ def __init__(self, content):
+ self.__content = StringIO(content)
+ self.__len = len(content)
+
+ def read(self, num_bytes=None):
+ if num_bytes is None:
+ num_bytes = self.__len or 1
+ assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
+ content = self.__content.read(num_bytes)
+ self.__len -= num_bytes
+ return content
+
+
class ClientHandler(BaseHandler):
"""
A HTTP Handler that can be used for testing purposes.
@@ -33,29 +58,30 @@
if self._request_middleware is None:
self.load_middleware()
- dispatcher.send(signal=signals.request_started)
+ signals.request_started.send(sender=self.__class__)
try:
request = WSGIRequest(environ)
response = self.get_response(request)
- # Apply response middleware
+ # Apply response middleware.
for middleware_method in self._response_middleware:
response = middleware_method(request, response)
response = self.apply_response_fixes(request, response)
finally:
- dispatcher.send(signal=signals.request_finished)
+ signals.request_finished.send(sender=self.__class__)
return response
-def store_rendered_templates(store, signal, sender, template, context):
- "A utility function for storing templates and contexts that are rendered"
+def store_rendered_templates(store, signal, sender, template, context, **kwargs):
+ """
+ Stores templates and contexts that are rendered.
+ """
store.setdefault('template',[]).append(template)
store.setdefault('context',[]).append(context)
def encode_multipart(boundary, data):
"""
- A simple method for encoding multipart POST data from a dictionary of
- form values.
+ Encodes multipart POST data from a dictionary of form values.
The key will be used as the form data name; the value will be transmitted
as content. If the value is a file, the contents of the file will be sent
@@ -63,31 +89,34 @@
"""
lines = []
to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET)
+
+ # Not by any means perfect, but good enough for our purposes.
+ is_file = lambda thing: hasattr(thing, "read") and callable(thing.read)
+
+ # Each bit of the multipart form data could be either a form value or a
+ # file, or a *list* of form values and/or files. Remember that HTTP field
+ # names can be duplicated!
for (key, value) in data.items():
- if isinstance(value, file):
- lines.extend([
- '--' + boundary,
- 'Content-Disposition: form-data; name="%s"; filename="%s"' % (to_str(key), to_str(value.name)),
- 'Content-Type: application/octet-stream',
- '',
- value.read()
- ])
- else:
- if not isinstance(value, basestring) and is_iterable(value):
- for item in value:
+ if is_file(value):
+ lines.extend(encode_file(boundary, key, value))
+ elif not isinstance(value, basestring) and is_iterable(value):
+ for item in value:
+ if is_file(item):
+ lines.extend(encode_file(boundary, key, item))
+ else:
lines.extend([
'--' + boundary,
'Content-Disposition: form-data; name="%s"' % to_str(key),
'',
to_str(item)
])
- else:
- lines.extend([
- '--' + boundary,
- 'Content-Disposition: form-data; name="%s"' % to_str(key),
- '',
- to_str(value)
- ])
+ else:
+ lines.extend([
+ '--' + boundary,
+ 'Content-Disposition: form-data; name="%s"' % to_str(key),
+ '',
+ to_str(value)
+ ])
lines.extend([
'--' + boundary + '--',
@@ -95,7 +124,18 @@
])
return '\r\n'.join(lines)
-class Client:
+def encode_file(boundary, key, file):
+ to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET)
+ return [
+ '--' + boundary,
+ 'Content-Disposition: form-data; name="%s"; filename="%s"' \
+ % (to_str(key), to_str(os.path.basename(file.name))),
+ 'Content-Type: application/octet-stream',
+ '',
+ file.read()
+ ]
+
+class Client(object):
"""
A class that can act as a client for testing purposes.
@@ -119,15 +159,16 @@
self.cookies = SimpleCookie()
self.exc_info = None
- def store_exc_info(self, *args, **kwargs):
+ def store_exc_info(self, **kwargs):
"""
- Utility method that can be used to store exceptions when they are
- generated by a view.
+ Stores exceptions when they are generated by a view.
"""
self.exc_info = sys.exc_info()
def _session(self):
- "Obtain the current session variables"
+ """
+ Obtains the current session variables.
+ """
if 'django.contrib.sessions' in settings.INSTALLED_APPS:
engine = __import__(settings.SESSION_ENGINE, {}, {}, [''])
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME, None)
@@ -143,28 +184,27 @@
Assumes defaults for the query environment, which can be overridden
using the arguments to the request.
"""
-
environ = {
'HTTP_COOKIE': self.cookies,
'PATH_INFO': '/',
'QUERY_STRING': '',
'REQUEST_METHOD': 'GET',
- 'SCRIPT_NAME': None,
+ 'SCRIPT_NAME': '',
'SERVER_NAME': 'testserver',
- 'SERVER_PORT': 80,
+ 'SERVER_PORT': '80',
'SERVER_PROTOCOL': 'HTTP/1.1',
}
environ.update(self.defaults)
environ.update(request)
- # Curry a data dictionary into an instance of
- # the template renderer callback function
+ # Curry a data dictionary into an instance of the template renderer
+ # callback function.
data = {}
on_template_render = curry(store_rendered_templates, data)
- dispatcher.connect(on_template_render, signal=signals.template_rendered)
+ signals.template_rendered.connect(on_template_render)
- # Capture exceptions created by the handler
- dispatcher.connect(self.store_exc_info, signal=got_request_exception)
+ # Capture exceptions created by the handler.
+ got_request_exception.connect(self.store_exc_info)
try:
response = self.handler(environ)
@@ -178,17 +218,22 @@
if e.args != ('500.html',):
raise
- # Look for a signalled exception and reraise it
+ # Look for a signalled exception, clear the current context
+ # exception data, then re-raise the signalled exception.
+ # Also make sure that the signalled exception is cleared from
+ # the local cache!
if self.exc_info:
- raise self.exc_info[1], None, self.exc_info[2]
+ exc_info = self.exc_info
+ self.exc_info = None
+ raise exc_info[1], None, exc_info[2]
- # Save the client and request that stimulated the response
+ # Save the client and request that stimulated the response.
response.client = self
response.request = request
- # Add any rendered template detail to the response
+ # Add any rendered template detail to the response.
# If there was only one template rendered (the most likely case),
- # flatten the list to a single element
+ # flatten the list to a single element.
for detail in ('template', 'context'):
if data.get(detail):
if len(data[detail]) == 1:
@@ -198,14 +243,16 @@
else:
setattr(response, detail, None)
- # Update persistent cookie data
+ # Update persistent cookie data.
if response.cookies:
self.cookies.update(response.cookies)
return response
def get(self, path, data={}, **extra):
- "Request a response from the server using GET."
+ """
+ Requests a response from the server using GET.
+ """
r = {
'CONTENT_LENGTH': None,
'CONTENT_TYPE': 'text/html; charset=utf-8',
@@ -218,8 +265,9 @@
return self.request(**r)
def post(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
- "Request a response from the server using POST."
-
+ """
+ Requests a response from the server using POST.
+ """
if content_type is MULTIPART_CONTENT:
post_data = encode_multipart(BOUNDARY, data)
else:
@@ -230,37 +278,109 @@
'CONTENT_TYPE': content_type,
'PATH_INFO': urllib.unquote(path),
'REQUEST_METHOD': 'POST',
- 'wsgi.input': StringIO(post_data),
+ 'wsgi.input': FakePayload(post_data),
+ }
+ r.update(extra)
+
+ return self.request(**r)
+
+ def head(self, path, data={}, **extra):
+ """
+ Request a response from the server using HEAD.
+ """
+ r = {
+ 'CONTENT_LENGTH': None,
+ 'CONTENT_TYPE': 'text/html; charset=utf-8',
+ 'PATH_INFO': urllib.unquote(path),
+ 'QUERY_STRING': urlencode(data, doseq=True),
+ 'REQUEST_METHOD': 'HEAD',
+ }
+ r.update(extra)
+
+ return self.request(**r)
+
+ def options(self, path, data={}, **extra):
+ """
+ Request a response from the server using OPTIONS.
+ """
+ r = {
+ 'CONTENT_LENGTH': None,
+ 'CONTENT_TYPE': None,
+ 'PATH_INFO': urllib.unquote(path),
+ 'QUERY_STRING': urlencode(data, doseq=True),
+ 'REQUEST_METHOD': 'OPTIONS',
}
r.update(extra)
return self.request(**r)
+ def put(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
+ """
+ Send a resource to the server using PUT.
+ """
+ if content_type is MULTIPART_CONTENT:
+ post_data = encode_multipart(BOUNDARY, data)
+ else:
+ post_data = data
+ r = {
+ 'CONTENT_LENGTH': len(post_data),
+ 'CONTENT_TYPE': content_type,
+ 'PATH_INFO': urllib.unquote(path),
+ 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': FakePayload(post_data),
+ }
+ r.update(extra)
+
+ return self.request(**r)
+
+ def delete(self, path, data={}, **extra):
+ """
+ Send a DELETE request to the server.
+ """
+ r = {
+ 'CONTENT_LENGTH': None,
+ 'CONTENT_TYPE': None,
+ 'PATH_INFO': urllib.unquote(path),
+ 'REQUEST_METHOD': 'DELETE',
+ }
+ r.update(extra)
+
+ return self.request(**r)
+
def login(self, **credentials):
- """Set the Client to appear as if it has sucessfully logged into a site.
+ """
+ Sets the Client to appear as if it has successfully logged into a site.
Returns True if login is possible; False if the provided credentials
are incorrect, or the user is inactive, or if the sessions framework is
not available.
"""
user = authenticate(**credentials)
- if user and user.is_active and 'django.contrib.sessions' in settings.INSTALLED_APPS:
+ if user and user.is_active \
+ and 'django.contrib.sessions' in settings.INSTALLED_APPS:
engine = __import__(settings.SESSION_ENGINE, {}, {}, [''])
- # Create a fake request to store login details
+ # Create a fake request to store login details.
request = HttpRequest()
- request.session = engine.SessionStore()
+ if self.session:
+ request.session = self.session
+ else:
+ request.session = engine.SessionStore()
login(request, user)
- # Set the cookie to represent the session
- self.cookies[settings.SESSION_COOKIE_NAME] = request.session.session_key
- self.cookies[settings.SESSION_COOKIE_NAME]['max-age'] = None
- self.cookies[settings.SESSION_COOKIE_NAME]['path'] = '/'
- self.cookies[settings.SESSION_COOKIE_NAME]['domain'] = settings.SESSION_COOKIE_DOMAIN
- self.cookies[settings.SESSION_COOKIE_NAME]['secure'] = settings.SESSION_COOKIE_SECURE or None
- self.cookies[settings.SESSION_COOKIE_NAME]['expires'] = None
+ # Set the cookie to represent the session.
+ session_cookie = settings.SESSION_COOKIE_NAME
+ self.cookies[session_cookie] = request.session.session_key
+ cookie_data = {
+ 'max-age': None,
+ 'path': '/',
+ 'domain': settings.SESSION_COOKIE_DOMAIN,
+ 'secure': settings.SESSION_COOKIE_SECURE or None,
+ 'expires': None,
+ }
+ self.cookies[session_cookie].update(cookie_data)
- # Save the session values
+ # Save the session values.
request.session.save()
return True
@@ -268,7 +388,8 @@
return False
def logout(self):
- """Removes the authenticated user's cookies.
+ """
+ Removes the authenticated user's cookies.
Causes the authenticated user to be logged out.
"""