Browse Source

ref(dynamic-sampling) Switch recalibrate-orgs to Timed iterator and cleanup (#53059)

Radu Woinaroski 1 year ago
parent
commit
de331c216e

+ 215 - 4
src/sentry/dynamic_sampling/tasks/common.py

@@ -1,6 +1,8 @@
 import math
 import math
 import time
 import time
+from dataclasses import dataclass
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
+from functools import wraps
 from typing import Any, Generator, Iterator, List, Mapping, Optional, Protocol, Tuple
 from typing import Any, Generator, Iterator, List, Mapping, Optional, Protocol, Tuple
 
 
 import sentry_sdk
 import sentry_sdk
@@ -24,6 +26,7 @@ from sentry.dynamic_sampling.tasks.constants import (
     MAX_ORGS_PER_QUERY,
     MAX_ORGS_PER_QUERY,
     MAX_PROJECTS_PER_QUERY,
     MAX_PROJECTS_PER_QUERY,
     MAX_SECONDS,
     MAX_SECONDS,
+    RECALIBRATE_ORGS_QUERY_INTERVAL,
 )
 )
 from sentry.dynamic_sampling.tasks.helpers.sliding_window import (
 from sentry.dynamic_sampling.tasks.helpers.sliding_window import (
     extrapolate_monthly_volume,
     extrapolate_monthly_volume,
@@ -39,6 +42,10 @@ from sentry.snuba.metrics.naming_layer.mri import TransactionMRI
 from sentry.snuba.referrer import Referrer
 from sentry.snuba.referrer import Referrer
 from sentry.utils.snuba import raw_snql_query
 from sentry.utils.snuba import raw_snql_query
 
 
+ACTIVE_ORGS_DEFAULT_TIME_INTERVAL = timedelta(hours=1)
+ACTIVE_ORGS_DEFAULT_GRANULARITY = Granularity(3600)
+ACTIVE_ORGS_VOLUMES_DEFAULT_GRANULARITY = Granularity(60)
+
 
 
 class TimeoutException(Exception):
 class TimeoutException(Exception):
     def __init__(self, task_context: TaskContext, *args):
     def __init__(self, task_context: TaskContext, *args):
@@ -48,6 +55,46 @@ class TimeoutException(Exception):
         self.task_context = task_context
         self.task_context = task_context
 
 
 
 
+class LogStateCallable(Protocol):
+    """
+    A function that gets a DynamicSamplingLogState argument as the first parameter
+
+    This protocol is used by the timedFunctionWrapper to convert a function that can
+    update its state into a function that accepts a TaskContext and Timer and the rest
+    of the parameters of the wrapped function and executes updating the DynamicSamplingLogState
+    from the passed function.
+
+    """
+
+    def __call__(self, state: DynamicSamplingLogState, *args, **kwargs) -> Any:
+        ...
+
+    __name__: str
+
+
+def timed_function(name=None):
+    def timed_function_decorator(inner: LogStateCallable):
+        if name is not None:
+            func_name = name
+        else:
+            func_name = inner.__name__
+
+        @wraps(inner)
+        def wrapped(context: TaskContext, timer: Timer, *args, **kwargs):
+            if time.monotonic() > context.expiration_time:
+                raise TimeoutException(context)
+            with timer:
+                state = context.get_function_state(func_name)
+                val = inner(state, *args, **kwargs)
+                state.execution_time = timer.current()
+                context.set_function_state(func_name, state)
+                return val
+
+        return wrapped
+
+    return timed_function_decorator
+
+
 class ContextIterator(Protocol):
 class ContextIterator(Protocol):
     """
     """
     An iterator that also can return its current state ( used for logging)
     An iterator that also can return its current state ( used for logging)
@@ -142,7 +189,14 @@ class GetActiveOrgs:
     are no more orgs
     are no more orgs
     """
     """
 
 
-    def __init__(self, max_orgs: int = MAX_ORGS_PER_QUERY, max_projects: Optional[int] = None):
+    def __init__(
+        self,
+        max_orgs: int = MAX_ORGS_PER_QUERY,
+        max_projects: Optional[int] = None,
+        time_interval: timedelta = ACTIVE_ORGS_DEFAULT_TIME_INTERVAL,
+        granularity: Granularity = ACTIVE_ORGS_DEFAULT_GRANULARITY,
+    ):
+
         self.metric_id = indexer.resolve_shared_org(
         self.metric_id = indexer.resolve_shared_org(
             str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value)
             str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value)
         )
         )
@@ -152,11 +206,13 @@ class GetActiveOrgs:
         self.max_orgs = max_orgs
         self.max_orgs = max_orgs
         self.max_projects = max_projects
         self.max_projects = max_projects
         self.log_state = DynamicSamplingLogState()
         self.log_state = DynamicSamplingLogState()
+        self.time_interval = time_interval
+        self.granularity = granularity
 
 
     def __iter__(self):
     def __iter__(self):
         return self
         return self
 
 
-    def __next__(self):
+    def __next__(self) -> List[int]:
         self.log_state.num_iterations += 1
         self.log_state.num_iterations += 1
         if self._enough_results_cached():
         if self._enough_results_cached():
             # we have enough in the cache to satisfy the current iteration
             # we have enough in the cache to satisfy the current iteration
@@ -176,12 +232,12 @@ class GetActiveOrgs:
                     ],
                     ],
                     where=[
                     where=[
                         Condition(
                         Condition(
-                            Column("timestamp"), Op.GTE, datetime.utcnow() - timedelta(hours=1)
+                            Column("timestamp"), Op.GTE, datetime.utcnow() - self.time_interval
                         ),
                         ),
                         Condition(Column("timestamp"), Op.LT, datetime.utcnow()),
                         Condition(Column("timestamp"), Op.LT, datetime.utcnow()),
                         Condition(Column("metric_id"), Op.EQ, self.metric_id),
                         Condition(Column("metric_id"), Op.EQ, self.metric_id),
                     ],
                     ],
-                    granularity=Granularity(3600),
+                    granularity=self.granularity,
                     orderby=[
                     orderby=[
                         OrderBy(Column("org_id"), Direction.ASC),
                         OrderBy(Column("org_id"), Direction.ASC),
                     ],
                     ],
@@ -269,6 +325,7 @@ class GetActiveOrgs:
         return ret_val
         return ret_val
 
 
 
 
+# TODO this is obsolete replace it's usages with GetActiveOrgs
 def get_active_orgs_with_projects_counts(
 def get_active_orgs_with_projects_counts(
     max_orgs: int = MAX_ORGS_PER_QUERY, max_projects: int = MAX_PROJECTS_PER_QUERY
     max_orgs: int = MAX_ORGS_PER_QUERY, max_projects: int = MAX_PROJECTS_PER_QUERY
 ) -> Generator[List[int], None, None]:
 ) -> Generator[List[int], None, None]:
@@ -345,6 +402,160 @@ def get_active_orgs_with_projects_counts(
         yield [org_id for org_id, _ in last_result]
         yield [org_id for org_id, _ in last_result]
 
 
 
 
+@dataclass(frozen=True)
+class OrganizationDataVolume:
+    """
+    Represents the total and indexed number of transactions received by an organisation
+    (in a particular interval of time).
+    """
+
+    # organisation id
+    org_id: int
+    # total number of transactions
+    total: int
+    # number of transactions indexed (i.e. stored)
+    indexed: int
+
+    def is_valid_for_recalibration(self):
+        return self.total > 0 and self.indexed > 0
+
+
+class GetActiveOrgsVolumes:
+    """
+    Fetch organisations volumes in batches.
+    A batch will return at max max_orgs elements
+    """
+
+    def __init__(
+        self,
+        max_orgs: int = MAX_ORGS_PER_QUERY,
+        time_interval: timedelta = RECALIBRATE_ORGS_QUERY_INTERVAL,
+        granularity: Granularity = ACTIVE_ORGS_VOLUMES_DEFAULT_GRANULARITY,
+    ):
+
+        self.metric_id = indexer.resolve_shared_org(
+            str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value)
+        )
+        decision_string_id = indexer.resolve_shared_org("decision")
+        decision_tag = f"tags_raw[{decision_string_id}]"
+
+        self.keep_count_column = Function(
+            "sumIf",
+            [
+                Column("value"),
+                Function(
+                    "equals",
+                    [Column(decision_tag), "keep"],
+                ),
+            ],
+            alias="keep_count",
+        )
+
+        self.offset = 0
+        self.last_result: List[OrganizationDataVolume] = []
+        self.has_more_results = True
+        self.max_orgs = max_orgs
+        self.log_state = DynamicSamplingLogState()
+        self.granularity = granularity
+        self.time_interval = time_interval
+
+    def __iter__(self):
+        return self
+
+    def __next__(self) -> List[OrganizationDataVolume]:
+        self.log_state.num_iterations += 1
+        if self._enough_results_cached():
+            # we have enough in the cache to satisfy the current iteration
+            return self._get_from_cache()
+
+        if self.has_more_results:
+            # not enough for the current iteration and data still in the db top it up from db
+            query = (
+                Query(
+                    match=Entity(EntityKey.GenericOrgMetricsCounters.value),
+                    select=[
+                        Function("sum", [Column("value")], "total_count"),
+                        Column("org_id"),
+                        self.keep_count_column,
+                    ],
+                    groupby=[
+                        Column("org_id"),
+                    ],
+                    where=[
+                        Condition(
+                            Column("timestamp"), Op.GTE, datetime.utcnow() - self.time_interval
+                        ),
+                        Condition(Column("timestamp"), Op.LT, datetime.utcnow()),
+                        Condition(Column("metric_id"), Op.EQ, self.metric_id),
+                    ],
+                    granularity=self.granularity,
+                    orderby=[
+                        OrderBy(Column("org_id"), Direction.ASC),
+                    ],
+                )
+                .set_limit(CHUNK_SIZE + 1)
+                .set_offset(self.offset)
+            )
+            request = Request(
+                dataset=Dataset.PerformanceMetrics.value, app_id="dynamic_sampling", query=query
+            )
+            self.log_state.num_db_calls += 1
+            data = raw_snql_query(
+                request,
+                referrer=Referrer.DYNAMIC_SAMPLING_COUNTERS_GET_ORG_TRANSACTION_VOLUMES.value,
+            )["data"]
+            count = len(data)
+
+            self.has_more_results = count > CHUNK_SIZE
+            self.offset += CHUNK_SIZE
+            if self.has_more_results:
+                data = data[:-1]
+            self.log_state.num_rows_total += len(data)
+            for row in data:
+                self.last_result.append(
+                    OrganizationDataVolume(
+                        org_id=row["org_id"], total=row["total_count"], indexed=row["keep_count"]
+                    )
+                )
+
+        if len(self.last_result) > 0:
+            # we have some data left return up to the max amount
+            return self._get_from_cache()  # we still have something left in cache
+        else:
+            # nothing left in the DB or cache
+            raise StopIteration()
+
+    def get_current_state(self):
+        """
+        Returns the current state of the iterator (how many orgs and projects it has iterated over)
+
+        part of the ContexIterator protocol
+
+        """
+        return self.log_state
+
+    def set_current_state(self, log_state: DynamicSamplingLogState):
+        self.log_state = log_state
+
+    def _enough_results_cached(self):
+        """
+        Return true if we have enough data to return a full batch in the cache (i.e. last_result)
+        """
+        return len(self.last_result) >= self.max_orgs
+
+    def _get_from_cache(self) -> List[OrganizationDataVolume]:
+        """
+        Returns a batch from cache and removes the elements returned from the cache
+        """
+        if len(self.last_result) >= self.max_orgs:
+            ret_val = self.last_result[: self.max_orgs]
+            self.last_result = self.last_result[self.max_orgs :]
+        else:
+            ret_val = self.last_result
+            self.last_result = []
+        return ret_val
+
+
 def fetch_orgs_with_total_root_transactions_count(
 def fetch_orgs_with_total_root_transactions_count(
     org_ids: List[int], window_size: int
     org_ids: List[int], window_size: int
 ) -> Mapping[OrganizationId, int]:
 ) -> Mapping[OrganizationId, int]:

+ 46 - 183
src/sentry/dynamic_sampling/tasks/recalibrate_orgs.py

@@ -1,24 +1,17 @@
-import time
-from dataclasses import dataclass
-from datetime import datetime, timedelta
-from typing import Generator, List
-
-from snuba_sdk import (
-    Column,
-    Condition,
-    Direction,
-    Entity,
-    Function,
-    Granularity,
-    Op,
-    OrderBy,
-    Query,
-    Request,
+from sentry_sdk import capture_message, set_extra
+from snuba_sdk import Granularity
+
+from sentry.dynamic_sampling.tasks.common import (
+    GetActiveOrgsVolumes,
+    OrganizationDataVolume,
+    TimedIterator,
+    TimeoutException,
+    get_adjusted_base_rate_from_cache_or_compute,
 )
 )
-
-from sentry.dynamic_sampling.tasks.common import get_adjusted_base_rate_from_cache_or_compute
 from sentry.dynamic_sampling.tasks.constants import (
 from sentry.dynamic_sampling.tasks.constants import (
+    CHUNK_SIZE,
     MAX_REBALANCE_FACTOR,
     MAX_REBALANCE_FACTOR,
+    MAX_SECONDS,
     MIN_REBALANCE_FACTOR,
     MIN_REBALANCE_FACTOR,
     RECALIBRATE_ORGS_QUERY_INTERVAL,
     RECALIBRATE_ORGS_QUERY_INTERVAL,
 )
 )
@@ -30,18 +23,15 @@ from sentry.dynamic_sampling.tasks.helpers.recalibrate_orgs import (
 )
 )
 from sentry.dynamic_sampling.tasks.logging import (
 from sentry.dynamic_sampling.tasks.logging import (
     log_action_if,
     log_action_if,
-    log_query_timeout,
     log_recalibrate_org_error,
     log_recalibrate_org_error,
     log_recalibrate_org_state,
     log_recalibrate_org_state,
     log_sample_rate_source,
     log_sample_rate_source,
+    log_task_execution,
+    log_task_timeout,
 )
 )
+from sentry.dynamic_sampling.tasks.task_context import TaskContext
 from sentry.dynamic_sampling.tasks.utils import dynamic_sampling_task
 from sentry.dynamic_sampling.tasks.utils import dynamic_sampling_task
-from sentry.sentry_metrics import indexer
-from sentry.snuba.dataset import Dataset, EntityKey
-from sentry.snuba.metrics.naming_layer.mri import TransactionMRI
-from sentry.snuba.referrer import Referrer
 from sentry.tasks.base import instrumented_task
 from sentry.tasks.base import instrumented_task
-from sentry.utils.snuba import raw_snql_query
 
 
 # Since we are using a granularity of 60 (minute granularity), we want to have a higher time upper limit for executing
 # Since we are using a granularity of 60 (minute granularity), we want to have a higher time upper limit for executing
 # multiple queries on Snuba.
 # multiple queries on Snuba.
@@ -55,24 +45,6 @@ class RecalibrationError(Exception):
         super().__init__(self.message)
         super().__init__(self.message)
 
 
 
 
-@dataclass(frozen=True)
-class OrganizationDataVolume:
-    """
-    Represents the total and indexed number of transactions received by an organisation
-    (in a particular interval of time).
-    """
-
-    # organisation id
-    org_id: int
-    # total number of transactions
-    total: int
-    # number of transactions indexed (i.e. stored)
-    indexed: int
-
-    def is_valid_for_recalibration(self):
-        return self.total > 0 and self.indexed > 0
-
-
 def orgs_to_check(org_volume: OrganizationDataVolume):
 def orgs_to_check(org_volume: OrganizationDataVolume):
     return lambda: org_volume.org_id in [1, 1407395]
     return lambda: org_volume.org_id in [1, 1407395]
 
 
@@ -87,20 +59,36 @@ def orgs_to_check(org_volume: OrganizationDataVolume):
 )
 )
 @dynamic_sampling_task
 @dynamic_sampling_task
 def recalibrate_orgs() -> None:
 def recalibrate_orgs() -> None:
-    for orgs in get_active_orgs(1000):
-        log_action_if("fetching_orgs", {"orgs": orgs}, lambda: True)
-
-        for org_volume in fetch_org_volumes(orgs):
-            try:
-                log_action_if(
-                    "starting_recalibration",
-                    {"org_id": org_volume.org_id},
-                    orgs_to_check(org_volume),
-                )
-
-                recalibrate_org(org_volume)
-            except Exception as e:
-                log_recalibrate_org_error(org_volume.org_id, str(e))
+    context = TaskContext("sentry.dynamic_sampling.tasks.recalibrate_orgs", MAX_SECONDS)
+
+    try:
+        for org_volumes in TimedIterator(
+            context,
+            GetActiveOrgsVolumes(
+                max_orgs=CHUNK_SIZE,
+                time_interval=RECALIBRATE_ORGS_QUERY_INTERVAL,
+                granularity=Granularity(60),
+            ),
+        ):
+            for org_volume in org_volumes:
+                try:
+                    log_action_if(
+                        "starting_recalibration",
+                        {"org_id": org_volume.org_id},
+                        orgs_to_check(org_volume),
+                    )
+                    recalibrate_org(org_volume)
+                except RecalibrationError as e:
+                    set_extra("context-data", context.to_dict())
+                    log_recalibrate_org_error(org_volume.org_id, str(e))
+    except TimeoutException:
+        set_extra("context-data", context.to_dict())
+        log_task_timeout(context)
+        raise
+    else:
+        set_extra("context-data", context.to_dict())
+        capture_message("timing for sentry.dynamic_sampling.tasks.boost_low_volume_projects")
+        log_task_execution(context)
 
 
 
 
 def recalibrate_org(org_volume: OrganizationDataVolume) -> None:
 def recalibrate_org(org_volume: OrganizationDataVolume) -> None:
@@ -150,136 +138,11 @@ def recalibrate_org(org_volume: OrganizationDataVolume) -> None:
         delete_adjusted_factor(org_volume.org_id)
         delete_adjusted_factor(org_volume.org_id)
         raise RecalibrationError(
         raise RecalibrationError(
             org_id=org_volume.org_id,
             org_id=org_volume.org_id,
-            message=f"factor {adjusted_factor} outside of the acceptable range [{MIN_REBALANCE_FACTOR}..{MAX_REBALANCE_FACTOR}]",
+            message=f"factor {adjusted_factor} outside of the acceptable range [{MIN_REBALANCE_FACTOR}.."
+            f"{MAX_REBALANCE_FACTOR}]",
         )
         )
 
 
     # At the end we set the adjusted factor.
     # At the end we set the adjusted factor.
     set_guarded_adjusted_factor(org_volume.org_id, adjusted_factor)
     set_guarded_adjusted_factor(org_volume.org_id, adjusted_factor)
 
 
     log_action_if("set_adjusted_factor", {"org_id": org_volume.org_id}, orgs_to_check(org_volume))
     log_action_if("set_adjusted_factor", {"org_id": org_volume.org_id}, orgs_to_check(org_volume))
-
-
-def get_active_orgs(
-    max_orgs: int, time_interval: timedelta = RECALIBRATE_ORGS_QUERY_INTERVAL
-) -> Generator[List[int], None, None]:
-    """
-    Fetch organisations in batches.
-    A batch will return at max max_orgs elements
-    """
-    start_time = time.time()
-    metric_id = indexer.resolve_shared_org(str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value))
-    offset = 0
-
-    while (time.time() - start_time) < RECALIBRATE_ORGS_MAX_SECONDS:
-        query = (
-            Query(
-                match=Entity(EntityKey.GenericOrgMetricsCounters.value),
-                select=[
-                    Function("uniq", [Column("project_id")], "num_projects"),
-                    Column("org_id"),
-                ],
-                groupby=[
-                    Column("org_id"),
-                ],
-                where=[
-                    Condition(Column("timestamp"), Op.GTE, datetime.utcnow() - time_interval),
-                    Condition(Column("timestamp"), Op.LT, datetime.utcnow()),
-                    Condition(Column("metric_id"), Op.EQ, metric_id),
-                ],
-                granularity=Granularity(60),
-                orderby=[
-                    OrderBy(Column("org_id"), Direction.ASC),
-                ],
-            )
-            .set_limit(max_orgs + 1)
-            .set_offset(offset)
-        )
-        request = Request(
-            dataset=Dataset.PerformanceMetrics.value, app_id="dynamic_sampling", query=query
-        )
-        data = raw_snql_query(
-            request,
-            referrer=Referrer.DYNAMIC_SAMPLING_COUNTERS_GET_ACTIVE_ORGS.value,
-        )["data"]
-        count = len(data)
-        more_results = count > max_orgs
-        offset += max_orgs
-        if more_results:
-            data = data[:-1]
-
-        ret_val = []
-
-        for row in data:
-            ret_val.append(row["org_id"])
-
-        yield ret_val
-
-        if not more_results:
-            return
-    else:
-        log_query_timeout(
-            query="get_active_orgs", offset=offset, timeout_seconds=RECALIBRATE_ORGS_MAX_SECONDS
-        )
-
-
-def fetch_org_volumes(
-    org_ids: List[int], query_interval: timedelta = RECALIBRATE_ORGS_QUERY_INTERVAL
-) -> List[OrganizationDataVolume]:
-    """
-    Returns the number of total and indexed transactions received by all organisations in the
-    specified interval.
-    """
-    transaction_string_id = indexer.resolve_shared_org("decision")
-    transaction_tag = f"tags_raw[{transaction_string_id}]"
-    metric_id = indexer.resolve_shared_org(str(TransactionMRI.COUNT_PER_ROOT_PROJECT.value))
-    where = [
-        Condition(Column("timestamp"), Op.GTE, datetime.utcnow() - query_interval),
-        Condition(Column("timestamp"), Op.LT, datetime.utcnow()),
-        Condition(Column("metric_id"), Op.EQ, metric_id),
-        Condition(Column("org_id"), Op.IN, org_ids),
-    ]
-
-    keep_count = Function(
-        "sumIf",
-        [
-            Column("value"),
-            Function(
-                "equals",
-                [Column(transaction_tag), "keep"],
-            ),
-        ],
-        alias="keep_count",
-    )
-
-    ret_val: List[OrganizationDataVolume] = []
-
-    query = Query(
-        match=Entity(EntityKey.GenericOrgMetricsCounters.value),
-        select=[
-            Function("sum", [Column("value")], "total_count"),
-            Column("org_id"),
-            keep_count,
-        ],
-        groupby=[Column("org_id")],
-        where=where,
-        granularity=Granularity(60),
-        orderby=[
-            OrderBy(Column("org_id"), Direction.ASC),
-        ],
-    )
-    request = Request(
-        dataset=Dataset.PerformanceMetrics.value, app_id="dynamic_sampling", query=query
-    )
-    data = raw_snql_query(
-        request,
-        referrer=Referrer.DYNAMIC_SAMPLING_COUNTERS_GET_ORG_TRANSACTION_VOLUMES.value,
-    )["data"]
-
-    for row in data:
-        ret_val.append(
-            OrganizationDataVolume(
-                org_id=row["org_id"], total=row["total_count"], indexed=row["keep_count"]
-            )
-        )
-
-    return ret_val

+ 0 - 83
tests/sentry/dynamic_sampling/tasks/test_boost_low_volume_transactions.py

@@ -1,5 +1,4 @@
 from datetime import timedelta
 from datetime import timedelta
-from typing import Tuple
 
 
 from django.utils import timezone
 from django.utils import timezone
 from freezegun import freeze_time
 from freezegun import freeze_time
@@ -17,7 +16,6 @@ from sentry.dynamic_sampling.tasks.boost_low_volume_transactions import (
     transactions_zip,
     transactions_zip,
 )
 )
 from sentry.dynamic_sampling.tasks.common import GetActiveOrgs
 from sentry.dynamic_sampling.tasks.common import GetActiveOrgs
-from sentry.dynamic_sampling.tasks.recalibrate_orgs import fetch_org_volumes, get_active_orgs
 from sentry.snuba.metrics.naming_layer.mri import TransactionMRI
 from sentry.snuba.metrics.naming_layer.mri import TransactionMRI
 from sentry.testutils import BaseMetricsLayerTestCase, SnubaTestCase, TestCase
 from sentry.testutils import BaseMetricsLayerTestCase, SnubaTestCase, TestCase
 
 
@@ -362,84 +360,3 @@ def test_next_totals():
     assert get_totals(pi(3, 1)) is None
     assert get_totals(pi(3, 1)) is None
     assert get_totals(pi(3, 2)) is None
     assert get_totals(pi(3, 2)) is None
     assert get_totals(pi(3, 3)) is None
     assert get_totals(pi(3, 3)) is None
-
-
-@freeze_time(MOCK_DATETIME)
-class FetchOrgVolumesSnubaQueryTest(BaseMetricsLayerTestCase, TestCase, SnubaTestCase):
-    @property
-    def now(self):
-        return MOCK_DATETIME
-
-    def setUp(self):
-        super().setUp()
-        self.orgs_info = []
-        self.num_orgs = 3
-        self.num_proj = 2
-        self.num_trans = 2
-        # create some orgs, projects and transactions
-        for org_idx in range(self.num_orgs):
-            org = self.create_organization(f"test-org{org_idx}")
-            org_info = {"org_id": org.id, "project_ids": []}
-            self.orgs_info.append(org_info)
-            for proj_idx in range(self.num_proj):
-                p = self.create_project(organization=org)
-                org_info["project_ids"].append(p.id)
-                for name_idx in range(1, self.num_trans + 1):
-                    # make up some unique count
-                    count_transactions_keep = name_idx + proj_idx * 10
-                    count_transactions_drop = count_transactions_keep + org_idx * 100
-
-                    self.store_performance_metric(
-                        name=TransactionMRI.COUNT_PER_ROOT_PROJECT.value,
-                        tags={"transaction": f"t-{name_idx}", "decision": "drop"},
-                        minutes_before_now=2,
-                        value=count_transactions_drop,
-                        project_id=p.id,
-                        org_id=org.id,
-                    )
-                    self.store_performance_metric(
-                        name=TransactionMRI.COUNT_PER_ROOT_PROJECT.value,
-                        tags={"transaction": f"t-{name_idx}", "decision": "keep"},
-                        minutes_before_now=2,
-                        value=count_transactions_keep,
-                        project_id=p.id,
-                        org_id=org.id,
-                    )
-        self.org_ids = [org["org_id"] for org in self.orgs_info]
-
-    def get_counts_for_org(self, org_id: int) -> Tuple[int, int]:
-        for org_idx, cur_org_id in enumerate(self.org_ids):
-            if cur_org_id == org_id:
-                break
-        else:
-            raise ValueError("Invalid organisation passed")
-
-        trans = self.num_trans
-        projs = self.num_proj
-
-        # calculate how many transactions we created in the setup for this org
-        total_keep = int(trans * (trans + 1) / 2) * projs + 10 * (projs - 1) * trans
-        total_drop = total_keep + 100 * org_idx * trans * projs
-
-        return total_drop + total_keep, total_keep
-
-    def test_fetch_org_volumes(self):
-        org_counts = fetch_org_volumes(self.org_ids, timedelta(minutes=5))
-
-        for actual_counts in org_counts:
-            current_org = actual_counts.org_id
-            expected_total, expected_keep = self.get_counts_for_org(current_org)
-
-            assert actual_counts.total == expected_total
-            assert actual_counts.indexed == expected_keep
-
-    def test_get_active_orgs(self):
-        all_orgs = []
-        for orgs in get_active_orgs(2, time_interval=timedelta(minutes=5)):
-            all_orgs += orgs
-
-        # test we return the expected number of unique orgs
-        assert len(set(all_orgs)) == len(self.org_ids)
-
-        for org in self.org_ids:
-            assert org in self.org_ids

+ 145 - 1
tests/sentry/dynamic_sampling/tasks/test_common.py

@@ -4,8 +4,15 @@ import pytest
 from django.utils import timezone
 from django.utils import timezone
 from freezegun import freeze_time
 from freezegun import freeze_time
 
 
-from sentry.dynamic_sampling.tasks.common import GetActiveOrgs, TimedIterator, TimeoutException
+from sentry.dynamic_sampling.tasks.common import (
+    GetActiveOrgs,
+    GetActiveOrgsVolumes,
+    TimedIterator,
+    TimeoutException,
+    timed_function,
+)
 from sentry.dynamic_sampling.tasks.task_context import DynamicSamplingLogState, TaskContext
 from sentry.dynamic_sampling.tasks.task_context import DynamicSamplingLogState, TaskContext
+from sentry.dynamic_sampling.tasks.utils import Timer
 from sentry.snuba.metrics.naming_layer import TransactionMRI
 from sentry.snuba.metrics.naming_layer import TransactionMRI
 from sentry.testutils import BaseMetricsLayerTestCase, SnubaTestCase, TestCase
 from sentry.testutils import BaseMetricsLayerTestCase, SnubaTestCase, TestCase
 
 
@@ -138,3 +145,140 @@ class TestGetActiveOrgs(BaseMetricsLayerTestCase, TestCase, SnubaTestCase):
             total_orgs += num_orgs
             total_orgs += num_orgs
             assert num_orgs == 2  # only 2 orgs since we limit the number of projects
             assert num_orgs == 2  # only 2 orgs since we limit the number of projects
         assert total_orgs == 10
         assert total_orgs == 10
+
+
+def test_timed_function_decorator_updates_state():
+    """
+    Tests that the decorator correctly extracts the state
+    and passes it to the inner function.
+
+    At the end the Context should be properly updated for the wrapped function
+    It works with the default function name and also with custom names
+
+    """
+    t = Timer()
+    context = TaskContext(name="TC", num_seconds=60.0)
+
+    @timed_function()
+    def f1(state: DynamicSamplingLogState, x: int, y: str):
+        state.num_iterations = 1
+
+    @timed_function("f2x")
+    def f2(state: DynamicSamplingLogState, x: int, y: str):
+        state.num_iterations = 2
+
+    f1(context, t, 1, "x")
+    f2(context, t, 1, "x")
+
+    f1_state = context.get_function_state("f1")
+    assert f1_state is not None
+    assert f1_state.num_iterations == 1
+
+    f2_state = context.get_function_state("f2x")
+    assert f2_state is not None
+    assert f2_state.num_iterations == 2
+
+
+def test_timed_function_correctly_times_inner_function():
+    with freeze_time("2023-07-14 10:00:00") as frozen_time:
+        t = Timer()
+        context = TaskContext(name="TC", num_seconds=60.0)
+
+        @timed_function()
+        def f1(state: DynamicSamplingLogState, x: int, y: str):
+            state.num_iterations = 1
+            frozen_time.tick()
+
+        f1(context, t, 1, "x")
+        frozen_time.tick()
+        f1(context, t, 1, "x")
+
+        # two seconds passed inside f1 ( one for each call)
+        assert t.current() == 2.0
+
+
+def test_timed_function_correctly_raises_when_task_expires():
+    with freeze_time("2023-07-14 10:00:00") as frozen_time:
+        t = Timer()
+        context = TaskContext(name="TC", num_seconds=2.0)
+
+        @timed_function()
+        def f1(state: DynamicSamplingLogState, x: int, y: str):
+            state.num_iterations = 1
+            frozen_time.tick()
+
+        f1(context, t, 1, "x")
+        assert t.current() == 1.0
+        frozen_time.tick()
+        assert t.current() == 1.0  # timer should not be moving ouside the function
+        f1(context, t, 1, "x")
+
+        # two seconds passed inside f1 ( one for each call)
+        assert t.current() == 2.0
+
+        with pytest.raises(TimeoutException):
+            f1(context, t, 1, "x")
+
+        # the tick should not advance ( the function should not have been called)
+        assert t.current() == 2.0
+
+
+NOW_ISH = timezone.now().replace(second=0, microsecond=0)
+
+
+@freeze_time(MOCK_DATETIME)
+class TestGetActiveOrgsVolumes(BaseMetricsLayerTestCase, TestCase, SnubaTestCase):
+    def setUp(self):
+        # create 12 orgs each and some transactions with a 2/1 drop/keep rate
+        for i in range(12):
+            org = self.create_organization(f"org-{i}")
+            project = self.create_project(organization=org)
+            for decision, value in [("drop", 2), ("keep", 1)]:
+                self.store_performance_metric(
+                    name=TransactionMRI.COUNT_PER_ROOT_PROJECT.value,
+                    tags={"transaction": "foo_transaction", "decision": decision},
+                    minutes_before_now=1,
+                    value=value,
+                    project_id=project.id,
+                    org_id=org.id,
+                )
+
+    @property
+    def now(self):
+        return MOCK_DATETIME
+
+    def test_get_active_orgs_volumes_exact_batch_match(self):
+        """
+        gets active org volumes, with a batch size multiple of
+        number of elements
+        """
+        total_orgs = 0
+        for orgs in GetActiveOrgsVolumes(max_orgs=3):
+            num_orgs = len(orgs)
+            total_orgs += num_orgs
+            assert num_orgs == 3  # first batch should be full
+            for org in orgs:
+                assert org.total == 3
+                assert org.indexed == 1
+        assert total_orgs == 12
+
+    def test_get_active_orgs_volumes(self):
+        """
+        gets active org volumes, with a batch size that is not a multiple
+        of the number of elements in the DB
+        """
+        total_orgs = 0
+        for idx, orgs in enumerate(GetActiveOrgsVolumes(max_orgs=5)):
+            num_orgs = len(orgs)
+            total_orgs += num_orgs
+            if idx in [0, 1]:
+                assert num_orgs == 5  # first two batches should be full
+            elif idx == 2:
+                assert num_orgs == 2  # last batch not full
+            else:
+                pytest.fail(f"Unexpected index {idx} only 3 iterations expected.")
+            for org in orgs:
+                assert org.total == 3
+                assert org.indexed == 1
+
+        assert total_orgs == 12