models.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from __future__ import annotations
  2. import re
  3. from typing import Any
  4. from django.apps import apps
  5. from django.conf import settings
  6. from django.db import models
  7. from sentry.db.models import control_silo_model
  8. from .fields import JSONField
  9. from .utils import setting
  10. AUTH_USER_MODEL = settings.AUTH_USER_MODEL
  11. UID_LENGTH = setting("SOCIAL_AUTH_UID_LENGTH", 255)
  12. NONCE_SERVER_URL_LENGTH = setting("SOCIAL_AUTH_NONCE_SERVER_URL_LENGTH", 255)
  13. ASSOCIATION_SERVER_URL_LENGTH = setting("SOCIAL_AUTH_ASSOCIATION_SERVER_URL_LENGTH", 255)
  14. ASSOCIATION_HANDLE_LENGTH = setting("SOCIAL_AUTH_ASSOCIATION_HANDLE_LENGTH", 255)
  15. CLEAN_USERNAME_REGEX = re.compile(r"[^\w.@+-_]+", re.UNICODE)
  16. @control_silo_model
  17. class UserSocialAuth(models.Model):
  18. """Social Auth association model"""
  19. id = models.AutoField(primary_key=True)
  20. user = models.ForeignKey(AUTH_USER_MODEL, related_name="social_auth", on_delete=models.CASCADE)
  21. provider = models.CharField(max_length=32)
  22. uid = models.CharField(max_length=UID_LENGTH)
  23. extra_data: models.Field[dict[str, Any] | None, dict[str, Any]] = JSONField(default="{}")
  24. class Meta:
  25. """Meta data"""
  26. unique_together = ("provider", "uid", "user")
  27. app_label = "social_auth"
  28. def __str__(self):
  29. """Return associated user unicode representation"""
  30. return f"{self.user} - {self.provider.title()}"
  31. def get_backend(self):
  32. from .utils import get_backend
  33. return get_backend(instance=self)
  34. @property
  35. def tokens(self):
  36. from .utils import tokens
  37. return tokens(instance=self)
  38. def revoke_token(self, drop_token=True):
  39. """Attempts to revoke permissions for provider."""
  40. if "access_token" in self.tokens:
  41. success = self.get_backend().revoke_token(self.tokens["access_token"], self.uid)
  42. if success and drop_token:
  43. self.extra_data.pop("access_token", None)
  44. self.save()
  45. def refresh_token(self):
  46. refresh_token = self.extra_data.get("refresh_token")
  47. if refresh_token:
  48. backend = self.get_backend()
  49. if hasattr(backend, "refresh_token"):
  50. response = backend.refresh_token(refresh_token, self.provider)
  51. new_access_token = response.get("access_token")
  52. # We have not got a new access token, so don't lose the
  53. # existing one.
  54. if not new_access_token:
  55. return
  56. self.extra_data["access_token"] = new_access_token
  57. # New refresh token might be given.
  58. new_refresh_token = response.get("refresh_token")
  59. if new_refresh_token:
  60. self.extra_data["refresh_token"] = new_refresh_token
  61. self.save()
  62. @classmethod
  63. def clean_username(cls, value):
  64. return CLEAN_USERNAME_REGEX.sub("", value)
  65. @classmethod
  66. def user_username(cls, user):
  67. if hasattr(user, "USERNAME_FIELD"):
  68. # Django 1.5 custom user model, 'username' is just for internal
  69. # use, doesn't imply that the model should have an username field
  70. field_name = user.USERNAME_FIELD
  71. else:
  72. field_name = "username"
  73. return getattr(user, field_name)
  74. @classmethod
  75. def username_field(cls, values):
  76. user_model = cls.user_model()
  77. if hasattr(user_model, "USERNAME_FIELD"):
  78. # Django 1.5 custom user model, 'username' is just for internal
  79. # use, doesn't imply that the model should have an username field
  80. values[user_model.USERNAME_FIELD] = values.pop("username")
  81. return values
  82. @classmethod
  83. def simple_user_exists(cls, *args, **kwargs):
  84. """
  85. Return True/False if a User instance exists with the given arguments.
  86. Arguments are directly passed to filter() manager method.
  87. TODO: consider how to ensure case-insensitive email matching
  88. """
  89. kwargs = cls.username_field(kwargs)
  90. return cls.user_model().objects.filter(*args, **kwargs).exists()
  91. @classmethod
  92. def create_user(cls, *args, **kwargs):
  93. kwargs = cls.username_field(kwargs)
  94. return cls.user_model().objects.create_user(*args, **kwargs)
  95. @classmethod
  96. def get_user(cls, pk):
  97. try:
  98. return cls.user_model().objects.get(pk=pk)
  99. except cls.user_model().DoesNotExist:
  100. return None
  101. @classmethod
  102. def get_user_by_email(cls, email):
  103. """Case insensitive search"""
  104. # Do case-insensitive match, since real-world email address is
  105. # case-insensitive.
  106. return cls.user_model().objects.get(email__iexact=email)
  107. @classmethod
  108. def resolve_user_or_id(cls, user_or_id):
  109. if isinstance(user_or_id, cls.user_model()):
  110. return user_or_id
  111. return cls.user_model().objects.get(pk=user_or_id)
  112. @classmethod
  113. def get_social_auth_for_user(cls, user):
  114. return user.social_auth.all()
  115. @classmethod
  116. def create_social_auth(cls, user, uid, provider):
  117. if not isinstance(uid, str):
  118. uid = str(uid)
  119. return cls.objects.create(user_id=user.id, uid=uid, provider=provider)
  120. @classmethod
  121. def get_social_auth(cls, provider, uid, user):
  122. try:
  123. instance = cls.objects.get(provider=provider, uid=uid, user_id=user.id)
  124. return instance
  125. except UserSocialAuth.DoesNotExist:
  126. return None
  127. @classmethod
  128. def username_max_length(cls):
  129. return cls._field_length("USERNAME_FIELD", "username")
  130. @classmethod
  131. def email_max_length(cls):
  132. return cls._field_length("EMAIL_FIELD", "email")
  133. @classmethod
  134. def _field_length(self, setting_name, default_name):
  135. model = UserSocialAuth.user_model()
  136. field_name = getattr(model, setting_name, default_name)
  137. return model._meta.get_field(field_name).max_length
  138. @classmethod
  139. def user_model(cls):
  140. db, name = AUTH_USER_MODEL.split(".")
  141. return apps.get_model(db, name)