test_rpc.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from __future__ import annotations
  2. from typing import Any, cast
  3. from unittest import mock
  4. import pytest
  5. import responses
  6. from django.db import router
  7. from django.test import override_settings
  8. from sentry.models import OrganizationMapping
  9. from sentry.services.hybrid_cloud.actor import RpcActor
  10. from sentry.services.hybrid_cloud.auth import AuthService
  11. from sentry.services.hybrid_cloud.organization import (
  12. OrganizationService,
  13. RpcOrganizationMemberFlags,
  14. RpcUserOrganizationContext,
  15. )
  16. from sentry.services.hybrid_cloud.organization.serial import serialize_rpc_organization
  17. from sentry.services.hybrid_cloud.rpc import (
  18. RpcSendException,
  19. dispatch_remote_call,
  20. dispatch_to_local_service,
  21. )
  22. from sentry.services.hybrid_cloud.user import RpcUser
  23. from sentry.services.hybrid_cloud.user.serial import serialize_rpc_user
  24. from sentry.silo import SiloMode, unguarded_write
  25. from sentry.testutils import TestCase
  26. from sentry.testutils.region import override_regions
  27. from sentry.types.region import Region, RegionCategory
  28. from sentry.utils import json
  29. _REGIONS = [
  30. Region("north_america", 1, "http://na.sentry.io", RegionCategory.MULTI_TENANT, "swordfish"),
  31. Region("europe", 2, "http://eu.sentry.io", RegionCategory.MULTI_TENANT, "courage"),
  32. ]
  33. class RpcServiceTest(TestCase):
  34. @mock.patch("sentry.services.hybrid_cloud.rpc.dispatch_remote_call")
  35. def test_remote_service(self, mock_dispatch_remote_call):
  36. target_region = _REGIONS[0]
  37. user = self.create_user()
  38. organization = self.create_organization()
  39. with unguarded_write(using=router.db_for_write(OrganizationMapping)):
  40. OrganizationMapping.objects.update_or_create(
  41. organization_id=organization.id,
  42. defaults={
  43. "slug": organization.slug,
  44. "name": organization.name,
  45. "region_name": target_region.name,
  46. },
  47. )
  48. serial_user = RpcUser(id=user.id)
  49. serial_org = serialize_rpc_organization(organization)
  50. service = OrganizationService.create_delegation()
  51. with override_regions(_REGIONS), override_settings(SILO_MODE=SiloMode.CONTROL):
  52. service.add_organization_member(
  53. organization_id=serial_org.id,
  54. default_org_role=serial_org.default_role,
  55. user=serial_user,
  56. flags=RpcOrganizationMemberFlags(),
  57. role=None,
  58. )
  59. assert mock_dispatch_remote_call.called
  60. (
  61. region,
  62. service_name,
  63. method_name,
  64. serial_arguments,
  65. ) = mock_dispatch_remote_call.call_args.args
  66. assert region == target_region
  67. assert service_name == OrganizationService.key
  68. assert method_name == "add_organization_member"
  69. assert serial_arguments.keys() == {
  70. "organization_id",
  71. "default_org_role",
  72. "user_id",
  73. "email",
  74. "flags",
  75. "role",
  76. "inviter_id",
  77. "invite_status",
  78. }
  79. assert serial_arguments["organization_id"] == organization.id
  80. @mock.patch("sentry.services.hybrid_cloud.report_pydantic_type_validation_error")
  81. def test_models_tolerate_invalid_types(self, mock_report):
  82. # Create an RpcModel instance whose fields don't obey type annotations and
  83. # ensure that it does not raise an exception.
  84. RpcActor(
  85. id="hey, this isn't an int",
  86. actor_id=None, # this one is okay
  87. actor_type=None, # should not be Optional
  88. )
  89. assert mock_report.call_count == 2
  90. field_names = {c.args[0].name for c in mock_report.call_args_list}
  91. model_classes = [c.args[3] for c in mock_report.call_args_list]
  92. assert field_names == {"id", "actor_type"}
  93. assert model_classes == [RpcActor] * 2
  94. def test_dispatch_to_local_service(self):
  95. user = self.create_user()
  96. organization = self.create_organization()
  97. serial_org = serialize_rpc_organization(organization)
  98. serial_arguments = dict(
  99. organization_id=serial_org.id,
  100. default_org_role=serial_org.default_role,
  101. user_id=user.id,
  102. flags=RpcOrganizationMemberFlags().dict(),
  103. role=None,
  104. )
  105. with override_settings(SILO_MODE=SiloMode.REGION):
  106. service = OrganizationService.create_delegation()
  107. dispatch_to_local_service(service.key, "add_organization_member", serial_arguments)
  108. def test_dispatch_to_local_service_list_result(self):
  109. organization = self.create_organization()
  110. args = {"organization_ids": [organization.id]}
  111. with override_settings(SILO_MODE=SiloMode.CONTROL):
  112. service = AuthService.create_delegation()
  113. response = dispatch_to_local_service(service.key, "get_org_auth_config", args)
  114. result = response["value"]
  115. assert len(result) == 1
  116. assert result[0]["organization_id"] == organization.id
  117. control_address = "https://control.example.com"
  118. shared_secret = ["a-long-token-you-could-not-guess"]
  119. class DispatchRemoteCallTest(TestCase):
  120. def test_while_not_allowed(self):
  121. with pytest.raises(RpcSendException):
  122. dispatch_remote_call(None, "user", "get_user", {"id": 0})
  123. @staticmethod
  124. def _set_up_mock_response(service_name: str, response_value: Any, address: str | None = None):
  125. address = address or control_address
  126. responses.add(
  127. responses.POST,
  128. f"{address}/api/0/internal/rpc/{service_name}/",
  129. content_type="json",
  130. body=json.dumps({"meta": {}, "value": response_value}),
  131. )
  132. @responses.activate
  133. def test_region_to_control_happy_path(self):
  134. org = self.create_organization()
  135. with override_settings(
  136. RPC_SHARED_SECRET=shared_secret, SENTRY_CONTROL_ADDRESS=control_address
  137. ):
  138. response_value = RpcUserOrganizationContext(
  139. organization=serialize_rpc_organization(org)
  140. )
  141. self._set_up_mock_response("organization/get_organization_by_id", response_value.dict())
  142. result = dispatch_remote_call(
  143. None, "organization", "get_organization_by_id", {"id": org.id}
  144. )
  145. assert result == response_value
  146. @responses.activate
  147. @override_settings(
  148. SILO_MODE=SiloMode.REGION,
  149. RPC_SHARED_SECRET=shared_secret,
  150. SENTRY_CONTROL_ADDRESS=control_address,
  151. )
  152. def test_region_to_control_null_result(self):
  153. self._set_up_mock_response("organization/get_organization_by_id", None)
  154. result = dispatch_remote_call(None, "organization", "get_organization_by_id", {"id": 0})
  155. assert result is None
  156. @responses.activate
  157. @override_regions(_REGIONS)
  158. @override_settings(
  159. SILO_MODE=SiloMode.CONTROL,
  160. RPC_SHARED_SECRET=shared_secret,
  161. SENTRY_CONTROL_ADDRESS=control_address,
  162. )
  163. def test_control_to_region_happy_path(self):
  164. user = self.create_user()
  165. serial = serialize_rpc_user(user)
  166. self._set_up_mock_response("user/get_user", serial.dict(), address="http://na.sentry.io")
  167. result = dispatch_remote_call(_REGIONS[0], "user", "get_user", {"id": 0})
  168. assert result == serial
  169. @responses.activate
  170. @override_regions(_REGIONS)
  171. @override_settings(
  172. SILO_MODE=SiloMode.CONTROL,
  173. RPC_SHARED_SECRET=shared_secret,
  174. SENTRY_CONTROL_ADDRESS=control_address,
  175. )
  176. def test_region_to_control_with_list_result(self):
  177. users = [self.create_user() for _ in range(3)]
  178. serial = [serialize_rpc_user(user) for user in users]
  179. self._set_up_mock_response("user/get_many", [m.dict() for m in serial])
  180. result = dispatch_remote_call(None, "user", "get_many", {"filter": {}})
  181. assert result == serial
  182. @responses.activate
  183. @override_regions(_REGIONS)
  184. @override_settings(SILO_MODE=SiloMode.CONTROL, DEV_HYBRID_CLOUD_RPC_SENDER={"is_allowed": True})
  185. def test_early_halt_from_null_region_resolution(self):
  186. with override_settings(SILO_MODE=SiloMode.CONTROL):
  187. org_service_delgn = cast(OrganizationService, OrganizationService.create_delegation())
  188. result = org_service_delgn.get_org_by_slug(slug="this_is_not_a_valid_slug")
  189. assert result is None