memo.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/usememos/memos/common"
  10. )
  11. // Visibility is the type of a visibility.
  12. type Visibility string
  13. const (
  14. // Public is the PUBLIC visibility.
  15. Public Visibility = "PUBLIC"
  16. // Protected is the PROTECTED visibility.
  17. Protected Visibility = "PROTECTED"
  18. // Private is the PRIVATE visibility.
  19. Private Visibility = "PRIVATE"
  20. )
  21. func (v Visibility) String() string {
  22. switch v {
  23. case Public:
  24. return "PUBLIC"
  25. case Protected:
  26. return "PROTECTED"
  27. case Private:
  28. return "PRIVATE"
  29. }
  30. return "PRIVATE"
  31. }
  32. type MemoMessage struct {
  33. ID int
  34. // Standard fields
  35. RowStatus RowStatus
  36. CreatorID int
  37. CreatedTs int64
  38. UpdatedTs int64
  39. // Domain specific fields
  40. Content string
  41. Visibility Visibility
  42. // Composed fields
  43. Pinned bool
  44. ResourceIDList []int
  45. RelationList []*MemoRelationMessage
  46. }
  47. type FindMemoMessage struct {
  48. ID *int
  49. // Standard fields
  50. RowStatus *RowStatus
  51. CreatorID *int
  52. // Domain specific fields
  53. Pinned *bool
  54. ContentSearch []string
  55. VisibilityList []Visibility
  56. // Pagination
  57. Limit *int
  58. Offset *int
  59. OrderByUpdatedTs bool
  60. }
  61. type UpdateMemoMessage struct {
  62. ID int
  63. CreatedTs *int64
  64. UpdatedTs *int64
  65. RowStatus *RowStatus
  66. Content *string
  67. Visibility *Visibility
  68. }
  69. type DeleteMemoMessage struct {
  70. ID int
  71. }
  72. func (s *Store) CreateMemo(ctx context.Context, create *MemoMessage) (*MemoMessage, error) {
  73. tx, err := s.db.BeginTx(ctx, nil)
  74. if err != nil {
  75. return nil, FormatError(err)
  76. }
  77. defer tx.Rollback()
  78. if create.CreatedTs == 0 {
  79. create.CreatedTs = time.Now().Unix()
  80. }
  81. query := `
  82. INSERT INTO memo (
  83. creator_id,
  84. created_ts,
  85. content,
  86. visibility
  87. )
  88. VALUES (?, ?, ?, ?)
  89. RETURNING id, created_ts, updated_ts, row_status
  90. `
  91. if err := tx.QueryRowContext(
  92. ctx,
  93. query,
  94. create.CreatorID,
  95. create.CreatedTs,
  96. create.Content,
  97. create.Visibility,
  98. ).Scan(
  99. &create.ID,
  100. &create.CreatedTs,
  101. &create.UpdatedTs,
  102. &create.RowStatus,
  103. ); err != nil {
  104. return nil, FormatError(err)
  105. }
  106. if err := tx.Commit(); err != nil {
  107. return nil, FormatError(err)
  108. }
  109. memoMessage := create
  110. return memoMessage, nil
  111. }
  112. func (s *Store) ListMemos(ctx context.Context, find *FindMemoMessage) ([]*MemoMessage, error) {
  113. tx, err := s.db.BeginTx(ctx, nil)
  114. if err != nil {
  115. return nil, FormatError(err)
  116. }
  117. defer tx.Rollback()
  118. list, err := listMemos(ctx, tx, find)
  119. if err != nil {
  120. return nil, err
  121. }
  122. return list, nil
  123. }
  124. func (s *Store) GetMemo(ctx context.Context, find *FindMemoMessage) (*MemoMessage, error) {
  125. tx, err := s.db.BeginTx(ctx, nil)
  126. if err != nil {
  127. return nil, FormatError(err)
  128. }
  129. defer tx.Rollback()
  130. list, err := listMemos(ctx, tx, find)
  131. if err != nil {
  132. return nil, err
  133. }
  134. if len(list) == 0 {
  135. return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("memo not found")}
  136. }
  137. memoMessage := list[0]
  138. return memoMessage, nil
  139. }
  140. func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemoMessage) error {
  141. tx, err := s.db.BeginTx(ctx, nil)
  142. if err != nil {
  143. return err
  144. }
  145. defer tx.Rollback()
  146. set, args := []string{}, []any{}
  147. if v := update.CreatedTs; v != nil {
  148. set, args = append(set, "created_ts = ?"), append(args, *v)
  149. }
  150. if v := update.UpdatedTs; v != nil {
  151. set, args = append(set, "updated_ts = ?"), append(args, *v)
  152. }
  153. if v := update.RowStatus; v != nil {
  154. set, args = append(set, "row_status = ?"), append(args, *v)
  155. }
  156. if v := update.Content; v != nil {
  157. set, args = append(set, "content = ?"), append(args, *v)
  158. }
  159. if v := update.Visibility; v != nil {
  160. set, args = append(set, "visibility = ?"), append(args, *v)
  161. }
  162. args = append(args, update.ID)
  163. query := `
  164. UPDATE memo
  165. SET ` + strings.Join(set, ", ") + `
  166. WHERE id = ?
  167. `
  168. if _, err := tx.ExecContext(ctx, query, args...); err != nil {
  169. return err
  170. }
  171. err = tx.Commit()
  172. return err
  173. }
  174. func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemoMessage) error {
  175. tx, err := s.db.BeginTx(ctx, nil)
  176. if err != nil {
  177. return FormatError(err)
  178. }
  179. defer tx.Rollback()
  180. where, args := []string{"id = ?"}, []any{delete.ID}
  181. stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
  182. result, err := tx.ExecContext(ctx, stmt, args...)
  183. if err != nil {
  184. return FormatError(err)
  185. }
  186. rows, err := result.RowsAffected()
  187. if err != nil {
  188. return err
  189. }
  190. if rows == 0 {
  191. return &common.Error{Code: common.NotFound, Err: fmt.Errorf("idp not found")}
  192. }
  193. if err := s.vacuumImpl(ctx, tx); err != nil {
  194. return err
  195. }
  196. err = tx.Commit()
  197. return err
  198. }
  199. func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) {
  200. tx, err := s.db.BeginTx(ctx, nil)
  201. if err != nil {
  202. return nil, FormatError(err)
  203. }
  204. defer tx.Rollback()
  205. args := make([]any, 0, len(memoIDs))
  206. list := make([]string, 0, len(memoIDs))
  207. for _, memoID := range memoIDs {
  208. args = append(args, memoID)
  209. list = append(list, "?")
  210. }
  211. where := fmt.Sprintf("id in (%s)", strings.Join(list, ","))
  212. query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where
  213. rows, err := tx.QueryContext(ctx, query, args...)
  214. if err != nil {
  215. return nil, FormatError(err)
  216. }
  217. defer rows.Close()
  218. visibilityList := make([]Visibility, 0)
  219. for rows.Next() {
  220. var visibility Visibility
  221. if err := rows.Scan(&visibility); err != nil {
  222. return nil, FormatError(err)
  223. }
  224. visibilityList = append(visibilityList, visibility)
  225. }
  226. if err := rows.Err(); err != nil {
  227. return nil, FormatError(err)
  228. }
  229. return visibilityList, nil
  230. }
  231. func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemoMessage) ([]*MemoMessage, error) {
  232. where, args := []string{"1 = 1"}, []any{}
  233. if v := find.ID; v != nil {
  234. where, args = append(where, "memo.id = ?"), append(args, *v)
  235. }
  236. if v := find.CreatorID; v != nil {
  237. where, args = append(where, "memo.creator_id = ?"), append(args, *v)
  238. }
  239. if v := find.RowStatus; v != nil {
  240. where, args = append(where, "memo.row_status = ?"), append(args, *v)
  241. }
  242. if v := find.Pinned; v != nil {
  243. where = append(where, "memo_organizer.pinned = 1")
  244. }
  245. if v := find.ContentSearch; len(v) != 0 {
  246. for _, s := range v {
  247. where, args = append(where, "memo.content LIKE ?"), append(args, "%"+s+"%")
  248. }
  249. }
  250. if v := find.VisibilityList; len(v) != 0 {
  251. list := []string{}
  252. for _, visibility := range v {
  253. list = append(list, fmt.Sprintf("$%d", len(args)+1))
  254. args = append(args, visibility)
  255. }
  256. where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(list, ",")))
  257. }
  258. orders := []string{"pinned DESC"}
  259. if find.OrderByUpdatedTs {
  260. orders = append(orders, "updated_ts DESC")
  261. } else {
  262. orders = append(orders, "created_ts DESC")
  263. }
  264. orders = append(orders, "id DESC")
  265. query := `
  266. SELECT
  267. memo.id AS id,
  268. memo.creator_id AS creator_id,
  269. memo.created_ts AS created_ts,
  270. memo.updated_ts AS updated_ts,
  271. memo.row_status AS row_status,
  272. memo.content AS content,
  273. memo.visibility AS visibility,
  274. CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END AS pinned,
  275. GROUP_CONCAT(memo_resource.resource_id) AS resource_id_list,
  276. (
  277. SELECT
  278. GROUP_CONCAT(related_memo_id || ':' || type)
  279. FROM
  280. memo_relation
  281. WHERE
  282. memo_relation.memo_id = memo.id
  283. GROUP BY
  284. memo_relation.memo_id
  285. ) AS relation_list
  286. FROM
  287. memo
  288. LEFT JOIN
  289. memo_organizer ON memo.id = memo_organizer.memo_id
  290. LEFT JOIN
  291. memo_resource ON memo.id = memo_resource.memo_id
  292. WHERE ` + strings.Join(where, " AND ") + `
  293. GROUP BY memo.id
  294. ORDER BY ` + strings.Join(orders, ", ") + `
  295. `
  296. if find.Limit != nil {
  297. query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
  298. if find.Offset != nil {
  299. query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
  300. }
  301. }
  302. rows, err := tx.QueryContext(ctx, query, args...)
  303. if err != nil {
  304. return nil, FormatError(err)
  305. }
  306. defer rows.Close()
  307. memoMessageList := make([]*MemoMessage, 0)
  308. for rows.Next() {
  309. var memoMessage MemoMessage
  310. var memoResourceIDList sql.NullString
  311. var memoRelationList sql.NullString
  312. if err := rows.Scan(
  313. &memoMessage.ID,
  314. &memoMessage.CreatorID,
  315. &memoMessage.CreatedTs,
  316. &memoMessage.UpdatedTs,
  317. &memoMessage.RowStatus,
  318. &memoMessage.Content,
  319. &memoMessage.Visibility,
  320. &memoMessage.Pinned,
  321. &memoResourceIDList,
  322. &memoRelationList,
  323. ); err != nil {
  324. return nil, FormatError(err)
  325. }
  326. if memoResourceIDList.Valid {
  327. idStringList := strings.Split(memoResourceIDList.String, ",")
  328. memoMessage.ResourceIDList = make([]int, 0, len(idStringList))
  329. for _, idString := range idStringList {
  330. id, err := strconv.Atoi(idString)
  331. if err != nil {
  332. return nil, FormatError(err)
  333. }
  334. memoMessage.ResourceIDList = append(memoMessage.ResourceIDList, id)
  335. }
  336. }
  337. if memoRelationList.Valid {
  338. memoMessage.RelationList = make([]*MemoRelationMessage, 0)
  339. relatedMemoTypeList := strings.Split(memoRelationList.String, ",")
  340. for _, relatedMemoType := range relatedMemoTypeList {
  341. relatedMemoTypeList := strings.Split(relatedMemoType, ":")
  342. if len(relatedMemoTypeList) != 2 {
  343. return nil, &common.Error{Code: common.Invalid, Err: fmt.Errorf("invalid relation format")}
  344. }
  345. relatedMemoID, err := strconv.Atoi(relatedMemoTypeList[0])
  346. if err != nil {
  347. return nil, FormatError(err)
  348. }
  349. memoMessage.RelationList = append(memoMessage.RelationList, &MemoRelationMessage{
  350. MemoID: memoMessage.ID,
  351. RelatedMemoID: relatedMemoID,
  352. Type: MemoRelationType(relatedMemoTypeList[1]),
  353. })
  354. }
  355. }
  356. memoMessageList = append(memoMessageList, &memoMessage)
  357. }
  358. if err := rows.Err(); err != nil {
  359. return nil, FormatError(err)
  360. }
  361. return memoMessageList, nil
  362. }
  363. func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
  364. stmt := `
  365. DELETE FROM
  366. memo
  367. WHERE
  368. creator_id NOT IN (
  369. SELECT
  370. id
  371. FROM
  372. user
  373. )`
  374. _, err := tx.ExecContext(ctx, stmt)
  375. if err != nil {
  376. return FormatError(err)
  377. }
  378. return nil
  379. }