manager.go 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371
  1. // Package user deals with authentication and authorization against topics
  2. package user
  3. import (
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. _ "github.com/mattn/go-sqlite3" // SQLite driver
  9. "github.com/stripe/stripe-go/v74"
  10. "golang.org/x/crypto/bcrypt"
  11. "heckel.io/ntfy/log"
  12. "heckel.io/ntfy/util"
  13. "net/netip"
  14. "strings"
  15. "sync"
  16. "time"
  17. )
  18. const (
  19. tierIDPrefix = "ti_"
  20. tierIDLength = 8
  21. syncTopicPrefix = "st_"
  22. syncTopicLength = 16
  23. userIDPrefix = "u_"
  24. userIDLength = 12
  25. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
  26. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  27. tokenPrefix = "tk_"
  28. tokenLength = 32
  29. tokenMaxCount = 20 // Only keep this many tokens in the table per user
  30. tag = "user_manager"
  31. )
  32. // Default constants that may be overridden by configs
  33. const (
  34. DefaultUserStatsQueueWriterInterval = 33 * time.Second
  35. DefaultUserPasswordBcryptCost = 10
  36. )
  37. var (
  38. errNoTokenProvided = errors.New("no token provided")
  39. errTopicOwnedByOthers = errors.New("topic owned by others")
  40. errNoRows = errors.New("no rows found")
  41. )
  42. // Manager-related queries
  43. const (
  44. createTablesQueriesNoTx = `
  45. CREATE TABLE IF NOT EXISTS tier (
  46. id TEXT PRIMARY KEY,
  47. code TEXT NOT NULL,
  48. name TEXT NOT NULL,
  49. messages_limit INT NOT NULL,
  50. messages_expiry_duration INT NOT NULL,
  51. emails_limit INT NOT NULL,
  52. reservations_limit INT NOT NULL,
  53. attachment_file_size_limit INT NOT NULL,
  54. attachment_total_size_limit INT NOT NULL,
  55. attachment_expiry_duration INT NOT NULL,
  56. attachment_bandwidth_limit INT NOT NULL,
  57. stripe_price_id TEXT
  58. );
  59. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  60. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  61. CREATE TABLE IF NOT EXISTS user (
  62. id TEXT PRIMARY KEY,
  63. tier_id TEXT,
  64. user TEXT NOT NULL,
  65. pass TEXT NOT NULL,
  66. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  67. prefs JSON NOT NULL DEFAULT '{}',
  68. sync_topic TEXT NOT NULL,
  69. stats_messages INT NOT NULL DEFAULT (0),
  70. stats_emails INT NOT NULL DEFAULT (0),
  71. stripe_customer_id TEXT,
  72. stripe_subscription_id TEXT,
  73. stripe_subscription_status TEXT,
  74. stripe_subscription_paid_until INT,
  75. stripe_subscription_cancel_at INT,
  76. created INT NOT NULL,
  77. deleted INT,
  78. FOREIGN KEY (tier_id) REFERENCES tier (id)
  79. );
  80. CREATE UNIQUE INDEX idx_user ON user (user);
  81. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  82. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  83. CREATE TABLE IF NOT EXISTS user_access (
  84. user_id TEXT NOT NULL,
  85. topic TEXT NOT NULL,
  86. read INT NOT NULL,
  87. write INT NOT NULL,
  88. owner_user_id INT,
  89. PRIMARY KEY (user_id, topic),
  90. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  91. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  92. );
  93. CREATE TABLE IF NOT EXISTS user_token (
  94. user_id TEXT NOT NULL,
  95. token TEXT NOT NULL,
  96. label TEXT NOT NULL,
  97. last_access INT NOT NULL,
  98. last_origin TEXT NOT NULL,
  99. expires INT NOT NULL,
  100. PRIMARY KEY (user_id, token),
  101. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  102. );
  103. CREATE TABLE IF NOT EXISTS schemaVersion (
  104. id INT PRIMARY KEY,
  105. version INT NOT NULL
  106. );
  107. INSERT INTO user (id, user, pass, role, sync_topic, created)
  108. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
  109. ON CONFLICT (id) DO NOTHING;
  110. `
  111. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  112. builtinStartupQueries = `
  113. PRAGMA foreign_keys = ON;
  114. `
  115. selectUserByIDQuery = `
  116. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  117. FROM user u
  118. LEFT JOIN tier t on t.id = u.tier_id
  119. WHERE u.id = ?
  120. `
  121. selectUserByNameQuery = `
  122. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  123. FROM user u
  124. LEFT JOIN tier t on t.id = u.tier_id
  125. WHERE user = ?
  126. `
  127. selectUserByTokenQuery = `
  128. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  129. FROM user u
  130. JOIN user_token tk on u.id = tk.user_id
  131. LEFT JOIN tier t on t.id = u.tier_id
  132. WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
  133. `
  134. selectUserByStripeCustomerIDQuery = `
  135. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  136. FROM user u
  137. LEFT JOIN tier t on t.id = u.tier_id
  138. WHERE u.stripe_customer_id = ?
  139. `
  140. selectTopicPermsQuery = `
  141. SELECT read, write
  142. FROM user_access a
  143. JOIN user u ON u.id = a.user_id
  144. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  145. ORDER BY u.user DESC
  146. `
  147. insertUserQuery = `
  148. INSERT INTO user (id, user, pass, role, sync_topic, created)
  149. VALUES (?, ?, ?, ?, ?, ?)
  150. `
  151. selectUsernamesQuery = `
  152. SELECT user
  153. FROM user
  154. ORDER BY
  155. CASE role
  156. WHEN 'admin' THEN 1
  157. WHEN 'anonymous' THEN 3
  158. ELSE 2
  159. END, user
  160. `
  161. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  162. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  163. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
  164. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
  165. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  166. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  167. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  168. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  169. upsertUserAccessQuery = `
  170. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  171. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  172. ON CONFLICT (user_id, topic)
  173. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  174. `
  175. selectUserAccessQuery = `
  176. SELECT topic, read, write
  177. FROM user_access
  178. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  179. ORDER BY write DESC, read DESC, topic
  180. `
  181. selectUserReservationsQuery = `
  182. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  183. FROM user_access a_user
  184. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  185. WHERE a_user.user_id = a_user.owner_user_id
  186. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  187. ORDER BY a_user.topic
  188. `
  189. selectUserReservationsCountQuery = `
  190. SELECT COUNT(*)
  191. FROM user_access
  192. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  193. `
  194. selectUserHasReservationQuery = `
  195. SELECT COUNT(*)
  196. FROM user_access
  197. WHERE user_id = owner_user_id
  198. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  199. AND topic = ?
  200. `
  201. selectOtherAccessCountQuery = `
  202. SELECT COUNT(*)
  203. FROM user_access
  204. WHERE (topic = ? OR ? LIKE topic)
  205. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  206. `
  207. deleteAllAccessQuery = `DELETE FROM user_access`
  208. deleteUserAccessQuery = `
  209. DELETE FROM user_access
  210. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  211. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  212. `
  213. deleteTopicAccessQuery = `
  214. DELETE FROM user_access
  215. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  216. AND topic = ?
  217. `
  218. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  219. selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
  220. selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
  221. insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
  222. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
  223. updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
  224. updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
  225. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  226. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  227. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
  228. deleteExcessTokensQuery = `
  229. DELETE FROM user_token
  230. WHERE (user_id, token) NOT IN (
  231. SELECT user_id, token
  232. FROM user_token
  233. WHERE user_id = ?
  234. ORDER BY expires DESC
  235. LIMIT ?
  236. )
  237. `
  238. insertTierQuery = `
  239. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
  240. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  241. `
  242. updateTierQuery = `
  243. UPDATE tier
  244. SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ?
  245. WHERE code = ?
  246. `
  247. selectTiersQuery = `
  248. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  249. FROM tier
  250. `
  251. selectTierByCodeQuery = `
  252. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  253. FROM tier
  254. WHERE code = ?
  255. `
  256. selectTierByPriceIDQuery = `
  257. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  258. FROM tier
  259. WHERE stripe_price_id = ?
  260. `
  261. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  262. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  263. deleteTierQuery = `DELETE FROM tier WHERE code = ?`
  264. updateBillingQuery = `
  265. UPDATE user
  266. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  267. WHERE user = ?
  268. `
  269. )
  270. // Schema management queries
  271. const (
  272. currentSchemaVersion = 2
  273. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  274. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  275. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  276. // 1 -> 2 (complex migration!)
  277. migrate1To2RenameUserTableQueryNoTx = `
  278. ALTER TABLE user RENAME TO user_old;
  279. `
  280. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  281. migrate1To2InsertUserNoTx = `
  282. INSERT INTO user (id, user, pass, role, sync_topic, created)
  283. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  284. `
  285. migrate1To2InsertFromOldTablesAndDropNoTx = `
  286. INSERT INTO user_access (user_id, topic, read, write)
  287. SELECT u.id, a.topic, a.read, a.write
  288. FROM user u
  289. JOIN access a ON u.user = a.user;
  290. DROP TABLE access;
  291. DROP TABLE user_old;
  292. `
  293. )
  294. var (
  295. migrations = map[int]func(db *sql.DB) error{
  296. 1: migrateFrom1,
  297. }
  298. )
  299. // Manager is an implementation of Manager. It stores users and access control list
  300. // in a SQLite database.
  301. type Manager struct {
  302. db *sql.DB
  303. defaultAccess Permission // Default permission if no ACL matches
  304. statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
  305. tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
  306. bcryptCost int // Makes testing easier
  307. mu sync.Mutex
  308. }
  309. var _ Auther = (*Manager)(nil)
  310. // NewManager creates a new Manager instance
  311. func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) {
  312. db, err := sql.Open("sqlite3", filename)
  313. if err != nil {
  314. return nil, err
  315. }
  316. if err := setupDB(db); err != nil {
  317. return nil, err
  318. }
  319. if err := runStartupQueries(db, startupQueries); err != nil {
  320. return nil, err
  321. }
  322. manager := &Manager{
  323. db: db,
  324. defaultAccess: defaultAccess,
  325. statsQueue: make(map[string]*Stats),
  326. tokenQueue: make(map[string]*TokenUpdate),
  327. bcryptCost: bcryptCost,
  328. }
  329. go manager.asyncQueueWriter(queueWriterInterval)
  330. return manager, nil
  331. }
  332. // Authenticate checks username and password and returns a User if correct, and the user has not been
  333. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  334. // the password is correct or incorrect.
  335. func (a *Manager) Authenticate(username, password string) (*User, error) {
  336. if username == Everyone {
  337. return nil, ErrUnauthenticated
  338. }
  339. user, err := a.User(username)
  340. if err != nil {
  341. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
  342. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  343. return nil, ErrUnauthenticated
  344. } else if user.Deleted {
  345. log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
  346. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  347. return nil, ErrUnauthenticated
  348. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  349. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
  350. return nil, ErrUnauthenticated
  351. }
  352. return user, nil
  353. }
  354. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  355. // The method sets the User.Token value to the token that was used for authentication.
  356. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  357. if len(token) != tokenLength {
  358. return nil, ErrUnauthenticated
  359. }
  360. user, err := a.userByToken(token)
  361. if err != nil {
  362. log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
  363. return nil, ErrUnauthenticated
  364. }
  365. user.Token = token
  366. return user, nil
  367. }
  368. // CreateToken generates a random token for the given user and returns it. The token expires
  369. // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
  370. // given user, if there are too many of them.
  371. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
  372. token := util.RandomStringPrefix(tokenPrefix, tokenLength)
  373. tx, err := a.db.Begin()
  374. if err != nil {
  375. return nil, err
  376. }
  377. defer tx.Rollback()
  378. access := time.Now()
  379. if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
  380. return nil, err
  381. }
  382. rows, err := tx.Query(selectTokenCountQuery, userID)
  383. if err != nil {
  384. return nil, err
  385. }
  386. defer rows.Close()
  387. if !rows.Next() {
  388. return nil, errNoRows
  389. }
  390. var tokenCount int
  391. if err := rows.Scan(&tokenCount); err != nil {
  392. return nil, err
  393. }
  394. if tokenCount >= tokenMaxCount {
  395. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  396. // on two indices, whereas the query below is a full table scan.
  397. if _, err := tx.Exec(deleteExcessTokensQuery, userID, tokenMaxCount); err != nil {
  398. return nil, err
  399. }
  400. }
  401. if err := tx.Commit(); err != nil {
  402. return nil, err
  403. }
  404. return &Token{
  405. Value: token,
  406. Label: label,
  407. LastAccess: access,
  408. LastOrigin: origin,
  409. Expires: expires,
  410. }, nil
  411. }
  412. // Tokens returns all existing tokens for the user with the given user ID
  413. func (a *Manager) Tokens(userID string) ([]*Token, error) {
  414. rows, err := a.db.Query(selectTokensQuery, userID)
  415. if err != nil {
  416. return nil, err
  417. }
  418. defer rows.Close()
  419. tokens := make([]*Token, 0)
  420. for {
  421. token, err := a.readToken(rows)
  422. if err == ErrTokenNotFound {
  423. break
  424. } else if err != nil {
  425. return nil, err
  426. }
  427. tokens = append(tokens, token)
  428. }
  429. return tokens, nil
  430. }
  431. // Token returns a specific token for a user
  432. func (a *Manager) Token(userID, token string) (*Token, error) {
  433. rows, err := a.db.Query(selectTokenQuery, userID, token)
  434. if err != nil {
  435. return nil, err
  436. }
  437. defer rows.Close()
  438. return a.readToken(rows)
  439. }
  440. func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
  441. var token, label, lastOrigin string
  442. var lastAccess, expires int64
  443. if !rows.Next() {
  444. return nil, ErrTokenNotFound
  445. }
  446. if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
  447. return nil, err
  448. } else if err := rows.Err(); err != nil {
  449. return nil, err
  450. }
  451. lastOriginIP, err := netip.ParseAddr(lastOrigin)
  452. if err != nil {
  453. lastOriginIP = netip.IPv4Unspecified()
  454. }
  455. return &Token{
  456. Value: token,
  457. Label: label,
  458. LastAccess: time.Unix(lastAccess, 0),
  459. LastOrigin: lastOriginIP,
  460. Expires: time.Unix(expires, 0),
  461. }, nil
  462. }
  463. // ChangeToken updates a token's label and/or expiry date
  464. func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
  465. if token == "" {
  466. return nil, errNoTokenProvided
  467. }
  468. tx, err := a.db.Begin()
  469. if err != nil {
  470. return nil, err
  471. }
  472. defer tx.Rollback()
  473. if label != nil {
  474. if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil {
  475. return nil, err
  476. }
  477. }
  478. if expires != nil {
  479. if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil {
  480. return nil, err
  481. }
  482. }
  483. if err := tx.Commit(); err != nil {
  484. return nil, err
  485. }
  486. return a.Token(userID, token)
  487. }
  488. // RemoveToken deletes the token defined in User.Token
  489. func (a *Manager) RemoveToken(userID, token string) error {
  490. if token == "" {
  491. return errNoTokenProvided
  492. }
  493. if _, err := a.db.Exec(deleteTokenQuery, userID, token); err != nil {
  494. return err
  495. }
  496. return nil
  497. }
  498. // RemoveExpiredTokens deletes all expired tokens from the database
  499. func (a *Manager) RemoveExpiredTokens() error {
  500. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  501. return err
  502. }
  503. return nil
  504. }
  505. // RemoveDeletedUsers deletes all users that have been marked deleted for
  506. func (a *Manager) RemoveDeletedUsers() error {
  507. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  508. return err
  509. }
  510. return nil
  511. }
  512. // ChangeSettings persists the user settings
  513. func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error {
  514. b, err := json.Marshal(prefs)
  515. if err != nil {
  516. return err
  517. }
  518. if _, err := a.db.Exec(updateUserPrefsQuery, string(b), userID); err != nil {
  519. return err
  520. }
  521. return nil
  522. }
  523. // ResetStats resets all user stats in the user database. This touches all users.
  524. func (a *Manager) ResetStats() error {
  525. a.mu.Lock() // Includes database query to avoid races!
  526. defer a.mu.Unlock()
  527. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  528. return err
  529. }
  530. a.statsQueue = make(map[string]*Stats)
  531. return nil
  532. }
  533. // EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  534. // batches at a regular interval
  535. func (a *Manager) EnqueueUserStats(userID string, stats *Stats) {
  536. a.mu.Lock()
  537. defer a.mu.Unlock()
  538. a.statsQueue[userID] = stats
  539. }
  540. // EnqueueTokenUpdate adds the token update to a queue which writes out token access times
  541. // in batches at a regular interval
  542. func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
  543. a.mu.Lock()
  544. defer a.mu.Unlock()
  545. a.tokenQueue[tokenID] = update
  546. }
  547. func (a *Manager) asyncQueueWriter(interval time.Duration) {
  548. ticker := time.NewTicker(interval)
  549. for range ticker.C {
  550. if err := a.writeUserStatsQueue(); err != nil {
  551. log.Tag(tag).Err(err).Warn("Writing user stats queue failed")
  552. }
  553. if err := a.writeTokenUpdateQueue(); err != nil {
  554. log.Tag(tag).Err(err).Warn("Writing token update queue failed")
  555. }
  556. }
  557. }
  558. func (a *Manager) writeUserStatsQueue() error {
  559. a.mu.Lock()
  560. if len(a.statsQueue) == 0 {
  561. a.mu.Unlock()
  562. log.Tag(tag).Trace("No user stats updates to commit")
  563. return nil
  564. }
  565. statsQueue := a.statsQueue
  566. a.statsQueue = make(map[string]*Stats)
  567. a.mu.Unlock()
  568. tx, err := a.db.Begin()
  569. if err != nil {
  570. return err
  571. }
  572. defer tx.Rollback()
  573. log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
  574. for userID, update := range statsQueue {
  575. log.
  576. Tag(tag).
  577. Fields(log.Context{
  578. "user_id": userID,
  579. "messages_count": update.Messages,
  580. "emails_count": update.Emails,
  581. }).
  582. Trace("Updating stats for user %s", userID)
  583. if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil {
  584. return err
  585. }
  586. }
  587. return tx.Commit()
  588. }
  589. func (a *Manager) writeTokenUpdateQueue() error {
  590. a.mu.Lock()
  591. if len(a.tokenQueue) == 0 {
  592. a.mu.Unlock()
  593. log.Tag(tag).Trace("No token updates to commit")
  594. return nil
  595. }
  596. tokenQueue := a.tokenQueue
  597. a.tokenQueue = make(map[string]*TokenUpdate)
  598. a.mu.Unlock()
  599. tx, err := a.db.Begin()
  600. if err != nil {
  601. return err
  602. }
  603. defer tx.Rollback()
  604. log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
  605. for tokenID, update := range tokenQueue {
  606. log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
  607. if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
  608. return err
  609. }
  610. }
  611. return tx.Commit()
  612. }
  613. // Authorize returns nil if the given user has access to the given topic using the desired
  614. // permission. The user param may be nil to signal an anonymous user.
  615. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  616. if user != nil && user.Role == RoleAdmin {
  617. return nil // Admin can do everything
  618. }
  619. username := Everyone
  620. if user != nil {
  621. username = user.Name
  622. }
  623. // Select the read/write permissions for this user/topic combo. The query may return two
  624. // rows (one for everyone, and one for the user), but prioritizes the user.
  625. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  626. if err != nil {
  627. return err
  628. }
  629. defer rows.Close()
  630. if !rows.Next() {
  631. return a.resolvePerms(a.defaultAccess, perm)
  632. }
  633. var read, write bool
  634. if err := rows.Scan(&read, &write); err != nil {
  635. return err
  636. } else if err := rows.Err(); err != nil {
  637. return err
  638. }
  639. return a.resolvePerms(NewPermission(read, write), perm)
  640. }
  641. func (a *Manager) resolvePerms(base, perm Permission) error {
  642. if perm == PermissionRead && base.IsRead() {
  643. return nil
  644. } else if perm == PermissionWrite && base.IsWrite() {
  645. return nil
  646. }
  647. return ErrUnauthorized
  648. }
  649. // AddUser adds a user with the given username, password and role
  650. func (a *Manager) AddUser(username, password string, role Role) error {
  651. if !AllowedUsername(username) || !AllowedRole(role) {
  652. return ErrInvalidArgument
  653. }
  654. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  655. if err != nil {
  656. return err
  657. }
  658. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  659. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  660. if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
  661. return err
  662. }
  663. return nil
  664. }
  665. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  666. // if the user did not exist in the first place.
  667. func (a *Manager) RemoveUser(username string) error {
  668. if !AllowedUsername(username) {
  669. return ErrInvalidArgument
  670. }
  671. // Rows in user_access, user_token, etc. are deleted via foreign keys
  672. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  673. return err
  674. }
  675. return nil
  676. }
  677. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  678. // successful auth via Authenticate. A background process will delete the user at a later date.
  679. func (a *Manager) MarkUserRemoved(user *User) error {
  680. if !AllowedUsername(user.Name) {
  681. return ErrInvalidArgument
  682. }
  683. tx, err := a.db.Begin()
  684. if err != nil {
  685. return err
  686. }
  687. defer tx.Rollback()
  688. if _, err := tx.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  689. return err
  690. }
  691. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  692. return err
  693. }
  694. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  695. return err
  696. }
  697. return tx.Commit()
  698. }
  699. // Users returns a list of users. It always also returns the Everyone user ("*").
  700. func (a *Manager) Users() ([]*User, error) {
  701. rows, err := a.db.Query(selectUsernamesQuery)
  702. if err != nil {
  703. return nil, err
  704. }
  705. defer rows.Close()
  706. usernames := make([]string, 0)
  707. for rows.Next() {
  708. var username string
  709. if err := rows.Scan(&username); err != nil {
  710. return nil, err
  711. } else if err := rows.Err(); err != nil {
  712. return nil, err
  713. }
  714. usernames = append(usernames, username)
  715. }
  716. rows.Close()
  717. users := make([]*User, 0)
  718. for _, username := range usernames {
  719. user, err := a.User(username)
  720. if err != nil {
  721. return nil, err
  722. }
  723. users = append(users, user)
  724. }
  725. return users, nil
  726. }
  727. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  728. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  729. func (a *Manager) User(username string) (*User, error) {
  730. rows, err := a.db.Query(selectUserByNameQuery, username)
  731. if err != nil {
  732. return nil, err
  733. }
  734. return a.readUser(rows)
  735. }
  736. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  737. func (a *Manager) UserByID(id string) (*User, error) {
  738. rows, err := a.db.Query(selectUserByIDQuery, id)
  739. if err != nil {
  740. return nil, err
  741. }
  742. return a.readUser(rows)
  743. }
  744. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  745. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  746. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  747. if err != nil {
  748. return nil, err
  749. }
  750. return a.readUser(rows)
  751. }
  752. func (a *Manager) userByToken(token string) (*User, error) {
  753. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  754. if err != nil {
  755. return nil, err
  756. }
  757. return a.readUser(rows)
  758. }
  759. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  760. defer rows.Close()
  761. var id, username, hash, role, prefs, syncTopic string
  762. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
  763. var messages, emails int64
  764. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  765. if !rows.Next() {
  766. return nil, ErrUserNotFound
  767. }
  768. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
  769. return nil, err
  770. } else if err := rows.Err(); err != nil {
  771. return nil, err
  772. }
  773. user := &User{
  774. ID: id,
  775. Name: username,
  776. Hash: hash,
  777. Role: Role(role),
  778. Prefs: &Prefs{},
  779. SyncTopic: syncTopic,
  780. Stats: &Stats{
  781. Messages: messages,
  782. Emails: emails,
  783. },
  784. Billing: &Billing{
  785. StripeCustomerID: stripeCustomerID.String, // May be empty
  786. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  787. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  788. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  789. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  790. },
  791. Deleted: deleted.Valid,
  792. }
  793. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  794. return nil, err
  795. }
  796. if tierCode.Valid {
  797. // See readTier() when this is changed!
  798. user.Tier = &Tier{
  799. ID: tierID.String,
  800. Code: tierCode.String,
  801. Name: tierName.String,
  802. MessageLimit: messagesLimit.Int64,
  803. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  804. EmailLimit: emailsLimit.Int64,
  805. ReservationLimit: reservationsLimit.Int64,
  806. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  807. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  808. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  809. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  810. StripePriceID: stripePriceID.String, // May be empty
  811. }
  812. }
  813. return user, nil
  814. }
  815. // Grants returns all user-specific access control entries
  816. func (a *Manager) Grants(username string) ([]Grant, error) {
  817. rows, err := a.db.Query(selectUserAccessQuery, username)
  818. if err != nil {
  819. return nil, err
  820. }
  821. defer rows.Close()
  822. grants := make([]Grant, 0)
  823. for rows.Next() {
  824. var topic string
  825. var read, write bool
  826. if err := rows.Scan(&topic, &read, &write); err != nil {
  827. return nil, err
  828. } else if err := rows.Err(); err != nil {
  829. return nil, err
  830. }
  831. grants = append(grants, Grant{
  832. TopicPattern: fromSQLWildcard(topic),
  833. Allow: NewPermission(read, write),
  834. })
  835. }
  836. return grants, nil
  837. }
  838. // Reservations returns all user-owned topics, and the associated everyone-access
  839. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  840. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  841. if err != nil {
  842. return nil, err
  843. }
  844. defer rows.Close()
  845. reservations := make([]Reservation, 0)
  846. for rows.Next() {
  847. var topic string
  848. var ownerRead, ownerWrite bool
  849. var everyoneRead, everyoneWrite sql.NullBool
  850. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  851. return nil, err
  852. } else if err := rows.Err(); err != nil {
  853. return nil, err
  854. }
  855. reservations = append(reservations, Reservation{
  856. Topic: topic,
  857. Owner: NewPermission(ownerRead, ownerWrite),
  858. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  859. })
  860. }
  861. return reservations, nil
  862. }
  863. // HasReservation returns true if the given topic access is owned by the user
  864. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  865. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  866. if err != nil {
  867. return false, err
  868. }
  869. defer rows.Close()
  870. if !rows.Next() {
  871. return false, errNoRows
  872. }
  873. var count int64
  874. if err := rows.Scan(&count); err != nil {
  875. return false, err
  876. }
  877. return count > 0, nil
  878. }
  879. // ReservationsCount returns the number of reservations owned by this user
  880. func (a *Manager) ReservationsCount(username string) (int64, error) {
  881. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  882. if err != nil {
  883. return 0, err
  884. }
  885. defer rows.Close()
  886. if !rows.Next() {
  887. return 0, errNoRows
  888. }
  889. var count int64
  890. if err := rows.Scan(&count); err != nil {
  891. return 0, err
  892. }
  893. return count, nil
  894. }
  895. // ChangePassword changes a user's password
  896. func (a *Manager) ChangePassword(username, password string) error {
  897. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  898. if err != nil {
  899. return err
  900. }
  901. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  902. return err
  903. }
  904. return nil
  905. }
  906. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  907. // all existing access control entries (Grant) are removed, since they are no longer needed.
  908. func (a *Manager) ChangeRole(username string, role Role) error {
  909. if !AllowedUsername(username) || !AllowedRole(role) {
  910. return ErrInvalidArgument
  911. }
  912. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  913. return err
  914. }
  915. if role == RoleAdmin {
  916. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  917. return err
  918. }
  919. }
  920. return nil
  921. }
  922. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  923. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  924. func (a *Manager) ChangeTier(username, tier string) error {
  925. if !AllowedUsername(username) {
  926. return ErrInvalidArgument
  927. }
  928. t, err := a.Tier(tier)
  929. if err != nil {
  930. return err
  931. } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
  932. return err
  933. }
  934. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  935. return err
  936. }
  937. return nil
  938. }
  939. // ResetTier removes the tier from the given user
  940. func (a *Manager) ResetTier(username string) error {
  941. if !AllowedUsername(username) && username != Everyone && username != "" {
  942. return ErrInvalidArgument
  943. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  944. return err
  945. }
  946. _, err := a.db.Exec(deleteUserTierQuery, username)
  947. return err
  948. }
  949. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  950. u, err := a.User(username)
  951. if err != nil {
  952. return err
  953. }
  954. if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
  955. reservations, err := a.Reservations(username)
  956. if err != nil {
  957. return err
  958. } else if int64(len(reservations)) > reservationsLimit {
  959. return ErrTooManyReservations
  960. }
  961. }
  962. return nil
  963. }
  964. // AllowReservation tests if a user may create an access control entry for the given topic.
  965. // If there are any ACL entries that are not owned by the user, an error is returned.
  966. func (a *Manager) AllowReservation(username string, topic string) error {
  967. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  968. return ErrInvalidArgument
  969. }
  970. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  971. if err != nil {
  972. return err
  973. }
  974. defer rows.Close()
  975. if !rows.Next() {
  976. return errNoRows
  977. }
  978. var otherCount int
  979. if err := rows.Scan(&otherCount); err != nil {
  980. return err
  981. }
  982. if otherCount > 0 {
  983. return errTopicOwnedByOthers
  984. }
  985. return nil
  986. }
  987. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  988. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  989. // owner may either be a user (username), or the system (empty).
  990. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  991. if !AllowedUsername(username) && username != Everyone {
  992. return ErrInvalidArgument
  993. } else if !AllowedTopicPattern(topicPattern) {
  994. return ErrInvalidArgument
  995. }
  996. owner := ""
  997. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  998. return err
  999. }
  1000. return nil
  1001. }
  1002. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  1003. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  1004. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  1005. if !AllowedUsername(username) && username != Everyone && username != "" {
  1006. return ErrInvalidArgument
  1007. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  1008. return ErrInvalidArgument
  1009. }
  1010. if username == "" && topicPattern == "" {
  1011. _, err := a.db.Exec(deleteAllAccessQuery, username)
  1012. return err
  1013. } else if topicPattern == "" {
  1014. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  1015. return err
  1016. }
  1017. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  1018. return err
  1019. }
  1020. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  1021. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  1022. // can modify or delete them.
  1023. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  1024. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  1025. return ErrInvalidArgument
  1026. }
  1027. tx, err := a.db.Begin()
  1028. if err != nil {
  1029. return err
  1030. }
  1031. defer tx.Rollback()
  1032. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  1033. return err
  1034. }
  1035. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  1036. return err
  1037. }
  1038. return tx.Commit()
  1039. }
  1040. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  1041. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  1042. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  1043. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  1044. return ErrInvalidArgument
  1045. }
  1046. for _, topic := range topics {
  1047. if !AllowedTopic(topic) {
  1048. return ErrInvalidArgument
  1049. }
  1050. }
  1051. tx, err := a.db.Begin()
  1052. if err != nil {
  1053. return err
  1054. }
  1055. defer tx.Rollback()
  1056. for _, topic := range topics {
  1057. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  1058. return err
  1059. }
  1060. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  1061. return err
  1062. }
  1063. }
  1064. return tx.Commit()
  1065. }
  1066. // DefaultAccess returns the default read/write access if no access control entry matches
  1067. func (a *Manager) DefaultAccess() Permission {
  1068. return a.defaultAccess
  1069. }
  1070. // AddTier creates a new tier in the database
  1071. func (a *Manager) AddTier(tier *Tier) error {
  1072. if tier.ID == "" {
  1073. tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  1074. }
  1075. if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID)); err != nil {
  1076. return err
  1077. }
  1078. return nil
  1079. }
  1080. // UpdateTier updates a tier's properties in the database
  1081. func (a *Manager) UpdateTier(tier *Tier) error {
  1082. if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID), tier.Code); err != nil {
  1083. return err
  1084. }
  1085. return nil
  1086. }
  1087. // RemoveTier deletes the tier with the given code
  1088. func (a *Manager) RemoveTier(code string) error {
  1089. if !AllowedTier(code) {
  1090. return ErrInvalidArgument
  1091. }
  1092. // This fails if any user has this tier
  1093. if _, err := a.db.Exec(deleteTierQuery, code); err != nil {
  1094. return err
  1095. }
  1096. return nil
  1097. }
  1098. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  1099. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  1100. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  1101. return err
  1102. }
  1103. return nil
  1104. }
  1105. // Tiers returns a list of all Tier structs
  1106. func (a *Manager) Tiers() ([]*Tier, error) {
  1107. rows, err := a.db.Query(selectTiersQuery)
  1108. if err != nil {
  1109. return nil, err
  1110. }
  1111. defer rows.Close()
  1112. tiers := make([]*Tier, 0)
  1113. for {
  1114. tier, err := a.readTier(rows)
  1115. if err == ErrTierNotFound {
  1116. break
  1117. } else if err != nil {
  1118. return nil, err
  1119. }
  1120. tiers = append(tiers, tier)
  1121. }
  1122. return tiers, nil
  1123. }
  1124. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  1125. func (a *Manager) Tier(code string) (*Tier, error) {
  1126. rows, err := a.db.Query(selectTierByCodeQuery, code)
  1127. if err != nil {
  1128. return nil, err
  1129. }
  1130. defer rows.Close()
  1131. return a.readTier(rows)
  1132. }
  1133. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  1134. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  1135. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  1136. if err != nil {
  1137. return nil, err
  1138. }
  1139. defer rows.Close()
  1140. return a.readTier(rows)
  1141. }
  1142. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  1143. var id, code, name string
  1144. var stripePriceID sql.NullString
  1145. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
  1146. if !rows.Next() {
  1147. return nil, ErrTierNotFound
  1148. }
  1149. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
  1150. return nil, err
  1151. } else if err := rows.Err(); err != nil {
  1152. return nil, err
  1153. }
  1154. // When changed, note readUser() as well
  1155. return &Tier{
  1156. ID: id,
  1157. Code: code,
  1158. Name: name,
  1159. MessageLimit: messagesLimit.Int64,
  1160. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1161. EmailLimit: emailsLimit.Int64,
  1162. ReservationLimit: reservationsLimit.Int64,
  1163. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1164. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1165. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1166. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1167. StripePriceID: stripePriceID.String, // May be empty
  1168. }, nil
  1169. }
  1170. // Close closes the underlying database
  1171. func (a *Manager) Close() error {
  1172. return a.db.Close()
  1173. }
  1174. func toSQLWildcard(s string) string {
  1175. return strings.ReplaceAll(s, "*", "%")
  1176. }
  1177. func fromSQLWildcard(s string) string {
  1178. return strings.ReplaceAll(s, "%", "*")
  1179. }
  1180. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1181. if _, err := db.Exec(startupQueries); err != nil {
  1182. return err
  1183. }
  1184. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1185. return err
  1186. }
  1187. return nil
  1188. }
  1189. func setupDB(db *sql.DB) error {
  1190. // If 'schemaVersion' table does not exist, this must be a new database
  1191. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1192. if err != nil {
  1193. return setupNewDB(db)
  1194. }
  1195. defer rowsSV.Close()
  1196. // If 'schemaVersion' table exists, read version and potentially upgrade
  1197. schemaVersion := 0
  1198. if !rowsSV.Next() {
  1199. return errors.New("cannot determine schema version: database file may be corrupt")
  1200. }
  1201. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1202. return err
  1203. }
  1204. rowsSV.Close()
  1205. // Do migrations
  1206. if schemaVersion == currentSchemaVersion {
  1207. return nil
  1208. } else if schemaVersion > currentSchemaVersion {
  1209. return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
  1210. }
  1211. for i := schemaVersion; i < currentSchemaVersion; i++ {
  1212. fn, ok := migrations[i]
  1213. if !ok {
  1214. return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
  1215. } else if err := fn(db); err != nil {
  1216. return err
  1217. }
  1218. }
  1219. return nil
  1220. }
  1221. func setupNewDB(db *sql.DB) error {
  1222. if _, err := db.Exec(createTablesQueries); err != nil {
  1223. return err
  1224. }
  1225. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1226. return err
  1227. }
  1228. return nil
  1229. }
  1230. func migrateFrom1(db *sql.DB) error {
  1231. log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
  1232. tx, err := db.Begin()
  1233. if err != nil {
  1234. return err
  1235. }
  1236. defer tx.Rollback()
  1237. // Rename user -> user_old, and create new tables
  1238. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  1239. return err
  1240. }
  1241. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  1242. return err
  1243. }
  1244. // Insert users from user_old into new user table, with ID and sync_topic
  1245. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1246. if err != nil {
  1247. return err
  1248. }
  1249. defer rows.Close()
  1250. usernames := make([]string, 0)
  1251. for rows.Next() {
  1252. var username string
  1253. if err := rows.Scan(&username); err != nil {
  1254. return err
  1255. }
  1256. usernames = append(usernames, username)
  1257. }
  1258. if err := rows.Close(); err != nil {
  1259. return err
  1260. }
  1261. for _, username := range usernames {
  1262. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1263. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1264. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1265. return err
  1266. }
  1267. }
  1268. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1269. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1270. return err
  1271. }
  1272. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1273. return err
  1274. }
  1275. if err := tx.Commit(); err != nil {
  1276. return err
  1277. }
  1278. return nil
  1279. }
  1280. func nullString(s string) sql.NullString {
  1281. if s == "" {
  1282. return sql.NullString{}
  1283. }
  1284. return sql.NullString{String: s, Valid: true}
  1285. }
  1286. func nullInt64(v int64) sql.NullInt64 {
  1287. if v == 0 {
  1288. return sql.NullInt64{}
  1289. }
  1290. return sql.NullInt64{Int64: v, Valid: true}
  1291. }