user_setting.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "strings"
  7. "google.golang.org/protobuf/encoding/protojson"
  8. storepb "github.com/usememos/memos/proto/gen/store"
  9. )
  10. type UserSetting struct {
  11. UserID int32
  12. Key string
  13. Value string
  14. }
  15. type FindUserSetting struct {
  16. UserID *int32
  17. Key string
  18. }
  19. func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) {
  20. stmt := `
  21. INSERT INTO user_setting (
  22. user_id, key, value
  23. )
  24. VALUES (?, ?, ?)
  25. ON CONFLICT(user_id, key) DO UPDATE
  26. SET value = EXCLUDED.value
  27. `
  28. if _, err := s.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key, upsert.Value); err != nil {
  29. return nil, err
  30. }
  31. userSetting := upsert
  32. s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
  33. return userSetting, nil
  34. }
  35. func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) {
  36. where, args := []string{"1 = 1"}, []any{}
  37. if v := find.Key; v != "" {
  38. where, args = append(where, "key = ?"), append(args, v)
  39. }
  40. if v := find.UserID; v != nil {
  41. where, args = append(where, "user_id = ?"), append(args, *find.UserID)
  42. }
  43. query := `
  44. SELECT
  45. user_id,
  46. key,
  47. value
  48. FROM user_setting
  49. WHERE ` + strings.Join(where, " AND ")
  50. rows, err := s.db.QueryContext(ctx, query, args...)
  51. if err != nil {
  52. return nil, err
  53. }
  54. defer rows.Close()
  55. userSettingList := make([]*UserSetting, 0)
  56. for rows.Next() {
  57. var userSetting UserSetting
  58. if err := rows.Scan(
  59. &userSetting.UserID,
  60. &userSetting.Key,
  61. &userSetting.Value,
  62. ); err != nil {
  63. return nil, err
  64. }
  65. userSettingList = append(userSettingList, &userSetting)
  66. }
  67. if err := rows.Err(); err != nil {
  68. return nil, err
  69. }
  70. for _, userSetting := range userSettingList {
  71. s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
  72. }
  73. return userSettingList, nil
  74. }
  75. func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) {
  76. if find.UserID != nil {
  77. if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok {
  78. return cache.(*UserSetting), nil
  79. }
  80. }
  81. list, err := s.ListUserSettings(ctx, find)
  82. if err != nil {
  83. return nil, err
  84. }
  85. if len(list) == 0 {
  86. return nil, nil
  87. }
  88. userSetting := list[0]
  89. s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
  90. return userSetting, nil
  91. }
  92. type FindUserSettingV1 struct {
  93. UserID *int32
  94. Key storepb.UserSettingKey
  95. }
  96. func (s *Store) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
  97. stmt := `
  98. INSERT INTO user_setting (
  99. user_id, key, value
  100. )
  101. VALUES (?, ?, ?)
  102. ON CONFLICT(user_id, key) DO UPDATE
  103. SET value = EXCLUDED.value
  104. `
  105. var valueString string
  106. if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
  107. valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
  108. if err != nil {
  109. return nil, err
  110. }
  111. valueString = string(valueBytes)
  112. } else {
  113. return nil, errors.New("invalid user setting key")
  114. }
  115. if _, err := s.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
  116. return nil, err
  117. }
  118. userSettingMessage := upsert
  119. s.userSettingCache.Store(getUserSettingV1CacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage)
  120. return userSettingMessage, nil
  121. }
  122. func (s *Store) ListUserSettingsV1(ctx context.Context, find *FindUserSettingV1) ([]*storepb.UserSetting, error) {
  123. where, args := []string{"1 = 1"}, []any{}
  124. if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
  125. where, args = append(where, "key = ?"), append(args, v.String())
  126. }
  127. if v := find.UserID; v != nil {
  128. where, args = append(where, "user_id = ?"), append(args, *find.UserID)
  129. }
  130. query := `
  131. SELECT
  132. user_id,
  133. key,
  134. value
  135. FROM user_setting
  136. WHERE ` + strings.Join(where, " AND ")
  137. rows, err := s.db.QueryContext(ctx, query, args...)
  138. if err != nil {
  139. return nil, err
  140. }
  141. defer rows.Close()
  142. userSettingList := make([]*storepb.UserSetting, 0)
  143. for rows.Next() {
  144. userSetting := &storepb.UserSetting{}
  145. var keyString, valueString string
  146. if err := rows.Scan(
  147. &userSetting.UserId,
  148. &keyString,
  149. &valueString,
  150. ); err != nil {
  151. return nil, err
  152. }
  153. userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
  154. if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
  155. accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
  156. if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
  157. return nil, err
  158. }
  159. userSetting.Value = &storepb.UserSetting_AccessTokens{
  160. AccessTokens: accessTokensUserSetting,
  161. }
  162. } else {
  163. // Skip unknown user setting v1 key.
  164. continue
  165. }
  166. userSettingList = append(userSettingList, userSetting)
  167. }
  168. if err := rows.Err(); err != nil {
  169. return nil, err
  170. }
  171. for _, userSetting := range userSettingList {
  172. s.userSettingCache.Store(getUserSettingV1CacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
  173. }
  174. return userSettingList, nil
  175. }
  176. func (s *Store) GetUserSettingV1(ctx context.Context, find *FindUserSettingV1) (*storepb.UserSetting, error) {
  177. if find.UserID != nil {
  178. if cache, ok := s.userSettingCache.Load(getUserSettingV1CacheKey(*find.UserID, find.Key.String())); ok {
  179. return cache.(*storepb.UserSetting), nil
  180. }
  181. }
  182. list, err := s.ListUserSettingsV1(ctx, find)
  183. if err != nil {
  184. return nil, err
  185. }
  186. if len(list) == 0 {
  187. return nil, nil
  188. }
  189. userSetting := list[0]
  190. s.userSettingCache.Store(getUserSettingV1CacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
  191. return userSetting, nil
  192. }
  193. // GetUserAccessTokens returns the access tokens of the user.
  194. func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) {
  195. userSetting, err := s.GetUserSettingV1(ctx, &FindUserSettingV1{
  196. UserID: &userID,
  197. Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
  198. })
  199. if err != nil {
  200. return nil, err
  201. }
  202. if userSetting == nil {
  203. return []*storepb.AccessTokensUserSetting_AccessToken{}, nil
  204. }
  205. accessTokensUserSetting := userSetting.GetAccessTokens()
  206. return accessTokensUserSetting.AccessTokens, nil
  207. }
  208. func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
  209. stmt := `
  210. DELETE FROM
  211. user_setting
  212. WHERE
  213. user_id NOT IN (
  214. SELECT
  215. id
  216. FROM
  217. user
  218. )`
  219. _, err := tx.ExecContext(ctx, stmt)
  220. if err != nil {
  221. return err
  222. }
  223. return nil
  224. }