plugin.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from __future__ import annotations
  2. import functools
  3. from collections.abc import Callable
  4. from mypy.build import PRI_MYPY
  5. from mypy.errorcodes import ATTR_DEFINED
  6. from mypy.messages import format_type
  7. from mypy.nodes import ARG_POS, MypyFile, TypeInfo
  8. from mypy.plugin import (
  9. AttributeContext,
  10. ClassDefContext,
  11. FunctionSigContext,
  12. Plugin,
  13. SemanticAnalyzerPluginInterface,
  14. )
  15. from mypy.plugins.common import add_attribute_to_class
  16. from mypy.subtypes import find_member
  17. from mypy.typeanal import make_optional_type
  18. from mypy.types import (
  19. AnyType,
  20. CallableType,
  21. FunctionLike,
  22. Instance,
  23. NoneType,
  24. Type,
  25. TypeOfAny,
  26. UnionType,
  27. )
  28. def _make_using_required_str(ctx: FunctionSigContext) -> CallableType:
  29. sig = ctx.default_signature
  30. using_arg = sig.argument_by_name("using")
  31. if using_arg is None or using_arg.pos is None:
  32. ctx.api.fail("The using parameter is required", ctx.context)
  33. return sig
  34. for kind in sig.arg_kinds[: using_arg.pos]:
  35. if kind != ARG_POS:
  36. ctx.api.fail("Expected using to be the first optional", ctx.context)
  37. return sig
  38. str_type = ctx.api.named_generic_type("builtins.str", [])
  39. arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]]
  40. arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]]
  41. return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types)
  42. def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
  43. sig = ctx.default_signature
  44. if not sig.argument_by_name("using"):
  45. # No using arg in the signature, bail
  46. return sig
  47. # We care about context managers.
  48. if not isinstance(sig.ret_type, Instance):
  49. return sig
  50. return _make_using_required_str(ctx)
  51. _FUNCTION_SIGNATURE_HOOKS = {
  52. "django.db.transaction.atomic": replace_transaction_atomic_sig_callback,
  53. "django.db.transaction.get_connection": _make_using_required_str,
  54. "django.db.transaction.on_commit": _make_using_required_str,
  55. "django.db.transaction.set_rollback": _make_using_required_str,
  56. }
  57. _AUTH_TOKEN_TP = "sentry.auth.services.auth.model.AuthenticatedToken"
  58. def _has_symbols(api: SemanticAnalyzerPluginInterface, *symbols: str) -> bool:
  59. for symbol in symbols:
  60. if not api.lookup_fully_qualified_or_none(symbol):
  61. return False
  62. else:
  63. return True
  64. def _request_auth_tp(api: SemanticAnalyzerPluginInterface) -> Type:
  65. st = api.lookup_fully_qualified(_AUTH_TOKEN_TP)
  66. assert isinstance(st.node, TypeInfo), st.node
  67. return make_optional_type(Instance(st.node, ()))
  68. def _adjust_http_request_members(ctx: ClassDefContext) -> None:
  69. if ctx.cls.name == "HttpRequest":
  70. if not _has_symbols(ctx.api, _AUTH_TOKEN_TP):
  71. return ctx.api.defer()
  72. # added by sentry.api.base and sentry.web.frontend.base
  73. # TODO: idk why I can't use the real type here :/
  74. add_attribute_to_class(ctx.api, ctx.cls, "access", AnyType(TypeOfAny.explicit))
  75. # added by sentry.middleware.auth
  76. add_attribute_to_class(ctx.api, ctx.cls, "auth", _request_auth_tp(ctx.api))
  77. # added by csp.middleware.CSPMiddleware
  78. add_attribute_to_class(ctx.api, ctx.cls, "csp_nonce", ctx.api.named_type("builtins.str"))
  79. # added by sudo.middleware.SudoMiddleware
  80. # this is slightly better than a method returning bool for overriding
  81. returns_bool = CallableType(
  82. arg_types=[],
  83. arg_kinds=[],
  84. arg_names=[],
  85. ret_type=ctx.api.named_type("builtins.bool"),
  86. fallback=ctx.api.named_type("builtins.function"),
  87. name="is_sudo",
  88. )
  89. add_attribute_to_class(ctx.api, ctx.cls, "is_sudo", returns_bool)
  90. # added by sentry.middleware.subdomain
  91. subdomain_tp = UnionType([NoneType(), ctx.api.named_type("builtins.str")])
  92. add_attribute_to_class(ctx.api, ctx.cls, "subdomain", subdomain_tp)
  93. # added by sentry.middleware.superuser
  94. # TODO: figure out how to get the real types here
  95. add_attribute_to_class(ctx.api, ctx.cls, "superuser", AnyType(TypeOfAny.explicit))
  96. def _adjust_request_members(ctx: ClassDefContext) -> None:
  97. if ctx.cls.name == "Request":
  98. if not _has_symbols(ctx.api, _AUTH_TOKEN_TP):
  99. return ctx.api.defer()
  100. # sentry.auth.middleware / sentry.api.authentication
  101. add_attribute_to_class(ctx.api, ctx.cls, "auth", _request_auth_tp(ctx.api))
  102. def _lazy_service_wrapper_attribute(ctx: AttributeContext, *, attr: str) -> Type:
  103. # we use `Any` as the `__getattr__` return value
  104. # allow existing attributes to be returned as normal if they are not `Any`
  105. if not isinstance(ctx.default_attr_type, AnyType):
  106. return ctx.default_attr_type
  107. assert isinstance(ctx.type, Instance), ctx.type
  108. assert len(ctx.type.args) == 1, ctx.type
  109. assert isinstance(ctx.type.args[0], Instance), ctx.type
  110. generic_type = ctx.type.args[0]
  111. member = find_member(attr, generic_type, generic_type)
  112. if member is None:
  113. ctx.api.fail(
  114. f'{format_type(ctx.type, ctx.api.options)} has no attribute "{attr}"',
  115. ctx.context,
  116. code=ATTR_DEFINED,
  117. )
  118. return ctx.default_attr_type
  119. else:
  120. return member
  121. class SentryMypyPlugin(Plugin):
  122. def get_function_signature_hook(
  123. self, fullname: str
  124. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  125. return _FUNCTION_SIGNATURE_HOOKS.get(fullname)
  126. def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
  127. # XXX: this is a hack -- I don't know if there's a better callback to modify a class
  128. if fullname == "_io.BytesIO":
  129. return _adjust_http_request_members
  130. elif fullname == "django.http.request.HttpRequest":
  131. return _adjust_request_members
  132. else:
  133. return None
  134. def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
  135. if fullname.startswith("sentry.utils.lazy_service_wrapper.LazyServiceWrapper."):
  136. _, attr = fullname.rsplit(".", 1)
  137. return functools.partial(_lazy_service_wrapper_attribute, attr=attr)
  138. else:
  139. return None
  140. def get_additional_deps(self, file: MypyFile) -> list[tuple[int, str, int]]:
  141. if file.fullname in {"django.http", "django.http.request", "rest_framework.request"}:
  142. return [(PRI_MYPY, "sentry.auth.services.auth.model", -1)]
  143. else:
  144. return []
  145. def plugin(version: str) -> type[SentryMypyPlugin]:
  146. return SentryMypyPlugin