Browse Source

test: Add mode to audit which models are touched by each test case (#39952)

When a "model manifest file path" env var is set, write the audit to a
JSON file at that path.

Add a set of triggers to BaseManager that execute provided callbacks
when the model is queried, saved, or deleted, in order to define what
counts as "touching" the model.

Co-authored-by: Zach Collins <recursive.cookie.jar@gmail.com>
Ryan Skonnord 2 years ago
parent
commit
7dedc5c463

+ 1 - 1
bin/decorate-silo-mode-tests

@@ -15,7 +15,7 @@ from sentry.utils.types import Any
 def find_test_cases_matching(model_name: str):
     manifest = json.loads(open(os.environ["SENTRY_MODEL_MANIFEST_FILE_PATH"]).read())
     for test_node_id, hits in manifest.items():
-        if any(hit["model"] == model_name for hit in hits):
+        if model_name in hits:
             parts = test_node_id.split("::")
             yield parts[0], parts[1]
 

+ 1 - 1
bin/run-model-tests

@@ -10,7 +10,7 @@ from sentry.utils import json
 def find_test_cases_matching(model_name: str):
     manifest = json.loads(open(os.environ["SENTRY_MODEL_MANIFEST_FILE_PATH"]).read())
     for test_node_id, hits in manifest.items():
-        if any(hit["model"] == model_name for hit in hits):
+        if model_name in hits:
             yield test_node_id.split("::")[1]
 
 

+ 1 - 0
mypy.ini

@@ -113,6 +113,7 @@ files = fixtures/mypy-stubs,
         src/sentry/tasks/store.py,
         src/sentry/tasks/symbolication.py,
         src/sentry/tasks/update_user_reports.py,
+        src/sentry/testutils/modelmanifest.py,
         src/sentry/testutils/silo.py,
         src/sentry/types/region.py,
         src/sentry/unmerge.py,

+ 55 - 0
src/sentry/db/models/manager/base.py

@@ -4,8 +4,11 @@ import logging
 import threading
 import weakref
 from contextlib import contextmanager
+from enum import IntEnum, auto
 from typing import (
     Any,
+    Callable,
+    Dict,
     Generator,
     Generic,
     Iterable,
@@ -14,6 +17,7 @@ from typing import (
     Optional,
     Sequence,
     Tuple,
+    Type,
 )
 
 from django.conf import settings
@@ -36,6 +40,15 @@ _local_cache_generation = 0
 _local_cache_enabled = False
 
 
+class ModelManagerTriggerCondition(IntEnum):
+    QUERY = auto()
+    SAVE = auto()
+    DELETE = auto()
+
+
+ModelManagerTriggerAction = Callable[[Type[Model]], None]
+
+
 class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]):  # type: ignore
     lookup_handlers = {"iexact": lambda x: x.upper()}
     use_for_related_fields = True
@@ -53,6 +66,10 @@ class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]):  #
         self.cache_ttl = kwargs.pop("cache_ttl", 60 * 5)
         self._cache_version: Optional[str] = kwargs.pop("cache_version", None)
         self.__local_cache = threading.local()
+
+        self._triggers: Dict[
+            object, Tuple[ModelManagerTriggerCondition, ModelManagerTriggerAction]
+        ] = {}
         super().__init__(*args, **kwargs)
 
     @staticmethod
@@ -195,6 +212,8 @@ class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]):  #
 
         self.__cache_state(instance)
 
+        self._execute_triggers(ModelManagerTriggerCondition.SAVE)
+
     def __post_delete(self, instance: M, **kwargs: Any) -> None:
         """
         Drops instance from all cache storages.
@@ -213,6 +232,8 @@ class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]):  #
             key=self.__get_lookup_cache_key(**{pk_name: instance.pk}), version=self.cache_version
         )
 
+        self._execute_triggers(ModelManagerTriggerCondition.DELETE)
+
     def __get_lookup_cache_key(self, **kwargs: Any) -> str:
         return make_key(self.model, "modelcache", kwargs)
 
@@ -448,10 +469,44 @@ class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]):  #
         Returns a new QuerySet object.  Subclasses can override this method to
         easily customize the behavior of the Manager.
         """
+
+        # TODO: This is a quick-and-dirty place to put the trigger hook that won't
+        #  work for all model classes, because some custom managers override
+        #  get_queryset without a `super` call.
+        self._execute_triggers(ModelManagerTriggerCondition.QUERY)
+
         if hasattr(self, "_hints"):
             return self._queryset_class(self.model, using=self._db, hints=self._hints)
         return self._queryset_class(self.model, using=self._db)
 
+    @contextmanager
+    def register_trigger(
+        self, condition: ModelManagerTriggerCondition, action: ModelManagerTriggerAction
+    ) -> Generator[None, None, None]:
+        """Register a callback for when an operation is executed inside the context.
+
+        There is no guarantee whether the action will be called before or after the
+        triggering operation is executed, nor whether it will or will not be called
+        if the triggering operation raises an exception.
+
+        Both the registration of the trigger and the execution of the action are NOT
+        THREADSAFE. This is intended for offline use in single-threaded contexts such
+        as pytest. We must add synchronization if we intend to adapt it for
+        production use.
+        """
+
+        key = object()
+        self._triggers[key] = (condition, action)
+        try:
+            yield
+        finally:
+            del self._triggers[key]
+
+    def _execute_triggers(self, condition: ModelManagerTriggerCondition) -> None:
+        for (next_condition, next_action) in self._triggers.values():
+            if condition == next_condition:
+                next_action(self.model)
+
     def create_silo_limited_copy(
         self, limit: SiloLimit, read_modes: Iterable[SiloMode]
     ) -> BaseManager[M]:

+ 105 - 0
src/sentry/testutils/modelmanifest.py

@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+import os
+from collections import defaultdict
+from contextlib import ExitStack, contextmanager
+from typing import Any, Collection, Dict, Generator, Iterable, Set, Type
+
+import django.apps
+
+from sentry.db.models import BaseManager, Model
+from sentry.db.models.manager.base import ModelManagerTriggerAction, ModelManagerTriggerCondition
+from sentry.utils import json
+
+
+class ModelManifest:
+    """For auditing which models are touched by each test case."""
+
+    class Entry:
+        def __init__(self) -> None:
+            self.hits: Dict[Type[Model], Set[ModelManagerTriggerCondition]] = defaultdict(set)
+
+        def create_trigger_action(
+            self, condition: ModelManagerTriggerCondition
+        ) -> ModelManagerTriggerAction:
+            def action(model_class: Type[Model]) -> None:
+                self.hits[model_class].add(condition)
+
+            return action
+
+    def __init__(self, file_path: str) -> None:
+        self.file_path = file_path
+        self.tests: Dict[str, Collection[ModelManifest.Entry]] = {}
+
+    def _load_json(self, content: Any) -> None:
+        models = {model.__qualname__: model for model in django.apps.apps.get_models()}
+        conditions = {condition.name: condition for condition in ModelManagerTriggerCondition}
+
+        entry_objects = []
+
+        for (test_id, entry_inputs) in content.items():
+            entry_objects.append(entry_obj := ModelManifest.Entry())
+
+            for entry_input in entry_inputs:
+                for (model_name, condition_names) in entry_input.items():
+                    model_class = models[model_name]
+                    for condition_name in condition_names:
+                        condition = conditions[condition_name]
+                        entry_obj.hits[model_class].add(condition)
+
+            self.tests[test_id] = entry_objects
+
+    def _to_json(self) -> Dict[str, Any]:
+        return {
+            test_id: [
+                {
+                    model_class.__qualname__: [condition.name for condition in conditions]
+                    for (model_class, conditions) in entry.hits.items()
+                }
+                for entry in entries
+                if entry.hits
+            ]
+            for (test_id, entries) in self.tests.items()
+        }
+
+    @classmethod
+    def open(cls, file_path: str) -> ModelManifest:
+        manifest = cls(file_path)
+        if os.path.exists(file_path):
+            with open(file_path) as f:
+                content = json.load(f)
+            manifest._load_json(content)
+        return manifest
+
+    @contextmanager
+    def write(self) -> Generator[None, None, None]:
+        try:
+            yield  # Populate self.tests
+        finally:
+            with open(self.file_path, mode="w") as f:
+                json.dump(self._to_json(), f)
+
+    @staticmethod
+    def _get_all_model_managers() -> Iterable[BaseManager]:
+        for model_class in django.apps.apps.get_models():
+            manager = model_class.objects
+            if isinstance(manager, BaseManager):
+                yield manager
+
+    @contextmanager
+    def register(self, test_id: str) -> Generator[None, None, None]:
+        with ExitStack() as stack:
+            entries = []
+
+            for model_manager in self._get_all_model_managers():
+                entries.append(entry := ModelManifest.Entry())
+                for condition in ModelManagerTriggerCondition:
+                    action = entry.create_trigger_action(condition)
+                    stack.enter_context(model_manager.register_trigger(condition, action))
+
+            try:
+                yield
+            finally:
+                # Overwrite the entire test in place, in case it used to touch a
+                # model and doesn't anymore
+                self.tests[test_id] = entries

+ 30 - 0
tests/conftest.py

@@ -90,3 +90,33 @@ def _error_workflow_command(filesystempath, lineno, longrepr):
 
 def _escape(s):
     return s.replace("%", "%25").replace("\r", "%0D").replace("\n", "%0A")
+
+
+_MODEL_MANIFEST_FILE_PATH = os.getenv("SENTRY_MODEL_MANIFEST_FILE_PATH")
+_model_manifest = None
+
+
+@pytest.fixture(scope="session", autouse=True)
+def create_model_manifest_file():
+    """Audit which models are touched by each test case and write it to file."""
+
+    # We have to construct the ModelManifest lazily, because importing
+    # sentry.testutils.modelmanifest too early causes a dependency cycle.
+    from sentry.testutils.modelmanifest import ModelManifest
+
+    if _MODEL_MANIFEST_FILE_PATH:
+        global _model_manifest
+        _model_manifest = ModelManifest.open(_MODEL_MANIFEST_FILE_PATH)
+        with _model_manifest.write():
+            yield
+    else:
+        yield
+
+
+@pytest.fixture(scope="class", autouse=True)
+def register_class_in_model_manifest(request: pytest.FixtureRequest):
+    if _model_manifest:
+        with _model_manifest.register(request.node.nodeid):
+            yield
+    else:
+        yield