Browse Source

Process delayed alert conditions in batches of 10,000 (#75302)

# Description
Some orgs are sending 100k+ events per minute, and the processing is
taking to long for a single task.

This PR will look at the size of the hash and determine if it needs to
be batched.

There's some restrictions around the celery task / redis, info is
outlined in a code comment here:
https://github.com/getsentry/sentry/pull/75302/files#diff-f906e75a0e4419db4870fa45ca5a1608ca79beaa052c8bc50b4805607a665d27R482-R486
Josh Callender 7 months ago
parent
commit
215491de2b

+ 14 - 0
src/sentry/buffer/base.py

@@ -34,6 +34,7 @@ class Buffer(Service):
         "push_to_hash",
         "get_sorted_set",
         "get_hash",
+        "get_hash_length",
         "delete_hash",
         "delete_key",
     )
@@ -54,6 +55,11 @@ class Buffer(Service):
     ) -> dict[str, str]:
         return {}
 
+    def get_hash_length(
+        self, model: type[models.Model], field: dict[str, models.Model | str | int]
+    ) -> int:
+        raise NotImplementedError
+
     def get_sorted_set(self, key: str, min: float, max: float) -> list[tuple[int, datetime]]:
         return []
 
@@ -69,6 +75,14 @@ class Buffer(Service):
     ) -> None:
         return None
 
+    def push_to_hash_bulk(
+        self,
+        model: type[models.Model],
+        filters: dict[str, models.Model | str | int],
+        data: dict[str, str],
+    ) -> None:
+        raise NotImplementedError
+
     def delete_hash(
         self,
         model: type[models.Model],

+ 18 - 0
src/sentry/buffer/redis.py

@@ -81,13 +81,16 @@ class BufferHookRegistry:
 redis_buffer_registry = BufferHookRegistry()
 
 
+# Note HMSET is not supported after redis 4.0.0, after updating we can use HSET directly.
 class RedisOperation(Enum):
     SORTED_SET_ADD = "zadd"
     SORTED_SET_GET_RANGE = "zrangebyscore"
     SORTED_SET_DELETE_RANGE = "zremrangebyscore"
     HASH_ADD = "hset"
+    HASH_ADD_BULK = "hmset"
     HASH_GET_ALL = "hgetall"
     HASH_DELETE = "hdel"
+    HASH_LENGTH = "hlen"
 
 
 class PendingBuffer:
@@ -296,6 +299,15 @@ class RedisBuffer(Buffer):
         key = self._make_key(model, filters)
         self._execute_redis_operation(key, RedisOperation.HASH_ADD, field, value)
 
+    def push_to_hash_bulk(
+        self,
+        model: type[models.Model],
+        filters: dict[str, models.Model | str | int],
+        data: dict[str, str],
+    ) -> None:
+        key = self._make_key(model, filters)
+        self._execute_redis_operation(key, RedisOperation.HASH_ADD_BULK, data)
+
     def get_hash(
         self, model: type[models.Model], field: dict[str, models.Model | str | int]
     ) -> dict[str, str]:
@@ -311,6 +323,12 @@ class RedisBuffer(Buffer):
 
         return decoded_hash
 
+    def get_hash_length(
+        self, model: type[models.Model], field: dict[str, models.Model | str | int]
+    ) -> int:
+        key = self._make_key(model, field)
+        return self._execute_redis_operation(key, RedisOperation.HASH_LENGTH)
+
     def process_batch(self) -> None:
         try:
             redis_buffer_registry.callback(BufferHookEvent.FLUSH)

+ 5 - 0
src/sentry/options/defaults.py

@@ -2636,3 +2636,8 @@ register(
     default=1,
     flags=FLAG_AUTOMATOR_MODIFIABLE,
 )
+register(
+    "delayed_processing.batch_size",
+    default=10000,
+    flags=FLAG_AUTOMATOR_MODIFIABLE,
+)

+ 73 - 28
src/sentry/rules/processing/delayed_processing.py

@@ -4,13 +4,15 @@ import uuid
 from collections import defaultdict
 from collections.abc import Sequence
 from datetime import datetime, timedelta, timezone
+from itertools import islice
 from typing import Any, DefaultDict, NamedTuple
 
 import sentry_sdk
 from django.db.models import OuterRef, Subquery
 
-from sentry import buffer, nodestore
+from sentry import buffer, nodestore, options
 from sentry.buffer.redis import BufferHookEvent, redis_buffer_registry
+from sentry.db import models
 from sentry.eventstore.models import Event, GroupEvent
 from sentry.issues.issue_occurrence import IssueOccurrence
 from sentry.models.group import Group
@@ -85,8 +87,15 @@ def fetch_project(project_id: int) -> Project | None:
         return None
 
 
-def fetch_rulegroup_to_event_data(project_id: int) -> dict[str, str]:
-    return buffer.backend.get_hash(model=Project, field={"project_id": project_id})
+def fetch_rulegroup_to_event_data(project_id: int, batch_key: str | None = None) -> dict[str, str]:
+    field: dict[str, models.Model | int | str] = {
+        "project_id": project_id,
+    }
+
+    if batch_key:
+        field["batch_key"] = batch_key
+
+    return buffer.backend.get_hash(model=Project, field=field)
 
 
 def get_rules_to_groups(rulegroup_to_event_data: dict[str, str]) -> DefaultDict[int, set[int]]:
@@ -447,13 +456,17 @@ def fire_rules(
                 safe_execute(callback, groupevent, futures)
 
 
-def cleanup_redis_buffer(project_id: int, rules_to_groups: DefaultDict[int, set[int]]) -> None:
+def cleanup_redis_buffer(
+    project_id: int, rules_to_groups: DefaultDict[int, set[int]], batch_key: str | None
+) -> None:
     hashes_to_delete = [
         f"{rule}:{group}" for rule, groups in rules_to_groups.items() for group in groups
     ]
-    buffer.backend.delete_hash(
-        model=Project, filters={"project_id": project_id}, fields=hashes_to_delete
-    )
+    filters: dict[str, models.Model | str | int] = {"project_id": project_id}
+    if batch_key:
+        filters["batch_key"] = batch_key
+
+    buffer.backend.delete_hash(model=Project, filters=filters, fields=hashes_to_delete)
 
 
 def bucket_num_groups(num_groups: int) -> str:
@@ -463,6 +476,55 @@ def bucket_num_groups(num_groups: int) -> str:
     return "1"
 
 
+def process_rulegroups_in_batches(project_id: int):
+    """
+    This will check the number of rulegroup_to_event_data items in the Redis buffer for a project.
+
+    If the number is larger than the batch size, it will chunk the items and process them in batches.
+
+    The batches are replicated into a new redis hash with a unique filter (a uuid) to identify the batch.
+    We need to use a UUID because these batches can be created in multiple processes and we need to ensure
+    uniqueness across all of them for the centralized redis buffer. The batches are stored in redis because
+    we shouldn't pass objects that need to be pickled and 10k items could be problematic in the celery tasks
+    as arguments could be problematic. Finally, we can't use a pagination system on the data because
+    redis doesn't maintain the sort order of the hash keys.
+
+    `apply_delayed` will fetch the batch from redis and process the rules.
+    """
+    batch_size = options.get("delayed_processing.batch_size")
+    event_count = buffer.backend.get_hash_length(Project, {"project_id": project_id})
+
+    if event_count < batch_size:
+        return apply_delayed.delayed(project_id)
+
+    logger.info(
+        "delayed_processing.process_large_batch",
+        extra={"project_id": project_id, "count": event_count},
+    )
+
+    # if the dictionary is large, get the items and chunk them.
+    rulegroup_to_event_data = fetch_rulegroup_to_event_data(project_id)
+
+    with metrics.timer("delayed_processing.process_batch.duration"):
+        items = iter(rulegroup_to_event_data.items())
+
+        while batch := dict(islice(items, batch_size)):
+            batch_key = str(uuid.uuid4())
+
+            buffer.backend.push_to_hash_bulk(
+                model=Project,
+                filters={"project_id": project_id, "batch_key": batch_key},
+                data=batch,
+            )
+
+            # remove the batched items from the project rulegroup_to_event_data
+            buffer.backend.delete_hash(
+                model=Project, filters={"project_id": project_id}, fields=list(batch.keys())
+            )
+
+            apply_delayed.delayed(project_id, batch_key)
+
+
 def process_delayed_alert_conditions() -> None:
     with metrics.timer("delayed_processing.process_all_conditions.duration"):
         fetch_time = datetime.now(tz=timezone.utc)
@@ -473,7 +535,7 @@ def process_delayed_alert_conditions() -> None:
         logger.info("delayed_processing.project_id_list", extra={"project_ids": log_str})
 
         for project_id, _ in project_ids:
-            apply_delayed.delay(project_id)
+            process_rulegroups_in_batches(project_id)
 
         buffer.backend.delete_key(PROJECT_ID_BUFFER_LIST_KEY, min=0, max=fetch_time.timestamp())
 
@@ -487,32 +549,15 @@ def process_delayed_alert_conditions() -> None:
     time_limit=60,
     silo_mode=SiloMode.REGION,
 )
-def apply_delayed(project_id: int, *args: Any, **kwargs: Any) -> None:
+def apply_delayed(project_id: int, batch_key: str | None = None, *args: Any, **kwargs: Any) -> None:
     """
     Grab rules, groups, and events from the Redis buffer, evaluate the "slow" conditions in a bulk snuba query, and fire them if they pass
     """
     project = fetch_project(project_id)
     if not project:
-        # Should we remove the project_id from the redis queue?
-        return
-
-    rulegroup_to_event_data = fetch_rulegroup_to_event_data(project_id)
-    num_groups = len(rulegroup_to_event_data)
-    num_groups_bucketed = bucket_num_groups(num_groups)
-    metrics.incr("delayed_processing.num_groups", tags={"num_groups": num_groups_bucketed})
-
-    if num_groups >= 10000:
-        logger.error(
-            "delayed_processing.too_many_groups",
-            extra={
-                "project_id": project_id,
-                "num_groups": num_groups,
-                "organization_id": project.organization_id,
-            },
-        )
-        # TODO @saponifi3d - Split the processing from here into smaller groups
         return
 
+    rulegroup_to_event_data = fetch_rulegroup_to_event_data(project_id, batch_key)
     rules_to_groups = get_rules_to_groups(rulegroup_to_event_data)
     alert_rules = fetch_alert_rules(list(rules_to_groups.keys()))
     condition_groups = get_condition_query_groups(alert_rules, rules_to_groups)
@@ -542,7 +587,7 @@ def apply_delayed(project_id: int, *args: Any, **kwargs: Any) -> None:
     with metrics.timer("delayed_processing.fire_rules.duration"):
         fire_rules(rules_to_fire, parsed_rulegroup_to_event_data, alert_rules, project)
 
-    cleanup_redis_buffer(project_id, rules_to_groups)
+    cleanup_redis_buffer(project_id, rules_to_groups, batch_key)
 
 
 if not redis_buffer_registry.has(BufferHookEvent.FLUSH):

+ 7 - 0
tests/sentry/buffer/test_base.py

@@ -2,6 +2,7 @@ from datetime import timedelta
 from unittest import mock
 
 from django.utils import timezone
+from pytest import raises
 
 from sentry.buffer.base import Buffer
 from sentry.db import models
@@ -77,3 +78,9 @@ class BufferTest(TestCase):
         self.buf.process(Group, columns, filters, {"last_seen": the_date}, signal_only=True)
         group.refresh_from_db()
         assert group.times_seen == prev_times_seen
+
+    def test_push_to_hash_bulk(self):
+        raises(NotImplementedError, self.buf.push_to_hash_bulk, Group, {"id": 1}, {"foo": "bar"})
+
+    def test_get_hash_length(self):
+        raises(NotImplementedError, self.buf.get_hash_length, Group, {"id": 1})

+ 31 - 0
tests/sentry/buffer/test_redis.py

@@ -1,6 +1,7 @@
 import datetime
 import pickle
 from collections import defaultdict
+from collections.abc import Mapping
 from unittest import mock
 from unittest.mock import Mock
 
@@ -367,6 +368,36 @@ class TestRedisBuffer:
         self.buf.process("foo")
         process.assert_called_once_with(mock.Mock, {"times_seen": 1}, {"pk": 1}, {}, True)
 
+    @mock.patch("sentry.buffer.redis.RedisBuffer._make_key", mock.Mock(return_value="foo"))
+    def test_get_hash_length(self):
+        client = get_cluster_routing_client(self.buf.cluster, self.buf.is_redis_cluster)
+        data: Mapping[str | bytes, bytes | float | int | str] = {
+            "f": '{"pk": ["i","1"]}',
+            "i+times_seen": "1",
+            "m": "unittest.mock.Mock",
+            "s": "1",
+        }
+
+        client.hmset("foo", data)
+        buffer_length = self.buf.get_hash_length("foo", field={"bar": 1})
+        assert buffer_length == len(data)
+
+    @mock.patch("sentry.buffer.redis.RedisBuffer._make_key", mock.Mock(return_value="foo"))
+    def test_push_to_hash_bulk(self):
+        def decode_dict(d):
+            return {k: v.decode("utf-8") if isinstance(v, bytes) else v for k, v in d.items()}
+
+        client = get_cluster_routing_client(self.buf.cluster, self.buf.is_redis_cluster)
+        data = {
+            "f": '{"pk": ["i","1"]}',
+            "i+times_seen": "1",
+            "m": "unittest.mock.Mock",
+            "s": "1",
+        }
+        self.buf.push_to_hash_bulk(model=Project, filters={"project_id": 1}, data=data)
+        result = _hgetall_decode_keys(client, "foo", self.buf.is_redis_cluster)
+        assert decode_dict(result) == data
+
 
 #    @mock.patch("sentry.buffer.redis.RedisBuffer._make_key", mock.Mock(return_value="foo"))
 #    def test_incr_uses_signal_only(self):

+ 132 - 17
tests/sentry/rules/processing/test_delayed_processing.py

@@ -25,6 +25,7 @@ from sentry.rules.processing.delayed_processing import (
     apply_delayed,
     bucket_num_groups,
     bulk_fetch_events,
+    cleanup_redis_buffer,
     generate_unique_queries,
     get_condition_group_results,
     get_condition_query_groups,
@@ -34,11 +35,13 @@ from sentry.rules.processing.delayed_processing import (
     get_slow_conditions,
     parse_rulegroup_to_event_data,
     process_delayed_alert_conditions,
+    process_rulegroups_in_batches,
 )
 from sentry.rules.processing.processor import PROJECT_ID_BUFFER_LIST_KEY
 from sentry.testutils.cases import PerformanceIssueTestCase, RuleTestCase, TestCase
 from sentry.testutils.factories import EventType
 from sentry.testutils.helpers.datetime import before_now, freeze_time, iso_format
+from sentry.testutils.helpers.options import override_options
 from sentry.testutils.helpers.redis import mock_redis_buffer
 from sentry.utils import json
 from sentry.utils.safe import safe_execute
@@ -83,6 +86,23 @@ def mock_get_condition_group(descending=False):
 
 @freeze_time(FROZEN_TIME)
 class CreateEventTestCase(TestCase, BaseEventFrequencyPercentTest):
+    def setUp(self):
+        super().setUp()
+        self.mock_redis_buffer = mock_redis_buffer()
+        self.mock_redis_buffer.__enter__()
+
+    def tearDown(self):
+        self.mock_redis_buffer.__exit__(None, None, None)
+
+    def push_to_hash(self, project_id, rule_id, group_id, event_id=None, occurrence_id=None):
+        value = json.dumps({"event_id": event_id, "occurrence_id": occurrence_id})
+        buffer.backend.push_to_hash(
+            model=Project,
+            filters={"project_id": project_id},
+            field=f"{rule_id}:{group_id}",
+            value=value,
+        )
+
     def create_event(
         self,
         project_id: int,
@@ -643,23 +663,12 @@ class ParseRuleGroupToEventDataTest(TestCase):
 class ProcessDelayedAlertConditionsTest(CreateEventTestCase, PerformanceIssueTestCase):
     buffer_timestamp = (FROZEN_TIME + timedelta(seconds=1)).timestamp()
 
-    def push_to_hash(self, project_id, rule_id, group_id, event_id=None, occurrence_id=None):
-        value = json.dumps({"event_id": event_id, "occurrence_id": occurrence_id})
-        buffer.backend.push_to_hash(
-            model=Project,
-            filters={"project_id": project_id},
-            field=f"{rule_id}:{group_id}",
-            value=value,
-        )
-
     def assert_buffer_cleared(self, project_id):
         rule_group_data = buffer.backend.get_hash(Project, {"project_id": project_id})
         assert rule_group_data == {}
 
     def setUp(self):
         super().setUp()
-        self.mock_redis_buffer = mock_redis_buffer()
-        self.mock_redis_buffer.__enter__()
 
         self.tag_filter = {
             "id": "sentry.rules.filters.tagged_event.TaggedEventFilter",
@@ -747,11 +756,8 @@ class ProcessDelayedAlertConditionsTest(CreateEventTestCase, PerformanceIssueTes
         self.push_to_hash(self.project_two.id, self.rule3.id, self.group3.id, self.event3.event_id)
         self.push_to_hash(self.project_two.id, self.rule4.id, self.group4.id, self.event4.event_id)
 
-    def tearDown(self):
-        self.mock_redis_buffer.__exit__(None, None, None)
-
-    @patch("sentry.rules.processing.delayed_processing.apply_delayed")
-    def test_fetches_from_buffer_and_executes(self, mock_apply_delayed):
+    @patch("sentry.rules.processing.delayed_processing.process_rulegroups_in_batches")
+    def test_fetches_from_buffer_and_executes(self, mock_process_in_batches):
         self._push_base_events()
         # To get the correct mapping, we need to return the correct
         # rulegroup_event mapping based on the project_id input
@@ -761,7 +767,7 @@ class ProcessDelayedAlertConditionsTest(CreateEventTestCase, PerformanceIssueTes
             (self.project, self.rulegroup_event_mapping_one),
             (self.project_two, self.rulegroup_event_mapping_two),
         ):
-            assert mock_apply_delayed.delay.call_count == 2
+            assert mock_process_in_batches.call_count == 2
 
         project_ids = buffer.backend.get_sorted_set(
             PROJECT_ID_BUFFER_LIST_KEY, 0, self.buffer_timestamp
@@ -1326,6 +1332,58 @@ class ProcessDelayedAlertConditionsTest(CreateEventTestCase, PerformanceIssueTes
         self._assert_count_percent_results(safe_execute_callthrough)
 
 
+class ProcessRuleGroupsInBatchesTest(CreateEventTestCase):
+    def setUp(self):
+        super().setUp()
+
+        self.project = self.create_project()
+        self.group = self.create_group(self.project)
+        self.group_two = self.create_group(self.project)
+        self.group_three = self.create_group(self.project)
+        self.rule = self.create_alert_rule()
+
+    @patch("sentry.rules.processing.delayed_processing.apply_delayed")
+    def test_no_redis_data(self, mock_apply_delayed):
+        process_rulegroups_in_batches(self.project.id)
+        mock_apply_delayed.delayed.assert_called_once_with(self.project.id)
+
+    @patch("sentry.rules.processing.delayed_processing.apply_delayed")
+    def test_basic(self, mock_apply_delayed):
+        self.push_to_hash(self.project.id, self.rule.id, self.group.id)
+        self.push_to_hash(self.project.id, self.rule.id, self.group_two.id)
+        self.push_to_hash(self.project.id, self.rule.id, self.group_three.id)
+
+        process_rulegroups_in_batches(self.project.id)
+        mock_apply_delayed.delayed.assert_called_once_with(self.project.id)
+
+    @override_options({"delayed_processing.batch_size": 2})
+    @patch("sentry.rules.processing.delayed_processing.apply_delayed")
+    def test_batch(self, mock_apply_delayed):
+        mock_delayed = mock_apply_delayed.delayed
+        self.push_to_hash(self.project.id, self.rule.id, self.group.id)
+        self.push_to_hash(self.project.id, self.rule.id, self.group_two.id)
+        self.push_to_hash(self.project.id, self.rule.id, self.group_three.id)
+
+        process_rulegroups_in_batches(self.project.id)
+        assert mock_delayed.call_count == 2
+
+        # Validate the batches are created correctly
+        batch_one_key = mock_delayed.call_args_list[0][0][1]
+        batch_one = buffer.backend.get_hash(
+            model=Project, field={"project_id": self.project.id, "batch_key": batch_one_key}
+        )
+        batch_two_key = mock_delayed.call_args_list[1][0][1]
+        batch_two = buffer.backend.get_hash(
+            model=Project, field={"project_id": self.project.id, "batch_key": batch_two_key}
+        )
+
+        assert len(batch_one) == 2
+        assert len(batch_two) == 1
+
+        # Validate that we've cleared the original data to reduce storage usage
+        assert not buffer.backend.get_hash(model=Project, field={"project_id": self.project.id})
+
+
 class UniqueConditionQueryTest(TestCase):
     """
     Tests for the UniqueConditionQuery class. Currently, this is just to pass codecov.
@@ -1352,3 +1410,60 @@ class DataAndGroupsTest(TestCase):
             repr(condition)
             == "<DataAndGroups data: {'id': 'sentry.rules.conditions.event_frequency.EventFrequencyCondition', 'value': 1, 'interval': '1h'} group_ids: {1, 2}>"
         )
+
+
+class CleanupRedisBufferTest(CreateEventTestCase):
+    def setUp(self):
+        super().setUp()
+
+        self.project = self.create_project()
+        self.group = self.create_group(self.project)
+        self.rule = self.create_alert_rule()
+
+    def test_cleanup_redis(self):
+        self.push_to_hash(self.project.id, self.rule.id, self.group.id)
+        rules_to_groups: defaultdict[int, set[int]] = defaultdict(set)
+        rules_to_groups[self.rule.id].add(self.group.id)
+
+        cleanup_redis_buffer(self.project.id, rules_to_groups, None)
+        rule_group_data = buffer.backend.get_hash(Project, {"project_id": self.project.id})
+        assert rule_group_data == {}
+
+    @override_options({"delayed_processing.batch_size": 2})
+    @patch("sentry.rules.processing.delayed_processing.apply_delayed")
+    def test_batched_cleanup(self, mock_apply_delayed):
+        group_two = self.create_group(self.project)
+        group_three = self.create_group(self.project)
+
+        self.push_to_hash(self.project.id, self.rule.id, self.group.id)
+        self.push_to_hash(self.project.id, self.rule.id, group_two.id)
+        self.push_to_hash(self.project.id, self.rule.id, group_three.id)
+
+        rules_to_groups: defaultdict[int, set[int]] = defaultdict(set)
+        rules_to_groups[self.rule.id].add(self.group.id)
+        rules_to_groups[self.rule.id].add(group_two.id)
+        rules_to_groups[self.rule.id].add(group_three.id)
+
+        process_rulegroups_in_batches(self.project.id)
+        batch_one_key = mock_apply_delayed.delayed.call_args_list[0][0][1]
+        batch_two_key = mock_apply_delayed.delayed.call_args_list[1][0][1]
+
+        # Verify process_rulegroups_in_batches removed the data from the buffer
+        rule_group_data = buffer.backend.get_hash(Project, {"project_id": self.project.id})
+        assert rule_group_data == {}
+
+        cleanup_redis_buffer(self.project.id, rules_to_groups, batch_one_key)
+
+        # Verify the batch we "executed" is removed
+        rule_group_data = buffer.backend.get_hash(
+            Project, {"project_id": self.project.id, "batch_key": batch_one_key}
+        )
+        assert rule_group_data == {}
+
+        # Verify the batch we didn't execute is still in redis
+        rule_group_data = buffer.backend.get_hash(
+            Project, {"project_id": self.project.id, "batch_key": batch_two_key}
+        )
+        assert rule_group_data == {
+            f"{self.rule.id}:{group_three.id}": '{"event_id":null,"occurrence_id":null}',
+        }