Browse Source

ref(hc): Adding helpers for managing transactions (#52943)

1. Adds helper such that during a transaction, no executes occur on a
differing connection
2. Adds helper to collect transaction spans queries run on them, making
it possible to write a test assertion about the transactional spans of
methods.
Zach Collins 1 year ago
parent
commit
2be79be8b5

+ 1 - 1
src/sentry/db/postgres/transactions.py

@@ -88,6 +88,6 @@ def in_test_assert_no_transaction(msg: str):
     from sentry.testutils import hybrid_cloud
 
     for conn in connections.all():
-        assert not hybrid_cloud.simulated_transaction_watermarks.connection_above_watermark(
+        assert not hybrid_cloud.simulated_transaction_watermarks.connection_transaction_depth_above_watermark(
             connection=conn
         ), msg

+ 107 - 5
src/sentry/testutils/hybrid_cloud.py

@@ -4,7 +4,20 @@ import contextlib
 import functools
 import threading
 from types import TracebackType
-from typing import Any, Callable, Generator, List, Mapping, Optional, Sequence, Tuple, Type
+from typing import (
+    Any,
+    Callable,
+    Generator,
+    Iterator,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    TypedDict,
+)
 
 from django.db import connections, transaction
 from django.db.backends.base.base import BaseDatabaseWrapper
@@ -144,17 +157,104 @@ class SimulatedTransactionWatermarks(threading.local):
             total += 1
         return total
 
-    def connection_above_watermark(
+    def connection_transaction_depth_above_watermark(
         self, using: str | None = None, connection: BaseDatabaseWrapper | None = None
-    ) -> bool:
+    ) -> int:
         if connection is None:
             connection = transaction.get_connection(using)
-        return self.get_transaction_depth(connection) > self.state.get(connection.alias, 0)
+        return max(self.get_transaction_depth(connection) - self.state.get(connection.alias, 0), 0)
+
+    def connections_above_watermark(self) -> Set[str]:
+        result = set()
+        for connection in connections.all():
+            if self.connection_transaction_depth_above_watermark(connection=connection):
+                result.add(connection.alias)
+        return result
 
 
 simulated_transaction_watermarks = SimulatedTransactionWatermarks()
 
 
+class EnforceNoCrossTransactionWrapper:
+    alias: str
+
+    def __init__(self, alias: str):
+        self.alias = alias
+
+    def __call__(self, execute: Callable[..., Any], *params: Any) -> Any:
+        open_transactions = simulated_transaction_watermarks.connections_above_watermark()
+        # If you are hitting this, it means you have two open transactions working in differing databases at the same
+        # time.  This is problematic in general for a variety of reasons -- it will never be possible to atomically
+        # transact in both databases (one may succeed and the other fail), but more likely, it means a bug in attempting
+        # to transact with resources that may not even co-exist in production (split silo db is a good example).
+        # Ideally, restructure transactions that span different databases into separate discrete blocks.
+        # It is fine to nest transactions so long as they are operating on the same database.
+        # Alternatively, it may be possible you are hitting this due to limitations in the test environment, such as
+        # when celery tasks fire synchronously, or other work is done in a test that would normally be separated by
+        # different connections / processes.  If you believe this is the case, context the #project-hybrid-cloud channel
+        # for assistance.
+        assert (
+            len(open_transactions) < 2
+        ), f"Found mixed open transactions between dbs {open_transactions}"
+        if open_transactions:
+            assert (
+                self.alias in open_transactions
+            ), f"Transaction opened for db {open_transactions}, but command running against db {self.alias}"
+
+        return execute(*params)
+
+
+@contextlib.contextmanager
+def enforce_no_cross_transaction_interactions():
+    with contextlib.ExitStack() as stack:
+        for conn in connections.all():
+            stack.enter_context(conn.execute_wrapper(EnforceNoCrossTransactionWrapper(conn.alias)))
+        yield
+
+
+class TransactionDetails(TypedDict):
+    transaction: str | None
+    queries: List[str]
+
+
+class TransactionDetailsWrapper:
+    result: List[TransactionDetails]
+    alias: str
+
+    def __init__(self, alias: str, result: List[TransactionDetails]):
+        self.result = result
+        self.alias = alias
+
+    def __call__(self, execute: Callable[..., Any], query: str, *args: Any) -> Any:
+        release = query.startswith("RELEASE")
+        savepoint = query.startswith("SAVEPOINT")
+        depth = simulated_transaction_watermarks.connection_transaction_depth_above_watermark(
+            using=self.alias
+        )
+        active_transaction = self.alias if release or savepoint or depth else None
+        if (
+            (savepoint and depth == 0)
+            or not self.result
+            or self.result[-1]["transaction"] != active_transaction
+        ):
+            cur: TransactionDetails = {"transaction": active_transaction, "queries": []}
+            self.result.append(cur)
+        else:
+            cur = self.result[-1]
+        cur["queries"].append(query)
+        return execute(query, *args)
+
+
+@contextlib.contextmanager
+def collect_transaction_queries() -> Iterator[List[TransactionDetails]]:
+    result: List[TransactionDetails] = []
+
+    with contextlib.ExitStack() as stack:
+        for conn in connections.all():
+            stack.enter_context(conn.execute_wrapper(TransactionDetailsWrapper(conn.alias, result)))
+        yield result
+
+
 @contextlib.contextmanager
 def simulate_on_commit(request: Any):
     """
@@ -180,7 +280,9 @@ def simulate_on_commit(request: Any):
         if connection.closed_in_transaction or connection.needs_rollback:
             return
 
-        if simulated_transaction_watermarks.connection_above_watermark(connection=connection):
+        if simulated_transaction_watermarks.connection_transaction_depth_above_watermark(
+            connection=connection
+        ):
             return
 
         old_validate = connection.validate_no_atomic_block

+ 6 - 1
src/sentry/testutils/silo.py

@@ -244,6 +244,11 @@ def validate_protected_queries(queries: Iterable[Dict[str, str]]) -> None:
     fence_depth = 0
     for query in queries:
         sql = query["sql"]
+        # The real type of queries is Iterable[Dict[str, str | None]], due to some weird bugs in django which can result
+        # in None sql query dicts.  However, typing the parameter that way breaks things due to a lack of covariance in
+        # the VT TypeVar for Dict.
+        if sql is None:
+            continue  # type: ignore
         match = match_fence_query(sql)
         if match:
             operation = match.group("operation")
@@ -273,7 +278,7 @@ def validate_protected_queries(queries: Iterable[Dict[str, str]]) -> None:
                         "Full query log:",
                         "",
                     ]
-                    msg.extend([q["sql"] for q in queries])
+                    msg.extend([q["sql"] for q in queries if q["sql"]])
 
                     raise AssertionError("\n".join(msg))
 

+ 8 - 0
tests/conftest.py

@@ -151,6 +151,14 @@ def setup_simulate_on_commit(request):
         yield
 
 
+@pytest.fixture(autouse=True)
+def setup_enforce_monotonic_transactions(request):
+    from sentry.testutils.hybrid_cloud import enforce_no_cross_transaction_interactions
+
+    with enforce_no_cross_transaction_interactions():
+        yield
+
+
 @pytest.fixture(autouse=True)
 def audit_hybrid_cloud_writes_and_deletes(request):
     """

+ 91 - 7
tests/sentry/db/test_transactions.py

@@ -1,17 +1,85 @@
 import pytest
-from django.db import transaction
+from django.db import IntegrityError, router, transaction
 
+from sentry.conf.server import env
 from sentry.db.postgres.transactions import (
     django_test_transaction_water_mark,
     in_test_assert_no_transaction,
     in_test_hide_transaction_boundary,
 )
+from sentry.models import Organization, User, outbox_context
 from sentry.testutils import TestCase, TransactionTestCase
-from sentry.testutils.silo import all_silo_test
+from sentry.testutils.factories import Factories
+from sentry.testutils.hybrid_cloud import collect_transaction_queries
+from sentry.testutils.silo import no_silo_test
 from sentry.utils.pytest.fixtures import django_db_all
+from sentry.utils.snowflake import MaxSnowflakeRetryError
 
 
 class CaseMixin:
+    def test_collect_transaction_queries(self):
+        with collect_transaction_queries() as queries, outbox_context(flush=False):
+            Organization.objects.filter(name="org1").first()
+            User.objects.filter(username="user1").first()
+
+            with transaction.atomic(using=router.db_for_write(Organization)):
+                try:
+                    with transaction.atomic(using=router.db_for_write(Organization)):
+                        Organization.objects.create(name=None)
+                except (IntegrityError, MaxSnowflakeRetryError):
+                    pass
+
+            with transaction.atomic(using=router.db_for_write(Organization)):
+                Organization.objects.create(name="org3")
+
+            with transaction.atomic(using=router.db_for_write(User)):
+                User.objects.create(username="user2")
+                User.objects.create(username="user3")
+
+        if env("SENTRY_USE_SPLIT_DBS", 0):
+            assert [(s["transaction"]) for s in queries] == [None, "default", "default", "control"]
+        else:
+            assert [(s["transaction"]) for s in queries] == [None, "default", "default", "default"]
+
+    def test_bad_transaction_boundaries(self):
+        if not env("SENTRY_USE_SPLIT_DBS", 0):
+            return
+
+        Factories.create_organization()
+        Factories.create_user()
+
+        with pytest.raises(AssertionError):
+            with transaction.atomic(using=router.db_for_write(User)):
+                Factories.create_organization()
+
+    def test_safe_transaction_boundaries(self):
+        Factories.create_organization()
+        Factories.create_user()
+
+        with transaction.atomic(using=router.db_for_write(Organization)):
+            Factories.create_organization()
+
+            with django_test_transaction_water_mark():
+                Factories.create_user()
+
+            with django_test_transaction_water_mark(), transaction.atomic(
+                using=router.db_for_write(User)
+            ):
+                Factories.create_user()
+
+                with django_test_transaction_water_mark():
+                    Factories.create_organization()
+
+                Factories.create_user()
+
+                with django_test_transaction_water_mark():
+                    Factories.create_organization()
+                    Factories.create_user()
+
+            Factories.create_organization()
+            with django_test_transaction_water_mark():
+                Factories.create_user()
+
     def test_in_test_assert_no_transaction(self):
         def do_assertions():
             in_test_assert_no_transaction("Not, in transaction, should not fail")
@@ -50,23 +118,39 @@ class CaseMixin:
             do_assertions()
 
 
-@all_silo_test(stable=True)
+@no_silo_test(stable=True)
 class TestDjangoTestCaseTransactions(CaseMixin, TestCase):
     pass
 
 
-@all_silo_test(stable=True)
+@no_silo_test(stable=True)
 class TestDjangoTransactionTestCaseTransactions(CaseMixin, TransactionTestCase):
-    pass
+    def test_collect_transaction_queries(self):
+        return
 
 
 class TestPytestDjangoDbAll(CaseMixin):
-    @all_silo_test(stable=True)
+    @no_silo_test(stable=True)
     @django_db_all
     def test_in_test_assert_no_transaction(self):
         super().test_in_test_assert_no_transaction()
 
-    @all_silo_test(stable=True)
+    @no_silo_test(stable=True)
     @django_db_all
     def test_transaction_on_commit(self):
         super().test_transaction_on_commit()
+
+    @no_silo_test(stable=True)
+    @django_db_all
+    def test_safe_transaction_boundaries(self):
+        super().test_safe_transaction_boundaries()
+
+    @no_silo_test(stable=True)
+    @django_db_all
+    def test_bad_transaction_boundaries(self):
+        super().test_bad_transaction_boundaries()
+
+    @no_silo_test(stable=True)
+    @django_db_all
+    def test_collect_transaction_queries(self):
+        super().test_collect_transaction_queries()