Browse Source

Revert "chore(hybridcloud) Remove role switching and update guard method (#52673)"

This reverts commit 07c498ab53b528ffc364bd1fa682e67207facf37.

Co-authored-by: markstory <24086+markstory@users.noreply.github.com>
getsentry-bot 1 year ago
parent
commit
ace3ab5861

+ 10 - 0
src/sentry/db/postgres/roles.py

@@ -3,11 +3,15 @@ from __future__ import annotations
 import contextlib
 import os
 import sys
+from collections import defaultdict
+from typing import MutableMapping
 
 from django.db.transaction import get_connection
 
 from sentry.silo.patches.silo_aware_transaction_patch import determine_using_by_silo_mode
 
+_fencing_counters: MutableMapping[str, int] = defaultdict(int)
+
 
 @contextlib.contextmanager
 def in_test_psql_role_override(role_name: str, using: str | None = None):
@@ -22,7 +26,12 @@ def in_test_psql_role_override(role_name: str, using: str | None = None):
 
     using = determine_using_by_silo_mode(using)
 
+    # TODO(mark) Move this closer to other silo code.
+    _fencing_counters[using] += 1
+
     with get_connection(using).cursor() as conn:
+        fence_value = _fencing_counters[using]
+        conn.execute("SELECT %s", [f"start_role_override_{fence_value}"])
         conn.execute("SELECT user")
         (cur,) = conn.fetchone()
         conn.execute("SET ROLE %s", [role_name])
@@ -30,3 +39,4 @@ def in_test_psql_role_override(role_name: str, using: str | None = None):
             yield
         finally:
             conn.execute("SET ROLE %s", [cur])
+            conn.execute("SELECT %s", [f"end_role_override_{fence_value}"])

+ 2 - 3
src/sentry/models/counter.py

@@ -13,6 +13,7 @@ from sentry.db.models import (
     region_silo_only_model,
     sane_repr,
 )
+from sentry.db.postgres.roles import in_test_psql_role_override
 from sentry.silo import SiloMode
 
 
@@ -93,8 +94,6 @@ def increment_project_counter(project, delta=1, using="default"):
 # this must be idempotent because it seems to execute twice
 # (at least during test runs)
 def create_counter_function(app_config, using, **kwargs):
-    from sentry.testutils.silo import unguarded_write
-
     if app_config and app_config.name != "sentry":
         return
 
@@ -104,7 +103,7 @@ def create_counter_function(app_config, using, **kwargs):
     if SiloMode.get_current_mode() == SiloMode.CONTROL:
         return
 
-    with unguarded_write(using), connections[using].cursor() as cursor:
+    with in_test_psql_role_override("postgres", using), connections[using].cursor() as cursor:
         cursor.execute(
             """
             create or replace function sentry_increment_project_counter(

+ 2 - 2
src/sentry/models/outbox.py

@@ -26,6 +26,7 @@ from sentry.db.models import (
     region_silo_only_model,
     sane_repr,
 )
+from sentry.db.postgres.roles import in_test_psql_role_override
 from sentry.db.postgres.transactions import (
     django_test_transaction_water_mark,
     in_test_assert_no_transaction,
@@ -467,7 +468,6 @@ _outbox_context = OutboxContext()
 
 @contextlib.contextmanager
 def outbox_context(inner: Atomic | None = None, flush: bool | None = None) -> ContextManager[None]:
-    from sentry.testutils.silo import unguarded_write
 
     # If we don't specify our flush, use the outer specified override
     if flush is None:
@@ -481,7 +481,7 @@ def outbox_context(inner: Atomic | None = None, flush: bool | None = None) -> Co
     original = _outbox_context.flushing_enabled
 
     if inner:
-        with unguarded_write(using=inner.using), inner:
+        with in_test_psql_role_override("postgres", using=inner.using), inner:
             _outbox_context.flushing_enabled = flush
             try:
                 yield

+ 3 - 3
src/sentry/services/hybrid_cloud/organization/impl.py

@@ -7,6 +7,7 @@ from django.dispatch import Signal
 
 from sentry import roles
 from sentry.api.serializers import serialize
+from sentry.db.postgres.roles import in_test_psql_role_override
 from sentry.models import (
     Activity,
     ControlOutbox,
@@ -47,7 +48,6 @@ from sentry.services.hybrid_cloud.organization.serial import (
 )
 from sentry.services.hybrid_cloud.user import RpcUser
 from sentry.services.hybrid_cloud.util import flags_to_bits
-from sentry.testutils.silo import unguarded_write
 from sentry.types.region import find_regions_for_orgs
 
 
@@ -470,7 +470,7 @@ class DatabaseBackedOrganizationService(OrganizationService):
                     pass
 
     def reset_idp_flags(self, *, organization_id: int) -> None:
-        with unguarded_write():
+        with in_test_psql_role_override("postgres"):
             # Flags are not replicated -- these updates are safe without outbox application.
             OrganizationMember.objects.filter(
                 organization_id=organization_id,
@@ -485,7 +485,7 @@ class DatabaseBackedOrganizationService(OrganizationService):
         # Normally, calling update on a QS for organization member fails because we need to ensure that updates to
         # OrganizationMember objects produces outboxes.  In this case, it is safe to do the update directly because
         # the attribute we are changing never needs to produce an outbox.
-        with unguarded_write():
+        with in_test_psql_role_override("postgres"):
             OrganizationMember.objects.filter(user_id=user.id).update(
                 user_is_active=user.is_active, user_email=user.email
             )

+ 3 - 3
src/sentry/services/hybrid_cloud/organization_mapping/impl.py

@@ -1,5 +1,6 @@
 from typing import List, Optional
 
+from sentry.db.postgres.roles import in_test_psql_role_override
 from sentry.models.organizationmapping import OrganizationMapping
 from sentry.services.hybrid_cloud.organization_mapping import (
     OrganizationMappingService,
@@ -7,7 +8,6 @@ from sentry.services.hybrid_cloud.organization_mapping import (
     RpcOrganizationMappingUpdate,
 )
 from sentry.services.hybrid_cloud.organization_mapping.serial import serialize_organization_mapping
-from sentry.testutils.silo import unguarded_write
 
 
 class DatabaseBackedOrganizationMappingService(OrganizationMappingService):
@@ -61,7 +61,7 @@ class DatabaseBackedOrganizationMappingService(OrganizationMappingService):
 
     def update(self, organization_id: int, update: RpcOrganizationMappingUpdate) -> None:
         # TODO: REMOVE FROM GETSENTRY!
-        with unguarded_write():
+        with in_test_psql_role_override("postgres"):
             try:
                 OrganizationMapping.objects.get(organization_id=organization_id).update(**update)
             except OrganizationMapping.DoesNotExist:
@@ -70,7 +70,7 @@ class DatabaseBackedOrganizationMappingService(OrganizationMappingService):
     def upsert(
         self, organization_id: int, update: RpcOrganizationMappingUpdate
     ) -> RpcOrganizationMapping:
-        with unguarded_write():
+        with in_test_psql_role_override("postgres"):
             org_mapping, _created = OrganizationMapping.objects.update_or_create(
                 organization_id=organization_id, defaults=update
             )

+ 2 - 2
src/sentry/services/hybrid_cloud/organizationmember_mapping/impl.py

@@ -7,6 +7,7 @@ from typing import Optional
 
 from django.db import IntegrityError, transaction
 
+from sentry.db.postgres.roles import in_test_psql_role_override
 from sentry.models import outbox_context
 from sentry.models.organizationmembermapping import OrganizationMemberMapping
 from sentry.models.user import User
@@ -18,7 +19,6 @@ from sentry.services.hybrid_cloud.organizationmember_mapping import (
 from sentry.services.hybrid_cloud.organizationmember_mapping.serial import (
     serialize_org_member_mapping,
 )
-from sentry.testutils.silo import unguarded_write
 
 
 class DatabaseBackedOrganizationMemberMappingService(OrganizationMemberMappingService):
@@ -101,5 +101,5 @@ class DatabaseBackedOrganizationMemberMappingService(OrganizationMemberMappingSe
             organizationmember_id=organizationmember_id,
         )
         if org_member_map:
-            with unguarded_write():
+            with in_test_psql_role_override("postgres"):
                 org_member_map.delete()

+ 2 - 2
src/sentry/tasks/check_auth.py

@@ -4,11 +4,11 @@ from datetime import timedelta
 from django.utils import timezone
 
 from sentry.auth.exceptions import IdentityNotValid
+from sentry.db.postgres.roles import in_test_psql_role_override
 from sentry.models import AuthIdentity
 from sentry.services.hybrid_cloud.organization import RpcOrganizationMember, organization_service
 from sentry.silo.base import SiloMode
 from sentry.tasks.base import instrumented_task
-from sentry.testutils.silo import unguarded_write
 from sentry.utils import metrics
 
 logger = logging.getLogger("sentry.auth")
@@ -95,7 +95,7 @@ def check_auth_identity(auth_identity_id, **kwargs):
         is_valid = True
 
     if getattr(om.flags, "sso:linked") != is_linked:
-        with unguarded_write():
+        with in_test_psql_role_override("postgres"):
             # flags are not replicated, so it's ok not to create outboxes here.
             setattr(om.flags, "sso:linked", is_linked)
             setattr(om.flags, "sso:invalid", not is_valid)

+ 3 - 3
src/sentry/tasks/organization_mapping.py

@@ -6,11 +6,11 @@ from django.db import router
 from django.db.models import Count
 from django.utils import timezone
 
+from sentry.db.postgres.roles import in_test_psql_role_override
 from sentry.models.organizationmapping import OrganizationMapping
 from sentry.services.hybrid_cloud.organization import organization_service
 from sentry.silo.base import SiloMode
 from sentry.tasks.base import instrumented_task, retry
-from sentry.testutils.silo import unguarded_write
 from sentry.utils import metrics
 from sentry.utils.query import RangeQuerySetWrapper
 
@@ -42,7 +42,7 @@ def _verify_mappings(expiration_threshold_time: datetime) -> None:
         org = organization_service.get_organization_by_id(
             id=mapping.organization_id, slug=mapping.slug
         )
-        with unguarded_write(using=router.db_for_write(OrganizationMapping)):
+        with in_test_psql_role_override("postgres", using=router.db_for_write(OrganizationMapping)):
             if org is None and mapping.date_created <= expiration_threshold_time:
                 mapping.delete()
             elif org is not None:
@@ -65,7 +65,7 @@ def _remove_duplicate_mappings(expiration_threshold_time: datetime) -> None:
         )
         organization_id = dupe["organization_id"]
 
-        with unguarded_write(using=router.db_for_write(OrganizationMapping)):
+        with in_test_psql_role_override("postgres", using=router.db_for_write(OrganizationMapping)):
             # Organization exists in the region silo
             found_org = organization_service.get_organization_by_id(id=organization_id)
             if found_org is None:

+ 2 - 2
src/sentry/testutils/helpers/api_gateway.py

@@ -10,9 +10,9 @@ from rest_framework.response import Response
 
 from sentry.api.base import control_silo_endpoint, region_silo_endpoint
 from sentry.api.bases.organization import OrganizationEndpoint
+from sentry.db.postgres.roles import in_test_psql_role_override
 from sentry.models.organizationmapping import OrganizationMapping
 from sentry.testutils import APITestCase
-from sentry.testutils.silo import unguarded_write
 from sentry.types.region import Region, RegionCategory, clear_global_regions
 from sentry.utils import json
 
@@ -136,7 +136,7 @@ class ApiGatewayTestCase(APITestCase):
             adding_headers={"test": "header"},
         )
 
-        with unguarded_write(using=router.db_for_write(OrganizationMapping)):
+        with in_test_psql_role_override("postgres", using=router.db_for_write(OrganizationMapping)):
             OrganizationMapping.objects.get(organization_id=self.organization.id).update(
                 region_name="region1"
             )

+ 15 - 42
src/sentry/testutils/silo.py

@@ -7,11 +7,10 @@ import re
 import sys
 from collections import defaultdict
 from contextlib import contextmanager
-from typing import Any, Callable, Dict, Iterable, List, MutableMapping, MutableSet, Set, Tuple, Type
+from typing import Any, Callable, Dict, Iterable, MutableMapping, MutableSet, Set, Tuple, Type
 from unittest import TestCase
 
 import pytest
-from django.apps import apps
 from django.conf import settings
 from django.db import connections, router
 from django.db.models import Model
@@ -199,7 +198,6 @@ def assume_test_silo_mode(desired_silo: SiloMode) -> Any:
 
 
 def reset_test_role(role: str, using: str, create_role: bool) -> None:
-    # Deprecated, will remove once getsentry is updated.
     connection_names = [conn.alias for conn in connections.all()]
 
     if create_role:
@@ -234,7 +232,6 @@ _role_privileges_created: MutableMapping[str, bool] = {}
 
 
 def create_model_role_guards(app_config: Any, using: str, **kwargs: Any):
-    # Deprecated, will remove once getsentry is updated.
     global _role_created
     if "pytest" not in sys.argv[0] or not settings.USE_ROLE_SWAPPING_IN_TESTS:
         return
@@ -299,7 +296,6 @@ post_migrate.connect(create_model_role_guards, dispatch_uid="create_model_role_g
 
 
 def restrict_role(role: str, model: Any, revocation_type: str, using: str = "default") -> None:
-    # Deprecated, will remove once getsentry is updated.
     if router.db_for_write(model) != using:
         return
 
@@ -341,43 +337,18 @@ def protected_table(table: str, operation: str) -> re.Pattern:
     return re.compile(f'{operation}[^"]+"{table}"', re.IGNORECASE)
 
 
-_protected_operations: List[re.Pattern] = []
+protected_operations = (
+    protected_table("sentry_organizationmember", "insert"),
+    protected_table("sentry_organizationmember", "update"),
+    protected_table("sentry_organizationmember", "delete"),
+    protected_table("sentry_organization", "insert"),
+    protected_table("sentry_organization", "update"),
+    protected_table("sentry_organizationmapping", "insert"),
+    protected_table("sentry_organizationmapping", "update"),
+    protected_table("sentry_organizationmembermapping", "insert"),
+)
 
-
-def get_protected_operations() -> List[re.Pattern]:
-    if len(_protected_operations):
-        return _protected_operations
-
-    # Protect Foreign Keys using hybrid cloud models from being deleted without using the
-    # privileged user. Deletion should only occur when the developer is actively aware
-    # of the need to generate outboxes.
-    seen_models: MutableSet[type] = set()
-    for app_config in apps.get_app_configs():
-        for model in iter_models(app_config.name):
-            for field in model._meta.fields:
-                if not isinstance(field, HybridCloudForeignKey):
-                    continue
-                fk_model = field.foreign_model
-                if fk_model is None or fk_model in seen_models:
-                    continue
-                seen_models.add(fk_model)
-                _protected_operations.append(protected_table(fk_model._meta.db_table, "delete"))
-
-    # Protect inserts/updates that require outbox messages.
-    _protected_operations.extend(
-        [
-            protected_table("sentry_organizationmember", "insert"),
-            protected_table("sentry_organizationmember", "update"),
-            protected_table("sentry_organizationmember", "delete"),
-            protected_table("sentry_organization", "insert"),
-            protected_table("sentry_organization", "update"),
-            protected_table("sentry_organizationmapping", "insert"),
-            protected_table("sentry_organizationmapping", "update"),
-            protected_table("sentry_organizationmembermapping", "insert"),
-        ]
-    )
-
-    return _protected_operations
+fence_re = re.compile(r"select\s*\'(?P<operation>start|end)_role_override", re.IGNORECASE)
 
 
 def validate_protected_queries(queries: Iterable[Dict[str, str]]) -> None:
@@ -400,7 +371,7 @@ def validate_protected_queries(queries: Iterable[Dict[str, str]]) -> None:
             else:
                 raise AssertionError("Invalid fencing operation encounted")
 
-        for protected in get_protected_operations():
+        for protected in protected_operations:
             if protected.match(sql):
                 if fence_depth == 0:
                     msg = [
@@ -425,6 +396,8 @@ def validate_protected_queries(queries: Iterable[Dict[str, str]]) -> None:
 
 
 def iter_models(app_name: str | None = None) -> Iterable[Type[Model]]:
+    from django.apps import apps
+
     for app, app_models in apps.all_models.items():
         if app == app_name or app_name is None:
             for model in app_models.values():

Some files were not shown because too many files changed in this diff