storage.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "strings"
  8. "github.com/usememos/memos/api"
  9. "github.com/usememos/memos/common"
  10. )
  11. type storageRaw struct {
  12. ID int
  13. Name string
  14. Type api.StorageType
  15. Config *api.StorageConfig
  16. }
  17. func (raw *storageRaw) toStorage() *api.Storage {
  18. return &api.Storage{
  19. ID: raw.ID,
  20. Name: raw.Name,
  21. Type: raw.Type,
  22. Config: raw.Config,
  23. }
  24. }
  25. func (s *Store) CreateStorage(ctx context.Context, create *api.StorageCreate) (*api.Storage, error) {
  26. tx, err := s.db.BeginTx(ctx, nil)
  27. if err != nil {
  28. return nil, FormatError(err)
  29. }
  30. defer tx.Rollback()
  31. storageRaw, err := createStorageRaw(ctx, tx, create)
  32. if err != nil {
  33. return nil, err
  34. }
  35. if err := tx.Commit(); err != nil {
  36. return nil, FormatError(err)
  37. }
  38. return storageRaw.toStorage(), nil
  39. }
  40. func (s *Store) PatchStorage(ctx context.Context, patch *api.StoragePatch) (*api.Storage, error) {
  41. tx, err := s.db.BeginTx(ctx, nil)
  42. if err != nil {
  43. return nil, FormatError(err)
  44. }
  45. defer tx.Rollback()
  46. storageRaw, err := patchStorageRaw(ctx, tx, patch)
  47. if err != nil {
  48. return nil, err
  49. }
  50. if err := tx.Commit(); err != nil {
  51. return nil, FormatError(err)
  52. }
  53. return storageRaw.toStorage(), nil
  54. }
  55. func (s *Store) FindStorageList(ctx context.Context, find *api.StorageFind) ([]*api.Storage, error) {
  56. tx, err := s.db.BeginTx(ctx, nil)
  57. if err != nil {
  58. return nil, FormatError(err)
  59. }
  60. defer tx.Rollback()
  61. storageRawList, err := findStorageRawList(ctx, tx, find)
  62. if err != nil {
  63. return nil, err
  64. }
  65. list := []*api.Storage{}
  66. for _, raw := range storageRawList {
  67. list = append(list, raw.toStorage())
  68. }
  69. return list, nil
  70. }
  71. func (s *Store) FindStorage(ctx context.Context, find *api.StorageFind) (*api.Storage, error) {
  72. tx, err := s.db.BeginTx(ctx, nil)
  73. if err != nil {
  74. return nil, FormatError(err)
  75. }
  76. defer tx.Rollback()
  77. list, err := findStorageRawList(ctx, tx, find)
  78. if err != nil {
  79. return nil, err
  80. }
  81. if len(list) == 0 {
  82. return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
  83. }
  84. storageRaw := list[0]
  85. return storageRaw.toStorage(), nil
  86. }
  87. func (s *Store) DeleteStorage(ctx context.Context, delete *api.StorageDelete) error {
  88. tx, err := s.db.BeginTx(ctx, nil)
  89. if err != nil {
  90. return FormatError(err)
  91. }
  92. defer tx.Rollback()
  93. if err := deleteStorage(ctx, tx, delete); err != nil {
  94. return FormatError(err)
  95. }
  96. if err := tx.Commit(); err != nil {
  97. return FormatError(err)
  98. }
  99. return nil
  100. }
  101. func createStorageRaw(ctx context.Context, tx *sql.Tx, create *api.StorageCreate) (*storageRaw, error) {
  102. set := []string{"name", "type", "config"}
  103. args := []any{create.Name, create.Type}
  104. placeholder := []string{"?", "?", "?"}
  105. var configBytes []byte
  106. var err error
  107. if create.Type == api.StorageS3 {
  108. configBytes, err = json.Marshal(create.Config.S3Config)
  109. if err != nil {
  110. return nil, err
  111. }
  112. } else {
  113. return nil, fmt.Errorf("unsupported storage type %s", string(create.Type))
  114. }
  115. args = append(args, string(configBytes))
  116. query := `
  117. INSERT INTO storage (
  118. ` + strings.Join(set, ", ") + `
  119. )
  120. VALUES (` + strings.Join(placeholder, ",") + `)
  121. RETURNING id
  122. `
  123. storageRaw := storageRaw{
  124. Name: create.Name,
  125. Type: create.Type,
  126. Config: create.Config,
  127. }
  128. if err := tx.QueryRowContext(ctx, query, args...).Scan(
  129. &storageRaw.ID,
  130. ); err != nil {
  131. return nil, FormatError(err)
  132. }
  133. return &storageRaw, nil
  134. }
  135. func patchStorageRaw(ctx context.Context, tx *sql.Tx, patch *api.StoragePatch) (*storageRaw, error) {
  136. set, args := []string{}, []any{}
  137. if v := patch.Name; v != nil {
  138. set, args = append(set, "name = ?"), append(args, *v)
  139. }
  140. if v := patch.Config; v != nil {
  141. var configBytes []byte
  142. var err error
  143. if patch.Type == api.StorageS3 {
  144. configBytes, err = json.Marshal(patch.Config.S3Config)
  145. if err != nil {
  146. return nil, err
  147. }
  148. } else {
  149. return nil, fmt.Errorf("unsupported storage type %s", string(patch.Type))
  150. }
  151. set, args = append(set, "config = ?"), append(args, string(configBytes))
  152. }
  153. args = append(args, patch.ID)
  154. query := `
  155. UPDATE storage
  156. SET ` + strings.Join(set, ", ") + `
  157. WHERE id = ?
  158. RETURNING id, name, type, config
  159. `
  160. var storageRaw storageRaw
  161. var storageConfig string
  162. if err := tx.QueryRowContext(ctx, query, args...).Scan(
  163. &storageRaw.ID,
  164. &storageRaw.Name,
  165. &storageRaw.Type,
  166. &storageConfig,
  167. ); err != nil {
  168. return nil, FormatError(err)
  169. }
  170. if storageRaw.Type == api.StorageS3 {
  171. s3Config := &api.StorageS3Config{}
  172. if err := json.Unmarshal([]byte(storageConfig), s3Config); err != nil {
  173. return nil, err
  174. }
  175. storageRaw.Config = &api.StorageConfig{
  176. S3Config: s3Config,
  177. }
  178. } else {
  179. return nil, fmt.Errorf("unsupported storage type %s", string(storageRaw.Type))
  180. }
  181. return &storageRaw, nil
  182. }
  183. func findStorageRawList(ctx context.Context, tx *sql.Tx, find *api.StorageFind) ([]*storageRaw, error) {
  184. where, args := []string{"1 = 1"}, []any{}
  185. if v := find.ID; v != nil {
  186. where, args = append(where, "id = ?"), append(args, *v)
  187. }
  188. query := `
  189. SELECT
  190. id,
  191. name,
  192. type,
  193. config
  194. FROM storage
  195. WHERE ` + strings.Join(where, " AND ") + `
  196. ORDER BY id DESC
  197. `
  198. rows, err := tx.QueryContext(ctx, query, args...)
  199. if err != nil {
  200. return nil, FormatError(err)
  201. }
  202. defer rows.Close()
  203. storageRawList := make([]*storageRaw, 0)
  204. for rows.Next() {
  205. var storageRaw storageRaw
  206. var storageConfig string
  207. if err := rows.Scan(
  208. &storageRaw.ID,
  209. &storageRaw.Name,
  210. &storageRaw.Type,
  211. &storageConfig,
  212. ); err != nil {
  213. return nil, FormatError(err)
  214. }
  215. if storageRaw.Type == api.StorageS3 {
  216. s3Config := &api.StorageS3Config{}
  217. if err := json.Unmarshal([]byte(storageConfig), s3Config); err != nil {
  218. return nil, err
  219. }
  220. storageRaw.Config = &api.StorageConfig{
  221. S3Config: s3Config,
  222. }
  223. } else {
  224. return nil, fmt.Errorf("unsupported storage type %s", string(storageRaw.Type))
  225. }
  226. storageRawList = append(storageRawList, &storageRaw)
  227. }
  228. if err := rows.Err(); err != nil {
  229. return nil, FormatError(err)
  230. }
  231. return storageRawList, nil
  232. }
  233. func deleteStorage(ctx context.Context, tx *sql.Tx, delete *api.StorageDelete) error {
  234. where, args := []string{"id = ?"}, []any{delete.ID}
  235. stmt := `DELETE FROM storage WHERE ` + strings.Join(where, " AND ")
  236. result, err := tx.ExecContext(ctx, stmt, args...)
  237. if err != nil {
  238. return FormatError(err)
  239. }
  240. rows, _ := result.RowsAffected()
  241. if rows == 0 {
  242. return &common.Error{Code: common.NotFound, Err: fmt.Errorf("storage not found")}
  243. }
  244. return nil
  245. }