server_payments_test.go 29 KB


  1. package server
  2. import (
  3. "encoding/json"
  4. "github.com/stretchr/testify/mock"
  5. "github.com/stretchr/testify/require"
  6. "github.com/stripe/stripe-go/v74"
  7. "golang.org/x/time/rate"
  8. "heckel.io/ntfy/v2/user"
  9. "heckel.io/ntfy/v2/util"
  10. "io"
  11. "net/netip"
  12. "path/filepath"
  13. "strings"
  14. "sync"
  15. "testing"
  16. "time"
  17. )
  18. func TestPayments_Tiers(t *testing.T) {
  19. stripeMock := &testStripeAPI{}
  20. defer stripeMock.AssertExpectations(t)
  21. c := newTestConfigWithAuthFile(t)
  22. c.StripeSecretKey = "secret key"
  23. c.StripeWebhookKey = "webhook key"
  24. c.VisitorRequestLimitReplenish = 12 * time.Hour
  25. c.CacheDuration = 13 * time.Hour
  26. c.AttachmentFileSizeLimit = 111
  27. c.VisitorAttachmentTotalSizeLimit = 222
  28. c.AttachmentExpiryDuration = 123 * time.Second
  29. s := newTestServer(t, c)
  30. s.stripe = stripeMock
  31. // Define how the mock should react
  32. stripeMock.
  33. On("ListPrices", mock.Anything).
  34. Return([]*stripe.Price{
  35. {ID: "price_123", UnitAmount: 500},
  36. {ID: "price_124", UnitAmount: 5000},
  37. {ID: "price_456", UnitAmount: 1000},
  38. {ID: "price_457", UnitAmount: 10000},
  39. {ID: "price_999", UnitAmount: 9999},
  40. }, nil)
  41. // Create tiers
  42. require.Nil(t, s.userManager.AddTier(&user.Tier{
  43. ID: "ti_1",
  44. Code: "admin",
  45. Name: "Admin",
  46. }))
  47. require.Nil(t, s.userManager.AddTier(&user.Tier{
  48. ID: "ti_123",
  49. Code: "pro",
  50. Name: "Pro",
  51. MessageLimit: 1000,
  52. MessageExpiryDuration: time.Hour,
  53. EmailLimit: 123,
  54. ReservationLimit: 777,
  55. AttachmentFileSizeLimit: 999,
  56. AttachmentTotalSizeLimit: 888,
  57. AttachmentExpiryDuration: time.Minute,
  58. StripeMonthlyPriceID: "price_123",
  59. StripeYearlyPriceID: "price_124",
  60. }))
  61. require.Nil(t, s.userManager.AddTier(&user.Tier{
  62. ID: "ti_444",
  63. Code: "business",
  64. Name: "Business",
  65. MessageLimit: 2000,
  66. MessageExpiryDuration: 10 * time.Hour,
  67. EmailLimit: 123123,
  68. ReservationLimit: 777333,
  69. AttachmentFileSizeLimit: 999111,
  70. AttachmentTotalSizeLimit: 888111,
  71. AttachmentExpiryDuration: time.Hour,
  72. StripeMonthlyPriceID: "price_456",
  73. StripeYearlyPriceID: "price_457",
  74. }))
  75. response := request(t, s, "GET", "/v1/tiers", "", nil)
  76. require.Equal(t, 200, response.Code)
  77. var tiers []apiAccountBillingTier
  78. require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
  79. require.Equal(t, 3, len(tiers))
  80. // Free tier
  81. tier := tiers[0]
  82. require.Equal(t, "", tier.Code)
  83. require.Equal(t, "", tier.Name)
  84. require.Equal(t, "ip", tier.Limits.Basis)
  85. require.Equal(t, int64(0), tier.Limits.Reservations)
  86. require.Equal(t, int64(2), tier.Limits.Messages) // :-(
  87. require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
  88. require.Equal(t, int64(24), tier.Limits.Emails)
  89. require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
  90. require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
  91. require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
  92. // Admin tier is not included, because it is not paid!
  93. tier = tiers[1]
  94. require.Equal(t, "pro", tier.Code)
  95. require.Equal(t, "Pro", tier.Name)
  96. require.Equal(t, "tier", tier.Limits.Basis)
  97. require.Equal(t, int64(500), tier.Prices.Month)
  98. require.Equal(t, int64(5000), tier.Prices.Year)
  99. require.Equal(t, int64(777), tier.Limits.Reservations)
  100. require.Equal(t, int64(1000), tier.Limits.Messages)
  101. require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
  102. require.Equal(t, int64(123), tier.Limits.Emails)
  103. require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
  104. require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
  105. require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
  106. tier = tiers[2]
  107. require.Equal(t, "business", tier.Code)
  108. require.Equal(t, "Business", tier.Name)
  109. require.Equal(t, int64(1000), tier.Prices.Month)
  110. require.Equal(t, int64(10000), tier.Prices.Year)
  111. require.Equal(t, "tier", tier.Limits.Basis)
  112. require.Equal(t, int64(777333), tier.Limits.Reservations)
  113. require.Equal(t, int64(2000), tier.Limits.Messages)
  114. require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
  115. require.Equal(t, int64(123123), tier.Limits.Emails)
  116. require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
  117. require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
  118. require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
  119. }
  120. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  121. stripeMock := &testStripeAPI{}
  122. defer stripeMock.AssertExpectations(t)
  123. c := newTestConfigWithAuthFile(t)
  124. c.StripeSecretKey = "secret key"
  125. c.StripeWebhookKey = "webhook key"
  126. s := newTestServer(t, c)
  127. s.stripe = stripeMock
  128. // Define how the mock should react
  129. stripeMock.
  130. On("NewCheckoutSession", mock.Anything).
  131. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  132. // Create tier and user
  133. require.Nil(t, s.userManager.AddTier(&user.Tier{
  134. ID: "ti_123",
  135. Code: "pro",
  136. StripeMonthlyPriceID: "price_123",
  137. }))
  138. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  139. // Create subscription
  140. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
  141. "Authorization": util.BasicAuth("phil", "phil"),
  142. })
  143. require.Equal(t, 200, response.Code)
  144. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  145. require.Nil(t, err)
  146. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  147. }
  148. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  149. stripeMock := &testStripeAPI{}
  150. defer stripeMock.AssertExpectations(t)
  151. c := newTestConfigWithAuthFile(t)
  152. c.StripeSecretKey = "secret key"
  153. c.StripeWebhookKey = "webhook key"
  154. s := newTestServer(t, c)
  155. s.stripe = stripeMock
  156. // Define how the mock should react
  157. stripeMock.
  158. On("GetCustomer", "acct_123").
  159. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  160. stripeMock.
  161. On("NewCheckoutSession", mock.Anything).
  162. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  163. // Create tier and user
  164. require.Nil(t, s.userManager.AddTier(&user.Tier{
  165. ID: "ti_123",
  166. Code: "pro",
  167. StripeMonthlyPriceID: "price_123",
  168. }))
  169. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  170. u, err := s.userManager.User("phil")
  171. require.Nil(t, err)
  172. billing := &user.Billing{
  173. StripeCustomerID: "acct_123",
  174. }
  175. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  176. // Create subscription
  177. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
  178. "Authorization": util.BasicAuth("phil", "phil"),
  179. })
  180. require.Equal(t, 200, response.Code)
  181. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  182. require.Nil(t, err)
  183. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  184. }
  185. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  186. stripeMock := &testStripeAPI{}
  187. defer stripeMock.AssertExpectations(t)
  188. c := newTestConfigWithAuthFile(t)
  189. c.EnableSignup = true
  190. c.StripeSecretKey = "secret key"
  191. c.StripeWebhookKey = "webhook key"
  192. s := newTestServer(t, c)
  193. s.stripe = stripeMock
  194. // Define how the mock should react
  195. stripeMock.
  196. On("CancelSubscription", "sub_123").
  197. Return(&stripe.Subscription{}, nil)
  198. // Create tier and user
  199. require.Nil(t, s.userManager.AddTier(&user.Tier{
  200. ID: "ti_123",
  201. Code: "pro",
  202. StripeMonthlyPriceID: "price_123",
  203. }))
  204. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  205. u, err := s.userManager.User("phil")
  206. require.Nil(t, err)
  207. billing := &user.Billing{
  208. StripeCustomerID: "acct_123",
  209. StripeSubscriptionID: "sub_123",
  210. }
  211. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  212. // Delete account
  213. rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
  214. "Authorization": util.BasicAuth("phil", "phil"),
  215. })
  216. require.Equal(t, 200, rr.Code)
  217. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  218. "Authorization": util.BasicAuth("phil", "mypass"),
  219. })
  220. require.Equal(t, 401, rr.Code)
  221. }
  222. func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
  223. // This test is too overloaded, but it's also a great end-to-end a test.
  224. //
  225. // It tests:
  226. // - A successful checkout flow (not a paying customer -> paying customer)
  227. // - Tier-changes reset the rate limits for the user
  228. // - The request limits for tier-less user and a tier-user
  229. // - The message limits for a tier-user
  230. stripeMock := &testStripeAPI{}
  231. defer stripeMock.AssertExpectations(t)
  232. c := newTestConfigWithAuthFile(t)
  233. c.StripeSecretKey = "secret key"
  234. c.StripeWebhookKey = "webhook key"
  235. c.VisitorRequestLimitBurst = 5
  236. c.VisitorRequestLimitReplenish = time.Hour
  237. c.CacheBatchSize = 500
  238. c.CacheBatchTimeout = time.Second
  239. s := newTestServer(t, c)
  240. s.stripe = stripeMock
  241. // Create a user with a Stripe subscription and 3 reservations
  242. require.Nil(t, s.userManager.AddTier(&user.Tier{
  243. ID: "ti_123",
  244. Code: "starter",
  245. StripeMonthlyPriceID: "price_1234",
  246. ReservationLimit: 1,
  247. MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
  248. MessageExpiryDuration: time.Hour,
  249. }))
  250. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
  251. u, err := s.userManager.User("phil")
  252. require.Nil(t, err)
  253. // Define how the mock should react
  254. stripeMock.
  255. On("GetSession", "SOMETOKEN").
  256. Return(&stripe.CheckoutSession{
  257. ClientReferenceID: u.ID, // ntfy user ID
  258. Customer: &stripe.Customer{
  259. ID: "acct_5555",
  260. },
  261. Subscription: &stripe.Subscription{
  262. ID: "sub_1234",
  263. },
  264. }, nil)
  265. stripeMock.
  266. On("GetSubscription", "sub_1234").
  267. Return(&stripe.Subscription{
  268. ID: "sub_1234",
  269. Status: stripe.SubscriptionStatusActive,
  270. CurrentPeriodEnd: 123456789,
  271. CancelAt: 0,
  272. Items: &stripe.SubscriptionItemList{
  273. Data: []*stripe.SubscriptionItem{
  274. {
  275. Price: &stripe.Price{
  276. ID: "price_1234",
  277. Recurring: &stripe.PriceRecurring{
  278. Interval: stripe.PriceRecurringIntervalMonth,
  279. },
  280. },
  281. },
  282. },
  283. },
  284. }, nil)
  285. stripeMock.
  286. On("UpdateCustomer", "acct_5555", &stripe.CustomerParams{
  287. Params: stripe.Params{
  288. Metadata: map[string]string{
  289. "user_id": u.ID,
  290. "user_name": u.Name,
  291. },
  292. },
  293. }).
  294. Return(&stripe.Customer{}, nil)
  295. // Send messages until rate limit of free tier is hit
  296. for i := 0; i < 5; i++ {
  297. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  298. "Authorization": util.BasicAuth("phil", "phil"),
  299. })
  300. require.Equal(t, 200, rr.Code)
  301. }
  302. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  303. "Authorization": util.BasicAuth("phil", "phil"),
  304. })
  305. require.Equal(t, 429, rr.Code)
  306. // Verify some "before-stats"
  307. u, err = s.userManager.User("phil")
  308. require.Nil(t, err)
  309. require.Nil(t, u.Tier)
  310. require.Equal(t, "", u.Billing.StripeCustomerID)
  311. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  312. require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  313. require.Equal(t, stripe.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval)
  314. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  315. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  316. require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
  317. require.Equal(t, int64(0), u.Stats.Emails)
  318. // Simulate Stripe success return URL call (no user context)
  319. rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
  320. require.Equal(t, 303, rr.Code)
  321. // Verify that database columns were updated
  322. u, err = s.userManager.User("phil")
  323. require.Nil(t, err)
  324. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  325. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  326. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  327. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
  328. require.Equal(t, stripe.PriceRecurringIntervalMonth, u.Billing.StripeSubscriptionInterval)
  329. require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
  330. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  331. require.Equal(t, int64(0), u.Stats.Messages)
  332. require.Equal(t, int64(0), u.Stats.Emails)
  333. // Now for the fun part: Verify that new rate limits are immediately applied
  334. // This only tests the request limiter, which kicks in before the message limiter.
  335. for i := 0; i < 11; i++ {
  336. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  337. "Authorization": util.BasicAuth("phil", "phil"),
  338. })
  339. require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
  340. }
  341. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  342. "Authorization": util.BasicAuth("phil", "phil"),
  343. })
  344. require.Equal(t, 429, rr.Code)
  345. // Now let's test the message limiter by faking a ridiculously generous rate limiter
  346. v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
  347. v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
  348. var wg sync.WaitGroup
  349. for i := 0; i < 209; i++ {
  350. wg.Add(1)
  351. go func(i int) {
  352. defer wg.Done()
  353. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  354. "Authorization": util.BasicAuth("phil", "phil"),
  355. })
  356. require.Equal(t, 200, rr.Code, "Failed on %d", i)
  357. }(i)
  358. }
  359. wg.Wait()
  360. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  361. "Authorization": util.BasicAuth("phil", "phil"),
  362. })
  363. require.Equal(t, 429, rr.Code)
  364. // And now let's cross-check that the stats are correct too
  365. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  366. "Authorization": util.BasicAuth("phil", "phil"),
  367. })
  368. require.Equal(t, 200, rr.Code)
  369. account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
  370. require.Equal(t, int64(220), account.Limits.Messages)
  371. require.Equal(t, int64(220), account.Stats.Messages)
  372. require.Equal(t, int64(0), account.Stats.MessagesRemaining)
  373. }
  374. func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
  375. t.Parallel()
  376. // This tests incoming webhooks from Stripe to update a subscription:
  377. // - All Stripe columns are updated in the user table
  378. // - When downgrading, excess reservations are deleted, including messages and attachments in
  379. // the corresponding topics
  380. stripeMock := &testStripeAPI{}
  381. defer stripeMock.AssertExpectations(t)
  382. c := newTestConfigWithAuthFile(t)
  383. c.StripeSecretKey = "secret key"
  384. c.StripeWebhookKey = "webhook key"
  385. s := newTestServer(t, c)
  386. s.stripe = stripeMock
  387. // Define how the mock should react
  388. stripeMock.
  389. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  390. Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
  391. // Create a user with a Stripe subscription and 3 reservations
  392. require.Nil(t, s.userManager.AddTier(&user.Tier{
  393. ID: "ti_1",
  394. Code: "starter",
  395. StripeMonthlyPriceID: "price_1234", // !
  396. ReservationLimit: 1, // !
  397. MessageLimit: 100,
  398. MessageExpiryDuration: time.Hour,
  399. AttachmentExpiryDuration: time.Hour,
  400. AttachmentFileSizeLimit: 1000000,
  401. AttachmentTotalSizeLimit: 1000000,
  402. AttachmentBandwidthLimit: 1000000,
  403. }))
  404. require.Nil(t, s.userManager.AddTier(&user.Tier{
  405. ID: "ti_2",
  406. Code: "pro",
  407. StripeMonthlyPriceID: "price_1111", // !
  408. ReservationLimit: 3, // !
  409. MessageLimit: 200,
  410. MessageExpiryDuration: time.Hour,
  411. AttachmentExpiryDuration: time.Hour,
  412. AttachmentFileSizeLimit: 1000000,
  413. AttachmentTotalSizeLimit: 1000000,
  414. AttachmentBandwidthLimit: 1000000,
  415. }))
  416. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  417. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  418. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  419. require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
  420. // Add billing details
  421. u, err := s.userManager.User("phil")
  422. require.Nil(t, err)
  423. billing := &user.Billing{
  424. StripeCustomerID: "acct_5555",
  425. StripeSubscriptionID: "sub_1234",
  426. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  427. StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
  428. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  429. StripeSubscriptionCancelAt: time.Unix(456, 0),
  430. }
  431. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  432. // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
  433. rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
  434. "Authorization": util.BasicAuth("phil", "phil"),
  435. })
  436. require.Equal(t, 200, rr.Code)
  437. rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
  438. "Authorization": util.BasicAuth("phil", "phil"),
  439. })
  440. require.Equal(t, 200, rr.Code)
  441. a2 := toMessage(t, rr.Body.String())
  442. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  443. rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
  444. "Authorization": util.BasicAuth("phil", "phil"),
  445. })
  446. require.Equal(t, 200, rr.Code)
  447. rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
  448. "Authorization": util.BasicAuth("phil", "phil"),
  449. })
  450. require.Equal(t, 200, rr.Code)
  451. z2 := toMessage(t, rr.Body.String())
  452. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  453. // Call the webhook: This does all the magic
  454. rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  455. "Stripe-Signature": "stripe signature",
  456. })
  457. require.Equal(t, 200, rr.Code)
  458. // Verify that database columns were updated
  459. u, err = s.userManager.User("phil")
  460. require.Nil(t, err)
  461. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  462. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  463. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  464. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
  465. require.Equal(t, stripe.PriceRecurringIntervalYear, u.Billing.StripeSubscriptionInterval) // Not "month"
  466. require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
  467. require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
  468. // Verify that reservations were deleted
  469. r, err := s.userManager.Reservations("phil")
  470. require.Nil(t, err)
  471. require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
  472. require.Equal(t, "atopic", r[0].Topic)
  473. // Verify that messages and attachments were deleted
  474. time.Sleep(time.Second)
  475. s.execManager()
  476. ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
  477. require.Nil(t, err)
  478. require.Equal(t, 2, len(ms))
  479. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  480. ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
  481. require.Nil(t, err)
  482. require.Equal(t, 0, len(ms))
  483. require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  484. }
  485. func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
  486. // This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is
  487. // updated (all Stripe fields are deleted, and the tier is removed).
  488. //
  489. // It doesn't fully test the message/attachment deletion. That is tested above in the subscription update call.
  490. stripeMock := &testStripeAPI{}
  491. defer stripeMock.AssertExpectations(t)
  492. c := newTestConfigWithAuthFile(t)
  493. c.StripeSecretKey = "secret key"
  494. c.StripeWebhookKey = "webhook key"
  495. s := newTestServer(t, c)
  496. s.stripe = stripeMock
  497. // Define how the mock should react
  498. stripeMock.
  499. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  500. Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil)
  501. // Create a user with a Stripe subscription and 3 reservations
  502. require.Nil(t, s.userManager.AddTier(&user.Tier{
  503. ID: "ti_1",
  504. Code: "pro",
  505. StripeMonthlyPriceID: "price_1234",
  506. ReservationLimit: 1,
  507. }))
  508. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  509. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  510. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  511. // Add billing details
  512. u, err := s.userManager.User("phil")
  513. require.Nil(t, err)
  514. require.Nil(t, s.userManager.ChangeBilling(u.Name, &user.Billing{
  515. StripeCustomerID: "acct_5555",
  516. StripeSubscriptionID: "sub_1234",
  517. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  518. StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
  519. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  520. StripeSubscriptionCancelAt: time.Unix(0, 0),
  521. }))
  522. // Call the webhook: This does all the magic
  523. rr := request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  524. "Stripe-Signature": "stripe signature",
  525. })
  526. require.Equal(t, 200, rr.Code)
  527. // Verify that database columns were updated
  528. u, err = s.userManager.User("phil")
  529. require.Nil(t, err)
  530. require.Nil(t, u.Tier)
  531. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  532. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  533. require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  534. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  535. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  536. // Verify that reservations were deleted
  537. r, err := s.userManager.Reservations("phil")
  538. require.Nil(t, err)
  539. require.Equal(t, 0, len(r))
  540. }
  541. func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
  542. stripeMock := &testStripeAPI{}
  543. defer stripeMock.AssertExpectations(t)
  544. c := newTestConfigWithAuthFile(t)
  545. c.StripeSecretKey = "secret key"
  546. c.StripeWebhookKey = "webhook key"
  547. s := newTestServer(t, c)
  548. s.stripe = stripeMock
  549. // Define how the mock should react
  550. stripeMock.
  551. On("GetSubscription", "sub_123").
  552. Return(&stripe.Subscription{
  553. ID: "sub_123",
  554. Items: &stripe.SubscriptionItemList{
  555. Data: []*stripe.SubscriptionItem{
  556. {
  557. ID: "someid_123",
  558. Price: &stripe.Price{ID: "price_123"},
  559. },
  560. },
  561. },
  562. }, nil)
  563. stripeMock.
  564. On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
  565. CancelAtPeriodEnd: stripe.Bool(false),
  566. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
  567. Items: []*stripe.SubscriptionItemsParams{
  568. {
  569. ID: stripe.String("someid_123"),
  570. Price: stripe.String("price_457"),
  571. },
  572. },
  573. }).
  574. Return(&stripe.Subscription{}, nil)
  575. // Create tier and user
  576. require.Nil(t, s.userManager.AddTier(&user.Tier{
  577. ID: "ti_123",
  578. Code: "pro",
  579. StripeMonthlyPriceID: "price_123",
  580. StripeYearlyPriceID: "price_124",
  581. }))
  582. require.Nil(t, s.userManager.AddTier(&user.Tier{
  583. ID: "ti_456",
  584. Code: "business",
  585. StripeMonthlyPriceID: "price_456",
  586. StripeYearlyPriceID: "price_457",
  587. }))
  588. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  589. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  590. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  591. StripeCustomerID: "acct_123",
  592. StripeSubscriptionID: "sub_123",
  593. }))
  594. // Call endpoint to change subscription
  595. rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{
  596. "Authorization": util.BasicAuth("phil", "phil"),
  597. })
  598. require.Equal(t, 200, rr.Code)
  599. }
  600. func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
  601. stripeMock := &testStripeAPI{}
  602. defer stripeMock.AssertExpectations(t)
  603. c := newTestConfigWithAuthFile(t)
  604. c.StripeSecretKey = "secret key"
  605. c.StripeWebhookKey = "webhook key"
  606. s := newTestServer(t, c)
  607. s.stripe = stripeMock
  608. // Define how the mock should react
  609. stripeMock.
  610. On("UpdateSubscription", "sub_123", mock.MatchedBy(func(s *stripe.SubscriptionParams) bool {
  611. return *s.CancelAtPeriodEnd // Is true
  612. })).
  613. Return(&stripe.Subscription{}, nil)
  614. // Create user
  615. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  616. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  617. StripeCustomerID: "acct_123",
  618. StripeSubscriptionID: "sub_123",
  619. }))
  620. // Delete subscription
  621. rr := request(t, s, "DELETE", "/v1/account/billing/subscription", "", map[string]string{
  622. "Authorization": util.BasicAuth("phil", "phil"),
  623. })
  624. require.Equal(t, 200, rr.Code)
  625. }
  626. func TestPayments_CreatePortalSession(t *testing.T) {
  627. stripeMock := &testStripeAPI{}
  628. defer stripeMock.AssertExpectations(t)
  629. c := newTestConfigWithAuthFile(t)
  630. c.StripeSecretKey = "secret key"
  631. c.StripeWebhookKey = "webhook key"
  632. s := newTestServer(t, c)
  633. s.stripe = stripeMock
  634. // Define how the mock should react
  635. stripeMock.
  636. On("NewPortalSession", &stripe.BillingPortalSessionParams{
  637. Customer: stripe.String("acct_123"),
  638. ReturnURL: stripe.String(s.config.BaseURL),
  639. }).
  640. Return(&stripe.BillingPortalSession{
  641. URL: "https://billing.stripe.com/blablabla",
  642. }, nil)
  643. // Create user
  644. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  645. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  646. StripeCustomerID: "acct_123",
  647. StripeSubscriptionID: "sub_123",
  648. }))
  649. // Create portal session
  650. rr := request(t, s, "POST", "/v1/account/billing/portal", "", map[string]string{
  651. "Authorization": util.BasicAuth("phil", "phil"),
  652. })
  653. require.Equal(t, 200, rr.Code)
  654. ps, _ := util.UnmarshalJSON[apiAccountBillingPortalRedirectResponse](io.NopCloser(rr.Body))
  655. require.Equal(t, "https://billing.stripe.com/blablabla", ps.RedirectURL)
  656. }
  657. type testStripeAPI struct {
  658. mock.Mock
  659. }
  660. var _ stripeAPI = (*testStripeAPI)(nil)
  661. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  662. args := s.Called(params)
  663. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  664. }
  665. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  666. args := s.Called(params)
  667. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  668. }
  669. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  670. args := s.Called(params)
  671. return args.Get(0).([]*stripe.Price), args.Error(1)
  672. }
  673. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  674. args := s.Called(id)
  675. return args.Get(0).(*stripe.Customer), args.Error(1)
  676. }
  677. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  678. args := s.Called(id)
  679. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  680. }
  681. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  682. args := s.Called(id)
  683. return args.Get(0).(*stripe.Subscription), args.Error(1)
  684. }
  685. func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
  686. args := s.Called(id, params)
  687. return args.Get(0).(*stripe.Customer), args.Error(1)
  688. }
  689. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  690. args := s.Called(id, params)
  691. return args.Get(0).(*stripe.Subscription), args.Error(1)
  692. }
  693. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  694. args := s.Called(id)
  695. return args.Get(0).(*stripe.Subscription), args.Error(1)
  696. }
  697. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  698. args := s.Called(payload, header, secret)
  699. return args.Get(0).(stripe.Event), args.Error(1)
  700. }
  701. func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
  702. var e stripe.Event
  703. if err := json.Unmarshal([]byte(v), &e); err != nil {
  704. t.Fatal(err)
  705. }
  706. return e
  707. }
  708. const subscriptionUpdatedEventJSON = `
  709. {
  710. "type": "customer.subscription.updated",
  711. "data": {
  712. "object": {
  713. "id": "sub_1234",
  714. "customer": "acct_5555",
  715. "status": "active",
  716. "current_period_end": 1674268231,
  717. "cancel_at": 1674299999,
  718. "items": {
  719. "data": [
  720. {
  721. "price": {
  722. "id": "price_1234",
  723. "recurring": {
  724. "interval": "year"
  725. }
  726. }
  727. }
  728. ]
  729. }
  730. }
  731. }
  732. }`
  733. const subscriptionDeletedEventJSON = `
  734. {
  735. "type": "customer.subscription.deleted",
  736. "data": {
  737. "object": {
  738. "id": "sub_1234",
  739. "customer": "acct_5555",
  740. "status": "active",
  741. "current_period_end": 1674268231,
  742. "cancel_at": 1674299999,
  743. "items": {
  744. "data": [
  745. {
  746. "price": {
  747. "id": "price_1234",
  748. "recurring": {
  749. "interval": "month"
  750. }
  751. }
  752. }
  753. ]
  754. }
  755. }
  756. }
  757. }`