diff -r 6641e941ef1e -r ff1a9aa48cfd app/django/test/client.py --- a/app/django/test/client.py Tue Oct 14 12:36:55 2008 +0000 +++ b/app/django/test/client.py Tue Oct 14 16:00:59 2008 +0000 @@ -1,12 +1,16 @@ import urllib import sys -from cStringIO import StringIO +import os +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + from django.conf import settings from django.contrib.auth import authenticate, login from django.core.handlers.base import BaseHandler from django.core.handlers.wsgi import WSGIRequest from django.core.signals import got_request_exception -from django.dispatch import dispatcher from django.http import SimpleCookie, HttpRequest from django.template import TemplateDoesNotExist from django.test import signals @@ -18,6 +22,27 @@ BOUNDARY = 'BoUnDaRyStRiNg' MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY + +class FakePayload(object): + """ + A wrapper around StringIO that restricts what can be read since data from + the network can't be seeked and cannot be read outside of its content + length. This makes sure that views can't do anything under the test client + that wouldn't work in Real Life. + """ + def __init__(self, content): + self.__content = StringIO(content) + self.__len = len(content) + + def read(self, num_bytes=None): + if num_bytes is None: + num_bytes = self.__len or 1 + assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data." + content = self.__content.read(num_bytes) + self.__len -= num_bytes + return content + + class ClientHandler(BaseHandler): """ A HTTP Handler that can be used for testing purposes. @@ -33,29 +58,30 @@ if self._request_middleware is None: self.load_middleware() - dispatcher.send(signal=signals.request_started) + signals.request_started.send(sender=self.__class__) try: request = WSGIRequest(environ) response = self.get_response(request) - # Apply response middleware + # Apply response middleware. for middleware_method in self._response_middleware: response = middleware_method(request, response) response = self.apply_response_fixes(request, response) finally: - dispatcher.send(signal=signals.request_finished) + signals.request_finished.send(sender=self.__class__) return response -def store_rendered_templates(store, signal, sender, template, context): - "A utility function for storing templates and contexts that are rendered" +def store_rendered_templates(store, signal, sender, template, context, **kwargs): + """ + Stores templates and contexts that are rendered. + """ store.setdefault('template',[]).append(template) store.setdefault('context',[]).append(context) def encode_multipart(boundary, data): """ - A simple method for encoding multipart POST data from a dictionary of - form values. + Encodes multipart POST data from a dictionary of form values. The key will be used as the form data name; the value will be transmitted as content. If the value is a file, the contents of the file will be sent @@ -63,31 +89,34 @@ """ lines = [] to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET) + + # Not by any means perfect, but good enough for our purposes. + is_file = lambda thing: hasattr(thing, "read") and callable(thing.read) + + # Each bit of the multipart form data could be either a form value or a + # file, or a *list* of form values and/or files. Remember that HTTP field + # names can be duplicated! for (key, value) in data.items(): - if isinstance(value, file): - lines.extend([ - '--' + boundary, - 'Content-Disposition: form-data; name="%s"; filename="%s"' % (to_str(key), to_str(value.name)), - 'Content-Type: application/octet-stream', - '', - value.read() - ]) - else: - if not isinstance(value, basestring) and is_iterable(value): - for item in value: + if is_file(value): + lines.extend(encode_file(boundary, key, value)) + elif not isinstance(value, basestring) and is_iterable(value): + for item in value: + if is_file(item): + lines.extend(encode_file(boundary, key, item)) + else: lines.extend([ '--' + boundary, 'Content-Disposition: form-data; name="%s"' % to_str(key), '', to_str(item) ]) - else: - lines.extend([ - '--' + boundary, - 'Content-Disposition: form-data; name="%s"' % to_str(key), - '', - to_str(value) - ]) + else: + lines.extend([ + '--' + boundary, + 'Content-Disposition: form-data; name="%s"' % to_str(key), + '', + to_str(value) + ]) lines.extend([ '--' + boundary + '--', @@ -95,7 +124,18 @@ ]) return '\r\n'.join(lines) -class Client: +def encode_file(boundary, key, file): + to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET) + return [ + '--' + boundary, + 'Content-Disposition: form-data; name="%s"; filename="%s"' \ + % (to_str(key), to_str(os.path.basename(file.name))), + 'Content-Type: application/octet-stream', + '', + file.read() + ] + +class Client(object): """ A class that can act as a client for testing purposes. @@ -119,15 +159,16 @@ self.cookies = SimpleCookie() self.exc_info = None - def store_exc_info(self, *args, **kwargs): + def store_exc_info(self, **kwargs): """ - Utility method that can be used to store exceptions when they are - generated by a view. + Stores exceptions when they are generated by a view. """ self.exc_info = sys.exc_info() def _session(self): - "Obtain the current session variables" + """ + Obtains the current session variables. + """ if 'django.contrib.sessions' in settings.INSTALLED_APPS: engine = __import__(settings.SESSION_ENGINE, {}, {}, ['']) cookie = self.cookies.get(settings.SESSION_COOKIE_NAME, None) @@ -143,28 +184,27 @@ Assumes defaults for the query environment, which can be overridden using the arguments to the request. """ - environ = { 'HTTP_COOKIE': self.cookies, 'PATH_INFO': '/', 'QUERY_STRING': '', 'REQUEST_METHOD': 'GET', - 'SCRIPT_NAME': None, + 'SCRIPT_NAME': '', 'SERVER_NAME': 'testserver', - 'SERVER_PORT': 80, + 'SERVER_PORT': '80', 'SERVER_PROTOCOL': 'HTTP/1.1', } environ.update(self.defaults) environ.update(request) - # Curry a data dictionary into an instance of - # the template renderer callback function + # Curry a data dictionary into an instance of the template renderer + # callback function. data = {} on_template_render = curry(store_rendered_templates, data) - dispatcher.connect(on_template_render, signal=signals.template_rendered) + signals.template_rendered.connect(on_template_render) - # Capture exceptions created by the handler - dispatcher.connect(self.store_exc_info, signal=got_request_exception) + # Capture exceptions created by the handler. + got_request_exception.connect(self.store_exc_info) try: response = self.handler(environ) @@ -178,17 +218,22 @@ if e.args != ('500.html',): raise - # Look for a signalled exception and reraise it + # Look for a signalled exception, clear the current context + # exception data, then re-raise the signalled exception. + # Also make sure that the signalled exception is cleared from + # the local cache! if self.exc_info: - raise self.exc_info[1], None, self.exc_info[2] + exc_info = self.exc_info + self.exc_info = None + raise exc_info[1], None, exc_info[2] - # Save the client and request that stimulated the response + # Save the client and request that stimulated the response. response.client = self response.request = request - # Add any rendered template detail to the response + # Add any rendered template detail to the response. # If there was only one template rendered (the most likely case), - # flatten the list to a single element + # flatten the list to a single element. for detail in ('template', 'context'): if data.get(detail): if len(data[detail]) == 1: @@ -198,14 +243,16 @@ else: setattr(response, detail, None) - # Update persistent cookie data + # Update persistent cookie data. if response.cookies: self.cookies.update(response.cookies) return response def get(self, path, data={}, **extra): - "Request a response from the server using GET." + """ + Requests a response from the server using GET. + """ r = { 'CONTENT_LENGTH': None, 'CONTENT_TYPE': 'text/html; charset=utf-8', @@ -218,8 +265,9 @@ return self.request(**r) def post(self, path, data={}, content_type=MULTIPART_CONTENT, **extra): - "Request a response from the server using POST." - + """ + Requests a response from the server using POST. + """ if content_type is MULTIPART_CONTENT: post_data = encode_multipart(BOUNDARY, data) else: @@ -230,37 +278,109 @@ 'CONTENT_TYPE': content_type, 'PATH_INFO': urllib.unquote(path), 'REQUEST_METHOD': 'POST', - 'wsgi.input': StringIO(post_data), + 'wsgi.input': FakePayload(post_data), + } + r.update(extra) + + return self.request(**r) + + def head(self, path, data={}, **extra): + """ + Request a response from the server using HEAD. + """ + r = { + 'CONTENT_LENGTH': None, + 'CONTENT_TYPE': 'text/html; charset=utf-8', + 'PATH_INFO': urllib.unquote(path), + 'QUERY_STRING': urlencode(data, doseq=True), + 'REQUEST_METHOD': 'HEAD', + } + r.update(extra) + + return self.request(**r) + + def options(self, path, data={}, **extra): + """ + Request a response from the server using OPTIONS. + """ + r = { + 'CONTENT_LENGTH': None, + 'CONTENT_TYPE': None, + 'PATH_INFO': urllib.unquote(path), + 'QUERY_STRING': urlencode(data, doseq=True), + 'REQUEST_METHOD': 'OPTIONS', } r.update(extra) return self.request(**r) + def put(self, path, data={}, content_type=MULTIPART_CONTENT, **extra): + """ + Send a resource to the server using PUT. + """ + if content_type is MULTIPART_CONTENT: + post_data = encode_multipart(BOUNDARY, data) + else: + post_data = data + r = { + 'CONTENT_LENGTH': len(post_data), + 'CONTENT_TYPE': content_type, + 'PATH_INFO': urllib.unquote(path), + 'REQUEST_METHOD': 'PUT', + 'wsgi.input': FakePayload(post_data), + } + r.update(extra) + + return self.request(**r) + + def delete(self, path, data={}, **extra): + """ + Send a DELETE request to the server. + """ + r = { + 'CONTENT_LENGTH': None, + 'CONTENT_TYPE': None, + 'PATH_INFO': urllib.unquote(path), + 'REQUEST_METHOD': 'DELETE', + } + r.update(extra) + + return self.request(**r) + def login(self, **credentials): - """Set the Client to appear as if it has sucessfully logged into a site. + """ + Sets the Client to appear as if it has successfully logged into a site. Returns True if login is possible; False if the provided credentials are incorrect, or the user is inactive, or if the sessions framework is not available. """ user = authenticate(**credentials) - if user and user.is_active and 'django.contrib.sessions' in settings.INSTALLED_APPS: + if user and user.is_active \ + and 'django.contrib.sessions' in settings.INSTALLED_APPS: engine = __import__(settings.SESSION_ENGINE, {}, {}, ['']) - # Create a fake request to store login details + # Create a fake request to store login details. request = HttpRequest() - request.session = engine.SessionStore() + if self.session: + request.session = self.session + else: + request.session = engine.SessionStore() login(request, user) - # Set the cookie to represent the session - self.cookies[settings.SESSION_COOKIE_NAME] = request.session.session_key - self.cookies[settings.SESSION_COOKIE_NAME]['max-age'] = None - self.cookies[settings.SESSION_COOKIE_NAME]['path'] = '/' - self.cookies[settings.SESSION_COOKIE_NAME]['domain'] = settings.SESSION_COOKIE_DOMAIN - self.cookies[settings.SESSION_COOKIE_NAME]['secure'] = settings.SESSION_COOKIE_SECURE or None - self.cookies[settings.SESSION_COOKIE_NAME]['expires'] = None + # Set the cookie to represent the session. + session_cookie = settings.SESSION_COOKIE_NAME + self.cookies[session_cookie] = request.session.session_key + cookie_data = { + 'max-age': None, + 'path': '/', + 'domain': settings.SESSION_COOKIE_DOMAIN, + 'secure': settings.SESSION_COOKIE_SECURE or None, + 'expires': None, + } + self.cookies[session_cookie].update(cookie_data) - # Save the session values + # Save the session values. request.session.save() return True @@ -268,7 +388,8 @@ return False def logout(self): - """Removes the authenticated user's cookies. + """ + Removes the authenticated user's cookies. Causes the authenticated user to be logged out. """