shortcut.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "github.com/usememos/memos/api"
  8. "github.com/usememos/memos/common"
  9. )
  10. // shortcutRaw is the store model for an Shortcut.
  11. // Fields have exactly the same meanings as Shortcut.
  12. type shortcutRaw struct {
  13. ID int
  14. // Standard fields
  15. RowStatus api.RowStatus
  16. CreatorID int
  17. CreatedTs int64
  18. UpdatedTs int64
  19. // Domain specific fields
  20. Title string
  21. Payload string
  22. }
  23. func (raw *shortcutRaw) toShortcut() *api.Shortcut {
  24. return &api.Shortcut{
  25. ID: raw.ID,
  26. RowStatus: raw.RowStatus,
  27. CreatorID: raw.CreatorID,
  28. CreatedTs: raw.CreatedTs,
  29. UpdatedTs: raw.UpdatedTs,
  30. Title: raw.Title,
  31. Payload: raw.Payload,
  32. }
  33. }
  34. func (s *Store) CreateShortcut(ctx context.Context, create *api.ShortcutCreate) (*api.Shortcut, error) {
  35. tx, err := s.db.BeginTx(ctx, nil)
  36. if err != nil {
  37. return nil, FormatError(err)
  38. }
  39. defer tx.Rollback()
  40. shortcutRaw, err := createShortcut(ctx, tx, create)
  41. if err != nil {
  42. return nil, err
  43. }
  44. if err := tx.Commit(); err != nil {
  45. return nil, FormatError(err)
  46. }
  47. s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw)
  48. shortcut := shortcutRaw.toShortcut()
  49. return shortcut, nil
  50. }
  51. func (s *Store) PatchShortcut(ctx context.Context, patch *api.ShortcutPatch) (*api.Shortcut, error) {
  52. tx, err := s.db.BeginTx(ctx, nil)
  53. if err != nil {
  54. return nil, FormatError(err)
  55. }
  56. defer tx.Rollback()
  57. shortcutRaw, err := patchShortcut(ctx, tx, patch)
  58. if err != nil {
  59. return nil, err
  60. }
  61. if err := tx.Commit(); err != nil {
  62. return nil, FormatError(err)
  63. }
  64. s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw)
  65. shortcut := shortcutRaw.toShortcut()
  66. return shortcut, nil
  67. }
  68. func (s *Store) FindShortcutList(ctx context.Context, find *api.ShortcutFind) ([]*api.Shortcut, error) {
  69. tx, err := s.db.BeginTx(ctx, nil)
  70. if err != nil {
  71. return nil, FormatError(err)
  72. }
  73. defer tx.Rollback()
  74. shortcutRawList, err := findShortcutList(ctx, tx, find)
  75. if err != nil {
  76. return nil, err
  77. }
  78. list := []*api.Shortcut{}
  79. for _, raw := range shortcutRawList {
  80. list = append(list, raw.toShortcut())
  81. }
  82. return list, nil
  83. }
  84. func (s *Store) FindShortcut(ctx context.Context, find *api.ShortcutFind) (*api.Shortcut, error) {
  85. if find.ID != nil {
  86. if shortcut, ok := s.shortcutCache.Load(*find.ID); ok {
  87. return shortcut.(*shortcutRaw).toShortcut(), nil
  88. }
  89. }
  90. tx, err := s.db.BeginTx(ctx, nil)
  91. if err != nil {
  92. return nil, FormatError(err)
  93. }
  94. defer tx.Rollback()
  95. list, err := findShortcutList(ctx, tx, find)
  96. if err != nil {
  97. return nil, err
  98. }
  99. if len(list) == 0 {
  100. return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
  101. }
  102. shortcutRaw := list[0]
  103. s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw)
  104. shortcut := shortcutRaw.toShortcut()
  105. return shortcut, nil
  106. }
  107. func (s *Store) DeleteShortcut(ctx context.Context, delete *api.ShortcutDelete) error {
  108. tx, err := s.db.BeginTx(ctx, nil)
  109. if err != nil {
  110. return FormatError(err)
  111. }
  112. defer tx.Rollback()
  113. err = deleteShortcut(ctx, tx, delete)
  114. if err != nil {
  115. return FormatError(err)
  116. }
  117. if err := tx.Commit(); err != nil {
  118. return FormatError(err)
  119. }
  120. s.shortcutCache.Delete(*delete.ID)
  121. return nil
  122. }
  123. func createShortcut(ctx context.Context, tx *sql.Tx, create *api.ShortcutCreate) (*shortcutRaw, error) {
  124. query := `
  125. INSERT INTO shortcut (
  126. title,
  127. payload,
  128. creator_id
  129. )
  130. VALUES (?, ?, ?)
  131. RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
  132. `
  133. var shortcutRaw shortcutRaw
  134. if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan(
  135. &shortcutRaw.ID,
  136. &shortcutRaw.Title,
  137. &shortcutRaw.Payload,
  138. &shortcutRaw.CreatorID,
  139. &shortcutRaw.CreatedTs,
  140. &shortcutRaw.UpdatedTs,
  141. &shortcutRaw.RowStatus,
  142. ); err != nil {
  143. return nil, FormatError(err)
  144. }
  145. return &shortcutRaw, nil
  146. }
  147. func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (*shortcutRaw, error) {
  148. set, args := []string{}, []interface{}{}
  149. if v := patch.UpdatedTs; v != nil {
  150. set, args = append(set, "updated_ts = ?"), append(args, *v)
  151. }
  152. if v := patch.Title; v != nil {
  153. set, args = append(set, "title = ?"), append(args, *v)
  154. }
  155. if v := patch.Payload; v != nil {
  156. set, args = append(set, "payload = ?"), append(args, *v)
  157. }
  158. if v := patch.RowStatus; v != nil {
  159. set, args = append(set, "row_status = ?"), append(args, *v)
  160. }
  161. args = append(args, patch.ID)
  162. query := `
  163. UPDATE shortcut
  164. SET ` + strings.Join(set, ", ") + `
  165. WHERE id = ?
  166. RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
  167. `
  168. var shortcutRaw shortcutRaw
  169. if err := tx.QueryRowContext(ctx, query, args...).Scan(
  170. &shortcutRaw.ID,
  171. &shortcutRaw.Title,
  172. &shortcutRaw.Payload,
  173. &shortcutRaw.CreatorID,
  174. &shortcutRaw.CreatedTs,
  175. &shortcutRaw.UpdatedTs,
  176. &shortcutRaw.RowStatus,
  177. ); err != nil {
  178. return nil, FormatError(err)
  179. }
  180. return &shortcutRaw, nil
  181. }
  182. func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) ([]*shortcutRaw, error) {
  183. where, args := []string{"1 = 1"}, []interface{}{}
  184. if v := find.ID; v != nil {
  185. where, args = append(where, "id = ?"), append(args, *v)
  186. }
  187. if v := find.CreatorID; v != nil {
  188. where, args = append(where, "creator_id = ?"), append(args, *v)
  189. }
  190. if v := find.Title; v != nil {
  191. where, args = append(where, "title = ?"), append(args, *v)
  192. }
  193. rows, err := tx.QueryContext(ctx, `
  194. SELECT
  195. id,
  196. title,
  197. payload,
  198. creator_id,
  199. created_ts,
  200. updated_ts,
  201. row_status
  202. FROM shortcut
  203. WHERE `+strings.Join(where, " AND ")+`
  204. ORDER BY created_ts DESC`,
  205. args...,
  206. )
  207. if err != nil {
  208. return nil, FormatError(err)
  209. }
  210. defer rows.Close()
  211. shortcutRawList := make([]*shortcutRaw, 0)
  212. for rows.Next() {
  213. var shortcutRaw shortcutRaw
  214. if err := rows.Scan(
  215. &shortcutRaw.ID,
  216. &shortcutRaw.Title,
  217. &shortcutRaw.Payload,
  218. &shortcutRaw.CreatorID,
  219. &shortcutRaw.CreatedTs,
  220. &shortcutRaw.UpdatedTs,
  221. &shortcutRaw.RowStatus,
  222. ); err != nil {
  223. return nil, FormatError(err)
  224. }
  225. shortcutRawList = append(shortcutRawList, &shortcutRaw)
  226. }
  227. if err := rows.Err(); err != nil {
  228. return nil, FormatError(err)
  229. }
  230. return shortcutRawList, nil
  231. }
  232. func deleteShortcut(ctx context.Context, tx *sql.Tx, delete *api.ShortcutDelete) error {
  233. where, args := []string{}, []interface{}{}
  234. if v := delete.ID; v != nil {
  235. where, args = append(where, "id = ?"), append(args, *v)
  236. }
  237. if v := delete.CreatorID; v != nil {
  238. where, args = append(where, "creator_id = ?"), append(args, *v)
  239. }
  240. stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ")
  241. result, err := tx.ExecContext(ctx, stmt, args...)
  242. if err != nil {
  243. return FormatError(err)
  244. }
  245. rows, _ := result.RowsAffected()
  246. if rows == 0 {
  247. return &common.Error{Code: common.NotFound, Err: fmt.Errorf("shortcut not found")}
  248. }
  249. return nil
  250. }
  251. func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
  252. stmt := `
  253. DELETE FROM
  254. shortcut
  255. WHERE
  256. creator_id NOT IN (
  257. SELECT
  258. id
  259. FROM
  260. user
  261. )`
  262. _, err := tx.ExecContext(ctx, stmt)
  263. if err != nil {
  264. return FormatError(err)
  265. }
  266. return nil
  267. }