123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- r"""Evaluate match expressions, as used by `-k` and `-m`.
- The grammar is:
- expression: expr? EOF
- expr: and_expr ('or' and_expr)*
- and_expr: not_expr ('and' not_expr)*
- not_expr: 'not' not_expr | '(' expr ')' | ident
- ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
- The semantics are:
- - Empty expression evaluates to False.
- - ident evaluates to True of False according to a provided matcher function.
- - or/and/not evaluate according to the usual boolean semantics.
- """
- import ast
- import dataclasses
- import enum
- import re
- import sys
- import types
- from typing import Callable
- from typing import Iterator
- from typing import Mapping
- from typing import NoReturn
- from typing import Optional
- from typing import Sequence
- if sys.version_info >= (3, 8):
- astNameConstant = ast.Constant
- else:
- astNameConstant = ast.NameConstant
- __all__ = [
- "Expression",
- "ParseError",
- ]
- class TokenType(enum.Enum):
- LPAREN = "left parenthesis"
- RPAREN = "right parenthesis"
- OR = "or"
- AND = "and"
- NOT = "not"
- IDENT = "identifier"
- EOF = "end of input"
- @dataclasses.dataclass(frozen=True)
- class Token:
- __slots__ = ("type", "value", "pos")
- type: TokenType
- value: str
- pos: int
- class ParseError(Exception):
- """The expression contains invalid syntax.
- :param column: The column in the line where the error occurred (1-based).
- :param message: A description of the error.
- """
- def __init__(self, column: int, message: str) -> None:
- self.column = column
- self.message = message
- def __str__(self) -> str:
- return f"at column {self.column}: {self.message}"
- class Scanner:
- __slots__ = ("tokens", "current")
- def __init__(self, input: str) -> None:
- self.tokens = self.lex(input)
- self.current = next(self.tokens)
- def lex(self, input: str) -> Iterator[Token]:
- pos = 0
- while pos < len(input):
- if input[pos] in (" ", "\t"):
- pos += 1
- elif input[pos] == "(":
- yield Token(TokenType.LPAREN, "(", pos)
- pos += 1
- elif input[pos] == ")":
- yield Token(TokenType.RPAREN, ")", pos)
- pos += 1
- else:
- match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
- if match:
- value = match.group(0)
- if value == "or":
- yield Token(TokenType.OR, value, pos)
- elif value == "and":
- yield Token(TokenType.AND, value, pos)
- elif value == "not":
- yield Token(TokenType.NOT, value, pos)
- else:
- yield Token(TokenType.IDENT, value, pos)
- pos += len(value)
- else:
- raise ParseError(
- pos + 1,
- f'unexpected character "{input[pos]}"',
- )
- yield Token(TokenType.EOF, "", pos)
- def accept(self, type: TokenType, *, reject: bool = False) -> Optional[Token]:
- if self.current.type is type:
- token = self.current
- if token.type is not TokenType.EOF:
- self.current = next(self.tokens)
- return token
- if reject:
- self.reject((type,))
- return None
- def reject(self, expected: Sequence[TokenType]) -> NoReturn:
- raise ParseError(
- self.current.pos + 1,
- "expected {}; got {}".format(
- " OR ".join(type.value for type in expected),
- self.current.type.value,
- ),
- )
- # True, False and None are legal match expression identifiers,
- # but illegal as Python identifiers. To fix this, this prefix
- # is added to identifiers in the conversion to Python AST.
- IDENT_PREFIX = "$"
- def expression(s: Scanner) -> ast.Expression:
- if s.accept(TokenType.EOF):
- ret: ast.expr = astNameConstant(False)
- else:
- ret = expr(s)
- s.accept(TokenType.EOF, reject=True)
- return ast.fix_missing_locations(ast.Expression(ret))
- def expr(s: Scanner) -> ast.expr:
- ret = and_expr(s)
- while s.accept(TokenType.OR):
- rhs = and_expr(s)
- ret = ast.BoolOp(ast.Or(), [ret, rhs])
- return ret
- def and_expr(s: Scanner) -> ast.expr:
- ret = not_expr(s)
- while s.accept(TokenType.AND):
- rhs = not_expr(s)
- ret = ast.BoolOp(ast.And(), [ret, rhs])
- return ret
- def not_expr(s: Scanner) -> ast.expr:
- if s.accept(TokenType.NOT):
- return ast.UnaryOp(ast.Not(), not_expr(s))
- if s.accept(TokenType.LPAREN):
- ret = expr(s)
- s.accept(TokenType.RPAREN, reject=True)
- return ret
- ident = s.accept(TokenType.IDENT)
- if ident:
- return ast.Name(IDENT_PREFIX + ident.value, ast.Load())
- s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
- class MatcherAdapter(Mapping[str, bool]):
- """Adapts a matcher function to a locals mapping as required by eval()."""
- def __init__(self, matcher: Callable[[str], bool]) -> None:
- self.matcher = matcher
- def __getitem__(self, key: str) -> bool:
- return self.matcher(key[len(IDENT_PREFIX) :])
- def __iter__(self) -> Iterator[str]:
- raise NotImplementedError()
- def __len__(self) -> int:
- raise NotImplementedError()
- class Expression:
- """A compiled match expression as used by -k and -m.
- The expression can be evaluated against different matchers.
- """
- __slots__ = ("code",)
- def __init__(self, code: types.CodeType) -> None:
- self.code = code
- @classmethod
- def compile(self, input: str) -> "Expression":
- """Compile a match expression.
- :param input: The input expression - one line.
- """
- astexpr = expression(Scanner(input))
- code: types.CodeType = compile(
- astexpr,
- filename="<pytest match expression>",
- mode="eval",
- )
- return Expression(code)
- def evaluate(self, matcher: Callable[[str], bool]) -> bool:
- """Evaluate the match expression.
- :param matcher:
- Given an identifier, should return whether it matches or not.
- Should be prepared to handle arbitrary strings as input.
- :returns: Whether the expression matches or not.
- """
- ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
- return ret
|