server_payments_test.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828
  1. package server
  2. import (
  3. "encoding/json"
  4. "github.com/stretchr/testify/mock"
  5. "github.com/stretchr/testify/require"
  6. "github.com/stripe/stripe-go/v74"
  7. "golang.org/x/time/rate"
  8. "heckel.io/ntfy/user"
  9. "heckel.io/ntfy/util"
  10. "io"
  11. "net/netip"
  12. "path/filepath"
  13. "strings"
  14. "sync"
  15. "testing"
  16. "time"
  17. )
  18. func TestPayments_Tiers(t *testing.T) {
  19. stripeMock := &testStripeAPI{}
  20. defer stripeMock.AssertExpectations(t)
  21. c := newTestConfigWithAuthFile(t)
  22. c.StripeSecretKey = "secret key"
  23. c.StripeWebhookKey = "webhook key"
  24. c.VisitorRequestLimitReplenish = 12 * time.Hour
  25. c.CacheDuration = 13 * time.Hour
  26. c.AttachmentFileSizeLimit = 111
  27. c.VisitorAttachmentTotalSizeLimit = 222
  28. c.AttachmentExpiryDuration = 123 * time.Second
  29. s := newTestServer(t, c)
  30. s.stripe = stripeMock
  31. // Define how the mock should react
  32. stripeMock.
  33. On("ListPrices", mock.Anything).
  34. Return([]*stripe.Price{
  35. {ID: "price_123", UnitAmount: 500},
  36. {ID: "price_456", UnitAmount: 1000},
  37. {ID: "price_999", UnitAmount: 9999},
  38. }, nil)
  39. // Create tiers
  40. require.Nil(t, s.userManager.AddTier(&user.Tier{
  41. ID: "ti_1",
  42. Code: "admin",
  43. Name: "Admin",
  44. }))
  45. require.Nil(t, s.userManager.AddTier(&user.Tier{
  46. ID: "ti_123",
  47. Code: "pro",
  48. Name: "Pro",
  49. MessageLimit: 1000,
  50. MessageExpiryDuration: time.Hour,
  51. EmailLimit: 123,
  52. ReservationLimit: 777,
  53. AttachmentFileSizeLimit: 999,
  54. AttachmentTotalSizeLimit: 888,
  55. AttachmentExpiryDuration: time.Minute,
  56. StripePriceID: "price_123",
  57. }))
  58. require.Nil(t, s.userManager.AddTier(&user.Tier{
  59. ID: "ti_444",
  60. Code: "business",
  61. Name: "Business",
  62. MessageLimit: 2000,
  63. MessageExpiryDuration: 10 * time.Hour,
  64. EmailLimit: 123123,
  65. ReservationLimit: 777333,
  66. AttachmentFileSizeLimit: 999111,
  67. AttachmentTotalSizeLimit: 888111,
  68. AttachmentExpiryDuration: time.Hour,
  69. StripePriceID: "price_456",
  70. }))
  71. response := request(t, s, "GET", "/v1/tiers", "", nil)
  72. require.Equal(t, 200, response.Code)
  73. var tiers []apiAccountBillingTier
  74. require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
  75. require.Equal(t, 3, len(tiers))
  76. // Free tier
  77. tier := tiers[0]
  78. require.Equal(t, "", tier.Code)
  79. require.Equal(t, "", tier.Name)
  80. require.Equal(t, "ip", tier.Limits.Basis)
  81. require.Equal(t, int64(0), tier.Limits.Reservations)
  82. require.Equal(t, int64(2), tier.Limits.Messages) // :-(
  83. require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
  84. require.Equal(t, int64(24), tier.Limits.Emails)
  85. require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
  86. require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
  87. require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
  88. // Admin tier is not included, because it is not paid!
  89. tier = tiers[1]
  90. require.Equal(t, "pro", tier.Code)
  91. require.Equal(t, "Pro", tier.Name)
  92. require.Equal(t, "tier", tier.Limits.Basis)
  93. require.Equal(t, int64(777), tier.Limits.Reservations)
  94. require.Equal(t, int64(1000), tier.Limits.Messages)
  95. require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
  96. require.Equal(t, int64(123), tier.Limits.Emails)
  97. require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
  98. require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
  99. require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
  100. tier = tiers[2]
  101. require.Equal(t, "business", tier.Code)
  102. require.Equal(t, "Business", tier.Name)
  103. require.Equal(t, "tier", tier.Limits.Basis)
  104. require.Equal(t, int64(777333), tier.Limits.Reservations)
  105. require.Equal(t, int64(2000), tier.Limits.Messages)
  106. require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
  107. require.Equal(t, int64(123123), tier.Limits.Emails)
  108. require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
  109. require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
  110. require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
  111. }
  112. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  113. stripeMock := &testStripeAPI{}
  114. defer stripeMock.AssertExpectations(t)
  115. c := newTestConfigWithAuthFile(t)
  116. c.StripeSecretKey = "secret key"
  117. c.StripeWebhookKey = "webhook key"
  118. s := newTestServer(t, c)
  119. s.stripe = stripeMock
  120. // Define how the mock should react
  121. stripeMock.
  122. On("NewCheckoutSession", mock.Anything).
  123. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  124. // Create tier and user
  125. require.Nil(t, s.userManager.AddTier(&user.Tier{
  126. ID: "ti_123",
  127. Code: "pro",
  128. StripePriceID: "price_123",
  129. }))
  130. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  131. // Create subscription
  132. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  133. "Authorization": util.BasicAuth("phil", "phil"),
  134. })
  135. require.Equal(t, 200, response.Code)
  136. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  137. require.Nil(t, err)
  138. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  139. }
  140. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  141. stripeMock := &testStripeAPI{}
  142. defer stripeMock.AssertExpectations(t)
  143. c := newTestConfigWithAuthFile(t)
  144. c.StripeSecretKey = "secret key"
  145. c.StripeWebhookKey = "webhook key"
  146. s := newTestServer(t, c)
  147. s.stripe = stripeMock
  148. // Define how the mock should react
  149. stripeMock.
  150. On("GetCustomer", "acct_123").
  151. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  152. stripeMock.
  153. On("NewCheckoutSession", mock.Anything).
  154. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  155. // Create tier and user
  156. require.Nil(t, s.userManager.AddTier(&user.Tier{
  157. ID: "ti_123",
  158. Code: "pro",
  159. StripePriceID: "price_123",
  160. }))
  161. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  162. u, err := s.userManager.User("phil")
  163. require.Nil(t, err)
  164. billing := &user.Billing{
  165. StripeCustomerID: "acct_123",
  166. }
  167. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  168. // Create subscription
  169. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  170. "Authorization": util.BasicAuth("phil", "phil"),
  171. })
  172. require.Equal(t, 200, response.Code)
  173. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  174. require.Nil(t, err)
  175. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  176. }
  177. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  178. stripeMock := &testStripeAPI{}
  179. defer stripeMock.AssertExpectations(t)
  180. c := newTestConfigWithAuthFile(t)
  181. c.EnableSignup = true
  182. c.StripeSecretKey = "secret key"
  183. c.StripeWebhookKey = "webhook key"
  184. s := newTestServer(t, c)
  185. s.stripe = stripeMock
  186. // Define how the mock should react
  187. stripeMock.
  188. On("CancelSubscription", "sub_123").
  189. Return(&stripe.Subscription{}, nil)
  190. // Create tier and user
  191. require.Nil(t, s.userManager.AddTier(&user.Tier{
  192. ID: "ti_123",
  193. Code: "pro",
  194. StripePriceID: "price_123",
  195. }))
  196. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  197. u, err := s.userManager.User("phil")
  198. require.Nil(t, err)
  199. billing := &user.Billing{
  200. StripeCustomerID: "acct_123",
  201. StripeSubscriptionID: "sub_123",
  202. }
  203. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  204. // Delete account
  205. rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
  206. "Authorization": util.BasicAuth("phil", "phil"),
  207. })
  208. require.Equal(t, 200, rr.Code)
  209. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  210. "Authorization": util.BasicAuth("phil", "mypass"),
  211. })
  212. require.Equal(t, 401, rr.Code)
  213. }
  214. func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
  215. // This test is too overloaded, but it's also a great end-to-end a test.
  216. //
  217. // It tests:
  218. // - A successful checkout flow (not a paying customer -> paying customer)
  219. // - Tier-changes reset the rate limits for the user
  220. // - The request limits for tier-less user and a tier-user
  221. // - The message limits for a tier-user
  222. stripeMock := &testStripeAPI{}
  223. defer stripeMock.AssertExpectations(t)
  224. c := newTestConfigWithAuthFile(t)
  225. c.StripeSecretKey = "secret key"
  226. c.StripeWebhookKey = "webhook key"
  227. c.VisitorRequestLimitBurst = 5
  228. c.VisitorRequestLimitReplenish = time.Hour
  229. c.CacheBatchSize = 500
  230. c.CacheBatchTimeout = time.Second
  231. s := newTestServer(t, c)
  232. s.stripe = stripeMock
  233. // Create a user with a Stripe subscription and 3 reservations
  234. require.Nil(t, s.userManager.AddTier(&user.Tier{
  235. ID: "ti_123",
  236. Code: "starter",
  237. StripePriceID: "price_1234",
  238. ReservationLimit: 1,
  239. MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
  240. MessageExpiryDuration: time.Hour,
  241. }))
  242. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
  243. u, err := s.userManager.User("phil")
  244. require.Nil(t, err)
  245. // Define how the mock should react
  246. stripeMock.
  247. On("GetSession", "SOMETOKEN").
  248. Return(&stripe.CheckoutSession{
  249. ClientReferenceID: u.ID, // ntfy user ID
  250. Customer: &stripe.Customer{
  251. ID: "acct_5555",
  252. },
  253. Subscription: &stripe.Subscription{
  254. ID: "sub_1234",
  255. },
  256. }, nil)
  257. stripeMock.
  258. On("GetSubscription", "sub_1234").
  259. Return(&stripe.Subscription{
  260. ID: "sub_1234",
  261. Status: stripe.SubscriptionStatusActive,
  262. CurrentPeriodEnd: 123456789,
  263. CancelAt: 0,
  264. Items: &stripe.SubscriptionItemList{
  265. Data: []*stripe.SubscriptionItem{
  266. {
  267. Price: &stripe.Price{ID: "price_1234"},
  268. },
  269. },
  270. },
  271. }, nil)
  272. stripeMock.
  273. On("UpdateCustomer", "acct_5555", &stripe.CustomerParams{
  274. Params: stripe.Params{
  275. Metadata: map[string]string{
  276. "user_id": u.ID,
  277. "user_name": u.Name,
  278. },
  279. },
  280. }).
  281. Return(&stripe.Customer{}, nil)
  282. // Send messages until rate limit of free tier is hit
  283. for i := 0; i < 5; i++ {
  284. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  285. "Authorization": util.BasicAuth("phil", "phil"),
  286. })
  287. require.Equal(t, 200, rr.Code)
  288. }
  289. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  290. "Authorization": util.BasicAuth("phil", "phil"),
  291. })
  292. require.Equal(t, 429, rr.Code)
  293. // Verify some "before-stats"
  294. u, err = s.userManager.User("phil")
  295. require.Nil(t, err)
  296. require.Nil(t, u.Tier)
  297. require.Equal(t, "", u.Billing.StripeCustomerID)
  298. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  299. require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  300. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  301. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  302. require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
  303. require.Equal(t, int64(0), u.Stats.Emails)
  304. // Simulate Stripe success return URL call (no user context)
  305. rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
  306. require.Equal(t, 303, rr.Code)
  307. // Verify that database columns were updated
  308. u, err = s.userManager.User("phil")
  309. require.Nil(t, err)
  310. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  311. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  312. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  313. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
  314. require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
  315. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  316. require.Equal(t, int64(0), u.Stats.Messages)
  317. require.Equal(t, int64(0), u.Stats.Emails)
  318. // Now for the fun part: Verify that new rate limits are immediately applied
  319. // This only tests the request limiter, which kicks in before the message limiter.
  320. for i := 0; i < 11; i++ {
  321. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  322. "Authorization": util.BasicAuth("phil", "phil"),
  323. })
  324. require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
  325. }
  326. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  327. "Authorization": util.BasicAuth("phil", "phil"),
  328. })
  329. require.Equal(t, 429, rr.Code)
  330. // Now let's test the message limiter by faking a ridiculously generous rate limiter
  331. v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
  332. v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
  333. var wg sync.WaitGroup
  334. for i := 0; i < 209; i++ {
  335. wg.Add(1)
  336. go func(i int) {
  337. defer wg.Done()
  338. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  339. "Authorization": util.BasicAuth("phil", "phil"),
  340. })
  341. require.Equal(t, 200, rr.Code, "Failed on %d", i)
  342. }(i)
  343. }
  344. wg.Wait()
  345. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  346. "Authorization": util.BasicAuth("phil", "phil"),
  347. })
  348. require.Equal(t, 429, rr.Code)
  349. // And now let's cross-check that the stats are correct too
  350. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  351. "Authorization": util.BasicAuth("phil", "phil"),
  352. })
  353. require.Equal(t, 200, rr.Code)
  354. account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
  355. require.Equal(t, int64(220), account.Limits.Messages)
  356. require.Equal(t, int64(220), account.Stats.Messages)
  357. require.Equal(t, int64(0), account.Stats.MessagesRemaining)
  358. }
  359. func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
  360. // This tests incoming webhooks from Stripe to update a subscription:
  361. // - All Stripe columns are updated in the user table
  362. // - When downgrading, excess reservations are deleted, including messages and attachments in
  363. // the corresponding topics
  364. stripeMock := &testStripeAPI{}
  365. defer stripeMock.AssertExpectations(t)
  366. c := newTestConfigWithAuthFile(t)
  367. c.StripeSecretKey = "secret key"
  368. c.StripeWebhookKey = "webhook key"
  369. s := newTestServer(t, c)
  370. s.stripe = stripeMock
  371. // Define how the mock should react
  372. stripeMock.
  373. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  374. Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
  375. // Create a user with a Stripe subscription and 3 reservations
  376. require.Nil(t, s.userManager.AddTier(&user.Tier{
  377. ID: "ti_1",
  378. Code: "starter",
  379. StripePriceID: "price_1234", // !
  380. ReservationLimit: 1, // !
  381. MessageLimit: 100,
  382. MessageExpiryDuration: time.Hour,
  383. AttachmentExpiryDuration: time.Hour,
  384. AttachmentFileSizeLimit: 1000000,
  385. AttachmentTotalSizeLimit: 1000000,
  386. AttachmentBandwidthLimit: 1000000,
  387. }))
  388. require.Nil(t, s.userManager.AddTier(&user.Tier{
  389. ID: "ti_2",
  390. Code: "pro",
  391. StripePriceID: "price_1111", // !
  392. ReservationLimit: 3, // !
  393. MessageLimit: 200,
  394. MessageExpiryDuration: time.Hour,
  395. AttachmentExpiryDuration: time.Hour,
  396. AttachmentFileSizeLimit: 1000000,
  397. AttachmentTotalSizeLimit: 1000000,
  398. AttachmentBandwidthLimit: 1000000,
  399. }))
  400. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  401. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  402. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  403. require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
  404. // Add billing details
  405. u, err := s.userManager.User("phil")
  406. require.Nil(t, err)
  407. billing := &user.Billing{
  408. StripeCustomerID: "acct_5555",
  409. StripeSubscriptionID: "sub_1234",
  410. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  411. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  412. StripeSubscriptionCancelAt: time.Unix(456, 0),
  413. }
  414. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  415. // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
  416. rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
  417. "Authorization": util.BasicAuth("phil", "phil"),
  418. })
  419. require.Equal(t, 200, rr.Code)
  420. rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
  421. "Authorization": util.BasicAuth("phil", "phil"),
  422. })
  423. require.Equal(t, 200, rr.Code)
  424. a2 := toMessage(t, rr.Body.String())
  425. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  426. rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
  427. "Authorization": util.BasicAuth("phil", "phil"),
  428. })
  429. require.Equal(t, 200, rr.Code)
  430. rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
  431. "Authorization": util.BasicAuth("phil", "phil"),
  432. })
  433. require.Equal(t, 200, rr.Code)
  434. z2 := toMessage(t, rr.Body.String())
  435. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  436. // Call the webhook: This does all the magic
  437. rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  438. "Stripe-Signature": "stripe signature",
  439. })
  440. require.Equal(t, 200, rr.Code)
  441. // Verify that database columns were updated
  442. u, err = s.userManager.User("phil")
  443. require.Nil(t, err)
  444. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  445. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  446. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  447. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
  448. require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
  449. require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
  450. // Verify that reservations were deleted
  451. r, err := s.userManager.Reservations("phil")
  452. require.Nil(t, err)
  453. require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
  454. require.Equal(t, "atopic", r[0].Topic)
  455. // Verify that messages and attachments were deleted
  456. time.Sleep(time.Second)
  457. s.execManager()
  458. ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
  459. require.Nil(t, err)
  460. require.Equal(t, 2, len(ms))
  461. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  462. ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
  463. require.Nil(t, err)
  464. require.Equal(t, 0, len(ms))
  465. require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  466. }
  467. func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
  468. // This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is
  469. // updated (all Stripe fields are deleted, and the tier is removed).
  470. //
  471. // It doesn't fully test the message/attachment deletion. That is tested above in the subscription update call.
  472. stripeMock := &testStripeAPI{}
  473. defer stripeMock.AssertExpectations(t)
  474. c := newTestConfigWithAuthFile(t)
  475. c.StripeSecretKey = "secret key"
  476. c.StripeWebhookKey = "webhook key"
  477. s := newTestServer(t, c)
  478. s.stripe = stripeMock
  479. // Define how the mock should react
  480. stripeMock.
  481. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  482. Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil)
  483. // Create a user with a Stripe subscription and 3 reservations
  484. require.Nil(t, s.userManager.AddTier(&user.Tier{
  485. ID: "ti_1",
  486. Code: "pro",
  487. StripePriceID: "price_1234",
  488. ReservationLimit: 1,
  489. }))
  490. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  491. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  492. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  493. // Add billing details
  494. u, err := s.userManager.User("phil")
  495. require.Nil(t, err)
  496. require.Nil(t, s.userManager.ChangeBilling(u.Name, &user.Billing{
  497. StripeCustomerID: "acct_5555",
  498. StripeSubscriptionID: "sub_1234",
  499. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  500. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  501. StripeSubscriptionCancelAt: time.Unix(0, 0),
  502. }))
  503. // Call the webhook: This does all the magic
  504. rr := request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  505. "Stripe-Signature": "stripe signature",
  506. })
  507. require.Equal(t, 200, rr.Code)
  508. // Verify that database columns were updated
  509. u, err = s.userManager.User("phil")
  510. require.Nil(t, err)
  511. require.Nil(t, u.Tier)
  512. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  513. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  514. require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  515. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  516. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  517. // Verify that reservations were deleted
  518. r, err := s.userManager.Reservations("phil")
  519. require.Nil(t, err)
  520. require.Equal(t, 0, len(r))
  521. }
  522. func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
  523. stripeMock := &testStripeAPI{}
  524. defer stripeMock.AssertExpectations(t)
  525. c := newTestConfigWithAuthFile(t)
  526. c.StripeSecretKey = "secret key"
  527. c.StripeWebhookKey = "webhook key"
  528. s := newTestServer(t, c)
  529. s.stripe = stripeMock
  530. // Define how the mock should react
  531. stripeMock.
  532. On("GetSubscription", "sub_123").
  533. Return(&stripe.Subscription{
  534. ID: "sub_123",
  535. Items: &stripe.SubscriptionItemList{
  536. Data: []*stripe.SubscriptionItem{
  537. {
  538. ID: "someid_123",
  539. Price: &stripe.Price{ID: "price_123"},
  540. },
  541. },
  542. },
  543. }, nil)
  544. stripeMock.
  545. On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
  546. CancelAtPeriodEnd: stripe.Bool(false),
  547. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
  548. Items: []*stripe.SubscriptionItemsParams{
  549. {
  550. ID: stripe.String("someid_123"),
  551. Price: stripe.String("price_456"),
  552. },
  553. },
  554. }).
  555. Return(&stripe.Subscription{}, nil)
  556. // Create tier and user
  557. require.Nil(t, s.userManager.AddTier(&user.Tier{
  558. ID: "ti_123",
  559. Code: "pro",
  560. StripePriceID: "price_123",
  561. }))
  562. require.Nil(t, s.userManager.AddTier(&user.Tier{
  563. ID: "ti_456",
  564. Code: "business",
  565. StripePriceID: "price_456",
  566. }))
  567. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  568. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  569. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  570. StripeCustomerID: "acct_123",
  571. StripeSubscriptionID: "sub_123",
  572. }))
  573. // Call endpoint to change subscription
  574. rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business"}`, map[string]string{
  575. "Authorization": util.BasicAuth("phil", "phil"),
  576. })
  577. require.Equal(t, 200, rr.Code)
  578. }
  579. func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
  580. stripeMock := &testStripeAPI{}
  581. defer stripeMock.AssertExpectations(t)
  582. c := newTestConfigWithAuthFile(t)
  583. c.StripeSecretKey = "secret key"
  584. c.StripeWebhookKey = "webhook key"
  585. s := newTestServer(t, c)
  586. s.stripe = stripeMock
  587. // Define how the mock should react
  588. stripeMock.
  589. On("UpdateSubscription", "sub_123", mock.MatchedBy(func(s *stripe.SubscriptionParams) bool {
  590. return *s.CancelAtPeriodEnd // Is true
  591. })).
  592. Return(&stripe.Subscription{}, nil)
  593. // Create user
  594. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  595. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  596. StripeCustomerID: "acct_123",
  597. StripeSubscriptionID: "sub_123",
  598. }))
  599. // Delete subscription
  600. rr := request(t, s, "DELETE", "/v1/account/billing/subscription", "", map[string]string{
  601. "Authorization": util.BasicAuth("phil", "phil"),
  602. })
  603. require.Equal(t, 200, rr.Code)
  604. }
  605. func TestPayments_CreatePortalSession(t *testing.T) {
  606. stripeMock := &testStripeAPI{}
  607. defer stripeMock.AssertExpectations(t)
  608. c := newTestConfigWithAuthFile(t)
  609. c.StripeSecretKey = "secret key"
  610. c.StripeWebhookKey = "webhook key"
  611. s := newTestServer(t, c)
  612. s.stripe = stripeMock
  613. // Define how the mock should react
  614. stripeMock.
  615. On("NewPortalSession", &stripe.BillingPortalSessionParams{
  616. Customer: stripe.String("acct_123"),
  617. ReturnURL: stripe.String(s.config.BaseURL),
  618. }).
  619. Return(&stripe.BillingPortalSession{
  620. URL: "https://billing.stripe.com/blablabla",
  621. }, nil)
  622. // Create user
  623. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  624. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  625. StripeCustomerID: "acct_123",
  626. StripeSubscriptionID: "sub_123",
  627. }))
  628. // Create portal session
  629. rr := request(t, s, "POST", "/v1/account/billing/portal", "", map[string]string{
  630. "Authorization": util.BasicAuth("phil", "phil"),
  631. })
  632. require.Equal(t, 200, rr.Code)
  633. ps, _ := util.UnmarshalJSON[apiAccountBillingPortalRedirectResponse](io.NopCloser(rr.Body))
  634. require.Equal(t, "https://billing.stripe.com/blablabla", ps.RedirectURL)
  635. }
  636. type testStripeAPI struct {
  637. mock.Mock
  638. }
  639. var _ stripeAPI = (*testStripeAPI)(nil)
  640. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  641. args := s.Called(params)
  642. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  643. }
  644. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  645. args := s.Called(params)
  646. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  647. }
  648. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  649. args := s.Called(params)
  650. return args.Get(0).([]*stripe.Price), args.Error(1)
  651. }
  652. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  653. args := s.Called(id)
  654. return args.Get(0).(*stripe.Customer), args.Error(1)
  655. }
  656. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  657. args := s.Called(id)
  658. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  659. }
  660. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  661. args := s.Called(id)
  662. return args.Get(0).(*stripe.Subscription), args.Error(1)
  663. }
  664. func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
  665. args := s.Called(id, params)
  666. return args.Get(0).(*stripe.Customer), args.Error(1)
  667. }
  668. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  669. args := s.Called(id, params)
  670. return args.Get(0).(*stripe.Subscription), args.Error(1)
  671. }
  672. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  673. args := s.Called(id)
  674. return args.Get(0).(*stripe.Subscription), args.Error(1)
  675. }
  676. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  677. args := s.Called(payload, header, secret)
  678. return args.Get(0).(stripe.Event), args.Error(1)
  679. }
  680. func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
  681. var e stripe.Event
  682. if err := json.Unmarshal([]byte(v), &e); err != nil {
  683. t.Fatal(err)
  684. }
  685. return e
  686. }
  687. const subscriptionUpdatedEventJSON = `
  688. {
  689. "type": "customer.subscription.updated",
  690. "data": {
  691. "object": {
  692. "id": "sub_1234",
  693. "customer": "acct_5555",
  694. "status": "active",
  695. "current_period_end": 1674268231,
  696. "cancel_at": 1674299999,
  697. "items": {
  698. "data": [
  699. {
  700. "price": {
  701. "id": "price_1234"
  702. }
  703. }
  704. ]
  705. }
  706. }
  707. }
  708. }`
  709. const subscriptionDeletedEventJSON = `
  710. {
  711. "type": "customer.subscription.deleted",
  712. "data": {
  713. "object": {
  714. "id": "sub_1234",
  715. "customer": "acct_5555",
  716. "status": "active",
  717. "current_period_end": 1674268231,
  718. "cancel_at": 1674299999,
  719. "items": {
  720. "data": [
  721. {
  722. "price": {
  723. "id": "price_1234"
  724. }
  725. }
  726. ]
  727. }
  728. }
  729. }
  730. }`