Browse Source

chore(hybrid-cloud): Improve the decorate-silo-mode-tests script (#40345)

Changes to improve the decorate silo mode script
Zach Collins 2 years ago
parent
commit
3c8e973042
1 changed files with 59 additions and 37 deletions
  1. 59 37
      bin/decorate-silo-mode-tests

+ 59 - 37
bin/decorate-silo-mode-tests

@@ -2,6 +2,7 @@
 import ast
 import os
 import os.path
+import pathlib
 import shutil
 import tempfile
 from typing import Union
@@ -11,31 +12,34 @@ import click
 from sentry.utils import json
 from sentry.utils.types import Any
 
+_manifest_files_env = {
+    "SENTRY_MODEL_MANIFEST_FILE_PATH": pathlib.Path(__file__).absolute().parent.parent,
+    "GETSENTRY_MODEL_MANIFEST_FILE_PATH": pathlib.Path(__file__)
+    .absolute()
+    .parent.parent.parent.joinpath("getsentry"),
+}
+
 
 def find_test_cases_matching(model_name: str):
-    manifest = json.loads(open(os.environ["SENTRY_MODEL_MANIFEST_FILE_PATH"]).read())
-    for test_node_id, hits in manifest.items():
-        if model_name in hits:
-            parts = test_node_id.split("::")
-            yield parts[0], parts[1]
+    for env, path_root in _manifest_files_env.items():
+        manifest = json.loads(open(os.environ[env]).read())
+        for test_node_id, hits in manifest.items():
+            hit_set = {list(v.keys())[0] for v in hits}
+            if model_name in hit_set:
+                parts = test_node_id.split("::")
+                yield str(path_root.joinpath(parts[0])), parts[1]
+
+
+def pick(prompt, options):
+    choice = ""
+    while choice not in options:
+        choice = input(prompt + " (" + ", ".join(options) + ") ")  # noqa
+    return choice
 
 
 @click.command()
-@click.option(
-    "silo_mode",
-    "--silo-mode",
-    required=True,
-    help="Which mode to apply to tests",
-    type=click.Choice(
-        [
-            "control",
-            "region",
-        ]
-    ),
-)
-@click.option("set_stable", "--set-stable", default=False, is_flag=True, help="Set tests as stable")
 @click.argument("target_model", required=True)
-def main(target_model: str, silo_mode: str, set_stable: bool):
+def main(target_model: str):
     """
     Script to decorate the given target test for silo tests, making it easier to deploy changes to given tests.
     """
@@ -47,21 +51,28 @@ def main(target_model: str, silo_mode: str, set_stable: bool):
         ]  # remove any parameterization off the test case
         file_path = os.path.abspath(file_name)
         file_ast = ast.parse(open(file_path).read())
-        test_visitor = TestVisitor(test_case_name, f"{silo_mode}_silo_test", set_stable)
+        test_visitor = TestVisitor(test_case_name)
         test_visitor.visit(file_ast)
 
-        test_visitor.rewrite(file_path)
+        if not test_visitor.decorator_was_stable:
+            print(f"Updating {test_case_name}")  # noqa
+            try:
+                silo_mode = pick("silo mode?", ["all", "no", "control", "region"])
+                set_stable = pick("set stable?", ["y", "n"]) == "y"
+            except EOFError:
+                continue
+            test_visitor.rewrite(f"{silo_mode}_silo_test", set_stable, file_path)
+            return
 
 
 class TestVisitor(ast.NodeVisitor):
-    def __init__(self, target_symbol_path: str, target_test_silo_mode: str, set_stable: bool):
-        self.set_stable = set_stable
-        self.target_test_silo_mode = target_test_silo_mode
+    def __init__(self, target_symbol_path: str):
         self.target_symbol_parts = target_symbol_path.split(".")
         self.import_match_line = False
         self.decorator_match_line = None
         self.func_match_line = None
         self.class_node = None
+        self.decorator_was_stable = False
 
     def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
         if node.module == "sentry.testutils.silo":
@@ -71,7 +82,7 @@ class TestVisitor(ast.NodeVisitor):
                         self.import_match_line = (node.lineno, node.col_offset)
 
     def visit_ClassDef(self, node: ast.ClassDef) -> Any:
-        if len(self.target_symbol_parts) == 2 and self.target_test_silo_mode[0] == node.name:
+        if len(self.target_symbol_parts) == 2 and self.target_symbol_parts[0] == node.name:
             self.class_node = node
             self.generic_visit(node)
             self.class_node = None
@@ -96,13 +107,14 @@ class TestVisitor(ast.NodeVisitor):
     def mark_target(self, node: Union[ast.FunctionDef, ast.ClassDef]):
         self.func_match_line = (node.lineno, node.col_offset)
         for expr in node.decorator_list:
-            decorator_visitor = DecoratorVisitor(self.target_test_silo_mode)
+            decorator_visitor = DecoratorVisitor()
             decorator_visitor.visit(expr)
             if decorator_visitor.match_line:
                 self.decorator_match_line = decorator_visitor.match_line
+                self.decorator_was_stable = decorator_visitor.stable
                 break
 
-    def _decorate(self, lineno, match_line):
+    def _decorate(self, target_test_silo_mode, set_stable, lineno, match_line):
         if not match_line:
             return False
 
@@ -110,13 +122,13 @@ class TestVisitor(ast.NodeVisitor):
             return False
 
         ws = b" " * match_line[1]
-        if self.set_stable:
-            return ws + f"@{self.target_test_silo_mode}(stable=True)\n".encode()
+        if set_stable:
+            return ws + f"@{target_test_silo_mode}(stable=True)\n".encode()
         else:
-            return ws + f"@{self.target_test_silo_mode}\n".encode()
+            return ws + f"@{target_test_silo_mode}\n".encode()
 
-    def rewrite(self, path):
-        import_line = f"from sentry.testutils.silo import {self.target_test_silo_mode}\n".encode()
+    def rewrite(self, target_test_silo_mode, set_stable, path):
+        import_line = f"from sentry.testutils.silo import {target_test_silo_mode}\n".encode()
         if not self.decorator_match_line and not self.func_match_line:
             raise Exception(f"Could not find test case {self.target_symbol_parts}!")
 
@@ -132,16 +144,21 @@ class TestVisitor(ast.NodeVisitor):
                         tf.write(import_line)
                         continue
 
-                    if newline := self._decorate(i, self.decorator_match_line):
+                    if newline := self._decorate(
+                        target_test_silo_mode, set_stable, i, self.decorator_match_line
+                    ):
                         # If the decorator type is not changing, keep the original line.
-                        if self.target_test_silo_mode in line:
+                        # If the decorator that existed was stable, don't replace it, keep it.
+                        if self.decorator_was_stable:
                             tf.write(line.encode("utf8"))
                         else:
                             tf.write(newline)
                         continue
 
                     if not self.decorator_match_line and (
-                        newline := self._decorate(i, self.func_match_line)
+                        newline := self._decorate(
+                            target_test_silo_mode, set_stable, i, self.func_match_line
+                        )
                     ):
                         tf.write(newline)
 
@@ -152,9 +169,14 @@ class TestVisitor(ast.NodeVisitor):
 
 
 class DecoratorVisitor(ast.NodeVisitor):
-    def __init__(self, target_test_silo_mode: str):
-        self.target_test_silo_mode = target_test_silo_mode
+    def __init__(self):
         self.match_line = None
+        self.stable = False
+
+    def visit_keyword(self, node: ast.keyword) -> Any:
+        if node.arg == "stable":
+            if isinstance(node.value, ast.Constant):
+                self.stable = node.value.value
 
     def visit_Name(self, node: ast.Name) -> Any:
         if node.id.endswith("_silo_test"):