server_payments_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. package server
  2. import (
  3. "encoding/json"
  4. "github.com/stretchr/testify/mock"
  5. "github.com/stretchr/testify/require"
  6. "github.com/stripe/stripe-go/v74"
  7. "heckel.io/ntfy/user"
  8. "heckel.io/ntfy/util"
  9. "io"
  10. "path/filepath"
  11. "strings"
  12. "testing"
  13. "time"
  14. )
  15. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  16. stripeMock := &testStripeAPI{}
  17. defer stripeMock.AssertExpectations(t)
  18. c := newTestConfigWithAuthFile(t)
  19. c.StripeSecretKey = "secret key"
  20. c.StripeWebhookKey = "webhook key"
  21. s := newTestServer(t, c)
  22. s.stripe = stripeMock
  23. // Define how the mock should react
  24. stripeMock.
  25. On("NewCheckoutSession", mock.Anything).
  26. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  27. // Create tier and user
  28. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  29. Code: "pro",
  30. StripePriceID: "price_123",
  31. }))
  32. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  33. // Create subscription
  34. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  35. "Authorization": util.BasicAuth("phil", "phil"),
  36. })
  37. require.Equal(t, 200, response.Code)
  38. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  39. require.Nil(t, err)
  40. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  41. }
  42. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  43. stripeMock := &testStripeAPI{}
  44. defer stripeMock.AssertExpectations(t)
  45. c := newTestConfigWithAuthFile(t)
  46. c.StripeSecretKey = "secret key"
  47. c.StripeWebhookKey = "webhook key"
  48. s := newTestServer(t, c)
  49. s.stripe = stripeMock
  50. // Define how the mock should react
  51. stripeMock.
  52. On("GetCustomer", "acct_123").
  53. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  54. stripeMock.
  55. On("NewCheckoutSession", mock.Anything).
  56. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  57. // Create tier and user
  58. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  59. Code: "pro",
  60. StripePriceID: "price_123",
  61. }))
  62. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  63. u, err := s.userManager.User("phil")
  64. require.Nil(t, err)
  65. billing := &user.Billing{
  66. StripeCustomerID: "acct_123",
  67. }
  68. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  69. // Create subscription
  70. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  71. "Authorization": util.BasicAuth("phil", "phil"),
  72. })
  73. require.Equal(t, 200, response.Code)
  74. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  75. require.Nil(t, err)
  76. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  77. }
  78. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  79. stripeMock := &testStripeAPI{}
  80. defer stripeMock.AssertExpectations(t)
  81. c := newTestConfigWithAuthFile(t)
  82. c.EnableSignup = true
  83. c.StripeSecretKey = "secret key"
  84. c.StripeWebhookKey = "webhook key"
  85. s := newTestServer(t, c)
  86. s.stripe = stripeMock
  87. // Define how the mock should react
  88. stripeMock.
  89. On("CancelSubscription", "sub_123").
  90. Return(&stripe.Subscription{}, nil)
  91. // Create tier and user
  92. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  93. Code: "pro",
  94. StripePriceID: "price_123",
  95. }))
  96. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  97. u, err := s.userManager.User("phil")
  98. require.Nil(t, err)
  99. billing := &user.Billing{
  100. StripeCustomerID: "acct_123",
  101. StripeSubscriptionID: "sub_123",
  102. }
  103. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  104. // Delete account
  105. rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
  106. "Authorization": util.BasicAuth("phil", "phil"),
  107. })
  108. require.Equal(t, 200, rr.Code)
  109. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  110. "Authorization": util.BasicAuth("phil", "mypass"),
  111. })
  112. require.Equal(t, 401, rr.Code)
  113. }
  114. func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
  115. // This tests incoming webhooks from Stripe to update a subscription:
  116. // - All Stripe columns are updated in the user table
  117. // - When downgrading, excess reservations are deleted, including messages and attachments in
  118. // the corresponding topics
  119. stripeMock := &testStripeAPI{}
  120. defer stripeMock.AssertExpectations(t)
  121. c := newTestConfigWithAuthFile(t)
  122. c.StripeSecretKey = "secret key"
  123. c.StripeWebhookKey = "webhook key"
  124. s := newTestServer(t, c)
  125. s.stripe = stripeMock
  126. // Define how the mock should react
  127. stripeMock.
  128. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  129. Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
  130. // Create a user with a Stripe subscription and 3 reservations
  131. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  132. Code: "starter",
  133. StripePriceID: "price_1234", // !
  134. ReservationsLimit: 1, // !
  135. MessagesLimit: 100,
  136. MessagesExpiryDuration: time.Hour,
  137. AttachmentExpiryDuration: time.Hour,
  138. AttachmentFileSizeLimit: 1000000,
  139. AttachmentTotalSizeLimit: 1000000,
  140. }))
  141. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  142. Code: "pro",
  143. StripePriceID: "price_1111", // !
  144. ReservationsLimit: 3, // !
  145. MessagesLimit: 200,
  146. MessagesExpiryDuration: time.Hour,
  147. AttachmentExpiryDuration: time.Hour,
  148. AttachmentFileSizeLimit: 1000000,
  149. AttachmentTotalSizeLimit: 1000000,
  150. }))
  151. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  152. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  153. require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll))
  154. require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll))
  155. // Add billing details
  156. u, err := s.userManager.User("phil")
  157. require.Nil(t, err)
  158. billing := &user.Billing{
  159. StripeCustomerID: "acct_5555",
  160. StripeSubscriptionID: "sub_1234",
  161. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  162. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  163. StripeSubscriptionCancelAt: time.Unix(456, 0),
  164. }
  165. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  166. // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
  167. rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
  168. "Authorization": util.BasicAuth("phil", "phil"),
  169. })
  170. require.Equal(t, 200, rr.Code)
  171. rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
  172. "Authorization": util.BasicAuth("phil", "phil"),
  173. })
  174. require.Equal(t, 200, rr.Code)
  175. a2 := toMessage(t, rr.Body.String())
  176. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  177. rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
  178. "Authorization": util.BasicAuth("phil", "phil"),
  179. })
  180. require.Equal(t, 200, rr.Code)
  181. rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
  182. "Authorization": util.BasicAuth("phil", "phil"),
  183. })
  184. require.Equal(t, 200, rr.Code)
  185. z2 := toMessage(t, rr.Body.String())
  186. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  187. // Call the webhook: This does all the magic
  188. rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  189. "Stripe-Signature": "stripe signature",
  190. })
  191. require.Equal(t, 200, rr.Code)
  192. // Verify that database columns were updated
  193. u, err = s.userManager.User("phil")
  194. require.Nil(t, err)
  195. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  196. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  197. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  198. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
  199. require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
  200. require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
  201. // Verify that reservations were deleted
  202. r, err := s.userManager.Reservations("phil")
  203. require.Nil(t, err)
  204. require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
  205. require.Equal(t, "atopic", r[0].Topic)
  206. // Verify that messages and attachments were deleted
  207. time.Sleep(time.Second)
  208. s.execManager()
  209. ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
  210. require.Nil(t, err)
  211. require.Equal(t, 2, len(ms))
  212. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  213. ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
  214. require.Nil(t, err)
  215. require.Equal(t, 0, len(ms))
  216. require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  217. }
  218. type testStripeAPI struct {
  219. mock.Mock
  220. }
  221. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  222. args := s.Called(params)
  223. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  224. }
  225. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  226. args := s.Called(params)
  227. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  228. }
  229. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  230. args := s.Called(params)
  231. return args.Get(0).([]*stripe.Price), args.Error(1)
  232. }
  233. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  234. args := s.Called(id)
  235. return args.Get(0).(*stripe.Customer), args.Error(1)
  236. }
  237. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  238. args := s.Called(id)
  239. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  240. }
  241. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  242. args := s.Called(id)
  243. return args.Get(0).(*stripe.Subscription), args.Error(1)
  244. }
  245. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  246. args := s.Called(id)
  247. return args.Get(0).(*stripe.Subscription), args.Error(1)
  248. }
  249. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  250. args := s.Called(id)
  251. return args.Get(0).(*stripe.Subscription), args.Error(1)
  252. }
  253. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  254. args := s.Called(payload, header, secret)
  255. return args.Get(0).(stripe.Event), args.Error(1)
  256. }
  257. var _ stripeAPI = (*testStripeAPI)(nil)
  258. func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
  259. var e stripe.Event
  260. if err := json.Unmarshal([]byte(v), &e); err != nil {
  261. t.Fatal(err)
  262. }
  263. return e
  264. }
  265. const subscriptionUpdatedEventJSON = `
  266. {
  267. "type": "customer.subscription.updated",
  268. "data": {
  269. "object": {
  270. "id": "sub_1234",
  271. "customer": "acct_5555",
  272. "status": "active",
  273. "current_period_end": 1674268231,
  274. "cancel_at": 1674299999,
  275. "items": {
  276. "data": [
  277. {
  278. "price": {
  279. "id": "price_1234"
  280. }
  281. }
  282. ]
  283. }
  284. }
  285. }
  286. }`