plugin.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from __future__ import annotations
  2. from typing import Callable
  3. from mypy.nodes import ARG_POS, TypeInfo
  4. from mypy.plugin import ClassDefContext, FunctionSigContext, MethodSigContext, Plugin
  5. from mypy.plugins.common import add_attribute_to_class
  6. from mypy.types import AnyType, CallableType, FunctionLike, Instance, NoneType, TypeOfAny, UnionType
  7. def _make_using_required_str(ctx: FunctionSigContext) -> CallableType:
  8. sig = ctx.default_signature
  9. using_arg = sig.argument_by_name("using")
  10. if using_arg is None or using_arg.pos is None:
  11. ctx.api.fail("The using parameter is required", ctx.context)
  12. return sig
  13. for kind in sig.arg_kinds[: using_arg.pos]:
  14. if kind != ARG_POS:
  15. ctx.api.fail("Expected using to be the first optional", ctx.context)
  16. return sig
  17. str_type = ctx.api.named_generic_type("builtins.str", [])
  18. arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]]
  19. arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]]
  20. return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types)
  21. def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
  22. sig = ctx.default_signature
  23. if not sig.argument_by_name("using"):
  24. # No using arg in the signature, bail
  25. return sig
  26. # We care about context managers.
  27. if not isinstance(sig.ret_type, Instance):
  28. return sig
  29. return _make_using_required_str(ctx)
  30. _FUNCTION_SIGNATURE_HOOKS = {
  31. "django.db.transaction.atomic": replace_transaction_atomic_sig_callback,
  32. "django.db.transaction.get_connection": _make_using_required_str,
  33. "django.db.transaction.on_commit": _make_using_required_str,
  34. "django.db.transaction.set_rollback": _make_using_required_str,
  35. }
  36. def field_descriptor_no_overloads(ctx: MethodSigContext) -> FunctionLike:
  37. # ignore the class / non-model instance descriptor overloads
  38. signature = ctx.default_signature
  39. # replace `def __get__(self, inst: Model, owner: Any) -> _GT:`
  40. # with `def __get__(self, inst: Any, owner: Any) -> _GT:`
  41. if str(signature.arg_types[0]) == "django.db.models.base.Model":
  42. return signature.copy_modified(arg_types=[signature.arg_types[1]] * 2)
  43. else:
  44. return signature
  45. def _adjust_http_request_members(ctx: ClassDefContext) -> None:
  46. if ctx.cls.name == "HttpRequest":
  47. # added by sentry.api.base and sentry.web.frontend.base
  48. # TODO: idk why I can't use the real type here :/
  49. add_attribute_to_class(ctx.api, ctx.cls, "access", AnyType(TypeOfAny.explicit))
  50. # added by sentry.middleware.auth
  51. # TODO: figure out how to get the real types here
  52. add_attribute_to_class(ctx.api, ctx.cls, "auth", AnyType(TypeOfAny.explicit))
  53. # added by csp.middleware.CSPMiddleware
  54. add_attribute_to_class(ctx.api, ctx.cls, "csp_nonce", ctx.api.named_type("builtins.str"))
  55. # added by sudo.middleware.SudoMiddleware
  56. # this is slightly better than a method returning bool for overriding
  57. returns_bool = CallableType(
  58. arg_types=[],
  59. arg_kinds=[],
  60. arg_names=[],
  61. ret_type=ctx.api.named_type("builtins.bool"),
  62. fallback=ctx.api.named_type("builtins.function"),
  63. name="is_sudo",
  64. )
  65. add_attribute_to_class(ctx.api, ctx.cls, "is_sudo", returns_bool)
  66. # added by sentry.middleware.subdomain
  67. subdomain_tp = UnionType([NoneType(), ctx.api.named_type("builtins.str")])
  68. add_attribute_to_class(ctx.api, ctx.cls, "subdomain", subdomain_tp)
  69. # added by sentry.middleware.superuser
  70. # TODO: figure out how to get the real types here
  71. add_attribute_to_class(ctx.api, ctx.cls, "superuser", AnyType(TypeOfAny.explicit))
  72. class SentryMypyPlugin(Plugin):
  73. def get_function_signature_hook(
  74. self, fullname: str
  75. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  76. return _FUNCTION_SIGNATURE_HOOKS.get(fullname)
  77. def get_method_signature_hook(
  78. self, fullname: str
  79. ) -> Callable[[MethodSigContext], FunctionLike] | None:
  80. if fullname == "django.db.models.fields.Field":
  81. return field_descriptor_no_overloads
  82. clsname, _, methodname = fullname.rpartition(".")
  83. if methodname != "__get__":
  84. return None
  85. clsinfo = self.lookup_fully_qualified(clsname)
  86. if clsinfo is None or not isinstance(clsinfo.node, TypeInfo):
  87. return None
  88. fieldinfo = self.lookup_fully_qualified("django.db.models.fields.Field")
  89. if fieldinfo is None:
  90. return None
  91. if fieldinfo.node in clsinfo.node.mro:
  92. return field_descriptor_no_overloads
  93. else:
  94. return None
  95. def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
  96. # XXX: this is a hack -- I don't know if there's a better callback to modify a class
  97. if fullname == "io.BytesIO":
  98. return _adjust_http_request_members
  99. else:
  100. return None
  101. def plugin(version: str) -> type[SentryMypyPlugin]:
  102. return SentryMypyPlugin