|
@@ -1,9 +1,11 @@
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
+import ast
|
|
import os
|
|
import os
|
|
from collections import defaultdict
|
|
from collections import defaultdict
|
|
from contextlib import ExitStack, contextmanager
|
|
from contextlib import ExitStack, contextmanager
|
|
-from typing import Any, Collection, Dict, Generator, Iterable, Set, Type
|
|
|
|
|
|
+from pathlib import Path
|
|
|
|
+from typing import Any, Dict, Generator, Iterable, Iterator, MutableMapping, Set, Tuple, Type, Union
|
|
|
|
|
|
import django.apps
|
|
import django.apps
|
|
|
|
|
|
@@ -15,69 +17,88 @@ from sentry.utils import json
|
|
class ModelManifest:
|
|
class ModelManifest:
|
|
"""For auditing which models are touched by each test case."""
|
|
"""For auditing which models are touched by each test case."""
|
|
|
|
|
|
|
|
+ file_path: str
|
|
|
|
+ connections: MutableMapping[int, Set[int]]
|
|
|
|
+ model_names: MutableMapping[str, int]
|
|
|
|
+ test_names: MutableMapping[str, int]
|
|
|
|
+ reverse_lookup: MutableMapping[int, str]
|
|
|
|
+ next_id: int
|
|
|
|
+
|
|
|
|
+ def get_or_create_id(self, cache: MutableMapping[str, int], name: str) -> int:
|
|
|
|
+ if name in cache:
|
|
|
|
+ return cache[name]
|
|
|
|
+ next_id = self.next_id
|
|
|
|
+ cache[name] = next_id
|
|
|
|
+ self.reverse_lookup[next_id] = name
|
|
|
|
+ self.next_id += 1
|
|
|
|
+ return next_id
|
|
|
|
+
|
|
class Entry:
|
|
class Entry:
|
|
|
|
+ hits: set[Type[Model]]
|
|
|
|
+
|
|
def __init__(self) -> None:
|
|
def __init__(self) -> None:
|
|
- self.hits: Dict[Type[Model], Set[ModelManagerTriggerCondition]] = defaultdict(set)
|
|
|
|
|
|
+ self.hits: Set[Type[Model]] = set()
|
|
|
|
|
|
def create_trigger_action(
|
|
def create_trigger_action(
|
|
self, condition: ModelManagerTriggerCondition
|
|
self, condition: ModelManagerTriggerCondition
|
|
) -> ModelManagerTriggerAction:
|
|
) -> ModelManagerTriggerAction:
|
|
def action(model_class: Type[Model]) -> None:
|
|
def action(model_class: Type[Model]) -> None:
|
|
- self.hits[model_class].add(condition)
|
|
|
|
|
|
+ self.hits.add(model_class)
|
|
|
|
|
|
return action
|
|
return action
|
|
|
|
|
|
def __init__(self, file_path: str) -> None:
|
|
def __init__(self, file_path: str) -> None:
|
|
self.file_path = file_path
|
|
self.file_path = file_path
|
|
- self.tests: Dict[str, Collection[ModelManifest.Entry]] = {}
|
|
|
|
-
|
|
|
|
- def _load_json(self, content: Any) -> None:
|
|
|
|
- models = {model.__qualname__: model for model in django.apps.apps.get_models()}
|
|
|
|
- conditions = {condition.name: condition for condition in ModelManagerTriggerCondition}
|
|
|
|
-
|
|
|
|
- entry_objects = []
|
|
|
|
-
|
|
|
|
- for (test_id, entry_inputs) in content.items():
|
|
|
|
- entry_objects.append(entry_obj := ModelManifest.Entry())
|
|
|
|
-
|
|
|
|
- for entry_input in entry_inputs:
|
|
|
|
- for (model_name, condition_names) in entry_input.items():
|
|
|
|
- model_class = models[model_name]
|
|
|
|
- for condition_name in condition_names:
|
|
|
|
- condition = conditions[condition_name]
|
|
|
|
- entry_obj.hits[model_class].add(condition)
|
|
|
|
-
|
|
|
|
- self.tests[test_id] = entry_objects
|
|
|
|
-
|
|
|
|
- def _to_json(self) -> Dict[str, Any]:
|
|
|
|
- return {
|
|
|
|
- test_id: [
|
|
|
|
- {
|
|
|
|
- model_class.__qualname__: [condition.name for condition in conditions]
|
|
|
|
- for (model_class, conditions) in entry.hits.items()
|
|
|
|
- }
|
|
|
|
- for entry in entries
|
|
|
|
- if entry.hits
|
|
|
|
- ]
|
|
|
|
- for (test_id, entries) in self.tests.items()
|
|
|
|
- }
|
|
|
|
|
|
+ self.connections = defaultdict(set)
|
|
|
|
+ self.model_names = {}
|
|
|
|
+ self.test_names = {}
|
|
|
|
+ self.reverse_lookup = {}
|
|
|
|
+ self.next_id = 0
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def from_json_file(cls, file_path: str) -> ModelManifest:
|
|
|
|
+ with open(file_path) as f:
|
|
|
|
+ content = json.load(f)
|
|
|
|
+
|
|
|
|
+ manifest = ModelManifest(file_path)
|
|
|
|
+ highest_id = 0
|
|
|
|
+ for model_name, model_id in content["model_names"].items():
|
|
|
|
+ manifest.model_names[model_name] = model_id
|
|
|
|
+ highest_id = max(model_id, highest_id)
|
|
|
|
+ manifest.reverse_lookup[model_id] = model_name
|
|
|
|
+
|
|
|
|
+ for test_name, test_id in content["test_names"].items():
|
|
|
|
+ manifest.test_names[test_name] = test_id
|
|
|
|
+ highest_id = max(test_id, highest_id)
|
|
|
|
+ manifest.reverse_lookup[test_id] = test_name
|
|
|
|
+
|
|
|
|
+ for id, connections in content["connections"].items():
|
|
|
|
+ for connection in connections:
|
|
|
|
+ manifest.connections[int(id)].add(int(connection))
|
|
|
|
+
|
|
|
|
+ manifest.next_id = highest_id + 1
|
|
|
|
+ return manifest
|
|
|
|
+
|
|
|
|
+ def to_json(self) -> Dict[str, Any]:
|
|
|
|
+ return dict(
|
|
|
|
+ connections=self.connections,
|
|
|
|
+ test_names=self.test_names,
|
|
|
|
+ model_names=self.model_names,
|
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def open(cls, file_path: str) -> ModelManifest:
|
|
def open(cls, file_path: str) -> ModelManifest:
|
|
- manifest = cls(file_path)
|
|
|
|
if os.path.exists(file_path):
|
|
if os.path.exists(file_path):
|
|
- with open(file_path) as f:
|
|
|
|
- content = json.load(f)
|
|
|
|
- manifest._load_json(content)
|
|
|
|
- return manifest
|
|
|
|
|
|
+ return cls.from_json_file(file_path)
|
|
|
|
+ return cls(file_path)
|
|
|
|
|
|
@contextmanager
|
|
@contextmanager
|
|
def write(self) -> Generator[None, None, None]:
|
|
def write(self) -> Generator[None, None, None]:
|
|
try:
|
|
try:
|
|
- yield # Populate self.tests
|
|
|
|
|
|
+ yield # allow population via register
|
|
finally:
|
|
finally:
|
|
with open(self.file_path, mode="w") as f:
|
|
with open(self.file_path, mode="w") as f:
|
|
- json.dump(self._to_json(), f)
|
|
|
|
|
|
+ json.dump(self.to_json(), f)
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def _get_all_model_managers() -> Iterable[BaseManager]:
|
|
def _get_all_model_managers() -> Iterable[BaseManager]:
|
|
@@ -87,7 +108,7 @@ class ModelManifest:
|
|
yield manager
|
|
yield manager
|
|
|
|
|
|
@contextmanager
|
|
@contextmanager
|
|
- def register(self, test_id: str) -> Generator[None, None, None]:
|
|
|
|
|
|
+ def register(self, test_name: str) -> Generator[None, None, None]:
|
|
with ExitStack() as stack:
|
|
with ExitStack() as stack:
|
|
entries = []
|
|
entries = []
|
|
|
|
|
|
@@ -102,4 +123,114 @@ class ModelManifest:
|
|
finally:
|
|
finally:
|
|
# Overwrite the entire test in place, in case it used to touch a
|
|
# Overwrite the entire test in place, in case it used to touch a
|
|
# model and doesn't anymore
|
|
# model and doesn't anymore
|
|
- self.tests[test_id] = entries
|
|
|
|
|
|
+ test_id = self.get_or_create_id(self.test_names, test_name)
|
|
|
|
+ self.connections[test_id] = set()
|
|
|
|
+ for key in list(self.connections.keys()):
|
|
|
|
+ self.connections[key].remove(test_id)
|
|
|
|
+
|
|
|
|
+ for entry in entries:
|
|
|
|
+ for model in entry.hits:
|
|
|
|
+ model_id = self.get_or_create_id(self.model_names, model.__name__)
|
|
|
|
+ self.connections[test_id].add(model_id)
|
|
|
|
+ self.connections[model_id].add(test_id)
|
|
|
|
+
|
|
|
|
+ def each_hybrid_cloud_test(
|
|
|
|
+ self, path_refix: Path
|
|
|
|
+ ) -> Iterator[Tuple[int, HybridCloudTestVisitor]]:
|
|
|
|
+ for test_node_name, test_id in self.test_names.items():
|
|
|
|
+ test_file_path: str
|
|
|
|
+ test_case_name: str
|
|
|
|
+ test_node_name = test_node_name.split("[")[0]
|
|
|
|
+ test_file_path, test_case_name = test_node_name.split("::")
|
|
|
|
+ test_file_path = os.path.abspath(str(path_refix.joinpath(test_file_path)))
|
|
|
|
+
|
|
|
|
+ test_visitor = HybridCloudTestVisitor(test_file_path, test_case_name)
|
|
|
|
+ if test_visitor.exists:
|
|
|
|
+ yield test_id, test_visitor
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class HybridCloudTestDecoratorVisitor(ast.NodeVisitor):
|
|
|
|
+ match_line: Tuple[int, int] | None
|
|
|
|
+
|
|
|
|
+ def __init__(self) -> None:
|
|
|
|
+ self.match_line = None
|
|
|
|
+ self.stable = False
|
|
|
|
+
|
|
|
|
+ def visit_keyword(self, node: ast.keyword) -> Any:
|
|
|
|
+ if node.arg == "stable":
|
|
|
|
+ if isinstance(node.value, ast.Constant):
|
|
|
|
+ self.stable = node.value.value
|
|
|
|
+
|
|
|
|
+ def visit_Name(self, node: ast.Name) -> Any:
|
|
|
|
+ if node.id.endswith("_silo_test"):
|
|
|
|
+ self.match_line = (node.lineno, node.col_offset - 1)
|
|
|
|
+ return ast.NodeVisitor.generic_visit(self, node)
|
|
|
|
+
|
|
|
|
+ def visit_Attribute(self, node: ast.Attribute) -> Any:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class HybridCloudTestVisitor(ast.NodeVisitor):
|
|
|
|
+ import_match_line: Tuple[int, int] | None
|
|
|
|
+ class_node: ast.ClassDef | None
|
|
|
|
+ func_match_line: Tuple[int, int] | None
|
|
|
|
+ decorator_match_line: Tuple[int, int] | None
|
|
|
|
+
|
|
|
|
+ def __init__(self, test_file_path: str, test_name: str):
|
|
|
|
+ self.test_file_path = test_file_path
|
|
|
|
+ self.test_name = test_name
|
|
|
|
+ self.target_symbol_parts = test_name.split(".")
|
|
|
|
+ self.import_match_line = None
|
|
|
|
+ self.decorator_match_line = None
|
|
|
|
+ self.func_match_line = None
|
|
|
|
+ self.class_node = None
|
|
|
|
+ self.decorator_was_stable = False
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def exists(self) -> bool:
|
|
|
|
+ return os.path.exists(self.test_file_path)
|
|
|
|
+
|
|
|
|
+ def load(self) -> None:
|
|
|
|
+ with open(self.test_file_path) as f:
|
|
|
|
+ file_ast = ast.parse(f.read())
|
|
|
|
+ self.visit(file_ast)
|
|
|
|
+
|
|
|
|
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
|
|
|
|
+ if node.module == "sentry.testutils.silo":
|
|
|
|
+ for name in node.names:
|
|
|
|
+ if isinstance(name, ast.alias):
|
|
|
|
+ if name.name.endswith("_silo_test"):
|
|
|
|
+ self.import_match_line = (node.lineno, node.col_offset)
|
|
|
|
+
|
|
|
|
+ def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
|
|
|
+ if len(self.target_symbol_parts) == 2 and self.target_symbol_parts[0] == node.name:
|
|
|
|
+ self.class_node = node
|
|
|
|
+ self.generic_visit(node)
|
|
|
|
+ self.class_node = None
|
|
|
|
+ elif len(self.target_symbol_parts) == 1:
|
|
|
|
+ if self.target_symbol_parts[-1] == node.name or self.target_symbol_parts[-1] in {
|
|
|
|
+ e.id for e in node.bases if isinstance(e, ast.Name)
|
|
|
|
+ }:
|
|
|
|
+ self.mark_target(node)
|
|
|
|
+
|
|
|
|
+ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.ClassDef]) -> Any:
|
|
|
|
+ if self.target_symbol_parts[-1] == node.name:
|
|
|
|
+ if self.class_node:
|
|
|
|
+ node = self.class_node
|
|
|
|
+ elif len(self.target_symbol_parts) != 1:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ self.mark_target(node)
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ return self.generic_visit(node)
|
|
|
|
+
|
|
|
|
+ def mark_target(self, node: Union[ast.FunctionDef, ast.ClassDef]) -> None:
|
|
|
|
+ self.func_match_line = (node.lineno, node.col_offset)
|
|
|
|
+ for expr in node.decorator_list:
|
|
|
|
+ decorator_visitor = HybridCloudTestDecoratorVisitor()
|
|
|
|
+ decorator_visitor.visit(expr)
|
|
|
|
+ if decorator_visitor.match_line:
|
|
|
|
+ self.decorator_match_line = decorator_visitor.match_line
|
|
|
|
+ self.decorator_was_stable = decorator_visitor.stable
|
|
|
|
+ break
|