test_query_subscription_consumer.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import time
  2. from copy import deepcopy
  3. from datetime import timedelta
  4. from unittest import mock
  5. from unittest.mock import Mock, call
  6. from uuid import uuid4
  7. import pytz
  8. from confluent_kafka import Producer
  9. from dateutil.parser import parse as parse_date
  10. from django.conf import settings
  11. from django.test.utils import override_settings
  12. from exam import fixture
  13. from sentry.snuba.dataset import Dataset
  14. from sentry.snuba.models import SnubaQuery
  15. from sentry.snuba.query_subscription_consumer import (
  16. QuerySubscriptionConsumer,
  17. register_subscriber,
  18. subscriber_registry,
  19. )
  20. from sentry.snuba.subscriptions import create_snuba_query, create_snuba_subscription
  21. from sentry.testutils.cases import SnubaTestCase, TestCase
  22. from sentry.utils import json
  23. class QuerySubscriptionConsumerTest(TestCase, SnubaTestCase):
  24. @fixture
  25. def subscription_id(self):
  26. return "1234"
  27. @fixture
  28. def old_valid_wrapper(self):
  29. return {"version": 2, "payload": self.old_payload}
  30. @fixture
  31. def old_payload(self):
  32. return {
  33. "subscription_id": self.subscription_id,
  34. "result": {"data": [{"hello": 50}]},
  35. "request": {"some": "data"},
  36. "timestamp": "2020-01-01T01:23:45.1234",
  37. }
  38. @fixture
  39. def valid_wrapper(self):
  40. return {"version": 3, "payload": self.valid_payload}
  41. @fixture
  42. def valid_payload(self):
  43. return {
  44. "subscription_id": "1234",
  45. "result": {"data": [{"hello": 50}]},
  46. "request": {
  47. "some": "data",
  48. "query": """MATCH (metrics_counters) SELECT sum(value) AS value BY
  49. tags[3] WHERE org_id = 1 AND project_id IN tuple(1) AND metric_id = 16
  50. AND tags[3] IN tuple(13, 4)""",
  51. },
  52. "entity": "metrics_counters",
  53. "timestamp": "2020-01-01T01:23:45.1234",
  54. }
  55. @fixture
  56. def topic(self):
  57. return uuid4().hex
  58. @fixture
  59. def producer(self):
  60. cluster_name = settings.KAFKA_TOPICS[self.topic]["cluster"]
  61. conf = {
  62. "bootstrap.servers": settings.KAFKA_CLUSTERS[cluster_name]["common"][
  63. "bootstrap.servers"
  64. ],
  65. "session.timeout.ms": 6000,
  66. }
  67. return Producer(conf)
  68. def setUp(self):
  69. super().setUp()
  70. self.override_settings_cm = override_settings(
  71. KAFKA_TOPICS={self.topic: {"cluster": "default"}}
  72. )
  73. self.override_settings_cm.__enter__()
  74. self.orig_registry = deepcopy(subscriber_registry)
  75. def tearDown(self):
  76. super().tearDown()
  77. self.override_settings_cm.__exit__(None, None, None)
  78. subscriber_registry.clear()
  79. subscriber_registry.update(self.orig_registry)
  80. @fixture
  81. def registration_key(self):
  82. return "registered_keyboard_interrupt"
  83. def create_subscription(self):
  84. with self.tasks():
  85. snuba_query = create_snuba_query(
  86. SnubaQuery.Type.ERROR,
  87. Dataset.Events,
  88. "hello",
  89. "count()",
  90. timedelta(minutes=1),
  91. timedelta(minutes=1),
  92. None,
  93. )
  94. sub = create_snuba_subscription(self.project, self.registration_key, snuba_query)
  95. sub.subscription_id = self.subscription_id
  96. sub.status = 0
  97. sub.save()
  98. return sub
  99. def test_old(self):
  100. cluster_name = settings.KAFKA_TOPICS[self.topic]["cluster"]
  101. conf = {
  102. "bootstrap.servers": settings.KAFKA_CLUSTERS[cluster_name]["common"][
  103. "bootstrap.servers"
  104. ],
  105. "session.timeout.ms": 6000,
  106. }
  107. producer = Producer(conf)
  108. producer.produce(self.topic, json.dumps(self.old_valid_wrapper))
  109. producer.flush()
  110. consumer = QuerySubscriptionConsumer("hi", topic=self.topic, commit_batch_size=1)
  111. mock_callback = Mock(side_effect=lambda *a, **k: consumer.shutdown())
  112. register_subscriber(self.registration_key)(mock_callback)
  113. sub = self.create_subscription()
  114. consumer.run()
  115. payload = self.old_payload
  116. payload["values"] = payload["result"]
  117. payload["timestamp"] = parse_date(payload["timestamp"]).replace(tzinfo=pytz.utc)
  118. mock_callback.assert_called_once_with(payload, sub)
  119. def test_normal(self):
  120. cluster_name = settings.KAFKA_TOPICS[self.topic]["cluster"]
  121. conf = {
  122. "bootstrap.servers": settings.KAFKA_CLUSTERS[cluster_name]["common"][
  123. "bootstrap.servers"
  124. ],
  125. "session.timeout.ms": 6000,
  126. }
  127. producer = Producer(conf)
  128. producer.produce(self.topic, json.dumps(self.valid_wrapper))
  129. producer.flush()
  130. consumer = QuerySubscriptionConsumer("hi", topic=self.topic, commit_batch_size=1)
  131. mock_callback = Mock(side_effect=lambda *a, **k: consumer.shutdown())
  132. register_subscriber(self.registration_key)(mock_callback)
  133. sub = self.create_subscription()
  134. consumer.run()
  135. payload = self.valid_payload
  136. payload["values"] = payload["result"]
  137. payload["timestamp"] = parse_date(payload["timestamp"]).replace(tzinfo=pytz.utc)
  138. mock_callback.assert_called_once_with(payload, sub)
  139. def test_shutdown(self):
  140. valid_wrapper_2 = deepcopy(self.valid_wrapper)
  141. valid_wrapper_2["payload"]["result"]["hello"] = 25
  142. valid_wrapper_3 = deepcopy(self.valid_wrapper)
  143. valid_wrapper_3["payload"]["result"]["hello"] = 5000
  144. self.producer.produce(self.topic, json.dumps(self.valid_wrapper))
  145. self.producer.produce(self.topic, json.dumps(valid_wrapper_2))
  146. self.producer.produce(self.topic, json.dumps(valid_wrapper_3))
  147. self.producer.flush()
  148. def normalize_payload(payload):
  149. return {
  150. **payload,
  151. "values": payload["result"],
  152. "timestamp": parse_date(payload["timestamp"]).replace(tzinfo=pytz.utc),
  153. }
  154. consumer = QuerySubscriptionConsumer("hi", topic=self.topic, commit_batch_size=100)
  155. def mock_callback(*args, **kwargs):
  156. if mock.call_count >= len(expected_calls):
  157. consumer.shutdown()
  158. mock = Mock(side_effect=mock_callback)
  159. register_subscriber(self.registration_key)(mock)
  160. sub = self.create_subscription()
  161. expected_calls = [
  162. call(normalize_payload(self.valid_payload), sub),
  163. call(normalize_payload(valid_wrapper_2["payload"]), sub),
  164. ]
  165. consumer.run()
  166. mock.assert_has_calls(expected_calls)
  167. expected_calls = [call(normalize_payload(valid_wrapper_3["payload"]), sub)]
  168. mock.reset_mock()
  169. consumer.run()
  170. mock.assert_has_calls(expected_calls)
  171. @mock.patch("sentry.snuba.query_subscription_consumer.QuerySubscriptionConsumer.commit_offsets")
  172. def test_batch_timeout(self, commit_offset_mock):
  173. self.producer.produce(self.topic, json.dumps(self.valid_wrapper))
  174. self.producer.flush()
  175. consumer = QuerySubscriptionConsumer(
  176. "hi", topic=self.topic, commit_batch_size=100, commit_batch_timeout_ms=1
  177. )
  178. def mock_callback(*args, **kwargs):
  179. time.sleep(0.1)
  180. consumer.shutdown()
  181. mock = Mock(side_effect=mock_callback)
  182. register_subscriber(self.registration_key)(mock)
  183. self.create_subscription()
  184. consumer.run()
  185. # Once on revoke, once on shutdown, and once due to batch timeout
  186. assert len(commit_offset_mock.call_args_list) == 3