Zachary Collins 1 year ago
parent
commit
a40203f4d1
3 changed files with 69 additions and 35 deletions
  1. 12 2
      src/sentry/models/outbox.py
  2. 34 32
      src/sentry/tasks/check_auth.py
  3. 23 1
      tests/sentry/models/test_outbox.py

+ 12 - 2
src/sentry/models/outbox.py

@@ -25,7 +25,7 @@ from typing import (
 import mmh3
 import sentry_sdk
 from django import db
-from django.db import connections, models, router, transaction
+from django.db import OperationalError, connections, models, router, transaction
 from django.db.models import Max, Min
 from django.db.transaction import Atomic
 from django.dispatch import Signal
@@ -286,7 +286,6 @@ class OutboxScope(IntEnum):
             OutboxCategory.ORGANIZATION_MAPPING_CUSTOMER_ID_UPDATE,
             OutboxCategory.TEAM_UPDATE,
             OutboxCategory.AUTH_PROVIDER_UPDATE,
-            OutboxCategory.AUTH_IDENTITY_UPDATE,
             OutboxCategory.ORGANIZATION_MEMBER_TEAM_UPDATE,
             OutboxCategory.API_KEY_UPDATE,
             OutboxCategory.ORGANIZATION_SLUG_RESERVATION_UPDATE,
@@ -299,6 +298,7 @@ class OutboxScope(IntEnum):
             OutboxCategory.UNUSED_ONE,
             OutboxCategory.UNUSED_TWO,
             OutboxCategory.UNUSUED_THREE,
+            OutboxCategory.AUTH_IDENTITY_UPDATE,
         },
     )
     WEBHOOK_SCOPE = scope_categories(2, {OutboxCategory.WEBHOOK_PROXY})
@@ -539,6 +539,16 @@ class OutboxBase(Model):
                 yield next_shard_row
             else:
                 yield None
+        except OperationalError as e:
+            next_shard_row = self.selected_messages_in_shard(
+                latest_shard_row=latest_shard_row
+            ).first()
+            if next_shard_row is None:
+                yield None
+            else:
+                raise OutboxFlushError(
+                    f"Could not flush shard category={self.category}", self
+                ) from e
         finally:
             try:
                 with connections[using].cursor() as cursor:

+ 34 - 32
src/sentry/tasks/check_auth.py

@@ -7,6 +7,7 @@ from django.utils import timezone
 from sentry.auth.exceptions import IdentityNotValid
 from sentry.models.authidentity import AuthIdentity
 from sentry.models.organizationmembermapping import OrganizationMemberMapping
+from sentry.models.outbox import outbox_context
 from sentry.services.hybrid_cloud.organization import RpcOrganizationMember, organization_service
 from sentry.silo import unguarded_write
 from sentry.silo.base import SiloMode
@@ -70,40 +71,41 @@ def check_auth_identity(auth_identity_id, **kwargs):
     prev_is_valid = not getattr(om.flags, "sso:invalid")
 
     provider = auth_provider.get_provider()
-    try:
-        provider.refresh_identity(auth_identity)
-    except IdentityNotValid as exc:
-        if prev_is_valid:
-            logger.warning(
-                "AuthIdentity(id=%s) notified as not valid: %s",
+    with outbox_context(flush=False):
+        try:
+            provider.refresh_identity(auth_identity)
+        except IdentityNotValid as exc:
+            if prev_is_valid:
+                logger.warning(
+                    "AuthIdentity(id=%s) notified as not valid: %s",
+                    auth_identity_id,
+                    str(exc),
+                    exc_info=True,
+                )
+                metrics.incr("auth.identities.invalidated", skip_internal=False)
+            is_linked = False
+            is_valid = False
+        except Exception as exc:
+            # to ensure security we count any kind of error as an invalidation
+            # event
+            metrics.incr("auth.identities.refresh_error", skip_internal=False)
+            logger.exception(
+                "AuthIdentity(id=%s) returned an error during validation: %s",
                 auth_identity_id,
                 str(exc),
-                exc_info=True,
             )
-            metrics.incr("auth.identities.invalidated", skip_internal=False)
-        is_linked = False
-        is_valid = False
-    except Exception as exc:
-        # to ensure security we count any kind of error as an invalidation
-        # event
-        metrics.incr("auth.identities.refresh_error", skip_internal=False)
-        logger.exception(
-            "AuthIdentity(id=%s) returned an error during validation: %s",
-            auth_identity_id,
-            str(exc),
-        )
-        is_linked = True
-        is_valid = False
-    else:
-        is_linked = True
-        is_valid = True
+            is_linked = True
+            is_valid = False
+        else:
+            is_linked = True
+            is_valid = True
 
-    if getattr(om.flags, "sso:linked") != is_linked:
-        with unguarded_write(using=router.db_for_write(OrganizationMemberMapping)):
-            # flags are not replicated, so it's ok not to create outboxes here.
-            setattr(om.flags, "sso:linked", is_linked)
-            setattr(om.flags, "sso:invalid", not is_valid)
-            organization_service.update_membership_flags(organization_member=om)
+        if getattr(om.flags, "sso:linked") != is_linked:
+            with unguarded_write(using=router.db_for_write(OrganizationMemberMapping)):
+                # flags are not replicated, so it's ok not to create outboxes here.
+                setattr(om.flags, "sso:linked", is_linked)
+                setattr(om.flags, "sso:invalid", not is_valid)
+                organization_service.update_membership_flags(organization_member=om)
 
-    now = timezone.now()
-    auth_identity.update(last_verified=now, last_synced=now)
+        now = timezone.now()
+        auth_identity.update(last_verified=now, last_synced=now)

+ 23 - 1
tests/sentry/models/test_outbox.py

@@ -8,7 +8,8 @@ from unittest.mock import call, patch
 import pytest
 import responses
 from django.conf import settings
-from django.db import connections
+from django.db import connections, router
+from django.db.transaction import get_connection
 from django.test import RequestFactory
 from pytest import raises
 from rest_framework import status
@@ -390,6 +391,27 @@ class OutboxDrainTest(TransactionTestCase):
 
         assert mock_process_region_outbox.call_count == 2
 
+    def test_holding_lock_too_long(self):
+        outbox = OrganizationMember(id=1, organization_id=3, user_id=1).outbox_for_update()
+        with outbox_context(flush=False):
+            outbox.save()
+
+        def test_inside_locked():
+            nonlocal outbox
+            conn = get_connection(router.db_for_write(RegionOutbox))
+            with conn.cursor() as cursor:
+                cursor.execute("SET lock_timeout = '1s'")
+            with outbox.process_shard(RegionOutbox(id=0)) as shard_outbox:
+                assert shard_outbox is None
+            with pytest.raises(OutboxFlushError):
+                with outbox.process_shard(RegionOutbox(id=outbox.id + 1)) as outbox:
+                    pass
+
+        thread = threading.Thread(target=wrap_with_connection_closure(test_inside_locked))
+        with outbox.process_shard(None):
+            thread.start()
+            thread.join()
+
 
 @region_silo_test(stable=True)
 class RegionOutboxTest(TestCase):