memo_relation.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "github.com/usememos/memos/common"
  8. )
  9. type MemoRelationType string
  10. const (
  11. MemoRelationReference MemoRelationType = "REFERENCE"
  12. MemoRelationAdditional MemoRelationType = "ADDITIONAL"
  13. )
  14. type MemoRelationMessage struct {
  15. MemoID int
  16. RelatedMemoID int
  17. Type MemoRelationType
  18. }
  19. type FindMemoRelationMessage struct {
  20. MemoID *int
  21. RelatedMemoID *int
  22. Type *MemoRelationType
  23. }
  24. type DeleteMemoRelationMessage struct {
  25. MemoID *int
  26. RelatedMemoID *int
  27. Type *MemoRelationType
  28. }
  29. func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelationMessage) (*MemoRelationMessage, error) {
  30. tx, err := s.db.BeginTx(ctx, nil)
  31. if err != nil {
  32. return nil, FormatError(err)
  33. }
  34. defer tx.Rollback()
  35. query := `
  36. INSERT INTO memo_relation (
  37. memo_id,
  38. related_memo_id,
  39. type
  40. )
  41. VALUES (?, ?, ?)
  42. ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET
  43. type = EXCLUDED.type
  44. RETURNING memo_id, related_memo_id, type
  45. `
  46. memoRelationMessage := &MemoRelationMessage{}
  47. if err := tx.QueryRowContext(
  48. ctx,
  49. query,
  50. create.MemoID,
  51. create.RelatedMemoID,
  52. create.Type,
  53. ).Scan(
  54. &memoRelationMessage.MemoID,
  55. &memoRelationMessage.RelatedMemoID,
  56. &memoRelationMessage.Type,
  57. ); err != nil {
  58. return nil, FormatError(err)
  59. }
  60. if err := tx.Commit(); err != nil {
  61. return nil, FormatError(err)
  62. }
  63. return memoRelationMessage, nil
  64. }
  65. func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelationMessage) ([]*MemoRelationMessage, error) {
  66. tx, err := s.db.BeginTx(ctx, nil)
  67. if err != nil {
  68. return nil, FormatError(err)
  69. }
  70. defer tx.Rollback()
  71. list, err := listMemoRelations(ctx, tx, find)
  72. if err != nil {
  73. return nil, err
  74. }
  75. return list, nil
  76. }
  77. func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelationMessage) (*MemoRelationMessage, error) {
  78. tx, err := s.db.BeginTx(ctx, nil)
  79. if err != nil {
  80. return nil, FormatError(err)
  81. }
  82. defer tx.Rollback()
  83. list, err := listMemoRelations(ctx, tx, find)
  84. if err != nil {
  85. return nil, err
  86. }
  87. if len(list) == 0 {
  88. return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
  89. }
  90. return list[0], nil
  91. }
  92. func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelationMessage) error {
  93. tx, err := s.db.BeginTx(ctx, nil)
  94. if err != nil {
  95. return FormatError(err)
  96. }
  97. defer tx.Rollback()
  98. where, args := []string{"TRUE"}, []any{}
  99. if delete.MemoID != nil {
  100. where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
  101. }
  102. if delete.RelatedMemoID != nil {
  103. where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
  104. }
  105. if delete.Type != nil {
  106. where, args = append(where, "type = ?"), append(args, delete.Type)
  107. }
  108. query := `
  109. DELETE FROM memo_relation
  110. WHERE ` + strings.Join(where, " AND ")
  111. if _, err := tx.ExecContext(ctx, query, args...); err != nil {
  112. return FormatError(err)
  113. }
  114. if err := tx.Commit(); err != nil {
  115. return FormatError(err)
  116. }
  117. return nil
  118. }
  119. func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelationMessage) ([]*MemoRelationMessage, error) {
  120. where, args := []string{"TRUE"}, []any{}
  121. if find.MemoID != nil {
  122. where, args = append(where, "memo_id = ?"), append(args, find.MemoID)
  123. }
  124. if find.RelatedMemoID != nil {
  125. where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID)
  126. }
  127. if find.Type != nil {
  128. where, args = append(where, "type = ?"), append(args, find.Type)
  129. }
  130. rows, err := tx.QueryContext(ctx, `
  131. SELECT
  132. memo_id,
  133. related_memo_id,
  134. type
  135. FROM memo_relation
  136. WHERE `+strings.Join(where, " AND "), args...)
  137. if err != nil {
  138. return nil, FormatError(err)
  139. }
  140. defer rows.Close()
  141. memoRelationMessages := []*MemoRelationMessage{}
  142. for rows.Next() {
  143. memoRelationMessage := &MemoRelationMessage{}
  144. if err := rows.Scan(
  145. &memoRelationMessage.MemoID,
  146. &memoRelationMessage.RelatedMemoID,
  147. &memoRelationMessage.Type,
  148. ); err != nil {
  149. return nil, FormatError(err)
  150. }
  151. memoRelationMessages = append(memoRelationMessages, memoRelationMessage)
  152. }
  153. if err := rows.Err(); err != nil {
  154. return nil, FormatError(err)
  155. }
  156. return memoRelationMessages, nil
  157. }
  158. func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
  159. if _, err := tx.ExecContext(ctx, `
  160. DELETE FROM memo_relation
  161. WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo)
  162. `); err != nil {
  163. return FormatError(err)
  164. }
  165. return nil
  166. }