decorate-silo-mode-tests 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #!/usr/bin/env python
  2. import ast
  3. import os
  4. import os.path
  5. import shutil
  6. import tempfile
  7. from typing import Union
  8. import click
  9. from sentry.utils import json
  10. from sentry.utils.types import Any
  11. def find_test_cases_matching(model_name: str):
  12. manifest = json.loads(open(os.environ["SENTRY_MODEL_MANIFEST_FILE_PATH"]).read())
  13. for test_node_id, hits in manifest.items():
  14. if model_name in hits:
  15. parts = test_node_id.split("::")
  16. yield parts[0], parts[1]
  17. @click.command()
  18. @click.option(
  19. "silo_mode",
  20. "--silo-mode",
  21. required=True,
  22. help="Which mode to apply to tests",
  23. type=click.Choice(
  24. [
  25. "control",
  26. "region",
  27. ]
  28. ),
  29. )
  30. @click.option("set_stable", "--set-stable", default=False, is_flag=True, help="Set tests as stable")
  31. @click.argument("target_model", required=True)
  32. def main(target_model: str, silo_mode: str, set_stable: bool):
  33. """
  34. Script to decorate the given target test for silo tests, making it easier to deploy changes to given tests.
  35. """
  36. for file_name, test_case_name in find_test_cases_matching(target_model):
  37. print(f"Trying {test_case_name} in {file_name}") # noqa
  38. test_case_name = test_case_name.split("[")[
  39. 0
  40. ] # remove any parameterization off the test case
  41. file_path = os.path.abspath(file_name)
  42. file_ast = ast.parse(open(file_path).read())
  43. test_visitor = TestVisitor(test_case_name, f"{silo_mode}_silo_test", set_stable)
  44. test_visitor.visit(file_ast)
  45. test_visitor.rewrite(file_path)
  46. class TestVisitor(ast.NodeVisitor):
  47. def __init__(self, target_symbol_path: str, target_test_silo_mode: str, set_stable: bool):
  48. self.set_stable = set_stable
  49. self.target_test_silo_mode = target_test_silo_mode
  50. self.target_symbol_parts = target_symbol_path.split(".")
  51. self.import_match_line = False
  52. self.decorator_match_line = None
  53. self.func_match_line = None
  54. self.class_node = None
  55. def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
  56. if node.module == "sentry.testutils.silo":
  57. for name in node.names:
  58. if isinstance(name, ast.alias):
  59. if name.name.endswith("_silo_test"):
  60. self.import_match_line = (node.lineno, node.col_offset)
  61. def visit_ClassDef(self, node: ast.ClassDef) -> Any:
  62. if len(self.target_symbol_parts) == 2 and self.target_test_silo_mode[0] == node.name:
  63. self.class_node = node
  64. self.generic_visit(node)
  65. self.class_node = None
  66. elif len(self.target_symbol_parts) == 1:
  67. if self.target_symbol_parts[-1] == node.name or self.target_symbol_parts[-1] in {
  68. e.id for e in node.bases if isinstance(e, ast.Name)
  69. }:
  70. self.mark_target(node)
  71. def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.ClassDef]) -> Any:
  72. if self.target_symbol_parts[-1] == node.name:
  73. if self.class_node:
  74. node = self.class_node
  75. elif len(self.target_symbol_parts) != 1:
  76. return
  77. self.mark_target(node)
  78. return
  79. return self.generic_visit(node)
  80. def mark_target(self, node: Union[ast.FunctionDef, ast.ClassDef]):
  81. self.func_match_line = (node.lineno, node.col_offset)
  82. for expr in node.decorator_list:
  83. decorator_visitor = DecoratorVisitor(self.target_test_silo_mode)
  84. decorator_visitor.visit(expr)
  85. if decorator_visitor.match_line:
  86. self.decorator_match_line = decorator_visitor.match_line
  87. break
  88. def _decorate(self, lineno, match_line):
  89. if not match_line:
  90. return False
  91. if not match_line[0] == lineno:
  92. return False
  93. ws = b" " * match_line[1]
  94. if self.set_stable:
  95. return ws + f"@{self.target_test_silo_mode}(stable=True)\n".encode()
  96. else:
  97. return ws + f"@{self.target_test_silo_mode}\n".encode()
  98. def rewrite(self, path):
  99. import_line = f"from sentry.testutils.silo import {self.target_test_silo_mode}\n".encode()
  100. if not self.decorator_match_line and not self.func_match_line:
  101. raise Exception(f"Could not find test case {self.target_symbol_parts}!")
  102. with tempfile.NamedTemporaryFile(delete=False) as tf:
  103. with open(path) as f:
  104. if not self.import_match_line:
  105. tf.write(import_line)
  106. for i, line in enumerate(f.readlines()):
  107. i += 1
  108. if self.import_match_line and self.import_match_line[0] == i:
  109. tf.write(import_line)
  110. continue
  111. if newline := self._decorate(i, self.decorator_match_line):
  112. # If the decorator type is not changing, keep the original line.
  113. if self.target_test_silo_mode in line:
  114. tf.write(line.encode("utf8"))
  115. else:
  116. tf.write(newline)
  117. continue
  118. if not self.decorator_match_line and (
  119. newline := self._decorate(i, self.func_match_line)
  120. ):
  121. tf.write(newline)
  122. tf.write(line.encode("utf8"))
  123. tf.close()
  124. shutil.move(tf.name, path)
  125. class DecoratorVisitor(ast.NodeVisitor):
  126. def __init__(self, target_test_silo_mode: str):
  127. self.target_test_silo_mode = target_test_silo_mode
  128. self.match_line = None
  129. def visit_Name(self, node: ast.Name) -> Any:
  130. if node.id.endswith("_silo_test"):
  131. self.match_line = (node.lineno, node.col_offset - 1)
  132. return ast.NodeVisitor.generic_visit(self, node)
  133. def visit_Attribute(self, node: ast.Attribute) -> Any:
  134. pass
  135. if __name__ == "__main__":
  136. main()