user.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package mysql
  2. import (
  3. "context"
  4. "strings"
  5. "github.com/pkg/errors"
  6. "github.com/usememos/memos/store"
  7. )
  8. func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
  9. fields := []string{"`username`", "`role`", "`email`", "`nickname`", "`password_hash`", "`avatar_url`"}
  10. placeholder := []string{"?", "?", "?", "?", "?", "?"}
  11. args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
  12. if create.RowStatus != "" {
  13. fields = append(fields, "`row_status`")
  14. placeholder = append(placeholder, "?")
  15. args = append(args, create.RowStatus)
  16. }
  17. if create.CreatedTs != 0 {
  18. fields = append(fields, "`created_ts`")
  19. placeholder = append(placeholder, "FROM_UNIXTIME(?)")
  20. args = append(args, create.CreatedTs)
  21. }
  22. if create.UpdatedTs != 0 {
  23. fields = append(fields, "`updated_ts`")
  24. placeholder = append(placeholder, "FROM_UNIXTIME(?)")
  25. args = append(args, create.UpdatedTs)
  26. }
  27. if create.ID != 0 {
  28. fields = append(fields, "`id`")
  29. placeholder = append(placeholder, "?")
  30. args = append(args, create.ID)
  31. }
  32. stmt := "INSERT INTO user (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
  33. result, err := d.db.ExecContext(ctx, stmt, args...)
  34. if err != nil {
  35. return nil, err
  36. }
  37. id, err := result.LastInsertId()
  38. if err != nil {
  39. return nil, err
  40. }
  41. id64 := int32(id)
  42. list, err := d.ListUsers(ctx, &store.FindUser{ID: &id64})
  43. if err != nil {
  44. return nil, err
  45. }
  46. if len(list) != 1 {
  47. return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
  48. }
  49. return list[0], nil
  50. }
  51. func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
  52. set, args := []string{}, []any{}
  53. if v := update.UpdatedTs; v != nil {
  54. set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
  55. }
  56. if v := update.RowStatus; v != nil {
  57. set, args = append(set, "`row_status` = ?"), append(args, *v)
  58. }
  59. if v := update.Username; v != nil {
  60. set, args = append(set, "`username` = ?"), append(args, *v)
  61. }
  62. if v := update.Email; v != nil {
  63. set, args = append(set, "`email` = ?"), append(args, *v)
  64. }
  65. if v := update.Nickname; v != nil {
  66. set, args = append(set, "`nickname` = ?"), append(args, *v)
  67. }
  68. if v := update.AvatarURL; v != nil {
  69. set, args = append(set, "`avatar_url` = ?"), append(args, *v)
  70. }
  71. if v := update.PasswordHash; v != nil {
  72. set, args = append(set, "`password_hash` = ?"), append(args, *v)
  73. }
  74. args = append(args, update.ID)
  75. query := "UPDATE `user` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
  76. if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
  77. return nil, err
  78. }
  79. user := &store.User{}
  80. query = "SELECT `id`, `username`, `role`, `email`, `nickname`, `password_hash`, `avatar_url`, UNIX_TIMESTAMP(`created_ts`), UNIX_TIMESTAMP(`updated_ts`), `row_status` FROM `user` WHERE `id` = ?"
  81. if err := d.db.QueryRowContext(ctx, query, update.ID).Scan(
  82. &user.ID,
  83. &user.Username,
  84. &user.Role,
  85. &user.Email,
  86. &user.Nickname,
  87. &user.PasswordHash,
  88. &user.AvatarURL,
  89. &user.CreatedTs,
  90. &user.UpdatedTs,
  91. &user.RowStatus,
  92. ); err != nil {
  93. return nil, err
  94. }
  95. return user, nil
  96. }
  97. func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
  98. where, args := []string{"1 = 1"}, []any{}
  99. if v := find.ID; v != nil {
  100. where, args = append(where, "`id` = ?"), append(args, *v)
  101. }
  102. if v := find.Username; v != nil {
  103. where, args = append(where, "`username` = ?"), append(args, *v)
  104. }
  105. if v := find.Role; v != nil {
  106. where, args = append(where, "`role` = ?"), append(args, *v)
  107. }
  108. if v := find.Email; v != nil {
  109. where, args = append(where, "`email` = ?"), append(args, *v)
  110. }
  111. if v := find.Nickname; v != nil {
  112. where, args = append(where, "`nickname` = ?"), append(args, *v)
  113. }
  114. query := "SELECT `id`, `username`, `role`, `email`, `nickname`, `password_hash`, `avatar_url`, UNIX_TIMESTAMP(`created_ts`), UNIX_TIMESTAMP(`updated_ts`), `row_status` FROM `user` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC, `row_status` DESC"
  115. rows, err := d.db.QueryContext(ctx, query, args...)
  116. if err != nil {
  117. return nil, err
  118. }
  119. defer rows.Close()
  120. list := make([]*store.User, 0)
  121. for rows.Next() {
  122. var user store.User
  123. if err := rows.Scan(
  124. &user.ID,
  125. &user.Username,
  126. &user.Role,
  127. &user.Email,
  128. &user.Nickname,
  129. &user.PasswordHash,
  130. &user.AvatarURL,
  131. &user.CreatedTs,
  132. &user.UpdatedTs,
  133. &user.RowStatus,
  134. ); err != nil {
  135. return nil, err
  136. }
  137. list = append(list, &user)
  138. }
  139. if err := rows.Err(); err != nil {
  140. return nil, err
  141. }
  142. return list, nil
  143. }
  144. func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
  145. result, err := d.db.ExecContext(ctx, "DELETE FROM `user` WHERE `id` = ?", delete.ID)
  146. if err != nil {
  147. return err
  148. }
  149. if _, err := result.RowsAffected(); err != nil {
  150. return err
  151. }
  152. if err := d.Vacuum(ctx); err != nil {
  153. // Prevent linter warning.
  154. return err
  155. }
  156. return nil
  157. }