123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- #!/usr/bin/env python
- import ast
- import os
- import os.path
- import pathlib
- import shutil
- import tempfile
- from typing import Union
- import click
- from sentry.utils import json
- from sentry.utils.types import Any
- _manifest_files_env = {
- "SENTRY_MODEL_MANIFEST_FILE_PATH": pathlib.Path(__file__).absolute().parent.parent,
- "GETSENTRY_MODEL_MANIFEST_FILE_PATH": pathlib.Path(__file__)
- .absolute()
- .parent.parent.parent.joinpath("getsentry"),
- }
- def find_test_cases_matching(model_name: str):
- for env, path_root in _manifest_files_env.items():
- manifest = json.loads(open(os.environ[env]).read())
- for test_node_id, hits in manifest.items():
- hit_set = {list(v.keys())[0] for v in hits}
- if model_name in hit_set:
- parts = test_node_id.split("::")
- yield str(path_root.joinpath(parts[0])), parts[1]
- def pick(prompt, options):
- choice = ""
- while choice not in options:
- choice = input(prompt + " (" + ", ".join(options) + ") ") # noqa
- return choice
- @click.command()
- @click.argument("target_model", required=True)
- def main(target_model: str):
- """
- Script to decorate the given target test for silo tests, making it easier to deploy changes to given tests.
- """
- for file_name, test_case_name in find_test_cases_matching(target_model):
- print(f"Trying {test_case_name} in {file_name}") # noqa
- test_case_name = test_case_name.split("[")[
- 0
- ] # remove any parameterization off the test case
- file_path = os.path.abspath(file_name)
- file_ast = ast.parse(open(file_path).read())
- test_visitor = TestVisitor(test_case_name)
- test_visitor.visit(file_ast)
- if not test_visitor.decorator_was_stable:
- print(f"Updating {test_case_name}") # noqa
- try:
- silo_mode = pick("silo mode?", ["all", "no", "control", "region"])
- set_stable = pick("set stable?", ["y", "n"]) == "y"
- except EOFError:
- continue
- test_visitor.rewrite(f"{silo_mode}_silo_test", set_stable, file_path)
- return
- class TestVisitor(ast.NodeVisitor):
- def __init__(self, target_symbol_path: str):
- self.target_symbol_parts = target_symbol_path.split(".")
- self.import_match_line = False
- self.decorator_match_line = None
- self.func_match_line = None
- self.class_node = None
- self.decorator_was_stable = False
- def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
- if node.module == "sentry.testutils.silo":
- for name in node.names:
- if isinstance(name, ast.alias):
- if name.name.endswith("_silo_test"):
- self.import_match_line = (node.lineno, node.col_offset)
- def visit_ClassDef(self, node: ast.ClassDef) -> Any:
- if len(self.target_symbol_parts) == 2 and self.target_symbol_parts[0] == node.name:
- self.class_node = node
- self.generic_visit(node)
- self.class_node = None
- elif len(self.target_symbol_parts) == 1:
- if self.target_symbol_parts[-1] == node.name or self.target_symbol_parts[-1] in {
- e.id for e in node.bases if isinstance(e, ast.Name)
- }:
- self.mark_target(node)
- def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.ClassDef]) -> Any:
- if self.target_symbol_parts[-1] == node.name:
- if self.class_node:
- node = self.class_node
- elif len(self.target_symbol_parts) != 1:
- return
- self.mark_target(node)
- return
- return self.generic_visit(node)
- def mark_target(self, node: Union[ast.FunctionDef, ast.ClassDef]):
- self.func_match_line = (node.lineno, node.col_offset)
- for expr in node.decorator_list:
- decorator_visitor = DecoratorVisitor()
- decorator_visitor.visit(expr)
- if decorator_visitor.match_line:
- self.decorator_match_line = decorator_visitor.match_line
- self.decorator_was_stable = decorator_visitor.stable
- break
- def _decorate(self, target_test_silo_mode, set_stable, lineno, match_line):
- if not match_line:
- return False
- if not match_line[0] == lineno:
- return False
- ws = b" " * match_line[1]
- if set_stable:
- return ws + f"@{target_test_silo_mode}(stable=True)\n".encode()
- else:
- return ws + f"@{target_test_silo_mode}\n".encode()
- def rewrite(self, target_test_silo_mode, set_stable, path):
- import_line = f"from sentry.testutils.silo import {target_test_silo_mode}\n".encode()
- if not self.decorator_match_line and not self.func_match_line:
- raise Exception(f"Could not find test case {self.target_symbol_parts}!")
- with tempfile.NamedTemporaryFile(delete=False) as tf:
- with open(path) as f:
- if not self.import_match_line:
- tf.write(import_line)
- for i, line in enumerate(f.readlines()):
- i += 1
- if self.import_match_line and self.import_match_line[0] == i:
- tf.write(import_line)
- continue
- if newline := self._decorate(
- target_test_silo_mode, set_stable, i, self.decorator_match_line
- ):
- # If the decorator type is not changing, keep the original line.
- # If the decorator that existed was stable, don't replace it, keep it.
- if self.decorator_was_stable:
- tf.write(line.encode("utf8"))
- else:
- tf.write(newline)
- continue
- if not self.decorator_match_line and (
- newline := self._decorate(
- target_test_silo_mode, set_stable, i, self.func_match_line
- )
- ):
- tf.write(newline)
- tf.write(line.encode("utf8"))
- tf.close()
- shutil.move(tf.name, path)
- class DecoratorVisitor(ast.NodeVisitor):
- def __init__(self):
- self.match_line = None
- self.stable = False
- def visit_keyword(self, node: ast.keyword) -> Any:
- if node.arg == "stable":
- if isinstance(node.value, ast.Constant):
- self.stable = node.value.value
- def visit_Name(self, node: ast.Name) -> Any:
- if node.id.endswith("_silo_test"):
- self.match_line = (node.lineno, node.col_offset - 1)
- return ast.NodeVisitor.generic_visit(self, node)
- def visit_Attribute(self, node: ast.Attribute) -> Any:
- pass
- if __name__ == "__main__":
- main()
|