Просмотр исходного кода

Start of SaaS sign up flow with stripe

David Burke 4 лет назад
Родитель
Сommit
7b23ef1811

+ 35 - 1
djstripe_ext/serializers.py

@@ -1,8 +1,11 @@
+from django.core.exceptions import SuspiciousOperation
+from rest_framework import serializers
 from rest_framework.serializers import ModelSerializer
-from djstripe.models import Plan
+from djstripe.models import Plan, Customer
 from djstripe.contrib.rest_framework.serializers import (
     SubscriptionSerializer as BaseSubscriptionSerializer,
 )
+from organizations_ext.models import OrganizationUserRole
 
 
 class PlanSerializer(ModelSerializer):
@@ -13,3 +16,34 @@ class PlanSerializer(ModelSerializer):
 
 class SubscriptionSerializer(BaseSubscriptionSerializer):
     plan = PlanSerializer(read_only=True)
+
+
+class OrganizationPrimaryKeySerializer(serializers.PrimaryKeyRelatedField):
+    def get_queryset(self):
+        user = self.context["request"].user
+        return user.organizations_ext_organization.filter(
+            organization_users__role=OrganizationUserRole.OWNER
+        )
+
+
+class CreateSubscriptionSerializer(serializers.Serializer):
+    """A serializer used to create a Subscription. Only works with free plans. """
+
+    plan = serializers.SlugRelatedField(queryset=Plan.objects.all(), slug_field="id")
+    organization = OrganizationPrimaryKeySerializer()
+    subscription = SubscriptionSerializer(read_only=True)
+
+    def create(self, data):
+        organization = data["organization"]
+        plan = data["plan"]
+        if plan.amount != 0.0:
+            raise SuspiciousOperation(
+                "Cannot subscribe to non-free plan without payment"
+            )
+        customer, _ = Customer.get_or_create(subscriber=organization)
+        subscription = customer.subscribe(plan)
+        return {
+            "plan": plan,
+            "organization": organization,
+            "subscription": subscription,
+        }

+ 55 - 0
djstripe_ext/tests.py

@@ -1,3 +1,4 @@
+from unittest.mock import patch
 from django.shortcuts import reverse
 from django.utils import timezone
 from rest_framework.test import APITestCase
@@ -54,3 +55,57 @@ class SubscriptionAPITestCase(APITestCase):
         url = reverse("subscription-detail", args=[self.organization.slug])
         res = self.client.get(url)
         self.assertContains(res, subscription.id)
+
+    @patch("djstripe.models.Customer.subscribe")
+    def test_create(self, djstripe_customer_subscribe_mock):
+        customer = baker.make("djstripe.Customer", subscriber=self.organization)
+        plan = baker.make("djstripe.Plan", amount=0)
+        subscription = baker.make(
+            "djstripe.Subscription", customer=customer, livemode=False,
+        )
+        djstripe_customer_subscribe_mock.return_value = subscription
+        data = {"plan": plan.id, "organization": self.organization.id}
+        res = self.client.post(self.url, data)
+        self.assertEqual(res.data["plan"], plan.id)
+
+    def test_create_invalid_org(self):
+        """ Only owners may create subscriptions """
+        user = baker.make("users.user")  # Non owner member
+        plan = baker.make("djstripe.Plan", amount=0)
+        self.organization.add_user(user)
+        self.client.force_login(user)
+        data = {"plan": plan.id, "organization": self.organization.id}
+        res = self.client.post(self.url, data)
+        self.assertEqual(res.status_code, 400)
+
+
+class SubscriptionIntegrationTestCase(APITestCase):
+    def setUp(self):
+        self.user = baker.make("users.user")
+        self.organization = baker.make("organizations_ext.Organization")
+        self.organization.add_user(self.user)
+        # Make these in this manner to avoid syncing data to stripe actual
+        self.plan = baker.make("djstripe.Plan", active=True, amount=0)
+        self.customer = baker.make("djstripe.Customer", subscriber=self.organization)
+        self.client.force_login(self.user)
+        self.list_url = reverse("subscription-list")
+        self.detail_url = reverse("subscription-detail", args=[self.organization.slug])
+
+    @patch("djstripe.models.Customer.subscribe")
+    def test_new_org_flow(self, djstripe_customer_subscribe_mock):
+        """ Test checking if subscription exists and when not, creating a free tier one """
+        res = self.client.get(self.detail_url)
+        self.assertFalse(res.data["id"])  # No subscription, user should create one
+
+        subscription = baker.make(
+            "djstripe.Subscription", customer=self.customer, livemode=False,
+        )
+        djstripe_customer_subscribe_mock.return_value = subscription
+
+        data = {"plan": self.plan.id, "organization": self.organization.id}
+        res = self.client.post(self.list_url, data)
+        self.assertContains(res, self.plan.id, status_code=201)
+        djstripe_customer_subscribe_mock.assert_called_once()
+
+        res = self.client.get(self.detail_url)
+        self.assertEqual(res.data["id"], subscription.id)

+ 7 - 2
djstripe_ext/views.py

@@ -2,10 +2,10 @@ from django.conf import settings
 from django.http import Http404
 from rest_framework import viewsets
 from djstripe.models import Subscription
-from .serializers import SubscriptionSerializer
+from .serializers import SubscriptionSerializer, CreateSubscriptionSerializer
 
 
-class SubscriptionViewSet(viewsets.ReadOnlyModelViewSet):
+class SubscriptionViewSet(viewsets.ModelViewSet):
     """
     View subscription status
 
@@ -16,6 +16,11 @@ class SubscriptionViewSet(viewsets.ReadOnlyModelViewSet):
     serializer_class = SubscriptionSerializer
     lookup_field = "customer__subscriber__slug"
 
+    def get_serializer_class(self):
+        if self.action == "create":
+            return CreateSubscriptionSerializer
+        return super().get_serializer_class()
+
     def get_queryset(self):
         """ Any user in an org may view subscription data """
         return self.queryset.filter(

+ 6 - 0
glitchtip/test_utils/generators.py

@@ -1,4 +1,10 @@
 from model_bakery import baker
 from model_bakery.random_gen import gen_slug
 
+
+def currency_code():
+    return "USD"
+
+
 baker.generators.add("organizations.fields.SlugField", gen_slug)
+baker.generators.add("djstripe.fields.StripeCurrencyCodeField", currency_code)

+ 17 - 0
organizations_ext/migrations/0005_remove_organization_throttling_cycle_anchor.py

@@ -0,0 +1,17 @@
+# Generated by Django 3.0.6 on 2020-05-23 14:48
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('organizations_ext', '0004_organization_throttling_cycle_anchor'),
+    ]
+
+    operations = [
+        migrations.RemoveField(
+            model_name='organization',
+            name='throttling_cycle_anchor',
+        ),
+    ]

+ 0 - 5
organizations_ext/models.py

@@ -1,5 +1,4 @@
 from django.db import models
-from django.utils import timezone
 from django.utils.translation import ugettext_lazy as _
 from organizations.base import (
     OrganizationBase,
@@ -30,10 +29,6 @@ class Organization(SharedBaseModel, OrganizationBase):
     is_accepting_events = models.BooleanField(
         default=True, help_text="Used for throttling at org level"
     )
-    throttling_cycle_anchor = models.DateTimeField(
-        default=timezone.now,
-        help_text="Useful for organization level throttling or free tier usage plans",
-    )
 
     def add_user(self, user, role=OrganizationUserRole.MEMBER):
         """