social.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from allauth.account import app_settings as allauth_account_settings
  2. from allauth.account.adapter import DefaultAccountAdapter
  3. from allauth.account.auth_backends import AuthenticationBackend
  4. from allauth.socialaccount.adapter import DefaultSocialAccountAdapter, get_adapter
  5. from allauth.socialaccount.helpers import complete_social_login
  6. from allauth.socialaccount.providers.oauth2.client import OAuth2Client, OAuth2Error
  7. from allauth.socialaccount.providers.openid_connect.views import (
  8. OpenIDConnectOAuth2Adapter,
  9. )
  10. from dj_rest_auth.registration.serializers import (
  11. SocialConnectMixin,
  12. )
  13. from dj_rest_auth.registration.serializers import (
  14. SocialLoginSerializer as BaseSocialLoginSerializer,
  15. )
  16. from dj_rest_auth.registration.views import SocialConnectView, SocialLoginView
  17. from django.conf import settings
  18. from django.contrib.auth import get_backends, get_user_model
  19. from django.db import IntegrityError
  20. from django.http import HttpResponseBadRequest
  21. from django.utils.translation import gettext_lazy as _
  22. from django_rest_mfa.helpers import has_mfa
  23. from requests.exceptions import HTTPError
  24. from rest_framework import serializers
  25. from rest_framework.response import Response
  26. from apps.users.utils import is_user_registration_open
  27. from .constants import SOCIAL_ADAPTER_MAP
  28. DOMAIN = settings.GLITCHTIP_URL.geturl()
  29. class MFAAccountAdapter(DefaultAccountAdapter):
  30. """
  31. If user requires MFA, do not actually log in
  32. """
  33. def login(self, request, user):
  34. """Extend to check for MFA status, backend hack is copied from super method"""
  35. if not hasattr(user, "backend"):
  36. backends = get_backends()
  37. backend = None
  38. for b in backends: # pylint: disable=invalid-name
  39. if isinstance(b, AuthenticationBackend):
  40. # prefer our own backend
  41. backend = b
  42. break
  43. elif not backend and hasattr(b, "get_user"):
  44. # Pick the first valid one
  45. backend = b
  46. backend_path = ".".join([backend.__module__, backend.__class__.__name__])
  47. user.backend = backend_path
  48. if has_mfa(request, user):
  49. user.mfa = True # Store for later, to avoid multiple DB checks
  50. else:
  51. super().login(request, user)
  52. def get_login_redirect_url(self, request):
  53. """Ignore login redirect when not logged in"""
  54. try:
  55. return super().get_login_redirect_url(request)
  56. except AssertionError:
  57. pass
  58. class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
  59. def is_open_for_signup(self, request, sociallogin):
  60. return is_user_registration_open()
  61. class SocialLoginSerializer(BaseSocialLoginSerializer):
  62. tags = serializers.CharField(
  63. allow_blank=True, required=False, allow_null=True, write_only=True
  64. )
  65. # Overriding to add check for django-allauth's is_open_for_signup() at end, to prevent
  66. # creation of new user on first-time social auth login
  67. # https://github.com/iMerica/dj-rest-auth/blob/master/dj_rest_auth/registration/serializers.py#L79
  68. def validate(self, attrs):
  69. view = self.context.get("view")
  70. request = self._get_request()
  71. if not view:
  72. raise serializers.ValidationError(
  73. _("View is not defined, pass it as a context variable"),
  74. )
  75. adapter_class = getattr(view, "adapter_class", None)
  76. if not adapter_class:
  77. raise serializers.ValidationError(_("Define adapter_class in view"))
  78. # The OIDC provider has a dynamic provider id. Fetch it from the request.
  79. if adapter_class == OpenIDConnectOAuth2Adapter:
  80. provider = request.resolver_match.captured_kwargs.get("provider")
  81. adapter = adapter_class(request, provider)
  82. else:
  83. adapter = adapter_class(request)
  84. app = adapter.get_provider().app
  85. access_token = attrs.get("access_token")
  86. code = attrs.get("code")
  87. if access_token:
  88. tokens_to_parse = {"access_token": access_token}
  89. token = access_token
  90. id_token = attrs.get("id_token")
  91. if id_token:
  92. tokens_to_parse["id_token"] = id_token
  93. elif code:
  94. self.set_callback_url(view=view, adapter_class=adapter_class)
  95. self.client_class = getattr(view, "client_class", None)
  96. if not self.client_class:
  97. raise serializers.ValidationError(
  98. _("Define client_class in view"),
  99. )
  100. provider = adapter.get_provider()
  101. client = self.client_class(
  102. request,
  103. app.client_id,
  104. app.secret,
  105. adapter.access_token_method,
  106. adapter.access_token_url,
  107. self.callback_url,
  108. scope_delimiter=adapter.scope_delimiter,
  109. headers=adapter.headers,
  110. basic_auth=adapter.basic_auth,
  111. )
  112. try:
  113. token = client.get_access_token(code)
  114. except OAuth2Error as ex:
  115. raise serializers.ValidationError(
  116. _("Failed to exchange code for access token")
  117. ) from ex
  118. access_token = token["access_token"]
  119. tokens_to_parse = {"access_token": access_token}
  120. for key in ["refresh_token", "id_token", adapter.expires_in_key]:
  121. if key in token:
  122. tokens_to_parse[key] = token[key]
  123. else:
  124. raise serializers.ValidationError(
  125. _("Incorrect input. access_token or code is required."),
  126. )
  127. social_token = adapter.parse_token(tokens_to_parse)
  128. social_token.app = app
  129. try:
  130. if adapter.provider_id == "google" and not code:
  131. login = self.get_social_login(
  132. adapter, app, social_token, response={"id_token": id_token}
  133. )
  134. else:
  135. login = self.get_social_login(adapter, app, social_token, token)
  136. # In allauth 0.53, we need this here instead
  137. account_exists = (
  138. get_user_model()
  139. .objects.filter(
  140. email=login.user.email,
  141. )
  142. .exists()
  143. )
  144. ret = complete_social_login(request, login)
  145. except HTTPError:
  146. raise serializers.ValidationError(_("Incorrect value"))
  147. if isinstance(ret, HttpResponseBadRequest):
  148. raise serializers.ValidationError(ret.content)
  149. if not login.is_existing:
  150. if allauth_account_settings.UNIQUE_EMAIL:
  151. if account_exists:
  152. raise serializers.ValidationError(
  153. _("User is already registered with this e-mail address."),
  154. )
  155. # Added check for open signup
  156. if not get_adapter(request).is_open_for_signup(request, login):
  157. raise serializers.ValidationError(_("User registration is closed."))
  158. else:
  159. login.lookup()
  160. try:
  161. login.save(request, connect=True)
  162. except IntegrityError as ex:
  163. raise serializers.ValidationError(
  164. _("User is already registered with this e-mail address."),
  165. ) from ex
  166. self.post_signup(login, attrs)
  167. attrs["user"] = login.account.user
  168. return attrs
  169. class GenericMFAMixin:
  170. client_class = OAuth2Client # Only OAuth2 client is supported
  171. @property
  172. def callback_url(self):
  173. # Set dynamic OIDC provider ID
  174. provider_id = self.kwargs.get("provider", self.adapter_class.provider_id)
  175. return DOMAIN + "/auth/" + provider_id
  176. @property
  177. def adapter_class(self):
  178. provider = self.kwargs.get("provider")
  179. adapter_class = SOCIAL_ADAPTER_MAP.get(
  180. provider, SOCIAL_ADAPTER_MAP["openid_connect"]
  181. )
  182. # Set dynamic OIDC provider ID
  183. adapter_class.provider_id = provider
  184. return adapter_class
  185. class GlitchTipSocialConnectSerializer(SocialConnectMixin, SocialLoginSerializer):
  186. pass
  187. class GlitchTipSocialConnectView(GenericMFAMixin, SocialConnectView):
  188. serializer_class = GlitchTipSocialConnectSerializer
  189. class MFASocialLoginView(GenericMFAMixin, SocialLoginView):
  190. serializer_class = SocialLoginSerializer
  191. def process_login(self):
  192. tags = self.serializer.validated_data.get("tags")
  193. if tags and self.user.analytics is None:
  194. self.user.set_register_analytics_tags(tags)
  195. self.user.save(update_fields=["analytics"])
  196. if not getattr(self.user, "mfa", False):
  197. super().process_login()
  198. def get_response(self):
  199. if getattr(self.user, "mfa", False):
  200. user_key_types = (
  201. self.user.userkey_set.all()
  202. .values_list("key_type", flat=True)
  203. .distinct()
  204. )
  205. return Response({"requires_mfa": True, "valid_auth": user_key_types})
  206. return super().get_response()