Browse Source

fix(hybrid-cloud): Mark DiscordInteractionsEndpoint as an all silo endpoint (#60938)

Alberto Leal 1 year ago
parent
commit
2fe6fa4fa9

+ 2 - 2
src/sentry/integrations/discord/webhooks/base.py

@@ -8,7 +8,7 @@ from rest_framework.request import Request
 
 
 from sentry import analytics
 from sentry import analytics
 from sentry.api.api_publish_status import ApiPublishStatus
 from sentry.api.api_publish_status import ApiPublishStatus
-from sentry.api.base import Endpoint, region_silo_endpoint
+from sentry.api.base import Endpoint, all_silo_endpoint
 from sentry.integrations.discord.requests.base import DiscordRequest, DiscordRequestError
 from sentry.integrations.discord.requests.base import DiscordRequest, DiscordRequestError
 from sentry.integrations.discord.webhooks.command import DiscordCommandHandler
 from sentry.integrations.discord.webhooks.command import DiscordCommandHandler
 from sentry.integrations.discord.webhooks.message_component import DiscordMessageComponentHandler
 from sentry.integrations.discord.webhooks.message_component import DiscordMessageComponentHandler
@@ -19,7 +19,7 @@ from .types import DiscordResponseTypes
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-@region_silo_endpoint
+@all_silo_endpoint
 class DiscordInteractionsEndpoint(Endpoint):
 class DiscordInteractionsEndpoint(Endpoint):
     publish_status = {
     publish_status = {
         "POST": ApiPublishStatus.UNKNOWN,
         "POST": ApiPublishStatus.UNKNOWN,

+ 35 - 18
src/sentry/integrations/discord/webhooks/message_component.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 from collections.abc import Mapping
 from collections.abc import Mapping
+from functools import cached_property
 
 
 from rest_framework.response import Response
 from rest_framework.response import Response
 
 
@@ -40,6 +41,7 @@ RESOLVED_IN_CURRENT_RELEASE = "The issue will be resolved in the current release
 UNRESOLVED = "The issue has been unresolved."
 UNRESOLVED = "The issue has been unresolved."
 MARKED_ONGOING = "The issue has been marked as ongoing."
 MARKED_ONGOING = "The issue has been marked as ongoing."
 ARCHIVE_UNTIL_ESCALATES = "The issue will be archived until it escalates."
 ARCHIVE_UNTIL_ESCALATES = "The issue will be archived until it escalates."
+INVALID_GROUP_ID = "Invalid group ID."
 
 
 
 
 class DiscordMessageComponentHandler(DiscordInteractionHandler):
 class DiscordMessageComponentHandler(DiscordInteractionHandler):
@@ -54,8 +56,16 @@ class DiscordMessageComponentHandler(DiscordInteractionHandler):
         self.custom_id: str = request.get_component_custom_id()
         self.custom_id: str = request.get_component_custom_id()
         self.user: RpcUser
         self.user: RpcUser
         # Everything after the colon is the group id in a custom_id
         # Everything after the colon is the group id in a custom_id
-        self.group_id: str = self.custom_id.split(":")[1]
-        self.group: Group = Group.objects.get(id=self.group_id)
+        custom_id_parts = self.custom_id.split(":")
+        self.group_id: str = custom_id_parts[1] if len(custom_id_parts) > 1 else ""
+
+    @cached_property
+    def group(self) -> Group | None:
+        try:
+            group_id = int(self.group_id)
+            return Group.objects.filter(id=group_id).first()
+        except Exception:
+            return None
 
 
     def handle(self) -> Response:
     def handle(self) -> Response:
         logging_data = self.request.logging_data
         logging_data = self.request.logging_data
@@ -65,6 +75,9 @@ class DiscordMessageComponentHandler(DiscordInteractionHandler):
             return self.send_message(NO_IDENTITY)
             return self.send_message(NO_IDENTITY)
         self.user = self.request.user
         self.user = self.request.user
 
 
+        if (not self.group_id) or (not self.group):
+            return self.send_message(INVALID_GROUP_ID)
+
         if not self.group.organization.has_access(self.user):
         if not self.group.organization.has_access(self.user):
             logger.warning(
             logger.warning(
                 "discord.interaction.component.not_in_org",
                 "discord.interaction.component.not_in_org",
@@ -104,9 +117,12 @@ class DiscordMessageComponentHandler(DiscordInteractionHandler):
             return self.archive()
             return self.archive()
 
 
         logger.warning("discord.interaction.component.unknown_custom_id", extra={**logging_data})
         logger.warning("discord.interaction.component.unknown_custom_id", extra={**logging_data})
-        return Response(status=404)
+        return self.send_message(INVALID_GROUP_ID)
 
 
     def assign_dialog(self) -> Response:
     def assign_dialog(self) -> Response:
+        if (not self.group_id) or (not self.group):
+            return self.send_message(INVALID_GROUP_ID)
+
         assign_selector = DiscordSelectMenu(
         assign_selector = DiscordSelectMenu(
             custom_id=f"{CustomIds.ASSIGN}:{self.group_id}",
             custom_id=f"{CustomIds.ASSIGN}:{self.group_id}",
             placeholder="Select Assignee...",
             placeholder="Select Assignee...",
@@ -203,21 +219,22 @@ class DiscordMessageComponentHandler(DiscordInteractionHandler):
         return self.send_message(ARCHIVE_UNTIL_ESCALATES)
         return self.send_message(ARCHIVE_UNTIL_ESCALATES)
 
 
     def update_group(self, data: Mapping[str, object]) -> None:
     def update_group(self, data: Mapping[str, object]) -> None:
-        analytics.record(
-            "integrations.discord.status",
-            organization_id=self.group.organization.id,
-            user_id=self.user.id,
-            status=data,
-        )
-        update_groups(
-            request=self.request.request,
-            group_ids=[self.group.id],
-            projects=[self.group.project],
-            organization_id=self.group.organization.id,
-            search_fn=None,
-            user=self.user,  # type: ignore
-            data=data,
-        )
+        if self.group:
+            analytics.record(
+                "integrations.discord.status",
+                organization_id=self.group.organization.id,
+                user_id=self.user.id,
+                status=data,
+            )
+            update_groups(
+                request=self.request.request,
+                group_ids=[self.group.id],
+                projects=[self.group.project],
+                organization_id=self.group.organization.id,
+                search_fn=None,
+                user=self.user,  # type: ignore
+                data=data,
+            )
 
 
 
 
 def get_assign_selector_options(group: Group) -> list[DiscordSelectMenuOption]:
 def get_assign_selector_options(group: Group) -> list[DiscordSelectMenuOption]:

+ 1 - 0
static/app/data/controlsiloUrlPatterns.ts

@@ -172,6 +172,7 @@ const patterns: RegExp[] = [
   new RegExp('^extensions/msteams/configure/$'),
   new RegExp('^extensions/msteams/configure/$'),
   new RegExp('^extensions/msteams/link-identity/[^/]+/$'),
   new RegExp('^extensions/msteams/link-identity/[^/]+/$'),
   new RegExp('^extensions/msteams/unlink-identity/[^/]+/$'),
   new RegExp('^extensions/msteams/unlink-identity/[^/]+/$'),
+  new RegExp('^extensions/discord/interactions/$'),
   new RegExp('^extensions/discord/link-identity/[^/]+/$'),
   new RegExp('^extensions/discord/link-identity/[^/]+/$'),
   new RegExp('^extensions/discord/unlink-identity/[^/]+/$'),
   new RegExp('^extensions/discord/unlink-identity/[^/]+/$'),
   new RegExp('^share/(?:group|issue)/[^/]+/$'),
   new RegExp('^share/(?:group|issue)/[^/]+/$'),

+ 63 - 3
tests/sentry/integrations/discord/webhooks/test_message_component.py

@@ -13,6 +13,7 @@ from sentry.integrations.discord.requests.base import (
 from sentry.integrations.discord.webhooks.message_component import (
 from sentry.integrations.discord.webhooks.message_component import (
     ARCHIVE_UNTIL_ESCALATES,
     ARCHIVE_UNTIL_ESCALATES,
     ASSIGNEE_UPDATED,
     ASSIGNEE_UPDATED,
+    INVALID_GROUP_ID,
     MARKED_ONGOING,
     MARKED_ONGOING,
     NO_IDENTITY,
     NO_IDENTITY,
     NOT_IN_ORG,
     NOT_IN_ORG,
@@ -23,11 +24,14 @@ from sentry.integrations.discord.webhooks.message_component import (
     UNRESOLVED,
     UNRESOLVED,
 )
 )
 from sentry.models.release import Release
 from sentry.models.release import Release
+from sentry.silo.base import SiloMode
 from sentry.testutils.cases import APITestCase
 from sentry.testutils.cases import APITestCase
+from sentry.testutils.silo import assume_test_silo_mode, region_silo_test
 
 
 WEBHOOK_URL = "/extensions/discord/interactions/"
 WEBHOOK_URL = "/extensions/discord/interactions/"
 
 
 
 
+@region_silo_test
 class DiscordMessageComponentInteractionTest(APITestCase):
 class DiscordMessageComponentInteractionTest(APITestCase):
     def setUp(self):
     def setUp(self):
         patcher = mock.patch(
         patcher = mock.patch(
@@ -79,25 +83,61 @@ class DiscordMessageComponentInteractionTest(APITestCase):
     def get_select_options(self, response: Any) -> Any:
     def get_select_options(self, response: Any) -> Any:
         return self.get_message_components(response)[0]["components"][0]["options"]
         return self.get_message_components(response)[0]["components"][0]["options"]
 
 
-    def test_unknown_id_interaction(self):
+    def test_unknown_custom_id_interaction(self):
         response = self.send_interaction({"custom_id": f"unknown:{self.group.id}"})
         response = self.send_interaction({"custom_id": f"unknown:{self.group.id}"})
-        assert response.status_code == 404
+        assert response.status_code == 200
+        assert self.get_message_content(response) == INVALID_GROUP_ID
+
+    def test_empty_custom_id_interaction(self):
+        response = self.send_interaction({"custom_id": ""})
+        assert response.status_code == 200
+        assert self.get_message_content(response) == INVALID_GROUP_ID
 
 
     def test_no_user(self):
     def test_no_user(self):
         response = self.send_interaction(member={"user": {"id": "not-our-user"}})
         response = self.send_interaction(member={"user": {"id": "not-our-user"}})
+        assert response.status_code == 200
+        assert self.get_message_content(response) == NO_IDENTITY
+
+    def test_no_guild_id(self):
+        response = self.client.post(
+            path=WEBHOOK_URL,
+            data={
+                "type": DiscordRequestTypes.MESSAGE_COMPONENT,
+            },
+            format="json",
+            HTTP_X_SIGNATURE_ED25519="signature",
+            HTTP_X_SIGNATURE_TIMESTAMP="timestamp",
+        )
+        assert response.status_code == 200
+        assert self.get_message_content(response) == NO_IDENTITY
+
+    def test_invalid_guild_id(self):
+        response = self.client.post(
+            path=WEBHOOK_URL,
+            data={
+                "type": DiscordRequestTypes.MESSAGE_COMPONENT,
+                "guild_id": "invalid_guild_id",
+            },
+            format="json",
+            HTTP_X_SIGNATURE_ED25519="signature",
+            HTTP_X_SIGNATURE_TIMESTAMP="timestamp",
+        )
+        assert response.status_code == 200
         assert self.get_message_content(response) == NO_IDENTITY
         assert self.get_message_content(response) == NO_IDENTITY
 
 
     def test_not_in_org(self):
     def test_not_in_org(self):
         other_user = self.create_user()
         other_user = self.create_user()
         other_user_discord_id = "other-user1234"
         other_user_discord_id = "other-user1234"
         other_org = self.create_organization()
         other_org = self.create_organization()
-        self.discord_integration.add_organization(other_org)
+        with assume_test_silo_mode(SiloMode.CONTROL):
+            self.discord_integration.add_organization(other_org)
         self.create_identity(
         self.create_identity(
             user=other_user, identity_provider=self.provider, external_id=other_user_discord_id
             user=other_user, identity_provider=self.provider, external_id=other_user_discord_id
         )
         )
 
 
         response = self.send_interaction(member={"user": {"id": other_user_discord_id}})
         response = self.send_interaction(member={"user": {"id": other_user_discord_id}})
 
 
+        assert response.status_code == 200
         assert self.get_message_content(response) == NOT_IN_ORG
         assert self.get_message_content(response) == NOT_IN_ORG
 
 
     def test_assign_dialog(self):
     def test_assign_dialog(self):
@@ -107,11 +147,22 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "custom_id": f"{CustomIds.ASSIGN_DIALOG}:{self.group.id}",
                 "custom_id": f"{CustomIds.ASSIGN_DIALOG}:{self.group.id}",
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_select_options(response) == [
         assert self.get_select_options(response) == [
             {"label": f"#{self.team.slug}", "value": f"team:{self.team.id}", "default": False},
             {"label": f"#{self.team.slug}", "value": f"team:{self.team.id}", "default": False},
             {"label": self.user.email, "value": f"user:{self.user.id}", "default": False},
             {"label": self.user.email, "value": f"user:{self.user.id}", "default": False},
         ]
         ]
 
 
+    def test_assign_dialog_invalid_group_id(self):
+        response = self.send_interaction(
+            {
+                "component_type": DiscordMessageComponentTypes.BUTTON,
+                "custom_id": f"{CustomIds.ASSIGN_DIALOG}:invalid",
+            }
+        )
+        assert response.status_code == 200
+        assert self.get_message_content(response) == INVALID_GROUP_ID
+
     def test_assign(self):
     def test_assign(self):
         response = self.send_interaction(
         response = self.send_interaction(
             {
             {
@@ -120,6 +171,7 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "values": [f"user:{self.user.id}"],
                 "values": [f"user:{self.user.id}"],
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_message_content(response) == ASSIGNEE_UPDATED
         assert self.get_message_content(response) == ASSIGNEE_UPDATED
 
 
     def test_resolve_dialog(self):
     def test_resolve_dialog(self):
@@ -129,6 +181,7 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "custom_id": f"{CustomIds.RESOLVE_DIALOG}:{self.group.id}",
                 "custom_id": f"{CustomIds.RESOLVE_DIALOG}:{self.group.id}",
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_select_options(response) == [
         assert self.get_select_options(response) == [
             option.build() for option in RESOLVE_DIALOG_OPTIONS
             option.build() for option in RESOLVE_DIALOG_OPTIONS
         ]
         ]
@@ -140,6 +193,7 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "custom_id": f"{CustomIds.RESOLVE}:{self.group.id}",
                 "custom_id": f"{CustomIds.RESOLVE}:{self.group.id}",
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_message_content(response) == RESOLVED
         assert self.get_message_content(response) == RESOLVED
 
 
     def test_resolve_now_from_dialog(self):
     def test_resolve_now_from_dialog(self):
@@ -150,6 +204,7 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "values": [""],
                 "values": [""],
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_message_content(response) == RESOLVED
         assert self.get_message_content(response) == RESOLVED
 
 
     def test_resolve_in_next_release(self):
     def test_resolve_in_next_release(self):
@@ -165,6 +220,7 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "values": ["inNextRelease"],
                 "values": ["inNextRelease"],
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_message_content(response) == RESOLVED_IN_NEXT_RELEASE
         assert self.get_message_content(response) == RESOLVED_IN_NEXT_RELEASE
 
 
     def test_resolve_in_current_release(self):
     def test_resolve_in_current_release(self):
@@ -180,6 +236,7 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "values": ["inCurrentRelease"],
                 "values": ["inCurrentRelease"],
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_message_content(response) == RESOLVED_IN_CURRENT_RELEASE
         assert self.get_message_content(response) == RESOLVED_IN_CURRENT_RELEASE
 
 
     def test_unresolve(self):
     def test_unresolve(self):
@@ -189,6 +246,7 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "custom_id": f"{CustomIds.UNRESOLVE}:{self.group.id}",
                 "custom_id": f"{CustomIds.UNRESOLVE}:{self.group.id}",
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_message_content(response) == UNRESOLVED
         assert self.get_message_content(response) == UNRESOLVED
 
 
     def test_mark_ongoing(self):
     def test_mark_ongoing(self):
@@ -198,6 +256,7 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "custom_id": f"{CustomIds.MARK_ONGOING}:{self.group.id}",
                 "custom_id": f"{CustomIds.MARK_ONGOING}:{self.group.id}",
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_message_content(response) == MARKED_ONGOING
         assert self.get_message_content(response) == MARKED_ONGOING
 
 
     def test_archive(self):
     def test_archive(self):
@@ -207,4 +266,5 @@ class DiscordMessageComponentInteractionTest(APITestCase):
                 "custom_id": f"{CustomIds.ARCHIVE}:{self.group.id}",
                 "custom_id": f"{CustomIds.ARCHIVE}:{self.group.id}",
             }
             }
         )
         )
+        assert response.status_code == 200
         assert self.get_message_content(response) == ARCHIVE_UNTIL_ESCALATES
         assert self.get_message_content(response) == ARCHIVE_UNTIL_ESCALATES

+ 21 - 1
tests/sentry/middleware/integrations/parsers/test_discord.py

@@ -112,7 +112,7 @@ class DiscordRequestParserTest(TestCase):
         mock_verify_signature.return_value = None
         mock_verify_signature.return_value = None
         data = {
         data = {
             "guild_id": self.integration.external_id,
             "guild_id": self.integration.external_id,
-            "name": "command_name",
+            "data": {"name": "command_name"},
             "type": int(DiscordRequestTypes.COMMAND),
             "type": int(DiscordRequestTypes.COMMAND),
         }
         }
         parser = self.get_parser(reverse("sentry-integration-discord-interactions"), data=data)
         parser = self.get_parser(reverse("sentry-integration-discord-interactions"), data=data)
@@ -126,6 +126,26 @@ class DiscordRequestParserTest(TestCase):
         integration = parser.get_integration_from_request()
         integration = parser.get_integration_from_request()
         assert integration == self.integration
         assert integration == self.integration
 
 
+    @patch("sentry.integrations.discord.requests.base.verify_signature")
+    def test_interactions_endpoint_routing_command_no_integration(self, mock_verify_signature):
+        mock_verify_signature.return_value = None
+        data = {
+            "data": {"name": "command_name"},
+            "type": int(DiscordRequestTypes.COMMAND),
+        }
+        parser = self.get_parser(reverse("sentry-integration-discord-interactions"), data=data)
+        with patch.object(parser, "get_regions_from_organizations", return_value=[]), patch.object(
+            parser, "get_response_from_first_region"
+        ) as mock_respond_from_first_region, patch.object(
+            parser, "get_response_from_control_silo"
+        ) as mock_response_from_control, assume_test_silo_mode(
+            SiloMode.CONTROL, can_be_monolith=False
+        ):
+            parser.get_response()
+            assert not mock_respond_from_first_region.called
+            assert mock_response_from_control.called
+        assert parser.get_integration_from_request() is None
+
     @patch("sentry.integrations.discord.requests.base.verify_signature")
     @patch("sentry.integrations.discord.requests.base.verify_signature")
     def test_interactions_endpoint_routing_message_component(self, mock_verify_signature):
     def test_interactions_endpoint_routing_message_component(self, mock_verify_signature):
         mock_verify_signature.return_value = None
         mock_verify_signature.return_value = None