Просмотр исходного кода

feat(dynamic-sampling) Rebalancing model V2 (#45992)

Radu Woinaroski 1 год назад
Родитель
Сommit
dccad0d897

+ 87 - 134
src/sentry/dynamic_sampling/models/transaction_adjustment_model.py

@@ -1,66 +1,105 @@
 import operator
 from collections import namedtuple
 from copy import copy
-from typing import List, Mapping, MutableMapping, Tuple
+from typing import List, Mapping, MutableMapping, Optional, Tuple
 
 AdjustedSampleRate = namedtuple("AdjustedSampleRate", "explicit_rates, global_rate")
 
 
 def adjust_sample_rate(
-    transactions: List[Tuple[str, int]], rate: float, max_explicit_transactions: int
-) -> Tuple[Mapping[str, float], float]:
+    classes: List[Tuple[str, float]],
+    rate: float,
+    total_num_classes: Optional[int],
+    total: Optional[float],
+) -> Tuple[MutableMapping[str, float], float]:
     """
-    Calculates transaction sample size in order to maximize the number of small transactions
-
-    :param transactions: the transaction types as an array of (name, count) tuples
-    :param rate: the overall desired rate
-    :param max_explicit_transactions: the maximum number of transactions that can have individual
-        rates set, the rest will have a common rate
-    :return: a tuple with the first element a mapping transaction-name->sampling-rate and the
-    second element the transaction rate for all other transactions (that are not in the dict)
+    Adjusts sampling rates to bring the number of samples kept in each class as close to
+    the same value as possible while maintaining the overall sampling rate.
+
+    The algorithm adjusts the explicitly given classes individually to bring them to
+    the ideal sample rate and then adjusts the global sample rate for all the remaining classes.
+
+    :param classes: a list of class id, num_samples in class
+    :param rate: global rate of sampling desired
+    :param total_num_classes: total number of classes (including the explicitly specified in classes)
+    :param total: total number of samples in all classes (including the explicitly specified classes)
+
+    :return: a dictionary with explicit rates for individual classes class_name->rate and
+    a rate for all other (unspecified) classes.
     """
-    # sort by transaction count
-    transactions = sorted(transactions, key=operator.itemgetter(1))
-    if len(transactions) <= max_explicit_transactions:
-        # we can get the ideal rate to all do a full resample
-        return AdjustedSampleRate(
-            explicit_rates=adjust_sample_rate_full(transactions, rate), global_rate=rate
-        )
-
-    # TODO I think we can find out which is the best option by looking at the distribution
-    #   if we compare the smallest rate with the max_explicit_rate_1 rate and the ratio is
-    #   around  1/rate then I think it is safe to go with resample_min otherwise resample_max
-    #   need to investigate and refine this idea.
-
-    # See what's gives better results, setting the rate of the smallest transactions or of
-    # the largest transactions (if we have just a few very large transactions is better to
-    # give them individual rates), if we have a few very small transactions we are better off
-    # specifying sample rate for the small transactions.
-    # The way we evaluate what's best is by maximizing the minimum number of samples returned
-    # by a transaction which is not sampled at 1.0
-    min_sample_size_x, explicit_rates_x, global_rate_x = adjust_sample_rate_max(
-        transactions, rate, max_explicit_transactions
-    )
-
-    min_sample_size_n, explicit_rates_n, global_rate_n = adjust_sample_rate_min(
-        transactions, rate, max_explicit_transactions
-    )
-
-    if min_sample_size_n < min_sample_size_x:
-        return AdjustedSampleRate(explicit_rates=explicit_rates_x, global_rate=global_rate_x)
+
+    classes = sorted(classes, key=operator.itemgetter(1))
+
+    if total_num_classes is None or total is None:
+        # we don't have totals information, the best we can do is do a full_adjustment on
+        # the explicit classes and keep the original rate for the rest
+        return adjust_sample_rate_full(classes, rate), rate
+
+    # total count for the explicitly specified classes
+    total_explicit = get_total(classes)
+    # total count for the unspecified classes
+    total_implicit = total - total_explicit
+    # total number of specified classes
+    num_explicit_classes = len(classes)
+    # total number of unspecified classes
+    num_implicit_classes = total_num_classes - num_explicit_classes
+
+    total_budget = total * rate
+    budget_per_class = total_budget / total_num_classes
+
+    implicit_budget = budget_per_class * num_implicit_classes
+    explicit_budget = budget_per_class * num_explicit_classes
+
+    if num_explicit_classes == total_num_classes:
+        # we have specified all classes
+        explicit_rates = adjust_sample_rate_full(classes, rate)
+        implicit_rate = rate  # doesn't really matter since everything is explicit
+    elif total_implicit < implicit_budget:
+        # we would not be able to spend all implicit budget we can only spend
+        # a maximum of total_implicit, set the implicit rate to 1
+        # and reevaluate the available budget for the explicit classes
+        implicit_rate = 1
+        # we spent all we could on the implicit classes see what budget we
+        # have left
+        explicit_budget = total_budget - total_implicit
+        # calculate the new global rate for the explicit transactions that
+        # would bring the overall rate to the desired rate
+        explicit_rate = explicit_budget / total_explicit
+        explicit_rates = adjust_sample_rate_full(classes, explicit_rate)
+    elif total_explicit < explicit_budget:
+        # we would not be able to spend all explicit budget we can only
+        # send a maximum of total_explicit so set the explicit rate to 1 for
+        # all explicit classes and reevaluate the available budget for the implicit classes
+        explicit_rates = {name: 1.0 for name, _count in classes}
+
+        # calculate the new global rate for the implicit transactions
+        implicit_budget = total_budget - total_explicit
+        implicit_rate = implicit_budget / total_implicit
     else:
-        return AdjustedSampleRate(explicit_rates=explicit_rates_n, global_rate=global_rate_n)
+        # we can spend all the implicit budget on the implicit classes
+        # and all the explicit budget on the explicit classes
+        # the calculation of rates can be done independently for explicit and
+        # implicit classes
+        implicit_rate = implicit_budget / total_implicit
+        explicit_rate = explicit_budget / total_explicit
+        explicit_rates = adjust_sample_rate_full(classes, explicit_rate)
+    return explicit_rates, implicit_rate
 
 
 def adjust_sample_rate_full(
-    transactions: List[Tuple[str, int]], rate: float
+    transactions: List[Tuple[str, float]], rate: float
 ) -> MutableMapping[str, float]:
     """
-    resample all transactions to their ideal size
+    Resample all classes to their ideal size.
+
+    Ideal size is defined as the minimum of:
+    - num_samples_in_class ( i.e. no sampling, rate 1.0)
+    - total_num_samples * rate / num_classes
+
     """
     transactions = copy(transactions)
     ret_val = {}
-    num_transactions = total_transactions(transactions)
+    num_transactions = get_total(transactions)
     # calculate how many transactions we are allowed to keep overall
     # this will allow us to pass transactions between different transaction types
     total_budget = num_transactions * rate
@@ -85,101 +124,15 @@ def adjust_sample_rate_full(
     return ret_val
 
 
-_SampleRates = namedtuple("_SampleRates", "min_sample_size, explicit_rates, global_rate")
-
-
-def adjust_sample_rate_max(
-    transactions: List[Tuple[str, int]], rate: float, max_explicit_transactions: int
-) -> _SampleRates:
-    """
-    Calculates explicit rates for the transactions with the biggest number of elements and
-    with the remaining space created adjusts the sampling rate for the rest
-    """
-    transactions = copy(transactions)
-    num_transactions, num_types, total_budget, budget_per_transaction_type = _sampling_info(
-        transactions, rate
-    )
-    # first see if we can get sample rate 1 for all small transactions if that's the case then it is
-    # the optimal solution
-    small_transactions = transactions[0 : num_types - max_explicit_transactions]
-    count_small_transactions = total_transactions(small_transactions)
-    transaction_dict: MutableMapping[str, float] = {}
-    if count_small_transactions < budget_per_transaction_type * len(small_transactions):
-        # we can set all small transactions to rate 1 and then adjust large transactions
-        budget_all_big_transactions = total_budget - count_small_transactions
-        big_transactions = transactions[-max_explicit_transactions:]
-        num_big_transactions = total_transactions(big_transactions)
-        # calculate the new rate for big transaction and treat them as a full resample
-        new_rate = budget_all_big_transactions / num_big_transactions
-        transaction_dict = adjust_sample_rate_full(big_transactions, new_rate)
-        # since all small transactions are sampled at 1 the min_sample_size can be found
-        # in any big transaction that is not sampled at 1 (since they all will have the
-        # same size
-        for name, count in big_transactions:
-            rate = transaction_dict[name]
-            if rate != 1.0:
-                min_sample_size = count * rate
-                return _SampleRates(min_sample_size, transaction_dict, 1)
-        # if we are here we are sampling at 1.0 (a bit silly but no reason to crash)
-        return _SampleRates(budget_per_transaction_type, transaction_dict, 1)
-    else:
-        # push all big transactions at the ideal sample size
-        for _ in range(max_explicit_transactions):
-            name, count = transactions.pop(-1)
-            transaction_rate = budget_per_transaction_type / count
-            transaction_dict[name] = transaction_rate
-            total_budget -= budget_per_transaction_type
-        global_rate = total_budget / total_transactions(transactions)
-        # min sample size would be for the first transaction
-        min_sample_size = transactions[0][1] * global_rate
-        return _SampleRates(min_sample_size, transaction_dict, global_rate)
-
-
-def adjust_sample_rate_min(
-    transactions: List[Tuple[str, int]], rate: float, max_explicit_transactions: int
-) -> _SampleRates:
-    transactions = copy(transactions)
-    num_transactions, num_types, total_budget, budget_per_transaction_type = _sampling_info(
-        transactions, rate
-    )
-
-    transactions_dict = {}
-    # push all small transactions at either rate=1 or ideal sample size
-    for idx in range(max_explicit_transactions):
-        name, count = transactions.pop(0)
-        if count < budget_per_transaction_type:
-            transactions_dict[name] = 1.0
-            total_budget -= count
-        else:
-            transactions_dict[name] = budget_per_transaction_type / count
-            total_budget -= budget_per_transaction_type
-        num_types = len(transactions)
-        budget_per_transaction_type = total_budget / num_types
-    # calculate rate for all other transactions
-    global_rate = total_budget / total_transactions(transactions)
-    min_sample_size = global_rate * transactions[0][1]
-    return _SampleRates(min_sample_size, transactions_dict, global_rate)
-
-
-def _sampling_info(
-    transactions: List[Tuple[str, int]], rate: float
-) -> Tuple[int, int, float, float]:
-    num_types = len(transactions)
-    num_transactions = total_transactions(transactions)
-    total_budget = num_transactions * rate
-    budget_per_transaction_type = total_budget / num_types
-    return num_transactions, num_types, total_budget, budget_per_transaction_type
-
-
-def total_transactions(transactions: List[Tuple[str, int]]) -> int:
-    ret_val = 0
+def get_total(transactions: List[Tuple[str, float]]) -> float:
+    ret_val = 0.0
     for _, v in transactions:
         ret_val += v
     return ret_val
 
 
-def get_num_sampled_transactions(
-    transactions: List[Tuple[str, int]], trans_dict: Mapping[str, float], global_rate: float
+def get_num_sampled_elements(
+    transactions: List[Tuple[str, float]], trans_dict: Mapping[str, float], global_rate: float
 ) -> float:
     num_transactions = 0.0
     for name, count in transactions:

+ 221 - 43
src/sentry/dynamic_sampling/prioritise_transactions.py

@@ -1,7 +1,7 @@
 import logging
 import time
 from datetime import datetime, timedelta
-from typing import Iterator, List, Optional, Tuple, TypedDict, cast
+from typing import Callable, Iterator, List, Optional, Tuple, TypedDict, cast
 
 from snuba_sdk import (
     AliasedExpression,
@@ -32,10 +32,46 @@ CHUNK_SIZE = 9998  # Snuba's limit is 10000 and we fetch CHUNK_SIZE+1
 QUERY_TIME_INTERVAL = timedelta(hours=1)
 
 
+class ProjectIdentity(TypedDict, total=True):
+    """
+    Project identity, used to match projects and also to
+    order them
+    """
+
+    project_id: int
+    org_id: int
+
+
 class ProjectTransactions(TypedDict, total=True):
+    """
+    Information about the project transactions
+    """
+
+    project_id: int
+    org_id: int
+    transaction_counts: List[Tuple[str, float]]
+    total_num_transactions: Optional[float]
+    total_num_classes: Optional[int]
+
+
+class ProjectTransactionsTotals(TypedDict, total=True):
     project_id: int
     org_id: int
-    transaction_counts: List[Tuple[str, int]]
+    total_num_transactions: float
+    total_num_classes: int
+
+
+def is_same_project(left: Optional[ProjectIdentity], right: Optional[ProjectIdentity]) -> bool:
+    if left is None or right is None:
+        return False
+
+    return left["project_id"] == right["project_id"] and left["org_id"] == right["org_id"]
+
+
+def is_project_identity_before(left: ProjectIdentity, right: ProjectIdentity) -> bool:
+    return left["org_id"] < right["org_id"] or (
+        left["org_id"] == right["org_id"] and left["project_id"] < right["project_id"]
+    )
 
 
 def get_orgs_with_project_counts(max_orgs: int, max_projects: int) -> Iterator[List[int]]:
@@ -109,6 +145,85 @@ def get_orgs_with_project_counts(max_orgs: int, max_projects: int) -> Iterator[L
         yield [org_id for org_id, _ in last_result]
 
 
+def fetch_project_transaction_totals(org_ids: List[int]) -> Iterator[ProjectTransactionsTotals]:
+    """
+    Fetches the total number of transactions and the number of distinct transaction types for each
+    project in the given organisations
+    :param org_ids:
+    :return: an iterator of org_ids
+    """
+    start_time = time.time()
+    offset = 0
+    org_ids = list(org_ids)  # just to be sure it is not some other sequence
+    transaction_string_id = indexer.resolve_shared_org("transaction")
+    transaction_tag = f"tags_raw[{transaction_string_id}]"
+    metric_id = indexer.resolve_shared_org(str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value))
+    more_results = True
+
+    while more_results and (time.time() - start_time) < MAX_SECONDS:
+        query = (
+            Query(
+                match=Entity(EntityKey.GenericOrgMetricsCounters.value),
+                select=[
+                    Function("sum", [Column("value")], "num_transactions"),
+                    Function("uniq", [Column(transaction_tag)], "num_classes"),
+                    Column("org_id"),
+                    Column("project_id"),
+                ],
+                groupby=[
+                    Column("org_id"),
+                    Column("project_id"),
+                ],
+                where=[
+                    Condition(Column("timestamp"), Op.GTE, datetime.utcnow() - QUERY_TIME_INTERVAL),
+                    Condition(Column("timestamp"), Op.LT, datetime.utcnow()),
+                    Condition(Column("metric_id"), Op.EQ, metric_id),
+                    Condition(Column("org_id"), Op.IN, org_ids),
+                ],
+                granularity=Granularity(3600),
+                orderby=[
+                    OrderBy(Column("org_id"), Direction.ASC),
+                    OrderBy(Column("project_id"), Direction.ASC),
+                ],
+            )
+            .set_limit(CHUNK_SIZE + 1)
+            .set_offset(offset)
+        )
+        request = Request(
+            dataset=Dataset.PerformanceMetrics.value, app_id="dynamic_sampling", query=query
+        )
+        data = raw_snql_query(
+            request,
+            referrer=Referrer.DYNAMIC_SAMPLING_COUNTERS_FETCH_PROJECTS_WITH_TRANSACTION_TOTALS.value,
+        )["data"]
+        count = len(data)
+        more_results = count > CHUNK_SIZE
+        offset += CHUNK_SIZE
+
+        if more_results:
+            data = data[:-1]
+
+        for row in data:
+            proj_id = row["project_id"]
+            org_id = row["org_id"]
+            num_transactions = row["num_transactions"]
+            num_classes = row["num_classes"]
+            yield {
+                "project_id": proj_id,
+                "org_id": org_id,
+                "total_num_transactions": num_transactions,
+                "total_num_classes": num_classes,
+            }
+
+    else:
+        logger.error(
+            "",
+            extra={"offset": offset},
+        )
+
+    return None
+
+
 def fetch_transactions_with_total_volumes(
     org_ids: List[int], large_transactions: bool, max_transactions: int
 ) -> Iterator[ProjectTransactions]:
@@ -122,6 +237,11 @@ def fetch_transactions_with_total_volumes(
 
     max_transactions: maximum number of transactions to return
     """
+
+    if max_transactions == 0:
+        # no transactions required from this end (probably we only need transactions from the other end)
+        return None
+
     start_time = time.time()
     offset = 0
     org_ids = list(org_ids)  # just to be sure it is not some other sequence
@@ -130,7 +250,7 @@ def fetch_transactions_with_total_volumes(
     metric_id = indexer.resolve_shared_org(str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value))
     current_org_id: Optional[int] = None
     current_proj_id: Optional[int] = None
-    transaction_counts: List[Tuple[str, int]] = []
+    transaction_counts: List[Tuple[str, float]] = []
 
     if large_transactions:
         transaction_ordering = Direction.DESC
@@ -184,7 +304,6 @@ def fetch_transactions_with_total_volumes(
 
         if more_results:
             data = data[:-1]
-
         for row in data:
             proj_id = row["project_id"]
             org_id = row["org_id"]
@@ -196,6 +315,8 @@ def fetch_transactions_with_total_volumes(
                         "project_id": cast(int, current_proj_id),
                         "org_id": cast(int, current_org_id),
                         "transaction_counts": transaction_counts,
+                        "total_num_transactions": None,
+                        "total_num_classes": None,
                     }
                 transaction_counts = []
                 current_org_id = org_id
@@ -207,6 +328,8 @@ def fetch_transactions_with_total_volumes(
                     "project_id": cast(int, current_proj_id),
                     "org_id": cast(int, current_org_id),
                     "transaction_counts": transaction_counts,
+                    "total_num_transactions": None,
+                    "total_num_classes": None,
                 }
             break
     else:
@@ -219,48 +342,110 @@ def fetch_transactions_with_total_volumes(
 
 
 def merge_transactions(
-    left: ProjectTransactions, right: ProjectTransactions
+    left: ProjectTransactions,
+    right: Optional[ProjectTransactions],
+    totals: Optional[ProjectTransactionsTotals],
 ) -> ProjectTransactions:
-    if left["org_id"] != right["org_id"]:
+
+    if right is None and left is None:
+        raise ValueError(
+            "no transactions passed to merge",
+        )
+
+    if left is not None and right is not None and not is_same_project(left, right):
         raise ValueError(
-            "missmatched orgs while merging transactions", left["org_id"], right["org_id"]
+            "mismatched project transactions",
+            (left["org_id"], left["project_id"]),
+            (right["org_id"], right["project_id"]),
         )
-    if left["project_id"] != right["project_id"]:
+
+    if totals is not None and not is_same_project(left, totals):
         raise ValueError(
-            "missmatched projects while merging transactions",
-            left["project_id"],
-            right["project_id"],
+            "mismatched projectTransaction and projectTransactionTotals",
+            (left["org_id"], left["project_id"]),
+            (totals["org_id"], totals["project_id"]),
         )
 
-    transactions = set()
-    merged_transactions = [*left["transaction_counts"]]
-    for transaction_name, _ in merged_transactions:
-        transactions.add(transaction_name)
+    if right is None:
+        merged_transactions = left["transaction_counts"]
+    else:
+        # we have both left and right we need to merge
+        transactions = set()
+        merged_transactions = [*left["transaction_counts"]]
+        for transaction_name, _ in merged_transactions:
+            transactions.add(transaction_name)
 
-    for transaction_name, count in right["transaction_counts"]:
-        if transaction_name not in transactions:
-            # not already in left, add it
-            merged_transactions.append((transaction_name, count))
+        for transaction_name, count in right["transaction_counts"]:
+            if transaction_name not in transactions:
+                # not already in left, add it
+                merged_transactions.append((transaction_name, count))
 
     return {
         "org_id": left["org_id"],
         "project_id": left["project_id"],
         "transaction_counts": merged_transactions,
+        "total_num_transactions": totals["total_num_transactions"] if totals is not None else None,
+        "total_num_classes": totals["total_num_classes"] if totals is not None else None,
     }
 
 
+def next_totals(
+    totals: Iterator[ProjectTransactionsTotals],
+) -> Callable[[ProjectIdentity], Optional[ProjectTransactionsTotals]]:
+    """
+    Advances the total iterator until it reaches the required identity
+
+    Given a match the iterator returns None if it cannot find it ( i.e. it is
+    already past it) or it is at the end (it never terminates, DO NOT use it
+    in a for loop). If it finds the match it will return the total for the match.
+
+    """
+    current: List[Optional[ProjectTransactionsTotals]] = [None]
+    # protection for the case when the caller passes a list instead of an iterator
+    totals = iter(totals)
+
+    def inner(match: ProjectIdentity) -> Optional[ProjectTransactionsTotals]:
+        if is_same_project(current[0], match):
+            temp = current[0]
+            current[0] = None
+            return temp
+
+        if current[0] is not None and is_project_identity_before(match, current[0]):
+            # still haven't reach current no point looking further
+            return None
+
+        for total in totals:
+            if is_same_project(total, match):
+                # found it
+                return total
+
+            if is_project_identity_before(match, total):
+                # we passed after match, remember were we are no need to go further
+                current[0] = total
+                return None
+        return None
+
+    return inner
+
+
 def transactions_zip(
+    totals: Iterator[ProjectTransactionsTotals],
     left: Iterator[ProjectTransactions],
     right: Iterator[ProjectTransactions],
 ) -> Iterator[ProjectTransactions]:
     """
     returns a generator that zips left and right (when they match) and when not it re-aligns the sequence
+
+    if it finds a totals to match it consolidates the result with totals information as well
     """
 
     more_right = True
     more_left = True
     left_elm = None
     right_elm = None
+
+    get_next_total = next_totals(totals)
+
     while more_left or more_right:
         if more_right and right_elm is None:
             try:
@@ -275,35 +460,28 @@ def transactions_zip(
                 more_left = False
                 left_elm = None
 
+        if left_elm is None and right_elm is None:
+            return
+
         if right_elm is not None and left_elm is not None:
             # we have both right and left try to merge them if they point to the same entity
-            if (
-                left_elm["org_id"] == right_elm["org_id"]
-                and left_elm["project_id"] == right_elm["project_id"]
-            ):
-                yield merge_transactions(left_elm, right_elm)
+            if is_same_project(left_elm, right_elm):
+                yield merge_transactions(left_elm, right_elm, get_next_total(left_elm))
+                left_elm = None
+                right_elm = None
+            elif is_project_identity_before(left_elm, right_elm):
+                # left is before right (return left keep right for next iteration)
+                yield merge_transactions(left_elm, None, get_next_total(left_elm))
                 left_elm = None
+            else:  # project_before(right_elm, left_elm):
+                # right before left ( return right keep left for next iteration)
+                yield merge_transactions(right_elm, None, get_next_total(right_elm))
                 right_elm = None
-            else:
-                # the two elements do not match see which one is "smaller" and return it, keep the other
-                # for the next iteration
-                if left_elm["org_id"] < right_elm["org_id"]:
-                    yield left_elm
-                    left_elm = None
-                elif left_elm["org_id"] > right_elm["org_id"]:
-                    yield right_elm
-                    right_elm = None
-                # orgs are the sam try projects
-                elif left_elm["project_id"] < right_elm["project_id"]:
-                    yield left_elm
-                    left_elm = None
-                else:  # right_elm["project_id"] > left_elm["project_id"]
-                    yield right_elm
-                    right_elm = None
         else:
+            # only one is not None
             if left_elm is not None:
-                yield left_elm
+                yield merge_transactions(left_elm, None, get_next_total(left_elm))
                 left_elm = None
-            if right_elm is not None:
-                yield right_elm
+            elif right_elm is not None:
+                yield merge_transactions(right_elm, None, get_next_total(right_elm))
                 right_elm = None

+ 20 - 9
src/sentry/dynamic_sampling/tasks.py

@@ -1,13 +1,14 @@
 import logging
 from typing import Optional, Sequence, Tuple
 
-from sentry import features, quotas
+from sentry import features, options, quotas
 from sentry.dynamic_sampling.models.adjustment_models import AdjustedModel
 from sentry.dynamic_sampling.models.transaction_adjustment_model import adjust_sample_rate
 from sentry.dynamic_sampling.models.utils import DSElement
 from sentry.dynamic_sampling.prioritise_projects import fetch_projects_with_total_volumes
 from sentry.dynamic_sampling.prioritise_transactions import (
     ProjectTransactions,
+    fetch_project_transaction_totals,
     fetch_transactions_with_total_volumes,
     get_orgs_with_project_counts,
     transactions_zip,
@@ -24,7 +25,6 @@ from sentry.utils import metrics
 
 CACHE_KEY_TTL = 24 * 60 * 60 * 1000  # in milliseconds
 
-# TODO RaduW validate assumptions
 MAX_ORGS_PER_QUERY = 100
 MAX_PROJECTS_PER_QUERY = 5000
 MAX_TRANSACTIONS_PER_PROJECT = 20
@@ -130,20 +130,28 @@ def prioritise_transactions() -> None:
     metrics.incr("sentry.tasks.dynamic_sampling.prioritise_transactions.start", sample_rate=1.0)
     current_org: Optional[Organization] = None
     current_org_enabled = False
+
+    num_big_trans = int(
+        options.get("dynamic-sampling.prioritise_transactions.num_explicit_large_transactions")
+    )
+    num_small_trans = int(
+        options.get("dynamic-sampling.prioritise_transactions.num_explicit_small_transactions")
+    )
+
     with metrics.timer("sentry.tasks.dynamic_sampling.prioritise_transactions", sample_rate=1.0):
         for orgs in get_orgs_with_project_counts(MAX_ORGS_PER_QUERY, MAX_PROJECTS_PER_QUERY):
             # get the low and high transactions
-            # TODO can we do this in one query rather than two
             for project_transactions in transactions_zip(
+                fetch_project_transaction_totals(orgs),
                 fetch_transactions_with_total_volumes(
                     orgs,
                     large_transactions=True,
-                    max_transactions=MAX_TRANSACTIONS_PER_PROJECT // 2,
+                    max_transactions=num_big_trans,
                 ),
                 fetch_transactions_with_total_volumes(
                     orgs,
                     large_transactions=False,
-                    max_transactions=MAX_TRANSACTIONS_PER_PROJECT // 2,
+                    max_transactions=num_small_trans,
                 ),
             ):
 
@@ -175,6 +183,8 @@ def process_transaction_biases(project_transactions: ProjectTransactions) -> Non
     org_id = project_transactions["org_id"]
     project_id = project_transactions["project_id"]
     transactions = project_transactions["transaction_counts"]
+    total_num_transactions = project_transactions.get("total_num_transactions")
+    total_num_classes = project_transactions.get("total_num_classes")
     project = Project.objects.get_from_cache(id=project_id)
     sample_rate = quotas.get_blended_sample_rate(project)
 
@@ -182,17 +192,18 @@ def process_transaction_biases(project_transactions: ProjectTransactions) -> Non
         # no sampling => no rebalancing
         return
 
-    named_rates, global_rate = adjust_sample_rate(
-        transactions=transactions,
+    named_rates, implicit_rate = adjust_sample_rate(
+        classes=transactions,
         rate=sample_rate,
-        max_explicit_transactions=MAX_TRANSACTIONS_PER_PROJECT,
+        total_num_classes=total_num_classes,
+        total=total_num_transactions,
     )
 
     set_transactions_resampling_rates(
         org_id=org_id,
         proj_id=project_id,
         named_rates=named_rates,
-        default_rate=global_rate,
+        default_rate=implicit_rate,
         ttl_ms=CACHE_KEY_TTL,
     )
 

+ 4 - 1
src/sentry/options/defaults.py

@@ -668,7 +668,10 @@ register("dynamic-sampling.prioritise_projects.sample_rate", default=0.0)
 # controls how many orgs will be queried by the prioritise by transaction task
 # 0-> no orgs , 0.5 -> half of the orgs, 1.0 -> all orgs
 register("dynamic-sampling.prioritise_transactions.load_rate", default=0.0)
-
+# the number of large transactions to retrieve from Snuba for transaction re-balancing
+register("dynamic-sampling.prioritise_transactions.num_explicit_large_transactions", 30)
+# the number of large transactions to retrieve from Snuba for transaction re-balancing
+register("dynamic-sampling.prioritise_transactions.num_explicit_small_transactions", 0)
 # Killswitch for deriving code mappings
 register("post_process.derive-code-mappings", default=True)
 # Allows adjusting the GA percentage

+ 3 - 0
src/sentry/snuba/referrer.py

@@ -358,6 +358,9 @@ class ReferrerBase(Enum):
     DYNAMIC_SAMPLING_COUNTERS_FETCH_PROJECTS_WITH_COUNT_PER_TRANSACTION = (
         "dynamic_sampling.counters.fetch_projects_with_count_per_transaction_volumes"
     )
+    DYNAMIC_SAMPLING_COUNTERS_FETCH_PROJECTS_WITH_TRANSACTION_TOTALS = (
+        "dynamic_sampling.counters.fetch_projects_with_transaction_totals"
+    )
     DYNAMIC_SAMPLING_COUNTERS_FETCH_ACTIVE_ORGS = "dynamic_sampling.counters.fetch_active_orgs"
     EVENTSTORE_GET_EVENT_BY_ID_NODESTORE = "eventstore.get_event_by_id_nodestore"
     EVENTSTORE_GET_EVENTS = "eventstore.get_events"

+ 55 - 167
tests/sentry/dynamic_sampling/models/test_transaction_adjustment_model.py

@@ -2,12 +2,12 @@ import pytest
 
 from sentry.dynamic_sampling.models.transaction_adjustment_model import (
     adjust_sample_rate,
-    get_num_sampled_transactions,
-    total_transactions,
+    get_num_sampled_elements,
+    get_total,
 )
 
 
-def transactions(big: int, med: int, small: int):
+def create_transaction_counts(big: int, med: int, small: int):
     big_t = [(f"tb{i}", 1000 + i) for i in range(big)]
     med_t = [(f"tm{i}", 100 + i) for i in range(med)]
     small_t = [(f"ts{i}", 1 + i) for i in range(small)]
@@ -15,185 +15,73 @@ def transactions(big: int, med: int, small: int):
 
 
 test_resample_cases = [
-    transactions(big=3, med=0, small=2),
-    transactions(big=6, med=0, small=2),
-    transactions(big=3, med=0, small=0),
-    transactions(big=3, med=100, small=3),
-    transactions(big=3, med=100, small=30),
-    transactions(big=30, med=10, small=30),
+    create_transaction_counts(big=3, med=4, small=2),
+    create_transaction_counts(big=6, med=0, small=2),
+    create_transaction_counts(big=3, med=0, small=4),
+    create_transaction_counts(big=3, med=100, small=3),
+    create_transaction_counts(big=3, med=100, small=30),
+    create_transaction_counts(big=30, med=10, small=30),
+    create_transaction_counts(big=30, med=3, small=5),
 ]
-
-
-@pytest.mark.parametrize("sample_rate", [0.1, 0.5, 0.9, 1.0])
-@pytest.mark.parametrize("transactions", test_resample_cases)
-def test_maintains_overall_sample_rate(sample_rate, transactions):
-
-    MAX_EXPLICIT_TRANSACTIONS = 3
-    trans, global_rate = adjust_sample_rate(transactions, sample_rate, MAX_EXPLICIT_TRANSACTIONS)
-
-    assert len(trans) == MAX_EXPLICIT_TRANSACTIONS
-
-    # make sure we maintain the required sample rate
-    old_sampled_transactions = get_num_sampled_transactions(transactions, {}, sample_rate)
-    new_sampled_transactions = get_num_sampled_transactions(transactions, trans, global_rate)
-
-    assert old_sampled_transactions == pytest.approx(new_sampled_transactions)
-
-
-small_transactions = [
-    [
-        ("tb1", 200),
-        ("tb2", 300),
-        ("tb3", 400),
-        ("tb4", 400),
-        ("tb5", 400),
-        ("tb6", 400),
-        ("ts1", 7),
-        ("ts2", 4),
-        ("ts3", 7),
-    ],
+sample_rates = [0.01, 0.1, 0.5, 0.9, 0.99, 1.0]
+excluded_transactions = [
+    (0, None),  # full resample
+    (3, None),  # exclude first 3
+    (0, -3),  # exclude last 3
+    (3, -3),  # take 3 from both ends
 ]
 
 
-@pytest.mark.parametrize("sample_rate", [0.1, 0.5])
-@pytest.mark.parametrize("transactions", small_transactions)
-def test_few_small_transactions(sample_rate, transactions):
+@pytest.mark.parametrize("sample_rate", sample_rates)
+@pytest.mark.parametrize("transactions", test_resample_cases)
+@pytest.mark.parametrize("idx_low,idx_high", excluded_transactions)
+def test_maintains_overall_sample_rate(sample_rate, transactions, idx_low, idx_high):
     """
-    Test that when we only have a few small transactions, they are sampled at
-    their ideal size and the rest is globally adjusted
+    Tests that the overall sampling rate is maintained after applying new rates
     """
-    MAX_EXPLICIT_TRANSACTIONS = 3
-    explicit_transactions, global_rate = adjust_sample_rate(
-        transactions, sample_rate, MAX_EXPLICIT_TRANSACTIONS
-    )
+    explict_transactions = transactions[idx_low:idx_high]
+    total = get_total(transactions)
+    total_classes = len(transactions)
 
-    assert len(explicit_transactions) == MAX_EXPLICIT_TRANSACTIONS
-    # make sure we maintain the required sample rate
-    old_sampled_transactions = get_num_sampled_transactions(transactions, {}, sample_rate)
-    new_sampled_transactions = get_num_sampled_transactions(
-        transactions, explicit_transactions, global_rate
+    trans, global_rate = adjust_sample_rate(
+        explict_transactions, sample_rate, total_num_classes=total_classes, total=total
     )
-    assert old_sampled_transactions == pytest.approx(new_sampled_transactions)
-
-    # check that we selected the small transactions and that they are sample at the ideal rate
-    transactions = sorted(transactions, key=lambda x: x[1])
-    smallest = transactions[:MAX_EXPLICIT_TRANSACTIONS]
-
-    num_transactions = total_transactions(transactions)
-    total_budget = num_transactions * sample_rate
-    num_transactions_rate_1 = 0
-    # the budget per transaction is calculated excluding the transactions that are sampled at 1
-    for name, count in smallest:
-        assert name in explicit_transactions
-        rate = explicit_transactions[name]
-        if rate == 1:
-            total_budget -= count
-            num_transactions_rate_1 += 1
-
-    budget_per_transaction = total_budget / (len(transactions) - num_transactions_rate_1)
-    # now check that all explicit transactions that are not sampled at 1 are sampled
-    # at their optimal rate
-    for name, count in smallest:
-        rate = explicit_transactions[name]
-        if rate != 1:
-            assert rate * count == pytest.approx(budget_per_transaction)
-
-
-big_transactions = [
-    [
-        ("tb1", 1000),
-        ("tb2", 3000),
-        ("tb3", 4000),
-        ("tm1", 40),
-        ("tm2", 50),
-        ("ts1", 7),
-        ("ts2", 4),
-        ("ts3", 7),
-    ],
-]
 
-
-@pytest.mark.parametrize("sample_rate", [0.1, 0.5])
-@pytest.mark.parametrize("transactions", big_transactions)
-def test_few_big_transactions(sample_rate, transactions):
-    """
-    Test that when we only have a few bit transactions,
-    they are sampled at their ideal size and the
-    rest is globally adjusted
-    """
-    MAX_EXPLICIT_TRANSACTIONS = 3
-    explicit_transactions, global_rate = adjust_sample_rate(
-        transactions, sample_rate, MAX_EXPLICIT_TRANSACTIONS
-    )
-
-    assert len(explicit_transactions) == MAX_EXPLICIT_TRANSACTIONS
     # make sure we maintain the required sample rate
-    old_sampled_transactions = get_num_sampled_transactions(transactions, {}, sample_rate)
-    new_sampled_transactions = get_num_sampled_transactions(
-        transactions, explicit_transactions, global_rate
-    )
-    assert old_sampled_transactions == pytest.approx(new_sampled_transactions)
-
-    # check that we selected the small transactions and that they are sample at the ideal rate
-    transactions = sorted(transactions, key=lambda x: x[1])
-    largest = transactions[-MAX_EXPLICIT_TRANSACTIONS:]
+    old_sampled_transactions = get_num_sampled_elements(transactions, {}, sample_rate)
+    new_sampled_transactions = get_num_sampled_elements(transactions, trans, global_rate)
 
-    for name, count in largest:
-        assert name in explicit_transactions
-
-
-full_resample_transactions = [
-    [("tb1", 1000), ("tm1", 100), ("tm2", 200), ("ts1", 5)],
-    [("tb1", 1000), ("tb2", 2000), ("ts1", 2), ("ts2", 5)],
-    [("tb1", 1000), ("tb2", 1100), ("tb3", 2000), ("tb4", 5000)],
-]
+    assert new_sampled_transactions == pytest.approx(old_sampled_transactions)
 
 
-@pytest.mark.parametrize("sample_rate", [0.1, 0.5, 0.9, 1.0])
-@pytest.mark.parametrize("transactions", full_resample_transactions)
-def test_full_resample(sample_rate, transactions):
+@pytest.mark.parametrize("sample_rate", sample_rates)
+@pytest.mark.parametrize("transactions", test_resample_cases)
+@pytest.mark.parametrize("idx_low,idx_high", excluded_transactions)
+def test_explicit_elements_ideal_rate(sample_rate, transactions, idx_low, idx_high):
     """
-    Test that when we can adjust all transactions we adjust them
-    either at rate=1.0 or at a sampling rate that generates
-    an ideal number of samples.
+    Tests that the explicitly specified elements are sampled at their ideal rate.
+
+    Ideal sample rate means that the resulting number of sampled elements is the minimum between:
+    * all transactions in the class (sampled at rate 1.0)
+    * the budget per transaction
     """
+    explict_transactions = transactions[idx_low:idx_high]
+    total = get_total(transactions)
+    total_classes = len(transactions)
 
-    num_explicit_transaction_types = len(transactions)
-    explicit_transactions, global_rate = adjust_sample_rate(
-        transactions, sample_rate, num_explicit_transaction_types
+    trans, global_rate = adjust_sample_rate(
+        explict_transactions, sample_rate, total_num_classes=total_classes, total=total
     )
 
-    assert len(explicit_transactions) == num_explicit_transaction_types
-    # make sure we maintain the required sample rate
-    old_sampled_transactions = get_num_sampled_transactions(transactions, {}, sample_rate)
-    new_sampled_transactions = get_num_sampled_transactions(
-        transactions, explicit_transactions, global_rate
-    )
-    assert old_sampled_transactions == pytest.approx(new_sampled_transactions)
-
-    # check that transactions are either sampled at 1 or their ideal rate
-    num_transaction_types_fully_sampled = 0
-    full_budget = total_transactions(transactions) * sample_rate
-    for name, count in transactions:
-        # everything should be explicitly specified
-        assert name in explicit_transactions
-        rate = explicit_transactions[name]
-        if rate == 1.0:
-            num_transaction_types_fully_sampled += 1
-            full_budget -= count
-
-    if num_transaction_types_fully_sampled == num_explicit_transaction_types:
-        # all transactions sampled at 1.0 this means we must have specified the
-        # overall sample rate to be 1.0
-        assert sample_rate == 1.0
-    else:
-        budget_per_transaction_type = full_budget / (
-            num_explicit_transaction_types - num_transaction_types_fully_sampled
-        )
-        # everything that is not fully sampled should be at budget_per_transaction_type
-        for name, count in transactions:
-            # everything should be explicitly specified
-            assert name in explicit_transactions
-            rate = explicit_transactions[name]
-            if rate != 1.0:
-                assert rate * count == pytest.approx(budget_per_transaction_type)
+    ideal_number_of_elements_per_class = total * sample_rate / total_classes
+
+    for name, count in explict_transactions:
+        actual_rate = trans[name]
+
+        if ideal_number_of_elements_per_class > count:
+            assert actual_rate == 1.0  # tiny transactions not sampled
+        else:
+            assert (
+                actual_rate * count == pytest.approx(ideal_number_of_elements_per_class)
+                or actual_rate * count >= ideal_number_of_elements_per_class
+            )

+ 198 - 12
tests/sentry/dynamic_sampling/test_prioritise_transactions.py

@@ -3,9 +3,15 @@ from datetime import datetime, timezone
 from freezegun import freeze_time
 
 from sentry.dynamic_sampling.prioritise_transactions import (
+    ProjectIdentity,
+    ProjectTransactionsTotals,
+    fetch_project_transaction_totals,
     fetch_transactions_with_total_volumes,
     get_orgs_with_project_counts,
+    is_project_identity_before,
+    is_same_project,
     merge_transactions,
+    next_totals,
     transactions_zip,
 )
 from sentry.snuba.metrics import TransactionMRI
@@ -60,6 +66,12 @@ class PrioritiseProjectsSnubaQueryTest(BaseMetricsLayerTestCase, TestCase, Snuba
         }
         return idx + counts[name]
 
+    def get_total_counts_for_project(self, idx: int):
+        """
+        Get the total number of transactions and the number of transaction classes for a proj_idx
+        """
+        return 1 + 100 + 1000 + 2000 + 3000 + idx * 5, 5
+
     def test_get_orgs_with_transactions_respects_max_orgs(self):
         with self.options({"dynamic-sampling.prioritise_transactions.load_rate": 1.0}):
             actual = list(get_orgs_with_project_counts(2, 20))
@@ -111,17 +123,96 @@ class PrioritiseProjectsSnubaQueryTest(BaseMetricsLayerTestCase, TestCase, Snuba
                     assert name in expected_names
                     assert count == self.get_count_for_transaction(idx, name)
 
+    def test_fetch_transactions_with_total_volumes(self):
+        """
+        Create some transactions in some orgs and project and verify
+        that the total counts and total transaction types per project are
+        correctly returned
+        """
+
+        orgs = self.org_ids
+
+        for idx, totals in enumerate(fetch_project_transaction_totals(orgs)):
+            total_counts, num_classes = self.get_total_counts_for_project(idx)
+            assert totals["total_num_transactions"] == total_counts
+            assert totals["total_num_classes"] == num_classes
+
+
+def test_merge_transactions_full():
+    t1 = {
+        "project_id": 1,
+        "org_id": 2,
+        "transaction_counts": [("ts1", 10), ("tm2", 100)],
+        "total_num_transactions": None,
+        "total_num_classes": None,
+    }
+    t2 = {
+        "project_id": 1,
+        "org_id": 2,
+        "transaction_counts": [("tm2", 100), ("tl3", 1000)],
+        "total_num_transactions": None,
+        "total_num_classes": None,
+    }
+    counts = {"project_id": 1, "org_id": 2, "total_num_transactions": 5555, "total_num_classes": 20}
+    actual = merge_transactions(t1, t2, counts)
+
+    expected = {
+        "project_id": 1,
+        "org_id": 2,
+        "transaction_counts": [("ts1", 10), ("tm2", 100), ("tl3", 1000)],
+        "total_num_transactions": 5555,
+        "total_num_classes": 20,
+    }
 
-def test_merge_transactions():
-    t1 = {"project_id": 1, "org_id": 2, "transaction_counts": [("ts1", 10), ("tm2", 100)]}
-    t2 = {"project_id": 1, "org_id": 2, "transaction_counts": [("tm2", 100), ("tl3", 1000)]}
+    assert actual == expected
+
+
+def test_merge_transactions_missing_totals():
+    t1 = {
+        "project_id": 1,
+        "org_id": 2,
+        "transaction_counts": [("ts1", 10), ("tm2", 100)],
+        "total_num_transactions": None,
+        "total_num_classes": None,
+    }
+    t2 = {
+        "project_id": 1,
+        "org_id": 2,
+        "transaction_counts": [("tm2", 100), ("tl3", 1000)],
+        "total_num_transactions": None,
+        "total_num_classes": None,
+    }
 
-    actual = merge_transactions(t1, t2)
+    actual = merge_transactions(t1, t2, None)
 
     expected = {
         "project_id": 1,
         "org_id": 2,
         "transaction_counts": [("ts1", 10), ("tm2", 100), ("tl3", 1000)],
+        "total_num_transactions": None,
+        "total_num_classes": None,
+    }
+
+    assert actual == expected
+
+
+def test_merge_transactions_missing_right():
+    t1 = {
+        "project_id": 1,
+        "org_id": 2,
+        "transaction_counts": [("ts1", 10), ("tm2", 100)],
+        "total_num_transactions": None,
+        "total_num_classes": None,
+    }
+    counts = {"project_id": 1, "org_id": 2, "total_num_transactions": 5555, "total_num_classes": 20}
+    actual = merge_transactions(t1, None, counts)
+
+    expected = {
+        "project_id": 1,
+        "org_id": 2,
+        "transaction_counts": [("ts1", 10), ("tm2", 100)],
+        "total_num_transactions": 5555,
+        "total_num_classes": 20,
     }
 
     assert actual == expected
@@ -132,33 +223,128 @@ def test_transactions_zip():
     low = 2
     both = 3
 
-    def pt(org_id: int, proj_id: int, what: int):
+    def pt(org_id: int, proj_id: int, what: int, add_totals: bool = False):
         if what == high:
             transaction_counts = [("tm2", 100), ("tl3", 1000)]
         elif what == low:
             transaction_counts = [("ts1", 10), ("tm2", 100)]
-        else:
+        else:  # what == both
             transaction_counts = [("ts1", 10), ("tm2", 100), ("tl3", 1000)]
         return {
             "project_id": proj_id,
             "org_id": org_id,
-            "transaction_counts": transaction_counts,  # not relevant in zipping
+            "transaction_counts": transaction_counts,
+            "total_num_transactions": 5000 if add_totals else None,
+            "total_num_classes": 5 if add_totals else None,
+        }
+
+    def tot(org_id, proj_id):
+        return {
+            "project_id": proj_id,
+            "org_id": org_id,
+            "total_num_transactions": 5000,
+            "total_num_classes": 5,
         }
 
     trans_low = [pt(1, 1, low), pt(1, 2, low), pt(2, 1, low), pt(2, 3, low), pt(3, 2, low)]
     trans_high = [pt(2, 1, high), (pt(2, 2, high)), pt(3, 1, high), pt(3, 2, high), pt(3, 3, high)]
+    totals = [tot(1, 0), tot(1, 2), tot(1, 3), tot(2, 1), tot(2, 4), tot(3, 1), tot(3, 3)]
 
     expected = [
         pt(1, 1, low),
-        pt(1, 2, low),
-        pt(2, 1, both),
+        pt(1, 2, low, True),
+        pt(2, 1, both, True),
         pt(2, 2, high),
         pt(2, 3, low),
-        pt(3, 1, high),
+        pt(3, 1, high, True),
         pt(3, 2, both),
-        pt(3, 3, high),
+        pt(3, 3, high, True),
     ]
 
-    actual = list(transactions_zip((x for x in trans_low), (x for x in trans_high)))
+    actual = list(
+        transactions_zip((x for x in totals), (x for x in trans_low), (x for x in trans_high))
+    )
 
     assert actual == expected
+
+
+def test_same_project():
+    p1 = {"project_id": 1, "org_id": 2}
+    p1bis = {"project_id": 1, "org_id": 2}
+    p2 = {"project_id": 1, "org_id": 3}
+    p3 = {"project_id": 2, "org_id": 1}
+    p4 = {"project_id": 3, "org_id": 4}
+
+    assert is_same_project(p1, p1bis)
+    assert not is_same_project(p1, p2)
+    assert not is_same_project(p1, p3)
+    assert not is_same_project(p1, p4)
+
+
+def test_project_before():
+    p1 = {"project_id": 1, "org_id": 2}
+    p1bis = {"project_id": 1, "org_id": 2}
+    p2 = {"project_id": 1, "org_id": 3}
+    p3 = {"project_id": 2, "org_id": 2}
+    p4 = {"project_id": 2, "org_id": 1}
+
+    # same project
+    assert not is_project_identity_before(p1, p1bis)
+    assert not is_project_identity_before(p1bis, p1)
+
+    # different project_id
+    assert is_project_identity_before(p1, p2)
+    assert not is_project_identity_before(p2, p1)
+
+    # different org_id
+    assert is_project_identity_before(p1, p3)
+    assert not is_project_identity_before(p3, p1)
+
+    # just different
+    assert is_project_identity_before(p4, p1)
+    assert not is_project_identity_before(p1, p4)
+
+
+def test_next_totals():
+    def ct(org_id: int, project_id: int) -> ProjectTransactionsTotals:
+        return {
+            "project_id": project_id,
+            "org_id": org_id,
+            "total_num_transactions": 123,
+            "total_num_classes": 5,
+        }
+
+    def pi(org_id: int, project_id: int) -> ProjectIdentity:
+        return {
+            "project_id": project_id,
+            "org_id": org_id,
+        }
+
+    my_totals = iter([ct(1, 2), ct(1, 4), ct(1, 5), ct(1, 6), ct(1, 9), ct(2, 1)])
+
+    get_totals = next_totals(my_totals)
+
+    # current should be 1,2
+    # ask for something before 1,2
+    assert get_totals(pi(0, 1)) is None
+    assert get_totals(pi(0, 2)) is None
+    assert get_totals(pi(1, 1)) is None
+
+    # ask for 1.2
+    assert get_totals(pi(1, 2)) == ct(1, 2)
+    # ask again
+    assert get_totals(pi(1, 2)) is None
+    # jump a few totals
+    assert get_totals(pi(1, 6)) == ct(1, 6)
+    # make sure we don't go back
+    assert get_totals(pi(1, 5)) is None
+    # forcing it to go forward jumps just enough
+    assert get_totals(pi(1, 10)) is None
+    # but not too much
+    assert get_totals(pi(1, 11)) is None
+    assert get_totals(pi(1, 12)) is None
+    assert get_totals(pi(2, 1)) == ct(2, 1)
+    # and from now on we return None
+    assert get_totals(pi(3, 1)) is None
+    assert get_totals(pi(3, 2)) is None
+    assert get_totals(pi(3, 3)) is None

+ 43 - 0
tests/sentry/dynamic_sampling/test_tasks.py

@@ -152,3 +152,46 @@ class TestPrioritiseTransactionsTask(BaseMetricsLayerTestCase, TestCase, SnubaTe
                     assert (
                         transaction_name in tran_rate
                     )  # check we have some rate calculated for each transaction
+
+    @patch("sentry.dynamic_sampling.rules.base.quotas.get_blended_sample_rate")
+    def test_prioritise_transactions_partial(self, get_blended_sample_rate):
+        """
+        Test the V2 algorithm is used, only specified projects are balanced and the
+        rest get a global rate
+
+        Create orgs projects & transactions and then check that the task creates rebalancing data
+        in Redis
+        """
+        BLENDED_RATE = 0.25
+        get_blended_sample_rate.return_value = BLENDED_RATE
+
+        with self.options(
+            {
+                "dynamic-sampling.prioritise_transactions.load_rate": 1.0,
+                "dynamic-sampling.prioritise_transactions.num_explicit_large_transactions": 1,
+                "dynamic-sampling.prioritise_transactions.num_explicit_small_transactions": 1,
+            }
+        ):
+            with self.feature({"organizations:ds-prioritise-by-transaction-bias": True}):
+                with self.tasks():
+                    prioritise_transactions()
+
+        # now redis should contain rebalancing data for our projects
+        for org in self.orgs_info:
+            org_id = org["org_id"]
+            for proj_id in org["project_ids"]:
+                tran_rate, implicit_rate = get_transactions_resampling_rates(
+                    org_id=org_id, proj_id=proj_id, default_rate=0.1
+                )
+                # explicit transactions
+                for transaction_name in ["ts1", "tl5"]:
+                    assert (
+                        transaction_name in tran_rate
+                    )  # check we have some rate calculated for each transaction
+                # implicit transactions
+                for transaction_name in ["ts2", "tm3", "tl4"]:
+                    assert (
+                        transaction_name not in tran_rate
+                    )  # check we have some rate calculated for each transaction
+                # we do have some different rate for implicit transactions
+                assert implicit_rate != BLENDED_RATE