user_setting.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "strings"
  6. "github.com/usememos/memos/api"
  7. )
  8. type userSettingRaw struct {
  9. UserID int
  10. Key api.UserSettingKey
  11. Value string
  12. }
  13. func (raw *userSettingRaw) toUserSetting() *api.UserSetting {
  14. return &api.UserSetting{
  15. UserID: raw.UserID,
  16. Key: raw.Key,
  17. Value: raw.Value,
  18. }
  19. }
  20. func (s *Store) UpsertUserSetting(ctx context.Context, upsert *api.UserSettingUpsert) (*api.UserSetting, error) {
  21. tx, err := s.db.BeginTx(ctx, nil)
  22. if err != nil {
  23. return nil, FormatError(err)
  24. }
  25. defer tx.Rollback()
  26. userSettingRaw, err := upsertUserSetting(ctx, tx, upsert)
  27. if err != nil {
  28. return nil, err
  29. }
  30. if err := tx.Commit(); err != nil {
  31. return nil, err
  32. }
  33. userSetting := userSettingRaw.toUserSetting()
  34. return userSetting, nil
  35. }
  36. func (s *Store) FindUserSettingList(ctx context.Context, find *api.UserSettingFind) ([]*api.UserSetting, error) {
  37. tx, err := s.db.BeginTx(ctx, nil)
  38. if err != nil {
  39. return nil, FormatError(err)
  40. }
  41. defer tx.Rollback()
  42. userSettingRawList, err := findUserSettingList(ctx, tx, find)
  43. if err != nil {
  44. return nil, err
  45. }
  46. list := []*api.UserSetting{}
  47. for _, raw := range userSettingRawList {
  48. list = append(list, raw.toUserSetting())
  49. }
  50. return list, nil
  51. }
  52. func (s *Store) FindUserSetting(ctx context.Context, find *api.UserSettingFind) (*api.UserSetting, error) {
  53. tx, err := s.db.BeginTx(ctx, nil)
  54. if err != nil {
  55. return nil, FormatError(err)
  56. }
  57. defer tx.Rollback()
  58. list, err := findUserSettingList(ctx, tx, find)
  59. if err != nil {
  60. return nil, err
  61. }
  62. if len(list) == 0 {
  63. return nil, nil
  64. }
  65. userSetting := list[0].toUserSetting()
  66. return userSetting, nil
  67. }
  68. func upsertUserSetting(ctx context.Context, tx *sql.Tx, upsert *api.UserSettingUpsert) (*userSettingRaw, error) {
  69. query := `
  70. INSERT INTO user_setting (
  71. user_id, key, value
  72. )
  73. VALUES (?, ?, ?)
  74. ON CONFLICT(user_id, key) DO UPDATE
  75. SET
  76. value = EXCLUDED.value
  77. RETURNING user_id, key, value
  78. `
  79. var userSettingRaw userSettingRaw
  80. if err := tx.QueryRowContext(ctx, query, upsert.UserID, upsert.Key, upsert.Value).Scan(
  81. &userSettingRaw.UserID,
  82. &userSettingRaw.Key,
  83. &userSettingRaw.Value,
  84. ); err != nil {
  85. return nil, FormatError(err)
  86. }
  87. return &userSettingRaw, nil
  88. }
  89. func findUserSettingList(ctx context.Context, tx *sql.Tx, find *api.UserSettingFind) ([]*userSettingRaw, error) {
  90. where, args := []string{"1 = 1"}, []interface{}{}
  91. if v := find.Key; v != nil {
  92. where, args = append(where, "key = ?"), append(args, v.String())
  93. }
  94. where, args = append(where, "user_id = ?"), append(args, find.UserID)
  95. query := `
  96. SELECT
  97. user_id,
  98. key,
  99. value
  100. FROM user_setting
  101. WHERE ` + strings.Join(where, " AND ")
  102. rows, err := tx.QueryContext(ctx, query, args...)
  103. if err != nil {
  104. return nil, FormatError(err)
  105. }
  106. defer rows.Close()
  107. userSettingRawList := make([]*userSettingRaw, 0)
  108. for rows.Next() {
  109. var userSettingRaw userSettingRaw
  110. if err := rows.Scan(
  111. &userSettingRaw.UserID,
  112. &userSettingRaw.Key,
  113. &userSettingRaw.Value,
  114. ); err != nil {
  115. return nil, FormatError(err)
  116. }
  117. userSettingRawList = append(userSettingRawList, &userSettingRaw)
  118. }
  119. if err := rows.Err(); err != nil {
  120. return nil, FormatError(err)
  121. }
  122. return userSettingRawList, nil
  123. }
  124. func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
  125. stmt := `
  126. DELETE FROM
  127. user_setting
  128. WHERE
  129. user_id NOT IN (
  130. SELECT
  131. id
  132. FROM
  133. user
  134. )`
  135. _, err := tx.ExecContext(ctx, stmt)
  136. if err != nil {
  137. return FormatError(err)
  138. }
  139. return nil
  140. }