123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449 |
- import ast
- import builtins
- import operator
- from collections import ChainMap, OrderedDict, deque
- from contextlib import suppress
- from types import FrameType
- from typing import Any, Tuple, Iterable, List, Mapping, Dict, Union, Set
- from pure_eval.my_getattr_static import getattr_static
- from pure_eval.utils import (
- CannotEval,
- has_ast_name,
- copy_ast_without_context,
- is_standard_types,
- of_standard_types,
- is_any,
- of_type,
- ensure_dict,
- )
- class Evaluator:
- def __init__(self, names: Mapping[str, Any]):
- """
- Construct a new evaluator with the given variable names.
- This is a low level API, typically you will use `Evaluator.from_frame(frame)`.
- :param names: a mapping from variable names to their values.
- """
- self.names = names
- self._cache = {} # type: Dict[ast.expr, Any]
- @classmethod
- def from_frame(cls, frame: FrameType) -> 'Evaluator':
- """
- Construct an Evaluator that can look up variables from the given frame.
- :param frame: a frame object, e.g. from a traceback or `inspect.currentframe().f_back`.
- """
- return cls(ChainMap(
- ensure_dict(frame.f_locals),
- ensure_dict(frame.f_globals),
- ensure_dict(frame.f_builtins),
- ))
- def __getitem__(self, node: ast.expr) -> Any:
- """
- Find the value of the given node.
- If it cannot be evaluated safely, this raises `CannotEval`.
- The result is cached either way.
- :param node: an AST expression to evaluate
- :return: the value of the node
- """
- if not isinstance(node, ast.expr):
- raise TypeError("node should be an ast.expr, not {!r}".format(type(node).__name__))
- with suppress(KeyError):
- result = self._cache[node]
- if result is CannotEval:
- raise CannotEval
- else:
- return result
- try:
- self._cache[node] = result = self._handle(node)
- return result
- except CannotEval:
- self._cache[node] = CannotEval
- raise
- def _handle(self, node: ast.expr) -> Any:
- """
- This is where the evaluation happens.
- Users should use `__getitem__`, i.e. `evaluator[node]`,
- as it provides caching.
- :param node: an AST expression to evaluate
- :return: the value of the node
- """
- with suppress(Exception):
- return ast.literal_eval(node)
- if isinstance(node, ast.Name):
- try:
- return self.names[node.id]
- except KeyError:
- raise CannotEval
- elif isinstance(node, ast.Attribute):
- value = self[node.value]
- attr = node.attr
- return getattr_static(value, attr)
- elif isinstance(node, ast.Subscript):
- return self._handle_subscript(node)
- elif isinstance(node, (ast.List, ast.Tuple, ast.Set, ast.Dict)):
- return self._handle_container(node)
- elif isinstance(node, ast.UnaryOp):
- return self._handle_unary(node)
- elif isinstance(node, ast.BinOp):
- return self._handle_binop(node)
- elif isinstance(node, ast.BoolOp):
- return self._handle_boolop(node)
- elif isinstance(node, ast.Compare):
- return self._handle_compare(node)
- elif isinstance(node, ast.Call):
- return self._handle_call(node)
- raise CannotEval
- def _handle_call(self, node):
- if node.keywords:
- raise CannotEval
- func = self[node.func]
- args = [self[arg] for arg in node.args]
- if (
- is_any(
- func,
- slice,
- int,
- range,
- round,
- complex,
- list,
- tuple,
- abs,
- hex,
- bin,
- oct,
- bool,
- ord,
- float,
- len,
- chr,
- )
- or len(args) == 0
- and is_any(func, set, dict, str, frozenset, bytes, bytearray, object)
- or len(args) >= 2
- and is_any(func, str, divmod, bytes, bytearray, pow)
- ):
- args = [
- of_standard_types(arg, check_dict_values=False, deep=False)
- for arg in args
- ]
- try:
- return func(*args)
- except Exception as e:
- raise CannotEval from e
- if len(args) == 1:
- arg = args[0]
- if is_any(func, id, type):
- try:
- return func(arg)
- except Exception as e:
- raise CannotEval from e
- if is_any(func, all, any, sum):
- of_type(arg, tuple, frozenset, list, set, dict, OrderedDict, deque)
- for x in arg:
- of_standard_types(x, check_dict_values=False, deep=False)
- try:
- return func(arg)
- except Exception as e:
- raise CannotEval from e
- if is_any(
- func, sorted, min, max, hash, set, dict, ascii, str, repr, frozenset
- ):
- of_standard_types(arg, check_dict_values=True, deep=True)
- try:
- return func(arg)
- except Exception as e:
- raise CannotEval from e
- raise CannotEval
- def _handle_compare(self, node):
- left = self[node.left]
- result = True
- for op, right in zip(node.ops, node.comparators):
- right = self[right]
- op_type = type(op)
- op_func = {
- ast.Eq: operator.eq,
- ast.NotEq: operator.ne,
- ast.Lt: operator.lt,
- ast.LtE: operator.le,
- ast.Gt: operator.gt,
- ast.GtE: operator.ge,
- ast.Is: operator.is_,
- ast.IsNot: operator.is_not,
- ast.In: (lambda a, b: a in b),
- ast.NotIn: (lambda a, b: a not in b),
- }[op_type]
- if op_type not in (ast.Is, ast.IsNot):
- of_standard_types(left, check_dict_values=False, deep=True)
- of_standard_types(right, check_dict_values=False, deep=True)
- try:
- result = op_func(left, right)
- except Exception as e:
- raise CannotEval from e
- if not result:
- return result
- left = right
- return result
- def _handle_boolop(self, node):
- left = of_standard_types(
- self[node.values[0]], check_dict_values=False, deep=False
- )
- for right in node.values[1:]:
- # We need short circuiting so that the whole operation can be evaluated
- # even if the right operand can't
- if isinstance(node.op, ast.Or):
- left = left or of_standard_types(
- self[right], check_dict_values=False, deep=False
- )
- else:
- assert isinstance(node.op, ast.And)
- left = left and of_standard_types(
- self[right], check_dict_values=False, deep=False
- )
- return left
- def _handle_binop(self, node):
- op_type = type(node.op)
- op = {
- ast.Add: operator.add,
- ast.Sub: operator.sub,
- ast.Mult: operator.mul,
- ast.Div: operator.truediv,
- ast.FloorDiv: operator.floordiv,
- ast.Mod: operator.mod,
- ast.Pow: operator.pow,
- ast.LShift: operator.lshift,
- ast.RShift: operator.rshift,
- ast.BitOr: operator.or_,
- ast.BitXor: operator.xor,
- ast.BitAnd: operator.and_,
- }.get(op_type)
- if not op:
- raise CannotEval
- left = self[node.left]
- hash_type = is_any(type(left), set, frozenset, dict, OrderedDict)
- left = of_standard_types(left, check_dict_values=False, deep=hash_type)
- formatting = type(left) in (str, bytes) and op_type == ast.Mod
- right = of_standard_types(
- self[node.right],
- check_dict_values=formatting,
- deep=formatting or hash_type,
- )
- try:
- return op(left, right)
- except Exception as e:
- raise CannotEval from e
- def _handle_unary(self, node: ast.UnaryOp):
- value = of_standard_types(
- self[node.operand], check_dict_values=False, deep=False
- )
- op_type = type(node.op)
- op = {
- ast.USub: operator.neg,
- ast.UAdd: operator.pos,
- ast.Not: operator.not_,
- ast.Invert: operator.invert,
- }[op_type]
- try:
- return op(value)
- except Exception as e:
- raise CannotEval from e
- def _handle_subscript(self, node):
- value = self[node.value]
- of_standard_types(
- value, check_dict_values=False, deep=is_any(type(value), dict, OrderedDict)
- )
- index = node.slice
- if isinstance(index, ast.Slice):
- index = slice(
- *[
- None if p is None else self[p]
- for p in [index.lower, index.upper, index.step]
- ]
- )
- elif isinstance(index, ast.ExtSlice):
- raise CannotEval
- else:
- if isinstance(index, ast.Index):
- index = index.value
- index = self[index]
- of_standard_types(index, check_dict_values=False, deep=True)
- try:
- return value[index]
- except Exception:
- raise CannotEval
- def _handle_container(
- self,
- node: Union[ast.List, ast.Tuple, ast.Set, ast.Dict]
- ) -> Union[List, Tuple, Set, Dict]:
- """Handle container nodes, including List, Set, Tuple and Dict"""
- if isinstance(node, ast.Dict):
- elts = node.keys
- if None in elts: # ** unpacking inside {}, not yet supported
- raise CannotEval
- else:
- elts = node.elts
- elts = [self[elt] for elt in elts]
- if isinstance(node, ast.List):
- return elts
- if isinstance(node, ast.Tuple):
- return tuple(elts)
- # Set and Dict
- if not all(
- is_standard_types(elt, check_dict_values=False, deep=True) for elt in elts
- ):
- raise CannotEval
- if isinstance(node, ast.Set):
- try:
- return set(elts)
- except TypeError:
- raise CannotEval
- assert isinstance(node, ast.Dict)
- pairs = [(elt, self[val]) for elt, val in zip(elts, node.values)]
- try:
- return dict(pairs)
- except TypeError:
- raise CannotEval
- def find_expressions(self, root: ast.AST) -> Iterable[Tuple[ast.expr, Any]]:
- """
- Find all expressions in the given tree that can be safely evaluated.
- This is a low level API, typically you will use `interesting_expressions_grouped`.
- :param root: any AST node
- :return: generator of pairs (tuples) of expression nodes and their corresponding values.
- """
- for node in ast.walk(root):
- if not isinstance(node, ast.expr):
- continue
- try:
- value = self[node]
- except CannotEval:
- continue
- yield node, value
- def interesting_expressions_grouped(self, root: ast.AST) -> List[Tuple[List[ast.expr], Any]]:
- """
- Find all interesting expressions in the given tree that can be safely evaluated,
- grouping equivalent nodes together.
- For more control and details, see:
- - Evaluator.find_expressions
- - is_expression_interesting
- - group_expressions
- :param root: any AST node
- :return: A list of pairs (tuples) containing:
- - A list of equivalent AST expressions
- - The value of the first expression node
- (which should be the same for all nodes, unless threads are involved)
- """
- return group_expressions(
- pair
- for pair in self.find_expressions(root)
- if is_expression_interesting(*pair)
- )
- def is_expression_interesting(node: ast.expr, value: Any) -> bool:
- """
- Determines if an expression is potentially interesting, at least in my opinion.
- Returns False for the following expressions whose value is generally obvious:
- - Literals (e.g. 123, 'abc', [1, 2, 3], {'a': (), 'b': ([1, 2], [3])})
- - Variables or attributes whose name is equal to the value's __name__.
- For example, a function `def foo(): ...` is not interesting when referred to
- as `foo` as it usually would, but `bar` can be interesting if `bar is foo`.
- Similarly the method `self.foo` is not interesting.
- - Builtins (e.g. `len`) referred to by their usual name.
- This is a low level API, typically you will use `interesting_expressions_grouped`.
- :param node: an AST expression
- :param value: the value of the node
- :return: a boolean: True if the expression is interesting, False otherwise
- """
- with suppress(ValueError):
- ast.literal_eval(node)
- return False
- # TODO exclude inner modules, e.g. numpy.random.__name__ == 'numpy.random' != 'random'
- # TODO exclude common module abbreviations, e.g. numpy as np, pandas as pd
- if has_ast_name(value, node):
- return False
- if (
- isinstance(node, ast.Name)
- and getattr(builtins, node.id, object()) is value
- ):
- return False
- return True
- def group_expressions(expressions: Iterable[Tuple[ast.expr, Any]]) -> List[Tuple[List[ast.expr], Any]]:
- """
- Organise expression nodes and their values such that equivalent nodes are together.
- Two nodes are considered equivalent if they have the same structure,
- ignoring context (Load, Store, or Delete) and location (lineno, col_offset).
- For example, this will group together the same variable name mentioned multiple times in an expression.
- This will not check the values of the nodes. Equivalent nodes should have the same values,
- unless threads are involved.
- This is a low level API, typically you will use `interesting_expressions_grouped`.
- :param expressions: pairs of AST expressions and their values, as obtained from
- `Evaluator.find_expressions`, or `(node, evaluator[node])`.
- :return: A list of pairs (tuples) containing:
- - A list of equivalent AST expressions
- - The value of the first expression node
- (which should be the same for all nodes, unless threads are involved)
- """
- result = {}
- for node, value in expressions:
- dump = ast.dump(copy_ast_without_context(node))
- result.setdefault(dump, ([], value))[0].append(node)
- return list(result.values())
|