social.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from allauth.account.adapter import DefaultAccountAdapter
  2. from allauth.account import app_settings as allauth_settings
  3. from allauth.account.auth_backends import AuthenticationBackend
  4. from allauth.socialaccount.adapter import DefaultSocialAccountAdapter, get_adapter
  5. from allauth.socialaccount.helpers import complete_social_login
  6. from allauth.socialaccount.providers.oauth2.client import OAuth2Client
  7. from dj_rest_auth.registration.serializers import (
  8. SocialLoginSerializer as BaseSocialLoginSerializer,
  9. )
  10. from dj_rest_auth.registration.views import SocialConnectView, SocialLoginView
  11. from django.conf import settings
  12. from django.contrib.auth import get_backends, get_user_model
  13. from django.http import HttpResponseBadRequest
  14. from django.utils.translation import gettext_lazy as _
  15. from django_rest_mfa.helpers import has_mfa
  16. from requests.exceptions import HTTPError
  17. from rest_framework import serializers
  18. from rest_framework.response import Response
  19. from users.utils import is_user_registration_open
  20. from .constants import SOCIAL_ADAPTER_MAP
  21. DOMAIN = settings.GLITCHTIP_URL.geturl()
  22. class MFAAccountAdapter(DefaultAccountAdapter):
  23. """
  24. If user requires MFA, do not actually log in
  25. """
  26. def login(self, request, user):
  27. """Extend to check for MFA status, backend hack is copied from super method"""
  28. if not hasattr(user, "backend"):
  29. backends = get_backends()
  30. backend = None
  31. for b in backends: # pylint: disable=invalid-name
  32. if isinstance(b, AuthenticationBackend):
  33. # prefer our own backend
  34. backend = b
  35. break
  36. elif not backend and hasattr(b, "get_user"):
  37. # Pick the first valid one
  38. backend = b
  39. backend_path = ".".join([backend.__module__, backend.__class__.__name__])
  40. user.backend = backend_path
  41. if has_mfa(request, user):
  42. user.mfa = True # Store for later, to avoid multiple DB checks
  43. else:
  44. super().login(request, user)
  45. def get_login_redirect_url(self, request):
  46. """Ignore login redirect when not logged in"""
  47. try:
  48. return super().get_login_redirect_url(request)
  49. except AssertionError:
  50. pass
  51. class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
  52. def is_open_for_signup(self, request, sociallogin):
  53. return is_user_registration_open()
  54. class SocialLoginSerializer(BaseSocialLoginSerializer):
  55. tags = serializers.CharField(
  56. allow_blank=True, required=False, allow_null=True, write_only=True
  57. )
  58. # Overriding to add check for django-allauth's is_open_for_signup() at end, to prevent
  59. # creation of new user on first-time social auth login
  60. # https://github.com/iMerica/dj-rest-auth/blob/master/dj_rest_auth/registration/serializers.py#L79
  61. def validate(self, attrs):
  62. view = self.context.get("view")
  63. request = self._get_request()
  64. if not view:
  65. raise serializers.ValidationError(
  66. _("View is not defined, pass it as a context variable"),
  67. )
  68. adapter_class = getattr(view, "adapter_class", None)
  69. if not adapter_class:
  70. raise serializers.ValidationError(_("Define adapter_class in view"))
  71. adapter = adapter_class(request)
  72. app = adapter.get_provider().get_app(request)
  73. access_token = attrs.get("access_token")
  74. code = attrs.get("code")
  75. if access_token:
  76. tokens_to_parse = {"access_token": access_token}
  77. token = access_token
  78. id_token = attrs.get("id_token")
  79. if id_token:
  80. tokens_to_parse["id_token"] = id_token
  81. elif code:
  82. self.set_callback_url(view=view, adapter_class=adapter_class)
  83. self.client_class = getattr(view, "client_class", None)
  84. if not self.client_class:
  85. raise serializers.ValidationError(
  86. _("Define client_class in view"),
  87. )
  88. provider = adapter.get_provider()
  89. scope = provider.get_scope(request)
  90. client = self.client_class(
  91. request,
  92. app.client_id,
  93. app.secret,
  94. adapter.access_token_method,
  95. adapter.access_token_url,
  96. self.callback_url,
  97. scope,
  98. scope_delimiter=adapter.scope_delimiter,
  99. headers=adapter.headers,
  100. basic_auth=adapter.basic_auth,
  101. )
  102. token = client.get_access_token(code)
  103. access_token = token["access_token"]
  104. tokens_to_parse = {"access_token": access_token}
  105. for key in ["refresh_token", "id_token", adapter.expires_in_key]:
  106. if key in token:
  107. tokens_to_parse[key] = token[key]
  108. else:
  109. raise serializers.ValidationError(
  110. _("Incorrect input. access_token or code is required."),
  111. )
  112. social_token = adapter.parse_token(tokens_to_parse)
  113. social_token.app = app
  114. try:
  115. login = self.get_social_login(adapter, app, social_token, token)
  116. ret = complete_social_login(request, login)
  117. except HTTPError:
  118. raise serializers.ValidationError(_("Incorrect value"))
  119. if isinstance(ret, HttpResponseBadRequest):
  120. raise serializers.ValidationError(ret.content)
  121. if not login.is_existing:
  122. if allauth_settings.UNIQUE_EMAIL:
  123. account_exists = (
  124. get_user_model()
  125. .objects.filter(
  126. email=login.user.email,
  127. )
  128. .exists()
  129. )
  130. if account_exists:
  131. raise serializers.ValidationError(
  132. _("User is already registered with this e-mail address."),
  133. )
  134. # Added check for open signup
  135. if not get_adapter(request).is_open_for_signup(request, login):
  136. raise serializers.ValidationError(_("User registration is closed."))
  137. else:
  138. login.lookup()
  139. login.save(request, connect=True)
  140. attrs["user"] = login.account.user
  141. return attrs
  142. class GenericMFAMixin:
  143. client_class = OAuth2Client # Needed for Github. Would this ever break a provider?
  144. @property
  145. def callback_url(self):
  146. provider_id = self.adapter_class.provider_id
  147. return DOMAIN + "/auth/" + provider_id
  148. @property
  149. def adapter_class(self):
  150. provider = self.kwargs.get("provider")
  151. return SOCIAL_ADAPTER_MAP[provider]
  152. class GlitchTipSocialConnectView(GenericMFAMixin, SocialConnectView):
  153. pass
  154. class MFASocialLoginView(GenericMFAMixin, SocialLoginView):
  155. serializer_class = SocialLoginSerializer
  156. def process_login(self):
  157. tags = self.serializer.validated_data.get("tags")
  158. if tags and self.user.analytics is None:
  159. self.user.set_register_analytics_tags(tags)
  160. self.user.save(update_fields=["analytics"])
  161. if not getattr(self.user, "mfa", False):
  162. super().process_login()
  163. def get_response(self):
  164. if getattr(self.user, "mfa", False):
  165. user_key_types = (
  166. self.user.userkey_set.all()
  167. .values_list("key_type", flat=True)
  168. .distinct()
  169. )
  170. return Response({"requires_mfa": True, "valid_auth": user_key_types})
  171. return super().get_response()