test_base.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. from unittest import mock
  2. from unittest.mock import MagicMock
  3. from django.http import HttpRequest, QueryDict, StreamingHttpResponse
  4. from django.test import override_settings
  5. from pytest import raises
  6. from rest_framework.response import Response
  7. from sentry_sdk import Scope
  8. from sentry_sdk.utils import exc_info_from_error
  9. from sentry.api.base import Endpoint, EndpointSiloLimit, resolve_region
  10. from sentry.api.paginator import GenericOffsetPaginator
  11. from sentry.models import ApiKey
  12. from sentry.services.hybrid_cloud.util import FunctionSiloLimit
  13. from sentry.silo import SiloMode
  14. from sentry.testutils.cases import APITestCase
  15. from sentry.testutils.helpers.options import override_options
  16. from sentry.testutils.region import override_region_config
  17. from sentry.types.region import RegionCategory, clear_global_regions
  18. from sentry.utils.cursors import Cursor
  19. # Though it looks weird to have a method outside a class, this isn't a mistake but rather
  20. # a mock for a method in Django REST Framework's `APIView` class
  21. def reraise(self, e: Exception):
  22. raise e
  23. class DummyEndpoint(Endpoint):
  24. permission_classes = ()
  25. def get(self, request):
  26. return Response({"ok": True})
  27. class DummyErroringEndpoint(Endpoint):
  28. permission_classes = ()
  29. # `as_view` requires that any init args passed to it match attributes already on the
  30. # class, so even though they're really meant to be instance attributes, we have to
  31. # add them here as class attributes first
  32. error = None
  33. handler_context_arg = None
  34. scope_arg = None
  35. def __init__(
  36. self,
  37. *args,
  38. error: Exception,
  39. handler_context_arg=None,
  40. scope_arg=None,
  41. **kwargs,
  42. ):
  43. # The error which will be thrown when a GET request is made
  44. self.error = error
  45. # The argumets which will be passed on to `Endpoint.handle_exception` via `super`
  46. self.handler_context_arg = handler_context_arg
  47. self.scope_arg = scope_arg
  48. super().__init__(*args, **kwargs)
  49. def get(self, request):
  50. raise self.error
  51. def handle_exception(self, request, exc, handler_context=None, scope=None):
  52. return super().handle_exception(request, exc, self.handler_context_arg, self.scope_arg)
  53. class DummyPaginationEndpoint(Endpoint):
  54. permission_classes = ()
  55. def get(self, request):
  56. values = [x for x in range(0, 100)]
  57. def data_fn(offset, limit):
  58. page_offset = offset * limit
  59. return values[page_offset : page_offset + limit]
  60. return self.paginate(
  61. request=request,
  62. paginator=GenericOffsetPaginator(data_fn),
  63. on_results=lambda results: results,
  64. )
  65. _dummy_endpoint = DummyEndpoint.as_view()
  66. class DummyPaginationStreamingEndpoint(Endpoint):
  67. permission_classes = ()
  68. def get(self, request):
  69. values = [x for x in range(0, 100)]
  70. def data_fn(offset, limit):
  71. page_offset = offset * limit
  72. return values[page_offset : page_offset + limit]
  73. return self.paginate(
  74. request=request,
  75. paginator=GenericOffsetPaginator(data_fn),
  76. on_results=lambda results: iter(results),
  77. response_cls=StreamingHttpResponse,
  78. response_kwargs={"content_type": "application/json"},
  79. )
  80. _dummy_streaming_endpoint = DummyPaginationStreamingEndpoint.as_view()
  81. class EndpointTest(APITestCase):
  82. def test_basic_cors(self):
  83. org = self.create_organization()
  84. apikey = ApiKey.objects.create(organization_id=org.id, allowed_origins="*")
  85. request = self.make_request(method="GET")
  86. request.META["HTTP_ORIGIN"] = "http://example.com"
  87. request.META["HTTP_AUTHORIZATION"] = self.create_basic_auth_header(apikey.key)
  88. response = _dummy_endpoint(request)
  89. response.render()
  90. assert response.status_code == 200, response.content
  91. assert response["Access-Control-Allow-Origin"] == "http://example.com"
  92. assert response["Access-Control-Allow-Headers"] == (
  93. "X-Sentry-Auth, X-Requested-With, Origin, Accept, "
  94. "Content-Type, Authentication, Authorization, Content-Encoding, "
  95. "sentry-trace, baggage, X-CSRFToken"
  96. )
  97. assert response["Access-Control-Expose-Headers"] == "X-Sentry-Error, Retry-After"
  98. assert response["Access-Control-Allow-Methods"] == "GET, HEAD, OPTIONS"
  99. assert "Access-Control-Allow-Credentials" not in response
  100. @override_options({"system.base-hostname": "example.com"})
  101. def test_allow_credentials(self):
  102. org = self.create_organization()
  103. apikey = ApiKey.objects.create(organization_id=org.id, allowed_origins="*")
  104. request = self.make_request(method="GET")
  105. request.META["HTTP_ORIGIN"] = "http://acme.example.com"
  106. request.META["HTTP_AUTHORIZATION"] = self.create_basic_auth_header(apikey.key)
  107. response = _dummy_endpoint(request)
  108. response.render()
  109. assert response.status_code == 200, response.content
  110. assert response["Access-Control-Allow-Origin"] == "http://acme.example.com"
  111. assert response["Access-Control-Allow-Headers"] == (
  112. "X-Sentry-Auth, X-Requested-With, Origin, Accept, "
  113. "Content-Type, Authentication, Authorization, Content-Encoding, "
  114. "sentry-trace, baggage, X-CSRFToken"
  115. )
  116. assert response["Access-Control-Expose-Headers"] == "X-Sentry-Error, Retry-After"
  117. assert response["Access-Control-Allow-Methods"] == "GET, HEAD, OPTIONS"
  118. assert response["Access-Control-Allow-Credentials"] == "true"
  119. @override_options({"system.base-hostname": "acme.com"})
  120. def test_allow_credentials_incorrect(self):
  121. org = self.create_organization()
  122. apikey = ApiKey.objects.create(organization_id=org.id, allowed_origins="*")
  123. for http_origin in ["http://acme.example.com", "http://fakeacme.com"]:
  124. request = self.make_request(method="GET")
  125. request.META["HTTP_ORIGIN"] = http_origin
  126. request.META["HTTP_AUTHORIZATION"] = self.create_basic_auth_header(apikey.key)
  127. response = _dummy_endpoint(request)
  128. response.render()
  129. assert "Access-Control-Allow-Credentials" not in response
  130. def test_invalid_cors_without_auth(self):
  131. request = self.make_request(method="GET")
  132. request.META["HTTP_ORIGIN"] = "http://example.com"
  133. with self.settings(SENTRY_ALLOW_ORIGIN="https://sentry.io"):
  134. response = _dummy_endpoint(request)
  135. response.render()
  136. assert response.status_code == 400, response.content
  137. def test_valid_cors_without_auth(self):
  138. request = self.make_request(method="GET")
  139. request.META["HTTP_ORIGIN"] = "http://example.com"
  140. with self.settings(SENTRY_ALLOW_ORIGIN="*"):
  141. response = _dummy_endpoint(request)
  142. response.render()
  143. assert response.status_code == 200, response.content
  144. assert response["Access-Control-Allow-Origin"] == "http://example.com"
  145. # XXX(dcramer): The default setting needs to allow requests to work or it will be a regression
  146. def test_cors_not_configured_is_valid(self):
  147. request = self.make_request(method="GET")
  148. request.META["HTTP_ORIGIN"] = "http://example.com"
  149. with self.settings(SENTRY_ALLOW_ORIGIN=None):
  150. response = _dummy_endpoint(request)
  151. response.render()
  152. assert response.status_code == 200, response.content
  153. assert response["Access-Control-Allow-Origin"] == "http://example.com"
  154. assert response["Access-Control-Allow-Headers"] == (
  155. "X-Sentry-Auth, X-Requested-With, Origin, Accept, "
  156. "Content-Type, Authentication, Authorization, Content-Encoding, "
  157. "sentry-trace, baggage, X-CSRFToken"
  158. )
  159. assert response["Access-Control-Expose-Headers"] == "X-Sentry-Error, Retry-After"
  160. assert response["Access-Control-Allow-Methods"] == "GET, HEAD, OPTIONS"
  161. @mock.patch("sentry.api.base.Endpoint.convert_args")
  162. def test_method_not_allowed(self, mock_convert_args):
  163. request = self.make_request(method="POST")
  164. response = _dummy_endpoint(request)
  165. response.render()
  166. assert response.status_code == 405, response.content
  167. # did not try to convert args
  168. assert not mock_convert_args.info.called
  169. class EndpointHandleExceptionTest(APITestCase):
  170. @mock.patch("rest_framework.views.APIView.handle_exception", return_value=Response(status=500))
  171. def test_handle_exception_when_super_returns_response(
  172. self, mock_super_handle_exception: MagicMock
  173. ):
  174. mock_endpoint = DummyErroringEndpoint.as_view(error=Exception("nope"))
  175. response = mock_endpoint(self.make_request(method="GET"))
  176. # The endpoint should pass along the response generated by `APIView.handle_exception`
  177. assert response == mock_super_handle_exception.return_value
  178. @mock.patch("rest_framework.views.APIView.handle_exception", new=reraise)
  179. @mock.patch("sentry.api.base.capture_exception", return_value="1231201211212012")
  180. def test_handle_exception_when_super_reraises(
  181. self,
  182. mock_capture_exception: MagicMock,
  183. ):
  184. handler_context = {"api_request_URL": "http://dogs.are.great/"}
  185. scope = Scope()
  186. tags = {"maisey": "silly", "charlie": "goofy"}
  187. for tag, value in tags.items():
  188. scope.set_tag(tag, value)
  189. cases = [
  190. # The first half of each tuple is what's passed to `handle_exception`, and the second
  191. # half is what we expect in the scope passed to `capture_exception`
  192. (None, None, {}, {}),
  193. (handler_context, None, {"Request Handler Data": handler_context}, {}),
  194. (None, scope, {}, tags),
  195. (
  196. handler_context,
  197. scope,
  198. {"Request Handler Data": handler_context},
  199. tags,
  200. ),
  201. ]
  202. for handler_context_arg, scope_arg, expected_scope_contexts, expected_scope_tags in cases:
  203. handler_error = Exception("nope")
  204. mock_endpoint = DummyErroringEndpoint.as_view(
  205. error=handler_error,
  206. handler_context_arg=handler_context_arg,
  207. scope_arg=scope_arg,
  208. )
  209. with mock.patch("sys.exc_info", return_value=exc_info_from_error(handler_error)):
  210. with mock.patch("sys.stderr.write") as mock_stderr_write:
  211. response = mock_endpoint(self.make_request(method="GET"))
  212. assert response.status_code == 500
  213. assert response.data == {
  214. "detail": "Internal Error",
  215. "errorId": "1231201211212012",
  216. }
  217. assert response.exception is True
  218. mock_stderr_write.assert_called_with("Exception: nope\n")
  219. capture_exception_handler_context_arg = mock_capture_exception.call_args.args[0]
  220. capture_exception_scope_kwarg = mock_capture_exception.call_args.kwargs.get(
  221. "scope"
  222. )
  223. assert capture_exception_handler_context_arg == handler_error
  224. assert isinstance(capture_exception_scope_kwarg, Scope)
  225. assert capture_exception_scope_kwarg._contexts == expected_scope_contexts
  226. assert capture_exception_scope_kwarg._tags == expected_scope_tags
  227. class CursorGenerationTest(APITestCase):
  228. def test_serializes_params(self):
  229. request = self.make_request(method="GET", path="/api/0/organizations/")
  230. request.GET = QueryDict("member=1&cursor=foo")
  231. endpoint = Endpoint()
  232. result = endpoint.build_cursor_link(request, "next", "1492107369532:0:0")
  233. assert result == (
  234. "<http://testserver/api/0/organizations/?member=1&cursor=1492107369532:0:0>;"
  235. ' rel="next"; results="true"; cursor="1492107369532:0:0"'
  236. )
  237. def test_preserves_ssl_proto(self):
  238. request = self.make_request(method="GET", path="/api/0/organizations/", secure_scheme=True)
  239. request.GET = QueryDict("member=1&cursor=foo")
  240. endpoint = Endpoint()
  241. with override_options({"system.url-prefix": "https://testserver"}):
  242. result = endpoint.build_cursor_link(request, "next", "1492107369532:0:0")
  243. assert result == (
  244. "<https://testserver/api/0/organizations/?member=1&cursor=1492107369532:0:0>;"
  245. ' rel="next"; results="true"; cursor="1492107369532:0:0"'
  246. )
  247. def test_handles_customer_domains(self):
  248. request = self.make_request(
  249. method="GET", path="/api/0/organizations/", secure_scheme=True, subdomain="bebe"
  250. )
  251. request.GET = QueryDict("member=1&cursor=foo")
  252. endpoint = Endpoint()
  253. with override_options(
  254. {
  255. "system.url-prefix": "https://testserver",
  256. "system.organization-url-template": "https://{hostname}",
  257. }
  258. ):
  259. result = endpoint.build_cursor_link(request, "next", "1492107369532:0:0")
  260. assert result == (
  261. "<https://bebe.testserver/api/0/organizations/?member=1&cursor=1492107369532:0:0>;"
  262. ' rel="next"; results="true"; cursor="1492107369532:0:0"'
  263. )
  264. def test_unicode_path(self):
  265. request = self.make_request(method="GET", path="/api/0/organizations/üuuuu/")
  266. endpoint = Endpoint()
  267. result = endpoint.build_cursor_link(request, "next", "1492107369532:0:0")
  268. assert result == (
  269. "<http://testserver/api/0/organizations/%C3%BCuuuu/?&cursor=1492107369532:0:0>;"
  270. ' rel="next"; results="true"; cursor="1492107369532:0:0"'
  271. )
  272. def test_encodes_url(self):
  273. endpoint = Endpoint()
  274. request = self.make_request(method="GET", path="/foo/bar/lol:what/")
  275. result = endpoint.build_cursor_link(request, "next", cursor=Cursor(0, 0, 0))
  276. assert (
  277. result
  278. == '<http://testserver/foo/bar/lol%3Awhat/?&cursor=0:0:0>; rel="next"; results="false"; cursor="0:0:0"'
  279. )
  280. class PaginateTest(APITestCase):
  281. def setUp(self):
  282. super().setUp()
  283. self.request = self.make_request(method="GET")
  284. self.view = DummyPaginationEndpoint().as_view()
  285. def test_success(self):
  286. response = self.view(self.request)
  287. assert response.status_code == 200, response.content
  288. assert (
  289. response["Link"]
  290. == '<http://testserver/?&cursor=0:0:1>; rel="previous"; results="false"; cursor="0:0:1", <http://testserver/?&cursor=0:100:0>; rel="next"; results="false"; cursor="0:100:0"'
  291. )
  292. def test_invalid_per_page(self):
  293. self.request.GET = {"per_page": "nope"}
  294. response = self.view(self.request)
  295. assert response.status_code == 400
  296. def test_invalid_cursor(self):
  297. self.request.GET = {"cursor": "no:no:no"}
  298. response = self.view(self.request)
  299. assert response.status_code == 400
  300. def test_per_page_out_of_bounds(self):
  301. self.request.GET = {"per_page": "101"}
  302. response = self.view(self.request)
  303. assert response.status_code == 400
  304. def test_custom_response_type(self):
  305. response = _dummy_streaming_endpoint(self.request)
  306. assert response.status_code == 200
  307. assert type(response) == StreamingHttpResponse
  308. assert response.has_header("content-type")
  309. class EndpointJSONBodyTest(APITestCase):
  310. def setUp(self):
  311. super().setUp()
  312. self.request = HttpRequest()
  313. self.request.method = "GET"
  314. self.request.META["CONTENT_TYPE"] = "application/json"
  315. def test_json(self):
  316. self.request._body = '{"foo":"bar"}'
  317. Endpoint().load_json_body(self.request)
  318. assert self.request.json_body == {"foo": "bar"}
  319. def test_invalid_json(self):
  320. self.request._body = "hello"
  321. Endpoint().load_json_body(self.request)
  322. assert not self.request.json_body
  323. def test_empty_request_body(self):
  324. self.request._body = ""
  325. Endpoint().load_json_body(self.request)
  326. assert not self.request.json_body
  327. def test_non_json_content_type(self):
  328. self.request.META["CONTENT_TYPE"] = "text/plain"
  329. Endpoint().load_json_body(self.request)
  330. assert not self.request.json_body
  331. class CustomerDomainTest(APITestCase):
  332. def test_resolve_region(self):
  333. clear_global_regions()
  334. def request_with_subdomain(subdomain):
  335. request = self.make_request(method="GET")
  336. request.subdomain = subdomain
  337. return resolve_region(request)
  338. region_config = [
  339. {
  340. "name": "na",
  341. "snowflake_id": 1,
  342. "address": "http://na.testserver",
  343. "category": RegionCategory.MULTI_TENANT.name,
  344. },
  345. {
  346. "name": "eu",
  347. "snowflake_id": 1,
  348. "address": "http://eu.testserver",
  349. "category": RegionCategory.MULTI_TENANT.name,
  350. },
  351. ]
  352. with override_region_config(region_config):
  353. assert request_with_subdomain("na") == "na"
  354. assert request_with_subdomain("eu") == "eu"
  355. assert request_with_subdomain("sentry") is None
  356. class EndpointSiloLimitTest(APITestCase):
  357. def _test_active_on(self, endpoint_mode, active_mode, expect_to_be_active):
  358. @EndpointSiloLimit(endpoint_mode)
  359. class DecoratedEndpoint(DummyEndpoint):
  360. pass
  361. class EndpointWithDecoratedMethod(DummyEndpoint):
  362. @EndpointSiloLimit(endpoint_mode)
  363. def get(self, request):
  364. return super().get(request)
  365. with override_settings(SILO_MODE=active_mode):
  366. request = self.make_request(method="GET")
  367. for endpoint_class in (DecoratedEndpoint, EndpointWithDecoratedMethod):
  368. view = endpoint_class.as_view()
  369. with override_settings(FAIL_ON_UNAVAILABLE_API_CALL=False):
  370. response = view(request)
  371. assert response.status_code == (200 if expect_to_be_active else 404)
  372. if not expect_to_be_active:
  373. with override_settings(FAIL_ON_UNAVAILABLE_API_CALL=True):
  374. with raises(EndpointSiloLimit.AvailabilityError):
  375. DecoratedEndpoint.as_view()(request)
  376. # TODO: Make work with EndpointWithDecoratedMethod
  377. def test_with_active_mode(self):
  378. self._test_active_on(SiloMode.REGION, SiloMode.REGION, True)
  379. self._test_active_on(SiloMode.CONTROL, SiloMode.CONTROL, True)
  380. def test_with_inactive_mode(self):
  381. self._test_active_on(SiloMode.REGION, SiloMode.CONTROL, False)
  382. self._test_active_on(SiloMode.CONTROL, SiloMode.REGION, False)
  383. def test_with_monolith_mode(self):
  384. self._test_active_on(SiloMode.REGION, SiloMode.MONOLITH, True)
  385. self._test_active_on(SiloMode.CONTROL, SiloMode.MONOLITH, True)
  386. class FunctionSiloLimitTest(APITestCase):
  387. def _test_active_on(self, endpoint_mode, active_mode, expect_to_be_active):
  388. @FunctionSiloLimit(endpoint_mode)
  389. def decorated_function():
  390. pass
  391. with override_settings(SILO_MODE=active_mode):
  392. if expect_to_be_active:
  393. decorated_function()
  394. else:
  395. with raises(FunctionSiloLimit.AvailabilityError):
  396. decorated_function()
  397. def test_with_active_mode(self):
  398. self._test_active_on(SiloMode.REGION, SiloMode.REGION, True)
  399. self._test_active_on(SiloMode.CONTROL, SiloMode.CONTROL, True)
  400. def test_with_inactive_mode(self):
  401. self._test_active_on(SiloMode.REGION, SiloMode.CONTROL, False)
  402. self._test_active_on(SiloMode.CONTROL, SiloMode.REGION, False)
  403. def test_with_monolith_mode(self):
  404. self._test_active_on(SiloMode.REGION, SiloMode.MONOLITH, True)
  405. self._test_active_on(SiloMode.CONTROL, SiloMode.MONOLITH, True)