manager.go 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "github.com/stripe/stripe-go/v74"
  9. "golang.org/x/crypto/bcrypt"
  10. "heckel.io/ntfy/log"
  11. "heckel.io/ntfy/util"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. tierIDPrefix = "ti_"
  18. tierIDLength = 8
  19. syncTopicPrefix = "st_"
  20. syncTopicLength = 16
  21. userIDPrefix = "u_"
  22. userIDLength = 12
  23. userPasswordBcryptCost = 10
  24. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
  25. userStatsQueueWriterInterval = 33 * time.Second
  26. tokenPrefix = "tk_"
  27. tokenLength = 32
  28. tokenMaxCount = 10 // Only keep this many tokens in the table per user
  29. tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
  30. )
  31. var (
  32. errNoTokenProvided = errors.New("no token provided")
  33. errTopicOwnedByOthers = errors.New("topic owned by others")
  34. errNoRows = errors.New("no rows found")
  35. )
  36. // Manager-related queries
  37. const (
  38. createTablesQueriesNoTx = `
  39. CREATE TABLE IF NOT EXISTS tier (
  40. id TEXT PRIMARY KEY,
  41. code TEXT NOT NULL,
  42. name TEXT NOT NULL,
  43. messages_limit INT NOT NULL,
  44. messages_expiry_duration INT NOT NULL,
  45. emails_limit INT NOT NULL,
  46. reservations_limit INT NOT NULL,
  47. attachment_file_size_limit INT NOT NULL,
  48. attachment_total_size_limit INT NOT NULL,
  49. attachment_expiry_duration INT NOT NULL,
  50. stripe_price_id TEXT
  51. );
  52. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  53. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  54. CREATE TABLE IF NOT EXISTS user (
  55. id TEXT PRIMARY KEY,
  56. tier_id INT,
  57. user TEXT NOT NULL,
  58. pass TEXT NOT NULL,
  59. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  60. prefs JSON NOT NULL DEFAULT '{}',
  61. sync_topic TEXT NOT NULL,
  62. stats_messages INT NOT NULL DEFAULT (0),
  63. stats_emails INT NOT NULL DEFAULT (0),
  64. stripe_customer_id TEXT,
  65. stripe_subscription_id TEXT,
  66. stripe_subscription_status TEXT,
  67. stripe_subscription_paid_until INT,
  68. stripe_subscription_cancel_at INT,
  69. created_by TEXT NOT NULL,
  70. created_at INT NOT NULL,
  71. FOREIGN KEY (tier_id) REFERENCES tier (id)
  72. );
  73. CREATE UNIQUE INDEX idx_user ON user (user);
  74. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  75. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  76. CREATE TABLE IF NOT EXISTS user_access (
  77. user_id TEXT NOT NULL,
  78. topic TEXT NOT NULL,
  79. read INT NOT NULL,
  80. write INT NOT NULL,
  81. owner_user_id INT,
  82. PRIMARY KEY (user_id, topic),
  83. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  84. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  85. );
  86. CREATE TABLE IF NOT EXISTS user_token (
  87. user_id TEXT NOT NULL,
  88. token TEXT NOT NULL,
  89. expires INT NOT NULL,
  90. PRIMARY KEY (user_id, token),
  91. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  92. );
  93. CREATE TABLE IF NOT EXISTS schemaVersion (
  94. id INT PRIMARY KEY,
  95. version INT NOT NULL
  96. );
  97. INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
  98. VALUES ('u_everyone', '*', '', 'anonymous', '', 'system', UNIXEPOCH())
  99. ON CONFLICT (id) DO NOTHING;
  100. `
  101. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  102. builtinStartupQueries = `
  103. PRAGMA foreign_keys = ON;
  104. `
  105. selectUserByIDQuery = `
  106. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  107. FROM user u
  108. LEFT JOIN tier t on t.id = u.tier_id
  109. WHERE u.id = ?
  110. `
  111. selectUserByNameQuery = `
  112. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  113. FROM user u
  114. LEFT JOIN tier t on t.id = u.tier_id
  115. WHERE user = ?
  116. `
  117. selectUserByTokenQuery = `
  118. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  119. FROM user u
  120. JOIN user_token t on u.id = t.user_id
  121. LEFT JOIN tier t on t.id = u.tier_id
  122. WHERE t.token = ? AND t.expires >= ?
  123. `
  124. selectUserByStripeCustomerIDQuery = `
  125. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  126. FROM user u
  127. LEFT JOIN tier t on t.id = u.tier_id
  128. WHERE u.stripe_customer_id = ?
  129. `
  130. selectTopicPermsQuery = `
  131. SELECT read, write
  132. FROM user_access a
  133. JOIN user u ON u.id = a.user_id
  134. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  135. ORDER BY u.user DESC
  136. `
  137. insertUserQuery = `
  138. INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
  139. VALUES (?, ?, ?, ?, ?, ?, ?)
  140. `
  141. selectUsernamesQuery = `
  142. SELECT user
  143. FROM user
  144. ORDER BY
  145. CASE role
  146. WHEN 'admin' THEN 1
  147. WHEN 'anonymous' THEN 3
  148. ELSE 2
  149. END, user
  150. `
  151. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  152. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  153. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  154. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
  155. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  156. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  157. upsertUserAccessQuery = `
  158. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  159. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  160. ON CONFLICT (user_id, topic)
  161. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  162. `
  163. selectUserAccessQuery = `
  164. SELECT topic, read, write
  165. FROM user_access
  166. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  167. ORDER BY write DESC, read DESC, topic
  168. `
  169. selectUserReservationsQuery = `
  170. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  171. FROM user_access a_user
  172. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  173. WHERE a_user.user_id = a_user.owner_user_id
  174. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  175. ORDER BY a_user.topic
  176. `
  177. selectUserReservationsCountQuery = `
  178. SELECT COUNT(*)
  179. FROM user_access
  180. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  181. `
  182. selectUserHasReservationQuery = `
  183. SELECT COUNT(*)
  184. FROM user_access
  185. WHERE user_id = owner_user_id
  186. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  187. AND topic = ?
  188. `
  189. selectOtherAccessCountQuery = `
  190. SELECT COUNT(*)
  191. FROM user_access
  192. WHERE (topic = ? OR ? LIKE topic)
  193. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  194. `
  195. deleteAllAccessQuery = `DELETE FROM user_access`
  196. deleteUserAccessQuery = `
  197. DELETE FROM user_access
  198. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  199. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  200. `
  201. deleteTopicAccessQuery = `
  202. DELETE FROM user_access
  203. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  204. AND topic = ?
  205. `
  206. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  207. insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)`
  208. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  209. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  210. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
  211. deleteExcessTokensQuery = `
  212. DELETE FROM user_token
  213. WHERE (user_id, token) NOT IN (
  214. SELECT user_id, token
  215. FROM user_token
  216. WHERE user_id = ?
  217. ORDER BY expires DESC
  218. LIMIT ?
  219. )
  220. `
  221. insertTierQuery = `
  222. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
  223. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  224. `
  225. selectTiersQuery = `
  226. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  227. FROM tier
  228. `
  229. selectTierByCodeQuery = `
  230. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  231. FROM tier
  232. WHERE code = ?
  233. `
  234. selectTierByPriceIDQuery = `
  235. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  236. FROM tier
  237. WHERE stripe_price_id = ?
  238. `
  239. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  240. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  241. updateBillingQuery = `
  242. UPDATE user
  243. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  244. WHERE user = ?
  245. `
  246. )
  247. // Schema management queries
  248. const (
  249. currentSchemaVersion = 2
  250. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  251. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  252. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  253. // 1 -> 2 (complex migration!)
  254. migrate1To2RenameUserTableQueryNoTx = `
  255. ALTER TABLE user RENAME TO user_old;
  256. `
  257. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  258. migrate1To2InsertUserNoTx = `
  259. INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
  260. SELECT ?, user, pass, role, ?, 'admin', UNIXEPOCH() FROM user_old WHERE user = ?
  261. `
  262. migrate1To2InsertFromOldTablesAndDropNoTx = `
  263. INSERT INTO user_access (user_id, topic, read, write)
  264. SELECT u.id, a.topic, a.read, a.write
  265. FROM user u
  266. JOIN access a ON u.user = a.user;
  267. DROP TABLE access;
  268. DROP TABLE user_old;
  269. `
  270. migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
  271. )
  272. // Manager is an implementation of Manager. It stores users and access control list
  273. // in a SQLite database.
  274. type Manager struct {
  275. db *sql.DB
  276. defaultAccess Permission // Default permission if no ACL matches
  277. statsQueue map[string]*User // Username -> User, for "unimportant" user updates
  278. mu sync.Mutex
  279. }
  280. var _ Auther = (*Manager)(nil)
  281. // NewManager creates a new Manager instance
  282. func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
  283. return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
  284. }
  285. // NewManager creates a new Manager instance
  286. func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
  287. db, err := sql.Open("sqlite3", filename)
  288. if err != nil {
  289. return nil, err
  290. }
  291. if err := setupDB(db); err != nil {
  292. return nil, err
  293. }
  294. if err := runStartupQueries(db, startupQueries); err != nil {
  295. return nil, err
  296. }
  297. manager := &Manager{
  298. db: db,
  299. defaultAccess: defaultAccess,
  300. statsQueue: make(map[string]*User),
  301. }
  302. go manager.userStatsQueueWriter(statsWriterInterval)
  303. return manager, nil
  304. }
  305. // Authenticate checks username and password and returns a User if correct. The method
  306. // returns in constant-ish time, regardless of whether the user exists or the password is
  307. // correct or incorrect.
  308. func (a *Manager) Authenticate(username, password string) (*User, error) {
  309. if username == Everyone {
  310. return nil, ErrUnauthenticated
  311. }
  312. user, err := a.User(username)
  313. if err != nil {
  314. log.Trace("authentication of user %s failed (1): %s", username, err.Error())
  315. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  316. return nil, ErrUnauthenticated
  317. }
  318. if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  319. log.Trace("authentication of user %s failed (2): %s", username, err.Error())
  320. return nil, ErrUnauthenticated
  321. }
  322. return user, nil
  323. }
  324. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  325. // The method sets the User.Token value to the token that was used for authentication.
  326. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  327. if len(token) != tokenLength {
  328. return nil, ErrUnauthenticated
  329. }
  330. user, err := a.userByToken(token)
  331. if err != nil {
  332. return nil, ErrUnauthenticated
  333. }
  334. user.Token = token
  335. return user, nil
  336. }
  337. // CreateToken generates a random token for the given user and returns it. The token expires
  338. // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
  339. // given user, if there are too many of them.
  340. func (a *Manager) CreateToken(user *User) (*Token, error) {
  341. token, expires := util.RandomStringPrefix(tokenPrefix, tokenLength), time.Now().Add(tokenExpiryDuration)
  342. tx, err := a.db.Begin()
  343. if err != nil {
  344. return nil, err
  345. }
  346. defer tx.Rollback()
  347. if _, err := tx.Exec(insertTokenQuery, user.ID, token, expires.Unix()); err != nil {
  348. return nil, err
  349. }
  350. rows, err := tx.Query(selectTokenCountQuery, user.ID)
  351. if err != nil {
  352. return nil, err
  353. }
  354. defer rows.Close()
  355. if !rows.Next() {
  356. return nil, errNoRows
  357. }
  358. var tokenCount int
  359. if err := rows.Scan(&tokenCount); err != nil {
  360. return nil, err
  361. }
  362. if tokenCount >= tokenMaxCount {
  363. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  364. // on two indices, whereas the query below is a full table scan.
  365. if _, err := tx.Exec(deleteExcessTokensQuery, user.ID, tokenMaxCount); err != nil {
  366. return nil, err
  367. }
  368. }
  369. if err := tx.Commit(); err != nil {
  370. return nil, err
  371. }
  372. return &Token{
  373. Value: token,
  374. Expires: expires,
  375. }, nil
  376. }
  377. // ExtendToken sets the new expiry date for a token, thereby extending its use further into the future.
  378. func (a *Manager) ExtendToken(user *User) (*Token, error) {
  379. if user.Token == "" {
  380. return nil, errNoTokenProvided
  381. }
  382. newExpires := time.Now().Add(tokenExpiryDuration)
  383. if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
  384. return nil, err
  385. }
  386. return &Token{
  387. Value: user.Token,
  388. Expires: newExpires,
  389. }, nil
  390. }
  391. // RemoveToken deletes the token defined in User.Token
  392. func (a *Manager) RemoveToken(user *User) error {
  393. if user.Token == "" {
  394. return ErrUnauthorized
  395. }
  396. if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil {
  397. return err
  398. }
  399. return nil
  400. }
  401. // RemoveExpiredTokens deletes all expired tokens from the database
  402. func (a *Manager) RemoveExpiredTokens() error {
  403. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  404. return err
  405. }
  406. return nil
  407. }
  408. // ChangeSettings persists the user settings
  409. func (a *Manager) ChangeSettings(user *User) error {
  410. prefs, err := json.Marshal(user.Prefs)
  411. if err != nil {
  412. return err
  413. }
  414. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  415. return err
  416. }
  417. return nil
  418. }
  419. // ResetStats resets all user stats in the user database. This touches all users.
  420. func (a *Manager) ResetStats() error {
  421. a.mu.Lock()
  422. defer a.mu.Unlock()
  423. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  424. return err
  425. }
  426. a.statsQueue = make(map[string]*User)
  427. return nil
  428. }
  429. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  430. // batches at a regular interval
  431. func (a *Manager) EnqueueStats(user *User) {
  432. a.mu.Lock()
  433. defer a.mu.Unlock()
  434. a.statsQueue[user.ID] = user
  435. }
  436. func (a *Manager) userStatsQueueWriter(interval time.Duration) {
  437. ticker := time.NewTicker(interval)
  438. for range ticker.C {
  439. if err := a.writeUserStatsQueue(); err != nil {
  440. log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
  441. }
  442. }
  443. }
  444. func (a *Manager) writeUserStatsQueue() error {
  445. a.mu.Lock()
  446. if len(a.statsQueue) == 0 {
  447. a.mu.Unlock()
  448. log.Trace("User Manager: No user stats updates to commit")
  449. return nil
  450. }
  451. statsQueue := a.statsQueue
  452. a.statsQueue = make(map[string]*User)
  453. a.mu.Unlock()
  454. tx, err := a.db.Begin()
  455. if err != nil {
  456. return err
  457. }
  458. defer tx.Rollback()
  459. log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
  460. for userID, u := range statsQueue {
  461. log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, u.Stats.Messages, u.Stats.Emails)
  462. if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, userID); err != nil {
  463. return err
  464. }
  465. }
  466. return tx.Commit()
  467. }
  468. // Authorize returns nil if the given user has access to the given topic using the desired
  469. // permission. The user param may be nil to signal an anonymous user.
  470. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  471. if user != nil && user.Role == RoleAdmin {
  472. return nil // Admin can do everything
  473. }
  474. username := Everyone
  475. if user != nil {
  476. username = user.Name
  477. }
  478. // Select the read/write permissions for this user/topic combo. The query may return two
  479. // rows (one for everyone, and one for the user), but prioritizes the user.
  480. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  481. if err != nil {
  482. return err
  483. }
  484. defer rows.Close()
  485. if !rows.Next() {
  486. return a.resolvePerms(a.defaultAccess, perm)
  487. }
  488. var read, write bool
  489. if err := rows.Scan(&read, &write); err != nil {
  490. return err
  491. } else if err := rows.Err(); err != nil {
  492. return err
  493. }
  494. return a.resolvePerms(NewPermission(read, write), perm)
  495. }
  496. func (a *Manager) resolvePerms(base, perm Permission) error {
  497. if perm == PermissionRead && base.IsRead() {
  498. return nil
  499. } else if perm == PermissionWrite && base.IsWrite() {
  500. return nil
  501. }
  502. return ErrUnauthorized
  503. }
  504. // AddUser adds a user with the given username, password and role
  505. func (a *Manager) AddUser(username, password string, role Role, createdBy string) error {
  506. if !AllowedUsername(username) || !AllowedRole(role) {
  507. return ErrInvalidArgument
  508. }
  509. hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
  510. if err != nil {
  511. return err
  512. }
  513. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  514. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  515. if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, createdBy, now); err != nil {
  516. return err
  517. }
  518. return nil
  519. }
  520. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  521. // if the user did not exist in the first place.
  522. func (a *Manager) RemoveUser(username string) error {
  523. if !AllowedUsername(username) {
  524. return ErrInvalidArgument
  525. }
  526. // Rows in user_access, user_token, etc. are deleted via foreign keys
  527. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  528. return err
  529. }
  530. return nil
  531. }
  532. // Users returns a list of users. It always also returns the Everyone user ("*").
  533. func (a *Manager) Users() ([]*User, error) {
  534. rows, err := a.db.Query(selectUsernamesQuery)
  535. if err != nil {
  536. return nil, err
  537. }
  538. defer rows.Close()
  539. usernames := make([]string, 0)
  540. for rows.Next() {
  541. var username string
  542. if err := rows.Scan(&username); err != nil {
  543. return nil, err
  544. } else if err := rows.Err(); err != nil {
  545. return nil, err
  546. }
  547. usernames = append(usernames, username)
  548. }
  549. rows.Close()
  550. users := make([]*User, 0)
  551. for _, username := range usernames {
  552. user, err := a.User(username)
  553. if err != nil {
  554. return nil, err
  555. }
  556. users = append(users, user)
  557. }
  558. return users, nil
  559. }
  560. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  561. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  562. func (a *Manager) User(username string) (*User, error) {
  563. rows, err := a.db.Query(selectUserByNameQuery, username)
  564. if err != nil {
  565. return nil, err
  566. }
  567. return a.readUser(rows)
  568. }
  569. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  570. func (a *Manager) UserByID(id string) (*User, error) {
  571. rows, err := a.db.Query(selectUserByIDQuery, id)
  572. if err != nil {
  573. return nil, err
  574. }
  575. return a.readUser(rows)
  576. }
  577. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  578. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  579. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  580. if err != nil {
  581. return nil, err
  582. }
  583. return a.readUser(rows)
  584. }
  585. func (a *Manager) userByToken(token string) (*User, error) {
  586. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  587. if err != nil {
  588. return nil, err
  589. }
  590. return a.readUser(rows)
  591. }
  592. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  593. defer rows.Close()
  594. var id, username, hash, role, prefs, syncTopic string
  595. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
  596. var messages, emails int64
  597. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
  598. if !rows.Next() {
  599. return nil, ErrUserNotFound
  600. }
  601. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  602. return nil, err
  603. } else if err := rows.Err(); err != nil {
  604. return nil, err
  605. }
  606. user := &User{
  607. ID: id,
  608. Name: username,
  609. Hash: hash,
  610. Role: Role(role),
  611. Prefs: &Prefs{},
  612. SyncTopic: syncTopic,
  613. Stats: &Stats{
  614. Messages: messages,
  615. Emails: emails,
  616. },
  617. Billing: &Billing{
  618. StripeCustomerID: stripeCustomerID.String, // May be empty
  619. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  620. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  621. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  622. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  623. },
  624. }
  625. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  626. return nil, err
  627. }
  628. if tierCode.Valid {
  629. // See readTier() when this is changed!
  630. user.Tier = &Tier{
  631. Code: tierCode.String,
  632. Name: tierName.String,
  633. Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
  634. MessagesLimit: messagesLimit.Int64,
  635. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  636. EmailsLimit: emailsLimit.Int64,
  637. ReservationsLimit: reservationsLimit.Int64,
  638. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  639. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  640. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  641. StripePriceID: stripePriceID.String, // May be empty
  642. }
  643. }
  644. return user, nil
  645. }
  646. // Grants returns all user-specific access control entries
  647. func (a *Manager) Grants(username string) ([]Grant, error) {
  648. rows, err := a.db.Query(selectUserAccessQuery, username)
  649. if err != nil {
  650. return nil, err
  651. }
  652. defer rows.Close()
  653. grants := make([]Grant, 0)
  654. for rows.Next() {
  655. var topic string
  656. var read, write bool
  657. if err := rows.Scan(&topic, &read, &write); err != nil {
  658. return nil, err
  659. } else if err := rows.Err(); err != nil {
  660. return nil, err
  661. }
  662. grants = append(grants, Grant{
  663. TopicPattern: fromSQLWildcard(topic),
  664. Allow: NewPermission(read, write),
  665. })
  666. }
  667. return grants, nil
  668. }
  669. // Reservations returns all user-owned topics, and the associated everyone-access
  670. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  671. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  672. if err != nil {
  673. return nil, err
  674. }
  675. defer rows.Close()
  676. reservations := make([]Reservation, 0)
  677. for rows.Next() {
  678. var topic string
  679. var ownerRead, ownerWrite bool
  680. var everyoneRead, everyoneWrite sql.NullBool
  681. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  682. return nil, err
  683. } else if err := rows.Err(); err != nil {
  684. return nil, err
  685. }
  686. reservations = append(reservations, Reservation{
  687. Topic: topic,
  688. Owner: NewPermission(ownerRead, ownerWrite),
  689. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  690. })
  691. }
  692. return reservations, nil
  693. }
  694. // HasReservation returns true if the given topic access is owned by the user
  695. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  696. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  697. if err != nil {
  698. return false, err
  699. }
  700. defer rows.Close()
  701. if !rows.Next() {
  702. return false, errNoRows
  703. }
  704. var count int64
  705. if err := rows.Scan(&count); err != nil {
  706. return false, err
  707. }
  708. return count > 0, nil
  709. }
  710. // ReservationsCount returns the number of reservations owned by this user
  711. func (a *Manager) ReservationsCount(username string) (int64, error) {
  712. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  713. if err != nil {
  714. return 0, err
  715. }
  716. defer rows.Close()
  717. if !rows.Next() {
  718. return 0, errNoRows
  719. }
  720. var count int64
  721. if err := rows.Scan(&count); err != nil {
  722. return 0, err
  723. }
  724. return count, nil
  725. }
  726. // ChangePassword changes a user's password
  727. func (a *Manager) ChangePassword(username, password string) error {
  728. hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
  729. if err != nil {
  730. return err
  731. }
  732. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  733. return err
  734. }
  735. return nil
  736. }
  737. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  738. // all existing access control entries (Grant) are removed, since they are no longer needed.
  739. func (a *Manager) ChangeRole(username string, role Role) error {
  740. if !AllowedUsername(username) || !AllowedRole(role) {
  741. return ErrInvalidArgument
  742. }
  743. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  744. return err
  745. }
  746. if role == RoleAdmin {
  747. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  748. return err
  749. }
  750. }
  751. return nil
  752. }
  753. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  754. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  755. func (a *Manager) ChangeTier(username, tier string) error {
  756. if !AllowedUsername(username) {
  757. return ErrInvalidArgument
  758. }
  759. t, err := a.Tier(tier)
  760. if err != nil {
  761. return err
  762. } else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
  763. return err
  764. }
  765. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  766. return err
  767. }
  768. return nil
  769. }
  770. // ResetTier removes the tier from the given user
  771. func (a *Manager) ResetTier(username string) error {
  772. if !AllowedUsername(username) && username != Everyone && username != "" {
  773. return ErrInvalidArgument
  774. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  775. return err
  776. }
  777. _, err := a.db.Exec(deleteUserTierQuery, username)
  778. return err
  779. }
  780. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  781. u, err := a.User(username)
  782. if err != nil {
  783. return err
  784. }
  785. if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
  786. reservations, err := a.Reservations(username)
  787. if err != nil {
  788. return err
  789. } else if int64(len(reservations)) > reservationsLimit {
  790. return ErrTooManyReservations
  791. }
  792. }
  793. return nil
  794. }
  795. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  796. // If there are any ACL entries that are not owned by the user, an error is returned.
  797. // FIXME is this the same as HasReservation?
  798. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  799. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  800. return ErrInvalidArgument
  801. }
  802. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  803. if err != nil {
  804. return err
  805. }
  806. defer rows.Close()
  807. if !rows.Next() {
  808. return errNoRows
  809. }
  810. var otherCount int
  811. if err := rows.Scan(&otherCount); err != nil {
  812. return err
  813. }
  814. if otherCount > 0 {
  815. return errTopicOwnedByOthers
  816. }
  817. return nil
  818. }
  819. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  820. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  821. // owner may either be a user (username), or the system (empty).
  822. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  823. if !AllowedUsername(username) && username != Everyone {
  824. return ErrInvalidArgument
  825. } else if !AllowedTopicPattern(topicPattern) {
  826. return ErrInvalidArgument
  827. }
  828. owner := ""
  829. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  830. return err
  831. }
  832. return nil
  833. }
  834. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  835. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  836. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  837. if !AllowedUsername(username) && username != Everyone && username != "" {
  838. return ErrInvalidArgument
  839. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  840. return ErrInvalidArgument
  841. }
  842. if username == "" && topicPattern == "" {
  843. _, err := a.db.Exec(deleteAllAccessQuery, username)
  844. return err
  845. } else if topicPattern == "" {
  846. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  847. return err
  848. }
  849. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  850. return err
  851. }
  852. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  853. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  854. // can modify or delete them.
  855. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  856. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  857. return ErrInvalidArgument
  858. }
  859. tx, err := a.db.Begin()
  860. if err != nil {
  861. return err
  862. }
  863. defer tx.Rollback()
  864. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  865. return err
  866. }
  867. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  868. return err
  869. }
  870. return tx.Commit()
  871. }
  872. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  873. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  874. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  875. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  876. return ErrInvalidArgument
  877. }
  878. for _, topic := range topics {
  879. if !AllowedTopic(topic) {
  880. return ErrInvalidArgument
  881. }
  882. }
  883. tx, err := a.db.Begin()
  884. if err != nil {
  885. return err
  886. }
  887. defer tx.Rollback()
  888. for _, topic := range topics {
  889. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  890. return err
  891. }
  892. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  893. return err
  894. }
  895. }
  896. return tx.Commit()
  897. }
  898. // DefaultAccess returns the default read/write access if no access control entry matches
  899. func (a *Manager) DefaultAccess() Permission {
  900. return a.defaultAccess
  901. }
  902. // CreateTier creates a new tier in the database
  903. func (a *Manager) CreateTier(tier *Tier) error {
  904. tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  905. if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
  906. return err
  907. }
  908. return nil
  909. }
  910. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  911. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  912. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  913. return err
  914. }
  915. return nil
  916. }
  917. // Tiers returns a list of all Tier structs
  918. func (a *Manager) Tiers() ([]*Tier, error) {
  919. rows, err := a.db.Query(selectTiersQuery)
  920. if err != nil {
  921. return nil, err
  922. }
  923. defer rows.Close()
  924. tiers := make([]*Tier, 0)
  925. for {
  926. tier, err := a.readTier(rows)
  927. if err == ErrTierNotFound {
  928. break
  929. } else if err != nil {
  930. return nil, err
  931. }
  932. tiers = append(tiers, tier)
  933. }
  934. return tiers, nil
  935. }
  936. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  937. func (a *Manager) Tier(code string) (*Tier, error) {
  938. rows, err := a.db.Query(selectTierByCodeQuery, code)
  939. if err != nil {
  940. return nil, err
  941. }
  942. defer rows.Close()
  943. return a.readTier(rows)
  944. }
  945. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  946. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  947. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  948. if err != nil {
  949. return nil, err
  950. }
  951. defer rows.Close()
  952. return a.readTier(rows)
  953. }
  954. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  955. var id, code, name string
  956. var stripePriceID sql.NullString
  957. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
  958. if !rows.Next() {
  959. return nil, ErrTierNotFound
  960. }
  961. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  962. return nil, err
  963. } else if err := rows.Err(); err != nil {
  964. return nil, err
  965. }
  966. // When changed, note readUser() as well
  967. return &Tier{
  968. ID: id,
  969. Code: code,
  970. Name: name,
  971. Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
  972. MessagesLimit: messagesLimit.Int64,
  973. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  974. EmailsLimit: emailsLimit.Int64,
  975. ReservationsLimit: reservationsLimit.Int64,
  976. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  977. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  978. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  979. StripePriceID: stripePriceID.String, // May be empty
  980. }, nil
  981. }
  982. func toSQLWildcard(s string) string {
  983. return strings.ReplaceAll(s, "*", "%")
  984. }
  985. func fromSQLWildcard(s string) string {
  986. return strings.ReplaceAll(s, "%", "*")
  987. }
  988. func runStartupQueries(db *sql.DB, startupQueries string) error {
  989. if _, err := db.Exec(startupQueries); err != nil {
  990. return err
  991. }
  992. if _, err := db.Exec(builtinStartupQueries); err != nil {
  993. return err
  994. }
  995. return nil
  996. }
  997. func setupDB(db *sql.DB) error {
  998. // If 'schemaVersion' table does not exist, this must be a new database
  999. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1000. if err != nil {
  1001. return setupNewDB(db)
  1002. }
  1003. defer rowsSV.Close()
  1004. // If 'schemaVersion' table exists, read version and potentially upgrade
  1005. schemaVersion := 0
  1006. if !rowsSV.Next() {
  1007. return errors.New("cannot determine schema version: database file may be corrupt")
  1008. }
  1009. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1010. return err
  1011. }
  1012. rowsSV.Close()
  1013. // Do migrations
  1014. if schemaVersion == currentSchemaVersion {
  1015. return nil
  1016. } else if schemaVersion == 1 {
  1017. return migrateFrom1(db)
  1018. }
  1019. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  1020. }
  1021. func setupNewDB(db *sql.DB) error {
  1022. if _, err := db.Exec(createTablesQueries); err != nil {
  1023. return err
  1024. }
  1025. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1026. return err
  1027. }
  1028. return nil
  1029. }
  1030. func migrateFrom1(db *sql.DB) error {
  1031. log.Info("Migrating user database schema: from 1 to 2")
  1032. tx, err := db.Begin()
  1033. if err != nil {
  1034. return err
  1035. }
  1036. defer tx.Rollback()
  1037. // Rename user -> user_old, and create new tables
  1038. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  1039. return err
  1040. }
  1041. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  1042. return err
  1043. }
  1044. // Insert users from user_old into new user table, with ID and sync_topic
  1045. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1046. if err != nil {
  1047. return err
  1048. }
  1049. defer rows.Close()
  1050. usernames := make([]string, 0)
  1051. for rows.Next() {
  1052. var username string
  1053. if err := rows.Scan(&username); err != nil {
  1054. return err
  1055. }
  1056. usernames = append(usernames, username)
  1057. }
  1058. if err := rows.Close(); err != nil {
  1059. return err
  1060. }
  1061. for _, username := range usernames {
  1062. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1063. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1064. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1065. return err
  1066. }
  1067. }
  1068. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1069. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1070. return err
  1071. }
  1072. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1073. return err
  1074. }
  1075. if err := tx.Commit(); err != nil {
  1076. return err
  1077. }
  1078. return nil // Update this when a new version is added
  1079. }
  1080. func nullString(s string) sql.NullString {
  1081. if s == "" {
  1082. return sql.NullString{}
  1083. }
  1084. return sql.NullString{String: s, Valid: true}
  1085. }
  1086. func nullInt64(v int64) sql.NullInt64 {
  1087. if v == 0 {
  1088. return sql.NullInt64{}
  1089. }
  1090. return sql.NullInt64{Int64: v, Valid: true}
  1091. }