plugin.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. from __future__ import annotations
  2. import functools
  3. from collections.abc import Callable
  4. from mypy.build import PRI_MYPY
  5. from mypy.errorcodes import ATTR_DEFINED
  6. from mypy.messages import format_type
  7. from mypy.nodes import ARG_POS, MDEF, MypyFile, SymbolTableNode, TypeInfo, Var
  8. from mypy.plugin import (
  9. AttributeContext,
  10. ClassDefContext,
  11. FunctionSigContext,
  12. Plugin,
  13. SemanticAnalyzerPluginInterface,
  14. )
  15. from mypy.plugins.common import add_attribute_to_class
  16. from mypy.subtypes import find_member
  17. from mypy.typeanal import make_optional_type
  18. from mypy.types import (
  19. AnyType,
  20. CallableType,
  21. FunctionLike,
  22. Instance,
  23. NoneType,
  24. Type,
  25. TypeOfAny,
  26. UnionType,
  27. )
  28. def _make_using_required_str(ctx: FunctionSigContext) -> CallableType:
  29. sig = ctx.default_signature
  30. using_arg = sig.argument_by_name("using")
  31. if using_arg is None or using_arg.pos is None:
  32. ctx.api.fail("The using parameter is required", ctx.context)
  33. return sig
  34. for kind in sig.arg_kinds[: using_arg.pos]:
  35. if kind != ARG_POS:
  36. ctx.api.fail("Expected using to be the first optional", ctx.context)
  37. return sig
  38. str_type = ctx.api.named_generic_type("builtins.str", [])
  39. arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]]
  40. arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]]
  41. return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types)
  42. def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
  43. sig = ctx.default_signature
  44. if not sig.argument_by_name("using"):
  45. # No using arg in the signature, bail
  46. return sig
  47. # We care about context managers.
  48. if not isinstance(sig.ret_type, Instance):
  49. return sig
  50. return _make_using_required_str(ctx)
  51. _FUNCTION_SIGNATURE_HOOKS = {
  52. "django.db.transaction.atomic": replace_transaction_atomic_sig_callback,
  53. "django.db.transaction.get_connection": _make_using_required_str,
  54. "django.db.transaction.on_commit": _make_using_required_str,
  55. "django.db.transaction.set_rollback": _make_using_required_str,
  56. }
  57. _AUTH_TOKEN_TP = "sentry.auth.services.auth.model.AuthenticatedToken"
  58. def _has_symbols(api: SemanticAnalyzerPluginInterface, *symbols: str) -> bool:
  59. for symbol in symbols:
  60. if not api.lookup_fully_qualified_or_none(symbol):
  61. return False
  62. else:
  63. return True
  64. def _request_auth_tp(api: SemanticAnalyzerPluginInterface) -> Type:
  65. st = api.lookup_fully_qualified(_AUTH_TOKEN_TP)
  66. assert isinstance(st.node, TypeInfo), st.node
  67. return make_optional_type(Instance(st.node, ()))
  68. def _adjust_http_request_members(ctx: ClassDefContext) -> None:
  69. if ctx.cls.name == "HttpRequest":
  70. if not _has_symbols(ctx.api, _AUTH_TOKEN_TP):
  71. return ctx.api.defer()
  72. # added by sentry.api.base and sentry.web.frontend.base
  73. # TODO: idk why I can't use the real type here :/
  74. add_attribute_to_class(ctx.api, ctx.cls, "access", AnyType(TypeOfAny.explicit))
  75. # added by sentry.middleware.auth
  76. add_attribute_to_class(ctx.api, ctx.cls, "auth", _request_auth_tp(ctx.api))
  77. # added by csp.middleware.CSPMiddleware
  78. add_attribute_to_class(ctx.api, ctx.cls, "csp_nonce", ctx.api.named_type("builtins.str"))
  79. # added by sudo.middleware.SudoMiddleware
  80. # this is slightly better than a method returning bool for overriding
  81. returns_bool = CallableType(
  82. arg_types=[],
  83. arg_kinds=[],
  84. arg_names=[],
  85. ret_type=ctx.api.named_type("builtins.bool"),
  86. fallback=ctx.api.named_type("builtins.function"),
  87. name="is_sudo",
  88. )
  89. add_attribute_to_class(ctx.api, ctx.cls, "is_sudo", returns_bool)
  90. # added by sentry.middleware.subdomain
  91. subdomain_tp = UnionType([NoneType(), ctx.api.named_type("builtins.str")])
  92. add_attribute_to_class(ctx.api, ctx.cls, "subdomain", subdomain_tp)
  93. # added by sentry.middleware.superuser
  94. # TODO: figure out how to get the real types here
  95. add_attribute_to_class(ctx.api, ctx.cls, "superuser", AnyType(TypeOfAny.explicit))
  96. def _adjust_request_members(ctx: ClassDefContext) -> None:
  97. if ctx.cls.name == "Request":
  98. if not _has_symbols(ctx.api, _AUTH_TOKEN_TP):
  99. return ctx.api.defer()
  100. # sentry.auth.middleware / sentry.api.authentication
  101. add_attribute_to_class(ctx.api, ctx.cls, "auth", _request_auth_tp(ctx.api))
  102. def _add_name_to_info(ti: TypeInfo, name: str, tp: Type) -> None:
  103. node = Var(name, tp)
  104. node.info = ti
  105. node._fullname = f"{ti.fullname}.{name}"
  106. ti.names[name] = SymbolTableNode(MDEF, node, plugin_generated=True)
  107. def _adjust_http_response_members(ctx: ClassDefContext) -> None:
  108. # there isn't a good plugin point for HttpResponseBase so we add it here?
  109. if ctx.cls.name == "HttpResponse":
  110. dict_str_list_str = ctx.api.named_type(
  111. "builtins.dict",
  112. [
  113. ctx.api.named_type("builtins.str"),
  114. ctx.api.named_type("builtins.list", [ctx.api.named_type("builtins.str")]),
  115. ],
  116. )
  117. base = ctx.cls.info.bases[0].type
  118. assert base.name == "HttpResponseBase", base.name
  119. _add_name_to_info(base, "_csp_replace", dict_str_list_str)
  120. def _lazy_service_wrapper_attribute(ctx: AttributeContext, *, attr: str) -> Type:
  121. # we use `Any` as the `__getattr__` return value
  122. # allow existing attributes to be returned as normal if they are not `Any`
  123. if not isinstance(ctx.default_attr_type, AnyType):
  124. return ctx.default_attr_type
  125. assert isinstance(ctx.type, Instance), ctx.type
  126. assert len(ctx.type.args) == 1, ctx.type
  127. assert isinstance(ctx.type.args[0], Instance), ctx.type
  128. generic_type = ctx.type.args[0]
  129. member = find_member(attr, generic_type, generic_type)
  130. if member is None:
  131. ctx.api.fail(
  132. f'{format_type(ctx.type, ctx.api.options)} has no attribute "{attr}"',
  133. ctx.context,
  134. code=ATTR_DEFINED,
  135. )
  136. return ctx.default_attr_type
  137. else:
  138. return member
  139. class SentryMypyPlugin(Plugin):
  140. def get_function_signature_hook(
  141. self, fullname: str
  142. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  143. return _FUNCTION_SIGNATURE_HOOKS.get(fullname)
  144. def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
  145. # XXX: this is a hack -- I don't know if there's a better callback to modify a class
  146. if fullname == "_io.BytesIO":
  147. return _adjust_http_request_members
  148. elif fullname == "django.http.request.HttpRequest":
  149. return _adjust_request_members
  150. elif fullname == "django.http.response.HttpResponseBase":
  151. return _adjust_http_response_members
  152. else:
  153. return None
  154. def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
  155. if fullname.startswith("sentry.utils.lazy_service_wrapper.LazyServiceWrapper."):
  156. _, attr = fullname.rsplit(".", 1)
  157. return functools.partial(_lazy_service_wrapper_attribute, attr=attr)
  158. else:
  159. return None
  160. def get_additional_deps(self, file: MypyFile) -> list[tuple[int, str, int]]:
  161. if file.fullname in {"django.http", "django.http.request", "rest_framework.request"}:
  162. return [(PRI_MYPY, "sentry.auth.services.auth.model", -1)]
  163. else:
  164. return []
  165. def plugin(version: str) -> type[SentryMypyPlugin]:
  166. return SentryMypyPlugin