decorate-silo-mode-tests 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #!/usr/bin/env python
  2. import ast
  3. import os
  4. import os.path
  5. import pathlib
  6. import shutil
  7. import tempfile
  8. from typing import Union
  9. import click
  10. from sentry.utils import json
  11. from sentry.utils.types import Any
  12. _manifest_files_env = {
  13. "SENTRY_MODEL_MANIFEST_FILE_PATH": pathlib.Path(__file__).absolute().parent.parent,
  14. "GETSENTRY_MODEL_MANIFEST_FILE_PATH": pathlib.Path(__file__)
  15. .absolute()
  16. .parent.parent.parent.joinpath("getsentry"),
  17. }
  18. def find_test_cases_matching(model_name: str):
  19. for env, path_root in _manifest_files_env.items():
  20. manifest = json.loads(open(os.environ[env]).read())
  21. for test_node_id, hits in manifest.items():
  22. hit_set = {list(v.keys())[0] for v in hits}
  23. if model_name in hit_set:
  24. parts = test_node_id.split("::")
  25. yield str(path_root.joinpath(parts[0])), parts[1]
  26. def pick(prompt, options):
  27. choice = ""
  28. while choice not in options:
  29. choice = input(prompt + " (" + ", ".join(options) + ") ") # noqa
  30. return choice
  31. @click.command()
  32. @click.argument("target_model", required=True)
  33. def main(target_model: str):
  34. """
  35. Script to decorate the given target test for silo tests, making it easier to deploy changes to given tests.
  36. """
  37. for file_name, test_case_name in find_test_cases_matching(target_model):
  38. print(f"Trying {test_case_name} in {file_name}") # noqa
  39. test_case_name = test_case_name.split("[")[
  40. 0
  41. ] # remove any parameterization off the test case
  42. file_path = os.path.abspath(file_name)
  43. file_ast = ast.parse(open(file_path).read())
  44. test_visitor = TestVisitor(test_case_name)
  45. test_visitor.visit(file_ast)
  46. if not test_visitor.decorator_was_stable:
  47. print(f"Updating {test_case_name}") # noqa
  48. try:
  49. silo_mode = pick("silo mode?", ["all", "no", "control", "region"])
  50. set_stable = pick("set stable?", ["y", "n"]) == "y"
  51. except EOFError:
  52. continue
  53. test_visitor.rewrite(f"{silo_mode}_silo_test", set_stable, file_path)
  54. return
  55. class TestVisitor(ast.NodeVisitor):
  56. def __init__(self, target_symbol_path: str):
  57. self.target_symbol_parts = target_symbol_path.split(".")
  58. self.import_match_line = False
  59. self.decorator_match_line = None
  60. self.func_match_line = None
  61. self.class_node = None
  62. self.decorator_was_stable = False
  63. def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
  64. if node.module == "sentry.testutils.silo":
  65. for name in node.names:
  66. if isinstance(name, ast.alias):
  67. if name.name.endswith("_silo_test"):
  68. self.import_match_line = (node.lineno, node.col_offset)
  69. def visit_ClassDef(self, node: ast.ClassDef) -> Any:
  70. if len(self.target_symbol_parts) == 2 and self.target_symbol_parts[0] == node.name:
  71. self.class_node = node
  72. self.generic_visit(node)
  73. self.class_node = None
  74. elif len(self.target_symbol_parts) == 1:
  75. if self.target_symbol_parts[-1] == node.name or self.target_symbol_parts[-1] in {
  76. e.id for e in node.bases if isinstance(e, ast.Name)
  77. }:
  78. self.mark_target(node)
  79. def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.ClassDef]) -> Any:
  80. if self.target_symbol_parts[-1] == node.name:
  81. if self.class_node:
  82. node = self.class_node
  83. elif len(self.target_symbol_parts) != 1:
  84. return
  85. self.mark_target(node)
  86. return
  87. return self.generic_visit(node)
  88. def mark_target(self, node: Union[ast.FunctionDef, ast.ClassDef]):
  89. self.func_match_line = (node.lineno, node.col_offset)
  90. for expr in node.decorator_list:
  91. decorator_visitor = DecoratorVisitor()
  92. decorator_visitor.visit(expr)
  93. if decorator_visitor.match_line:
  94. self.decorator_match_line = decorator_visitor.match_line
  95. self.decorator_was_stable = decorator_visitor.stable
  96. break
  97. def _decorate(self, target_test_silo_mode, set_stable, lineno, match_line):
  98. if not match_line:
  99. return False
  100. if not match_line[0] == lineno:
  101. return False
  102. ws = b" " * match_line[1]
  103. if set_stable:
  104. return ws + f"@{target_test_silo_mode}(stable=True)\n".encode()
  105. else:
  106. return ws + f"@{target_test_silo_mode}\n".encode()
  107. def rewrite(self, target_test_silo_mode, set_stable, path):
  108. import_line = f"from sentry.testutils.silo import {target_test_silo_mode}\n".encode()
  109. if not self.decorator_match_line and not self.func_match_line:
  110. raise Exception(f"Could not find test case {self.target_symbol_parts}!")
  111. with tempfile.NamedTemporaryFile(delete=False) as tf:
  112. with open(path) as f:
  113. if not self.import_match_line:
  114. tf.write(import_line)
  115. for i, line in enumerate(f.readlines()):
  116. i += 1
  117. if self.import_match_line and self.import_match_line[0] == i:
  118. tf.write(import_line)
  119. continue
  120. if newline := self._decorate(
  121. target_test_silo_mode, set_stable, i, self.decorator_match_line
  122. ):
  123. # If the decorator type is not changing, keep the original line.
  124. # If the decorator that existed was stable, don't replace it, keep it.
  125. if self.decorator_was_stable:
  126. tf.write(line.encode("utf8"))
  127. else:
  128. tf.write(newline)
  129. continue
  130. if not self.decorator_match_line and (
  131. newline := self._decorate(
  132. target_test_silo_mode, set_stable, i, self.func_match_line
  133. )
  134. ):
  135. tf.write(newline)
  136. tf.write(line.encode("utf8"))
  137. tf.close()
  138. shutil.move(tf.name, path)
  139. class DecoratorVisitor(ast.NodeVisitor):
  140. def __init__(self):
  141. self.match_line = None
  142. self.stable = False
  143. def visit_keyword(self, node: ast.keyword) -> Any:
  144. if node.arg == "stable":
  145. if isinstance(node.value, ast.Constant):
  146. self.stable = node.value.value
  147. def visit_Name(self, node: ast.Name) -> Any:
  148. if node.id.endswith("_silo_test"):
  149. self.match_line = (node.lineno, node.col_offset - 1)
  150. return ast.NodeVisitor.generic_visit(self, node)
  151. def visit_Attribute(self, node: ast.Attribute) -> Any:
  152. pass
  153. if __name__ == "__main__":
  154. main()