#!/usr/bin/env python import ast import os import os.path import shutil import tempfile from typing import Union import click from sentry.utils import json from sentry.utils.types import Any def find_test_cases_matching(model_name: str): manifest = json.loads(open(os.environ["SENTRY_MODEL_MANIFEST_FILE_PATH"]).read()) for test_node_id, hits in manifest.items(): if model_name in hits: parts = test_node_id.split("::") yield parts[0], parts[1] @click.command() @click.option( "silo_mode", "--silo-mode", required=True, help="Which mode to apply to tests", type=click.Choice( [ "control", "region", ] ), ) @click.option("set_stable", "--set-stable", default=False, is_flag=True, help="Set tests as stable") @click.argument("target_model", required=True) def main(target_model: str, silo_mode: str, set_stable: bool): """ Script to decorate the given target test for silo tests, making it easier to deploy changes to given tests. """ for file_name, test_case_name in find_test_cases_matching(target_model): print(f"Trying {test_case_name} in {file_name}") # noqa test_case_name = test_case_name.split("[")[ 0 ] # remove any parameterization off the test case file_path = os.path.abspath(file_name) file_ast = ast.parse(open(file_path).read()) test_visitor = TestVisitor(test_case_name, f"{silo_mode}_silo_test", set_stable) test_visitor.visit(file_ast) test_visitor.rewrite(file_path) class TestVisitor(ast.NodeVisitor): def __init__(self, target_symbol_path: str, target_test_silo_mode: str, set_stable: bool): self.set_stable = set_stable self.target_test_silo_mode = target_test_silo_mode self.target_symbol_parts = target_symbol_path.split(".") self.import_match_line = False self.decorator_match_line = None self.func_match_line = None self.class_node = None def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: if node.module == "sentry.testutils.silo": for name in node.names: if isinstance(name, ast.alias): if name.name.endswith("_silo_test"): self.import_match_line = (node.lineno, node.col_offset) def visit_ClassDef(self, node: ast.ClassDef) -> Any: if len(self.target_symbol_parts) == 2 and self.target_test_silo_mode[0] == node.name: self.class_node = node self.generic_visit(node) self.class_node = None elif len(self.target_symbol_parts) == 1: if self.target_symbol_parts[-1] == node.name or self.target_symbol_parts[-1] in { e.id for e in node.bases if isinstance(e, ast.Name) }: self.mark_target(node) def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.ClassDef]) -> Any: if self.target_symbol_parts[-1] == node.name: if self.class_node: node = self.class_node elif len(self.target_symbol_parts) != 1: return self.mark_target(node) return return self.generic_visit(node) def mark_target(self, node: Union[ast.FunctionDef, ast.ClassDef]): self.func_match_line = (node.lineno, node.col_offset) for expr in node.decorator_list: decorator_visitor = DecoratorVisitor(self.target_test_silo_mode) decorator_visitor.visit(expr) if decorator_visitor.match_line: self.decorator_match_line = decorator_visitor.match_line break def _decorate(self, lineno, match_line): if not match_line: return False if not match_line[0] == lineno: return False ws = b" " * match_line[1] if self.set_stable: return ws + f"@{self.target_test_silo_mode}(stable=True)\n".encode() else: return ws + f"@{self.target_test_silo_mode}\n".encode() def rewrite(self, path): import_line = f"from sentry.testutils.silo import {self.target_test_silo_mode}\n".encode() if not self.decorator_match_line and not self.func_match_line: raise Exception(f"Could not find test case {self.target_symbol_parts}!") with tempfile.NamedTemporaryFile(delete=False) as tf: with open(path) as f: if not self.import_match_line: tf.write(import_line) for i, line in enumerate(f.readlines()): i += 1 if self.import_match_line and self.import_match_line[0] == i: tf.write(import_line) continue if newline := self._decorate(i, self.decorator_match_line): # If the decorator type is not changing, keep the original line. if self.target_test_silo_mode in line: tf.write(line.encode("utf8")) else: tf.write(newline) continue if not self.decorator_match_line and ( newline := self._decorate(i, self.func_match_line) ): tf.write(newline) tf.write(line.encode("utf8")) tf.close() shutil.move(tf.name, path) class DecoratorVisitor(ast.NodeVisitor): def __init__(self, target_test_silo_mode: str): self.target_test_silo_mode = target_test_silo_mode self.match_line = None def visit_Name(self, node: ast.Name) -> Any: if node.id.endswith("_silo_test"): self.match_line = (node.lineno, node.col_offset - 1) return ast.NodeVisitor.generic_visit(self, node) def visit_Attribute(self, node: ast.Attribute) -> Any: pass if __name__ == "__main__": main()