views.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from allauth.account.models import EmailAddress
  2. from dj_rest_auth.registration.views import (
  3. SocialAccountDisconnectView as BaseSocialAccountDisconnectView,
  4. )
  5. from django.core.exceptions import ObjectDoesNotExist, ValidationError
  6. from django.http import Http404
  7. from django.shortcuts import get_object_or_404
  8. from rest_framework import exceptions, mixins, status, viewsets
  9. from rest_framework.decorators import action
  10. from rest_framework.response import Response
  11. from apps.projects.models import UserProjectAlert
  12. from .models import User
  13. from .serializers import (
  14. ConfirmEmailAddressSerializer,
  15. CurrentUserSerializer,
  16. EmailAddressSerializer,
  17. UserNotificationsSerializer,
  18. UserSerializer,
  19. )
  20. class UserViewSet(viewsets.ReadOnlyModelViewSet):
  21. queryset = User.objects.all()
  22. serializer_class = UserSerializer
  23. def get_queryset(self):
  24. queryset = super().get_queryset()
  25. organization_slug = self.kwargs.get("organization_slug")
  26. if organization_slug:
  27. queryset = queryset.filter(
  28. organizations_ext_organization__slug=organization_slug,
  29. organizations_ext_organization__users=self.request.user,
  30. )
  31. else:
  32. queryset = queryset.filter(id=self.request.user.id)
  33. return queryset
  34. def get_object(self):
  35. if self.kwargs.get("pk") == "me":
  36. return self.request.user
  37. return super().get_object()
  38. def get_serializer_class(self):
  39. if self.kwargs.get("pk") == "me":
  40. return CurrentUserSerializer
  41. return super().get_serializer_class()
  42. @action(detail=True, methods=["get", "post", "put"])
  43. def notifications(self, request, pk=None):
  44. user = self.get_object()
  45. if request.method == "GET":
  46. serializer = UserNotificationsSerializer(user)
  47. return Response(serializer.data)
  48. serializer = UserNotificationsSerializer(user, data=request.data)
  49. if serializer.is_valid():
  50. serializer.save()
  51. return Response(serializer.data)
  52. else:
  53. return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
  54. @action(
  55. detail=True, methods=["get", "post", "put"], url_path="notifications/alerts"
  56. )
  57. def alerts(self, request, pk=None):
  58. """
  59. Returns dictionary of project_id: status. Now project_id status means it's "default"
  60. To update, submit `{project_id: status}` where status is -1 (default), 0, or 1
  61. """
  62. user = self.get_object()
  63. alerts = user.userprojectalert_set.all()
  64. if request.method == "GET":
  65. data = {}
  66. for alert in alerts:
  67. data[alert.project_id] = alert.status
  68. return Response(data)
  69. data = request.data
  70. try:
  71. items = [x for x in data.items()]
  72. except AttributeError as err:
  73. raise exceptions.ValidationError(
  74. "Invalid alert format, expected dictionary"
  75. ) from err
  76. if len(data) != 1:
  77. raise exceptions.ValidationError("Invalid alert format, expected one value")
  78. project_id, alert_status = items[0]
  79. if alert_status not in [1, 0, -1]:
  80. raise exceptions.ValidationError("Invalid status, must be -1, 0, or 1")
  81. alert = alerts.filter(project_id=project_id).first()
  82. if alert and alert_status == -1:
  83. alert.delete()
  84. else:
  85. UserProjectAlert.objects.update_or_create(
  86. user=user, project_id=project_id, defaults={"status": alert_status}
  87. )
  88. return Response(status=204)
  89. class EmailAddressViewSet(
  90. mixins.CreateModelMixin,
  91. mixins.ListModelMixin,
  92. viewsets.GenericViewSet,
  93. ):
  94. queryset = EmailAddress.objects.all()
  95. serializer_class = EmailAddressSerializer
  96. pagination_class = None
  97. def get_user(self, user_pk):
  98. if user_pk == "me":
  99. return self.request.user
  100. raise exceptions.ValidationError(
  101. "Can only change primary email address on own account"
  102. )
  103. def get_queryset(self):
  104. user = self.get_user(self.kwargs.get("user_pk"))
  105. queryset = super().get_queryset().filter(user=user)
  106. return queryset
  107. def put(self, request, user_pk, format=None):
  108. """
  109. Set a new primary email (must be verified) this will also set the email used when a user logs in.
  110. """
  111. user = self.get_user(user_pk)
  112. try:
  113. email_address = user.emailaddress_set.get(
  114. email=request.data.get("email"), verified=True
  115. )
  116. email_address.set_as_primary()
  117. except ObjectDoesNotExist as err:
  118. raise Http404 from err
  119. serializer = self.serializer_class(
  120. instance=email_address, data=request.data, context={"request": request}
  121. )
  122. if serializer.is_valid():
  123. serializer.save()
  124. return Response(serializer.data)
  125. return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
  126. def delete(self, request, user_pk, format=None):
  127. user = self.get_user(user_pk)
  128. try:
  129. email_address = user.emailaddress_set.get(
  130. email=request.data.get("email"), primary=False
  131. )
  132. except ObjectDoesNotExist as err:
  133. raise Http404 from err
  134. email_address.delete()
  135. return Response(status=status.HTTP_204_NO_CONTENT)
  136. @action(detail=False, methods=["post"])
  137. def confirm(self, request, user_pk):
  138. serializer = ConfirmEmailAddressSerializer(data=request.data)
  139. serializer.is_valid(raise_exception=True)
  140. email_address = get_object_or_404(
  141. self.get_queryset(), email=serializer.validated_data.get("email")
  142. )
  143. email_address.send_confirmation(request)
  144. return Response(status=204)
  145. class SocialAccountDisconnectView(BaseSocialAccountDisconnectView):
  146. def post(self, request, *args, **kwargs):
  147. try:
  148. return super().post(request, *args, **kwargs)
  149. except ValidationError as e:
  150. raise exceptions.ValidationError(e.message)