Browse Source

ref: type sentry.utils.sms and sentry.utils.otp (#70084)

planning to refactor this to be a little more typesafe -- but typing it
first so the refactor is more apparent and not lost in this

<!-- Describe your PR here. -->
anthony sottile 10 months ago
parent
commit
2535c544a0

+ 2 - 0
pyproject.toml

@@ -593,8 +593,10 @@ module = [
     "sentry.utils.email.*",
     "sentry.utils.iterators",
     "sentry.utils.locking.backends.redis",
+    "sentry.utils.otp",
     "sentry.utils.redis",
     "sentry.utils.redis_metrics",
+    "sentry.utils.sms",
     "sentry_plugins.base",
     "tests.sentry.api.endpoints.issues.*",
     "tests.sentry.grouping.test_fingerprinting",

+ 2 - 2
src/sentry/auth/authenticators/__init__.py

@@ -9,14 +9,14 @@ AUTHENTICATOR_INTERFACES_BY_TYPE = {}
 AUTHENTICATOR_CHOICES = []
 
 
-def register_authenticator(cls):
+def register_authenticator(cls: type[AuthenticatorInterface]) -> None:
     AUTHENTICATOR_INTERFACES[cls.interface_id] = cls
     AUTHENTICATOR_INTERFACES_BY_TYPE[cls.type] = cls
     AUTHENTICATOR_CHOICES.append((cls.type, cls.name))
     AUTHENTICATOR_CHOICES.sort(key=lambda x: x[0])
 
 
-def available_authenticators(ignore_backup=False):
+def available_authenticators(ignore_backup: bool = False) -> list[type[AuthenticatorInterface]]:
     interfaces = AUTHENTICATOR_INTERFACES.values()
     if not ignore_backup:
         return [v for v in interfaces if v.is_available]

+ 31 - 22
src/sentry/auth/authenticators/base.py

@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from enum import Enum
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Literal, Self
 
 from django.core.cache import cache
 from django.utils import timezone
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
     from django.utils.functional import _StrPromise
 
     from sentry.models.authenticator import Authenticator
+    from sentry.models.user import User
 
 
 class ActivationResult:
@@ -22,7 +23,11 @@ class ActivationResult:
 
 
 class ActivationMessageResult(ActivationResult):
-    def __init__(self, message, type="info"):
+    def __init__(
+        self,
+        message: str | _StrPromise,
+        type: Literal["error", "warning", "info"] = "info",
+    ) -> None:
         assert type in ("error", "warning", "info")
         self.type = type
         self.message = message
@@ -30,14 +35,14 @@ class ActivationMessageResult(ActivationResult):
     def __str__(self):
         return self.message
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"<{type(self).__name__}: {self.message}>"
 
 
 class ActivationChallengeResult(ActivationResult):
     type = "challenge"
 
-    def __init__(self, challenge):
+    def __init__(self, challenge: bytes) -> None:
         self.challenge = challenge
 
 
@@ -66,25 +71,27 @@ class AuthenticatorInterface:
     allow_multi_enrollment = False
     allow_rotation_in_place = False
 
-    def __init__(self, authenticator=None, status=EnrollmentStatus.EXISTING):
+    def __init__(
+        self, authenticator=None, status: EnrollmentStatus = EnrollmentStatus.EXISTING
+    ) -> None:
         self.authenticator = authenticator
         self.status = status
 
     @classmethod
-    def generate(cls, status):
+    def generate(cls, status: EnrollmentStatus) -> Self:
         # Convenience method to build new instances either from the
         # class or existing instances. That is, it's nicer than doing
         # `type(interface)()`.
         return cls(status=status)
 
-    def is_enrolled(self):
+    def is_enrolled(self) -> bool:
         """Returns `True` if the interfaces is enrolled (eg: has an
         authenticator for a user attached).
         """
         return self.authenticator is not None
 
     @property
-    def disallow_new_enrollment(self):
+    def disallow_new_enrollment(self) -> bool:
         """If new enrollments of this 2FA interface type are no allowed
         this returns `True`.
 
@@ -95,7 +102,7 @@ class AuthenticatorInterface:
         return bool(options.get(f"{self.interface_id}.disallow-new-enrollment"))
 
     @property
-    def requires_activation(self):
+    def requires_activation(self) -> bool:
         """If the interface has an activation method that needs to be
         called this returns `True`.
         """
@@ -109,7 +116,7 @@ class AuthenticatorInterface:
         return type(self).validate_otp is not AuthenticatorInterface.validate_otp
 
     @property
-    def config(self):
+    def config(self) -> dict[str, Any]:
         """Returns the configuration dictionary for this interface.  If
         the interface is registered with an authenticator (eg: it is
         enrolled) then the authenticator's config is returned, otherwise
@@ -125,7 +132,7 @@ class AuthenticatorInterface:
             rv = self._unbound_config = self.generate_new_config()
         return rv
 
-    def generate_new_config(self):
+    def generate_new_config(self) -> dict[str, Any]:
         """This method is invoked if a new config is required."""
         return {}
 
@@ -137,7 +144,7 @@ class AuthenticatorInterface:
         # This method needs to be empty for the default
         # `requires_activation` property to make sense.
 
-    def enroll(self, user):
+    def enroll(self, user: User) -> None:
         """Invoked to enroll a user for this interface.  If already enrolled
         an error is raised.
 
@@ -158,7 +165,7 @@ class AuthenticatorInterface:
             self.authenticator.config = self.config
             self.authenticator.save()
 
-    def rotate_in_place(self):
+    def rotate_in_place(self) -> None:
         if not self.allow_rotation_in_place:
             raise Exception("This interface does not allow rotation in place")
         if self.authenticator is None:
@@ -169,7 +176,7 @@ class AuthenticatorInterface:
         self.authenticator.last_used_at = None
         self.authenticator.save()
 
-    def validate_otp(self, otp):
+    def validate_otp(self, otp: str) -> bool:
         """This method is invoked for an OTP response and has to return
         `True` or `False` based on the validity of the OTP response.  Note
         that this can be called with otp responses from other interfaces.
@@ -190,25 +197,27 @@ class OtpMixin:
     config: dict[str, Any]
     authenticator: Authenticator | None
 
-    def generate_new_config(self):
+    def generate_new_config(self) -> dict[str, Any]:
         return {"secret": generate_secret_key()}
 
     @property
-    def secret(self):
+    def secret(self) -> str:
         return self.config["secret"]
 
     @secret.setter
-    def secret(self, secret):
+    def secret(self, secret: str) -> None:
         self.config["secret"] = secret
 
-    def make_otp(self):
+    def make_otp(self) -> TOTP:
         return TOTP(self.secret)
 
-    def _get_otp_counter_cache_key(self, counter):
+    def _get_otp_counter_cache_key(self, counter: int) -> str | None:
         if self.authenticator is not None:
             return f"used-otp-counters:{self.authenticator.user_id}:{counter}"
+        else:
+            return None
 
-    def check_otp_counter(self, counter):
+    def check_otp_counter(self, counter: int) -> bool:
         # OTP uses an internal counter that increments every 30 seconds.
         # A hash function generates a six digit code based on the counter
         # and a secret key.  If the generated PIN was used it is marked in
@@ -217,13 +226,13 @@ class OtpMixin:
         cache_key = self._get_otp_counter_cache_key(counter)
         return cache_key is None or cache.get(cache_key) != "1"
 
-    def mark_otp_counter_used(self, counter):
+    def mark_otp_counter_used(self, counter: int) -> None:
         cache_key = self._get_otp_counter_cache_key(counter)
         if cache_key is not None:
             # Mark us used for three windows
             cache.set(cache_key, "1", timeout=120)
 
-    def validate_otp(self, otp):
+    def validate_otp(self, otp: str) -> bool:
         if not otp:
             return False
         otp = otp.strip().replace("-", "").replace(" ", "")

+ 4 - 3
src/sentry/auth/authenticators/sms.py

@@ -4,6 +4,7 @@ import logging
 from hashlib import md5
 from typing import TYPE_CHECKING
 
+from django.http.request import HttpRequest
 from django.utils.translation import gettext_lazy as _
 
 from sentry.ratelimits import backend as ratelimiter
@@ -50,7 +51,7 @@ class SmsInterface(OtpMixin, AuthenticatorInterface):
         config["phone_number"] = None
         return config
 
-    def make_otp(self):
+    def make_otp(self) -> TOTP:
         return TOTP(self.config["secret"], digits=6, interval=self.code_ttl, default_window=1)
 
     @property
@@ -61,7 +62,7 @@ class SmsInterface(OtpMixin, AuthenticatorInterface):
     def phone_number(self, value):
         self.config["phone_number"] = value
 
-    def activate(self, request):
+    def activate(self, request: HttpRequest) -> ActivationMessageResult:
         phone_number = self.config["phone_number"]
         if len(phone_number) == 10:
             mask = "(***) ***-**%s" % (phone_number[-2:])
@@ -84,7 +85,7 @@ class SmsInterface(OtpMixin, AuthenticatorInterface):
             type="error",
         )
 
-    def send_text(self, for_enrollment=False, request=None):
+    def send_text(self, *, request: HttpRequest, for_enrollment: bool = False) -> bool:
         ctx = {"code": self.make_otp().generate_otp()}
 
         if for_enrollment:

+ 1 - 1
src/sentry/auth/authenticators/totp.py

@@ -24,5 +24,5 @@ class TotpInterface(OtpMixin, AuthenticatorInterface):
         "replaced and will no longer work to access your account."
     )
 
-    def get_provision_url(self, user, issuer=None):
+    def get_provision_url(self, user: str, issuer: str | None = None) -> str:
         return self.make_otp().get_provision_url(user, issuer=issuer)

+ 23 - 7
src/sentry/utils/otp.py

@@ -2,17 +2,18 @@ import base64
 import hashlib
 import hmac
 import time
+from collections.abc import Callable
 from datetime import datetime
 from urllib.parse import quote
 
 from django.utils.crypto import constant_time_compare, get_random_string
 
 
-def generate_secret_key(length=32):
+def generate_secret_key(length: int = 32) -> str:
     return get_random_string(length, "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567")
 
 
-def _pack_int(i):
+def _pack_int(i: int) -> bytes:
     result = bytearray()
     while i != 0:
         result.append(i & 0xFF)
@@ -20,7 +21,7 @@ def _pack_int(i):
     return bytes(bytearray(reversed(result)).rjust(8, b"\0"))
 
 
-def _get_ts(ts):
+def _get_ts(ts: float | datetime | None) -> int:
     if ts is None:
         return int(time.time())
     if isinstance(ts, datetime):
@@ -29,7 +30,13 @@ def _get_ts(ts):
 
 
 class TOTP:
-    def __init__(self, secret=None, digits=6, interval=30, default_window=2):
+    def __init__(
+        self,
+        secret: str | None = None,
+        digits: int = 6,
+        interval: int = 30,
+        default_window: int = 2,
+    ) -> None:
         if secret is None:
             secret = generate_secret_key()
         if len(secret) % 8 != 0:
@@ -39,7 +46,9 @@ class TOTP:
         self.interval = interval
         self.default_window = default_window
 
-    def generate_otp(self, ts=None, offset=0, counter=None):
+    def generate_otp(
+        self, ts: float | datetime | None = None, offset: int = 0, counter: int | None = None
+    ) -> str:
         if counter is None:
             ts = _get_ts(ts)
             counter = int(ts) // self.interval + offset
@@ -60,7 +69,14 @@ class TOTP:
         str_code = str(code % 10**self.digits)
         return ("0" * (self.digits - len(str_code))) + str_code
 
-    def verify(self, otp, ts=None, window=None, return_counter=False, check_counter_func=None):
+    def verify(
+        self,
+        otp: str,
+        ts: float | datetime | None = None,
+        window: int | None = None,
+        return_counter: bool = False,
+        check_counter_func: Callable[[int], bool] | None = None,
+    ) -> bool | int | None:
         ts = _get_ts(ts)
         if window is None:
             window = self.default_window
@@ -78,7 +94,7 @@ class TOTP:
             return None
         return False
 
-    def get_provision_url(self, user, issuer=None):
+    def get_provision_url(self, user: str, issuer: str | None = None) -> str:
         if issuer is None:
             issuer = "Sentry"
         rv = "otpauth://totp/{}?issuer={}&secret={}".format(

+ 4 - 13
src/sentry/utils/sms.py

@@ -10,17 +10,8 @@ logger = logging.getLogger(__name__)
 
 
 class InvalidPhoneNumber(Exception):
-    def __init__(self, *args):
-        if args:
-            self.message = args[0]
-        else:
-            self.message = None
-
-    def __str__(self):
-        if self.message:
-            return f"InvalidPhoneNumber: {self.message}"
-        else:
-            return "InvalidPhoneNumber"
+    def __str__(self) -> str:
+        return "InvalidPhoneNumber"
 
 
 def validate_phone_number(phone_number: str) -> bool:
@@ -50,11 +41,11 @@ def phone_number_as_e164(num: str) -> str:
         raise InvalidPhoneNumber
 
 
-def sms_available():
+def sms_available() -> bool:
     return bool(options.get("sms.twilio-account"))
 
 
-def send_sms(body, to, from_=None):
+def send_sms(body: str, to: str, from_: str | None = None) -> bool:
     account = options.get("sms.twilio-account")
     if not account:
         raise RuntimeError("SMS backend is not configured.")