server_webpush_test.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package server
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/stretchr/testify/require"
  6. "heckel.io/ntfy/user"
  7. "heckel.io/ntfy/util"
  8. "io"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/netip"
  12. "strings"
  13. "sync/atomic"
  14. "testing"
  15. "time"
  16. )
  17. const (
  18. testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
  19. )
  20. func TestServer_WebPush_Disabled(t *testing.T) {
  21. s := newTestServer(t, newTestConfig(t))
  22. response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
  23. require.Equal(t, 404, response.Code)
  24. }
  25. func TestServer_WebPush_TopicAdd(t *testing.T) {
  26. s := newTestServer(t, newTestConfigWithWebPush(t))
  27. response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
  28. require.Equal(t, 200, response.Code)
  29. require.Equal(t, `{"success":true}`+"\n", response.Body.String())
  30. subs, err := s.webPush.SubscriptionsForTopic("test-topic")
  31. require.Nil(t, err)
  32. require.Len(t, subs, 1)
  33. require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
  34. require.Equal(t, subs[0].P256dh, "p256dh-key")
  35. require.Equal(t, subs[0].Auth, "auth-key")
  36. require.Equal(t, subs[0].UserID, "")
  37. }
  38. func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
  39. s := newTestServer(t, newTestConfigWithWebPush(t))
  40. response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
  41. require.Equal(t, 400, response.Code)
  42. require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
  43. }
  44. func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
  45. s := newTestServer(t, newTestConfigWithWebPush(t))
  46. topicList := make([]string, 51)
  47. for i := range topicList {
  48. topicList[i] = util.RandomString(5)
  49. }
  50. response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
  51. require.Equal(t, 400, response.Code)
  52. require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
  53. }
  54. func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
  55. s := newTestServer(t, newTestConfigWithWebPush(t))
  56. addSubscription(t, s, testWebPushEndpoint, "test-topic")
  57. requireSubscriptionCount(t, s, "test-topic", 1)
  58. response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
  59. require.Equal(t, 200, response.Code)
  60. require.Equal(t, `{"success":true}`+"\n", response.Body.String())
  61. requireSubscriptionCount(t, s, "test-topic", 0)
  62. }
  63. func TestServer_WebPush_Delete(t *testing.T) {
  64. s := newTestServer(t, newTestConfigWithWebPush(t))
  65. addSubscription(t, s, testWebPushEndpoint, "test-topic")
  66. requireSubscriptionCount(t, s, "test-topic", 1)
  67. response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
  68. require.Equal(t, 200, response.Code)
  69. require.Equal(t, `{"success":true}`+"\n", response.Body.String())
  70. requireSubscriptionCount(t, s, "test-topic", 0)
  71. }
  72. func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
  73. config := configureAuth(t, newTestConfigWithWebPush(t))
  74. config.AuthDefault = user.PermissionDenyAll
  75. s := newTestServer(t, config)
  76. require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
  77. require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
  78. response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
  79. "Authorization": util.BasicAuth("ben", "ben"),
  80. })
  81. require.Equal(t, 200, response.Code)
  82. require.Equal(t, `{"success":true}`+"\n", response.Body.String())
  83. subs, err := s.webPush.SubscriptionsForTopic("test-topic")
  84. require.Nil(t, err)
  85. require.Len(t, subs, 1)
  86. require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
  87. }
  88. func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
  89. config := configureAuth(t, newTestConfigWithWebPush(t))
  90. config.AuthDefault = user.PermissionDenyAll
  91. s := newTestServer(t, config)
  92. response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
  93. require.Equal(t, 403, response.Code)
  94. requireSubscriptionCount(t, s, "test-topic", 0)
  95. }
  96. func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
  97. config := configureAuth(t, newTestConfigWithWebPush(t))
  98. s := newTestServer(t, config)
  99. require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
  100. require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
  101. response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
  102. "Authorization": util.BasicAuth("ben", "ben"),
  103. })
  104. require.Equal(t, 200, response.Code)
  105. require.Equal(t, `{"success":true}`+"\n", response.Body.String())
  106. requireSubscriptionCount(t, s, "test-topic", 1)
  107. request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
  108. "Authorization": util.BasicAuth("ben", "ben"),
  109. })
  110. // should've been deleted with the account
  111. requireSubscriptionCount(t, s, "test-topic", 0)
  112. }
  113. func TestServer_WebPush_Publish(t *testing.T) {
  114. s := newTestServer(t, newTestConfigWithWebPush(t))
  115. var received atomic.Bool
  116. pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  117. _, err := io.ReadAll(r.Body)
  118. require.Nil(t, err)
  119. require.Equal(t, "/push-receive", r.URL.Path)
  120. require.Equal(t, "high", r.Header.Get("Urgency"))
  121. require.Equal(t, "", r.Header.Get("Topic"))
  122. received.Store(true)
  123. }))
  124. defer pushService.Close()
  125. addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
  126. request(t, s, "POST", "/test-topic", "web push test", nil)
  127. waitFor(t, func() bool {
  128. return received.Load()
  129. })
  130. }
  131. func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
  132. s := newTestServer(t, newTestConfigWithWebPush(t))
  133. var received atomic.Bool
  134. pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  135. _, err := io.ReadAll(r.Body)
  136. require.Nil(t, err)
  137. w.WriteHeader(http.StatusGone)
  138. received.Store(true)
  139. }))
  140. defer pushService.Close()
  141. addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
  142. requireSubscriptionCount(t, s, "test-topic", 1)
  143. requireSubscriptionCount(t, s, "test-topic-abc", 1)
  144. request(t, s, "POST", "/test-topic", "web push test", nil)
  145. waitFor(t, func() bool {
  146. return received.Load()
  147. })
  148. // Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
  149. requireSubscriptionCount(t, s, "test-topic", 0)
  150. requireSubscriptionCount(t, s, "test-topic-abc", 0)
  151. }
  152. func TestServer_WebPush_Expiry(t *testing.T) {
  153. s := newTestServer(t, newTestConfigWithWebPush(t))
  154. var received atomic.Bool
  155. pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  156. _, err := io.ReadAll(r.Body)
  157. require.Nil(t, err)
  158. w.WriteHeader(200)
  159. w.Write([]byte(``))
  160. received.Store(true)
  161. }))
  162. defer pushService.Close()
  163. addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
  164. requireSubscriptionCount(t, s, "test-topic", 1)
  165. _, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-7*24*time.Hour).Unix())
  166. require.Nil(t, err)
  167. s.pruneAndNotifyWebPushSubscriptions()
  168. requireSubscriptionCount(t, s, "test-topic", 1)
  169. waitFor(t, func() bool {
  170. return received.Load()
  171. })
  172. _, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-9*24*time.Hour).Unix())
  173. require.Nil(t, err)
  174. s.pruneAndNotifyWebPushSubscriptions()
  175. waitFor(t, func() bool {
  176. subs, err := s.webPush.SubscriptionsForTopic("test-topic")
  177. require.Nil(t, err)
  178. return len(subs) == 0
  179. })
  180. }
  181. func payloadForTopics(t *testing.T, topics []string, endpoint string) string {
  182. topicsJSON, err := json.Marshal(topics)
  183. require.Nil(t, err)
  184. return fmt.Sprintf(`{
  185. "topics": %s,
  186. "endpoint": "%s",
  187. "p256dh": "p256dh-key",
  188. "auth": "auth-key"
  189. }`, topicsJSON, endpoint)
  190. }
  191. func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
  192. require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh
  193. }
  194. func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
  195. subs, err := s.webPush.SubscriptionsForTopic(topic)
  196. require.Nil(t, err)
  197. require.Len(t, subs, expectedLength)
  198. }