manager.go 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "github.com/stripe/stripe-go/v74"
  9. "golang.org/x/crypto/bcrypt"
  10. "heckel.io/ntfy/log"
  11. "heckel.io/ntfy/util"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. bcryptCost = 10
  18. intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
  19. userStatsQueueWriterInterval = 33 * time.Second
  20. tokenLength = 32
  21. tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
  22. syncTopicLength = 16
  23. tokenMaxCount = 10 // Only keep this many tokens in the table per user
  24. )
  25. var (
  26. errNoTokenProvided = errors.New("no token provided")
  27. errTopicOwnedByOthers = errors.New("topic owned by others")
  28. errNoRows = errors.New("no rows found")
  29. )
  30. // Manager-related queries
  31. const (
  32. createTablesQueriesNoTx = `
  33. CREATE TABLE IF NOT EXISTS tier (
  34. id INTEGER PRIMARY KEY AUTOINCREMENT,
  35. code TEXT NOT NULL,
  36. name TEXT NOT NULL,
  37. messages_limit INT NOT NULL,
  38. messages_expiry_duration INT NOT NULL,
  39. emails_limit INT NOT NULL,
  40. reservations_limit INT NOT NULL,
  41. attachment_file_size_limit INT NOT NULL,
  42. attachment_total_size_limit INT NOT NULL,
  43. attachment_expiry_duration INT NOT NULL,
  44. stripe_price_id TEXT
  45. );
  46. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  47. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  48. CREATE TABLE IF NOT EXISTS user (
  49. id INTEGER PRIMARY KEY AUTOINCREMENT,
  50. tier_id INT,
  51. user TEXT NOT NULL,
  52. pass TEXT NOT NULL,
  53. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  54. prefs JSON NOT NULL DEFAULT '{}',
  55. sync_topic TEXT NOT NULL,
  56. stats_messages INT NOT NULL DEFAULT (0),
  57. stats_emails INT NOT NULL DEFAULT (0),
  58. stripe_customer_id TEXT,
  59. stripe_subscription_id TEXT,
  60. stripe_subscription_status TEXT,
  61. stripe_subscription_paid_until INT,
  62. stripe_subscription_cancel_at INT,
  63. created_by TEXT NOT NULL,
  64. created_at INT NOT NULL,
  65. FOREIGN KEY (tier_id) REFERENCES tier (id)
  66. );
  67. CREATE UNIQUE INDEX idx_user ON user (user);
  68. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  69. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  70. CREATE TABLE IF NOT EXISTS user_access (
  71. user_id INT NOT NULL,
  72. topic TEXT NOT NULL,
  73. read INT NOT NULL,
  74. write INT NOT NULL,
  75. owner_user_id INT,
  76. PRIMARY KEY (user_id, topic),
  77. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  78. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  79. );
  80. CREATE TABLE IF NOT EXISTS user_token (
  81. user_id INT NOT NULL,
  82. token TEXT NOT NULL,
  83. expires INT NOT NULL,
  84. PRIMARY KEY (user_id, token),
  85. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  86. );
  87. CREATE TABLE IF NOT EXISTS schemaVersion (
  88. id INT PRIMARY KEY,
  89. version INT NOT NULL
  90. );
  91. INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
  92. VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH())
  93. ON CONFLICT (id) DO NOTHING;
  94. `
  95. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  96. builtinStartupQueries = `
  97. PRAGMA foreign_keys = ON;
  98. `
  99. selectUserByNameQuery = `
  100. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  101. FROM user u
  102. LEFT JOIN tier t on t.id = u.tier_id
  103. WHERE user = ?
  104. `
  105. selectUserByTokenQuery = `
  106. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  107. FROM user u
  108. JOIN user_token t on u.id = t.user_id
  109. LEFT JOIN tier t on t.id = u.tier_id
  110. WHERE t.token = ? AND t.expires >= ?
  111. `
  112. selectUserByStripeCustomerIDQuery = `
  113. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  114. FROM user u
  115. LEFT JOIN tier t on t.id = u.tier_id
  116. WHERE u.stripe_customer_id = ?
  117. `
  118. selectTopicPermsQuery = `
  119. SELECT read, write
  120. FROM user_access a
  121. JOIN user u ON u.id = a.user_id
  122. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  123. ORDER BY u.user DESC
  124. `
  125. insertUserQuery = `
  126. INSERT INTO user (user, pass, role, sync_topic, created_by, created_at)
  127. VALUES (?, ?, ?, ?, ?, ?)
  128. `
  129. selectUsernamesQuery = `
  130. SELECT user
  131. FROM user
  132. ORDER BY
  133. CASE role
  134. WHEN 'admin' THEN 1
  135. WHEN 'anonymous' THEN 3
  136. ELSE 2
  137. END, user
  138. `
  139. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  140. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  141. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  142. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE user = ?`
  143. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  144. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  145. upsertUserAccessQuery = `
  146. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  147. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  148. ON CONFLICT (user_id, topic)
  149. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  150. `
  151. selectUserAccessQuery = `
  152. SELECT topic, read, write
  153. FROM user_access
  154. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  155. ORDER BY write DESC, read DESC, topic
  156. `
  157. selectUserReservationsQuery = `
  158. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  159. FROM user_access a_user
  160. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  161. WHERE a_user.user_id = a_user.owner_user_id
  162. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  163. ORDER BY a_user.topic
  164. `
  165. selectUserReservationsCountQuery = `
  166. SELECT COUNT(*)
  167. FROM user_access
  168. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  169. `
  170. selectUserHasReservationQuery = `
  171. SELECT COUNT(*)
  172. FROM user_access
  173. WHERE user_id = owner_user_id
  174. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  175. AND topic = ?
  176. `
  177. selectOtherAccessCountQuery = `
  178. SELECT COUNT(*)
  179. FROM user_access
  180. WHERE (topic = ? OR ? LIKE topic)
  181. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  182. `
  183. deleteAllAccessQuery = `DELETE FROM user_access`
  184. deleteUserAccessQuery = `
  185. DELETE FROM user_access
  186. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  187. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  188. `
  189. deleteTopicAccessQuery = `
  190. DELETE FROM user_access
  191. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  192. AND topic = ?
  193. `
  194. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
  195. insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
  196. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  197. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  198. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
  199. deleteExcessTokensQuery = `
  200. DELETE FROM user_token
  201. WHERE (user_id, token) NOT IN (
  202. SELECT user_id, token
  203. FROM user_token
  204. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  205. ORDER BY expires DESC
  206. LIMIT ?
  207. )
  208. `
  209. insertTierQuery = `
  210. INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
  211. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  212. `
  213. selectTiersQuery = `
  214. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  215. FROM tier
  216. `
  217. selectTierByCodeQuery = `
  218. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  219. FROM tier
  220. WHERE code = ?
  221. `
  222. selectTierByPriceIDQuery = `
  223. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  224. FROM tier
  225. WHERE stripe_price_id = ?
  226. `
  227. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  228. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  229. updateBillingQuery = `
  230. UPDATE user
  231. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  232. WHERE user = ?
  233. `
  234. )
  235. // Schema management queries
  236. const (
  237. currentSchemaVersion = 2
  238. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  239. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  240. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  241. // 1 -> 2 (complex migration!)
  242. migrate1To2RenameUserTableQueryNoTx = `
  243. ALTER TABLE user RENAME TO user_old;
  244. `
  245. migrate1To2InsertFromOldTablesAndDropNoTx = `
  246. INSERT INTO user (user, pass, role, sync_topic, created_by, created_at)
  247. SELECT user, pass, role, '', 'admin', UNIXEPOCH() FROM user_old;
  248. INSERT INTO user_access (user_id, topic, read, write)
  249. SELECT u.id, a.topic, a.read, a.write
  250. FROM user u
  251. JOIN access a ON u.user = a.user;
  252. DROP TABLE access;
  253. DROP TABLE user_old;
  254. `
  255. migrate1To2SelectAllUsersIDsNoTx = `SELECT id FROM user`
  256. migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
  257. )
  258. // Manager is an implementation of Manager. It stores users and access control list
  259. // in a SQLite database.
  260. type Manager struct {
  261. db *sql.DB
  262. defaultAccess Permission // Default permission if no ACL matches
  263. statsQueue map[string]*User // Username -> User, for "unimportant" user updates
  264. mu sync.Mutex
  265. }
  266. var _ Auther = (*Manager)(nil)
  267. // NewManager creates a new Manager instance
  268. func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
  269. return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
  270. }
  271. // NewManager creates a new Manager instance
  272. func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
  273. db, err := sql.Open("sqlite3", filename)
  274. if err != nil {
  275. return nil, err
  276. }
  277. if err := setupDB(db); err != nil {
  278. return nil, err
  279. }
  280. if err := runStartupQueries(db, startupQueries); err != nil {
  281. return nil, err
  282. }
  283. manager := &Manager{
  284. db: db,
  285. defaultAccess: defaultAccess,
  286. statsQueue: make(map[string]*User),
  287. }
  288. go manager.userStatsQueueWriter(statsWriterInterval)
  289. return manager, nil
  290. }
  291. // Authenticate checks username and password and returns a User if correct. The method
  292. // returns in constant-ish time, regardless of whether the user exists or the password is
  293. // correct or incorrect.
  294. func (a *Manager) Authenticate(username, password string) (*User, error) {
  295. if username == Everyone {
  296. return nil, ErrUnauthenticated
  297. }
  298. user, err := a.User(username)
  299. if err != nil {
  300. log.Trace("authentication of user %s failed (1): %s", username, err.Error())
  301. bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  302. return nil, ErrUnauthenticated
  303. }
  304. if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  305. log.Trace("authentication of user %s failed (2): %s", username, err.Error())
  306. return nil, ErrUnauthenticated
  307. }
  308. return user, nil
  309. }
  310. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  311. // The method sets the User.Token value to the token that was used for authentication.
  312. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  313. if len(token) != tokenLength {
  314. return nil, ErrUnauthenticated
  315. }
  316. user, err := a.userByToken(token)
  317. if err != nil {
  318. return nil, ErrUnauthenticated
  319. }
  320. user.Token = token
  321. return user, nil
  322. }
  323. // CreateToken generates a random token for the given user and returns it. The token expires
  324. // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
  325. // given user, if there are too many of them.
  326. func (a *Manager) CreateToken(user *User) (*Token, error) {
  327. token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
  328. tx, err := a.db.Begin()
  329. if err != nil {
  330. return nil, err
  331. }
  332. defer tx.Rollback()
  333. if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
  334. return nil, err
  335. }
  336. rows, err := tx.Query(selectTokenCountQuery, user.Name)
  337. if err != nil {
  338. return nil, err
  339. }
  340. defer rows.Close()
  341. if !rows.Next() {
  342. return nil, errNoRows
  343. }
  344. var tokenCount int
  345. if err := rows.Scan(&tokenCount); err != nil {
  346. return nil, err
  347. }
  348. if tokenCount >= tokenMaxCount {
  349. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  350. // on two indices, whereas the query below is a full table scan.
  351. if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
  352. return nil, err
  353. }
  354. }
  355. if err := tx.Commit(); err != nil {
  356. return nil, err
  357. }
  358. return &Token{
  359. Value: token,
  360. Expires: expires,
  361. }, nil
  362. }
  363. // ExtendToken sets the new expiry date for a token, thereby extending its use further into the future.
  364. func (a *Manager) ExtendToken(user *User) (*Token, error) {
  365. if user.Token == "" {
  366. return nil, errNoTokenProvided
  367. }
  368. newExpires := time.Now().Add(tokenExpiryDuration)
  369. if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
  370. return nil, err
  371. }
  372. return &Token{
  373. Value: user.Token,
  374. Expires: newExpires,
  375. }, nil
  376. }
  377. // RemoveToken deletes the token defined in User.Token
  378. func (a *Manager) RemoveToken(user *User) error {
  379. if user.Token == "" {
  380. return ErrUnauthorized
  381. }
  382. if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil {
  383. return err
  384. }
  385. return nil
  386. }
  387. // RemoveExpiredTokens deletes all expired tokens from the database
  388. func (a *Manager) RemoveExpiredTokens() error {
  389. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  390. return err
  391. }
  392. return nil
  393. }
  394. // ChangeSettings persists the user settings
  395. func (a *Manager) ChangeSettings(user *User) error {
  396. prefs, err := json.Marshal(user.Prefs)
  397. if err != nil {
  398. return err
  399. }
  400. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  401. return err
  402. }
  403. return nil
  404. }
  405. // ResetStats resets all user stats in the user database. This touches all users.
  406. func (a *Manager) ResetStats() error {
  407. a.mu.Lock()
  408. defer a.mu.Unlock()
  409. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  410. return err
  411. }
  412. a.statsQueue = make(map[string]*User)
  413. return nil
  414. }
  415. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  416. // batches at a regular interval
  417. func (a *Manager) EnqueueStats(user *User) {
  418. a.mu.Lock()
  419. defer a.mu.Unlock()
  420. a.statsQueue[user.Name] = user
  421. }
  422. func (a *Manager) userStatsQueueWriter(interval time.Duration) {
  423. ticker := time.NewTicker(interval)
  424. for range ticker.C {
  425. if err := a.writeUserStatsQueue(); err != nil {
  426. log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
  427. }
  428. }
  429. }
  430. func (a *Manager) writeUserStatsQueue() error {
  431. a.mu.Lock()
  432. if len(a.statsQueue) == 0 {
  433. a.mu.Unlock()
  434. log.Trace("User Manager: No user stats updates to commit")
  435. return nil
  436. }
  437. statsQueue := a.statsQueue
  438. a.statsQueue = make(map[string]*User)
  439. a.mu.Unlock()
  440. tx, err := a.db.Begin()
  441. if err != nil {
  442. return err
  443. }
  444. defer tx.Rollback()
  445. log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
  446. for username, u := range statsQueue {
  447. log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails)
  448. if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil {
  449. return err
  450. }
  451. }
  452. return tx.Commit()
  453. }
  454. // Authorize returns nil if the given user has access to the given topic using the desired
  455. // permission. The user param may be nil to signal an anonymous user.
  456. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  457. if user != nil && user.Role == RoleAdmin {
  458. return nil // Admin can do everything
  459. }
  460. username := Everyone
  461. if user != nil {
  462. username = user.Name
  463. }
  464. // Select the read/write permissions for this user/topic combo. The query may return two
  465. // rows (one for everyone, and one for the user), but prioritizes the user.
  466. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  467. if err != nil {
  468. return err
  469. }
  470. defer rows.Close()
  471. if !rows.Next() {
  472. return a.resolvePerms(a.defaultAccess, perm)
  473. }
  474. var read, write bool
  475. if err := rows.Scan(&read, &write); err != nil {
  476. return err
  477. } else if err := rows.Err(); err != nil {
  478. return err
  479. }
  480. return a.resolvePerms(NewPermission(read, write), perm)
  481. }
  482. func (a *Manager) resolvePerms(base, perm Permission) error {
  483. if perm == PermissionRead && base.IsRead() {
  484. return nil
  485. } else if perm == PermissionWrite && base.IsWrite() {
  486. return nil
  487. }
  488. return ErrUnauthorized
  489. }
  490. // AddUser adds a user with the given username, password and role
  491. func (a *Manager) AddUser(username, password string, role Role, createdBy string) error {
  492. if !AllowedUsername(username) || !AllowedRole(role) {
  493. return ErrInvalidArgument
  494. }
  495. hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
  496. if err != nil {
  497. return err
  498. }
  499. syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix()
  500. if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now); err != nil {
  501. return err
  502. }
  503. return nil
  504. }
  505. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  506. // if the user did not exist in the first place.
  507. func (a *Manager) RemoveUser(username string) error {
  508. if !AllowedUsername(username) {
  509. return ErrInvalidArgument
  510. }
  511. // Rows in user_access, user_token, etc. are deleted via foreign keys
  512. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  513. return err
  514. }
  515. return nil
  516. }
  517. // Users returns a list of users. It always also returns the Everyone user ("*").
  518. func (a *Manager) Users() ([]*User, error) {
  519. rows, err := a.db.Query(selectUsernamesQuery)
  520. if err != nil {
  521. return nil, err
  522. }
  523. defer rows.Close()
  524. usernames := make([]string, 0)
  525. for rows.Next() {
  526. var username string
  527. if err := rows.Scan(&username); err != nil {
  528. return nil, err
  529. } else if err := rows.Err(); err != nil {
  530. return nil, err
  531. }
  532. usernames = append(usernames, username)
  533. }
  534. rows.Close()
  535. users := make([]*User, 0)
  536. for _, username := range usernames {
  537. user, err := a.User(username)
  538. if err != nil {
  539. return nil, err
  540. }
  541. users = append(users, user)
  542. }
  543. return users, nil
  544. }
  545. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  546. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  547. func (a *Manager) User(username string) (*User, error) {
  548. rows, err := a.db.Query(selectUserByNameQuery, username)
  549. if err != nil {
  550. return nil, err
  551. }
  552. return a.readUser(rows)
  553. }
  554. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  555. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  556. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  557. if err != nil {
  558. return nil, err
  559. }
  560. return a.readUser(rows)
  561. }
  562. func (a *Manager) userByToken(token string) (*User, error) {
  563. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  564. if err != nil {
  565. return nil, err
  566. }
  567. return a.readUser(rows)
  568. }
  569. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  570. defer rows.Close()
  571. var username, hash, role, prefs, syncTopic string
  572. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
  573. var messages, emails int64
  574. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
  575. if !rows.Next() {
  576. return nil, ErrUserNotFound
  577. }
  578. if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  579. return nil, err
  580. } else if err := rows.Err(); err != nil {
  581. return nil, err
  582. }
  583. user := &User{
  584. Name: username,
  585. Hash: hash,
  586. Role: Role(role),
  587. Prefs: &Prefs{},
  588. SyncTopic: syncTopic,
  589. Stats: &Stats{
  590. Messages: messages,
  591. Emails: emails,
  592. },
  593. Billing: &Billing{
  594. StripeCustomerID: stripeCustomerID.String, // May be empty
  595. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  596. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  597. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  598. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  599. },
  600. }
  601. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  602. return nil, err
  603. }
  604. if tierCode.Valid {
  605. // See readTier() when this is changed!
  606. user.Tier = &Tier{
  607. Code: tierCode.String,
  608. Name: tierName.String,
  609. Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
  610. MessagesLimit: messagesLimit.Int64,
  611. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  612. EmailsLimit: emailsLimit.Int64,
  613. ReservationsLimit: reservationsLimit.Int64,
  614. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  615. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  616. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  617. StripePriceID: stripePriceID.String, // May be empty
  618. }
  619. }
  620. return user, nil
  621. }
  622. // Grants returns all user-specific access control entries
  623. func (a *Manager) Grants(username string) ([]Grant, error) {
  624. rows, err := a.db.Query(selectUserAccessQuery, username)
  625. if err != nil {
  626. return nil, err
  627. }
  628. defer rows.Close()
  629. grants := make([]Grant, 0)
  630. for rows.Next() {
  631. var topic string
  632. var read, write bool
  633. if err := rows.Scan(&topic, &read, &write); err != nil {
  634. return nil, err
  635. } else if err := rows.Err(); err != nil {
  636. return nil, err
  637. }
  638. grants = append(grants, Grant{
  639. TopicPattern: fromSQLWildcard(topic),
  640. Allow: NewPermission(read, write),
  641. })
  642. }
  643. return grants, nil
  644. }
  645. // Reservations returns all user-owned topics, and the associated everyone-access
  646. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  647. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  648. if err != nil {
  649. return nil, err
  650. }
  651. defer rows.Close()
  652. reservations := make([]Reservation, 0)
  653. for rows.Next() {
  654. var topic string
  655. var ownerRead, ownerWrite bool
  656. var everyoneRead, everyoneWrite sql.NullBool
  657. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  658. return nil, err
  659. } else if err := rows.Err(); err != nil {
  660. return nil, err
  661. }
  662. reservations = append(reservations, Reservation{
  663. Topic: topic,
  664. Owner: NewPermission(ownerRead, ownerWrite),
  665. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  666. })
  667. }
  668. return reservations, nil
  669. }
  670. // HasReservation returns true if the given topic access is owned by the user
  671. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  672. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  673. if err != nil {
  674. return false, err
  675. }
  676. defer rows.Close()
  677. if !rows.Next() {
  678. return false, errNoRows
  679. }
  680. var count int64
  681. if err := rows.Scan(&count); err != nil {
  682. return false, err
  683. }
  684. return count > 0, nil
  685. }
  686. // ReservationsCount returns the number of reservations owned by this user
  687. func (a *Manager) ReservationsCount(username string) (int64, error) {
  688. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  689. if err != nil {
  690. return 0, err
  691. }
  692. defer rows.Close()
  693. if !rows.Next() {
  694. return 0, errNoRows
  695. }
  696. var count int64
  697. if err := rows.Scan(&count); err != nil {
  698. return 0, err
  699. }
  700. return count, nil
  701. }
  702. // ChangePassword changes a user's password
  703. func (a *Manager) ChangePassword(username, password string) error {
  704. hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
  705. if err != nil {
  706. return err
  707. }
  708. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  709. return err
  710. }
  711. return nil
  712. }
  713. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  714. // all existing access control entries (Grant) are removed, since they are no longer needed.
  715. func (a *Manager) ChangeRole(username string, role Role) error {
  716. if !AllowedUsername(username) || !AllowedRole(role) {
  717. return ErrInvalidArgument
  718. }
  719. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  720. return err
  721. }
  722. if role == RoleAdmin {
  723. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  724. return err
  725. }
  726. }
  727. return nil
  728. }
  729. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  730. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  731. func (a *Manager) ChangeTier(username, tier string) error {
  732. if !AllowedUsername(username) {
  733. return ErrInvalidArgument
  734. }
  735. t, err := a.Tier(tier)
  736. if err != nil {
  737. return err
  738. } else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
  739. return err
  740. }
  741. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  742. return err
  743. }
  744. return nil
  745. }
  746. // ResetTier removes the tier from the given user
  747. func (a *Manager) ResetTier(username string) error {
  748. if !AllowedUsername(username) && username != Everyone && username != "" {
  749. return ErrInvalidArgument
  750. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  751. return err
  752. }
  753. _, err := a.db.Exec(deleteUserTierQuery, username)
  754. return err
  755. }
  756. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  757. u, err := a.User(username)
  758. if err != nil {
  759. return err
  760. }
  761. if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
  762. reservations, err := a.Reservations(username)
  763. if err != nil {
  764. return err
  765. } else if int64(len(reservations)) > reservationsLimit {
  766. return ErrTooManyReservations
  767. }
  768. }
  769. return nil
  770. }
  771. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  772. // If there are any ACL entries that are not owned by the user, an error is returned.
  773. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  774. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  775. return ErrInvalidArgument
  776. }
  777. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  778. if err != nil {
  779. return err
  780. }
  781. defer rows.Close()
  782. if !rows.Next() {
  783. return errNoRows
  784. }
  785. var otherCount int
  786. if err := rows.Scan(&otherCount); err != nil {
  787. return err
  788. }
  789. if otherCount > 0 {
  790. return errTopicOwnedByOthers
  791. }
  792. return nil
  793. }
  794. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  795. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  796. // owner may either be a user (username), or the system (empty).
  797. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  798. if !AllowedUsername(username) && username != Everyone {
  799. return ErrInvalidArgument
  800. } else if !AllowedTopicPattern(topicPattern) {
  801. return ErrInvalidArgument
  802. }
  803. owner := ""
  804. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  805. return err
  806. }
  807. return nil
  808. }
  809. func (a *Manager) ReserveAccess(username string, topic string, everyone Permission) error {
  810. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  811. return ErrInvalidArgument
  812. }
  813. tx, err := a.db.Begin()
  814. if err != nil {
  815. return err
  816. }
  817. defer tx.Rollback()
  818. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  819. return err
  820. }
  821. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  822. return err
  823. }
  824. return tx.Commit()
  825. }
  826. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  827. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  828. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  829. if !AllowedUsername(username) && username != Everyone && username != "" {
  830. return ErrInvalidArgument
  831. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  832. return ErrInvalidArgument
  833. }
  834. if username == "" && topicPattern == "" {
  835. _, err := a.db.Exec(deleteAllAccessQuery, username)
  836. return err
  837. } else if topicPattern == "" {
  838. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  839. return err
  840. }
  841. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  842. return err
  843. }
  844. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  845. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  846. return ErrInvalidArgument
  847. }
  848. for _, topic := range topics {
  849. if !AllowedTopic(topic) {
  850. return ErrInvalidArgument
  851. }
  852. }
  853. tx, err := a.db.Begin()
  854. if err != nil {
  855. return err
  856. }
  857. defer tx.Rollback()
  858. for _, topic := range topics {
  859. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  860. return err
  861. }
  862. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  863. return err
  864. }
  865. }
  866. return tx.Commit()
  867. }
  868. // DefaultAccess returns the default read/write access if no access control entry matches
  869. func (a *Manager) DefaultAccess() Permission {
  870. return a.defaultAccess
  871. }
  872. // CreateTier creates a new tier in the database
  873. func (a *Manager) CreateTier(tier *Tier) error {
  874. if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
  875. return err
  876. }
  877. return nil
  878. }
  879. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  880. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  881. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  882. return err
  883. }
  884. return nil
  885. }
  886. // Tiers returns a list of all Tier structs
  887. func (a *Manager) Tiers() ([]*Tier, error) {
  888. rows, err := a.db.Query(selectTiersQuery)
  889. if err != nil {
  890. return nil, err
  891. }
  892. defer rows.Close()
  893. tiers := make([]*Tier, 0)
  894. for {
  895. tier, err := a.readTier(rows)
  896. if err == ErrTierNotFound {
  897. break
  898. } else if err != nil {
  899. return nil, err
  900. }
  901. tiers = append(tiers, tier)
  902. }
  903. return tiers, nil
  904. }
  905. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  906. func (a *Manager) Tier(code string) (*Tier, error) {
  907. rows, err := a.db.Query(selectTierByCodeQuery, code)
  908. if err != nil {
  909. return nil, err
  910. }
  911. defer rows.Close()
  912. return a.readTier(rows)
  913. }
  914. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  915. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  916. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  917. if err != nil {
  918. return nil, err
  919. }
  920. defer rows.Close()
  921. return a.readTier(rows)
  922. }
  923. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  924. var code, name string
  925. var stripePriceID sql.NullString
  926. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
  927. if !rows.Next() {
  928. return nil, ErrTierNotFound
  929. }
  930. if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  931. return nil, err
  932. } else if err := rows.Err(); err != nil {
  933. return nil, err
  934. }
  935. // When changed, note readUser() as well
  936. return &Tier{
  937. Code: code,
  938. Name: name,
  939. Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
  940. MessagesLimit: messagesLimit.Int64,
  941. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  942. EmailsLimit: emailsLimit.Int64,
  943. ReservationsLimit: reservationsLimit.Int64,
  944. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  945. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  946. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  947. StripePriceID: stripePriceID.String, // May be empty
  948. }, nil
  949. }
  950. func toSQLWildcard(s string) string {
  951. return strings.ReplaceAll(s, "*", "%")
  952. }
  953. func fromSQLWildcard(s string) string {
  954. return strings.ReplaceAll(s, "%", "*")
  955. }
  956. func runStartupQueries(db *sql.DB, startupQueries string) error {
  957. if _, err := db.Exec(startupQueries); err != nil {
  958. return err
  959. }
  960. if _, err := db.Exec(builtinStartupQueries); err != nil {
  961. return err
  962. }
  963. return nil
  964. }
  965. func setupDB(db *sql.DB) error {
  966. // If 'schemaVersion' table does not exist, this must be a new database
  967. rowsSV, err := db.Query(selectSchemaVersionQuery)
  968. if err != nil {
  969. return setupNewDB(db)
  970. }
  971. defer rowsSV.Close()
  972. // If 'schemaVersion' table exists, read version and potentially upgrade
  973. schemaVersion := 0
  974. if !rowsSV.Next() {
  975. return errors.New("cannot determine schema version: database file may be corrupt")
  976. }
  977. if err := rowsSV.Scan(&schemaVersion); err != nil {
  978. return err
  979. }
  980. rowsSV.Close()
  981. // Do migrations
  982. if schemaVersion == currentSchemaVersion {
  983. return nil
  984. } else if schemaVersion == 1 {
  985. return migrateFrom1(db)
  986. }
  987. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  988. }
  989. func setupNewDB(db *sql.DB) error {
  990. if _, err := db.Exec(createTablesQueries); err != nil {
  991. return err
  992. }
  993. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  994. return err
  995. }
  996. return nil
  997. }
  998. func migrateFrom1(db *sql.DB) error {
  999. log.Info("Migrating user database schema: from 1 to 2")
  1000. tx, err := db.Begin()
  1001. if err != nil {
  1002. return err
  1003. }
  1004. defer tx.Rollback()
  1005. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  1006. return err
  1007. }
  1008. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  1009. return err
  1010. }
  1011. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1012. return err
  1013. }
  1014. rows, err := tx.Query(migrate1To2SelectAllUsersIDsNoTx)
  1015. if err != nil {
  1016. return err
  1017. }
  1018. defer rows.Close()
  1019. syncTopics := make(map[int]string)
  1020. for rows.Next() {
  1021. var userID int
  1022. if err := rows.Scan(&userID); err != nil {
  1023. return err
  1024. }
  1025. syncTopics[userID] = util.RandomString(syncTopicLength)
  1026. }
  1027. if err := rows.Close(); err != nil {
  1028. return err
  1029. }
  1030. for userID, syncTopic := range syncTopics {
  1031. if _, err := tx.Exec(migrate1To2UpdateSyncTopicNoTx, syncTopic, userID); err != nil {
  1032. return err
  1033. }
  1034. }
  1035. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1036. return err
  1037. }
  1038. if err := tx.Commit(); err != nil {
  1039. return err
  1040. }
  1041. return nil // Update this when a new version is added
  1042. }
  1043. func nullString(s string) sql.NullString {
  1044. if s == "" {
  1045. return sql.NullString{}
  1046. }
  1047. return sql.NullString{String: s, Valid: true}
  1048. }
  1049. func nullInt64(v int64) sql.NullInt64 {
  1050. if v == 0 {
  1051. return sql.NullInt64{}
  1052. }
  1053. return sql.NullInt64{Int64: v, Valid: true}
  1054. }