Browse Source

:recycle: ref(metric alerts): refactor action handler factory (#86459)

this pr breaks the `AlertRuleTriggerAction`, `Incident`, `Project`
dependency for the `ActionHandlerFactory`.

this will allow us to define new entrypoints for the action handlers for
each integration which can invoke the underlying business logic using
legacy models of new models using a translation layer (pr for this comes
later).

instead of putting `action`, `incident`, `project` as member variables,
u pass it to the `fire`, `resolve` methods directly.

i verified that firing notifications works locally

![image](https://github.com/user-attachments/assets/73851292-1f3a-40c6-b36b-bc199a3afbfa)


## Review Note
i have a local commit to update the 47 broken tests, but waiting for
cursory reviews first since it will massively increase the diff. the
failing tests are for `test_subscription_processor` and i need to change
assertions for majority of the 86 tests , which inflates the diff.

UPDATE: i just pushed the commit. i would recommend ignoring the commit
when reviewing and look at it separately
Raj Joshi 5 days ago
parent
commit
090cdb7e1e

+ 97 - 43
src/sentry/incidents/action_handlers.py

@@ -51,19 +51,12 @@ class ActionHandler(metaclass=abc.ABCMeta):
     def provider(self) -> str:
         raise NotImplementedError
 
-    def __init__(
+    @abc.abstractmethod
+    def fire(
         self,
         action: AlertRuleTriggerAction,
         incident: Incident,
         project: Project,
-    ) -> None:
-        self.action = action
-        self.incident = incident
-        self.project = project
-
-    @abc.abstractmethod
-    def fire(
-        self,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
@@ -73,6 +66,9 @@ class ActionHandler(metaclass=abc.ABCMeta):
     @abc.abstractmethod
     def resolve(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
@@ -80,14 +76,19 @@ class ActionHandler(metaclass=abc.ABCMeta):
         pass
 
     def record_alert_sent_analytics(
-        self, external_id: int | str | None = None, notification_uuid: str | None = None
+        self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
+        external_id: int | str | None = None,
+        notification_uuid: str | None = None,
     ) -> None:
         analytics.record(
             "alert.sent",
-            organization_id=self.incident.organization_id,
-            project_id=self.project.id,
+            organization_id=incident.organization_id,
+            project_id=project.id,
             provider=self.provider,
-            alert_id=self.incident.alert_rule_id,
+            alert_id=incident.alert_rule_id,
             alert_type="metric_alert",
             external_id=str(external_id) if external_id is not None else "",
             notification_uuid=notification_uuid or "",
@@ -97,25 +98,34 @@ class ActionHandler(metaclass=abc.ABCMeta):
 class DefaultActionHandler(ActionHandler):
     def fire(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
     ) -> None:
-        if not RuleSnooze.objects.is_snoozed_for_all(alert_rule=self.incident.alert_rule):
-            self.send_alert(metric_value, new_status, notification_uuid)
+        if not RuleSnooze.objects.is_snoozed_for_all(alert_rule=incident.alert_rule):
+            self.send_alert(action, incident, project, metric_value, new_status, notification_uuid)
 
     def resolve(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
     ) -> None:
-        if not RuleSnooze.objects.is_snoozed_for_all(alert_rule=self.incident.alert_rule):
-            self.send_alert(metric_value, new_status, notification_uuid)
+        if not RuleSnooze.objects.is_snoozed_for_all(alert_rule=incident.alert_rule):
+            self.send_alert(action, incident, project, metric_value, new_status, notification_uuid)
 
     @abc.abstractmethod
     def send_alert(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
@@ -133,24 +143,26 @@ class EmailActionHandler(ActionHandler):
     def provider(self) -> str:
         return "email"
 
-    def _get_targets(self) -> set[int]:
-        target = self.action.target
+    def _get_targets(
+        self, action: AlertRuleTriggerAction, incident: Incident, project: Project
+    ) -> set[int]:
+        target = action.target
         if not target:
             return set()
 
-        if RuleSnooze.objects.is_snoozed_for_all(alert_rule=self.incident.alert_rule):
+        if RuleSnooze.objects.is_snoozed_for_all(alert_rule=incident.alert_rule):
             return set()
 
-        if self.action.target_type == AlertRuleTriggerAction.TargetType.USER.value:
+        if action.target_type == AlertRuleTriggerAction.TargetType.USER.value:
             assert isinstance(target, RpcUser)
             if RuleSnooze.objects.is_snoozed_for_user(
-                user_id=target.id, alert_rule=self.incident.alert_rule
+                user_id=target.id, alert_rule=incident.alert_rule
             ):
                 return set()
 
             return {target.id}
 
-        elif self.action.target_type == AlertRuleTriggerAction.TargetType.TEAM.value:
+        elif action.target_type == AlertRuleTriggerAction.TargetType.TEAM.value:
             assert isinstance(target, Team)
             out = get_notification_recipients(
                 recipients=list(
@@ -158,29 +170,41 @@ class EmailActionHandler(ActionHandler):
                     for member in target.member_set
                 ),
                 type=NotificationSettingEnum.ISSUE_ALERTS,
-                organization_id=self.project.organization_id,
-                project_ids=[self.project.id],
+                organization_id=incident.organization_id,
+                project_ids=[project.id],
                 actor_type=ActorType.USER,
             )
             users = out[ExternalProviders.EMAIL]
 
             snoozed_users = RuleSnooze.objects.filter(
-                alert_rule=self.incident.alert_rule, user_id__in=[user.id for user in users]
+                alert_rule=incident.alert_rule, user_id__in=[user.id for user in users]
             ).values_list("user_id", flat=True)
             return {user.id for user in users if user.id not in snoozed_users}
 
         return set()
 
-    def get_targets(self) -> Sequence[tuple[int, str]]:
-        return list(get_email_addresses(self._get_targets(), project=self.project).items())
+    def get_targets(
+        self, action: AlertRuleTriggerAction, incident: Incident, project: Project
+    ) -> Sequence[tuple[int, str]]:
+        return list(
+            get_email_addresses(
+                self._get_targets(action, incident, project), project=project
+            ).items()
+        )
 
     def fire(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
     ) -> None:
         self.email_users(
+            action,
+            incident,
+            project,
             trigger_status=TriggerStatus.ACTIVE,
             incident_status=new_status,
             notification_uuid=notification_uuid,
@@ -188,11 +212,17 @@ class EmailActionHandler(ActionHandler):
 
     def resolve(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
     ) -> None:
         self.email_users(
+            action,
+            incident,
+            project,
             trigger_status=TriggerStatus.RESOLVED,
             incident_status=new_status,
             notification_uuid=notification_uuid,
@@ -200,25 +230,30 @@ class EmailActionHandler(ActionHandler):
 
     def email_users(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         trigger_status: TriggerStatus,
         incident_status: IncidentStatus,
         notification_uuid: str | None = None,
     ) -> None:
-        targets = [(user_id, email) for user_id, email in self.get_targets()]
+        targets = [
+            (user_id, email) for user_id, email in self.get_targets(action, incident, project)
+        ]
         users = user_service.get_many_by_id(ids=[user_id for user_id, _ in targets])
         for index, (user_id, email) in enumerate(targets):
             user = users[index]
             email_context = generate_incident_trigger_email_context(
-                project=self.project,
-                incident=self.incident,
-                alert_rule_trigger=self.action.alert_rule_trigger,
+                project=project,
+                incident=incident,
+                alert_rule_trigger=action.alert_rule_trigger,
                 trigger_status=trigger_status,
                 incident_status=incident_status,
                 user=user,
                 notification_uuid=notification_uuid,
             )
             self.build_message(email_context, trigger_status, user_id).send_async(to=[email])
-            self.record_alert_sent_analytics(user_id, notification_uuid)
+            self.record_alert_sent_analytics(action, incident, project, user_id, notification_uuid)
 
     def build_message(
         self, context: dict[str, Any], status: TriggerStatus, user_id: int
@@ -227,7 +262,7 @@ class EmailActionHandler(ActionHandler):
 
         return MessageBuilder(
             subject="[{}] {} - {}".format(
-                context["status"], context["incident_name"], self.project.slug
+                context["status"], context["incident_name"], context["project_slug"]
             ),
             template="sentry/emails/incidents/trigger.txt",
             html_template="sentry/emails/incidents/trigger.html",
@@ -250,6 +285,9 @@ class PagerDutyActionHandler(DefaultActionHandler):
 
     def send_alert(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
@@ -257,14 +295,16 @@ class PagerDutyActionHandler(DefaultActionHandler):
         from sentry.integrations.pagerduty.utils import send_incident_alert_notification
 
         success = send_incident_alert_notification(
-            action=self.action,
-            incident=self.incident,
+            action=action,
+            incident=incident,
             new_status=new_status,
             metric_value=metric_value,
             notification_uuid=notification_uuid,
         )
         if success:
-            self.record_alert_sent_analytics(self.action.target_identifier, notification_uuid)
+            self.record_alert_sent_analytics(
+                action, incident, project, action.target_identifier, notification_uuid
+            )
 
 
 @AlertRuleTriggerAction.register_type(
@@ -280,6 +320,9 @@ class OpsgenieActionHandler(DefaultActionHandler):
 
     def send_alert(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
@@ -287,14 +330,16 @@ class OpsgenieActionHandler(DefaultActionHandler):
         from sentry.integrations.opsgenie.utils import send_incident_alert_notification
 
         success = send_incident_alert_notification(
-            action=self.action,
-            incident=self.incident,
+            action=action,
+            incident=incident,
             new_status=new_status,
             metric_value=metric_value,
             notification_uuid=notification_uuid,
         )
         if success:
-            self.record_alert_sent_analytics(self.action.target_identifier, notification_uuid)
+            self.record_alert_sent_analytics(
+                action, incident, project, action.target_identifier, notification_uuid
+            )
 
 
 @AlertRuleTriggerAction.register_type(
@@ -309,6 +354,9 @@ class SentryAppActionHandler(DefaultActionHandler):
 
     def send_alert(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
@@ -316,10 +364,16 @@ class SentryAppActionHandler(DefaultActionHandler):
         from sentry.rules.actions.notify_event_service import send_incident_alert_notification
 
         success = send_incident_alert_notification(
-            self.action, self.incident, new_status, metric_value, notification_uuid
+            action=action,
+            incident=incident,
+            new_status=new_status,
+            metric_value=metric_value,
+            notification_uuid=notification_uuid,
         )
         if success:
-            self.record_alert_sent_analytics(self.action.sentry_app_id, notification_uuid)
+            self.record_alert_sent_analytics(
+                action, incident, project, action.sentry_app_id, notification_uuid
+            )
 
 
 def format_duration(minutes):
@@ -344,7 +398,7 @@ def format_duration(minutes):
 
 
 def generate_incident_trigger_email_context(
-    project,
+    project: Project,
     incident: Incident,
     alert_rule_trigger: AlertRuleTrigger,
     trigger_status: TriggerStatus,

+ 16 - 21
src/sentry/incidents/models/alert_rule.py

@@ -376,12 +376,7 @@ class ActionHandlerFactory(abc.ABC):
         self.integration_provider = integration_provider
 
     @abc.abstractmethod
-    def build_handler(
-        self,
-        action: AlertRuleTriggerAction,
-        incident: Incident,
-        project: Project,
-    ) -> ActionHandler:
+    def build_handler(self) -> ActionHandler:
         raise NotImplementedError
 
 
@@ -403,10 +398,8 @@ class _AlertRuleActionHandlerClassFactory(ActionHandlerFactory):
         super().__init__(slug, service_type, supported_target_types, integration_provider)
         self.trigger_action_class = trigger_action_class
 
-    def build_handler(
-        self, action: AlertRuleTriggerAction, incident: Incident, project: Project
-    ) -> ActionHandler:
-        return self.trigger_action_class(action, incident, project)
+    def build_handler(self) -> ActionHandler:
+        return self.trigger_action_class()
 
 
 class _FactoryRegistry:
@@ -492,15 +485,13 @@ class AlertRuleTriggerAction(AbstractNotificationAction):
             return self.target_identifier
         return None
 
-    def build_handler(
-        self, action: AlertRuleTriggerAction, incident: Incident, project: Project
-    ) -> ActionHandler | None:
-        service_type = AlertRuleTriggerAction.Type(self.type)
-        factory = self._factory_registrations.by_action_service.get(service_type)
+    @staticmethod
+    def build_handler(type: ActionService) -> ActionHandler | None:
+        factory = AlertRuleTriggerAction._factory_registrations.by_action_service.get(type)
         if factory is not None:
-            return factory.build_handler(action, incident, project)
+            return factory.build_handler()
         else:
-            metrics.incr(f"alert_rule_trigger.unhandled_type.{self.type}")
+            metrics.incr(f"alert_rule_trigger.unhandled_type.{type}")
             return None
 
     def fire(
@@ -512,9 +503,11 @@ class AlertRuleTriggerAction(AbstractNotificationAction):
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
     ) -> None:
-        handler = self.build_handler(action, incident, project)
+        handler = AlertRuleTriggerAction.build_handler(AlertRuleTriggerAction.Type(self.type))
         if handler:
-            return handler.fire(metric_value, new_status, notification_uuid)
+            return handler.fire(
+                action, incident, project, metric_value, new_status, notification_uuid
+            )
 
     def resolve(
         self,
@@ -525,9 +518,11 @@ class AlertRuleTriggerAction(AbstractNotificationAction):
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
     ) -> None:
-        handler = self.build_handler(action, incident, project)
+        handler = AlertRuleTriggerAction.build_handler(AlertRuleTriggerAction.Type(self.type))
         if handler:
-            return handler.resolve(metric_value, new_status, notification_uuid)
+            return handler.resolve(
+                action, incident, project, metric_value, new_status, notification_uuid
+            )
 
     def get_single_sentry_app_config(self) -> dict[str, Any] | None:
         value = self.sentry_app_config

+ 11 - 14
src/sentry/integrations/messaging/spec.py

@@ -161,14 +161,8 @@ class MessagingIntegrationSpec(ABC):
 
 
 class MessagingActionHandler(DefaultActionHandler):
-    def __init__(
-        self,
-        action: AlertRuleTriggerAction,
-        incident: Incident,
-        project: Project,
-        spec: MessagingIntegrationSpec,
-    ):
-        super().__init__(action, incident, project)
+    def __init__(self, spec: MessagingIntegrationSpec):
+        super().__init__()
         self._spec = spec
 
     @property
@@ -177,15 +171,20 @@ class MessagingActionHandler(DefaultActionHandler):
 
     def send_alert(
         self,
+        action: AlertRuleTriggerAction,
+        incident: Incident,
+        project: Project,
         metric_value: int | float,
         new_status: IncidentStatus,
         notification_uuid: str | None = None,
     ) -> None:
         success = self._spec.send_incident_alert_notification(
-            self.action, self.incident, metric_value, new_status, notification_uuid
+            action, incident, metric_value, new_status, notification_uuid
         )
         if success:
-            self.record_alert_sent_analytics(self.action.target_identifier, notification_uuid)
+            self.record_alert_sent_analytics(
+                action, incident, project, action.target_identifier, notification_uuid
+            )
 
 
 class _MessagingHandlerFactory(ActionHandlerFactory):
@@ -198,7 +197,5 @@ class _MessagingHandlerFactory(ActionHandlerFactory):
         )
         self.spec = spec
 
-    def build_handler(
-        self, action: AlertRuleTriggerAction, incident: Incident, project: Project
-    ) -> ActionHandler:
-        return MessagingActionHandler(action, incident, project, self.spec)
+    def build_handler(self) -> ActionHandler:
+        return MessagingActionHandler(self.spec)

+ 19 - 13
tests/sentry/incidents/action_handlers/test_discord.py

@@ -21,7 +21,7 @@ class DiscordActionHandlerTest(FireTest):
     @responses.activate
     def setUp(self):
         self.spec = DiscordMessagingSpec()
-
+        self.handler = MessagingActionHandler(self.spec)
         self.guild_id = "guild-id"
         self.channel_id = "12345678910"
         self.discord_user_id = "user1234"
@@ -59,10 +59,11 @@ class DiscordActionHandlerTest(FireTest):
             status=200,
         )
 
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            getattr(handler, method)(metric_value, IncidentStatus(incident.status))
+            getattr(self.handler, method)(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         data = orjson.loads(responses.calls[0].request.body)
         return data
@@ -86,10 +87,11 @@ class DiscordActionHandlerTest(FireTest):
             status=200,
         )
 
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert len(responses.calls) == 0
 
@@ -99,10 +101,11 @@ class DiscordActionHandlerTest(FireTest):
     def test_metric_alert_failure(self, mock_record_event, mock_send_message):
         alert_rule = self.create_alert_rule()
         incident = self.create_incident(alert_rule=alert_rule, status=IncidentStatus.CLOSED.value)
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus.WARNING)
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus.WARNING
+            )
 
         assert_slo_metric(mock_record_event, EventLifecycleOutcome.FAILURE)
 
@@ -114,10 +117,11 @@ class DiscordActionHandlerTest(FireTest):
     def test_metric_alert_halt_for_rate_limited(self, mock_record_event, mock_send_message):
         alert_rule = self.create_alert_rule()
         incident = self.create_incident(alert_rule=alert_rule, status=IncidentStatus.CLOSED.value)
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus.WARNING)
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus.WARNING
+            )
 
         assert_slo_metric(mock_record_event, EventLifecycleOutcome.HALTED)
 
@@ -132,10 +136,11 @@ class DiscordActionHandlerTest(FireTest):
     def test_metric_alert_halt_for_missing_access(self, mock_record_event, mock_send_message):
         alert_rule = self.create_alert_rule()
         incident = self.create_incident(alert_rule=alert_rule, status=IncidentStatus.CLOSED.value)
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus.WARNING)
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus.WARNING
+            )
 
         assert_slo_metric(mock_record_event, EventLifecycleOutcome.HALTED)
 
@@ -147,9 +152,10 @@ class DiscordActionHandlerTest(FireTest):
     def test_metric_alert_halt_for_other_api_error(self, mock_record_event, mock_send_message):
         alert_rule = self.create_alert_rule()
         incident = self.create_incident(alert_rule=alert_rule, status=IncidentStatus.CLOSED.value)
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus.WARNING)
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus.WARNING
+            )
 
         assert_slo_metric(mock_record_event, EventLifecycleOutcome.FAILURE)

+ 25 - 23
tests/sentry/incidents/action_handlers/test_email.py

@@ -50,9 +50,9 @@ class EmailActionHandlerTest(FireTest):
             target_identifier=str(self.user.id),
             triggered_for_incident=incident,
         )
-        handler = EmailActionHandler(action, incident, self.project)
+        handler = EmailActionHandler()
         with self.tasks():
-            handler.fire(1000, IncidentStatus(incident.status))
+            handler.fire(action, incident, self.project, 1000, IncidentStatus(incident.status))
         out = mail.outbox[0]
         assert out.to == [self.user.email]
         assert out.subject == "[{}] {} - {}".format(
@@ -81,6 +81,10 @@ class EmailActionHandlerTest(FireTest):
 
 
 class EmailActionHandlerGetTargetsTest(TestCase):
+    def setUp(self) -> None:
+        super().setUp()
+        self.handler = EmailActionHandler()
+
     @cached_property
     def incident(self):
         return self.create_incident()
@@ -90,8 +94,9 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_type=AlertRuleTriggerAction.TargetType.USER,
             target_identifier=str(self.user.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
-        assert handler.get_targets() == [(self.user.id, self.user.email)]
+        assert self.handler.get_targets(action, self.incident, self.project) == [
+            (self.user.id, self.user.email)
+        ]
 
     def test_rule_snoozed_by_user(self):
         action = self.create_alert_rule_trigger_action(
@@ -99,18 +104,16 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_identifier=str(self.user.id),
         )
 
-        handler = EmailActionHandler(action, self.incident, self.project)
         self.snooze_rule(user_id=self.user.id, alert_rule=self.incident.alert_rule)
-        assert handler.get_targets() == []
+        assert self.handler.get_targets(action, self.incident, self.project) == []
 
     def test_user_rule_snoozed(self):
         action = self.create_alert_rule_trigger_action(
             target_type=AlertRuleTriggerAction.TargetType.USER,
             target_identifier=str(self.user.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
         self.snooze_rule(alert_rule=self.incident.alert_rule)
-        assert handler.get_targets() == []
+        assert self.handler.get_targets(action, self.incident, self.project) == []
 
     def test_user_alerts_disabled(self):
         with assume_test_silo_mode_of(NotificationSettingOption):
@@ -125,8 +128,9 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_type=AlertRuleTriggerAction.TargetType.USER,
             target_identifier=str(self.user.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
-        assert handler.get_targets() == [(self.user.id, self.user.email)]
+        assert self.handler.get_targets(action, self.incident, self.project) == [
+            (self.user.id, self.user.email)
+        ]
 
     def test_team(self):
         new_user = self.create_user()
@@ -135,8 +139,7 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_type=AlertRuleTriggerAction.TargetType.TEAM,
             target_identifier=str(self.team.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
-        assert set(handler.get_targets()) == {
+        assert set(self.handler.get_targets(action, self.incident, self.project)) == {
             (self.user.id, self.user.email),
             (new_user.id, new_user.email),
         }
@@ -148,9 +151,8 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_type=AlertRuleTriggerAction.TargetType.TEAM,
             target_identifier=str(self.team.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
         self.snooze_rule(user_id=new_user.id, alert_rule=self.incident.alert_rule)
-        assert set(handler.get_targets()) == {
+        assert set(self.handler.get_targets(action, self.incident, self.project)) == {
             (self.user.id, self.user.email),
         }
 
@@ -161,9 +163,8 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_type=AlertRuleTriggerAction.TargetType.TEAM,
             target_identifier=str(self.team.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
         self.snooze_rule(alert_rule=self.incident.alert_rule)
-        assert handler.get_targets() == []
+        assert self.handler.get_targets(action, self.incident, self.project) == []
 
     def test_team_alert_disabled(self):
         with assume_test_silo_mode_of(NotificationSettingOption):
@@ -189,8 +190,9 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_type=AlertRuleTriggerAction.TargetType.TEAM,
             target_identifier=str(self.team.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
-        assert set(handler.get_targets()) == {(new_user.id, new_user.email)}
+        assert set(self.handler.get_targets(action, self.incident, self.project)) == {
+            (new_user.id, new_user.email),
+        }
 
     def test_user_email_routing(self):
         new_email = "marcos@sentry.io"
@@ -207,8 +209,9 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_type=AlertRuleTriggerAction.TargetType.USER,
             target_identifier=str(self.user.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
-        assert handler.get_targets() == [(self.user.id, new_email)]
+        assert self.handler.get_targets(action, self.incident, self.project) == [
+            (self.user.id, new_email),
+        ]
 
     def test_team_email_routing(self):
         new_email = "marcos@sentry.io"
@@ -232,11 +235,10 @@ class EmailActionHandlerGetTargetsTest(TestCase):
             target_type=AlertRuleTriggerAction.TargetType.TEAM,
             target_identifier=str(self.team.id),
         )
-        handler = EmailActionHandler(action, self.incident, self.project)
-        assert set(handler.get_targets()) == {
+        assert self.handler.get_targets(action, self.incident, self.project) == [
             (self.user.id, new_email),
             (new_user.id, new_email),
-        }
+        ]
 
 
 @freeze_time()

+ 13 - 8
tests/sentry/incidents/action_handlers/test_msteams.py

@@ -41,6 +41,7 @@ class MsTeamsActionHandlerTest(FireTest):
     @responses.activate
     def setUp(self):
         self.spec = MsTeamsMessagingSpec()
+        self.handler = MessagingActionHandler(self.spec)
 
         with assume_test_silo_mode(SiloMode.CONTROL):
             integration = self.create_provider_integration(
@@ -86,10 +87,11 @@ class MsTeamsActionHandlerTest(FireTest):
             json={},
         )
 
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            getattr(handler, method)(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
         data = json.loads(responses.calls[0].request.body)
 
         assert data["attachments"][0]["content"] == build_incident_attachment(
@@ -231,10 +233,11 @@ class MsTeamsActionHandlerTest(FireTest):
             json={},
         )
 
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert len(responses.calls) == 0
 
@@ -253,10 +256,11 @@ class MsTeamsActionHandlerTest(FireTest):
             json={},
         )
 
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            getattr(handler, "fire")(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert_slo_metric(mock_record, EventLifecycleOutcome.FAILURE)
 
@@ -280,9 +284,10 @@ class MsTeamsActionHandlerTest(FireTest):
             },
         )
 
-        handler = MessagingActionHandler(self.action, incident, self.project, self.spec)
         metric_value = 1000
         with self.tasks():
-            getattr(handler, "fire")(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert_slo_metric(mock_record, EventLifecycleOutcome.HALTED)

+ 13 - 8
tests/sentry/incidents/action_handlers/test_opsgenie.py

@@ -35,6 +35,7 @@ METADATA = {
 class OpsgenieActionHandlerTest(FireTest):
     @responses.activate
     def setUp(self):
+        self.handler = OpsgenieActionHandler()
         self.og_team = {"id": "123-id", "team": "cool-team", "integration_key": "1234-5678"}
         self.integration = self.create_provider_integration(
             provider="opsgenie", name="hello-world", external_id="hello-world", metadata=METADATA
@@ -210,10 +211,11 @@ class OpsgenieActionHandlerTest(FireTest):
             )
             expected_payload = attach_custom_priority(expected_payload, self.action, new_status)
 
-        handler = OpsgenieActionHandler(self.action, incident, self.project)
         metric_value = 1000
         with self.tasks():
-            getattr(handler, method)(metric_value, IncidentStatus(incident.status))
+            getattr(self.handler, method)(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
         data = responses.calls[0].request.body
 
         assert json.loads(data) == expected_payload
@@ -247,10 +249,11 @@ class OpsgenieActionHandlerTest(FireTest):
             json={},
             status=202,
         )
-        handler = OpsgenieActionHandler(self.action, incident, self.project)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert len(responses.calls) == 0
 
@@ -263,10 +266,11 @@ class OpsgenieActionHandlerTest(FireTest):
         with assume_test_silo_mode_of(Integration):
             self.integration.delete()
 
-        handler = OpsgenieActionHandler(self.action, incident, self.project)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert len(responses.calls) == 0
         assert (
@@ -284,10 +288,11 @@ class OpsgenieActionHandlerTest(FireTest):
         with assume_test_silo_mode_of(OrganizationIntegration):
             self.org_integration.save()
 
-        handler = OpsgenieActionHandler(self.action, incident, self.project)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert len(responses.calls) == 0
         assert (

+ 8 - 4
tests/sentry/incidents/action_handlers/test_pagerduty.py

@@ -32,6 +32,7 @@ from . import FireTest
 class PagerDutyActionHandlerTest(FireTest):
     def setUp(self):
         self.integration_key = "pfc73e8cb4s44d519f3d63d45b5q77g9"
+        self.handler = PagerDutyActionHandler()
         service = [
             {
                 "type": "service",
@@ -176,11 +177,13 @@ class PagerDutyActionHandlerTest(FireTest):
             status=202,
             content_type="application/json",
         )
-        handler = PagerDutyActionHandler(self.action, incident, self.project)
+
         metric_value = 1000
         new_status = IncidentStatus(incident.status)
         with self.tasks():
-            getattr(handler, method)(metric_value, new_status)
+            getattr(self.handler, method)(
+                self.action, incident, self.project, metric_value, new_status
+            )
         data = responses.calls[0].request.body
 
         expected_payload = build_incident_attachment(
@@ -242,10 +245,11 @@ class PagerDutyActionHandlerTest(FireTest):
             status=202,
             content_type="application/json",
         )
-        handler = PagerDutyActionHandler(self.action, incident, self.project)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert len(responses.calls) == 0
 

+ 12 - 6
tests/sentry/incidents/action_handlers/test_sentry_app.py

@@ -31,6 +31,8 @@ class SentryAppActionHandlerTest(FireTest):
             sentry_app=self.sentry_app,
         )
 
+        self.handler = SentryAppActionHandler()
+
     @responses.activate
     def run_test(self, incident, method):
         from sentry.rules.actions.notify_event_service import build_incident_attachment
@@ -43,10 +45,11 @@ class SentryAppActionHandlerTest(FireTest):
             body=json.dumps({"ok": "true"}),
         )
 
-        handler = SentryAppActionHandler(self.action, incident, self.project)
         metric_value = 1000
         with self.tasks():
-            getattr(handler, method)(metric_value, IncidentStatus(incident.status))
+            getattr(self.handler, method)(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
         data = responses.calls[0].request.body
         assert (
             json.dumps(
@@ -69,10 +72,11 @@ class SentryAppActionHandlerTest(FireTest):
             body=json.dumps({"ok": "true"}),
         )
 
-        handler = SentryAppActionHandler(self.action, incident, self.project)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert len(responses.calls) == 0
 
@@ -100,6 +104,7 @@ class SentryAppAlertRuleUIComponentActionHandlerTest(FireTest):
         self.create_sentry_app_installation(
             slug=self.sentry_app.slug, organization=self.organization, user=self.user
         )
+        self.handler = SentryAppActionHandler()
 
     @responses.activate
     def run_test(self, incident, method):
@@ -129,10 +134,11 @@ class SentryAppAlertRuleUIComponentActionHandlerTest(FireTest):
             body=json.dumps({"ok": "true"}),
         )
 
-        handler = SentryAppActionHandler(self.action, incident, self.project)
         metric_value = 1000
         with self.tasks():
-            getattr(handler, method)(metric_value, IncidentStatus(incident.status))
+            getattr(self.handler, method)(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
         data = responses.calls[0].request.body
         assert (
             json.dumps(

+ 12 - 14
tests/sentry/incidents/action_handlers/test_slack.py

@@ -9,7 +9,7 @@ from slack_sdk.web import SlackResponse
 from sentry.constants import ObjectStatus
 from sentry.incidents.logic import update_incident_status
 from sentry.incidents.models.alert_rule import AlertRuleTriggerAction
-from sentry.incidents.models.incident import Incident, IncidentStatus, IncidentStatusMethod
+from sentry.incidents.models.incident import IncidentStatus, IncidentStatusMethod
 from sentry.integrations.messaging.spec import MessagingActionHandler
 from sentry.integrations.metric_alerts import AlertContext
 from sentry.integrations.slack.message_builder.incidents import SlackIncidentsMessageBuilder
@@ -45,6 +45,7 @@ class SlackActionHandlerTest(FireTest):
     @responses.activate
     def setUp(self):
         self.spec = SlackMessagingSpec()
+        self.handler = MessagingActionHandler(self.spec)
 
         token = "xoxp-xxxxxxxxx-xxxxxxxxxx-xxxxxxxxxxxx"
         self.integration = self.create_integration(
@@ -79,18 +80,12 @@ class SlackActionHandlerTest(FireTest):
         )
         self.alert_rule = self.create_alert_rule()
 
-    def _build_action_handler(
-        self, action: AlertRuleTriggerAction, incident: Incident
-    ) -> MessagingActionHandler:
-        return MessagingActionHandler(action, incident, self.project, self.spec)
-
     def run_test(self, incident, method, **kwargs):
         chart_url = kwargs.get("chart_url")
-        handler = self._build_action_handler(self.action, incident)
         metric_value = 1000
         status = IncidentStatus(incident.status)
         with self.tasks():
-            getattr(handler, method)(metric_value, status)
+            getattr(self.handler, method)(self.action, incident, self.project, metric_value, status)
 
         return incident, chart_url
 
@@ -256,10 +251,11 @@ class SlackActionHandlerTest(FireTest):
             sentry_app_id=None,
         )
 
-        handler = self._build_action_handler(action, incident)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
     @patch("sentry.integrations.slack.sdk_client.SlackSdkClient.chat_postMessage")
     def test_rule_snoozed(self, mock_post):
@@ -267,10 +263,11 @@ class SlackActionHandlerTest(FireTest):
         incident = self.create_incident(alert_rule=alert_rule, status=IncidentStatus.CLOSED.value)
         self.snooze_rule(alert_rule=alert_rule)
 
-        handler = self._build_action_handler(self.action, incident)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         assert not mock_post.called
 
@@ -283,10 +280,11 @@ class SlackActionHandlerTest(FireTest):
         incident = self.create_incident(alert_rule=alert_rule, status=IncidentStatus.CLOSED.value)
         self.snooze_rule(user_id=self.user.id, alert_rule=alert_rule)
 
-        handler = self._build_action_handler(self.action, incident)
         metric_value = 1000
         with self.tasks():
-            handler.fire(metric_value, IncidentStatus(incident.status))
+            self.handler.fire(
+                self.action, incident, self.project, metric_value, IncidentStatus(incident.status)
+            )
 
         mock_post.assert_called
 

Some files were not shown because too many files changed in this diff