123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337 |
- package server
- import (
- "encoding/json"
- "github.com/stretchr/testify/mock"
- "github.com/stretchr/testify/require"
- "github.com/stripe/stripe-go/v74"
- "heckel.io/ntfy/user"
- "heckel.io/ntfy/util"
- "io"
- "path/filepath"
- "strings"
- "testing"
- "time"
- )
- func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("NewCheckoutSession", mock.Anything).
- Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
- // Create tier and user
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "pro",
- StripePriceID: "price_123",
- }))
- require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
- // Create subscription
- response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, response.Code)
- redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
- require.Nil(t, err)
- require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
- }
- func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("GetCustomer", "acct_123").
- Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
- stripeMock.
- On("NewCheckoutSession", mock.Anything).
- Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
- // Create tier and user
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "pro",
- StripePriceID: "price_123",
- }))
- require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
- u, err := s.userManager.User("phil")
- require.Nil(t, err)
- billing := &user.Billing{
- StripeCustomerID: "acct_123",
- }
- require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
- // Create subscription
- response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, response.Code)
- redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
- require.Nil(t, err)
- require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
- }
- func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.EnableSignup = true
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("CancelSubscription", "sub_123").
- Return(&stripe.Subscription{}, nil)
- // Create tier and user
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "pro",
- StripePriceID: "price_123",
- }))
- require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
- u, err := s.userManager.User("phil")
- require.Nil(t, err)
- billing := &user.Billing{
- StripeCustomerID: "acct_123",
- StripeSubscriptionID: "sub_123",
- }
- require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
- // Delete account
- rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- rr = request(t, s, "GET", "/v1/account", "", map[string]string{
- "Authorization": util.BasicAuth("phil", "mypass"),
- })
- require.Equal(t, 401, rr.Code)
- }
- func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
- // This tests incoming webhooks from Stripe to update a subscription:
- // - All Stripe columns are updated in the user table
- // - When downgrading, excess reservations are deleted, including messages and attachments in
- // the corresponding topics
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
- Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
- // Create a user with a Stripe subscription and 3 reservations
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "starter",
- StripePriceID: "price_1234", // !
- ReservationsLimit: 1, // !
- MessagesLimit: 100,
- MessagesExpiryDuration: time.Hour,
- AttachmentExpiryDuration: time.Hour,
- AttachmentFileSizeLimit: 1000000,
- AttachmentTotalSizeLimit: 1000000,
- }))
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "pro",
- StripePriceID: "price_1111", // !
- ReservationsLimit: 3, // !
- MessagesLimit: 200,
- MessagesExpiryDuration: time.Hour,
- AttachmentExpiryDuration: time.Hour,
- AttachmentFileSizeLimit: 1000000,
- AttachmentTotalSizeLimit: 1000000,
- }))
- require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
- require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
- require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll))
- require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll))
- // Add billing details
- u, err := s.userManager.User("phil")
- require.Nil(t, err)
- billing := &user.Billing{
- StripeCustomerID: "acct_5555",
- StripeSubscriptionID: "sub_1234",
- StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
- StripeSubscriptionPaidUntil: time.Unix(123, 0),
- StripeSubscriptionCancelAt: time.Unix(456, 0),
- }
- require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
- // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
- rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- a2 := toMessage(t, rr.Body.String())
- require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
- rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- z2 := toMessage(t, rr.Body.String())
- require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
- // Call the webhook: This does all the magic
- rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
- "Stripe-Signature": "stripe signature",
- })
- require.Equal(t, 200, rr.Code)
- // Verify that database columns were updated
- u, err = s.userManager.User("phil")
- require.Nil(t, err)
- require.Equal(t, "starter", u.Tier.Code) // Not "pro"
- require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
- require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
- require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
- require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
- require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
- // Verify that reservations were deleted
- r, err := s.userManager.Reservations("phil")
- require.Nil(t, err)
- require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
- require.Equal(t, "atopic", r[0].Topic)
- // Verify that messages and attachments were deleted
- time.Sleep(time.Second)
- s.execManager()
- ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
- require.Nil(t, err)
- require.Equal(t, 2, len(ms))
- require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
- ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
- require.Nil(t, err)
- require.Equal(t, 0, len(ms))
- require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
- }
- type testStripeAPI struct {
- mock.Mock
- }
- func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
- args := s.Called(params)
- return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
- }
- func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
- args := s.Called(params)
- return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
- }
- func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
- args := s.Called(params)
- return args.Get(0).([]*stripe.Price), args.Error(1)
- }
- func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Customer), args.Error(1)
- }
- func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
- }
- func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Subscription), args.Error(1)
- }
- func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Subscription), args.Error(1)
- }
- func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Subscription), args.Error(1)
- }
- func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
- args := s.Called(payload, header, secret)
- return args.Get(0).(stripe.Event), args.Error(1)
- }
- var _ stripeAPI = (*testStripeAPI)(nil)
- func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
- var e stripe.Event
- if err := json.Unmarshal([]byte(v), &e); err != nil {
- t.Fatal(err)
- }
- return e
- }
- const subscriptionUpdatedEventJSON = `
- {
- "type": "customer.subscription.updated",
- "data": {
- "object": {
- "id": "sub_1234",
- "customer": "acct_5555",
- "status": "active",
- "current_period_end": 1674268231,
- "cancel_at": 1674299999,
- "items": {
- "data": [
- {
- "price": {
- "id": "price_1234"
- }
- }
- ]
- }
- }
- }
- }`
|