user.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package mysql
  2. import (
  3. "context"
  4. "strings"
  5. "github.com/pkg/errors"
  6. "github.com/usememos/memos/store"
  7. )
  8. func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
  9. fields := []string{"`username`", "`role`", "`email`", "`nickname`", "`password_hash`", "`avatar_url`"}
  10. placeholder := []string{"?", "?", "?", "?", "?", "?"}
  11. args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
  12. stmt := "INSERT INTO user (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
  13. result, err := d.db.ExecContext(ctx, stmt, args...)
  14. if err != nil {
  15. return nil, err
  16. }
  17. id, err := result.LastInsertId()
  18. if err != nil {
  19. return nil, err
  20. }
  21. id32 := int32(id)
  22. list, err := d.ListUsers(ctx, &store.FindUser{ID: &id32})
  23. if err != nil {
  24. return nil, err
  25. }
  26. if len(list) != 1 {
  27. return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
  28. }
  29. return list[0], nil
  30. }
  31. func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
  32. set, args := []string{}, []any{}
  33. if v := update.UpdatedTs; v != nil {
  34. set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
  35. }
  36. if v := update.RowStatus; v != nil {
  37. set, args = append(set, "`row_status` = ?"), append(args, *v)
  38. }
  39. if v := update.Username; v != nil {
  40. set, args = append(set, "`username` = ?"), append(args, *v)
  41. }
  42. if v := update.Email; v != nil {
  43. set, args = append(set, "`email` = ?"), append(args, *v)
  44. }
  45. if v := update.Nickname; v != nil {
  46. set, args = append(set, "`nickname` = ?"), append(args, *v)
  47. }
  48. if v := update.AvatarURL; v != nil {
  49. set, args = append(set, "`avatar_url` = ?"), append(args, *v)
  50. }
  51. if v := update.PasswordHash; v != nil {
  52. set, args = append(set, "`password_hash` = ?"), append(args, *v)
  53. }
  54. args = append(args, update.ID)
  55. query := "UPDATE `user` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
  56. if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
  57. return nil, err
  58. }
  59. user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
  60. if err != nil {
  61. return nil, err
  62. }
  63. return user, nil
  64. }
  65. func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
  66. where, args := []string{"1 = 1"}, []any{}
  67. if v := find.ID; v != nil {
  68. where, args = append(where, "`id` = ?"), append(args, *v)
  69. }
  70. if v := find.Username; v != nil {
  71. where, args = append(where, "`username` = ?"), append(args, *v)
  72. }
  73. if v := find.Role; v != nil {
  74. where, args = append(where, "`role` = ?"), append(args, *v)
  75. }
  76. if v := find.Email; v != nil {
  77. where, args = append(where, "`email` = ?"), append(args, *v)
  78. }
  79. if v := find.Nickname; v != nil {
  80. where, args = append(where, "`nickname` = ?"), append(args, *v)
  81. }
  82. query := "SELECT `id`, `username`, `role`, `email`, `nickname`, `password_hash`, `avatar_url`, UNIX_TIMESTAMP(`created_ts`), UNIX_TIMESTAMP(`updated_ts`), `row_status` FROM `user` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC, `row_status` DESC"
  83. rows, err := d.db.QueryContext(ctx, query, args...)
  84. if err != nil {
  85. return nil, err
  86. }
  87. defer rows.Close()
  88. list := make([]*store.User, 0)
  89. for rows.Next() {
  90. var user store.User
  91. if err := rows.Scan(
  92. &user.ID,
  93. &user.Username,
  94. &user.Role,
  95. &user.Email,
  96. &user.Nickname,
  97. &user.PasswordHash,
  98. &user.AvatarURL,
  99. &user.CreatedTs,
  100. &user.UpdatedTs,
  101. &user.RowStatus,
  102. ); err != nil {
  103. return nil, err
  104. }
  105. list = append(list, &user)
  106. }
  107. if err := rows.Err(); err != nil {
  108. return nil, err
  109. }
  110. return list, nil
  111. }
  112. func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
  113. list, err := d.ListUsers(ctx, find)
  114. if err != nil {
  115. return nil, err
  116. }
  117. if len(list) != 1 {
  118. return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
  119. }
  120. return list[0], nil
  121. }
  122. func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
  123. result, err := d.db.ExecContext(ctx, "DELETE FROM `user` WHERE `id` = ?", delete.ID)
  124. if err != nil {
  125. return err
  126. }
  127. if _, err := result.RowsAffected(); err != nil {
  128. return err
  129. }
  130. if err := d.Vacuum(ctx); err != nil {
  131. // Prevent linter warning.
  132. return err
  133. }
  134. return nil
  135. }