Browse Source

feat(stats-detectors): Add per project ratelimit for ema detection (#58969)

As a precaution, we want to able to rate limit the number of possible
regressions surfaced per project. This is to prevent too many 14 days
queries from being made in the next step in the analysis. This change
introduces an option `statistical_detectors.ratelimit.ema` that allows
us to set a limit for the number of possible regressions per project. It
defaults to -1 meaning no limit, and any non negative value is the
number of possible regressions allowed per project.
Tony Xiao 1 year ago
parent
commit
911561cf50

+ 6 - 0
src/sentry/options/defaults.py

@@ -1580,6 +1580,12 @@ register(
     default=14,
     flags=FLAG_PRIORITIZE_DISK | FLAG_AUTOMATOR_MODIFIABLE,
 )
+register(
+    "statistical_detectors.ratelimit.ema",
+    type=Int,
+    default=-1,
+    flags=FLAG_PRIORITIZE_DISK | FLAG_AUTOMATOR_MODIFIABLE,
+)
 
 register(
     "options_automator_slack_webhook_enabled",

+ 15 - 11
src/sentry/statistical_detectors/algorithm.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import logging
 from dataclasses import dataclass
 from datetime import datetime, timezone
-from typing import Any, Callable, Mapping, MutableMapping, Optional
+from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple
 
 from sentry.statistical_detectors.detector import (
     DetectorAlgorithm,
@@ -100,7 +100,7 @@ class MovingAverageDetector(DetectorAlgorithm):
 
 
 class MovingAverageCrossOverDetector(MovingAverageDetector):
-    def update(self, payload: DetectorPayload) -> Optional[TrendType]:
+    def update(self, payload: DetectorPayload) -> Tuple[Optional[TrendType], float]:
         if self.timestamp is not None and self.timestamp > payload.timestamp:
             # In the event that the timestamp is before the payload's timestamps,
             # we do not want to process this payload.
@@ -111,7 +111,7 @@ class MovingAverageCrossOverDetector(MovingAverageDetector):
                 payload.timestamp.isoformat(),
                 self.timestamp.isoformat(),
             )
-            return None
+            return None, 0
 
         old_moving_avg_short = self.moving_avg_short.value
         old_moving_avg_long = self.moving_avg_long.value
@@ -125,6 +125,8 @@ class MovingAverageCrossOverDetector(MovingAverageDetector):
         # number of data points before looking for a regression.
         stablized = self.count > self.config.min_data_points
 
+        score = abs(self.moving_avg_short.value - self.moving_avg_long.value)
+
         if (
             stablized
             and self.moving_avg_short.value > self.moving_avg_long.value
@@ -133,7 +135,7 @@ class MovingAverageCrossOverDetector(MovingAverageDetector):
             # The new fast moving average is above the new slow moving average.
             # The old fast moving average is below the old slow moving average.
             # This indicates an upwards trend.
-            return TrendType.Regressed
+            return TrendType.Regressed, score
 
         elif (
             stablized
@@ -143,9 +145,9 @@ class MovingAverageCrossOverDetector(MovingAverageDetector):
             # The new fast moving average is below the new slow moving average
             # The old fast moving average is above the old slow moving average
             # This indicates an downwards trend.
-            return TrendType.Improved
+            return TrendType.Improved, score
 
-        return TrendType.Unchanged
+        return TrendType.Unchanged, score
 
 
 @dataclass(frozen=True)
@@ -164,7 +166,7 @@ class MovingAverageRelativeChangeDetector(MovingAverageDetector):
         self.change_metric = config.change_metric
         self.threshold = abs(config.threshold)
 
-    def update(self, payload: DetectorPayload) -> Optional[TrendType]:
+    def update(self, payload: DetectorPayload) -> Tuple[Optional[TrendType], float]:
         if self.timestamp is not None and self.timestamp > payload.timestamp:
             # In the event that the timestamp is before the payload's timestamps,
             # we do not want to process this payload.
@@ -175,7 +177,7 @@ class MovingAverageRelativeChangeDetector(MovingAverageDetector):
                 payload.timestamp.isoformat(),
                 self.timestamp.isoformat(),
             )
-            return None
+            return None, 0
 
         old_moving_avg_short = self.moving_avg_short.value
         old_moving_avg_long = self.moving_avg_long.value
@@ -189,6 +191,8 @@ class MovingAverageRelativeChangeDetector(MovingAverageDetector):
         # number of data points before looking for a regression.
         stablized = self.count > self.config.min_data_points
 
+        score = abs(self.moving_avg_short.value - self.moving_avg_long.value)
+
         try:
             relative_change_old = (old_moving_avg_short - old_moving_avg_long) / abs(
                 old_moving_avg_long
@@ -208,13 +212,13 @@ class MovingAverageRelativeChangeDetector(MovingAverageDetector):
             and relative_change_old < self.threshold
             and relative_change_new > self.threshold
         ):
-            return TrendType.Regressed
+            return TrendType.Regressed, score
 
         elif (
             stablized
             and relative_change_old > -self.threshold
             and relative_change_new < -self.threshold
         ):
-            return TrendType.Improved
+            return TrendType.Improved, score
 
-        return TrendType.Unchanged
+        return TrendType.Unchanged, score

+ 2 - 2
src/sentry/statistical_detectors/detector.py

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from datetime import datetime
 from enum import Enum
-from typing import Any, Generic, List, Mapping, Optional, TypeVar
+from typing import Any, Generic, List, Mapping, Optional, Tuple, TypeVar
 
 
 class TrendType(Enum):
@@ -55,5 +55,5 @@ class DetectorStore(ABC, Generic[T]):
 
 class DetectorAlgorithm(ABC, Generic[T]):
     @abstractmethod
-    def update(self, payload: DetectorPayload) -> Optional[TrendType]:
+    def update(self, payload: DetectorPayload) -> Tuple[Optional[TrendType], float]:
         ...

+ 39 - 15
src/sentry/tasks/statistical_detectors.py

@@ -1,8 +1,10 @@
 from __future__ import annotations
 
+import heapq
 import logging
+from collections import defaultdict
 from datetime import datetime, timedelta
-from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
+from typing import Any, DefaultDict, Dict, Generator, List, Optional, Set, Tuple, Union
 
 import sentry_sdk
 from django.utils import timezone as django_timezone
@@ -158,18 +160,17 @@ def detect_transaction_trends(
     if not options.get("statistical_detectors.enable"):
         return
 
-    regressions = filter(
-        lambda trend: trend[0] == TrendType.Regressed,
-        _detect_transaction_trends(org_ids, project_ids, start),
-    )
+    ratelimit = options.get("statistical_detectors.ratelimit.ema")
+    trends = _detect_transaction_trends(org_ids, project_ids, start)
+    regressions = limit_regressions_by_project(trends, ratelimit)
 
     delay = 12  # hours
     delayed_start = start + timedelta(hours=delay)
 
-    for trends in chunked(regressions, TRANSACTIONS_PER_BATCH):
+    for regression_chunk in chunked(regressions, TRANSACTIONS_PER_BATCH):
         detect_transaction_change_points.apply_async(
             args=[
-                [(payload.project_id, payload.group) for _, payload in trends],
+                [(payload.project_id, payload.group) for payload in regression_chunk],
                 delayed_start,
             ],
             # delay the check by delay hours because we want to make sure there
@@ -288,7 +289,7 @@ def get_all_transaction_payloads(
 
 def _detect_transaction_trends(
     org_ids: List[int], project_ids: List[int], start: datetime
-) -> Generator[Tuple[Optional[TrendType], DetectorPayload], None, None]:
+) -> Generator[Tuple[Optional[TrendType], float, DetectorPayload], None, None]:
     unique_project_ids: Set[int] = set()
 
     transactions_count = 0
@@ -329,7 +330,7 @@ def _detect_transaction_trends(
                     sentry_sdk.capture_exception(e)
 
             detector = MovingAverageRelativeChangeDetector(state, detector_config)
-            trend_type = detector.update(payload)
+            trend_type, score = detector.update(payload)
             states.append(None if trend_type is None else detector.state.to_redis_dict())
 
             if trend_type == TrendType.Regressed:
@@ -339,7 +340,7 @@ def _detect_transaction_trends(
 
             unique_project_ids.add(payload.project_id)
 
-            yield (trend_type, payload)
+            yield (trend_type, score, payload)
 
         detector_store.bulk_write_states(payloads, states)
 
@@ -544,8 +545,10 @@ def detect_function_trends(project_ids: List[int], start: datetime, *args, **kwa
     if not options.get("statistical_detectors.enable"):
         return
 
+    ratelimit = options.get("statistical_detectors.ratelimit.ema")
+
     trends = _detect_function_trends(project_ids, start)
-    regressions = filter(lambda trend: trend[0] == TrendType.Regressed, trends)
+    regressions = limit_regressions_by_project(trends, ratelimit)
 
     delay = 12  # hours
     delayed_start = start + timedelta(hours=delay)
@@ -553,7 +556,7 @@ def detect_function_trends(project_ids: List[int], start: datetime, *args, **kwa
     for regression_chunk in chunked(regressions, FUNCTIONS_PER_BATCH):
         detect_function_change_points.apply_async(
             args=[
-                [(payload.project_id, payload.group) for _, payload in regression_chunk],
+                [(payload.project_id, payload.group) for payload in regression_chunk],
                 delayed_start,
             ],
             # delay the check by delay hours because we want to make sure there
@@ -600,7 +603,7 @@ def detect_function_change_points(
 
 def _detect_function_trends(
     project_ids: List[int], start: datetime
-) -> Generator[Tuple[Optional[TrendType], DetectorPayload], None, None]:
+) -> Generator[Tuple[Optional[TrendType], float, DetectorPayload], None, None]:
     unique_project_ids: Set[int] = set()
 
     functions_count = 0
@@ -638,7 +641,7 @@ def _detect_function_trends(
                     sentry_sdk.capture_exception(e)
 
             detector = MovingAverageRelativeChangeDetector(state, detector_config)
-            trend_type = detector.update(payload)
+            trend_type, score = detector.update(payload)
 
             states.append(None if trend_type is None else detector.state.to_redis_dict())
 
@@ -649,7 +652,7 @@ def _detect_function_trends(
 
             unique_project_ids.add(payload.project_id)
 
-            yield (trend_type, payload)
+            yield (trend_type, score, payload)
 
         detector_store.bulk_write_states(payloads, states)
 
@@ -1094,3 +1097,24 @@ def query_functions_timeseries(
             )
             continue
         yield project.id, fingerprint, results[key]
+
+
+def limit_regressions_by_project(
+    trends: Generator[Tuple[Optional[TrendType], float, DetectorPayload], None, None],
+    ratelimit: int,
+) -> Generator[DetectorPayload, None, None]:
+    regressions_by_project: DefaultDict[int, List[Tuple[float, DetectorPayload]]] = defaultdict(
+        list
+    )
+
+    for trend_type, score, payload in trends:
+        if trend_type != TrendType.Regressed:
+            continue
+        heapq.heappush(regressions_by_project[payload.project_id], (score, payload))
+
+        while ratelimit >= 0 and len(regressions_by_project[payload.project_id]) > ratelimit:
+            heapq.heappop(regressions_by_project[payload.project_id])
+
+    for regressions in regressions_by_project.values():
+        for _, regression in regressions:
+            yield regression

+ 6 - 4
tests/sentry/statistical_detectors/test_algorithm.py

@@ -261,7 +261,8 @@ def test_moving_average_cross_over_detector(
     )
 
     for payload in payloads:
-        trend_type = detector.update(payload)
+        trend_type, score = detector.update(payload)
+        assert score >= 0
         if trend_type == TrendType.Regressed:
             all_regressed.append(payload)
         elif trend_type == TrendType.Improved:
@@ -304,7 +305,7 @@ def test_moving_average_cross_over_detector_bad_order(
         value=100,
         timestamp=now,
     )
-    trend_type = detector.update(payload)
+    trend_type, _ = detector.update(payload)
     assert trend_type is not None
 
     payload = DetectorPayload(
@@ -314,7 +315,7 @@ def test_moving_average_cross_over_detector_bad_order(
         value=100,
         timestamp=now - timedelta(hours=1),
     )
-    trend_type = detector.update(payload)
+    trend_type, _ = detector.update(payload)
     assert trend_type is None
 
 
@@ -394,7 +395,8 @@ def test_moving_average_relative_change_detector(
     )
 
     for payload in payloads:
-        trend_type = detector.update(payload)
+        trend_type, score = detector.update(payload)
+        assert score >= 0
         if trend_type == TrendType.Regressed:
             all_regressed.append(payload)
         elif trend_type == TrendType.Improved:

+ 191 - 3
tests/sentry/tasks/test_statistical_detectors.py

@@ -10,13 +10,14 @@ from sentry.seer.utils import BreakpointData
 from sentry.sentry_metrics.use_case_id_registry import UseCaseID
 from sentry.snuba.discover import zerofill
 from sentry.snuba.metrics.naming_layer.mri import TransactionMRI
-from sentry.statistical_detectors.detector import DetectorPayload
+from sentry.statistical_detectors.detector import DetectorPayload, TrendType
 from sentry.tasks.statistical_detectors import (
     detect_function_change_points,
     detect_function_trends,
     detect_transaction_change_points,
     detect_transaction_trends,
     emit_function_regression_issue,
+    limit_regressions_by_project,
     query_functions,
     query_transactions,
     query_transactions_timeseries,
@@ -254,7 +255,6 @@ def test_detect_transaction_trends(
     query_transactions,
     timestamp,
     project,
-    organization,
 ):
     n = 20
     timestamps = [timestamp - timedelta(hours=n - i) for i in range(n)]
@@ -279,10 +279,126 @@ def test_detect_transaction_trends(
         }
     ), TaskRunner():
         for ts in timestamps:
-            detect_transaction_trends([organization.id], [project.id], ts)
+            detect_transaction_trends([project.organization.id], [project.id], ts)
     assert detect_transaction_change_points.apply_async.called
 
 
+@pytest.mark.parametrize(
+    ["ratelimit", "expected_calls"],
+    [(-1, 3), (0, 0), (1, 1), (2, 2), (3, 3)],
+)
+@mock.patch("sentry.tasks.statistical_detectors.query_transactions")
+@mock.patch("sentry.tasks.statistical_detectors.detect_transaction_change_points")
+@django_db_all
+def test_detect_transaction_trends_ratelimit(
+    detect_transaction_change_points,
+    query_transactions,
+    ratelimit,
+    expected_calls,
+    timestamp,
+    organization,
+    project,
+):
+    n = 20
+    timestamps = [timestamp - timedelta(hours=n - i) for i in range(n)]
+
+    query_transactions.side_effect = [
+        [
+            DetectorPayload(
+                project_id=project.id,
+                group="/1",
+                count=100,
+                value=100 if i < n / 2 else 301,
+                timestamp=ts,
+            ),
+            DetectorPayload(
+                project_id=project.id,
+                group="/2",
+                count=100,
+                value=100 if i < n / 2 else 302,
+                timestamp=ts,
+            ),
+            DetectorPayload(
+                project_id=project.id,
+                group="/3",
+                count=100,
+                value=100 if i < n / 2 else 303,
+                timestamp=ts,
+            ),
+        ]
+        for i, ts in enumerate(timestamps)
+    ]
+
+    with override_options(
+        {
+            "statistical_detectors.enable": True,
+            "statistical_detectors.enable.projects.performance": [project.id],
+            "statistical_detectors.ratelimit.ema": ratelimit,
+        }
+    ), TaskRunner():
+        for ts in timestamps:
+            detect_transaction_trends([project.organization.id], [project.id], ts)
+
+    if expected_calls > 0:
+        detect_transaction_change_points.apply_async.assert_has_calls(
+            [
+                mock.call(
+                    args=[
+                        [(project.id, "/1"), (project.id, "/2"), (project.id, "/3")][
+                            -expected_calls:
+                        ],
+                        timestamp + timedelta(hours=5),
+                    ],
+                    countdown=12 * 60 * 60,
+                ),
+            ],
+        )
+        assert detect_transaction_change_points.apply_async.call_count == 1
+    else:
+        assert detect_transaction_change_points.apply_async.call_count == 0
+
+
+@pytest.mark.parametrize(
+    ["ratelimit", "expected_idx"],
+    [
+        pytest.param(-1, 4, id="all"),
+        pytest.param(0, 0, id="zero per project"),
+        pytest.param(1, 2, id="one per project"),
+        pytest.param(2, 3, id="two per project"),
+        pytest.param(3, 4, id="three per project"),
+    ],
+)
+def test_limit_regressions_by_project(ratelimit, timestamp, expected_idx):
+    payloads = {
+        (project_id, group): DetectorPayload(
+            project_id=project_id,
+            group=f"{project_id}_{group}",
+            count=int(f"{project_id}_{group}"),
+            value=int(f"{project_id}_{group}"),
+            timestamp=timestamp,
+        )
+        for project_id in range(1, 4)
+        for group in range(1, project_id + 1)
+    }
+
+    def trends():
+        yield (None, 0, payloads[(1, 1)])
+        yield (TrendType.Improved, 0, payloads[(2, 1)])
+        yield (TrendType.Regressed, 0, payloads[(2, 2)])
+        yield (TrendType.Regressed, 0, payloads[(3, 1)])
+        yield (TrendType.Regressed, 1, payloads[(3, 2)])
+        yield (TrendType.Regressed, 2, payloads[(3, 3)])
+
+    expected_regressions = [
+        payloads[(2, 2)],
+        payloads[(3, 3)],
+        payloads[(3, 2)],
+        payloads[(3, 1)],
+    ][:expected_idx]
+    regressions = limit_regressions_by_project(trends(), ratelimit)
+    assert set(regressions) == set(expected_regressions)
+
+
 @mock.patch("sentry.tasks.statistical_detectors.query_functions")
 @mock.patch("sentry.tasks.statistical_detectors.detect_function_change_points")
 @django_db_all
@@ -319,6 +435,78 @@ def test_detect_function_trends(
     assert detect_function_change_points.apply_async.called
 
 
+@pytest.mark.parametrize(
+    ["ratelimit", "expected_calls"],
+    [(-1, 3), (0, 0), (1, 1), (2, 2), (3, 3)],
+)
+@mock.patch("sentry.tasks.statistical_detectors.query_functions")
+@mock.patch("sentry.tasks.statistical_detectors.detect_function_change_points")
+@django_db_all
+def test_detect_function_trends_ratelimit(
+    detect_function_change_points,
+    query_functions,
+    ratelimit,
+    expected_calls,
+    timestamp,
+    project,
+):
+    n = 20
+    timestamps = [timestamp - timedelta(hours=n - i) for i in range(n)]
+
+    query_functions.side_effect = [
+        [
+            DetectorPayload(
+                project_id=project.id,
+                group=1,
+                count=100,
+                value=100 if i < n / 2 else 301,
+                timestamp=ts,
+            ),
+            DetectorPayload(
+                project_id=project.id,
+                group=2,
+                count=100,
+                value=100 if i < n / 2 else 302,
+                timestamp=ts,
+            ),
+            DetectorPayload(
+                project_id=project.id,
+                group=3,
+                count=100,
+                value=100 if i < n / 2 else 303,
+                timestamp=ts,
+            ),
+        ]
+        for i, ts in enumerate(timestamps)
+    ]
+
+    with override_options(
+        {
+            "statistical_detectors.enable": True,
+            "statistical_detectors.enable.projects.profiling": [project.id],
+            "statistical_detectors.ratelimit.ema": ratelimit,
+        }
+    ), TaskRunner():
+        for ts in timestamps:
+            detect_function_trends([project.id], ts)
+
+    if expected_calls > 0:
+        detect_function_change_points.apply_async.assert_has_calls(
+            [
+                mock.call(
+                    args=[
+                        [(project.id, 1), (project.id, 2), (project.id, 3)][-expected_calls:],
+                        timestamp + timedelta(hours=5),
+                    ],
+                    countdown=12 * 60 * 60,
+                ),
+            ],
+        )
+        assert detect_function_change_points.apply_async.call_count == 1
+    else:
+        assert detect_function_change_points.apply_async.call_count == 0
+
+
 @mock.patch("sentry.tasks.statistical_detectors.emit_function_regression_issue")
 @mock.patch("sentry.tasks.statistical_detectors.detect_breakpoints")
 @mock.patch("sentry.tasks.statistical_detectors.raw_snql_query")