user.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "github.com/usememos/memos/api"
  8. "github.com/usememos/memos/common"
  9. )
  10. // Role is the type of a role.
  11. type Role string
  12. const (
  13. // Host is the HOST role.
  14. Host Role = "HOST"
  15. // Admin is the ADMIN role.
  16. Admin Role = "ADMIN"
  17. // NormalUser is the USER role.
  18. NormalUser Role = "USER"
  19. )
  20. func (e Role) String() string {
  21. switch e {
  22. case Host:
  23. return "HOST"
  24. case Admin:
  25. return "ADMIN"
  26. case NormalUser:
  27. return "USER"
  28. }
  29. return "USER"
  30. }
  31. type UserMessage struct {
  32. ID int
  33. // Standard fields
  34. RowStatus RowStatus
  35. CreatedTs int64
  36. UpdatedTs int64
  37. // Domain specific fields
  38. Username string
  39. Role Role
  40. Email string
  41. Nickname string
  42. PasswordHash string
  43. OpenID string
  44. AvatarURL string
  45. }
  46. type FindUserMessage struct {
  47. ID *int
  48. // Standard fields
  49. RowStatus *RowStatus
  50. // Domain specific fields
  51. Username *string
  52. Role *Role
  53. Email *string
  54. Nickname *string
  55. OpenID *string
  56. }
  57. func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMessage, error) {
  58. tx, err := s.db.BeginTx(ctx, nil)
  59. if err != nil {
  60. return nil, FormatError(err)
  61. }
  62. defer tx.Rollback()
  63. query := `
  64. INSERT INTO user (
  65. username,
  66. role,
  67. email,
  68. nickname,
  69. password_hash,
  70. open_id
  71. )
  72. VALUES (?, ?, ?, ?, ?, ?)
  73. RETURNING id, avatar_url, created_ts, updated_ts, row_status
  74. `
  75. if err := tx.QueryRowContext(ctx, query,
  76. create.Username,
  77. create.Role,
  78. create.Email,
  79. create.Nickname,
  80. create.PasswordHash,
  81. create.OpenID,
  82. ).Scan(
  83. &create.ID,
  84. &create.AvatarURL,
  85. &create.CreatedTs,
  86. &create.UpdatedTs,
  87. &create.RowStatus,
  88. ); err != nil {
  89. return nil, FormatError(err)
  90. }
  91. if err := tx.Commit(); err != nil {
  92. return nil, FormatError(err)
  93. }
  94. userMessage := create
  95. return userMessage, nil
  96. }
  97. func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMessage, error) {
  98. tx, err := s.db.BeginTx(ctx, nil)
  99. if err != nil {
  100. return nil, FormatError(err)
  101. }
  102. defer tx.Rollback()
  103. list, err := listUsers(ctx, tx, find)
  104. if err != nil {
  105. return nil, err
  106. }
  107. return list, nil
  108. }
  109. func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessage, error) {
  110. tx, err := s.db.BeginTx(ctx, nil)
  111. if err != nil {
  112. return nil, FormatError(err)
  113. }
  114. defer tx.Rollback()
  115. list, err := listUsers(ctx, tx, find)
  116. if err != nil {
  117. return nil, err
  118. }
  119. if len(list) == 0 {
  120. return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
  121. }
  122. memoMessage := list[0]
  123. return memoMessage, nil
  124. }
  125. func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserMessage, error) {
  126. where, args := []string{"1 = 1"}, []any{}
  127. if v := find.ID; v != nil {
  128. where, args = append(where, "id = ?"), append(args, *v)
  129. }
  130. if v := find.Username; v != nil {
  131. where, args = append(where, "username = ?"), append(args, *v)
  132. }
  133. if v := find.Role; v != nil {
  134. where, args = append(where, "role = ?"), append(args, *v)
  135. }
  136. if v := find.Email; v != nil {
  137. where, args = append(where, "email = ?"), append(args, *v)
  138. }
  139. if v := find.Nickname; v != nil {
  140. where, args = append(where, "nickname = ?"), append(args, *v)
  141. }
  142. if v := find.OpenID; v != nil {
  143. where, args = append(where, "open_id = ?"), append(args, *v)
  144. }
  145. query := `
  146. SELECT
  147. id,
  148. username,
  149. role,
  150. email,
  151. nickname,
  152. password_hash,
  153. open_id,
  154. avatar_url,
  155. created_ts,
  156. updated_ts,
  157. row_status
  158. FROM user
  159. WHERE ` + strings.Join(where, " AND ") + `
  160. ORDER BY created_ts DESC, row_status DESC
  161. `
  162. rows, err := tx.QueryContext(ctx, query, args...)
  163. if err != nil {
  164. return nil, FormatError(err)
  165. }
  166. defer rows.Close()
  167. userMessageList := make([]*UserMessage, 0)
  168. for rows.Next() {
  169. var userMessage UserMessage
  170. if err := rows.Scan(
  171. &userMessage.ID,
  172. &userMessage.Username,
  173. &userMessage.Role,
  174. &userMessage.Email,
  175. &userMessage.Nickname,
  176. &userMessage.PasswordHash,
  177. &userMessage.OpenID,
  178. &userMessage.AvatarURL,
  179. &userMessage.CreatedTs,
  180. &userMessage.UpdatedTs,
  181. &userMessage.RowStatus,
  182. ); err != nil {
  183. return nil, FormatError(err)
  184. }
  185. userMessageList = append(userMessageList, &userMessage)
  186. }
  187. if err := rows.Err(); err != nil {
  188. return nil, FormatError(err)
  189. }
  190. return userMessageList, nil
  191. }
  192. // userRaw is the store model for an User.
  193. // Fields have exactly the same meanings as User.
  194. type userRaw struct {
  195. ID int
  196. // Standard fields
  197. RowStatus api.RowStatus
  198. CreatedTs int64
  199. UpdatedTs int64
  200. // Domain specific fields
  201. Username string
  202. Role api.Role
  203. Email string
  204. Nickname string
  205. PasswordHash string
  206. OpenID string
  207. AvatarURL string
  208. }
  209. func (raw *userRaw) toUser() *api.User {
  210. return &api.User{
  211. ID: raw.ID,
  212. RowStatus: raw.RowStatus,
  213. CreatedTs: raw.CreatedTs,
  214. UpdatedTs: raw.UpdatedTs,
  215. Username: raw.Username,
  216. Role: raw.Role,
  217. Email: raw.Email,
  218. Nickname: raw.Nickname,
  219. PasswordHash: raw.PasswordHash,
  220. OpenID: raw.OpenID,
  221. AvatarURL: raw.AvatarURL,
  222. }
  223. }
  224. func (s *Store) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) {
  225. tx, err := s.db.BeginTx(ctx, nil)
  226. if err != nil {
  227. return nil, FormatError(err)
  228. }
  229. defer tx.Rollback()
  230. userRaw, err := createUser(ctx, tx, create)
  231. if err != nil {
  232. return nil, err
  233. }
  234. if err := tx.Commit(); err != nil {
  235. return nil, FormatError(err)
  236. }
  237. s.userCache.Store(userRaw.ID, userRaw)
  238. user := userRaw.toUser()
  239. return user, nil
  240. }
  241. func (s *Store) PatchUser(ctx context.Context, patch *api.UserPatch) (*api.User, error) {
  242. tx, err := s.db.BeginTx(ctx, nil)
  243. if err != nil {
  244. return nil, FormatError(err)
  245. }
  246. defer tx.Rollback()
  247. userRaw, err := patchUser(ctx, tx, patch)
  248. if err != nil {
  249. return nil, err
  250. }
  251. if err := tx.Commit(); err != nil {
  252. return nil, FormatError(err)
  253. }
  254. s.userCache.Store(userRaw.ID, userRaw)
  255. user := userRaw.toUser()
  256. return user, nil
  257. }
  258. func (s *Store) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) {
  259. tx, err := s.db.BeginTx(ctx, nil)
  260. if err != nil {
  261. return nil, FormatError(err)
  262. }
  263. defer tx.Rollback()
  264. userRawList, err := findUserList(ctx, tx, find)
  265. if err != nil {
  266. return nil, err
  267. }
  268. list := []*api.User{}
  269. for _, raw := range userRawList {
  270. list = append(list, raw.toUser())
  271. }
  272. return list, nil
  273. }
  274. func (s *Store) FindUser(ctx context.Context, find *api.UserFind) (*api.User, error) {
  275. if find.ID != nil {
  276. if user, ok := s.userCache.Load(*find.ID); ok {
  277. return user.(*userRaw).toUser(), nil
  278. }
  279. }
  280. tx, err := s.db.BeginTx(ctx, nil)
  281. if err != nil {
  282. return nil, FormatError(err)
  283. }
  284. defer tx.Rollback()
  285. list, err := findUserList(ctx, tx, find)
  286. if err != nil {
  287. return nil, err
  288. }
  289. if len(list) == 0 {
  290. return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found user with filter %+v", find)}
  291. }
  292. userRaw := list[0]
  293. s.userCache.Store(userRaw.ID, userRaw)
  294. user := userRaw.toUser()
  295. return user, nil
  296. }
  297. func (s *Store) DeleteUser(ctx context.Context, delete *api.UserDelete) error {
  298. tx, err := s.db.BeginTx(ctx, nil)
  299. if err != nil {
  300. return FormatError(err)
  301. }
  302. defer tx.Rollback()
  303. if err := deleteUser(ctx, tx, delete); err != nil {
  304. return err
  305. }
  306. if err := s.vacuumImpl(ctx, tx); err != nil {
  307. return err
  308. }
  309. if err := tx.Commit(); err != nil {
  310. return err
  311. }
  312. s.userCache.Delete(delete.ID)
  313. return nil
  314. }
  315. func createUser(ctx context.Context, tx *sql.Tx, create *api.UserCreate) (*userRaw, error) {
  316. query := `
  317. INSERT INTO user (
  318. username,
  319. role,
  320. email,
  321. nickname,
  322. password_hash,
  323. open_id
  324. )
  325. VALUES (?, ?, ?, ?, ?, ?)
  326. RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status
  327. `
  328. var userRaw userRaw
  329. if err := tx.QueryRowContext(ctx, query,
  330. create.Username,
  331. create.Role,
  332. create.Email,
  333. create.Nickname,
  334. create.PasswordHash,
  335. create.OpenID,
  336. ).Scan(
  337. &userRaw.ID,
  338. &userRaw.Username,
  339. &userRaw.Role,
  340. &userRaw.Email,
  341. &userRaw.Nickname,
  342. &userRaw.PasswordHash,
  343. &userRaw.OpenID,
  344. &userRaw.AvatarURL,
  345. &userRaw.CreatedTs,
  346. &userRaw.UpdatedTs,
  347. &userRaw.RowStatus,
  348. ); err != nil {
  349. return nil, FormatError(err)
  350. }
  351. return &userRaw, nil
  352. }
  353. func patchUser(ctx context.Context, tx *sql.Tx, patch *api.UserPatch) (*userRaw, error) {
  354. set, args := []string{}, []any{}
  355. if v := patch.UpdatedTs; v != nil {
  356. set, args = append(set, "updated_ts = ?"), append(args, *v)
  357. }
  358. if v := patch.RowStatus; v != nil {
  359. set, args = append(set, "row_status = ?"), append(args, *v)
  360. }
  361. if v := patch.Username; v != nil {
  362. set, args = append(set, "username = ?"), append(args, *v)
  363. }
  364. if v := patch.Email; v != nil {
  365. set, args = append(set, "email = ?"), append(args, *v)
  366. }
  367. if v := patch.Nickname; v != nil {
  368. set, args = append(set, "nickname = ?"), append(args, *v)
  369. }
  370. if v := patch.AvatarURL; v != nil {
  371. set, args = append(set, "avatar_url = ?"), append(args, *v)
  372. }
  373. if v := patch.PasswordHash; v != nil {
  374. set, args = append(set, "password_hash = ?"), append(args, *v)
  375. }
  376. if v := patch.OpenID; v != nil {
  377. set, args = append(set, "open_id = ?"), append(args, *v)
  378. }
  379. args = append(args, patch.ID)
  380. query := `
  381. UPDATE user
  382. SET ` + strings.Join(set, ", ") + `
  383. WHERE id = ?
  384. RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status
  385. `
  386. var userRaw userRaw
  387. if err := tx.QueryRowContext(ctx, query, args...).Scan(
  388. &userRaw.ID,
  389. &userRaw.Username,
  390. &userRaw.Role,
  391. &userRaw.Email,
  392. &userRaw.Nickname,
  393. &userRaw.PasswordHash,
  394. &userRaw.OpenID,
  395. &userRaw.AvatarURL,
  396. &userRaw.CreatedTs,
  397. &userRaw.UpdatedTs,
  398. &userRaw.RowStatus,
  399. ); err != nil {
  400. return nil, FormatError(err)
  401. }
  402. return &userRaw, nil
  403. }
  404. func findUserList(ctx context.Context, tx *sql.Tx, find *api.UserFind) ([]*userRaw, error) {
  405. where, args := []string{"1 = 1"}, []any{}
  406. if v := find.ID; v != nil {
  407. where, args = append(where, "id = ?"), append(args, *v)
  408. }
  409. if v := find.Username; v != nil {
  410. where, args = append(where, "username = ?"), append(args, *v)
  411. }
  412. if v := find.Role; v != nil {
  413. where, args = append(where, "role = ?"), append(args, *v)
  414. }
  415. if v := find.Email; v != nil {
  416. where, args = append(where, "email = ?"), append(args, *v)
  417. }
  418. if v := find.Nickname; v != nil {
  419. where, args = append(where, "nickname = ?"), append(args, *v)
  420. }
  421. if v := find.OpenID; v != nil {
  422. where, args = append(where, "open_id = ?"), append(args, *v)
  423. }
  424. query := `
  425. SELECT
  426. id,
  427. username,
  428. role,
  429. email,
  430. nickname,
  431. password_hash,
  432. open_id,
  433. avatar_url,
  434. created_ts,
  435. updated_ts,
  436. row_status
  437. FROM user
  438. WHERE ` + strings.Join(where, " AND ") + `
  439. ORDER BY created_ts DESC, row_status DESC
  440. `
  441. rows, err := tx.QueryContext(ctx, query, args...)
  442. if err != nil {
  443. return nil, FormatError(err)
  444. }
  445. defer rows.Close()
  446. userRawList := make([]*userRaw, 0)
  447. for rows.Next() {
  448. var userRaw userRaw
  449. if err := rows.Scan(
  450. &userRaw.ID,
  451. &userRaw.Username,
  452. &userRaw.Role,
  453. &userRaw.Email,
  454. &userRaw.Nickname,
  455. &userRaw.PasswordHash,
  456. &userRaw.OpenID,
  457. &userRaw.AvatarURL,
  458. &userRaw.CreatedTs,
  459. &userRaw.UpdatedTs,
  460. &userRaw.RowStatus,
  461. ); err != nil {
  462. return nil, FormatError(err)
  463. }
  464. userRawList = append(userRawList, &userRaw)
  465. }
  466. if err := rows.Err(); err != nil {
  467. return nil, FormatError(err)
  468. }
  469. return userRawList, nil
  470. }
  471. func deleteUser(ctx context.Context, tx *sql.Tx, delete *api.UserDelete) error {
  472. result, err := tx.ExecContext(ctx, `
  473. DELETE FROM user WHERE id = ?
  474. `, delete.ID)
  475. if err != nil {
  476. return FormatError(err)
  477. }
  478. rows, err := result.RowsAffected()
  479. if err != nil {
  480. return err
  481. }
  482. if rows == 0 {
  483. return &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
  484. }
  485. return nil
  486. }