user.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package mysql
  2. import (
  3. "context"
  4. "fmt"
  5. "slices"
  6. "strings"
  7. "github.com/pkg/errors"
  8. "github.com/usememos/memos/store"
  9. )
  10. func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
  11. fields := []string{"`username`", "`role`", "`email`", "`nickname`", "`password_hash`", "`avatar_url`"}
  12. placeholder := []string{"?", "?", "?", "?", "?", "?"}
  13. args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
  14. stmt := "INSERT INTO user (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
  15. result, err := d.db.ExecContext(ctx, stmt, args...)
  16. if err != nil {
  17. return nil, err
  18. }
  19. id, err := result.LastInsertId()
  20. if err != nil {
  21. return nil, err
  22. }
  23. id32 := int32(id)
  24. list, err := d.ListUsers(ctx, &store.FindUser{ID: &id32})
  25. if err != nil {
  26. return nil, err
  27. }
  28. if len(list) != 1 {
  29. return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
  30. }
  31. return list[0], nil
  32. }
  33. func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
  34. set, args := []string{}, []any{}
  35. if v := update.UpdatedTs; v != nil {
  36. set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
  37. }
  38. if v := update.RowStatus; v != nil {
  39. set, args = append(set, "`row_status` = ?"), append(args, *v)
  40. }
  41. if v := update.Username; v != nil {
  42. set, args = append(set, "`username` = ?"), append(args, *v)
  43. }
  44. if v := update.Email; v != nil {
  45. set, args = append(set, "`email` = ?"), append(args, *v)
  46. }
  47. if v := update.Nickname; v != nil {
  48. set, args = append(set, "`nickname` = ?"), append(args, *v)
  49. }
  50. if v := update.AvatarURL; v != nil {
  51. set, args = append(set, "`avatar_url` = ?"), append(args, *v)
  52. }
  53. if v := update.PasswordHash; v != nil {
  54. set, args = append(set, "`password_hash` = ?"), append(args, *v)
  55. }
  56. if v := update.Description; v != nil {
  57. set, args = append(set, "`description` = ?"), append(args, *v)
  58. }
  59. args = append(args, update.ID)
  60. query := "UPDATE `user` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
  61. if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
  62. return nil, err
  63. }
  64. user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
  65. if err != nil {
  66. return nil, err
  67. }
  68. return user, nil
  69. }
  70. func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
  71. where, args := []string{"1 = 1"}, []any{}
  72. if v := find.ID; v != nil {
  73. where, args = append(where, "`id` = ?"), append(args, *v)
  74. }
  75. if v := find.Username; v != nil {
  76. where, args = append(where, "`username` = ?"), append(args, *v)
  77. }
  78. if v := find.Role; v != nil {
  79. where, args = append(where, "`role` = ?"), append(args, *v)
  80. }
  81. if v := find.Email; v != nil {
  82. where, args = append(where, "`email` = ?"), append(args, *v)
  83. }
  84. if v := find.Nickname; v != nil {
  85. where, args = append(where, "`nickname` = ?"), append(args, *v)
  86. }
  87. orderBy := []string{"`created_ts` DESC", "`row_status` DESC"}
  88. if find.Random {
  89. orderBy = slices.Concat([]string{"RAND()"}, orderBy)
  90. }
  91. query := "SELECT `id`, `username`, `role`, `email`, `nickname`, `password_hash`, `avatar_url`, `description`, UNIX_TIMESTAMP(`created_ts`), UNIX_TIMESTAMP(`updated_ts`), `row_status` FROM `user` WHERE " + strings.Join(where, " AND ") + " ORDER BY " + strings.Join(orderBy, ", ")
  92. if v := find.Limit; v != nil {
  93. query += fmt.Sprintf(" LIMIT %d", *v)
  94. }
  95. rows, err := d.db.QueryContext(ctx, query, args...)
  96. if err != nil {
  97. return nil, err
  98. }
  99. defer rows.Close()
  100. list := make([]*store.User, 0)
  101. for rows.Next() {
  102. var user store.User
  103. if err := rows.Scan(
  104. &user.ID,
  105. &user.Username,
  106. &user.Role,
  107. &user.Email,
  108. &user.Nickname,
  109. &user.PasswordHash,
  110. &user.AvatarURL,
  111. &user.Description,
  112. &user.CreatedTs,
  113. &user.UpdatedTs,
  114. &user.RowStatus,
  115. ); err != nil {
  116. return nil, err
  117. }
  118. list = append(list, &user)
  119. }
  120. if err := rows.Err(); err != nil {
  121. return nil, err
  122. }
  123. return list, nil
  124. }
  125. func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
  126. list, err := d.ListUsers(ctx, find)
  127. if err != nil {
  128. return nil, err
  129. }
  130. if len(list) != 1 {
  131. return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
  132. }
  133. return list[0], nil
  134. }
  135. func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
  136. result, err := d.db.ExecContext(ctx, "DELETE FROM `user` WHERE `id` = ?", delete.ID)
  137. if err != nil {
  138. return err
  139. }
  140. if _, err := result.RowsAffected(); err != nil {
  141. return err
  142. }
  143. return nil
  144. }