Browse Source

feat(dynamic sampling): Implements tasks for generating transaction boosting rules [TET-519] (#45043)

Transaction boosting infrastructure
Radu Woinaroski 2 years ago
parent
commit
149881d80c

+ 307 - 0
src/sentry/dynamic_sampling/prioritise_transactions.py

@@ -0,0 +1,307 @@
+import logging
+import time
+from datetime import datetime, timedelta
+from typing import Iterator, List, Optional, Tuple, TypedDict, cast
+
+from snuba_sdk import (
+    AliasedExpression,
+    Column,
+    Condition,
+    Direction,
+    Entity,
+    Function,
+    Granularity,
+    LimitBy,
+    Op,
+    OrderBy,
+    Query,
+    Request,
+)
+
+from sentry import options
+from sentry.sentry_metrics import indexer
+from sentry.snuba.dataset import Dataset, EntityKey
+from sentry.snuba.metrics.naming_layer.mri import TransactionMRI
+from sentry.snuba.referrer import Referrer
+from sentry.utils.snuba import raw_snql_query
+
+logger = logging.getLogger(__name__)
+MAX_SECONDS = 60
+CHUNK_SIZE = 9998  # Snuba's limit is 10000 and we fetch CHUNK_SIZE+1
+
+
+class ProjectTransactions(TypedDict, total=True):
+    project_id: int
+    org_id: int
+    transaction_counts: List[Tuple[str, int]]
+
+
+def get_orgs_with_project_counts(max_orgs: int, max_projects: int) -> Iterator[List[int]]:
+    """
+    Fetch organisations in batches.
+    A batch will return at max max_orgs elements
+    It will accumulate org ids in the list until either it accumulates max_orgs or the
+    number of projects in the already accumulated orgs is more than max_projects or there
+    are no more orgs
+    """
+    start_time = time.time()
+    metric_id = indexer.resolve_shared_org(str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value))
+    offset = 0
+    load_rate = int(options.get("dynamic-sampling.prioritise_transactions.load_rate") * 100)
+    last_result: List[Tuple[int, int]] = []
+    while (time.time() - start_time) < MAX_SECONDS:
+        query = (
+            Query(
+                match=Entity(EntityKey.GenericOrgMetricsCounters.value),
+                select=[
+                    Function("uniq", [Column("project_id")], "num_projects"),
+                    Column("org_id"),
+                ],
+                groupby=[
+                    Column("org_id"),
+                ],
+                where=[
+                    Condition(Function("modulo", [Column("org_id"), 100]), Op.LT, load_rate),
+                    Condition(Column("timestamp"), Op.GTE, datetime.utcnow() - timedelta(hours=6)),
+                    Condition(Column("timestamp"), Op.LT, datetime.utcnow()),
+                    Condition(Column("metric_id"), Op.EQ, metric_id),
+                ],
+                granularity=Granularity(3600),
+                orderby=[
+                    OrderBy(Column("org_id"), Direction.ASC),
+                ],
+            )
+            .set_limit(CHUNK_SIZE + 1)
+            .set_offset(offset)
+        )
+        request = Request(
+            dataset=Dataset.PerformanceMetrics.value, app_id="dynamic_sampling", query=query
+        )
+        data = raw_snql_query(
+            request,
+            referrer=Referrer.DYNAMIC_SAMPLING_COUNTERS_FETCH_PROJECTS_WITH_COUNT_PER_TRANSACTION.value,
+        )["data"]
+        count = len(data)
+        more_results = count > CHUNK_SIZE
+        offset += CHUNK_SIZE
+        if more_results:
+            data = data[:-1]
+        for row in data:
+            last_result.append((row["org_id"], row["num_projects"]))
+
+        first_idx = 0
+        count_projects = 0
+        for idx, (org_id, num_projects) in enumerate(last_result):
+            count_projects += num_projects
+            if idx - first_idx >= max_orgs - 1 or count_projects >= max_projects:
+                # we got to the number of elements desired
+                yield [o for o, _ in last_result[first_idx : idx + 1]]
+                first_idx = idx + 1
+                count_projects = 0
+
+        # keep what is left unused from last_result for the next iteration or final result
+        last_result = last_result[first_idx:]
+        if not more_results:
+            break
+    if len(last_result) > 0:
+        yield [org_id for org_id, _ in last_result]
+
+
+def fetch_transactions_with_total_volumes(
+    org_ids: List[int], large_transactions: bool, max_transactions: int
+) -> Iterator[ProjectTransactions]:
+    """
+    Fetch transactions for all orgs and all projects  with pagination orgs and projects with count per root project
+
+    org_ids: the orgs for which the projects & transactions should be returned
+
+    large_transactions: if True it returns transactions with the largest count
+                        if False it returns transactions with the smallest count
+
+    max_transactions: maximum number of transactions to return
+    """
+    start_time = time.time()
+    offset = 0
+    org_ids = list(org_ids)  # just to be sure it is not some other sequence
+    transaction_string_id = indexer.resolve_shared_org("transaction")
+    transaction_tag = f"tags_raw[{transaction_string_id}]"
+    metric_id = indexer.resolve_shared_org(str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value))
+    current_org_id: Optional[int] = None
+    current_proj_id: Optional[int] = None
+    transaction_counts: List[Tuple[str, int]] = []
+
+    if large_transactions:
+        transaction_ordering = Direction.DESC
+    else:
+        transaction_ordering = Direction.ASC
+
+    while (time.time() - start_time) < MAX_SECONDS:
+        query = (
+            Query(
+                match=Entity(EntityKey.GenericOrgMetricsCounters.value),
+                select=[
+                    Function("sum", [Column("value")], "num_transactions"),
+                    Column("org_id"),
+                    Column("project_id"),
+                    AliasedExpression(Column(transaction_tag), "transaction_name"),
+                ],
+                groupby=[
+                    Column("org_id"),
+                    Column("project_id"),
+                    AliasedExpression(Column(transaction_tag), "transaction_name"),
+                ],
+                where=[
+                    Condition(Column("timestamp"), Op.GTE, datetime.utcnow() - timedelta(hours=6)),
+                    Condition(Column("timestamp"), Op.LT, datetime.utcnow()),
+                    Condition(Column("metric_id"), Op.EQ, metric_id),
+                    Condition(Column("org_id"), Op.IN, org_ids),
+                ],
+                granularity=Granularity(3600),
+                orderby=[
+                    OrderBy(Column("org_id"), Direction.ASC),
+                    OrderBy(Column("project_id"), Direction.ASC),
+                    OrderBy(Column("num_transactions"), transaction_ordering),
+                ],
+            )
+            .set_limitby(
+                LimitBy(columns=[Column("org_id"), Column("project_id")], count=max_transactions)
+            )
+            .set_limit(CHUNK_SIZE + 1)
+            .set_offset(offset)
+        )
+        request = Request(
+            dataset=Dataset.PerformanceMetrics.value, app_id="dynamic_sampling", query=query
+        )
+        data = raw_snql_query(
+            request,
+            referrer=Referrer.DYNAMIC_SAMPLING_COUNTERS_FETCH_PROJECTS_WITH_COUNT_PER_TRANSACTION.value,
+        )["data"]
+        count = len(data)
+        more_results = count > CHUNK_SIZE
+        offset += CHUNK_SIZE
+
+        if more_results:
+            data = data[:-1]
+
+        for row in data:
+            proj_id = row["project_id"]
+            org_id = row["org_id"]
+            transaction_name = row["transaction_name"]
+            num_transactions = row["num_transactions"]
+            if current_proj_id != proj_id or current_org_id != org_id:
+                if len(transaction_counts) > 0:
+                    yield {
+                        "project_id": cast(int, current_proj_id),
+                        "org_id": cast(int, current_org_id),
+                        "transaction_counts": transaction_counts,
+                    }
+                transaction_counts = []
+                current_org_id = org_id
+                current_proj_id = proj_id
+            transaction_counts.append((transaction_name, num_transactions))
+        if not more_results:
+            if len(transaction_counts) > 0:
+                yield {
+                    "project_id": cast(int, current_proj_id),
+                    "org_id": cast(int, current_org_id),
+                    "transaction_counts": transaction_counts,
+                }
+            break
+    else:
+        logger.error(
+            "",
+            extra={"offset": offset},
+        )
+
+    return None
+
+
+def merge_transactions(
+    left: ProjectTransactions, right: ProjectTransactions
+) -> ProjectTransactions:
+    if left["org_id"] != right["org_id"]:
+        raise ValueError(
+            "missmatched orgs while merging transactions", left["org_id"], right["org_id"]
+        )
+    if left["project_id"] != right["project_id"]:
+        raise ValueError(
+            "missmatched projects while merging transactions",
+            left["project_id"],
+            right["project_id"],
+        )
+
+    transactions = set()
+    merged_transactions = [*left["transaction_counts"]]
+    for transaction_name, _ in merged_transactions:
+        transactions.add(transaction_name)
+
+    for transaction_name, count in right["transaction_counts"]:
+        if transaction_name not in transactions:
+            # not already in left, add it
+            merged_transactions.append((transaction_name, count))
+
+    return {
+        "org_id": left["org_id"],
+        "project_id": left["project_id"],
+        "transaction_counts": merged_transactions,
+    }
+
+
+def transactions_zip(
+    left: Iterator[ProjectTransactions],
+    right: Iterator[ProjectTransactions],
+) -> Iterator[ProjectTransactions]:
+    """
+    returns a generator that zips left and right (when they match) and when not it re-aligns the sequence
+    """
+
+    more_right = True
+    more_left = True
+    left_elm = None
+    right_elm = None
+    while more_left or more_right:
+        if more_right and right_elm is None:
+            try:
+                right_elm = next(right)
+            except StopIteration:
+                more_right = False
+                right_elm = None
+        if more_left and left_elm is None:
+            try:
+                left_elm = next(left)
+            except StopIteration:
+                more_left = False
+                left_elm = None
+
+        if right_elm is not None and left_elm is not None:
+            # we have both right and left try to merge them if they point to the same entity
+            if (
+                left_elm["org_id"] == right_elm["org_id"]
+                and left_elm["project_id"] == right_elm["project_id"]
+            ):
+                yield merge_transactions(left_elm, right_elm)
+                left_elm = None
+                right_elm = None
+            else:
+                # the two elements do not match see which one is "smaller" and return it, keep the other
+                # for the next iteration
+                if left_elm["org_id"] < right_elm["org_id"]:
+                    yield left_elm
+                    left_elm = None
+                elif left_elm["org_id"] > right_elm["org_id"]:
+                    yield right_elm
+                    right_elm = None
+                # orgs are the sam try projects
+                elif left_elm["project_id"] < right_elm["project_id"]:
+                    yield left_elm
+                    left_elm = None
+                else:  # right_elm["project_id"] > left_elm["project_id"]
+                    yield right_elm
+                    right_elm = None
+        else:
+            if left_elm is not None:
+                yield left_elm
+                left_elm = None
+            if right_elm is not None:
+                yield right_elm
+                right_elm = None

+ 38 - 0
src/sentry/dynamic_sampling/rules/helpers/prioritize_transactions.py

@@ -0,0 +1,38 @@
+from typing import Mapping, Tuple, cast
+
+import sentry_sdk
+
+from sentry.dynamic_sampling.rules.utils import get_redis_client_for_ds
+from sentry.utils import json
+
+
+def _get_cache_key(org_id: int, proj_id: int) -> str:
+    return f"ds::o:{org_id}:p:{proj_id}:pri_tran"
+
+
+def get_transactions_resampling_rates(
+    org_id: int, proj_id: int, default_rate: float
+) -> Tuple[Mapping[str, float], float]:
+    redis_client = get_redis_client_for_ds()
+    cache_key = _get_cache_key(org_id=org_id, proj_id=proj_id)
+    try:
+        serialised_val = redis_client.get(cache_key)
+        if serialised_val:
+            val = json.loads(serialised_val)
+            ret_val = cast(Tuple[Mapping[str, float], float], val)
+            return ret_val
+    except (TypeError, ValueError) as e:
+        sentry_sdk.capture_exception(e)
+
+    return {}, default_rate
+
+
+def set_transactions_resampling_rates(
+    org_id: int, proj_id: int, named_rates: Mapping[str, float], default_rate: float, ttl_ms: int
+) -> None:
+    redis_client = get_redis_client_for_ds()
+    cache_key = _get_cache_key(org_id=org_id, proj_id=proj_id)
+    val = [named_rates, default_rate]
+    val_str = json.dumps(val)
+    redis_client.set(cache_key, val_str)
+    redis_client.pexpire(cache_key, ttl_ms)

+ 104 - 2
src/sentry/dynamic_sampling/tasks.py

@@ -1,11 +1,21 @@
 import logging
-from typing import Sequence, Tuple
+from typing import Optional, Sequence, Tuple
 
 from sentry import features, quotas
 from sentry.dynamic_sampling.models.adjustment_models import AdjustedModel
+from sentry.dynamic_sampling.models.transaction_adjustment_model import adjust_sample_rate
 from sentry.dynamic_sampling.models.utils import DSElement
 from sentry.dynamic_sampling.prioritise_projects import fetch_projects_with_total_volumes
+from sentry.dynamic_sampling.prioritise_transactions import (
+    ProjectTransactions,
+    fetch_transactions_with_total_volumes,
+    get_orgs_with_project_counts,
+    transactions_zip,
+)
 from sentry.dynamic_sampling.rules.helpers.prioritise_project import _generate_cache_key
+from sentry.dynamic_sampling.rules.helpers.prioritize_transactions import (
+    set_transactions_resampling_rates,
+)
 from sentry.dynamic_sampling.rules.utils import OrganizationId, ProjectId, get_redis_client_for_ds
 from sentry.models import Organization, Project
 from sentry.tasks.base import instrumented_task
@@ -14,7 +24,12 @@ from sentry.utils import metrics
 
 CHUNK_SIZE = 1000
 MAX_SECONDS = 60
-CACHE_KEY_TTL = 24 * 60 * 60 * 1000
+CACHE_KEY_TTL = 24 * 60 * 60 * 1000  # in milliseconds
+
+# TODO RaduW validate assumptions
+MAX_ORGS_PER_QUERY = 100
+MAX_PROJECTS_PER_QUERY = 5000
+MAX_TRANSACTIONS_PER_PROJECT = 20
 
 logger = logging.getLogger(__name__)
 
@@ -102,3 +117,90 @@ def adjust_sample_rates(
                 project_id=ds_project.id, trigger="dynamic_sampling_prioritise_project_bias"
             )
         pipeline.execute()
+
+
+@instrumented_task(
+    name="sentry.dynamic_sampling.tasks.prioritise_transactions",
+    queue="dynamicsampling",
+    default_retry_delay=5,
+    max_retries=5,
+    soft_time_limit=2 * 60 * 60,  # 2hours
+    time_limit=2 * 60 * 60 + 5,
+)  # type: ignore
+def prioritise_transactions() -> None:
+    """
+    A task that retrieves all relative transaction counts from all projects in all orgs
+    and invokes a task for rebalancing transaction sampling rates within each project
+    """
+    metrics.incr("sentry.tasks.dynamic_sampling.prioritise_transactions.start", sample_rate=1.0)
+    current_org: Optional[Organization] = None
+    current_org_enabled = False
+    with metrics.timer("sentry.tasks.dynamic_sampling.prioritise_transactions", sample_rate=1.0):
+        for orgs in get_orgs_with_project_counts(MAX_ORGS_PER_QUERY, MAX_PROJECTS_PER_QUERY):
+            # get the low and high transactions
+            # TODO can we do this in one query rather than two
+            for project_transactions in transactions_zip(
+                fetch_transactions_with_total_volumes(
+                    orgs,
+                    large_transactions=True,
+                    max_transactions=MAX_TRANSACTIONS_PER_PROJECT // 2,
+                ),
+                fetch_transactions_with_total_volumes(
+                    orgs,
+                    large_transactions=False,
+                    max_transactions=MAX_TRANSACTIONS_PER_PROJECT // 2,
+                ),
+            ):
+
+                if not current_org or current_org.id != project_transactions["org_id"]:
+                    current_org = Organization.objects.get_from_cache(
+                        id=project_transactions["org_id"]
+                    )
+                    current_org_enabled = features.has(
+                        "organizations:ds-prioritise-by-transaction-bias", current_org
+                    )
+                if current_org_enabled:
+                    process_transaction_biases.delay(project_transactions)
+
+
+@instrumented_task(
+    name="sentry.dynamic_sampling.process_transaction_biases",
+    queue="dynamicsampling",
+    default_retry_delay=5,
+    max_retries=5,
+    soft_time_limit=25 * 60,  # 25 mins
+    time_limit=2 * 60 + 5,
+)  # type: ignore
+def process_transaction_biases(project_transactions: ProjectTransactions) -> None:
+    """
+    A task that given a project relative transaction counts calculates rebalancing
+    sampling rates based on the overall desired project sampling rate.
+    """
+
+    org_id = project_transactions["org_id"]
+    project_id = project_transactions["project_id"]
+    transactions = project_transactions["transaction_counts"]
+    project = Project.objects.get_from_cache(id=project_id)
+    sample_rate = quotas.get_blended_sample_rate(project)
+
+    if sample_rate is None:
+        # no sampling => no rebalancing
+        return
+
+    named_rates, global_rate = adjust_sample_rate(
+        transactions=transactions,
+        rate=sample_rate,
+        max_explicit_transactions=MAX_TRANSACTIONS_PER_PROJECT,
+    )
+
+    set_transactions_resampling_rates(
+        org_id=org_id,
+        proj_id=project_id,
+        named_rates=named_rates,
+        default_rate=global_rate,
+        ttl_ms=CACHE_KEY_TTL,
+    )
+
+    schedule_invalidate_project_config(
+        project_id=project_id, trigger="dynamic_sampling_prioritise_transaction_bias"
+    )

+ 3 - 0
src/sentry/options/defaults.py

@@ -630,6 +630,9 @@ register("dynamic-sampling:enabled-biases", default=True)
 # project config computation. This is temporary option to monitor the performance of this feature.
 register("dynamic-sampling:boost-latest-release", default=False)
 register("dynamic-sampling.prioritise_projects.sample_rate", default=0.0)
+# controls how many orgs will be queried by the prioritise by transaction task
+# 0-> no orgs , 0.5 -> half of the orgs, 1.0 -> all orgs
+register("dynamic-sampling.prioritise_transactions.load_rate", default=0.0)
 
 # Killswitch for deriving code mappings
 register("post_process.derive-code-mappings", default=True)

+ 4 - 0
src/sentry/snuba/referrer.py

@@ -355,6 +355,10 @@ class ReferrerBase(Enum):
     DYNAMIC_SAMPLING_DISTRIBUTION_FETCH_PROJECTS_WITH_COUNT_PER_ROOT = (
         "dynamic_sampling.distribution.fetch_projects_with_count_per_root_total_volumes"
     )
+    DYNAMIC_SAMPLING_COUNTERS_FETCH_PROJECTS_WITH_COUNT_PER_TRANSACTION = (
+        "dynamic_sampling.counters.fetch_projects_with_count_per_transaction_volumes"
+    )
+    DYNAMIC_SAMPLING_COUNTERS_FETCH_ACTIVE_ORGS = "dynamic_sampling.counters.fetch_active_orgs"
     EVENTSTORE_GET_EVENT_BY_ID_NODESTORE = "eventstore.get_event_by_id_nodestore"
     EVENTSTORE_GET_EVENTS = "eventstore.get_events"
     EVENTSTORE_GET_NEXT_OR_PREV_EVENT_ID = "eventstore.get_next_or_prev_event_id"

+ 17 - 13
tests/sentry/dynamic_sampling/models/test_transaction_adjustment_model.py

@@ -24,7 +24,7 @@ test_resample_cases = [
 ]
 
 
-@pytest.mark.parametrize("sample_rate", [0.1, 0.5, 0.9])
+@pytest.mark.parametrize("sample_rate", [0.1, 0.5, 0.9, 1.0])
 @pytest.mark.parametrize("transactions", test_resample_cases)
 def test_maintains_overall_sample_rate(sample_rate, transactions):
 
@@ -149,7 +149,7 @@ full_resample_transactions = [
 ]
 
 
-@pytest.mark.parametrize("sample_rate", [0.1, 0.5, 0.9])
+@pytest.mark.parametrize("sample_rate", [0.1, 0.5, 0.9, 1.0])
 @pytest.mark.parametrize("transactions", full_resample_transactions)
 def test_full_resample(sample_rate, transactions):
     """
@@ -182,14 +182,18 @@ def test_full_resample(sample_rate, transactions):
             num_transaction_types_fully_sampled += 1
             full_budget -= count
 
-    budget_per_transaction_type = full_budget / (
-        num_explicit_transaction_types - num_transaction_types_fully_sampled
-    )
-
-    # everything that is not fully sampled should be at budget_per_transaction_type
-    for name, count in transactions:
-        # everything should be explicitly specified
-        assert name in explicit_transactions
-        rate = explicit_transactions[name]
-        if rate != 1.0:
-            assert rate * count == pytest.approx(budget_per_transaction_type)
+    if num_transaction_types_fully_sampled == num_explicit_transaction_types:
+        # all transactions sampled at 1.0 this means we must have specified the
+        # overall sample rate to be 1.0
+        assert sample_rate == 1.0
+    else:
+        budget_per_transaction_type = full_budget / (
+            num_explicit_transaction_types - num_transaction_types_fully_sampled
+        )
+        # everything that is not fully sampled should be at budget_per_transaction_type
+        for name, count in transactions:
+            # everything should be explicitly specified
+            assert name in explicit_transactions
+            rate = explicit_transactions[name]
+            if rate != 1.0:
+                assert rate * count == pytest.approx(budget_per_transaction_type)

+ 90 - 0
tests/sentry/dynamic_sampling/rules/helpers/test_prioritize_transactions.py

@@ -0,0 +1,90 @@
+from sentry.dynamic_sampling.rules.helpers.prioritize_transactions import (
+    get_transactions_resampling_rates,
+    set_transactions_resampling_rates,
+)
+
+
+def test_resampling_rates_in_cache():
+    """
+    Tests that we can correctly store and retrieve resampling rates without
+    key clashes
+    """
+    org_id = 1
+    proj_id = 10
+    other_org_id = 2
+    other_proj_id = 20
+
+    expected_global_rate = 0.3
+    expected_trans_rates = {"t1": 0.6, "t2": 0.7}
+
+    other_rate = 0.1
+    other_rates = {"t11": 0.1}
+
+    # store our desired transaction rate
+    set_transactions_resampling_rates(
+        org_id=org_id,
+        proj_id=proj_id,
+        named_rates=expected_trans_rates,
+        default_rate=expected_global_rate,
+        ttl_ms=100 * 1000,
+    )
+
+    # store some garbage to check we don't accidentally override our info
+    set_transactions_resampling_rates(
+        org_id=other_org_id,
+        proj_id=proj_id,
+        named_rates=other_rates,
+        default_rate=other_rate,
+        ttl_ms=100 * 1000,
+    )
+    set_transactions_resampling_rates(
+        org_id=org_id,
+        proj_id=other_proj_id,
+        named_rates=other_rates,
+        default_rate=other_rate,
+        ttl_ms=100 * 1000,
+    )
+
+    actual_trans_rates, actual_global_rate = get_transactions_resampling_rates(
+        org_id=org_id, proj_id=proj_id, default_rate=1.0
+    )
+
+    assert actual_trans_rates == expected_trans_rates
+    assert actual_global_rate == expected_global_rate
+
+
+def test_resampling_rates_missing():
+    """
+    Tests that if the resampling rates are not in cache the default values are returned
+    """
+    org_id = 1
+    proj_id = 10
+    other_org_id = 2
+    other_proj_id = 20
+
+    other_rate = 0.1
+    other_rates = {"t11": 0.1}
+
+    # store some garbage to check we don't accidentally return other keys
+    set_transactions_resampling_rates(
+        org_id=other_org_id,
+        proj_id=proj_id,
+        named_rates=other_rates,
+        default_rate=other_rate,
+        ttl_ms=100 * 1000,
+    )
+    set_transactions_resampling_rates(
+        org_id=org_id,
+        proj_id=other_proj_id,
+        named_rates=other_rates,
+        default_rate=other_rate,
+        ttl_ms=100 * 1000,
+    )
+
+    expected_global_rate = 0.33
+    actual_trans_rates, actual_global_rate = get_transactions_resampling_rates(
+        org_id=org_id, proj_id=proj_id, default_rate=expected_global_rate
+    )
+
+    assert actual_trans_rates == {}
+    assert actual_global_rate == expected_global_rate

+ 164 - 0
tests/sentry/dynamic_sampling/test_prioritise_transactions.py

@@ -0,0 +1,164 @@
+from datetime import datetime, timezone
+
+from freezegun import freeze_time
+
+from sentry.dynamic_sampling.prioritise_transactions import (
+    fetch_transactions_with_total_volumes,
+    get_orgs_with_project_counts,
+    merge_transactions,
+    transactions_zip,
+)
+from sentry.snuba.metrics import TransactionMRI
+from sentry.testutils import BaseMetricsLayerTestCase, SnubaTestCase, TestCase
+
+MOCK_DATETIME = datetime(2023, 8, 7, 0, 0, 0, tzinfo=timezone.utc)
+
+
+@freeze_time(MOCK_DATETIME)
+class PrioritiseProjectsSnubaQueryTest(BaseMetricsLayerTestCase, TestCase, SnubaTestCase):
+    @property
+    def now(self):
+        return MOCK_DATETIME
+
+    def setUp(self):
+        super().setUp()
+        self.orgs_info = []
+        num_orgs = 3
+        num_proj_per_org = 3
+        for org_idx in range(num_orgs):
+            org = self.create_organization(f"test-org{org_idx}")
+            org_info = {"org_id": org.id, "project_ids": []}
+            self.orgs_info.append(org_info)
+            for proj_idx in range(num_proj_per_org):
+                p = self.create_project(organization=org)
+                org_info["project_ids"].append(p.id)
+                # create 5 transaction types
+                for name in ["ts1", "ts2", "tm3", "tl4", "tl5"]:
+                    # make up some unique count
+                    idx = org_idx * num_orgs + proj_idx
+                    num_transactions = self.get_count_for_transaction(idx, name)
+                    self.store_performance_metric(
+                        name=TransactionMRI.COUNT_PER_ROOT_PROJECT.value,
+                        tags={"transaction": name},
+                        hours_before_now=1,
+                        value=num_transactions,
+                        project_id=p.id,
+                        org_id=org.id,
+                    )
+        self.org_ids = [org["org_id"] for org in self.orgs_info]
+
+    def get_count_for_transaction(self, idx: int, name: str):
+        """
+        Create some known count based on transaction name and the order (based on org and project)
+        """
+        counts = {
+            "ts1": 1,
+            "ts2": 100,
+            "tm3": 1000,
+            "tl4": 2000,
+            "tl5": 3000,
+        }
+        return idx + counts[name]
+
+    def test_get_orgs_with_transactions_respects_max_orgs(self):
+        with self.options({"dynamic-sampling.prioritise_transactions.load_rate": 1.0}):
+            actual = list(get_orgs_with_project_counts(2, 20))
+
+        orgs = self.org_ids
+        # we should return groups of 2 orgs at a time
+        assert actual == [[orgs[0], orgs[1]], [orgs[2]]]
+
+    def test_get_orgs_with_transactions_respects_max_projs(self):
+        with self.options({"dynamic-sampling.prioritise_transactions.load_rate": 1.0}):
+            actual = list(get_orgs_with_project_counts(10, 5))
+
+        orgs = [org["org_id"] for org in self.orgs_info]
+        # since each org has 3 projects and we have a limit of 5 proj
+        # we should return 2 orgs at a time
+        assert actual == [[orgs[0], orgs[1]], [orgs[2]]]
+
+    def test_fetch_transactions_with_total_volumes_large(self):
+        """
+        Create some transactions in some orgs and project and verify
+        that they are correctly returned by fetch_transactions_with_total_volumes
+        """
+
+        # get the transaction counts from snuba and check that they match what we put in
+        orgs = self.org_ids
+
+        expected_names = {"tm3", "tl5", "tl4"}
+        for idx, p_tran in enumerate(fetch_transactions_with_total_volumes(orgs, True, 3)):
+            if p_tran is not None:
+                assert len(p_tran["transaction_counts"]) == 3
+                for name, count in p_tran["transaction_counts"]:
+                    assert name in expected_names
+                    assert count == self.get_count_for_transaction(idx, name)
+
+    def test_fetch_transactions_with_total_volumes_small(self):
+        """
+        Create some transactions in some orgs and project and verify
+        that they are correctly returned by fetch_transactions_with_total_volumes
+        """
+
+        # get the transaction counts from snuba and check that they match what we put in
+        orgs = self.org_ids
+
+        expected_names = {"ts1", "ts2"}
+        for idx, p_tran in enumerate(fetch_transactions_with_total_volumes(orgs, False, 2)):
+            assert len(p_tran["transaction_counts"]) == 2
+            if p_tran is not None:
+                for name, count in p_tran["transaction_counts"]:
+                    assert name in expected_names
+                    assert count == self.get_count_for_transaction(idx, name)
+
+
+def test_merge_transactions():
+    t1 = {"project_id": 1, "org_id": 2, "transaction_counts": [("ts1", 10), ("tm2", 100)]}
+    t2 = {"project_id": 1, "org_id": 2, "transaction_counts": [("tm2", 100), ("tl3", 1000)]}
+
+    actual = merge_transactions(t1, t2)
+
+    expected = {
+        "project_id": 1,
+        "org_id": 2,
+        "transaction_counts": [("ts1", 10), ("tm2", 100), ("tl3", 1000)],
+    }
+
+    assert actual == expected
+
+
+def test_transactions_zip():
+    high = 1
+    low = 2
+    both = 3
+
+    def pt(org_id: int, proj_id: int, what: int):
+        if what == high:
+            transaction_counts = [("tm2", 100), ("tl3", 1000)]
+        elif what == low:
+            transaction_counts = [("ts1", 10), ("tm2", 100)]
+        else:
+            transaction_counts = [("ts1", 10), ("tm2", 100), ("tl3", 1000)]
+        return {
+            "project_id": proj_id,
+            "org_id": org_id,
+            "transaction_counts": transaction_counts,  # not relevant in zipping
+        }
+
+    trans_low = [pt(1, 1, low), pt(1, 2, low), pt(2, 1, low), pt(2, 3, low), pt(3, 2, low)]
+    trans_high = [pt(2, 1, high), (pt(2, 2, high)), pt(3, 1, high), pt(3, 2, high), pt(3, 3, high)]
+
+    expected = [
+        pt(1, 1, low),
+        pt(1, 2, low),
+        pt(2, 1, both),
+        pt(2, 2, high),
+        pt(2, 3, low),
+        pt(3, 1, high),
+        pt(3, 2, both),
+        pt(3, 3, high),
+    ]
+
+    actual = list(transactions_zip((x for x in trans_low), (x for x in trans_high)))
+
+    assert actual == expected

+ 76 - 1
tests/sentry/dynamic_sampling/test_tasks.py

@@ -6,7 +6,10 @@ from django.utils import timezone
 from freezegun import freeze_time
 
 from sentry.dynamic_sampling import generate_rules
-from sentry.dynamic_sampling.tasks import prioritise_projects
+from sentry.dynamic_sampling.rules.helpers.prioritize_transactions import (
+    get_transactions_resampling_rates,
+)
+from sentry.dynamic_sampling.tasks import prioritise_projects, prioritise_transactions
 from sentry.snuba.metrics import TransactionMRI
 from sentry.testutils import BaseMetricsLayerTestCase, SnubaTestCase, TestCase
 
@@ -78,3 +81,75 @@ class TestPrioritiseProjectsTask(BaseMetricsLayerTestCase, TestCase, SnubaTestCa
             "value": pytest.approx(0.4444444444444444),
         }
         assert generate_rules(proj_d)[0]["samplingValue"] == {"type": "sampleRate", "value": 1.0}
+
+
+@freeze_time(MOCK_DATETIME)
+class TestPrioritiseTransactionsTask(BaseMetricsLayerTestCase, TestCase, SnubaTestCase):
+    @property
+    def now(self):
+        return MOCK_DATETIME
+
+    def setUp(self):
+        super().setUp()
+        self.orgs_info = []
+        num_orgs = 3
+        num_proj_per_org = 3
+        for org_idx in range(num_orgs):
+            org = self.create_organization(f"test-org{org_idx}")
+            org_info = {"org_id": org.id, "project_ids": []}
+            self.orgs_info.append(org_info)
+            for proj_idx in range(num_proj_per_org):
+                p = self.create_project(organization=org)
+                org_info["project_ids"].append(p.id)
+                # create 5 transaction types
+                for name in ["ts1", "ts2", "tm3", "tl4", "tl5"]:
+                    # make up some unique count
+                    idx = org_idx * num_orgs + proj_idx
+                    num_transactions = self.get_count_for_transaction(idx, name)
+                    self.store_performance_metric(
+                        name=TransactionMRI.COUNT_PER_ROOT_PROJECT.value,
+                        tags={"transaction": name},
+                        hours_before_now=1,
+                        value=num_transactions,
+                        project_id=p.id,
+                        org_id=org.id,
+                    )
+        self.org_ids = [org["org_id"] for org in self.orgs_info]
+
+    def get_count_for_transaction(self, idx: int, name: str):
+        """
+        Create some known count based on transaction name and the order (based on org and project)
+        """
+        counts = {
+            "ts1": 1,
+            "ts2": 100,
+            "tm3": 1000,
+            "tl4": 2000,
+            "tl5": 3000,
+        }
+        return idx + counts[name]
+
+    @patch("sentry.dynamic_sampling.rules.base.quotas.get_blended_sample_rate")
+    def test_prioritise_transactions_simple(self, get_blended_sample_rate):
+        """
+        Create orgs projects & transactions and then check that the task creates rebalancing data
+        in Redis
+        """
+        get_blended_sample_rate.return_value = 0.25
+
+        with self.options({"dynamic-sampling.prioritise_transactions.load_rate": 1.0}):
+            with self.feature({"organizations:ds-prioritise-by-transaction-bias": True}):
+                with self.tasks():
+                    prioritise_transactions()
+
+        # now redis should contain rebalancing data for our projects
+        for org in self.orgs_info:
+            org_id = org["org_id"]
+            for proj_id in org["project_ids"]:
+                tran_rate, global_rate = get_transactions_resampling_rates(
+                    org_id=org_id, proj_id=proj_id, default_rate=0.1
+                )
+                for transaction_name in ["ts1", "ts2", "tm3", "tl4", "tl5"]:
+                    assert (
+                        transaction_name in tran_rate
+                    )  # check we have some rate calculated for each transaction