Browse Source

feat(escalating-issues): Optimize query getting until-escalating groups (#47802)

## Objective:
Replace suboptimal query of fetching all the groups at once with a query
that uses RangeQuerySetWrapper and iterates over batches of the query
results.

First we want to grab a list of all projects (10k at a time, otherwise,
it takes too long and it will get killed). Then, for each group of
projects, schedule a task with an array of project IDs.

In each task, we want to query list of ignored until escalating groups
for each project and create a list of all these groups. Then, pass the
group IDs to the function that will query Snuba
(query_groups_past_counts) and create the forecasts
NisanthanNanthakumar 1 year ago
parent
commit
f35cfb0889

+ 8 - 6
src/sentry/issues/forecasts.py

@@ -30,13 +30,15 @@ def save_forecast_per_group(
     """
     time = datetime.now()
     group_dict = {group.id: group for group in until_escalating_groups}
-    for group_id in group_counts.keys():
-        forecasts = generate_issue_forecast(group_counts[group_id], time)
+    for group_id, group_count in group_counts.items():
+        forecasts = generate_issue_forecast(group_count, time)
         forecasts_list = [forecast["forecasted_value"] for forecast in forecasts]
-        escalating_group_forecast = EscalatingGroupForecast(
-            group_dict[group_id].project.id, group_id, forecasts_list, datetime.now()
-        )
-        escalating_group_forecast.save()
+
+        if group_dict.get(group_id):
+            escalating_group_forecast = EscalatingGroupForecast(
+                group_dict[group_id].project.id, group_id, forecasts_list, time
+            )
+            escalating_group_forecast.save()
     logger.info(
         "Saved forecasts in nodestore",
         extra={"num_groups": len(group_counts.keys())},

+ 38 - 18
src/sentry/tasks/weekly_escalating_forecast.py

@@ -1,12 +1,15 @@
 import logging
+from datetime import datetime, timedelta
 from typing import Dict, List, TypedDict
 
 from sentry_sdk.crons.decorator import monitor
 
 from sentry.issues.forecasts import generate_and_save_forecasts
-from sentry.models import Group, GroupStatus
+from sentry.models import Group, GroupStatus, ObjectStatus, Project
 from sentry.tasks.base import instrumented_task
 from sentry.types.group import GroupSubStatus
+from sentry.utils.iterators import chunked
+from sentry.utils.query import RangeQuerySetWrapper
 
 
 class GroupCount(TypedDict):
@@ -18,6 +21,8 @@ ParsedGroupsCount = Dict[int, GroupCount]
 
 logger = logging.getLogger(__name__)
 
+ITERATOR_CHUNK = 10_000
+
 
 @instrumented_task(
     name="sentry.tasks.weekly_escalating_forecast.run_escalating_forecast",
@@ -30,20 +35,35 @@ def run_escalating_forecast() -> None:
     Run the escalating forecast algorithm on archived until escalating issues.
     """
     logger.info("Starting task for sentry.tasks.weekly_escalating_forecast.run_escalating_forecast")
-    # TODO: Do not limit to project id = 1 and limit 10 once these topics are clarified
-    # TODO: If possible, fetch group_id instead of the entire group model
-    until_escalating_groups = list(
-        Group.objects.filter(
-            status=GroupStatus.IGNORED,
-            substatus=GroupSubStatus.UNTIL_ESCALATING,
-            project__id=1,
-        )[:10]
-    )
-    logger.info(
-        "Checking for archived until escalating groups",
-        extra={"num_groups": len(until_escalating_groups)},
-    )
-    if not until_escalating_groups:
-        return
-
-    generate_and_save_forecasts(until_escalating_groups)
+
+    for project_ids in chunked(
+        RangeQuerySetWrapper(
+            Project.objects.filter(status=ObjectStatus.VISIBLE).values_list("id", flat=True),
+            result_value_getter=lambda item: item,
+            step=ITERATOR_CHUNK,
+        ),
+        ITERATOR_CHUNK,
+    ):
+        generate_forecasts_for_projects.delay(project_ids=project_ids)
+
+
+@instrumented_task(
+    name="sentry.tasks.weekly_escalating_forecast.generate_forecasts_for_projects",
+    queue="weekly_escalating_forecast",
+    max_retries=3,
+    default_retry_delay=60,
+)  # type: ignore
+def generate_forecasts_for_projects(project_ids: List[int]) -> None:
+    for until_escalating_groups in chunked(
+        RangeQuerySetWrapper(
+            Group.objects.filter(
+                status=GroupStatus.IGNORED,
+                substatus=GroupSubStatus.UNTIL_ESCALATING,
+                project_id__in=project_ids,
+                last_seen__gte=datetime.now() - timedelta(days=7),
+            ),
+            step=ITERATOR_CHUNK,
+        ),
+        ITERATOR_CHUNK,
+    ):
+        generate_and_save_forecasts(groups=until_escalating_groups)

+ 72 - 64
tests/sentry/tasks/test_weekly_escalating_forecast.py

@@ -33,91 +33,99 @@ class TestWeeklyEscalatingForecast(APITestCase, SnubaTestCase):  # type: ignore
     def test_empty_escalating_forecast(
         self, mock_query_groups_past_counts: MagicMock, record_mock: MagicMock
     ) -> None:
-        group_list = self.create_archived_until_escalating_groups(num_groups=1)
+        with self.tasks():
+            group_list = self.create_archived_until_escalating_groups(num_groups=1)
 
-        mock_query_groups_past_counts.return_value = {}
+            mock_query_groups_past_counts.return_value = {}
 
-        run_escalating_forecast()
-        fetched_forecast = EscalatingGroupForecast.fetch(group_list[0].project.id, group_list[0].id)
-        assert fetched_forecast is not None
-        assert fetched_forecast.project_id == group_list[0].project.id
-        assert fetched_forecast.group_id == group_list[0].id
-        assert fetched_forecast.forecast == DEFAULT_MINIMUM_CEILING_FORECAST
-        record_mock.assert_called_with("issue_forecasts.saved", num_groups=0)
+            run_escalating_forecast()
+            fetched_forecast = EscalatingGroupForecast.fetch(
+                group_list[0].project.id, group_list[0].id
+            )
+            assert fetched_forecast is not None
+            assert fetched_forecast.project_id == group_list[0].project.id
+            assert fetched_forecast.group_id == group_list[0].id
+            assert fetched_forecast.forecast == DEFAULT_MINIMUM_CEILING_FORECAST
+            record_mock.assert_called_with("issue_forecasts.saved", num_groups=0)
 
     @patch("sentry.analytics.record")
     @patch("sentry.issues.forecasts.query_groups_past_counts")
     def test_single_group_escalating_forecast(
         self, mock_query_groups_past_counts: MagicMock, record_mock: MagicMock
     ) -> None:
-        group_list = self.create_archived_until_escalating_groups(num_groups=1)
-
-        mock_query_groups_past_counts.return_value = get_mock_groups_past_counts_response(
-            num_days=7, num_hours=1, groups=group_list
-        )
-
-        run_escalating_forecast()
-        approximate_date_added = datetime.now(pytz.utc)
-        fetched_forecast = EscalatingGroupForecast.fetch(group_list[0].project.id, group_list[0].id)
-        assert fetched_forecast is not None
-        assert fetched_forecast.project_id == group_list[0].project.id
-        assert fetched_forecast.group_id == group_list[0].id
-        assert fetched_forecast.forecast == DEFAULT_MINIMUM_CEILING_FORECAST
-        assert fetched_forecast.date_added.replace(
-            second=0, microsecond=0
-        ) == approximate_date_added.replace(second=0, microsecond=0)
-        assert fetched_forecast.date_added < approximate_date_added
-        record_mock.assert_called_with("issue_forecasts.saved", num_groups=1)
+        with self.tasks():
+            group_list = self.create_archived_until_escalating_groups(num_groups=1)
 
-    @patch("sentry.analytics.record")
-    @patch("sentry.issues.forecasts.query_groups_past_counts")
-    def test_multiple_groups_escalating_forecast(
-        self, mock_query_groups_past_counts: MagicMock, record_mock: MagicMock
-    ) -> None:
-        group_list = self.create_archived_until_escalating_groups(num_groups=3)
-
-        mock_query_groups_past_counts.return_value = get_mock_groups_past_counts_response(
-            num_days=7, num_hours=23, groups=group_list
-        )
+            mock_query_groups_past_counts.return_value = get_mock_groups_past_counts_response(
+                num_days=7, num_hours=1, groups=group_list
+            )
 
-        run_escalating_forecast()
-        approximate_date_added = datetime.now(pytz.utc)
-        for i in range(len(group_list)):
+            run_escalating_forecast()
+            approximate_date_added = datetime.now(pytz.utc)
             fetched_forecast = EscalatingGroupForecast.fetch(
-                group_list[i].project.id, group_list[i].id
+                group_list[0].project.id, group_list[0].id
             )
             assert fetched_forecast is not None
-            assert fetched_forecast.project_id == group_list[i].project.id
-            assert fetched_forecast.group_id == group_list[i].id
+            assert fetched_forecast.project_id == group_list[0].project.id
+            assert fetched_forecast.group_id == group_list[0].id
             assert fetched_forecast.forecast == DEFAULT_MINIMUM_CEILING_FORECAST
             assert fetched_forecast.date_added.replace(
                 second=0, microsecond=0
             ) == approximate_date_added.replace(second=0, microsecond=0)
             assert fetched_forecast.date_added < approximate_date_added
-            record_mock.assert_called_with("issue_forecasts.saved", num_groups=3)
+            record_mock.assert_called_with("issue_forecasts.saved", num_groups=1)
+
+    @patch("sentry.analytics.record")
+    @patch("sentry.issues.forecasts.query_groups_past_counts")
+    def test_multiple_groups_escalating_forecast(
+        self, mock_query_groups_past_counts: MagicMock, record_mock: MagicMock
+    ) -> None:
+        with self.tasks():
+            group_list = self.create_archived_until_escalating_groups(num_groups=3)
+
+            mock_query_groups_past_counts.return_value = get_mock_groups_past_counts_response(
+                num_days=7, num_hours=23, groups=group_list
+            )
+
+            run_escalating_forecast()
+            approximate_date_added = datetime.now(pytz.utc)
+            for i in range(len(group_list)):
+                fetched_forecast = EscalatingGroupForecast.fetch(
+                    group_list[i].project.id, group_list[i].id
+                )
+                assert fetched_forecast is not None
+                assert fetched_forecast.project_id == group_list[i].project.id
+                assert fetched_forecast.group_id == group_list[i].id
+                assert fetched_forecast.forecast == DEFAULT_MINIMUM_CEILING_FORECAST
+                assert fetched_forecast.date_added.replace(
+                    second=0, microsecond=0
+                ) == approximate_date_added.replace(second=0, microsecond=0)
+                assert fetched_forecast.date_added < approximate_date_added
+                record_mock.assert_called_with("issue_forecasts.saved", num_groups=3)
 
     @patch("sentry.analytics.record")
     @patch("sentry.issues.forecasts.query_groups_past_counts")
     def test_update_group_escalating_forecast(
         self, mock_query_groups_past_counts: MagicMock, record_mock: MagicMock
     ) -> None:
-        group_list = self.create_archived_until_escalating_groups(num_groups=1)
-
-        mock_query_groups_past_counts.return_value = get_mock_groups_past_counts_response(
-            num_days=7, num_hours=2, groups=group_list
-        )
-
-        run_escalating_forecast()
-        first_fetched_forecast = EscalatingGroupForecast.fetch(
-            group_list[0].project.id, group_list[0].id
-        )
-
-        # Assert update when this is run twice
-        run_escalating_forecast()
-        second_fetched_forecast = EscalatingGroupForecast.fetch(
-            group_list[0].project.id, group_list[0].id
-        )
-        assert first_fetched_forecast is not None
-        assert second_fetched_forecast is not None
-        assert first_fetched_forecast.date_added < second_fetched_forecast.date_added
-        record_mock.assert_called_with("issue_forecasts.saved", num_groups=1)
+        with self.tasks():
+            group_list = self.create_archived_until_escalating_groups(num_groups=1)
+
+            mock_query_groups_past_counts.return_value = get_mock_groups_past_counts_response(
+                num_days=7, num_hours=2, groups=group_list
+            )
+
+            run_escalating_forecast()
+            first_fetched_forecast = EscalatingGroupForecast.fetch(
+                group_list[0].project.id, group_list[0].id
+            )
+
+            # Assert update when this is run twice
+            run_escalating_forecast()
+            second_fetched_forecast = EscalatingGroupForecast.fetch(
+                group_list[0].project.id, group_list[0].id
+            )
+            assert first_fetched_forecast is not None
+            assert second_fetched_forecast is not None
+            assert first_fetched_forecast.date_added < second_fetched_forecast.date_added
+            record_mock.assert_called_with("issue_forecasts.saved", num_groups=1)