Browse Source

ref(hc): Fixes for split db tests (#53494)

Zach Collins 1 year ago
parent
commit
20d4a8f5a0

+ 74 - 80
src/sentry/auth/helper.py

@@ -35,7 +35,7 @@ from sentry.auth.provider import MigratingIdentityId, Provider
 from sentry.auth.providers.fly.provider import FlyOAuth2Provider
 from sentry.auth.superuser import is_active_superuser
 from sentry.locks import locks
-from sentry.models import AuditLogEntry, AuthIdentity, AuthProvider, User
+from sentry.models import AuditLogEntry, AuthIdentity, AuthProvider, User, outbox_context
 from sentry.pipeline import Pipeline, PipelineSessionStore
 from sentry.pipeline.provider import PipelineProvider
 from sentry.services.hybrid_cloud.organization import (
@@ -298,65 +298,65 @@ class AuthIdentityHandler:
 
     def _get_auth_identity(self, **params: Any) -> AuthIdentity | None:
         try:
-            return AuthIdentity.objects.get(auth_provider=self.auth_provider, **params)
+            return AuthIdentity.objects.get(auth_provider_id=self.auth_provider.id, **params)
         except AuthIdentity.DoesNotExist:
             return None
 
-    @transaction.atomic
     def handle_attach_identity(self, member: RpcOrganizationMember | None = None) -> AuthIdentity:
         """
         Given an already authenticated user, attach or re-attach an identity.
         """
         # prioritize identifying by the SSO provider's user ID
-        auth_identity = self._get_auth_identity(ident=self.identity["id"])
-        if auth_identity is None:
-            # otherwise look for an already attached identity
-            # this can happen if the SSO provider's internal ID changes
-            auth_identity = self._get_auth_identity(user=self.user)
-
-        if auth_identity is None:
-            auth_is_new = True
-            auth_identity = AuthIdentity.objects.create(
-                auth_provider=self.auth_provider,
-                user=self.user,
-                ident=self.identity["id"],
-                data=self.identity.get("data", {}),
-            )
-        else:
-            auth_is_new = False
-
-            # TODO(dcramer): this might leave the user with duplicate accounts,
-            # and in that kind of situation its very reasonable that we could
-            # test email addresses + is_managed to determine if we can auto
-            # merge
-            if auth_identity.user != self.user:
-                wipe = self._wipe_existing_identity(auth_identity)
+        with transaction.atomic(router.db_for_write(AuthIdentity)):
+            auth_identity = self._get_auth_identity(ident=self.identity["id"])
+            if auth_identity is None:
+                # otherwise look for an already attached identity
+                # this can happen if the SSO provider's internal ID changes
+                auth_identity = self._get_auth_identity(user_id=self.user.id)
+
+            if auth_identity is None:
+                auth_is_new = True
+                auth_identity = AuthIdentity.objects.create(
+                    auth_provider=self.auth_provider,
+                    user_id=self.user.id,
+                    ident=self.identity["id"],
+                    data=self.identity.get("data", {}),
+                )
             else:
-                wipe = None
-
-            now = timezone.now()
-            auth_identity.update(
-                user=self.user,
-                ident=self.identity["id"],
-                data=self.provider.update_identity(
-                    new_data=self.identity.get("data", {}), current_data=auth_identity.data
-                ),
-                last_verified=now,
-                last_synced=now,
-            )
+                auth_is_new = False
+
+                # TODO(dcramer): this might leave the user with duplicate accounts,
+                # and in that kind of situation its very reasonable that we could
+                # test email addresses + is_managed to determine if we can auto
+                # merge
+                if auth_identity.user_id != self.user.id:
+                    wipe = self._wipe_existing_identity(auth_identity)
+                else:
+                    wipe = None
 
-            logger.info(
-                "sso.login-pipeline.attach-existing-identity",
-                extra={
-                    "wipe_result": repr(wipe),
-                    "organization_id": self.organization.id,
-                    "user_id": self.user.id,
-                    "auth_identity_user_id": auth_identity.user.id,
-                    "auth_provider_id": self.auth_provider.id,
-                    "idp_identity_id": self.identity["id"],
-                    "idp_identity_email": self.identity.get("email"),
-                },
-            )
+                now = timezone.now()
+                auth_identity.update(
+                    user_id=self.user.id,
+                    ident=self.identity["id"],
+                    data=self.provider.update_identity(
+                        new_data=self.identity.get("data", {}), current_data=auth_identity.data
+                    ),
+                    last_verified=now,
+                    last_synced=now,
+                )
+
+                logger.info(
+                    "sso.login-pipeline.attach-existing-identity",
+                    extra={
+                        "wipe_result": repr(wipe),
+                        "organization_id": self.organization.id,
+                        "user_id": self.user.id,
+                        "auth_identity_user_id": auth_identity.user.id,
+                        "auth_provider_id": self.auth_provider.id,
+                        "idp_identity_id": self.identity["id"],
+                        "idp_identity_email": self.identity.get("email"),
+                    },
+                )
 
         if member is None:
             member = self._get_organization_member(auth_identity)
@@ -386,23 +386,15 @@ class AuthIdentityHandler:
         # so that the new identifier gets used (other we'll hit a constraint)
         # violation since one might exist for (provider, user) as well as
         # (provider, ident)
-        deletion_result = (
-            AuthIdentity.objects.exclude(id=auth_identity.id)
-            .filter(auth_provider=self.auth_provider, user=self.user)
-            .delete()
-        )
-
-        # since we've identified an identity which is no longer valid
-        # lets preemptively mark it as such
-        other_member = organization_service.check_membership_by_id(
-            user_id=auth_identity.user_id, organization_id=self.organization.id
-        )
-        if other_member is None:
-            return
+        with outbox_context(transaction.atomic(router.db_for_write(AuthIdentity))):
+            deletion_result = (
+                AuthIdentity.objects.exclude(id=auth_identity.id)
+                .filter(auth_provider=self.auth_provider, user_id=self.user.id)
+                .delete()
+            )
 
-        other_member.flags.sso__invalid = True
-        other_member.flags.sso__linked = False
-        organization_service.update_membership_flags(organization_member=other_member)
+            for outbox in self.auth_provider.outboxes_for_mark_invalid_sso(auth_identity.user_id):
+                outbox.save()
 
         return deletion_result
 
@@ -785,7 +777,6 @@ class AuthHelper(Pipeline):
             self.provider_model, self.provider, self.organization, self.request, identity
         )
 
-    @transaction.atomic
     def _finish_login_pipeline(self, identity: Mapping[str, Any]) -> HttpResponse:
         """
         The login flow executes both with anonymous and authenticated users.
@@ -843,7 +834,6 @@ class AuthHelper(Pipeline):
 
             return auth_handler.handle_existing_identity(self.state, auth_identity)
 
-    @transaction.atomic
     def _finish_setup_pipeline(self, identity: Mapping[str, Any]) -> HttpResponseRedirect:
         """
         the setup flow here is configuring SSO for an organization.
@@ -877,20 +867,24 @@ class AuthHelper(Pipeline):
 
         auth.mark_sso_complete(request, self.organization.id)
 
-        sso_enabled.send_robust(
-            organization=self.organization,
-            user=request.user,
-            provider=self.provider.key,
-            sender=self.__class__,
+        organization_service.schedule_signal(
+            sso_enabled,
+            organization_id=self.organization.id,
+            args=dict(
+                user_id=request.user.id,
+                provider=self.provider.key,
+            ),
         )
 
-        AuditLogEntry.objects.create(
-            organization_id=self.organization.id,
-            actor=request.user,
-            ip_address=request.META["REMOTE_ADDR"],
-            target_object=self.provider_model.id,
-            event=audit_log.get_event_id("SSO_ENABLE"),
-            data=self.provider_model.get_audit_log_data(),
+        log_service.record_audit_log(
+            event=AuditLogEvent(
+                organization_id=self.organization.id,
+                actor_user_id=request.user.id,
+                ip_address=request.META["REMOTE_ADDR"],
+                target_object_id=self.provider_model.id,
+                event_id=audit_log.get_event_id("SSO_ENABLE"),
+                data=self.provider_model.get_audit_log_data(),
+            )
         )
 
         email_missing_links.delay(self.organization.id, request.user.id, self.provider.key)

+ 33 - 6
src/sentry/models/authprovider.py

@@ -1,4 +1,5 @@
 import logging
+from typing import List
 
 from django.db import models
 from django.utils import timezone
@@ -13,7 +14,8 @@ from sentry.db.models import (
 )
 from sentry.db.models.fields.hybrid_cloud_foreign_key import HybridCloudForeignKey
 from sentry.db.models.fields.jsonfield import JSONField
-from sentry.services.hybrid_cloud.organization import organization_service
+from sentry.models import ControlOutbox, OutboxCategory, OutboxScope
+from sentry.types.region import find_regions_for_orgs
 
 logger = logging.getLogger("sentry.authprovider")
 
@@ -80,11 +82,11 @@ class AuthProvider(Model):
         return self.get_provider().name
 
     def get_scim_token(self):
-        from sentry.models import SentryAppInstallationToken
+        from sentry.services.hybrid_cloud.app import app_service
 
         if self.flags.scim_enabled:
-            return SentryAppInstallationToken.objects.get_token(
-                self.organization_id, f"{self.provider}_scim"
+            return app_service.get_installation_token(
+                organization_id=self.organization_id, provider=f"{self.provider}_scim"
             )
         else:
             logger.warning(
@@ -141,7 +143,19 @@ class AuthProvider(Model):
         )
         self.flags.scim_enabled = True
 
-    def disable_scim(self, user):
+    def outboxes_for_reset_idp_flags(self) -> List[ControlOutbox]:
+        return [
+            ControlOutbox(
+                shard_scope=OutboxScope.ORGANIZATION_SCOPE,
+                shard_identifier=self.organization_id,
+                category=OutboxCategory.RESET_IDP_FLAGS,
+                object_identifier=self.organization_id,
+                region_name=region_name,
+            )
+            for region_name in find_regions_for_orgs([self.organization_id])
+        ]
+
+    def disable_scim(self):
         from sentry import deletions
         from sentry.models import SentryAppInstallationForProvider
 
@@ -152,7 +166,8 @@ class AuthProvider(Model):
             # Only one SCIM installation allowed per organization. So we can reset the idp flags for the orgs
             # We run this update before the app is uninstalled to avoid ending up in a situation where there are
             # members locked out because we failed to drop the IDP flag
-            organization_service.reset_idp_flags(organization_id=self.organization_id)
+            for outbox in self.outboxes_for_reset_idp_flags():
+                outbox.save()
             sentry_app = install.sentry_app_installation.sentry_app
             assert (
                 sentry_app.is_internal
@@ -162,3 +177,15 @@ class AuthProvider(Model):
 
     def get_audit_log_data(self):
         return {"provider": self.provider, "config": self.config}
+
+    def outboxes_for_mark_invalid_sso(self, user_id: int) -> List[ControlOutbox]:
+        return [
+            ControlOutbox(
+                shard_scope=OutboxScope.ORGANIZATION_SCOPE,
+                shard_identifier=self.organization_id,
+                category=OutboxCategory.MARK_INVALID_SSO,
+                object_identifier=user_id,
+                region_name=region_name,
+            )
+            for region_name in find_regions_for_orgs([self.organization_id])
+        ]

+ 3 - 0
src/sentry/models/outbox.py

@@ -84,6 +84,9 @@ class OutboxCategory(IntEnum):
     PROVISION_ORGANIZATION = 17
     PROVISION_SUBSCRIPTION = 18
     SEND_MODEL_SIGNAL = 19
+    DISABLE_AUTH_PROVIDER = 20
+    RESET_IDP_FLAGS = 21
+    MARK_INVALID_SSO = 22
 
     @classmethod
     def as_choices(cls):

+ 3 - 3
src/sentry/receivers/features.py

@@ -378,13 +378,13 @@ def record_plugin_enabled(plugin, project, user, **kwargs):
 
 
 @sso_enabled.connect(weak=False)
-def record_sso_enabled(organization, user, provider, **kwargs):
+def record_sso_enabled(organization_id, user_id, provider, **kwargs):
     FeatureAdoption.objects.record(
-        organization_id=organization.id, feature_slug="sso", complete=True
+        organization_id=organization_id, feature_slug="sso", complete=True
     )
 
     analytics.record(
-        "sso.enabled", user_id=user.id, organization_id=organization.id, provider=provider
+        "sso.enabled", user_id=user_id, organization_id=organization_id, provider=provider
     )
 
 

+ 21 - 0
src/sentry/receivers/outbox/control.py

@@ -120,3 +120,24 @@ def process_send_signal(
             args=payload["args"],
             signal=RpcOrganizationSignal(payload["signal"]),
         )
+
+
+@receiver(process_control_outbox, sender=OutboxCategory.RESET_IDP_FLAGS)
+def process_reset_idp_flags(shard_identifier: int, **kwds: Any):
+    organization_service.reset_idp_flags(organization_id=shard_identifier)
+
+
+@receiver(process_control_outbox, sender=OutboxCategory.MARK_INVALID_SSO)
+def process_mark_invalid_sso(object_identifier: int, shard_identifier: int, **kwds: Any):
+    # since we've identified an identity which is no longer valid
+    # lets preemptively mark it as such
+    other_member = organization_service.check_membership_by_id(
+        user_id=object_identifier,
+        organization_id=shard_identifier,
+    )
+    if other_member is None:
+        return
+
+    other_member.flags.sso__invalid = True
+    other_member.flags.sso__linked = False
+    organization_service.update_membership_flags(organization_member=other_member)

+ 6 - 0
src/sentry/receivers/outbox/region.py

@@ -20,6 +20,7 @@ from sentry.models import (
 )
 from sentry.models.team import Team
 from sentry.receivers.outbox import maybe_process_tombstone
+from sentry.services.hybrid_cloud.auth import auth_service
 from sentry.services.hybrid_cloud.identity import identity_service
 from sentry.services.hybrid_cloud.log import AuditLogEvent, UserIpEvent, log_rpc_service
 from sentry.services.hybrid_cloud.organization_mapping import organization_mapping_service
@@ -120,3 +121,8 @@ def process_organization_mapping_customer_id_update(
         organization_mapping_service.update(
             organization_id=org.id, update={"customer_id": payload["customer_id"]}
         )
+
+
+@receiver(process_region_outbox, sender=OutboxCategory.DISABLE_AUTH_PROVIDER)
+def process_disable_auth_provider(object_identifier: int, **kwds: Any):
+    auth_service.disable_provider(provider_id=object_identifier)

+ 9 - 1
src/sentry/services/hybrid_cloud/app/impl.py

@@ -7,7 +7,12 @@ from django.db.models import QuerySet
 from sentry.api.serializers import SentryAppAlertRuleActionSerializer, Serializer, serialize
 from sentry.constants import SentryAppInstallationStatus
 from sentry.mediators import alert_rule_actions
-from sentry.models import SentryApp, SentryAppComponent, SentryAppInstallation
+from sentry.models import (
+    SentryApp,
+    SentryAppComponent,
+    SentryAppInstallation,
+    SentryAppInstallationToken,
+)
 from sentry.models.integrations.sentry_app_installation import prepare_sentry_app_components
 from sentry.services.hybrid_cloud.app import (
     AppService,
@@ -183,6 +188,9 @@ class DatabaseBackedAppService(AppService):
 
         return serialize_sentry_app_installation(installation, sentry_app)
 
+    def get_installation_token(self, *, organization_id: int, provider: str) -> Optional[str]:
+        return SentryAppInstallationToken.objects.get_token(organization_id, provider)
+
     def trigger_sentry_app_action_creators(
         self, *, fields: List[Mapping[str, Any]], install_uuid: str | None
     ) -> RpcAlertRuleActionResult:

+ 5 - 0
src/sentry/services/hybrid_cloud/app/service.py

@@ -71,6 +71,11 @@ class AppService(RpcService):
     def get_sentry_app_by_slug(self, *, slug: str) -> Optional[RpcSentryApp]:
         pass
 
+    @rpc_method
+    @abc.abstractmethod
+    def get_installation_token(self, *, organization_id: int, provider: str) -> Optional[str]:
+        pass
+
     @rpc_method
     @abc.abstractmethod
     def find_alertable_services(self, *, organization_id: int) -> List[RpcSentryAppService]:

+ 38 - 0
src/sentry/services/hybrid_cloud/auth/impl.py

@@ -21,6 +21,7 @@ from sentry.models import (
     OrgAuthToken,
     SentryAppInstallationToken,
     User,
+    outbox_context,
 )
 from sentry.services.hybrid_cloud.auth import (
     AuthenticatedToken,
@@ -41,6 +42,7 @@ from sentry.services.hybrid_cloud.organization import (
 )
 from sentry.services.hybrid_cloud.user import RpcUser
 from sentry.services.hybrid_cloud.user.service import user_service
+from sentry.silo import unguarded_write
 from sentry.silo.base import SiloMode
 from sentry.utils.auth import AuthUserPasswordExpired
 
@@ -235,6 +237,42 @@ class DatabaseBackedAuthService(AuthService):
     def get_auth_providers(self, organization_id: int) -> List[RpcAuthProvider]:
         return list(AuthProvider.objects.filter(organization_id=organization_id))
 
+    def change_scim(
+        self, *, user_id: int, provider_id: int, enabled: bool, allow_unlinked: bool
+    ) -> None:
+        try:
+            auth_provider = AuthProvider.objects.get(id=provider_id)
+            user = User.objects.get(id=user_id)
+        except (AuthProvider.DoesNotExist, User.DoesNotExist):
+            return
+
+        with outbox_context(transaction.atomic(router.db_for_write(AuthProvider))):
+            auth_provider.flags.allow_unlinked = allow_unlinked
+            if auth_provider.flags.scim_enabled != enabled:
+                if enabled:
+                    auth_provider.enable_scim(user)
+                else:
+                    auth_provider.disable_scim()
+
+            auth_provider.save()
+
+    def disable_provider(self, *, provider_id: int) -> None:
+        with outbox_context(transaction.atomic(router.db_for_write(AuthProvider))):
+            try:
+                auth_provider: AuthProvider = AuthProvider.objects.get(id=provider_id)
+            except AuthProvider.DoesNotExist:
+                return
+
+            user_ids = OrganizationMemberMapping.objects.filter(
+                organization_id=auth_provider.organization_id
+            ).values_list("user_id", flat=True)
+            with unguarded_write(router.db_for_write(User)):
+                User.objects.filter(id__in=user_ids).update(is_managed=False)
+
+            if auth_provider.flags.scim_enabled:
+                auth_provider.disable_scim()
+            auth_provider.delete()
+
 
 class FakeRequestDict:
     d: Mapping[str, str | bytes | None]

+ 9 - 0
src/sentry/services/hybrid_cloud/auth/model.py

@@ -315,10 +315,19 @@ class RpcAuthProvider(RpcModel):
     organization_id: int = -1
     provider: str = ""
     flags: RpcAuthProviderFlags = Field(default_factory=lambda: RpcAuthProviderFlags())
+    config: Mapping[str, Any]
 
     def __hash__(self) -> int:
         return hash((self.id, self.organization_id, self.provider))
 
+    def get_audit_log_data(self):
+        return {"provider": self.provider, "config": self.config}
+
+    def get_provider(self):
+        from sentry.auth import manager
+
+        return manager.get(self.provider, **self.config)
+
 
 class RpcAuthIdentity(RpcModel):
     id: int = -1

Some files were not shown because too many files changed in this diff