social.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from allauth.account import app_settings as allauth_account_settings
  2. from allauth.account.adapter import DefaultAccountAdapter
  3. from allauth.account.auth_backends import AuthenticationBackend
  4. from allauth.socialaccount.adapter import DefaultSocialAccountAdapter, get_adapter
  5. from allauth.socialaccount.helpers import complete_social_login
  6. from allauth.socialaccount.providers.oauth2.client import OAuth2Client, OAuth2Error
  7. from dj_rest_auth.registration.serializers import (
  8. SocialLoginSerializer as BaseSocialLoginSerializer,
  9. )
  10. from dj_rest_auth.registration.views import SocialConnectView, SocialLoginView
  11. from django.conf import settings
  12. from django.contrib.auth import get_backends, get_user_model
  13. from django.http import HttpResponseBadRequest
  14. from django.utils.translation import gettext_lazy as _
  15. from django_rest_mfa.helpers import has_mfa
  16. from requests.exceptions import HTTPError
  17. from rest_framework import serializers
  18. from rest_framework.response import Response
  19. from users.utils import is_user_registration_open
  20. from .constants import SOCIAL_ADAPTER_MAP
  21. DOMAIN = settings.GLITCHTIP_URL.geturl()
  22. class MFAAccountAdapter(DefaultAccountAdapter):
  23. """
  24. If user requires MFA, do not actually log in
  25. """
  26. def login(self, request, user):
  27. """Extend to check for MFA status, backend hack is copied from super method"""
  28. if not hasattr(user, "backend"):
  29. backends = get_backends()
  30. backend = None
  31. for b in backends: # pylint: disable=invalid-name
  32. if isinstance(b, AuthenticationBackend):
  33. # prefer our own backend
  34. backend = b
  35. break
  36. elif not backend and hasattr(b, "get_user"):
  37. # Pick the first valid one
  38. backend = b
  39. backend_path = ".".join([backend.__module__, backend.__class__.__name__])
  40. user.backend = backend_path
  41. if has_mfa(request, user):
  42. user.mfa = True # Store for later, to avoid multiple DB checks
  43. else:
  44. super().login(request, user)
  45. def get_login_redirect_url(self, request):
  46. """Ignore login redirect when not logged in"""
  47. try:
  48. return super().get_login_redirect_url(request)
  49. except AssertionError:
  50. pass
  51. class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
  52. def is_open_for_signup(self, request, sociallogin):
  53. return is_user_registration_open()
  54. class SocialLoginSerializer(BaseSocialLoginSerializer):
  55. tags = serializers.CharField(
  56. allow_blank=True, required=False, allow_null=True, write_only=True
  57. )
  58. # Overriding to add check for django-allauth's is_open_for_signup() at end, to prevent
  59. # creation of new user on first-time social auth login
  60. # https://github.com/iMerica/dj-rest-auth/blob/master/dj_rest_auth/registration/serializers.py#L79
  61. def validate(self, attrs):
  62. view = self.context.get("view")
  63. request = self._get_request()
  64. if not view:
  65. raise serializers.ValidationError(
  66. _("View is not defined, pass it as a context variable"),
  67. )
  68. adapter_class = getattr(view, "adapter_class", None)
  69. if not adapter_class:
  70. raise serializers.ValidationError(_("Define adapter_class in view"))
  71. if adapter_class.provider_id == "openid_connect":
  72. adapter = adapter_class(request, adapter_class.provider_id)
  73. else:
  74. adapter = adapter_class(request)
  75. app = adapter.get_provider().app
  76. access_token = attrs.get("access_token")
  77. code = attrs.get("code")
  78. if access_token:
  79. tokens_to_parse = {"access_token": access_token}
  80. token = access_token
  81. id_token = attrs.get("id_token")
  82. if id_token:
  83. tokens_to_parse["id_token"] = id_token
  84. elif code:
  85. self.set_callback_url(view=view, adapter_class=adapter_class)
  86. self.client_class = getattr(view, "client_class", None)
  87. if not self.client_class:
  88. raise serializers.ValidationError(
  89. _("Define client_class in view"),
  90. )
  91. provider = adapter.get_provider()
  92. scope = provider.get_scope(request)
  93. client = self.client_class(
  94. request,
  95. app.client_id,
  96. app.secret,
  97. adapter.access_token_method,
  98. adapter.access_token_url,
  99. self.callback_url,
  100. scope,
  101. scope_delimiter=adapter.scope_delimiter,
  102. headers=adapter.headers,
  103. basic_auth=adapter.basic_auth,
  104. )
  105. try:
  106. token = client.get_access_token(code)
  107. except OAuth2Error as ex:
  108. raise serializers.ValidationError(
  109. _("Failed to exchange code for access token")
  110. ) from ex
  111. access_token = token["access_token"]
  112. tokens_to_parse = {"access_token": access_token}
  113. for key in ["refresh_token", "id_token", adapter.expires_in_key]:
  114. if key in token:
  115. tokens_to_parse[key] = token[key]
  116. else:
  117. raise serializers.ValidationError(
  118. _("Incorrect input. access_token or code is required."),
  119. )
  120. social_token = adapter.parse_token(tokens_to_parse)
  121. social_token.app = app
  122. try:
  123. if adapter.provider_id == "google" and not code:
  124. login = self.get_social_login(
  125. adapter, app, social_token, response={"id_token": id_token}
  126. )
  127. else:
  128. login = self.get_social_login(adapter, app, social_token, token)
  129. ret = complete_social_login(request, login)
  130. except HTTPError:
  131. raise serializers.ValidationError(_("Incorrect value"))
  132. if isinstance(ret, HttpResponseBadRequest):
  133. raise serializers.ValidationError(ret.content)
  134. if not login.is_existing:
  135. if allauth_account_settings.UNIQUE_EMAIL:
  136. account_exists = (
  137. get_user_model()
  138. .objects.filter(
  139. email=login.user.email,
  140. )
  141. .exists()
  142. )
  143. if account_exists:
  144. raise serializers.ValidationError(
  145. _("User is already registered with this e-mail address."),
  146. )
  147. # Added check for open signup
  148. if not get_adapter(request).is_open_for_signup(request, login):
  149. raise serializers.ValidationError(_("User registration is closed."))
  150. else:
  151. login.lookup()
  152. login.save(request, connect=True)
  153. self.post_signup(login, attrs)
  154. attrs["user"] = login.account.user
  155. return attrs
  156. class GenericMFAMixin:
  157. client_class = OAuth2Client # Only OAuth2 client is supported
  158. @property
  159. def callback_url(self):
  160. # Set dynamic OIDC provider ID
  161. provider_id = self.kwargs.get("provider", self.adapter_class.provider_id)
  162. return DOMAIN + "/auth/" + provider_id
  163. @property
  164. def adapter_class(self):
  165. provider = self.kwargs.get("provider")
  166. adapter_class = SOCIAL_ADAPTER_MAP.get(
  167. provider, SOCIAL_ADAPTER_MAP["openid_connect"]
  168. )
  169. # Set dynamic OIDC provider ID
  170. adapter_class.provider_id = provider
  171. return adapter_class
  172. class GlitchTipSocialConnectView(GenericMFAMixin, SocialConnectView):
  173. pass
  174. class MFASocialLoginView(GenericMFAMixin, SocialLoginView):
  175. serializer_class = SocialLoginSerializer
  176. def process_login(self):
  177. tags = self.serializer.validated_data.get("tags")
  178. if tags and self.user.analytics is None:
  179. self.user.set_register_analytics_tags(tags)
  180. self.user.save(update_fields=["analytics"])
  181. if not getattr(self.user, "mfa", False):
  182. super().process_login()
  183. def get_response(self):
  184. if getattr(self.user, "mfa", False):
  185. user_key_types = (
  186. self.user.userkey_set.all()
  187. .values_list("key_type", flat=True)
  188. .distinct()
  189. )
  190. return Response({"requires_mfa": True, "valid_auth": user_key_types})
  191. return super().get_response()