Browse Source

chore(hybrid-cloud): Add client proxy for vsts integration (#51892)

Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
Alberto Leal 1 year ago
parent
commit
d2856987b9

+ 38 - 16
fixtures/vsts.py

@@ -128,22 +128,44 @@ class VstsIntegrationTestCase(IntegrationTestCase):
             },
         )
 
-        responses.add(
-            responses.GET,
-            "https://{}.visualstudio.com/{}/_apis/wit/workitemtypes/{}/states".format(
-                self.vsts_account_name.lower(), self.project_a["name"], "Bug"
-            ),
-            json={
-                "value": [
-                    {"name": "resolve_status"},
-                    {"name": "resolve_when"},
-                    {"name": "regression_status"},
-                    {"name": "sync_comments"},
-                    {"name": "sync_forward_assignment"},
-                    {"name": "sync_reverse_assignment"},
-                ]
-            },
-        )
+        for project in [self.project_a, self.project_b]:
+            responses.add(
+                responses.GET,
+                "https://{}.visualstudio.com/{}/_apis/wit/workitemtypes/{}/states".format(
+                    self.vsts_account_name.lower(), project["id"], "Bug"
+                ),
+                json={
+                    "count": 6,
+                    "value": [
+                        {"name": "resolve_status"},
+                        {"name": "resolve_when"},
+                        {"name": "regression_status"},
+                        {"name": "sync_comments"},
+                        {"name": "sync_forward_assignment"},
+                        {"name": "sync_reverse_assignment"},
+                    ],
+                },
+            )
+            responses.add(
+                responses.GET,
+                "https://{}.visualstudio.com/{}/_apis/wit/workitemtypes/{}/states".format(
+                    self.vsts_account_name.lower(), project["id"], "Issue"
+                ),
+                json={
+                    "count": 0,
+                    "value": [],
+                },
+            )
+            responses.add(
+                responses.GET,
+                "https://{}.visualstudio.com/{}/_apis/wit/workitemtypes/{}/states".format(
+                    self.vsts_account_name.lower(), project["id"], "Task"
+                ),
+                json={
+                    "count": 0,
+                    "value": [],
+                },
+            )
 
     def make_init_request(self, path=None, body=None):
         return self.client.get(path or self.init_path, body or {})

+ 0 - 3
pyproject.toml

@@ -555,7 +555,6 @@ module = [
     "sentry.integrations.jira_server.client",
     "sentry.integrations.jira_server.integration",
     "sentry.integrations.jira_server.search",
-    "sentry.integrations.jira_server.utils.api",
     "sentry.integrations.jira_server.webhooks",
     "sentry.integrations.manager",
     "sentry.integrations.message_builder",
@@ -609,7 +608,6 @@ module = [
     "sentry.integrations.vsts.integration",
     "sentry.integrations.vsts.issues",
     "sentry.integrations.vsts.repository",
-    "sentry.integrations.vsts.webhooks",
     "sentry.integrations.vsts_extension.integration",
     "sentry.interfaces.breadcrumbs",
     "sentry.interfaces.contexts",
@@ -1331,7 +1329,6 @@ module = [
     "tests.sentry.models.test_authenticator",
     "tests.sentry.models.test_avatar",
     "tests.sentry.models.test_eventuser",
-    "tests.sentry.models.test_groupassignee",
     "tests.sentry.models.test_grouphistory",
     "tests.sentry.models.test_groupsnooze",
     "tests.sentry.models.test_notificationaction",

+ 2 - 0
src/sentry/integrations/client.py

@@ -1,6 +1,7 @@
 from time import time
 
 from sentry.exceptions import InvalidIdentity
+from sentry.services.hybrid_cloud.util import control_silo_function
 from sentry.shared_integrations.client import BaseApiClient
 
 
@@ -16,6 +17,7 @@ class ApiClient(BaseApiClient):
 
 
 class OAuth2RefreshMixin:
+    @control_silo_function
     def check_auth(self, *args, **kwargs):
         """
         Checks if auth is expired and if so refreshes it

+ 9 - 2
src/sentry/integrations/utils/sync.py

@@ -7,12 +7,14 @@ from sentry import features
 from sentry.models import Group, GroupAssignee, Organization, Project
 from sentry.services.hybrid_cloud.integration import integration_service
 from sentry.services.hybrid_cloud.user.service import user_service
+from sentry.services.hybrid_cloud.util import region_silo_function
 from sentry.tasks.integrations import sync_assignee_outbound
 
 if TYPE_CHECKING:
     from sentry.services.hybrid_cloud.integration import RpcIntegration
 
 
+@region_silo_function
 def where_should_sync(
     integration: RpcIntegration,
     key: str,
@@ -24,10 +26,14 @@ def where_should_sync(
     check the integration for that organization.
     """
     kwargs = dict()
-    if organization_id:
+    if organization_id is not None:
         kwargs["id"] = organization_id
+        ois = integration_service.get_organization_integrations(
+            integration_id=integration.id, organization_id=organization_id
+        )
+    else:
+        ois = integration_service.get_organization_integrations(integration_id=integration.id)
 
-    ois = integration_service.get_organization_integrations(integration_id=integration.id)
     organizations = Organization.objects.filter(id__in=[oi.organization_id for oi in ois])
 
     return [
@@ -52,6 +58,7 @@ def get_user_id(projects_by_user: Mapping[int, Sequence[int]], group: Group) ->
     return user_ids[0]
 
 
+@region_silo_function
 def sync_group_assignee_inbound(
     integration: RpcIntegration,
     email: str | None,

+ 111 - 46
src/sentry/integrations/vsts/client.py

@@ -1,12 +1,19 @@
+from __future__ import annotations
+
 from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
 
+from requests import PreparedRequest
 from rest_framework.response import Response
 
 from sentry.integrations.client import ApiClient, OAuth2RefreshMixin
+from sentry.models import Identity
+from sentry.services.hybrid_cloud.util import control_silo_function
+from sentry.shared_integrations.client.base import BaseApiResponseX
+from sentry.shared_integrations.client.proxy import IntegrationProxyClient
 from sentry.utils.http import absolute_uri
 
 if TYPE_CHECKING:
-    from sentry.models import Identity, Project
+    from sentry.models import Project
 
 UNSET = object()
 
@@ -72,49 +79,124 @@ class VstsApiPath:
     work_item_categories = "{instance}{project}/_apis/wit/workitemtypecategories"
 
 
-class VstsApiClient(ApiClient, OAuth2RefreshMixin):
+def prepare_headers(
+    api_version: str,
+    method: str,
+    api_version_preview: str,
+):
+
+    headers = {
+        "Accept": f"application/json; api-version={api_version}{api_version_preview}",
+        "Content-Type": "application/json-patch+json" if method == "PATCH" else "application/json",
+        "X-HTTP-Method-Override": method,
+        "X-TFS-FedAuthRedirect": "Suppress",
+    }
+    return headers
+
+
+def prepare_auth_header(
+    access_token: str,
+):
+    headers = {
+        "Authorization": f"Bearer {access_token}",
+    }
+    return headers
+
+
+class VstsApiMixin:
     api_version = "4.1"  # TODO: update api version
     api_version_preview = "-preview.1"
+
+    def create_subscription(self, instance: Optional[str], shared_secret: str) -> Response:
+        return self.post(
+            VstsApiPath.subscriptions.format(instance=instance),
+            data={
+                "publisherId": "tfs",
+                "eventType": "workitem.updated",
+                "resourceVersion": "1.0",
+                "consumerId": "webHooks",
+                "consumerActionId": "httpRequest",
+                "consumerInputs": {
+                    "url": absolute_uri("/extensions/vsts/issue-updated/"),
+                    "resourceDetailsToSend": "all",
+                    "httpHeaders": f"shared-secret:{shared_secret}",
+                },
+            },
+        )
+
+
+class VstsSetupApiClient(ApiClient, VstsApiMixin):
+    integration_name = "vsts"
+
+    def __init__(self, oauth_redirect_url: str, access_token: str):
+        super().__init__()
+        self.oauth_redirect_url = oauth_redirect_url
+        self.access_token = access_token
+
+    def request(
+        self, method, path, data=None, params=None, api_preview: bool = False
+    ) -> BaseApiResponseX:
+        headers = prepare_headers(
+            api_version=self.api_version,
+            method=method,
+            api_version_preview=self.api_version_preview if api_preview else "",
+        )
+        headers.update(prepare_auth_header(access_token=self.access_token))
+        return self._request(method, path, headers=headers, data=data, params=params)
+
+
+class VstsApiClient(IntegrationProxyClient, OAuth2RefreshMixin, VstsApiMixin):
     integration_name = "vsts"
+    _identity: Identity | None = None
 
     def __init__(
-        self, identity: "Identity", oauth_redirect_url: str, *args: Any, **kwargs: Any
+        self,
+        base_url: str,
+        oauth_redirect_url: str,
+        org_integration_id: int,
+        identity_id: int | None = None,
     ) -> None:
-        super().__init__(*args, **kwargs)
-        self.identity = identity
+        self.base_url = base_url
+        self.identity_id = identity_id
         self.oauth_redirect_url = oauth_redirect_url
-        if "access_token" not in self.identity.data:
-            raise ValueError("Vsts Identity missing access token")
+        super().__init__(org_integration_id=org_integration_id)
+
+    @property
+    def identity(self):
+        if self._identity:
+            return self._identity
+        self._identity = Identity.objects.get(id=self.identity_id)
+        return self._identity
+
+    def request(self, method: str, *args: Any, **kwargs: Any) -> BaseApiResponseX:
+        api_preview = kwargs.pop("api_preview", False)
+        headers = kwargs.pop("headers", {})
+        new_headers = prepare_headers(
+            api_version=self.api_version,
+            method=method,
+            api_version_preview=self.api_version_preview if api_preview else "",
+        )
+        headers.update(new_headers)
 
-    def request(
+        return self._request(method, *args, headers=headers, **kwargs)
+
+    @control_silo_function
+    def authorize_request(
         self,
-        method: str,
-        path: str,
-        data: Optional[Mapping[str, Any]] = None,
-        params: Optional[Sequence[Any]] = None,
-        api_preview: bool = False,
-        timeout: Optional[int] = None,
-    ) -> Response:
+        prepared_request: PreparedRequest,
+    ) -> PreparedRequest:
         self.check_auth(redirect_url=self.oauth_redirect_url)
-        headers = {
-            "Accept": "application/json; api-version={}{}".format(
-                self.api_version, self.api_version_preview if api_preview else ""
-            ),
-            "Content-Type": "application/json-patch+json"
-            if method == "PATCH"
-            else "application/json",
-            "X-HTTP-Method-Override": method,
-            "X-TFS-FedAuthRedirect": "Suppress",
-            "Authorization": "Bearer {}".format(self.identity.data["access_token"]),
-        }
-        return self._request(
-            method, path, headers=headers, data=data, params=params, timeout=timeout
+        access_token = self.identity.data["access_token"]
+        headers = prepare_auth_header(
+            access_token=access_token,
         )
+        prepared_request.headers.update(headers)
+        return prepared_request
 
     def create_work_item(
         self,
         instance: str,
-        project: "Project",
+        project: Project,
         item_type: Optional[str] = None,
         title: Optional[str] = None,
         description: Optional[str] = None,
@@ -294,23 +376,6 @@ class VstsApiClient(ApiClient, OAuth2RefreshMixin):
             params={"continuationToken": continuation_token},
         )
 
-    def create_subscription(self, instance: Optional[str], shared_secret: str) -> Response:
-        return self.post(
-            VstsApiPath.subscriptions.format(instance=instance),
-            data={
-                "publisherId": "tfs",
-                "eventType": "workitem.updated",
-                "resourceVersion": "1.0",
-                "consumerId": "webHooks",
-                "consumerActionId": "httpRequest",
-                "consumerInputs": {
-                    "url": absolute_uri("/extensions/vsts/issue-updated/"),
-                    "resourceDetailsToSend": "all",
-                    "httpHeaders": f"shared-secret:{shared_secret}",
-                },
-            },
-        )
-
     def get_subscription(self, instance: str, subscription_id: str) -> Response:
         return self.get(
             VstsApiPath.subscription.format(instance=instance, subscription_id=subscription_id)

+ 21 - 7
src/sentry/integrations/vsts/integration.py

@@ -42,12 +42,13 @@ from sentry.shared_integrations.exceptions import (
     IntegrationError,
     IntegrationProviderError,
 )
+from sentry.silo import SiloMode
 from sentry.tasks.integrations import migrate_repo
 from sentry.utils.http import absolute_uri
 from sentry.utils.json import JSONData
 from sentry.web.helpers import render_to_response
 
-from .client import VstsApiClient
+from .client import VstsApiClient, VstsSetupApiClient
 from .repository import VstsRepositoryProvider
 
 DESCRIPTION = """
@@ -161,11 +162,21 @@ class VstsIntegration(IntegrationInstallation, RepositoryMixin, VstsIssueSync):
         return True
 
     def get_client(self) -> VstsApiClient:
-        if self.default_identity is None:
-            self.default_identity = self.get_default_identity()
-
-        self.check_domain_name(self.default_identity)
-        return VstsApiClient(self.default_identity, VstsIntegrationProvider.oauth_redirect_url)
+        if SiloMode.get_current_mode() != SiloMode.REGION:
+            if self.default_identity is None:
+                self.default_identity = self.get_default_identity()
+            self.check_domain_name(self.default_identity)
+
+        if self.org_integration is None:
+            raise Exception("self.org_integration is not defined")
+        if self.org_integration.default_auth_id is None:
+            raise Exception("self.org_integration.default_auth_id is not defined")
+        return VstsApiClient(
+            base_url=self.instance,
+            oauth_redirect_url=VstsIntegrationProvider.oauth_redirect_url,
+            org_integration_id=self.org_integration.id,
+            identity_id=self.org_integration.default_auth_id,
+        )
 
     def check_domain_name(self, default_identity: Identity) -> None:
         if re.match("^https://.+/$", self.model.metadata["domain_name"]):
@@ -444,7 +455,10 @@ class VstsIntegrationProvider(IntegrationProvider):
     def create_subscription(
         self, instance: str | None, oauth_data: Mapping[str, Any]
     ) -> tuple[int, str]:
-        client = VstsApiClient(Identity(data=oauth_data), self.oauth_redirect_url)
+        client = VstsSetupApiClient(
+            oauth_redirect_url=self.oauth_redirect_url,
+            access_token=oauth_data["access_token"],
+        )
         shared_secret = generate_token()
         try:
             subscription = client.create_subscription(instance, shared_secret)

+ 4 - 2
src/sentry/integrations/vsts/repository.py

@@ -1,4 +1,6 @@
-from typing import Any, Mapping, MutableMapping, Optional, Sequence
+from __future__ import annotations
+
+from typing import Any, Mapping, MutableMapping, Sequence
 
 from sentry.models import Commit, Organization, Repository
 from sentry.plugins.providers import IntegrationRepositoryProvider
@@ -93,7 +95,7 @@ class VstsRepositoryProvider(IntegrationRepositoryProvider):
         return commit_list
 
     def compare_commits(
-        self, repo: Repository, start_sha: Optional[str], end_sha: str
+        self, repo: Repository, start_sha: str | None, end_sha: str
     ) -> Sequence[Mapping[str, str]]:
         """TODO(mgaeta): This function is kinda a mess."""
         installation = self.get_installation(repo.integration_id, repo.organization_id)

+ 25 - 15
src/sentry/integrations/vsts/webhooks.py

@@ -1,25 +1,29 @@
 from __future__ import annotations
 
 import logging
-from typing import Any, Mapping
+from typing import TYPE_CHECKING, Any, Mapping
 
 from django.utils.crypto import constant_time_compare
 from rest_framework import status
 from rest_framework.request import Request
 from rest_framework.response import Response
 
-from sentry.api.base import Endpoint, control_silo_endpoint
+from sentry.api.base import Endpoint, region_silo_endpoint
 from sentry.integrations.mixins import IssueSyncMixin
 from sentry.integrations.utils import sync_group_assignee_inbound
-from sentry.models import Integration
+from sentry.services.hybrid_cloud.integration import integration_service
 from sentry.utils.email import parse_email
 
+if TYPE_CHECKING:
+    from sentry.services.hybrid_cloud.integration import RpcIntegration
+
+
 UNSET = object()
 logger = logging.getLogger("sentry.integrations")
 PROVIDER_KEY = "vsts"
 
 
-@control_silo_endpoint
+@region_silo_endpoint
 class WorkItemWebhook(Endpoint):
     authentication_classes = ()
     permission_classes = ()
@@ -35,11 +39,11 @@ class WorkItemWebhook(Endpoint):
 
         # https://docs.microsoft.com/en-us/azure/devops/service-hooks/events?view=azure-devops#workitem.updated
         if event_type == "workitem.updated":
-            try:
-                integration = Integration.objects.get(
-                    provider=PROVIDER_KEY, external_id=external_id
-                )
-            except Integration.DoesNotExist:
+
+            integration = integration_service.get_integration(
+                provider=PROVIDER_KEY, external_id=external_id
+            )
+            if integration is None:
                 logger.info(
                     "vsts.integration-in-webhook-payload-does-not-exist",
                     extra={"external_id": external_id, "event_type": event_type},
@@ -56,7 +60,7 @@ class WorkItemWebhook(Endpoint):
         return self.respond()
 
 
-def check_webhook_secret(request: Request, integration: Integration, event_type: str) -> bool:
+def check_webhook_secret(request: Request, integration: RpcIntegration, event_type: str) -> bool:
     integration_secret = integration.metadata.get("subscription", {}).get("secret")
     webhook_payload_secret = request.META.get("HTTP_SHARED_SECRET")
 
@@ -72,7 +76,7 @@ def check_webhook_secret(request: Request, integration: Integration, event_type:
 
 
 def handle_assign_to(
-    integration: Integration,
+    integration: RpcIntegration,
     external_issue_key: str | None,
     assigned_to: Mapping[str, str] | None,
 ) -> None:
@@ -107,7 +111,7 @@ def handle_assign_to(
 
 
 def handle_status_change(
-    integration: Integration,
+    integration: RpcIntegration,
     external_issue_key: str,
     status_change: Mapping[str, str] | None,
     project: str | None,
@@ -115,8 +119,14 @@ def handle_status_change(
     if status_change is None:
         return
 
-    for org_id in integration.organizationintegration_set.values_list("organization_id", flat=True):
-        installation = integration.get_installation(organization_id=org_id)
+    org_integrations = integration_service.get_organization_integrations(
+        integration_id=integration.id
+    )
+
+    for org_integration in org_integrations:
+        installation = integration_service.get_installation(
+            integration=integration, organization_id=org_integration.organization_id
+        )
         if isinstance(installation, IssueSyncMixin):
             installation.sync_status_inbound(
                 external_issue_key,
@@ -129,7 +139,7 @@ def handle_status_change(
             )
 
 
-def handle_updated_workitem(data: Mapping[str, Any], integration: Integration) -> None:
+def handle_updated_workitem(data: Mapping[str, Any], integration: RpcIntegration) -> None:
     project: str | None = None
     try:
         external_issue_key = data["resource"]["workItemId"]

+ 5 - 2
src/sentry/shared_integrations/client/proxy.py

@@ -107,8 +107,11 @@ class IntegrationProxyClient(ApiClient):
             return prepared_request
 
         # E.g. client.get("/chat.postMessage") -> proxy_path = 'chat.postMessage'
-        proxy_path = trim_leading_slashes(prepared_request.url[len(self.base_url) :])
-        url = f"{self.proxy_url}/{proxy_path}"
+        assert self.base_url and self.proxy_url
+        base_url = self.base_url.rstrip("/")
+        proxy_path = trim_leading_slashes(prepared_request.url[len(base_url) :])
+        proxy_url = self.proxy_url.rstrip("/")
+        url = f"{proxy_url}/{proxy_path}"
 
         request_body = prepared_request.body
         if not isinstance(request_body, bytes):

+ 8 - 2
src/sentry/tasks/integrations/sync_status_inbound.py

@@ -2,6 +2,7 @@ from typing import Any, Mapping
 
 from sentry import analytics
 from sentry.models import Group, GroupStatus, Integration, Organization
+from sentry.services.hybrid_cloud.integration import integration_service
 from sentry.tasks.base import instrumented_task, retry, track_group_async_operation
 from sentry.types.activity import ActivityType
 from sentry.types.group import GroupSubStatus
@@ -20,7 +21,10 @@ def sync_status_inbound(
 ) -> None:
     from sentry.integrations.mixins import ResolveSyncAction
 
-    integration = Integration.objects.get(id=integration_id)
+    integration = integration_service.get_integration(integration_id=integration_id)
+    if integration is None:
+        raise Integration.DoesNotExist
+
     organizations = Organization.objects.filter(id=organization_id)
     affected_groups = Group.objects.get_groups_by_external_issue(
         integration, organizations, issue_key
@@ -28,7 +32,9 @@ def sync_status_inbound(
     if not affected_groups:
         return
 
-    installation = integration.get_installation(organization_id=organization_id)
+    installation = integration_service.get_installation(
+        integration=integration, organization_id=organization_id
+    )
 
     try:
         # This makes an API call.

Some files were not shown because too many files changed in this diff