123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- from __future__ import annotations
- import functools
- from collections.abc import Callable
- from mypy.build import PRI_MYPY
- from mypy.errorcodes import ATTR_DEFINED
- from mypy.messages import format_type
- from mypy.nodes import ARG_POS, MypyFile, TypeInfo
- from mypy.plugin import (
- AttributeContext,
- ClassDefContext,
- FunctionSigContext,
- Plugin,
- SemanticAnalyzerPluginInterface,
- )
- from mypy.plugins.common import add_attribute_to_class
- from mypy.subtypes import find_member
- from mypy.typeanal import make_optional_type
- from mypy.types import (
- AnyType,
- CallableType,
- FunctionLike,
- Instance,
- NoneType,
- Type,
- TypeOfAny,
- UnionType,
- )
- def _make_using_required_str(ctx: FunctionSigContext) -> CallableType:
- sig = ctx.default_signature
- using_arg = sig.argument_by_name("using")
- if using_arg is None or using_arg.pos is None:
- ctx.api.fail("The using parameter is required", ctx.context)
- return sig
- for kind in sig.arg_kinds[: using_arg.pos]:
- if kind != ARG_POS:
- ctx.api.fail("Expected using to be the first optional", ctx.context)
- return sig
- str_type = ctx.api.named_generic_type("builtins.str", [])
- arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]]
- arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]]
- return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types)
- def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
- sig = ctx.default_signature
- if not sig.argument_by_name("using"):
- # No using arg in the signature, bail
- return sig
- # We care about context managers.
- if not isinstance(sig.ret_type, Instance):
- return sig
- return _make_using_required_str(ctx)
- _FUNCTION_SIGNATURE_HOOKS = {
- "django.db.transaction.atomic": replace_transaction_atomic_sig_callback,
- "django.db.transaction.get_connection": _make_using_required_str,
- "django.db.transaction.on_commit": _make_using_required_str,
- "django.db.transaction.set_rollback": _make_using_required_str,
- }
- _AUTH_TOKEN_TP = "sentry.auth.services.auth.model.AuthenticatedToken"
- def _has_symbols(api: SemanticAnalyzerPluginInterface, *symbols: str) -> bool:
- for symbol in symbols:
- if not api.lookup_fully_qualified_or_none(symbol):
- return False
- else:
- return True
- def _request_auth_tp(api: SemanticAnalyzerPluginInterface) -> Type:
- st = api.lookup_fully_qualified(_AUTH_TOKEN_TP)
- assert isinstance(st.node, TypeInfo), st.node
- return make_optional_type(Instance(st.node, ()))
- def _adjust_http_request_members(ctx: ClassDefContext) -> None:
- if ctx.cls.name == "HttpRequest":
- if not _has_symbols(ctx.api, _AUTH_TOKEN_TP):
- return ctx.api.defer()
- # added by sentry.api.base and sentry.web.frontend.base
- # TODO: idk why I can't use the real type here :/
- add_attribute_to_class(ctx.api, ctx.cls, "access", AnyType(TypeOfAny.explicit))
- # added by sentry.middleware.auth
- add_attribute_to_class(ctx.api, ctx.cls, "auth", _request_auth_tp(ctx.api))
- # added by csp.middleware.CSPMiddleware
- add_attribute_to_class(ctx.api, ctx.cls, "csp_nonce", ctx.api.named_type("builtins.str"))
- # added by sudo.middleware.SudoMiddleware
- # this is slightly better than a method returning bool for overriding
- returns_bool = CallableType(
- arg_types=[],
- arg_kinds=[],
- arg_names=[],
- ret_type=ctx.api.named_type("builtins.bool"),
- fallback=ctx.api.named_type("builtins.function"),
- name="is_sudo",
- )
- add_attribute_to_class(ctx.api, ctx.cls, "is_sudo", returns_bool)
- # added by sentry.middleware.subdomain
- subdomain_tp = UnionType([NoneType(), ctx.api.named_type("builtins.str")])
- add_attribute_to_class(ctx.api, ctx.cls, "subdomain", subdomain_tp)
- # added by sentry.middleware.superuser
- # TODO: figure out how to get the real types here
- add_attribute_to_class(ctx.api, ctx.cls, "superuser", AnyType(TypeOfAny.explicit))
- def _adjust_request_members(ctx: ClassDefContext) -> None:
- if ctx.cls.name == "Request":
- if not _has_symbols(ctx.api, _AUTH_TOKEN_TP):
- return ctx.api.defer()
- # sentry.auth.middleware / sentry.api.authentication
- add_attribute_to_class(ctx.api, ctx.cls, "auth", _request_auth_tp(ctx.api))
- def _lazy_service_wrapper_attribute(ctx: AttributeContext, *, attr: str) -> Type:
- # we use `Any` as the `__getattr__` return value
- # allow existing attributes to be returned as normal if they are not `Any`
- if not isinstance(ctx.default_attr_type, AnyType):
- return ctx.default_attr_type
- assert isinstance(ctx.type, Instance), ctx.type
- assert len(ctx.type.args) == 1, ctx.type
- assert isinstance(ctx.type.args[0], Instance), ctx.type
- generic_type = ctx.type.args[0]
- member = find_member(attr, generic_type, generic_type)
- if member is None:
- ctx.api.fail(
- f'{format_type(ctx.type, ctx.api.options)} has no attribute "{attr}"',
- ctx.context,
- code=ATTR_DEFINED,
- )
- return ctx.default_attr_type
- else:
- return member
- class SentryMypyPlugin(Plugin):
- def get_function_signature_hook(
- self, fullname: str
- ) -> Callable[[FunctionSigContext], FunctionLike] | None:
- return _FUNCTION_SIGNATURE_HOOKS.get(fullname)
- def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
- # XXX: this is a hack -- I don't know if there's a better callback to modify a class
- if fullname == "_io.BytesIO":
- return _adjust_http_request_members
- elif fullname == "django.http.request.HttpRequest":
- return _adjust_request_members
- else:
- return None
- def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
- if fullname.startswith("sentry.utils.lazy_service_wrapper.LazyServiceWrapper."):
- _, attr = fullname.rsplit(".", 1)
- return functools.partial(_lazy_service_wrapper_attribute, attr=attr)
- else:
- return None
- def get_additional_deps(self, file: MypyFile) -> list[tuple[int, str, int]]:
- if file.fullname in {"django.http", "django.http.request", "rest_framework.request"}:
- return [(PRI_MYPY, "sentry.auth.services.auth.model", -1)]
- else:
- return []
- def plugin(version: str) -> type[SentryMypyPlugin]:
- return SentryMypyPlugin
|