manager.go 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492
  1. // Package user deals with authentication and authorization against topics
  2. package user
  3. import (
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. _ "github.com/mattn/go-sqlite3" // SQLite driver
  9. "github.com/stripe/stripe-go/v74"
  10. "golang.org/x/crypto/bcrypt"
  11. "heckel.io/ntfy/log"
  12. "heckel.io/ntfy/util"
  13. "net/netip"
  14. "strings"
  15. "sync"
  16. "time"
  17. )
  18. const (
  19. tierIDPrefix = "ti_"
  20. tierIDLength = 8
  21. syncTopicPrefix = "st_"
  22. syncTopicLength = 16
  23. userIDPrefix = "u_"
  24. userIDLength = 12
  25. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
  26. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  27. tokenPrefix = "tk_"
  28. tokenLength = 32
  29. tokenMaxCount = 20 // Only keep this many tokens in the table per user
  30. tag = "user_manager"
  31. )
  32. // Default constants that may be overridden by configs
  33. const (
  34. DefaultUserStatsQueueWriterInterval = 33 * time.Second
  35. DefaultUserPasswordBcryptCost = 10
  36. )
  37. var (
  38. errNoTokenProvided = errors.New("no token provided")
  39. errTopicOwnedByOthers = errors.New("topic owned by others")
  40. errNoRows = errors.New("no rows found")
  41. )
  42. // Manager-related queries
  43. const (
  44. createTablesQueries = `
  45. BEGIN;
  46. CREATE TABLE IF NOT EXISTS tier (
  47. id TEXT PRIMARY KEY,
  48. code TEXT NOT NULL,
  49. name TEXT NOT NULL,
  50. messages_limit INT NOT NULL,
  51. messages_expiry_duration INT NOT NULL,
  52. emails_limit INT NOT NULL,
  53. reservations_limit INT NOT NULL,
  54. attachment_file_size_limit INT NOT NULL,
  55. attachment_total_size_limit INT NOT NULL,
  56. attachment_expiry_duration INT NOT NULL,
  57. attachment_bandwidth_limit INT NOT NULL,
  58. stripe_monthly_price_id TEXT,
  59. stripe_yearly_price_id TEXT
  60. );
  61. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  62. CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
  63. CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
  64. CREATE TABLE IF NOT EXISTS user (
  65. id TEXT PRIMARY KEY,
  66. tier_id TEXT,
  67. user TEXT NOT NULL,
  68. pass TEXT NOT NULL,
  69. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  70. prefs JSON NOT NULL DEFAULT '{}',
  71. sync_topic TEXT NOT NULL,
  72. stats_messages INT NOT NULL DEFAULT (0),
  73. stats_emails INT NOT NULL DEFAULT (0),
  74. stripe_customer_id TEXT,
  75. stripe_subscription_id TEXT,
  76. stripe_subscription_status TEXT,
  77. stripe_subscription_interval TEXT,
  78. stripe_subscription_paid_until INT,
  79. stripe_subscription_cancel_at INT,
  80. created INT NOT NULL,
  81. deleted INT,
  82. FOREIGN KEY (tier_id) REFERENCES tier (id)
  83. );
  84. CREATE UNIQUE INDEX idx_user ON user (user);
  85. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  86. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  87. CREATE TABLE IF NOT EXISTS user_access (
  88. user_id TEXT NOT NULL,
  89. topic TEXT NOT NULL,
  90. read INT NOT NULL,
  91. write INT NOT NULL,
  92. owner_user_id INT,
  93. PRIMARY KEY (user_id, topic),
  94. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  95. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  96. );
  97. CREATE TABLE IF NOT EXISTS user_token (
  98. user_id TEXT NOT NULL,
  99. token TEXT NOT NULL,
  100. label TEXT NOT NULL,
  101. last_access INT NOT NULL,
  102. last_origin TEXT NOT NULL,
  103. expires INT NOT NULL,
  104. PRIMARY KEY (user_id, token),
  105. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  106. );
  107. CREATE TABLE IF NOT EXISTS schemaVersion (
  108. id INT PRIMARY KEY,
  109. version INT NOT NULL
  110. );
  111. INSERT INTO user (id, user, pass, role, sync_topic, created)
  112. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
  113. ON CONFLICT (id) DO NOTHING;
  114. COMMIT;
  115. `
  116. builtinStartupQueries = `
  117. PRAGMA foreign_keys = ON;
  118. `
  119. selectUserByIDQuery = `
  120. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  121. FROM user u
  122. LEFT JOIN tier t on t.id = u.tier_id
  123. WHERE u.id = ?
  124. `
  125. selectUserByNameQuery = `
  126. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  127. FROM user u
  128. LEFT JOIN tier t on t.id = u.tier_id
  129. WHERE user = ?
  130. `
  131. selectUserByTokenQuery = `
  132. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  133. FROM user u
  134. JOIN user_token tk on u.id = tk.user_id
  135. LEFT JOIN tier t on t.id = u.tier_id
  136. WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
  137. `
  138. selectUserByStripeCustomerIDQuery = `
  139. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  140. FROM user u
  141. LEFT JOIN tier t on t.id = u.tier_id
  142. WHERE u.stripe_customer_id = ?
  143. `
  144. selectTopicPermsQuery = `
  145. SELECT read, write
  146. FROM user_access a
  147. JOIN user u ON u.id = a.user_id
  148. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  149. ORDER BY u.user DESC
  150. `
  151. insertUserQuery = `
  152. INSERT INTO user (id, user, pass, role, sync_topic, created)
  153. VALUES (?, ?, ?, ?, ?, ?)
  154. `
  155. selectUsernamesQuery = `
  156. SELECT user
  157. FROM user
  158. ORDER BY
  159. CASE role
  160. WHEN 'admin' THEN 1
  161. WHEN 'anonymous' THEN 3
  162. ELSE 2
  163. END, user
  164. `
  165. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  166. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  167. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
  168. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
  169. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  170. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  171. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  172. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  173. upsertUserAccessQuery = `
  174. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  175. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  176. ON CONFLICT (user_id, topic)
  177. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  178. `
  179. selectUserAccessQuery = `
  180. SELECT topic, read, write
  181. FROM user_access
  182. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  183. ORDER BY write DESC, read DESC, topic
  184. `
  185. selectUserReservationsQuery = `
  186. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  187. FROM user_access a_user
  188. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  189. WHERE a_user.user_id = a_user.owner_user_id
  190. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  191. ORDER BY a_user.topic
  192. `
  193. selectUserReservationsCountQuery = `
  194. SELECT COUNT(*)
  195. FROM user_access
  196. WHERE user_id = owner_user_id
  197. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  198. `
  199. selectUserReservationsOwnerQuery = `
  200. SELECT owner_user_id
  201. FROM user_access
  202. WHERE topic = ?
  203. AND user_id = owner_user_id
  204. `
  205. selectUserHasReservationQuery = `
  206. SELECT COUNT(*)
  207. FROM user_access
  208. WHERE user_id = owner_user_id
  209. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  210. AND topic = ?
  211. `
  212. selectOtherAccessCountQuery = `
  213. SELECT COUNT(*)
  214. FROM user_access
  215. WHERE (topic = ? OR ? LIKE topic)
  216. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  217. `
  218. deleteAllAccessQuery = `DELETE FROM user_access`
  219. deleteUserAccessQuery = `
  220. DELETE FROM user_access
  221. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  222. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  223. `
  224. deleteTopicAccessQuery = `
  225. DELETE FROM user_access
  226. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  227. AND topic = ?
  228. `
  229. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  230. selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
  231. selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
  232. insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
  233. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
  234. updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
  235. updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
  236. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  237. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  238. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
  239. deleteExcessTokensQuery = `
  240. DELETE FROM user_token
  241. WHERE (user_id, token) NOT IN (
  242. SELECT user_id, token
  243. FROM user_token
  244. WHERE user_id = ?
  245. ORDER BY expires DESC
  246. LIMIT ?
  247. )
  248. `
  249. insertTierQuery = `
  250. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
  251. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  252. `
  253. updateTierQuery = `
  254. UPDATE tier
  255. SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
  256. WHERE code = ?
  257. `
  258. selectTiersQuery = `
  259. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  260. FROM tier
  261. `
  262. selectTierByCodeQuery = `
  263. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  264. FROM tier
  265. WHERE code = ?
  266. `
  267. selectTierByPriceIDQuery = `
  268. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  269. FROM tier
  270. WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
  271. `
  272. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  273. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  274. deleteTierQuery = `DELETE FROM tier WHERE code = ?`
  275. updateBillingQuery = `
  276. UPDATE user
  277. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  278. WHERE user = ?
  279. `
  280. )
  281. // Schema management queries
  282. const (
  283. currentSchemaVersion = 3
  284. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  285. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  286. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  287. // 1 -> 2 (complex migration!)
  288. migrate1To2CreateTablesQueries = `
  289. ALTER TABLE user RENAME TO user_old;
  290. CREATE TABLE IF NOT EXISTS tier (
  291. id TEXT PRIMARY KEY,
  292. code TEXT NOT NULL,
  293. name TEXT NOT NULL,
  294. messages_limit INT NOT NULL,
  295. messages_expiry_duration INT NOT NULL,
  296. emails_limit INT NOT NULL,
  297. reservations_limit INT NOT NULL,
  298. attachment_file_size_limit INT NOT NULL,
  299. attachment_total_size_limit INT NOT NULL,
  300. attachment_expiry_duration INT NOT NULL,
  301. attachment_bandwidth_limit INT NOT NULL,
  302. stripe_price_id TEXT
  303. );
  304. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  305. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  306. CREATE TABLE IF NOT EXISTS user (
  307. id TEXT PRIMARY KEY,
  308. tier_id TEXT,
  309. user TEXT NOT NULL,
  310. pass TEXT NOT NULL,
  311. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  312. prefs JSON NOT NULL DEFAULT '{}',
  313. sync_topic TEXT NOT NULL,
  314. stats_messages INT NOT NULL DEFAULT (0),
  315. stats_emails INT NOT NULL DEFAULT (0),
  316. stripe_customer_id TEXT,
  317. stripe_subscription_id TEXT,
  318. stripe_subscription_status TEXT,
  319. stripe_subscription_paid_until INT,
  320. stripe_subscription_cancel_at INT,
  321. created INT NOT NULL,
  322. deleted INT,
  323. FOREIGN KEY (tier_id) REFERENCES tier (id)
  324. );
  325. CREATE UNIQUE INDEX idx_user ON user (user);
  326. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  327. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  328. CREATE TABLE IF NOT EXISTS user_access (
  329. user_id TEXT NOT NULL,
  330. topic TEXT NOT NULL,
  331. read INT NOT NULL,
  332. write INT NOT NULL,
  333. owner_user_id INT,
  334. PRIMARY KEY (user_id, topic),
  335. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  336. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  337. );
  338. CREATE TABLE IF NOT EXISTS user_token (
  339. user_id TEXT NOT NULL,
  340. token TEXT NOT NULL,
  341. label TEXT NOT NULL,
  342. last_access INT NOT NULL,
  343. last_origin TEXT NOT NULL,
  344. expires INT NOT NULL,
  345. PRIMARY KEY (user_id, token),
  346. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  347. );
  348. CREATE TABLE IF NOT EXISTS schemaVersion (
  349. id INT PRIMARY KEY,
  350. version INT NOT NULL
  351. );
  352. INSERT INTO user (id, user, pass, role, sync_topic, created)
  353. VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
  354. ON CONFLICT (id) DO NOTHING;
  355. `
  356. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  357. migrate1To2InsertUserNoTx = `
  358. INSERT INTO user (id, user, pass, role, sync_topic, created)
  359. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  360. `
  361. migrate1To2InsertFromOldTablesAndDropNoTx = `
  362. INSERT INTO user_access (user_id, topic, read, write)
  363. SELECT u.id, a.topic, a.read, a.write
  364. FROM user u
  365. JOIN access a ON u.user = a.user;
  366. DROP TABLE access;
  367. DROP TABLE user_old;
  368. `
  369. // 2 -> 3
  370. migrate2To3UpdateQueries = `
  371. ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
  372. ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
  373. ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
  374. DROP INDEX IF EXISTS idx_tier_price_id;
  375. CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
  376. CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
  377. `
  378. )
  379. var (
  380. migrations = map[int]func(db *sql.DB) error{
  381. 1: migrateFrom1,
  382. 2: migrateFrom2,
  383. }
  384. )
  385. // Manager is an implementation of Manager. It stores users and access control list
  386. // in a SQLite database.
  387. type Manager struct {
  388. db *sql.DB
  389. defaultAccess Permission // Default permission if no ACL matches
  390. statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
  391. tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
  392. bcryptCost int // Makes testing easier
  393. mu sync.Mutex
  394. }
  395. var _ Auther = (*Manager)(nil)
  396. // NewManager creates a new Manager instance
  397. func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) {
  398. db, err := sql.Open("sqlite3", filename)
  399. if err != nil {
  400. return nil, err
  401. }
  402. if err := setupDB(db); err != nil {
  403. return nil, err
  404. }
  405. if err := runStartupQueries(db, startupQueries); err != nil {
  406. return nil, err
  407. }
  408. manager := &Manager{
  409. db: db,
  410. defaultAccess: defaultAccess,
  411. statsQueue: make(map[string]*Stats),
  412. tokenQueue: make(map[string]*TokenUpdate),
  413. bcryptCost: bcryptCost,
  414. }
  415. go manager.asyncQueueWriter(queueWriterInterval)
  416. return manager, nil
  417. }
  418. // Authenticate checks username and password and returns a User if correct, and the user has not been
  419. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  420. // the password is correct or incorrect.
  421. func (a *Manager) Authenticate(username, password string) (*User, error) {
  422. if username == Everyone {
  423. return nil, ErrUnauthenticated
  424. }
  425. user, err := a.User(username)
  426. if err != nil {
  427. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
  428. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  429. return nil, ErrUnauthenticated
  430. } else if user.Deleted {
  431. log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
  432. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  433. return nil, ErrUnauthenticated
  434. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  435. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
  436. return nil, ErrUnauthenticated
  437. }
  438. return user, nil
  439. }
  440. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  441. // The method sets the User.Token value to the token that was used for authentication.
  442. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  443. if len(token) != tokenLength {
  444. return nil, ErrUnauthenticated
  445. }
  446. user, err := a.userByToken(token)
  447. if err != nil {
  448. log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
  449. return nil, ErrUnauthenticated
  450. }
  451. user.Token = token
  452. return user, nil
  453. }
  454. // CreateToken generates a random token for the given user and returns it. The token expires
  455. // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
  456. // given user, if there are too many of them.
  457. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
  458. token := util.RandomStringPrefix(tokenPrefix, tokenLength)
  459. tx, err := a.db.Begin()
  460. if err != nil {
  461. return nil, err
  462. }
  463. defer tx.Rollback()
  464. access := time.Now()
  465. if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
  466. return nil, err
  467. }
  468. rows, err := tx.Query(selectTokenCountQuery, userID)
  469. if err != nil {
  470. return nil, err
  471. }
  472. defer rows.Close()
  473. if !rows.Next() {
  474. return nil, errNoRows
  475. }
  476. var tokenCount int
  477. if err := rows.Scan(&tokenCount); err != nil {
  478. return nil, err
  479. }
  480. if tokenCount >= tokenMaxCount {
  481. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  482. // on two indices, whereas the query below is a full table scan.
  483. if _, err := tx.Exec(deleteExcessTokensQuery, userID, tokenMaxCount); err != nil {
  484. return nil, err
  485. }
  486. }
  487. if err := tx.Commit(); err != nil {
  488. return nil, err
  489. }
  490. return &Token{
  491. Value: token,
  492. Label: label,
  493. LastAccess: access,
  494. LastOrigin: origin,
  495. Expires: expires,
  496. }, nil
  497. }
  498. // Tokens returns all existing tokens for the user with the given user ID
  499. func (a *Manager) Tokens(userID string) ([]*Token, error) {
  500. rows, err := a.db.Query(selectTokensQuery, userID)
  501. if err != nil {
  502. return nil, err
  503. }
  504. defer rows.Close()
  505. tokens := make([]*Token, 0)
  506. for {
  507. token, err := a.readToken(rows)
  508. if err == ErrTokenNotFound {
  509. break
  510. } else if err != nil {
  511. return nil, err
  512. }
  513. tokens = append(tokens, token)
  514. }
  515. return tokens, nil
  516. }
  517. // Token returns a specific token for a user
  518. func (a *Manager) Token(userID, token string) (*Token, error) {
  519. rows, err := a.db.Query(selectTokenQuery, userID, token)
  520. if err != nil {
  521. return nil, err
  522. }
  523. defer rows.Close()
  524. return a.readToken(rows)
  525. }
  526. func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
  527. var token, label, lastOrigin string
  528. var lastAccess, expires int64
  529. if !rows.Next() {
  530. return nil, ErrTokenNotFound
  531. }
  532. if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
  533. return nil, err
  534. } else if err := rows.Err(); err != nil {
  535. return nil, err
  536. }
  537. lastOriginIP, err := netip.ParseAddr(lastOrigin)
  538. if err != nil {
  539. lastOriginIP = netip.IPv4Unspecified()
  540. }
  541. return &Token{
  542. Value: token,
  543. Label: label,
  544. LastAccess: time.Unix(lastAccess, 0),
  545. LastOrigin: lastOriginIP,
  546. Expires: time.Unix(expires, 0),
  547. }, nil
  548. }
  549. // ChangeToken updates a token's label and/or expiry date
  550. func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
  551. if token == "" {
  552. return nil, errNoTokenProvided
  553. }
  554. tx, err := a.db.Begin()
  555. if err != nil {
  556. return nil, err
  557. }
  558. defer tx.Rollback()
  559. if label != nil {
  560. if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil {
  561. return nil, err
  562. }
  563. }
  564. if expires != nil {
  565. if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil {
  566. return nil, err
  567. }
  568. }
  569. if err := tx.Commit(); err != nil {
  570. return nil, err
  571. }
  572. return a.Token(userID, token)
  573. }
  574. // RemoveToken deletes the token defined in User.Token
  575. func (a *Manager) RemoveToken(userID, token string) error {
  576. if token == "" {
  577. return errNoTokenProvided
  578. }
  579. if _, err := a.db.Exec(deleteTokenQuery, userID, token); err != nil {
  580. return err
  581. }
  582. return nil
  583. }
  584. // RemoveExpiredTokens deletes all expired tokens from the database
  585. func (a *Manager) RemoveExpiredTokens() error {
  586. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  587. return err
  588. }
  589. return nil
  590. }
  591. // RemoveDeletedUsers deletes all users that have been marked deleted for
  592. func (a *Manager) RemoveDeletedUsers() error {
  593. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  594. return err
  595. }
  596. return nil
  597. }
  598. // ChangeSettings persists the user settings
  599. func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error {
  600. b, err := json.Marshal(prefs)
  601. if err != nil {
  602. return err
  603. }
  604. if _, err := a.db.Exec(updateUserPrefsQuery, string(b), userID); err != nil {
  605. return err
  606. }
  607. return nil
  608. }
  609. // ResetStats resets all user stats in the user database. This touches all users.
  610. func (a *Manager) ResetStats() error {
  611. a.mu.Lock() // Includes database query to avoid races!
  612. defer a.mu.Unlock()
  613. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  614. return err
  615. }
  616. a.statsQueue = make(map[string]*Stats)
  617. return nil
  618. }
  619. // EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  620. // batches at a regular interval
  621. func (a *Manager) EnqueueUserStats(userID string, stats *Stats) {
  622. a.mu.Lock()
  623. defer a.mu.Unlock()
  624. a.statsQueue[userID] = stats
  625. }
  626. // EnqueueTokenUpdate adds the token update to a queue which writes out token access times
  627. // in batches at a regular interval
  628. func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
  629. a.mu.Lock()
  630. defer a.mu.Unlock()
  631. a.tokenQueue[tokenID] = update
  632. }
  633. func (a *Manager) asyncQueueWriter(interval time.Duration) {
  634. ticker := time.NewTicker(interval)
  635. for range ticker.C {
  636. if err := a.writeUserStatsQueue(); err != nil {
  637. log.Tag(tag).Err(err).Warn("Writing user stats queue failed")
  638. }
  639. if err := a.writeTokenUpdateQueue(); err != nil {
  640. log.Tag(tag).Err(err).Warn("Writing token update queue failed")
  641. }
  642. }
  643. }
  644. func (a *Manager) writeUserStatsQueue() error {
  645. a.mu.Lock()
  646. if len(a.statsQueue) == 0 {
  647. a.mu.Unlock()
  648. log.Tag(tag).Trace("No user stats updates to commit")
  649. return nil
  650. }
  651. statsQueue := a.statsQueue
  652. a.statsQueue = make(map[string]*Stats)
  653. a.mu.Unlock()
  654. tx, err := a.db.Begin()
  655. if err != nil {
  656. return err
  657. }
  658. defer tx.Rollback()
  659. log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
  660. for userID, update := range statsQueue {
  661. log.
  662. Tag(tag).
  663. Fields(log.Context{
  664. "user_id": userID,
  665. "messages_count": update.Messages,
  666. "emails_count": update.Emails,
  667. }).
  668. Trace("Updating stats for user %s", userID)
  669. if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil {
  670. return err
  671. }
  672. }
  673. return tx.Commit()
  674. }
  675. func (a *Manager) writeTokenUpdateQueue() error {
  676. a.mu.Lock()
  677. if len(a.tokenQueue) == 0 {
  678. a.mu.Unlock()
  679. log.Tag(tag).Trace("No token updates to commit")
  680. return nil
  681. }
  682. tokenQueue := a.tokenQueue
  683. a.tokenQueue = make(map[string]*TokenUpdate)
  684. a.mu.Unlock()
  685. tx, err := a.db.Begin()
  686. if err != nil {
  687. return err
  688. }
  689. defer tx.Rollback()
  690. log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
  691. for tokenID, update := range tokenQueue {
  692. log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
  693. if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
  694. return err
  695. }
  696. }
  697. return tx.Commit()
  698. }
  699. // Authorize returns nil if the given user has access to the given topic using the desired
  700. // permission. The user param may be nil to signal an anonymous user.
  701. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  702. if user != nil && user.Role == RoleAdmin {
  703. return nil // Admin can do everything
  704. }
  705. username := Everyone
  706. if user != nil {
  707. username = user.Name
  708. }
  709. // Select the read/write permissions for this user/topic combo. The query may return two
  710. // rows (one for everyone, and one for the user), but prioritizes the user.
  711. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  712. if err != nil {
  713. return err
  714. }
  715. defer rows.Close()
  716. if !rows.Next() {
  717. return a.resolvePerms(a.defaultAccess, perm)
  718. }
  719. var read, write bool
  720. if err := rows.Scan(&read, &write); err != nil {
  721. return err
  722. } else if err := rows.Err(); err != nil {
  723. return err
  724. }
  725. return a.resolvePerms(NewPermission(read, write), perm)
  726. }
  727. func (a *Manager) resolvePerms(base, perm Permission) error {
  728. if perm == PermissionRead && base.IsRead() {
  729. return nil
  730. } else if perm == PermissionWrite && base.IsWrite() {
  731. return nil
  732. }
  733. return ErrUnauthorized
  734. }
  735. // AddUser adds a user with the given username, password and role
  736. func (a *Manager) AddUser(username, password string, role Role) error {
  737. if !AllowedUsername(username) || !AllowedRole(role) {
  738. return ErrInvalidArgument
  739. }
  740. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  741. if err != nil {
  742. return err
  743. }
  744. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  745. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  746. if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
  747. return err
  748. }
  749. return nil
  750. }
  751. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  752. // if the user did not exist in the first place.
  753. func (a *Manager) RemoveUser(username string) error {
  754. if !AllowedUsername(username) {
  755. return ErrInvalidArgument
  756. }
  757. // Rows in user_access, user_token, etc. are deleted via foreign keys
  758. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  759. return err
  760. }
  761. return nil
  762. }
  763. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  764. // successful auth via Authenticate. A background process will delete the user at a later date.
  765. func (a *Manager) MarkUserRemoved(user *User) error {
  766. if !AllowedUsername(user.Name) {
  767. return ErrInvalidArgument
  768. }
  769. tx, err := a.db.Begin()
  770. if err != nil {
  771. return err
  772. }
  773. defer tx.Rollback()
  774. if _, err := tx.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  775. return err
  776. }
  777. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  778. return err
  779. }
  780. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  781. return err
  782. }
  783. return tx.Commit()
  784. }
  785. // Users returns a list of users. It always also returns the Everyone user ("*").
  786. func (a *Manager) Users() ([]*User, error) {
  787. rows, err := a.db.Query(selectUsernamesQuery)
  788. if err != nil {
  789. return nil, err
  790. }
  791. defer rows.Close()
  792. usernames := make([]string, 0)
  793. for rows.Next() {
  794. var username string
  795. if err := rows.Scan(&username); err != nil {
  796. return nil, err
  797. } else if err := rows.Err(); err != nil {
  798. return nil, err
  799. }
  800. usernames = append(usernames, username)
  801. }
  802. rows.Close()
  803. users := make([]*User, 0)
  804. for _, username := range usernames {
  805. user, err := a.User(username)
  806. if err != nil {
  807. return nil, err
  808. }
  809. users = append(users, user)
  810. }
  811. return users, nil
  812. }
  813. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  814. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  815. func (a *Manager) User(username string) (*User, error) {
  816. rows, err := a.db.Query(selectUserByNameQuery, username)
  817. if err != nil {
  818. return nil, err
  819. }
  820. return a.readUser(rows)
  821. }
  822. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  823. func (a *Manager) UserByID(id string) (*User, error) {
  824. rows, err := a.db.Query(selectUserByIDQuery, id)
  825. if err != nil {
  826. return nil, err
  827. }
  828. return a.readUser(rows)
  829. }
  830. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  831. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  832. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  833. if err != nil {
  834. return nil, err
  835. }
  836. return a.readUser(rows)
  837. }
  838. func (a *Manager) userByToken(token string) (*User, error) {
  839. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  840. if err != nil {
  841. return nil, err
  842. }
  843. return a.readUser(rows)
  844. }
  845. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  846. defer rows.Close()
  847. var id, username, hash, role, prefs, syncTopic string
  848. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
  849. var messages, emails int64
  850. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  851. if !rows.Next() {
  852. return nil, ErrUserNotFound
  853. }
  854. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
  855. return nil, err
  856. } else if err := rows.Err(); err != nil {
  857. return nil, err
  858. }
  859. user := &User{
  860. ID: id,
  861. Name: username,
  862. Hash: hash,
  863. Role: Role(role),
  864. Prefs: &Prefs{},
  865. SyncTopic: syncTopic,
  866. Stats: &Stats{
  867. Messages: messages,
  868. Emails: emails,
  869. },
  870. Billing: &Billing{
  871. StripeCustomerID: stripeCustomerID.String, // May be empty
  872. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  873. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  874. StripeSubscriptionInterval: stripe.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
  875. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  876. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  877. },
  878. Deleted: deleted.Valid,
  879. }
  880. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  881. return nil, err
  882. }
  883. if tierCode.Valid {
  884. // See readTier() when this is changed!
  885. user.Tier = &Tier{
  886. ID: tierID.String,
  887. Code: tierCode.String,
  888. Name: tierName.String,
  889. MessageLimit: messagesLimit.Int64,
  890. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  891. EmailLimit: emailsLimit.Int64,
  892. ReservationLimit: reservationsLimit.Int64,
  893. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  894. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  895. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  896. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  897. StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
  898. StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
  899. }
  900. }
  901. return user, nil
  902. }
  903. // Grants returns all user-specific access control entries
  904. func (a *Manager) Grants(username string) ([]Grant, error) {
  905. rows, err := a.db.Query(selectUserAccessQuery, username)
  906. if err != nil {
  907. return nil, err
  908. }
  909. defer rows.Close()
  910. grants := make([]Grant, 0)
  911. for rows.Next() {
  912. var topic string
  913. var read, write bool
  914. if err := rows.Scan(&topic, &read, &write); err != nil {
  915. return nil, err
  916. } else if err := rows.Err(); err != nil {
  917. return nil, err
  918. }
  919. grants = append(grants, Grant{
  920. TopicPattern: fromSQLWildcard(topic),
  921. Allow: NewPermission(read, write),
  922. })
  923. }
  924. return grants, nil
  925. }
  926. // Reservations returns all user-owned topics, and the associated everyone-access
  927. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  928. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  929. if err != nil {
  930. return nil, err
  931. }
  932. defer rows.Close()
  933. reservations := make([]Reservation, 0)
  934. for rows.Next() {
  935. var topic string
  936. var ownerRead, ownerWrite bool
  937. var everyoneRead, everyoneWrite sql.NullBool
  938. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  939. return nil, err
  940. } else if err := rows.Err(); err != nil {
  941. return nil, err
  942. }
  943. reservations = append(reservations, Reservation{
  944. Topic: topic,
  945. Owner: NewPermission(ownerRead, ownerWrite),
  946. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  947. })
  948. }
  949. return reservations, nil
  950. }
  951. // HasReservation returns true if the given topic access is owned by the user
  952. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  953. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  954. if err != nil {
  955. return false, err
  956. }
  957. defer rows.Close()
  958. if !rows.Next() {
  959. return false, errNoRows
  960. }
  961. var count int64
  962. if err := rows.Scan(&count); err != nil {
  963. return false, err
  964. }
  965. return count > 0, nil
  966. }
  967. // ReservationsCount returns the number of reservations owned by this user
  968. func (a *Manager) ReservationsCount(username string) (int64, error) {
  969. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  970. if err != nil {
  971. return 0, err
  972. }
  973. defer rows.Close()
  974. if !rows.Next() {
  975. return 0, errNoRows
  976. }
  977. var count int64
  978. if err := rows.Scan(&count); err != nil {
  979. return 0, err
  980. }
  981. return count, nil
  982. }
  983. // ReservationOwner returns user ID of the user that owns this topic, or an
  984. // empty string if it's not owned by anyone
  985. func (a *Manager) ReservationOwner(topic string) (string, error) {
  986. rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic)
  987. if err != nil {
  988. return "", err
  989. }
  990. defer rows.Close()
  991. if !rows.Next() {
  992. return "", nil
  993. }
  994. var ownerUserID string
  995. if err := rows.Scan(&ownerUserID); err != nil {
  996. return "", err
  997. }
  998. return ownerUserID, nil
  999. }
  1000. // ChangePassword changes a user's password
  1001. func (a *Manager) ChangePassword(username, password string) error {
  1002. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  1003. if err != nil {
  1004. return err
  1005. }
  1006. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  1007. return err
  1008. }
  1009. return nil
  1010. }
  1011. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  1012. // all existing access control entries (Grant) are removed, since they are no longer needed.
  1013. func (a *Manager) ChangeRole(username string, role Role) error {
  1014. if !AllowedUsername(username) || !AllowedRole(role) {
  1015. return ErrInvalidArgument
  1016. }
  1017. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  1018. return err
  1019. }
  1020. if role == RoleAdmin {
  1021. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  1022. return err
  1023. }
  1024. }
  1025. return nil
  1026. }
  1027. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  1028. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  1029. func (a *Manager) ChangeTier(username, tier string) error {
  1030. if !AllowedUsername(username) {
  1031. return ErrInvalidArgument
  1032. }
  1033. t, err := a.Tier(tier)
  1034. if err != nil {
  1035. return err
  1036. } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
  1037. return err
  1038. }
  1039. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  1040. return err
  1041. }
  1042. return nil
  1043. }
  1044. // ResetTier removes the tier from the given user
  1045. func (a *Manager) ResetTier(username string) error {
  1046. if !AllowedUsername(username) && username != Everyone && username != "" {
  1047. return ErrInvalidArgument
  1048. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  1049. return err
  1050. }
  1051. _, err := a.db.Exec(deleteUserTierQuery, username)
  1052. return err
  1053. }
  1054. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  1055. u, err := a.User(username)
  1056. if err != nil {
  1057. return err
  1058. }
  1059. if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
  1060. reservations, err := a.Reservations(username)
  1061. if err != nil {
  1062. return err
  1063. } else if int64(len(reservations)) > reservationsLimit {
  1064. return ErrTooManyReservations
  1065. }
  1066. }
  1067. return nil
  1068. }
  1069. // AllowReservation tests if a user may create an access control entry for the given topic.
  1070. // If there are any ACL entries that are not owned by the user, an error is returned.
  1071. func (a *Manager) AllowReservation(username string, topic string) error {
  1072. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  1073. return ErrInvalidArgument
  1074. }
  1075. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  1076. if err != nil {
  1077. return err
  1078. }
  1079. defer rows.Close()
  1080. if !rows.Next() {
  1081. return errNoRows
  1082. }
  1083. var otherCount int
  1084. if err := rows.Scan(&otherCount); err != nil {
  1085. return err
  1086. }
  1087. if otherCount > 0 {
  1088. return errTopicOwnedByOthers
  1089. }
  1090. return nil
  1091. }
  1092. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  1093. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  1094. // owner may either be a user (username), or the system (empty).
  1095. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  1096. if !AllowedUsername(username) && username != Everyone {
  1097. return ErrInvalidArgument
  1098. } else if !AllowedTopicPattern(topicPattern) {
  1099. return ErrInvalidArgument
  1100. }
  1101. owner := ""
  1102. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  1103. return err
  1104. }
  1105. return nil
  1106. }
  1107. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  1108. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  1109. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  1110. if !AllowedUsername(username) && username != Everyone && username != "" {
  1111. return ErrInvalidArgument
  1112. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  1113. return ErrInvalidArgument
  1114. }
  1115. if username == "" && topicPattern == "" {
  1116. _, err := a.db.Exec(deleteAllAccessQuery, username)
  1117. return err
  1118. } else if topicPattern == "" {
  1119. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  1120. return err
  1121. }
  1122. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  1123. return err
  1124. }
  1125. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  1126. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  1127. // can modify or delete them.
  1128. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  1129. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  1130. return ErrInvalidArgument
  1131. }
  1132. tx, err := a.db.Begin()
  1133. if err != nil {
  1134. return err
  1135. }
  1136. defer tx.Rollback()
  1137. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  1138. return err
  1139. }
  1140. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  1141. return err
  1142. }
  1143. return tx.Commit()
  1144. }
  1145. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  1146. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  1147. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  1148. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  1149. return ErrInvalidArgument
  1150. }
  1151. for _, topic := range topics {
  1152. if !AllowedTopic(topic) {
  1153. return ErrInvalidArgument
  1154. }
  1155. }
  1156. tx, err := a.db.Begin()
  1157. if err != nil {
  1158. return err
  1159. }
  1160. defer tx.Rollback()
  1161. for _, topic := range topics {
  1162. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  1163. return err
  1164. }
  1165. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  1166. return err
  1167. }
  1168. }
  1169. return tx.Commit()
  1170. }
  1171. // DefaultAccess returns the default read/write access if no access control entry matches
  1172. func (a *Manager) DefaultAccess() Permission {
  1173. return a.defaultAccess
  1174. }
  1175. // AddTier creates a new tier in the database
  1176. func (a *Manager) AddTier(tier *Tier) error {
  1177. if tier.ID == "" {
  1178. tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  1179. }
  1180. if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
  1181. return err
  1182. }
  1183. return nil
  1184. }
  1185. // UpdateTier updates a tier's properties in the database
  1186. func (a *Manager) UpdateTier(tier *Tier) error {
  1187. if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
  1188. return err
  1189. }
  1190. return nil
  1191. }
  1192. // RemoveTier deletes the tier with the given code
  1193. func (a *Manager) RemoveTier(code string) error {
  1194. if !AllowedTier(code) {
  1195. return ErrInvalidArgument
  1196. }
  1197. // This fails if any user has this tier
  1198. if _, err := a.db.Exec(deleteTierQuery, code); err != nil {
  1199. return err
  1200. }
  1201. return nil
  1202. }
  1203. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  1204. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  1205. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  1206. return err
  1207. }
  1208. return nil
  1209. }
  1210. // Tiers returns a list of all Tier structs
  1211. func (a *Manager) Tiers() ([]*Tier, error) {
  1212. rows, err := a.db.Query(selectTiersQuery)
  1213. if err != nil {
  1214. return nil, err
  1215. }
  1216. defer rows.Close()
  1217. tiers := make([]*Tier, 0)
  1218. for {
  1219. tier, err := a.readTier(rows)
  1220. if err == ErrTierNotFound {
  1221. break
  1222. } else if err != nil {
  1223. return nil, err
  1224. }
  1225. tiers = append(tiers, tier)
  1226. }
  1227. return tiers, nil
  1228. }
  1229. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  1230. func (a *Manager) Tier(code string) (*Tier, error) {
  1231. rows, err := a.db.Query(selectTierByCodeQuery, code)
  1232. if err != nil {
  1233. return nil, err
  1234. }
  1235. defer rows.Close()
  1236. return a.readTier(rows)
  1237. }
  1238. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  1239. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  1240. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID, priceID)
  1241. if err != nil {
  1242. return nil, err
  1243. }
  1244. defer rows.Close()
  1245. return a.readTier(rows)
  1246. }
  1247. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  1248. var id, code, name string
  1249. var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
  1250. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
  1251. if !rows.Next() {
  1252. return nil, ErrTierNotFound
  1253. }
  1254. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
  1255. return nil, err
  1256. } else if err := rows.Err(); err != nil {
  1257. return nil, err
  1258. }
  1259. // When changed, note readUser() as well
  1260. return &Tier{
  1261. ID: id,
  1262. Code: code,
  1263. Name: name,
  1264. MessageLimit: messagesLimit.Int64,
  1265. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1266. EmailLimit: emailsLimit.Int64,
  1267. ReservationLimit: reservationsLimit.Int64,
  1268. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1269. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1270. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1271. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1272. StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
  1273. StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
  1274. }, nil
  1275. }
  1276. // Close closes the underlying database
  1277. func (a *Manager) Close() error {
  1278. return a.db.Close()
  1279. }
  1280. func toSQLWildcard(s string) string {
  1281. return strings.ReplaceAll(s, "*", "%")
  1282. }
  1283. func fromSQLWildcard(s string) string {
  1284. return strings.ReplaceAll(s, "%", "*")
  1285. }
  1286. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1287. if _, err := db.Exec(startupQueries); err != nil {
  1288. return err
  1289. }
  1290. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1291. return err
  1292. }
  1293. return nil
  1294. }
  1295. func setupDB(db *sql.DB) error {
  1296. // If 'schemaVersion' table does not exist, this must be a new database
  1297. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1298. if err != nil {
  1299. return setupNewDB(db)
  1300. }
  1301. defer rowsSV.Close()
  1302. // If 'schemaVersion' table exists, read version and potentially upgrade
  1303. schemaVersion := 0
  1304. if !rowsSV.Next() {
  1305. return errors.New("cannot determine schema version: database file may be corrupt")
  1306. }
  1307. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1308. return err
  1309. }
  1310. rowsSV.Close()
  1311. // Do migrations
  1312. if schemaVersion == currentSchemaVersion {
  1313. return nil
  1314. } else if schemaVersion > currentSchemaVersion {
  1315. return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
  1316. }
  1317. for i := schemaVersion; i < currentSchemaVersion; i++ {
  1318. fn, ok := migrations[i]
  1319. if !ok {
  1320. return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
  1321. } else if err := fn(db); err != nil {
  1322. return err
  1323. }
  1324. }
  1325. return nil
  1326. }
  1327. func setupNewDB(db *sql.DB) error {
  1328. if _, err := db.Exec(createTablesQueries); err != nil {
  1329. return err
  1330. }
  1331. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1332. return err
  1333. }
  1334. return nil
  1335. }
  1336. func migrateFrom1(db *sql.DB) error {
  1337. log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
  1338. tx, err := db.Begin()
  1339. if err != nil {
  1340. return err
  1341. }
  1342. defer tx.Rollback()
  1343. // Rename user -> user_old, and create new tables
  1344. if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
  1345. return err
  1346. }
  1347. // Insert users from user_old into new user table, with ID and sync_topic
  1348. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1349. if err != nil {
  1350. return err
  1351. }
  1352. defer rows.Close()
  1353. usernames := make([]string, 0)
  1354. for rows.Next() {
  1355. var username string
  1356. if err := rows.Scan(&username); err != nil {
  1357. return err
  1358. }
  1359. usernames = append(usernames, username)
  1360. }
  1361. if err := rows.Close(); err != nil {
  1362. return err
  1363. }
  1364. for _, username := range usernames {
  1365. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1366. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1367. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1368. return err
  1369. }
  1370. }
  1371. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1372. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1373. return err
  1374. }
  1375. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1376. return err
  1377. }
  1378. if err := tx.Commit(); err != nil {
  1379. return err
  1380. }
  1381. return nil
  1382. }
  1383. func migrateFrom2(db *sql.DB) error {
  1384. log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
  1385. tx, err := db.Begin()
  1386. if err != nil {
  1387. return err
  1388. }
  1389. defer tx.Rollback()
  1390. if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
  1391. return err
  1392. }
  1393. if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
  1394. return err
  1395. }
  1396. return tx.Commit()
  1397. }
  1398. func nullString(s string) sql.NullString {
  1399. if s == "" {
  1400. return sql.NullString{}
  1401. }
  1402. return sql.NullString{String: s, Valid: true}
  1403. }
  1404. func nullInt64(v int64) sql.NullInt64 {
  1405. if v == 0 {
  1406. return sql.NullInt64{}
  1407. }
  1408. return sql.NullInt64{Int64: v, Valid: true}
  1409. }