app/django/utils/simplejson/decoder.py
changeset 323 ff1a9aa48cfd
parent 54 03e267d67478
--- a/app/django/utils/simplejson/decoder.py	Tue Oct 14 12:36:55 2008 +0000
+++ b/app/django/utils/simplejson/decoder.py	Tue Oct 14 16:00:59 2008 +0000
@@ -2,8 +2,13 @@
 Implementation of JSONDecoder
 """
 import re
+import sys
 
 from django.utils.simplejson.scanner import Scanner, pattern
+try:
+    from django.utils.simplejson._speedups import scanstring as c_scanstring
+except ImportError:
+    pass
 
 FLAGS = re.VERBOSE | re.MULTILINE | re.DOTALL
 
@@ -18,6 +23,7 @@
 
 NaN, PosInf, NegInf = _floatconstants()
 
+
 def linecol(doc, pos):
     lineno = doc.count('\n', 0, pos) + 1
     if lineno == 1:
@@ -26,6 +32,7 @@
         colno = pos - doc.rindex('\n', 0, pos)
     return lineno, colno
 
+
 def errmsg(msg, doc, pos, end=None):
     lineno, colno = linecol(doc, pos)
     if end is None:
@@ -34,6 +41,7 @@
     return '%s: line %d column %d - line %d column %d (char %d - %d)' % (
         msg, lineno, colno, endlineno, endcolno, pos, end)
 
+
 _CONSTANTS = {
     '-Infinity': NegInf,
     'Infinity': PosInf,
@@ -44,20 +52,30 @@
 }
 
 def JSONConstant(match, context, c=_CONSTANTS):
-    return c[match.group(0)], None
+    s = match.group(0)
+    fn = getattr(context, 'parse_constant', None)
+    if fn is None:
+        rval = c[s]
+    else:
+        rval = fn(s)
+    return rval, None
 pattern('(-?Infinity|NaN|true|false|null)')(JSONConstant)
 
+
 def JSONNumber(match, context):
     match = JSONNumber.regex.match(match.string, *match.span())
     integer, frac, exp = match.groups()
     if frac or exp:
-        res = float(integer + (frac or '') + (exp or ''))
+        fn = getattr(context, 'parse_float', None) or float
+        res = fn(integer + (frac or '') + (exp or ''))
     else:
-        res = int(integer)
+        fn = getattr(context, 'parse_int', None) or int
+        res = fn(integer)
     return res, None
 pattern(r'(-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?')(JSONNumber)
 
-STRINGCHUNK = re.compile(r'(.*?)(["\\])', FLAGS)
+
+STRINGCHUNK = re.compile(r'(.*?)(["\\\x00-\x1f])', FLAGS)
 BACKSLASH = {
     '"': u'"', '\\': u'\\', '/': u'/',
     'b': u'\b', 'f': u'\f', 'n': u'\n', 'r': u'\r', 't': u'\t',
@@ -65,7 +83,7 @@
 
 DEFAULT_ENCODING = "utf-8"
 
-def scanstring(s, end, encoding=None, _b=BACKSLASH, _m=STRINGCHUNK.match):
+def py_scanstring(s, end, encoding=None, strict=True, _b=BACKSLASH, _m=STRINGCHUNK.match):
     if encoding is None:
         encoding = DEFAULT_ENCODING
     chunks = []
@@ -84,6 +102,12 @@
             _append(content)
         if terminator == '"':
             break
+        elif terminator != '\\':
+            if strict:
+                raise ValueError(errmsg("Invalid control character %r at", s, end))
+            else:
+                _append(terminator)
+                continue
         try:
             esc = s[end]
         except IndexError:
@@ -98,21 +122,43 @@
             end += 1
         else:
             esc = s[end + 1:end + 5]
+            next_end = end + 5
+            msg = "Invalid \\uXXXX escape"
             try:
-                m = unichr(int(esc, 16))
-                if len(esc) != 4 or not esc.isalnum():
+                if len(esc) != 4:
                     raise ValueError
+                uni = int(esc, 16)
+                if 0xd800 <= uni <= 0xdbff and sys.maxunicode > 65535:
+                    msg = "Invalid \\uXXXX\\uXXXX surrogate pair"
+                    if not s[end + 5:end + 7] == '\\u':
+                        raise ValueError
+                    esc2 = s[end + 7:end + 11]
+                    if len(esc2) != 4:
+                        raise ValueError
+                    uni2 = int(esc2, 16)
+                    uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00))
+                    next_end += 6
+                m = unichr(uni)
             except ValueError:
-                raise ValueError(errmsg("Invalid \\uXXXX escape", s, end))
-            end += 5
+                raise ValueError(errmsg(msg, s, end))
+            end = next_end
         _append(m)
     return u''.join(chunks), end
 
+
+# Use speedup
+try:
+    scanstring = c_scanstring
+except NameError:
+    scanstring = py_scanstring
+
 def JSONString(match, context):
     encoding = getattr(context, 'encoding', None)
-    return scanstring(match.string, match.end(), encoding)
+    strict = getattr(context, 'strict', True)
+    return scanstring(match.string, match.end(), encoding, strict)
 pattern(r'"')(JSONString)
 
+
 WHITESPACE = re.compile(r'\s*', FLAGS)
 
 def JSONObject(match, context, _w=WHITESPACE.match):
@@ -120,16 +166,17 @@
     s = match.string
     end = _w(s, match.end()).end()
     nextchar = s[end:end + 1]
-    # trivial empty object
+    # Trivial empty object
     if nextchar == '}':
         return pairs, end + 1
     if nextchar != '"':
         raise ValueError(errmsg("Expecting property name", s, end))
     end += 1
     encoding = getattr(context, 'encoding', None)
+    strict = getattr(context, 'strict', True)
     iterscan = JSONScanner.iterscan
     while True:
-        key, end = scanstring(s, end, encoding)
+        key, end = scanstring(s, end, encoding, strict)
         end = _w(s, end).end()
         if s[end:end + 1] != ':':
             raise ValueError(errmsg("Expecting : delimiter", s, end))
@@ -156,12 +203,13 @@
         pairs = object_hook(pairs)
     return pairs, end
 pattern(r'{')(JSONObject)
-            
+
+
 def JSONArray(match, context, _w=WHITESPACE.match):
     values = []
     s = match.string
     end = _w(s, match.end()).end()
-    # look-ahead for trivial empty array
+    # Look-ahead for trivial empty array
     nextchar = s[end:end + 1]
     if nextchar == ']':
         return values, end + 1
@@ -182,7 +230,8 @@
         end = _w(s, end).end()
     return values, end
 pattern(r'\[')(JSONArray)
- 
+
+
 ANYTHING = [
     JSONObject,
     JSONArray,
@@ -193,11 +242,12 @@
 
 JSONScanner = Scanner(ANYTHING)
 
+
 class JSONDecoder(object):
     """
     Simple JSON <http://json.org> decoder
 
-    Performs the following translations in decoding:
+    Performs the following translations in decoding by default:
     
     +---------------+-------------------+
     | JSON          | Python            |
@@ -226,7 +276,8 @@
     _scanner = Scanner(ANYTHING)
     __all__ = ['__init__', 'decode', 'raw_decode']
 
-    def __init__(self, encoding=None, object_hook=None):
+    def __init__(self, encoding=None, object_hook=None, parse_float=None,
+            parse_int=None, parse_constant=None, strict=True):
         """
         ``encoding`` determines the encoding used to interpret any ``str``
         objects decoded by this instance (utf-8 by default).  It has no
@@ -239,9 +290,28 @@
         of every JSON object decoded and its return value will be used in
         place of the given ``dict``.  This can be used to provide custom
         deserializations (e.g. to support JSON-RPC class hinting).
+
+        ``parse_float``, if specified, will be called with the string
+        of every JSON float to be decoded. By default this is equivalent to
+        float(num_str). This can be used to use another datatype or parser
+        for JSON floats (e.g. decimal.Decimal).
+
+        ``parse_int``, if specified, will be called with the string
+        of every JSON int to be decoded. By default this is equivalent to
+        int(num_str). This can be used to use another datatype or parser
+        for JSON integers (e.g. float).
+
+        ``parse_constant``, if specified, will be called with one of the
+        following strings: -Infinity, Infinity, NaN, null, true, false.
+        This can be used to raise an exception if invalid JSON numbers
+        are encountered.
         """
         self.encoding = encoding
         self.object_hook = object_hook
+        self.parse_float = parse_float
+        self.parse_int = parse_int
+        self.parse_constant = parse_constant
+        self.strict = strict
 
     def decode(self, s, _w=WHITESPACE.match):
         """