123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- #!/usr/bin/env python
- import ast
- import os
- import os.path
- import shutil
- import tempfile
- from typing import Union
- import click
- from sentry.utils import json
- from sentry.utils.types import Any
- def find_test_cases_matching(model_name: str):
- manifest = json.loads(open(os.environ["SENTRY_MODEL_MANIFEST_FILE_PATH"]).read())
- for test_node_id, hits in manifest.items():
- if model_name in hits:
- parts = test_node_id.split("::")
- yield parts[0], parts[1]
- @click.command()
- @click.option(
- "silo_mode",
- "--silo-mode",
- required=True,
- help="Which mode to apply to tests",
- type=click.Choice(
- [
- "control",
- "region",
- ]
- ),
- )
- @click.option("set_stable", "--set-stable", default=False, is_flag=True, help="Set tests as stable")
- @click.argument("target_model", required=True)
- def main(target_model: str, silo_mode: str, set_stable: bool):
- """
- Script to decorate the given target test for silo tests, making it easier to deploy changes to given tests.
- """
- for file_name, test_case_name in find_test_cases_matching(target_model):
- print(f"Trying {test_case_name} in {file_name}") # noqa
- test_case_name = test_case_name.split("[")[
- 0
- ] # remove any parameterization off the test case
- file_path = os.path.abspath(file_name)
- file_ast = ast.parse(open(file_path).read())
- test_visitor = TestVisitor(test_case_name, f"{silo_mode}_silo_test", set_stable)
- test_visitor.visit(file_ast)
- test_visitor.rewrite(file_path)
- class TestVisitor(ast.NodeVisitor):
- def __init__(self, target_symbol_path: str, target_test_silo_mode: str, set_stable: bool):
- self.set_stable = set_stable
- self.target_test_silo_mode = target_test_silo_mode
- self.target_symbol_parts = target_symbol_path.split(".")
- self.import_match_line = False
- self.decorator_match_line = None
- self.func_match_line = None
- self.class_node = None
- def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
- if node.module == "sentry.testutils.silo":
- for name in node.names:
- if isinstance(name, ast.alias):
- if name.name.endswith("_silo_test"):
- self.import_match_line = (node.lineno, node.col_offset)
- def visit_ClassDef(self, node: ast.ClassDef) -> Any:
- if len(self.target_symbol_parts) == 2 and self.target_test_silo_mode[0] == node.name:
- self.class_node = node
- self.generic_visit(node)
- self.class_node = None
- elif len(self.target_symbol_parts) == 1:
- if self.target_symbol_parts[-1] == node.name or self.target_symbol_parts[-1] in {
- e.id for e in node.bases if isinstance(e, ast.Name)
- }:
- self.mark_target(node)
- def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.ClassDef]) -> Any:
- if self.target_symbol_parts[-1] == node.name:
- if self.class_node:
- node = self.class_node
- elif len(self.target_symbol_parts) != 1:
- return
- self.mark_target(node)
- return
- return self.generic_visit(node)
- def mark_target(self, node: Union[ast.FunctionDef, ast.ClassDef]):
- self.func_match_line = (node.lineno, node.col_offset)
- for expr in node.decorator_list:
- decorator_visitor = DecoratorVisitor(self.target_test_silo_mode)
- decorator_visitor.visit(expr)
- if decorator_visitor.match_line:
- self.decorator_match_line = decorator_visitor.match_line
- break
- def _decorate(self, lineno, match_line):
- if not match_line:
- return False
- if not match_line[0] == lineno:
- return False
- ws = b" " * match_line[1]
- if self.set_stable:
- return ws + f"@{self.target_test_silo_mode}(stable=True)\n".encode()
- else:
- return ws + f"@{self.target_test_silo_mode}\n".encode()
- def rewrite(self, path):
- import_line = f"from sentry.testutils.silo import {self.target_test_silo_mode}\n".encode()
- if not self.decorator_match_line and not self.func_match_line:
- raise Exception(f"Could not find test case {self.target_symbol_parts}!")
- with tempfile.NamedTemporaryFile(delete=False) as tf:
- with open(path) as f:
- if not self.import_match_line:
- tf.write(import_line)
- for i, line in enumerate(f.readlines()):
- i += 1
- if self.import_match_line and self.import_match_line[0] == i:
- tf.write(import_line)
- continue
- if newline := self._decorate(i, self.decorator_match_line):
- # If the decorator type is not changing, keep the original line.
- if self.target_test_silo_mode in line:
- tf.write(line.encode("utf8"))
- else:
- tf.write(newline)
- continue
- if not self.decorator_match_line and (
- newline := self._decorate(i, self.func_match_line)
- ):
- tf.write(newline)
- tf.write(line.encode("utf8"))
- tf.close()
- shutil.move(tf.name, path)
- class DecoratorVisitor(ast.NodeVisitor):
- def __init__(self, target_test_silo_mode: str):
- self.target_test_silo_mode = target_test_silo_mode
- self.match_line = None
- def visit_Name(self, node: ast.Name) -> Any:
- if node.id.endswith("_silo_test"):
- self.match_line = (node.lineno, node.col_offset - 1)
- return ast.NodeVisitor.generic_visit(self, node)
- def visit_Attribute(self, node: ast.Attribute) -> Any:
- pass
- if __name__ == "__main__":
- main()
|