eggs/py-1.4.0-py2.6.egg/py/_code/_assertionnew.py
changeset 69 c6bca38c1cbf
equal deleted inserted replaced
68:5ff1fc726848 69:c6bca38c1cbf
       
     1 """
       
     2 Find intermediate evalutation results in assert statements through builtin AST.
       
     3 This should replace _assertionold.py eventually.
       
     4 """
       
     5 
       
     6 import sys
       
     7 import ast
       
     8 
       
     9 import py
       
    10 from py._code.assertion import _format_explanation, BuiltinAssertionError
       
    11 
       
    12 
       
    13 if sys.platform.startswith("java") and sys.version_info < (2, 5, 2):
       
    14     # See http://bugs.jython.org/issue1497
       
    15     _exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict",
       
    16               "ListComp", "GeneratorExp", "Yield", "Compare", "Call",
       
    17               "Repr", "Num", "Str", "Attribute", "Subscript", "Name",
       
    18               "List", "Tuple")
       
    19     _stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign",
       
    20               "AugAssign", "Print", "For", "While", "If", "With", "Raise",
       
    21               "TryExcept", "TryFinally", "Assert", "Import", "ImportFrom",
       
    22               "Exec", "Global", "Expr", "Pass", "Break", "Continue")
       
    23     _expr_nodes = set(getattr(ast, name) for name in _exprs)
       
    24     _stmt_nodes = set(getattr(ast, name) for name in _stmts)
       
    25     def _is_ast_expr(node):
       
    26         return node.__class__ in _expr_nodes
       
    27     def _is_ast_stmt(node):
       
    28         return node.__class__ in _stmt_nodes
       
    29 else:
       
    30     def _is_ast_expr(node):
       
    31         return isinstance(node, ast.expr)
       
    32     def _is_ast_stmt(node):
       
    33         return isinstance(node, ast.stmt)
       
    34 
       
    35 
       
    36 class Failure(Exception):
       
    37     """Error found while interpreting AST."""
       
    38 
       
    39     def __init__(self, explanation=""):
       
    40         self.cause = sys.exc_info()
       
    41         self.explanation = explanation
       
    42 
       
    43 
       
    44 def interpret(source, frame, should_fail=False):
       
    45     mod = ast.parse(source)
       
    46     visitor = DebugInterpreter(frame)
       
    47     try:
       
    48         visitor.visit(mod)
       
    49     except Failure:
       
    50         failure = sys.exc_info()[1]
       
    51         return getfailure(failure)
       
    52     if should_fail:
       
    53         return ("(assertion failed, but when it was re-run for "
       
    54                 "printing intermediate values, it did not fail.  Suggestions: "
       
    55                 "compute assert expression before the assert or use --no-assert)")
       
    56 
       
    57 def run(offending_line, frame=None):
       
    58     if frame is None:
       
    59         frame = py.code.Frame(sys._getframe(1))
       
    60     return interpret(offending_line, frame)
       
    61 
       
    62 def getfailure(failure):
       
    63     explanation = _format_explanation(failure.explanation)
       
    64     value = failure.cause[1]
       
    65     if str(value):
       
    66         lines = explanation.splitlines()
       
    67         if not lines:
       
    68             lines.append("")
       
    69         lines[0] += " << %s" % (value,)
       
    70         explanation = "\n".join(lines)
       
    71     text = "%s: %s" % (failure.cause[0].__name__, explanation)
       
    72     if text.startswith("AssertionError: assert "):
       
    73         text = text[16:]
       
    74     return text
       
    75 
       
    76 
       
    77 operator_map = {
       
    78     ast.BitOr : "|",
       
    79     ast.BitXor : "^",
       
    80     ast.BitAnd : "&",
       
    81     ast.LShift : "<<",
       
    82     ast.RShift : ">>",
       
    83     ast.Add : "+",
       
    84     ast.Sub : "-",
       
    85     ast.Mult : "*",
       
    86     ast.Div : "/",
       
    87     ast.FloorDiv : "//",
       
    88     ast.Mod : "%",
       
    89     ast.Eq : "==",
       
    90     ast.NotEq : "!=",
       
    91     ast.Lt : "<",
       
    92     ast.LtE : "<=",
       
    93     ast.Gt : ">",
       
    94     ast.GtE : ">=",
       
    95     ast.Pow : "**",
       
    96     ast.Is : "is",
       
    97     ast.IsNot : "is not",
       
    98     ast.In : "in",
       
    99     ast.NotIn : "not in"
       
   100 }
       
   101 
       
   102 unary_map = {
       
   103     ast.Not : "not %s",
       
   104     ast.Invert : "~%s",
       
   105     ast.USub : "-%s",
       
   106     ast.UAdd : "+%s"
       
   107 }
       
   108 
       
   109 
       
   110 class DebugInterpreter(ast.NodeVisitor):
       
   111     """Interpret AST nodes to gleam useful debugging information. """
       
   112 
       
   113     def __init__(self, frame):
       
   114         self.frame = frame
       
   115 
       
   116     def generic_visit(self, node):
       
   117         # Fallback when we don't have a special implementation.
       
   118         if _is_ast_expr(node):
       
   119             mod = ast.Expression(node)
       
   120             co = self._compile(mod)
       
   121             try:
       
   122                 result = self.frame.eval(co)
       
   123             except Exception:
       
   124                 raise Failure()
       
   125             explanation = self.frame.repr(result)
       
   126             return explanation, result
       
   127         elif _is_ast_stmt(node):
       
   128             mod = ast.Module([node])
       
   129             co = self._compile(mod, "exec")
       
   130             try:
       
   131                 self.frame.exec_(co)
       
   132             except Exception:
       
   133                 raise Failure()
       
   134             return None, None
       
   135         else:
       
   136             raise AssertionError("can't handle %s" %(node,))
       
   137 
       
   138     def _compile(self, source, mode="eval"):
       
   139         return compile(source, "<assertion interpretation>", mode)
       
   140 
       
   141     def visit_Expr(self, expr):
       
   142         return self.visit(expr.value)
       
   143 
       
   144     def visit_Module(self, mod):
       
   145         for stmt in mod.body:
       
   146             self.visit(stmt)
       
   147 
       
   148     def visit_Name(self, name):
       
   149         explanation, result = self.generic_visit(name)
       
   150         # See if the name is local.
       
   151         source = "%r in locals() is not globals()" % (name.id,)
       
   152         co = self._compile(source)
       
   153         try:
       
   154             local = self.frame.eval(co)
       
   155         except Exception:
       
   156             # have to assume it isn't
       
   157             local = False
       
   158         if not local:
       
   159             return name.id, result
       
   160         return explanation, result
       
   161 
       
   162     def visit_Compare(self, comp):
       
   163         left = comp.left
       
   164         left_explanation, left_result = self.visit(left)
       
   165         for op, next_op in zip(comp.ops, comp.comparators):
       
   166             next_explanation, next_result = self.visit(next_op)
       
   167             op_symbol = operator_map[op.__class__]
       
   168             explanation = "%s %s %s" % (left_explanation, op_symbol,
       
   169                                         next_explanation)
       
   170             source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,)
       
   171             co = self._compile(source)
       
   172             try:
       
   173                 result = self.frame.eval(co, __exprinfo_left=left_result,
       
   174                                          __exprinfo_right=next_result)
       
   175             except Exception:
       
   176                 raise Failure(explanation)
       
   177             if not result:
       
   178                 break
       
   179             left_explanation, left_result = next_explanation, next_result
       
   180 
       
   181         rcomp = py.code._reprcompare
       
   182         if rcomp:
       
   183             res = rcomp(op_symbol, left_result, next_result)
       
   184             if res:
       
   185                 explanation = res
       
   186         return explanation, result
       
   187 
       
   188     def visit_BoolOp(self, boolop):
       
   189         is_or = isinstance(boolop.op, ast.Or)
       
   190         explanations = []
       
   191         for operand in boolop.values:
       
   192             explanation, result = self.visit(operand)
       
   193             explanations.append(explanation)
       
   194             if result == is_or:
       
   195                 break
       
   196         name = is_or and " or " or " and "
       
   197         explanation = "(" + name.join(explanations) + ")"
       
   198         return explanation, result
       
   199 
       
   200     def visit_UnaryOp(self, unary):
       
   201         pattern = unary_map[unary.op.__class__]
       
   202         operand_explanation, operand_result = self.visit(unary.operand)
       
   203         explanation = pattern % (operand_explanation,)
       
   204         co = self._compile(pattern % ("__exprinfo_expr",))
       
   205         try:
       
   206             result = self.frame.eval(co, __exprinfo_expr=operand_result)
       
   207         except Exception:
       
   208             raise Failure(explanation)
       
   209         return explanation, result
       
   210 
       
   211     def visit_BinOp(self, binop):
       
   212         left_explanation, left_result = self.visit(binop.left)
       
   213         right_explanation, right_result = self.visit(binop.right)
       
   214         symbol = operator_map[binop.op.__class__]
       
   215         explanation = "(%s %s %s)" % (left_explanation, symbol,
       
   216                                       right_explanation)
       
   217         source = "__exprinfo_left %s __exprinfo_right" % (symbol,)
       
   218         co = self._compile(source)
       
   219         try:
       
   220             result = self.frame.eval(co, __exprinfo_left=left_result,
       
   221                                      __exprinfo_right=right_result)
       
   222         except Exception:
       
   223             raise Failure(explanation)
       
   224         return explanation, result
       
   225 
       
   226     def visit_Call(self, call):
       
   227         func_explanation, func = self.visit(call.func)
       
   228         arg_explanations = []
       
   229         ns = {"__exprinfo_func" : func}
       
   230         arguments = []
       
   231         for arg in call.args:
       
   232             arg_explanation, arg_result = self.visit(arg)
       
   233             arg_name = "__exprinfo_%s" % (len(ns),)
       
   234             ns[arg_name] = arg_result
       
   235             arguments.append(arg_name)
       
   236             arg_explanations.append(arg_explanation)
       
   237         for keyword in call.keywords:
       
   238             arg_explanation, arg_result = self.visit(keyword.value)
       
   239             arg_name = "__exprinfo_%s" % (len(ns),)
       
   240             ns[arg_name] = arg_result
       
   241             keyword_source = "%s=%%s" % (keyword.arg)
       
   242             arguments.append(keyword_source % (arg_name,))
       
   243             arg_explanations.append(keyword_source % (arg_explanation,))
       
   244         if call.starargs:
       
   245             arg_explanation, arg_result = self.visit(call.starargs)
       
   246             arg_name = "__exprinfo_star"
       
   247             ns[arg_name] = arg_result
       
   248             arguments.append("*%s" % (arg_name,))
       
   249             arg_explanations.append("*%s" % (arg_explanation,))
       
   250         if call.kwargs:
       
   251             arg_explanation, arg_result = self.visit(call.kwargs)
       
   252             arg_name = "__exprinfo_kwds"
       
   253             ns[arg_name] = arg_result
       
   254             arguments.append("**%s" % (arg_name,))
       
   255             arg_explanations.append("**%s" % (arg_explanation,))
       
   256         args_explained = ", ".join(arg_explanations)
       
   257         explanation = "%s(%s)" % (func_explanation, args_explained)
       
   258         args = ", ".join(arguments)
       
   259         source = "__exprinfo_func(%s)" % (args,)
       
   260         co = self._compile(source)
       
   261         try:
       
   262             result = self.frame.eval(co, **ns)
       
   263         except Exception:
       
   264             raise Failure(explanation)
       
   265         # Only show result explanation if it's not a builtin call or returns a
       
   266         # bool.
       
   267         if not isinstance(call.func, ast.Name) or \
       
   268                 not self._is_builtin_name(call.func):
       
   269             source = "isinstance(__exprinfo_value, bool)"
       
   270             co = self._compile(source)
       
   271             try:
       
   272                 is_bool = self.frame.eval(co, __exprinfo_value=result)
       
   273             except Exception:
       
   274                 is_bool = False
       
   275             if not is_bool:
       
   276                 pattern = "%s\n{%s = %s\n}"
       
   277                 rep = self.frame.repr(result)
       
   278                 explanation = pattern % (rep, rep, explanation)
       
   279         return explanation, result
       
   280 
       
   281     def _is_builtin_name(self, name):
       
   282         pattern = "%r not in globals() and %r not in locals()"
       
   283         source = pattern % (name.id, name.id)
       
   284         co = self._compile(source)
       
   285         try:
       
   286             return self.frame.eval(co)
       
   287         except Exception:
       
   288             return False
       
   289 
       
   290     def visit_Attribute(self, attr):
       
   291         if not isinstance(attr.ctx, ast.Load):
       
   292             return self.generic_visit(attr)
       
   293         source_explanation, source_result = self.visit(attr.value)
       
   294         explanation = "%s.%s" % (source_explanation, attr.attr)
       
   295         source = "__exprinfo_expr.%s" % (attr.attr,)
       
   296         co = self._compile(source)
       
   297         try:
       
   298             result = self.frame.eval(co, __exprinfo_expr=source_result)
       
   299         except Exception:
       
   300             raise Failure(explanation)
       
   301         explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result),
       
   302                                               self.frame.repr(result),
       
   303                                               source_explanation, attr.attr)
       
   304         # Check if the attr is from an instance.
       
   305         source = "%r in getattr(__exprinfo_expr, '__dict__', {})"
       
   306         source = source % (attr.attr,)
       
   307         co = self._compile(source)
       
   308         try:
       
   309             from_instance = self.frame.eval(co, __exprinfo_expr=source_result)
       
   310         except Exception:
       
   311             from_instance = True
       
   312         if from_instance:
       
   313             rep = self.frame.repr(result)
       
   314             pattern = "%s\n{%s = %s\n}"
       
   315             explanation = pattern % (rep, rep, explanation)
       
   316         return explanation, result
       
   317 
       
   318     def visit_Assert(self, assrt):
       
   319         test_explanation, test_result = self.visit(assrt.test)
       
   320         if test_explanation.startswith("False\n{False =") and \
       
   321                 test_explanation.endswith("\n"):
       
   322             test_explanation = test_explanation[15:-2]
       
   323         explanation = "assert %s" % (test_explanation,)
       
   324         if not test_result:
       
   325             try:
       
   326                 raise BuiltinAssertionError
       
   327             except Exception:
       
   328                 raise Failure(explanation)
       
   329         return explanation, test_result
       
   330 
       
   331     def visit_Assign(self, assign):
       
   332         value_explanation, value_result = self.visit(assign.value)
       
   333         explanation = "... = %s" % (value_explanation,)
       
   334         name = ast.Name("__exprinfo_expr", ast.Load(),
       
   335                         lineno=assign.value.lineno,
       
   336                         col_offset=assign.value.col_offset)
       
   337         new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno,
       
   338                                 col_offset=assign.col_offset)
       
   339         mod = ast.Module([new_assign])
       
   340         co = self._compile(mod, "exec")
       
   341         try:
       
   342             self.frame.exec_(co, __exprinfo_expr=value_result)
       
   343         except Exception:
       
   344             raise Failure(explanation)
       
   345         return explanation, value_result