Browse Source

feat(hybrid): Scripts for silo mode decorators (#38288)

Co-authored-by: Megan Heskett <meg.heskett@gmail.com>
Ryan Skonnord 2 years ago
parent
commit
d7107be63f

+ 1 - 0
mypy.ini

@@ -90,6 +90,7 @@ files = src/sentry/analytics/,
         src/sentry/tasks/store.py,
         src/sentry/tasks/symbolication.py,
         src/sentry/tasks/update_user_reports.py,
+        src/sentry/testutils/silo.py,
         src/sentry/unmerge.py,
         src/sentry/utils/appleconnect/,
         src/sentry/utils/hashlib.py,

+ 0 - 0
scripts/silo/__init__.py


+ 14 - 142
scripts/silo/add_silo_decorators.py

@@ -1,157 +1,29 @@
-#!.venv/bin/python
+#!/usr/bin/env sentry exec
 
 from __future__ import annotations
 
-import os
-import re
-import sys
-from dataclasses import dataclass
-from enum import Enum, auto
-from typing import Iterable
-
-from sentry.utils import json
+from sentry.utils.silo.add_silo_decorators import add_silo_decorators
+from sentry.utils.silo.common import Keywords
 
 """
 Instructions for use:
 
 1. Commit or stash any Git changes in progress.
-2. Scroll down to "Fill these predicates in..." and write what you want to do.
+2. Add keywords to identify model and api classes in each silo.
 3. From the Sentry project root, do
      ./scripts/silo/audit_silo_decorators.py | ./scripts/silo/add_silo_decorators.py
 4. Do `git status` or `git diff` to observe the results. Commit if you're happy.
 """
 
-
-class ClassCategory(Enum):
-    MODEL = auto()
-    ENDPOINT = auto()
-
-
-@dataclass
-class TargetClass:
-    module: str
-    name: str
-    category: ClassCategory
-    is_decorated: bool
-
-
-def parse_audit(audit) -> Iterable[TargetClass]:
-    def split_qualname(value):
-        dot_index = value.rindex(".")
-        module = value[:dot_index]
-        name = value[dot_index + 1 :]
-        return module, name
-
-    def parse_group(category, dec_group):
-        is_decorated = dec_group["decorator"] is not None
-        for value in dec_group["values"]:
-            module, name = split_qualname(value)
-            yield TargetClass(module, name, category, is_decorated)
-
-    for dec_group in audit["models"]["decorators"]:
-        yield from parse_group(ClassCategory.MODEL, dec_group)
-    for dec_group in audit["endpoints"]["decorators"]:
-        yield from parse_group(ClassCategory.ENDPOINT, dec_group)
-
-
-def read_audit():
-    pipe_input = sys.stdin.read()
-    brace_index = pipe_input.index("{")
-    pipe_input = pipe_input[brace_index:]  # strip leading junk
-    silo_audit = json.loads(pipe_input)
-    return list(parse_audit(silo_audit))
-
-
-def find_source_paths():
-    for (dirpath, dirnames, filenames) in os.walk("./src/sentry"):
-        for filename in filenames:
-            if filename.endswith(".py"):
-                yield os.path.join(dirpath, filename)
-
-
-def find_class_declarations():
-    for src_path in find_source_paths():
-        with open(src_path) as f:
-            src_code = f.read()
-        for match in re.findall(r"\nclass\s+(\w+)\(", src_code):
-            yield src_path, match
-
-
-def insert_import(src_code: str, import_stmt: str) -> str:
-    future_import = None
-    for future_import in re.finditer(r"from\s+__future__\s+.*\n+", src_code):
-        pass  # iterate to last match
-    if future_import:
-        start, end = future_import.span()
-        return src_code[:end] + import_stmt + "\n" + src_code[end:]
-    else:
-        return import_stmt + "\n" + src_code
-
-
-def apply_decorators(
-    decorator_name: str,
-    import_stmt: str,
-    target_classes: Iterable[TargetClass],
-) -> None:
-    target_names = {c.name for c in target_classes if not c.is_decorated}
-    for src_path, class_name in find_class_declarations():
-        if class_name in target_names:
-            with open(src_path) as f:
-                src_code = f.read()
-            new_code = re.sub(
-                rf"\nclass\s+{class_name}\(",
-                f"\n@{decorator_name}\nclass {class_name}(",
-                src_code,
-            )
-            new_code = insert_import(new_code, import_stmt)
-            with open(src_path, mode="w") as f:
-                f.write(new_code)
-
-
-def main():
-    classes = read_audit()
-
-    def filter_classes(category, predicate):
-        return (c for c in classes if c.category == category and predicate(c))
-
-    ####################################################################
-    # Fill these predicates in with the logic you want to apply
-
-    def control_model_predicate(c: TargetClass) -> bool:
-        return False
-
-    def customer_model_predicate(c: TargetClass) -> bool:
-        return False
-
-    def control_endpoint_predicate(c: TargetClass) -> bool:
-        return False
-
-    def customer_endpoint_predicate(c: TargetClass) -> bool:
-        return False
-
-    ####################################################################
-
-    apply_decorators(
-        "control_silo_model",
-        "from sentry.db.models import control_silo_model",
-        filter_classes(ClassCategory.MODEL, control_model_predicate),
-    )
-    apply_decorators(
-        "customer_silo_model",
-        "from sentry.db.models import customer_silo_model",
-        filter_classes(ClassCategory.MODEL, customer_model_predicate),
-    )
-    apply_decorators(
-        "control_silo_endpoint",
-        "from sentry.api.base import control_silo_endpoint",
-        filter_classes(ClassCategory.ENDPOINT, control_endpoint_predicate),
-    )
-    apply_decorators(
-        "customer_silo_endpoint",
-        "from sentry.api.base import customer_silo_endpoint",
-        filter_classes(ClassCategory.ENDPOINT, customer_endpoint_predicate),
-    )
-
+SILO_KEYWORDS = {
+    "control": Keywords(
+        include_words=["User", "Auth", "Identity"],
+    ),
+    "customer": Keywords(
+        include_words=["Organization", "Project", "Team", "Group", "Event", "Issue"],
+        exclude_words=["JiraIssue"],
+    ),
+}
 
 if __name__ == "__main__":
-    main()
+    add_silo_decorators(silo_keywords=SILO_KEYWORDS)

+ 7 - 0
scripts/silo/all_decorators.sh

@@ -0,0 +1,7 @@
+#!/bin/sh
+
+# Run from project root
+
+./scripts/silo/audit_silo_decorators.py | ./scripts/silo/add_silo_decorators.py
+./scripts/silo/decorate_models_by_relation.py
+pytest --collect-only | ./scripts/silo/decorate_unit_tests.py

+ 4 - 183
scripts/silo/audit_silo_decorators.py

@@ -2,190 +2,11 @@
 
 from __future__ import annotations
 
-import abc
-import json  # noqa - I want the `indent` param
-import sys
-from collections import defaultdict
+from sentry.runner import configure
 
-import django.apps
-import django.urls
-
-
-def audit_silo_limits(format="json"):
-    """Lists which classes have had silo decorators applied."""
-
-    from sentry.runner import configure
-
-    configure()
-    model_table = create_model_table()
-    endpoint_table = create_endpoint_table()
-
-    if format == "json":
-        json_repr = {
-            "models": ModelPresentation().as_json_repr(model_table),
-            "endpoints": EndpointPresentation().as_json_repr(endpoint_table),
-        }
-        json.dump(json_repr, sys.stdout, indent=4)
-    elif format == "markdown":
-        ModelPresentation().print_markdown(model_table)
-        EndpointPresentation().print_markdown(endpoint_table)
-    else:
-        raise ValueError
-
-
-def create_model_table():
-    table = defaultdict(list)
-    for model_class in django.apps.apps.get_models():
-        if model_class._meta.app_label != "sentry":
-            continue
-        limit = getattr(model_class._meta, "_ModelSiloLimit__silo_limit", None)
-        key = (limit.modes, limit.read_only) if limit else None
-        table[key].append(model_class)
-    return table
-
-
-def create_endpoint_table():
-    from sentry.api.base import Endpoint
-
-    def is_endpoint(view_function, bindings):
-        view_class = getattr(view_function, "view_class", None)
-        return view_class and issubclass(view_class, Endpoint)
-
-    def get_endpoint_classes():
-        url_mappings = list(django.urls.get_resolver().reverse_dict.items())
-        for (view_function, bindings) in url_mappings:
-            if is_endpoint(view_function, bindings):
-                yield view_function.view_class
-
-    table = defaultdict(list)
-    for endpoint_class in get_endpoint_classes():
-        limit = getattr(endpoint_class, "__silo_limit", None)
-        key = limit.modes if limit else None
-        table[key].append(endpoint_class)
-
-    return table
-
-
-class ConsolePresentation(abc.ABC):
-    @property
-    @abc.abstractmethod
-    def table_label(self):
-        raise NotImplementedError
-
-    @abc.abstractmethod
-    def order(self, group):
-        raise NotImplementedError
-
-    @abc.abstractmethod
-    def get_group_label(self, key):
-        raise NotImplementedError
-
-    @abc.abstractmethod
-    def get_key_repr(self, key):
-        raise NotImplementedError
-
-    @staticmethod
-    def format_mode_set(modes):
-        if modes is None:
-            return None
-        return sorted(str(x) for x in modes)
-
-    @staticmethod
-    def format_value(value):
-        return f"{value.__module__}.{value.__name__}"
-
-    def normalize_table(self, table):
-        return {
-            key: sorted({self.format_value(value) for value in group})
-            for (key, group) in (sorted(table.items(), key=self.order))
-        }
-
-    def as_json_repr(self, table):
-        table = self.normalize_table(table)
-        return {
-            "total_count": sum(len(group) for group in table.values()),
-            "decorators": [
-                {
-                    "decorator": self.get_key_repr(group_key),
-                    "count": len(group),
-                    "values": group,
-                }
-                for group_key, group in table.items()
-            ],
-        }
-
-    def print_markdown(self, table):
-        table = self.normalize_table(table)
-
-        total_count = sum(len(group) for group in table.values())
-        table_header = f"{self.table_label} ({total_count})"
-        print("\n" + table_header)  # noqa
-        print("=" * len(table_header), end="\n\n")  # noqa
-
-        for (group_key, group) in table.items():
-            group_label = self.get_group_label(group_key)
-            group_header = f"{group_label} ({len(group)})"
-            print(group_header)  # noqa
-            print("-" * len(group_header), end="\n\n")  # noqa
-
-            for value in group:
-                print("  - " + value)  # noqa
-            print()  # noqa
-
-
-class ModelPresentation(ConsolePresentation):
-    @property
-    def table_label(self):
-        return "MODELS"
-
-    def order(self, group):
-        group_key, _model_group = group
-        if group_key is None:
-            return ()
-        write_modes, read_modes = group_key
-        return (
-            len(write_modes),
-            len(read_modes),
-            self.format_mode_set(write_modes),
-            self.format_mode_set(read_modes),
-        )
-
-    def get_key_repr(self, key):
-        if key is None:
-            return None
-        write_modes, read_modes = key
-        return {
-            "write_modes": self.format_mode_set(write_modes),
-            "read_modes": self.format_mode_set(read_modes),
-        }
-
-    def get_group_label(self, key):
-        if key is None:
-            return "No decorator"
-        write_modes, read_modes = key
-        if read_modes:
-            return (
-                f"{self.format_mode_set(write_modes)}, read_only={self.format_mode_set(read_modes)}"
-            )
-        else:
-            return self.format_mode_set(write_modes)
-
-
-class EndpointPresentation(ConsolePresentation):
-    @property
-    def table_label(self):
-        return "VIEWS"
-
-    def order(self, group):
-        mode_set, _endpoint_group = group
-        return len(mode_set or ()), self.format_mode_set(mode_set)
-
-    def get_group_label(self, key):
-        return self.format_mode_set(key) if key else "No decorator"
-
-    def get_key_repr(self, key):
-        return self.format_mode_set(key)
+configure()
 
+from sentry.utils.silo.audit_silo_decorators import audit_silo_decorators
 
 if __name__ == "__main__":
-    audit_silo_limits()
+    audit_silo_decorators()

+ 43 - 0
scripts/silo/decorate_models_by_relation.py

@@ -0,0 +1,43 @@
+#!/usr/bin/env sentry exec
+
+from __future__ import annotations
+
+from sentry.models.group import Group
+from sentry.models.organization import Organization
+from sentry.models.project import Project
+from sentry.models.release import Release
+from sentry.utils.silo.decorate_models_by_relation import (
+    TargetRelations,
+    decorate_models_by_relation,
+)
+
+"""
+This is an alternative to add_silo_decorators.py that uses an algorithmic definition of
+the silos and aims for 100% coverage. It examines the fields of model classes and
+uses a graph traversal algorithm to find all models that point to the `Organization`
+model, either directly or through a number of steps. Those models are tagged for the
+customer silo, and all others for the control silo.
+
+Instructions for use:
+
+1. Commit or stash any Git changes in progress.
+2. Update foreign key relationships to identify models in the customer silo.
+2. From the Sentry project root, do
+     ./scripts/silo/decorate_models_by_relation.py
+3. Do `git status` or `git diff` to observe the results. Commit if you're happy.
+"""
+
+CUSTOMER_TARGET_RELATIONS = TargetRelations(
+    # Foreign key relationships
+    models=[Organization],
+    naming_conventions={
+        # Covers BoundedBigIntegerFields used as soft foreign keys
+        "organization_id": Organization,
+        "project_id": Project,
+        "group_id": Group,
+        "release_id": Release,
+    },
+)
+
+if __name__ == "__main__":
+    decorate_models_by_relation(target_relations=CUSTOMER_TARGET_RELATIONS)

+ 27 - 0
scripts/silo/decorate_unit_tests.py

@@ -0,0 +1,27 @@
+#!/usr/bin/env sentry exec
+
+from __future__ import annotations
+
+from scripts.silo.add_silo_decorators import SILO_KEYWORDS
+from sentry.utils.silo.decorate_unit_tests import decorate_unit_tests
+
+"""Add silo mode decorators to unit test cases en masse.
+
+Unlike `add_silo_decorators`, this script can't really reflect on interpreted
+Python code in order to distinguish unit tests. It instead relies on an external
+`pytest` run to collect the list of test cases, and does some kludgey regex
+business in order to apply the decorators.
+
+Instructions for use:
+
+From the Sentry project root, do
+    pytest --collect-only | ./scripts/silo/decorate_unit_tests.py
+
+Running `pytest` to collect unit test cases can be quite slow. To speed up
+repeated runs, you can instead do
+    pytest --collect-only > pytest-collect.txt
+    ./scripts/silo/decorate_unit_tests.py < pytest-collect.txt
+"""
+
+if __name__ == "__main__":
+    decorate_unit_tests(silo_keywords=SILO_KEYWORDS)

+ 2 - 0
src/sentry/conf/server.py

@@ -2773,8 +2773,10 @@ SENTRY_FUNCTIONS_PROJECT_NAME = None
 
 SENTRY_FUNCTIONS_REGION = "us-central1"
 
+# Settings related to SiloMode
 SILO_MODE = os.environ.get("SENTRY_SILO_MODE", None)
 FAIL_ON_UNAVAILABLE_API_CALL = False
+SILO_MODE_SPLICE_TESTS = bool(os.environ.get("SENTRY_SILO_MODE_SPLICE_TESTS", False))
 
 DISALLOWED_CUSTOMER_DOMAINS = []
 

+ 59 - 0
src/sentry/testutils/factories.py

@@ -92,6 +92,7 @@ from sentry.models.integrations.integration_feature import Feature, IntegrationT
 from sentry.models.releasefile import update_artifact_index
 from sentry.signals import project_created
 from sentry.snuba.dataset import Dataset
+from sentry.testutils.silo import exempt_from_silo_limits
 from sentry.types.activity import ActivityType
 from sentry.types.integrations import ExternalProviders
 from sentry.utils import json, loremipsum
@@ -240,6 +241,7 @@ def _patch_artifact_manifest(path, org, release, project=None, extra_files=None)
 # TODO(dcramer): consider moving to something more scalable like factoryboy
 class Factories:
     @staticmethod
+    @exempt_from_silo_limits()
     def create_organization(name=None, owner=None, **kwargs):
         if not name:
             name = petname.Generate(2, " ", letters=10).title()
@@ -250,6 +252,7 @@ class Factories:
         return org
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_member(teams=None, **kwargs):
         kwargs.setdefault("role", "member")
 
@@ -260,6 +263,7 @@ class Factories:
         return om
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_team_membership(team, member=None, user=None, role=None):
         if member is None:
             member, _ = OrganizationMember.objects.get_or_create(
@@ -271,6 +275,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_team(organization, **kwargs):
         if not kwargs.get("name"):
             kwargs["name"] = petname.Generate(2, " ", letters=10).title()
@@ -285,6 +290,7 @@ class Factories:
         return team
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_environment(project, **kwargs):
         name = kwargs.get("name", petname.Generate(3, " ", letters=10)[:64])
 
@@ -298,6 +304,7 @@ class Factories:
         return env
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_project(organization=None, teams=None, fire_project_created=False, **kwargs):
         if not kwargs.get("name"):
             kwargs["name"] = petname.Generate(2, " ", letters=10).title()
@@ -318,10 +325,12 @@ class Factories:
         return project
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_project_bookmark(project, user):
         return ProjectBookmark.objects.create(project_id=project.id, user=user)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_project_rule(project, action_data=None, condition_data=None):
         action_data = action_data or [
             {
@@ -350,6 +359,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_slack_project_rule(project, integration_id, channel_id=None, channel_name=None):
         action_data = [
             {
@@ -363,10 +373,12 @@ class Factories:
         return Factories.create_project_rule(project, action_data)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_project_key(project):
         return project.key_set.get_or_create()[0]
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_release(
         project: Project,
         user: Optional[User] = None,
@@ -437,6 +449,7 @@ class Factories:
         return release
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_release_file(release_id, file=None, name=None, dist_id=None):
         if file is None:
             file = Factories.create_file(
@@ -460,6 +473,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_artifact_bundle(org, release, project=None, extra_files=None):
         import zipfile
 
@@ -483,6 +497,7 @@ class Factories:
         return bundle.getvalue()
 
     @classmethod
+    @exempt_from_silo_limits()
     def create_release_archive(cls, org, release: str, project=None, dist=None):
         bundle = cls.create_artifact_bundle(org, release, project)
         file_ = File.objects.create(name="release-artifacts.zip")
@@ -491,6 +506,7 @@ class Factories:
         return update_artifact_index(release, dist, file_)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_code_mapping(project, repo=None, organization_integration=None, **kwargs):
         kwargs.setdefault("stack_root", "")
         kwargs.setdefault("source_root", "")
@@ -506,6 +522,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_repo(project, name=None, provider=None, integration_id=None, url=None):
         repo = Repository.objects.create(
             organization_id=project.organization_id,
@@ -518,6 +535,7 @@ class Factories:
         return repo
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_commit(
         repo, project=None, author=None, release=None, message=None, key=None, date_added=None
     ):
@@ -550,6 +568,7 @@ class Factories:
         return commit
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_commit_author(organization_id=None, project=None, user=None):
         return CommitAuthor.objects.get_or_create(
             organization_id=organization_id or project.organization_id,
@@ -558,12 +577,14 @@ class Factories:
         )[0]
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_commit_file_change(commit, filename):
         return CommitFileChange.objects.get_or_create(
             organization_id=commit.organization_id, commit=commit, filename=filename, type="M"
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_user(email=None, **kwargs):
         if email is None:
             email = uuid4().hex + "@example.com"
@@ -584,6 +605,7 @@ class Factories:
         return user
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_useremail(user, email, **kwargs):
         if not email:
             email = uuid4().hex + "@example.com"
@@ -611,6 +633,7 @@ class Factories:
         return event
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_group(project, **kwargs):
         kwargs.setdefault("message", "Hello world")
         kwargs.setdefault("data", {})
@@ -621,10 +644,12 @@ class Factories:
         return Group.objects.create(project=project, **kwargs)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_file(**kwargs):
         return File.objects.create(**kwargs)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_file_from_path(path, name=None, **kwargs):
         if name is None:
             name = os.path.basename(path)
@@ -635,6 +660,7 @@ class Factories:
         return file
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_event_attachment(event, file=None, **kwargs):
         if file is None:
             file = Factories.create_file(
@@ -653,6 +679,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_dif_file(
         project,
         debug_id=None,
@@ -696,6 +723,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_dif_from_path(path, object_name=None, **kwargs):
         if object_name is None:
             object_name = os.path.basename(path)
@@ -709,6 +737,7 @@ class Factories:
         UserPermission.objects.create(user=user, permission=permission)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_sentry_app(**kwargs):
         app = sentry_apps.Creator.run(is_internal=False, **Factories._sentry_app_kwargs(**kwargs))
 
@@ -718,12 +747,14 @@ class Factories:
         return app
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_internal_integration(**kwargs):
         return sentry_apps.InternalCreator.run(
             is_internal=True, **Factories._sentry_app_kwargs(**kwargs)
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_internal_integration_token(install, **kwargs):
         return sentry_app_installation_tokens.Creator.run(sentry_app_installation=install, **kwargs)
 
@@ -745,6 +776,7 @@ class Factories:
         return _kwargs
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_sentry_app_installation(
         organization=None, slug=None, user=None, status=None, prevent_token_exchange=False
     ):
@@ -773,10 +805,12 @@ class Factories:
         return install
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_stacktrace_link_schema():
         return {"type": "stacktrace-link", "uri": "/redirect/"}
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_issue_link_schema():
         return {
             "type": "issue-link",
@@ -815,6 +849,7 @@ class Factories:
         }
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_alert_rule_action_schema():
         return {
             "type": "alert-rule-action",
@@ -844,6 +879,7 @@ class Factories:
         }
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_service_hook(actor=None, org=None, project=None, events=None, url=None, **kwargs):
         if not actor:
             actor = Factories.create_user()
@@ -869,6 +905,7 @@ class Factories:
         return service_hooks.Creator.run(**_kwargs)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_sentry_app_feature(feature=None, sentry_app=None, description=None):
         if not sentry_app:
             sentry_app = Factories.create_sentry_app()
@@ -900,6 +937,7 @@ class Factories:
         return _kwargs
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_doc_integration(features=None, has_avatar: bool = False, **kwargs) -> DocIntegration:
         doc = DocIntegration.objects.create(**Factories._doc_integration_kwargs(**kwargs))
         if features:
@@ -909,6 +947,7 @@ class Factories:
         return doc
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_doc_integration_features(
         features=None, doc_integration=None
     ) -> List[IntegrationFeature]:
@@ -928,6 +967,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_doc_integration_avatar(doc_integration=None, **kwargs) -> DocIntegrationAvatar:
         if not doc_integration:
             doc_integration = Factories.create_doc_integration()
@@ -938,6 +978,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_userreport(group, project=None, event_id=None, **kwargs):
         return UserReport.objects.create(
             group_id=group.id,
@@ -950,6 +991,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_session():
         engine = import_module(settings.SESSION_ENGINE)
 
@@ -958,6 +1000,7 @@ class Factories:
         return session
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_platform_external_issue(
         group=None, service_type=None, display_name=None, web_url=None
     ):
@@ -970,6 +1013,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_integration_external_issue(group=None, integration=None, key=None):
         external_issue = ExternalIssue.objects.create(
             organization_id=group.organization.id, integration_id=integration.id, key=key
@@ -986,6 +1030,7 @@ class Factories:
         return external_issue
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_incident(
         organization,
         projects,
@@ -1025,12 +1070,14 @@ class Factories:
         return incident
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_incident_activity(incident, type, comment=None, user=None):
         return IncidentActivity.objects.create(
             incident=incident, type=type, comment=comment, user=user
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_alert_rule(
         organization,
         projects,
@@ -1085,6 +1132,7 @@ class Factories:
         return alert_rule
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_alert_rule_trigger(alert_rule, label=None, alert_threshold=100):
         if not label:
             label = petname.Generate(2, " ", letters=10).title()
@@ -1092,6 +1140,7 @@ class Factories:
         return create_alert_rule_trigger(alert_rule, label, alert_threshold)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_incident_trigger(incident, alert_rule_trigger, status=None):
         if status is None:
             status = TriggerStatus.ACTIVE.value
@@ -1101,6 +1150,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_alert_rule_trigger_action(
         trigger,
         type=AlertRuleTriggerAction.Type.EMAIL,
@@ -1121,6 +1171,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_external_user(user: User, **kwargs: Any) -> ExternalActor:
         kwargs.setdefault("provider", ExternalProviders.GITHUB.value)
         kwargs.setdefault("external_name", "")
@@ -1128,6 +1179,7 @@ class Factories:
         return ExternalActor.objects.create(actor=user.actor, **kwargs)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_external_team(team: Team, **kwargs: Any) -> ExternalActor:
         kwargs.setdefault("provider", ExternalProviders.GITHUB.value)
         kwargs.setdefault("external_name", "@getsentry/ecosystem")
@@ -1135,6 +1187,7 @@ class Factories:
         return ExternalActor.objects.create(actor=team.actor, **kwargs)
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_codeowners(project, code_mapping, **kwargs):
         kwargs.setdefault("raw", "")
 
@@ -1143,6 +1196,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_slack_integration(
         organization: Organization, external_id: str, **kwargs: Any
     ) -> Integration:
@@ -1159,6 +1213,7 @@ class Factories:
         return integration
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_integration(
         organization: Organization, external_id: str, **kwargs: Any
     ) -> Integration:
@@ -1168,6 +1223,7 @@ class Factories:
         return integration
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_identity_provider(integration: Integration, **kwargs: Any) -> IdentityProvider:
         return IdentityProvider.objects.create(
             type=integration.provider,
@@ -1176,6 +1232,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_identity(
         user: User, identity_provider: IdentityProvider, external_id: str, **kwargs: Any
     ) -> Identity:
@@ -1188,6 +1245,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_group_history(
         group: Group,
         status: int,
@@ -1216,6 +1274,7 @@ class Factories:
         )
 
     @staticmethod
+    @exempt_from_silo_limits()
     def create_comment(issue, project, user, text="hello world"):
         data = {"text": text}
         return Activity.objects.create(

+ 107 - 0
src/sentry/testutils/silo.py

@@ -0,0 +1,107 @@
+from __future__ import annotations
+
+import functools
+from contextlib import contextmanager
+from typing import Any, Callable, Generator, Iterable, Tuple
+from unittest import TestCase
+
+import pytest
+from django.conf import settings
+from django.test import override_settings
+
+from sentry.silo import SiloMode
+
+TestMethod = Callable[..., None]
+
+
+class SiloModeTest:
+    """Decorate a test case that is expected to work in a given silo mode.
+
+    By default, the test is executed if the environment is in that silo mode or
+    in monolith mode. The test is skipped in an incompatible mode.
+
+    If the SILO_MODE_SPLICE_TESTS environment flag is set, any decorated test
+    class will be modified by having new test methods inserted. These new
+    methods run in the given modes and have generated names (such as
+    "test_response__in_customer_silo"). This can be used in a dev environment to
+    test in multiple modes conveniently during a single test run. Individually
+    decorated methods and stand-alone functions are treated as normal.
+    """
+
+    def __init__(self, *silo_modes: SiloMode) -> None:
+        self.silo_modes = frozenset(silo_modes)
+        self.splice = bool(settings.SILO_MODE_SPLICE_TESTS)
+
+    @staticmethod
+    def _find_all_test_methods(test_class: type) -> Iterable[Tuple[str, TestMethod]]:
+        for attr_name in dir(test_class):
+            if attr_name.startswith("test_"):
+                attr = getattr(test_class, attr_name)
+                if callable(attr):
+                    yield attr_name, attr
+
+    def _create_mode_methods_to_splice(
+        self, test_method: TestMethod
+    ) -> Iterable[Tuple[str, TestMethod]]:
+        for mode in self.silo_modes:
+
+            def replacement_test_method(*args: Any, **kwargs: Any) -> None:
+                with override_settings(SILO_MODE=mode):
+                    test_method(*args, **kwargs)
+
+            functools.update_wrapper(replacement_test_method, test_method)
+            modified_name = f"{test_method.__name__}__in_{str(mode).lower()}_silo"
+            replacement_test_method.__name__ = modified_name
+            yield modified_name, replacement_test_method
+
+    def _splice_mode_methods(self, test_class: type) -> type:
+        for (method_name, test_method) in self._find_all_test_methods(test_class):
+            for (new_name, new_method) in self._create_mode_methods_to_splice(test_method):
+                setattr(test_class, new_name, new_method)
+        return test_class
+
+    def __call__(self, decorated_obj: Any) -> Any:
+        is_test_case_class = isinstance(decorated_obj, type) and issubclass(decorated_obj, TestCase)
+        is_function = callable(decorated_obj)
+        if not (is_test_case_class or is_function):
+            raise ValueError("@SiloModeTest must decorate a function or TestCase class")
+
+        if self.splice and is_test_case_class:
+            return self._splice_mode_methods(decorated_obj)
+
+        current_silo_mode = SiloMode.get_current_mode()
+        is_skipped = (
+            current_silo_mode != SiloMode.MONOLITH and current_silo_mode not in self.silo_modes
+        )
+        reason = f"Test case is not part of {current_silo_mode} mode"
+        return pytest.mark.skipif(is_skipped, reason=reason)(decorated_obj)
+
+
+control_silo_test = SiloModeTest(SiloMode.CONTROL)
+customer_silo_test = SiloModeTest(SiloMode.CUSTOMER)
+
+
+@contextmanager
+def exempt_from_silo_limits() -> Generator[None, None, None]:
+    """Exempt test setup code from silo mode checks.
+
+    This can be used to decorate functions that are used exclusively in setting
+    up test cases, so that those functions don't produce false exceptions from
+    writing to tables that wouldn't be allowed in a certain SiloModeTest case.
+
+    It can also be used as a context manager to enclose setup code within a test
+    method. Such setup code would ideally be moved to the test class's `setUp`
+    method or a helper function where possible, but this is available as a
+    kludge when that's too inconvenient. For example:
+
+    ```
+    @SiloModeTest(SiloMode.CUSTOMER)
+    class MyTest(TestCase):
+        def test_something(self):
+            with exempt_from_mode_limits():
+                org = self.create_organization()  # would be wrong if under test
+            do_something(org)  # the actual code under test
+    ```
+    """
+    with override_settings(SILO_MODE=SiloMode.MONOLITH):
+        yield

Some files were not shown because too many files changed in this diff