Browse Source

chore(typing): Initial test of typing and introducing a mypy config file (#23606)

* chore(typing): Initial test of typing and introducing a mypy config file

This is a test of converting a file to be entirely typed. I chose something fairly self contained,
and typed and functions imported from other modules as well.

I typed our json.loads function as `Any`. I'm open to being more specific here, although from what
I've been reading it's hard to do this well until MyPy supports recursive types.

Also added in a mypy.ini, which enables strict checking and specifies a list of files to run the
checks on. I figure that the way we'd use this is to slowly convert files over time and add them to
this list, and once and entire module is added we'd add a glob match for that module instead.

Looking for feedback on this approach. I think that we can fairly easily have this run in CI using
the config file.
Dan Fuller 4 years ago
parent
commit
35a6e81cf2

+ 26 - 0
mypy.ini

@@ -0,0 +1,26 @@
+[mypy]
+python_version = 3.6
+files = src/sentry/snuba/query_subscription_consumer.py
+
+; Enable all options used with --strict
+warn_unused_configs=True
+disallow_any_generics=True
+disallow_subclassing_any=True
+disallow_untyped_calls=True
+disallow_untyped_defs=True
+disallow_incomplete_defs=True
+check_untyped_defs=True
+disallow_untyped_decorators=True
+no_implicit_optional=True
+warn_redundant_casts=True
+warn_unused_ignores=True
+warn_return_any=True
+no_implicit_reexport=True
+
+
+[mypy-confluent_kafka.*]
+ignore_missing_imports = True
+[mypy-jsonschema]
+ignore_missing_imports = True
+[mypy-django.*]
+ignore_missing_imports = True

+ 1 - 1
src/sentry/db/models/manager.py

@@ -262,7 +262,7 @@ class BaseManager(Manager):
         super().contribute_to_class(model, name)
         class_prepared.connect(self.__class_prepared, sender=model)
 
-    def get_from_cache(self, **kwargs):
+    def get_from_cache(self, **kwargs) -> Model:  # TODO(typing): Properly type this
         """
         Wrapper around QuerySet.get which supports caching of the
         intermediate value.  Callee is responsible for making sure

+ 3 - 1
src/sentry/snuba/json_schemas.py

@@ -1,3 +1,5 @@
+from typing import Any, Mapping
+
 SUBSCRIPTION_WRAPPER_SCHEMA = {
     "type": "object",
     "properties": {"version": {"type": "integer"}, "payload": {"type": "object"}},
@@ -6,7 +8,7 @@ SUBSCRIPTION_WRAPPER_SCHEMA = {
 }
 
 
-SUBSCRIPTION_PAYLOAD_VERSIONS = {
+SUBSCRIPTION_PAYLOAD_VERSIONS: Mapping[int, Mapping[str, Any]] = {
     1: {
         "type": "object",
         "properties": {

+ 38 - 29
src/sentry/snuba/query_subscription_consumer.py

@@ -1,9 +1,10 @@
 import logging
+from typing import Any, Callable, cast, Dict, Iterable, List, Optional
 
 import jsonschema
 import pytz
 import sentry_sdk
-from confluent_kafka import Consumer, KafkaException, OFFSET_INVALID, TopicPartition
+from confluent_kafka import Consumer, KafkaException, Message, OFFSET_INVALID, TopicPartition
 from confluent_kafka.admin import AdminClient
 from dateutil.parser import parse as parse_date
 from django.conf import settings
@@ -16,12 +17,15 @@ from sentry.utils.batching_kafka_consumer import wait_for_topics
 
 logger = logging.getLogger(__name__)
 
+TQuerySubscriptionCallable = Callable[[Dict[str, Any], QuerySubscription], None]
 
-subscriber_registry = {}
+subscriber_registry: Dict[str, TQuerySubscriptionCallable] = {}
 
 
-def register_subscriber(subscriber_key):
-    def inner(func):
+def register_subscriber(
+    subscriber_key: str,
+) -> Callable[[TQuerySubscriptionCallable], TQuerySubscriptionCallable]:
+    def inner(func: TQuerySubscriptionCallable) -> TQuerySubscriptionCallable:
         if subscriber_key in subscriber_registry:
             raise Exception("Handler already registered for %s" % subscriber_key)
         subscriber_registry[subscriber_key] = func
@@ -45,28 +49,30 @@ class QuerySubscriptionConsumer:
     These values are passed along to a callback associated with the subscription.
     """
 
-    topic_to_dataset = {
+    topic_to_dataset: Dict[str, QueryDatasets] = {
         settings.KAFKA_EVENTS_SUBSCRIPTIONS_RESULTS: QueryDatasets.EVENTS,
         settings.KAFKA_TRANSACTIONS_SUBSCRIPTIONS_RESULTS: QueryDatasets.TRANSACTIONS,
     }
 
     def __init__(
         self,
-        group_id,
-        topic=None,
-        commit_batch_size=100,
-        initial_offset_reset="earliest",
-        force_offset_reset=None,
+        group_id: str,
+        topic: Optional[str] = None,
+        commit_batch_size: int = 100,
+        initial_offset_reset: str = "earliest",
+        force_offset_reset: Optional[str] = None,
     ):
         self.group_id = group_id
         if not topic:
-            topic = settings.KAFKA_EVENTS_SUBSCRIPTIONS_RESULTS
+            # TODO(typing): Need a way to get the actual value of settings to avoid this
+            topic = cast(str, settings.KAFKA_EVENTS_SUBSCRIPTIONS_RESULTS)
+
         self.topic = topic
-        cluster_name = settings.KAFKA_TOPICS[topic]["cluster"]
+        cluster_name: str = settings.KAFKA_TOPICS[topic]["cluster"]
         self.commit_batch_size = commit_batch_size
         self.initial_offset_reset = initial_offset_reset
-        self.offsets = {}
-        self.consumer = None
+        self.offsets: Dict[int, Optional[int]] = {}
+        self.consumer: Consumer = None
         self.cluster_options = kafka_config.get_kafka_consumer_cluster_options(
             cluster_name,
             {
@@ -84,26 +90,29 @@ class QuerySubscriptionConsumer:
         )
         self.resolve_partition_force_offset = self.offset_reset_name_to_func(force_offset_reset)
 
-    def offset_reset_name_to_func(self, offset_reset):
+    def offset_reset_name_to_func(
+        self, offset_reset: Optional[str]
+    ) -> Optional[Callable[[TopicPartition], TopicPartition]]:
         if offset_reset in {"smallest", "earliest", "beginning"}:
             return self.resolve_partition_offset_earliest
         elif offset_reset in {"largest", "latest", "end"}:
             return self.resolve_partition_offset_latest
+        return None
 
-    def resolve_partition_offset_earliest(self, partition):
+    def resolve_partition_offset_earliest(self, partition: TopicPartition) -> TopicPartition:
         low, high = self.consumer.get_watermark_offsets(partition)
         return TopicPartition(partition.topic, partition.partition, low)
 
-    def resolve_partition_offset_latest(self, partition):
+    def resolve_partition_offset_latest(self, partition: TopicPartition) -> TopicPartition:
         low, high = self.consumer.get_watermark_offsets(partition)
         return TopicPartition(partition.topic, partition.partition, high)
 
-    def run(self):
+    def run(self) -> None:
         logger.debug("Starting snuba query subscriber")
         self.offsets.clear()
 
-        def on_assign(consumer, partitions):
-            updated_partitions = []
+        def on_assign(consumer: Consumer, partitions: List[TopicPartition]) -> None:
+            updated_partitions: List[TopicPartition] = []
             for partition in partitions:
                 if self.resolve_partition_force_offset:
                     partition = self.resolve_partition_force_offset(partition)
@@ -124,7 +133,7 @@ class QuerySubscriptionConsumer:
                 },
             )
 
-        def on_revoke(consumer, partitions):
+        def on_revoke(consumer: Consumer, partitions: List[TopicPartition]) -> None:
             partition_numbers = [partition.partition for partition in partitions]
             self.commit_offsets(partition_numbers)
             for partition_number in partition_numbers:
@@ -177,7 +186,7 @@ class QuerySubscriptionConsumer:
 
         self.shutdown()
 
-    def commit_offsets(self, partitions=None):
+    def commit_offsets(self, partitions: Optional[Iterable[int]] = None) -> None:
         logger.info(
             "query-subscription-consumer.commit_offsets",
             extra={"offsets": str(self.offsets), "partitions": str(partitions)},
@@ -196,12 +205,12 @@ class QuerySubscriptionConsumer:
 
             self.consumer.commit(offsets=to_commit)
 
-    def shutdown(self):
+    def shutdown(self) -> None:
         logger.debug("Committing offsets and closing consumer")
         self.commit_offsets()
         self.consumer.close()
 
-    def handle_message(self, message):
+    def handle_message(self, message: Message) -> None:
         """
         Parses the value from Kafka, and if valid passes the payload to the callback defined by the
         subscription. If the subscription has been removed, or no longer has a valid callback then
@@ -229,7 +238,7 @@ class QuerySubscriptionConsumer:
 
             try:
                 with metrics.timer("snuba_query_subscriber.fetch_subscription"):
-                    subscription = QuerySubscription.objects.get_from_cache(
+                    subscription: QuerySubscription = QuerySubscription.objects.get_from_cache(
                         subscription_id=contents["subscription_id"]
                     )
                     if subscription.status != QuerySubscription.Status.ACTIVE.value:
@@ -284,7 +293,7 @@ class QuerySubscriptionConsumer:
 
                 callback(contents, subscription)
 
-    def parse_message_value(self, value):
+    def parse_message_value(self, value: str) -> Dict[str, Any]:
         """
         Parses the value received via the Kafka consumer and verifies that it
         matches the expected schema.
@@ -292,7 +301,7 @@ class QuerySubscriptionConsumer:
         :return: A dict with the parsed message
         """
         with metrics.timer("snuba_query_subscriber.parse_message_value.json_parse"):
-            wrapper = json.loads(value)
+            wrapper: Dict[str, Any] = json.loads(value)
 
         with metrics.timer("snuba_query_subscriber.parse_message_value.json_validate_wrapper"):
             try:
@@ -301,12 +310,12 @@ class QuerySubscriptionConsumer:
                 metrics.incr("snuba_query_subscriber.message_wrapper_invalid")
                 raise InvalidSchemaError("Message wrapper does not match schema")
 
-        schema_version = wrapper["version"]
+        schema_version: int = wrapper["version"]
         if schema_version not in SUBSCRIPTION_PAYLOAD_VERSIONS:
             metrics.incr("snuba_query_subscriber.message_wrapper_invalid_version")
             raise InvalidMessageError("Version specified in wrapper has no schema")
 
-        payload = wrapper["payload"]
+        payload: Dict[str, Any] = wrapper["payload"]
         with metrics.timer("snuba_query_subscriber.parse_message_value.json_validate_payload"):
             try:
                 jsonschema.validate(payload, SUBSCRIPTION_PAYLOAD_VERSIONS[schema_version])

+ 1 - 1
src/sentry/snuba/tasks.py

@@ -205,7 +205,7 @@ def _create_in_snuba(subscription):
     return json.loads(response.data)["subscription_id"]
 
 
-def _delete_from_snuba(dataset, subscription_id):
+def _delete_from_snuba(dataset: QueryDatasets, subscription_id: str) -> None:
     response = _snuba_pool.urlopen("DELETE", f"/{dataset.value}/subscriptions/{subscription_id}")
     if response.status != 202:
         raise SnubaError("HTTP %s response from Snuba!" % response.status)

+ 2 - 1
src/sentry/utils/batching_kafka_consumer.py

@@ -1,6 +1,7 @@
 import abc
 import logging
 import time
+from typing import List
 
 from confluent_kafka import (
     Consumer,
@@ -23,7 +24,7 @@ DEFAULT_QUEUED_MAX_MESSAGE_KBYTES = 50000
 DEFAULT_QUEUED_MIN_MESSAGES = 10000
 
 
-def wait_for_topics(admin_client, topics, timeout=10):
+def wait_for_topics(admin_client: AdminClient, topics: List[str], timeout: int = 10) -> None:
     """
     Make sure that the provided topics exist and have non-zero partitions in them.
     """

+ 4 - 2
src/sentry/utils/json.py

@@ -3,11 +3,13 @@
 # XXX(epurkhiser): We import JSONDecodeError just to have it be exported as
 # part of this module. We don't use it directly within the module, but modules
 # that import it from here will. Do not remove.
+
 from simplejson import JSONEncoder, JSONDecodeError, _default_decoder  # NOQA
 from enum import Enum
 import datetime
-import uuid
 import decimal
+import uuid
+from typing import Any
 
 from bitfield.types import BitHandler
 from django.utils.encoding import force_text
@@ -109,7 +111,7 @@ def load(fp, **kwargs):
     return loads(fp.read())
 
 
-def loads(value, **kwargs):
+def loads(value: str, **kwargs) -> Any:
     return _default_decoder.decode(value)
 
 

+ 8 - 2
src/sentry/utils/kafka_config.py

@@ -1,3 +1,5 @@
+from typing import Any, MutableMapping, Optional
+
 from django.conf import settings
 
 SUPPORTED_KAFKA_CONFIGURATION = (
@@ -75,13 +77,17 @@ def get_kafka_producer_cluster_options(cluster_name):
     return _get_kafka_cluster_options(cluster_name, PRODUCERS_SECTION)
 
 
-def get_kafka_consumer_cluster_options(cluster_name, override_params=None):
+def get_kafka_consumer_cluster_options(
+    cluster_name: str, override_params: Optional[MutableMapping[str, Any]] = None
+) -> MutableMapping[Any, Any]:
     return _get_kafka_cluster_options(
         cluster_name, CONSUMERS_SECTION, only_bootstrap=True, override_params=override_params
     )
 
 
-def get_kafka_admin_cluster_options(cluster_name, override_params=None):
+def get_kafka_admin_cluster_options(
+    cluster_name: str, override_params: Optional[MutableMapping[str, Any]] = None
+) -> MutableMapping[Any, Any]:
     return _get_kafka_cluster_options(
         cluster_name, ADMIN_SECTION, only_bootstrap=True, override_params=override_params
     )

+ 11 - 9
src/sentry/utils/metrics.py

@@ -1,10 +1,12 @@
 __all__ = ["timing", "incr"]
 
-import logging
-import time
 
 import functools
+import logging
+import time
 from contextlib import contextmanager
+from typing import Mapping, Optional
+
 from django.conf import settings
 from random import random
 from threading import Thread, local
@@ -123,13 +125,13 @@ internal = InternalMetrics()
 
 
 def incr(
-    key,
-    amount=1,
-    instance=None,
-    tags=None,
-    skip_internal=True,
-    sample_rate=settings.SENTRY_METRICS_SAMPLE_RATE,
-):
+    key: str,
+    amount: int = 1,
+    instance: Optional[str] = None,
+    tags: Optional[Mapping[str, str]] = None,
+    skip_internal: bool = True,
+    sample_rate: float = settings.SENTRY_METRICS_SAMPLE_RATE,
+) -> None:
     current_tags = _get_current_global_tags()
     if tags is not None:
         current_tags.update(tags)