webpush_store.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. package server
  2. import (
  3. "database/sql"
  4. "errors"
  5. "heckel.io/ntfy/v2/util"
  6. "net/netip"
  7. "time"
  8. _ "github.com/mattn/go-sqlite3" // SQLite driver
  9. )
  10. const (
  11. subscriptionIDPrefix = "wps_"
  12. subscriptionIDLength = 10
  13. subscriptionEndpointLimitPerSubscriberIP = 10
  14. )
  15. var (
  16. errWebPushNoRows = errors.New("no rows found")
  17. errWebPushTooManySubscriptions = errors.New("too many subscriptions")
  18. errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
  19. )
  20. const (
  21. createWebPushSubscriptionsTableQuery = `
  22. BEGIN;
  23. CREATE TABLE IF NOT EXISTS subscription (
  24. id TEXT PRIMARY KEY,
  25. endpoint TEXT NOT NULL,
  26. key_auth TEXT NOT NULL,
  27. key_p256dh TEXT NOT NULL,
  28. user_id TEXT NOT NULL,
  29. subscriber_ip TEXT NOT NULL,
  30. updated_at INT NOT NULL,
  31. warned_at INT NOT NULL DEFAULT 0
  32. );
  33. CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
  34. CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
  35. CREATE TABLE IF NOT EXISTS subscription_topic (
  36. subscription_id TEXT NOT NULL,
  37. topic TEXT NOT NULL,
  38. PRIMARY KEY (subscription_id, topic),
  39. FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
  40. );
  41. CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
  42. CREATE TABLE IF NOT EXISTS schemaVersion (
  43. id INT PRIMARY KEY,
  44. version INT NOT NULL
  45. );
  46. COMMIT;
  47. `
  48. builtinStartupQueries = `
  49. PRAGMA foreign_keys = ON;
  50. `
  51. selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
  52. selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
  53. selectWebPushSubscriptionsForTopicQuery = `
  54. SELECT id, endpoint, key_auth, key_p256dh, user_id
  55. FROM subscription_topic st
  56. JOIN subscription s ON s.id = st.subscription_id
  57. WHERE st.topic = ?
  58. ORDER BY endpoint
  59. `
  60. selectWebPushSubscriptionsExpiringSoonQuery = `
  61. SELECT id, endpoint, key_auth, key_p256dh, user_id
  62. FROM subscription
  63. WHERE warned_at = 0 AND updated_at <= ?
  64. `
  65. insertWebPushSubscriptionQuery = `
  66. INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
  67. VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  68. ON CONFLICT (endpoint)
  69. DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
  70. `
  71. updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
  72. deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
  73. deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
  74. deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
  75. insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
  76. deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
  77. )
  78. // Schema management queries
  79. const (
  80. currentWebPushSchemaVersion = 1
  81. insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  82. selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  83. )
  84. type webPushStore struct {
  85. db *sql.DB
  86. }
  87. func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
  88. db, err := sql.Open("sqlite3", filename)
  89. if err != nil {
  90. return nil, err
  91. }
  92. if err := setupWebPushDB(db); err != nil {
  93. return nil, err
  94. }
  95. if err := runWebPushStartupQueries(db, startupQueries); err != nil {
  96. return nil, err
  97. }
  98. return &webPushStore{
  99. db: db,
  100. }, nil
  101. }
  102. func setupWebPushDB(db *sql.DB) error {
  103. // If 'schemaVersion' table does not exist, this must be a new database
  104. rows, err := db.Query(selectWebPushSchemaVersionQuery)
  105. if err != nil {
  106. return setupNewWebPushDB(db)
  107. }
  108. return rows.Close()
  109. }
  110. func setupNewWebPushDB(db *sql.DB) error {
  111. if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
  112. return err
  113. }
  114. if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
  115. return err
  116. }
  117. return nil
  118. }
  119. func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
  120. if _, err := db.Exec(startupQueries); err != nil {
  121. return err
  122. }
  123. if _, err := db.Exec(builtinStartupQueries); err != nil {
  124. return err
  125. }
  126. return nil
  127. }
  128. // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
  129. // existing entries for a given endpoint.
  130. func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
  131. tx, err := c.db.Begin()
  132. if err != nil {
  133. return err
  134. }
  135. defer tx.Rollback()
  136. // Read number of subscriptions for subscriber IP address
  137. rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
  138. if err != nil {
  139. return err
  140. }
  141. defer rowsCount.Close()
  142. var subscriptionCount int
  143. if !rowsCount.Next() {
  144. return errWebPushNoRows
  145. }
  146. if err := rowsCount.Scan(&subscriptionCount); err != nil {
  147. return err
  148. }
  149. if err := rowsCount.Close(); err != nil {
  150. return err
  151. }
  152. // Read existing subscription ID for endpoint (or create new ID)
  153. rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
  154. if err != nil {
  155. return err
  156. }
  157. defer rows.Close()
  158. var subscriptionID string
  159. if rows.Next() {
  160. if err := rows.Scan(&subscriptionID); err != nil {
  161. return err
  162. }
  163. } else {
  164. if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
  165. return errWebPushTooManySubscriptions
  166. }
  167. subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
  168. }
  169. if err := rows.Close(); err != nil {
  170. return err
  171. }
  172. // Insert or update subscription
  173. updatedAt, warnedAt := time.Now().Unix(), 0
  174. if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
  175. return err
  176. }
  177. // Replace all subscription topics
  178. if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
  179. return err
  180. }
  181. for _, topic := range topics {
  182. if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
  183. return err
  184. }
  185. }
  186. return tx.Commit()
  187. }
  188. // SubscriptionsForTopic returns all subscriptions for the given topic
  189. func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
  190. rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
  191. if err != nil {
  192. return nil, err
  193. }
  194. defer rows.Close()
  195. return c.subscriptionsFromRows(rows)
  196. }
  197. // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
  198. func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
  199. rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
  200. if err != nil {
  201. return nil, err
  202. }
  203. defer rows.Close()
  204. return c.subscriptionsFromRows(rows)
  205. }
  206. // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
  207. func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
  208. tx, err := c.db.Begin()
  209. if err != nil {
  210. return err
  211. }
  212. defer tx.Rollback()
  213. for _, subscription := range subscriptions {
  214. if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
  215. return err
  216. }
  217. }
  218. return tx.Commit()
  219. }
  220. func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
  221. subscriptions := make([]*webPushSubscription, 0)
  222. for rows.Next() {
  223. var id, endpoint, auth, p256dh, userID string
  224. if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
  225. return nil, err
  226. }
  227. subscriptions = append(subscriptions, &webPushSubscription{
  228. ID: id,
  229. Endpoint: endpoint,
  230. Auth: auth,
  231. P256dh: p256dh,
  232. UserID: userID,
  233. })
  234. }
  235. return subscriptions, nil
  236. }
  237. // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
  238. func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
  239. _, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
  240. return err
  241. }
  242. // RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
  243. func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
  244. if userID == "" {
  245. return errWebPushUserIDCannotBeEmpty
  246. }
  247. _, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
  248. return err
  249. }
  250. // RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
  251. func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
  252. _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
  253. return err
  254. }
  255. // Close closes the underlying database connection
  256. func (c *webPushStore) Close() error {
  257. return c.db.Close()
  258. }