memo_relation.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "strings"
  6. )
  7. type MemoRelationType string
  8. const (
  9. MemoRelationReference MemoRelationType = "REFERENCE"
  10. MemoRelationAdditional MemoRelationType = "ADDITIONAL"
  11. )
  12. type MemoRelation struct {
  13. MemoID int
  14. RelatedMemoID int
  15. Type MemoRelationType
  16. }
  17. type FindMemoRelation struct {
  18. MemoID *int
  19. RelatedMemoID *int
  20. Type *MemoRelationType
  21. }
  22. type DeleteMemoRelation struct {
  23. MemoID *int
  24. RelatedMemoID *int
  25. Type *MemoRelationType
  26. }
  27. func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error) {
  28. tx, err := s.db.BeginTx(ctx, nil)
  29. if err != nil {
  30. return nil, err
  31. }
  32. defer tx.Rollback()
  33. query := `
  34. INSERT INTO memo_relation (
  35. memo_id,
  36. related_memo_id,
  37. type
  38. )
  39. VALUES (?, ?, ?)
  40. ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET
  41. type = EXCLUDED.type
  42. RETURNING memo_id, related_memo_id, type
  43. `
  44. memoRelation := &MemoRelation{}
  45. if err := tx.QueryRowContext(
  46. ctx,
  47. query,
  48. create.MemoID,
  49. create.RelatedMemoID,
  50. create.Type,
  51. ).Scan(
  52. &memoRelation.MemoID,
  53. &memoRelation.RelatedMemoID,
  54. &memoRelation.Type,
  55. ); err != nil {
  56. return nil, err
  57. }
  58. if err := tx.Commit(); err != nil {
  59. return nil, err
  60. }
  61. return memoRelation, nil
  62. }
  63. func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) {
  64. tx, err := s.db.BeginTx(ctx, nil)
  65. if err != nil {
  66. return nil, err
  67. }
  68. defer tx.Rollback()
  69. list, err := listMemoRelations(ctx, tx, find)
  70. if err != nil {
  71. return nil, err
  72. }
  73. if err := tx.Commit(); err != nil {
  74. return nil, err
  75. }
  76. return list, nil
  77. }
  78. func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*MemoRelation, error) {
  79. tx, err := s.db.BeginTx(ctx, nil)
  80. if err != nil {
  81. return nil, err
  82. }
  83. defer tx.Rollback()
  84. list, err := listMemoRelations(ctx, tx, find)
  85. if err != nil {
  86. return nil, err
  87. }
  88. if len(list) == 0 {
  89. return nil, nil
  90. }
  91. if err := tx.Commit(); err != nil {
  92. return nil, err
  93. }
  94. return list[0], nil
  95. }
  96. func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error {
  97. tx, err := s.db.BeginTx(ctx, nil)
  98. if err != nil {
  99. return err
  100. }
  101. defer tx.Rollback()
  102. where, args := []string{"TRUE"}, []any{}
  103. if delete.MemoID != nil {
  104. where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
  105. }
  106. if delete.RelatedMemoID != nil {
  107. where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
  108. }
  109. if delete.Type != nil {
  110. where, args = append(where, "type = ?"), append(args, delete.Type)
  111. }
  112. query := `
  113. DELETE FROM memo_relation
  114. WHERE ` + strings.Join(where, " AND ")
  115. if _, err := tx.ExecContext(ctx, query, args...); err != nil {
  116. return err
  117. }
  118. if err := tx.Commit(); err != nil {
  119. // Prevent lint warning.
  120. return err
  121. }
  122. return nil
  123. }
  124. func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation) ([]*MemoRelation, error) {
  125. where, args := []string{"TRUE"}, []any{}
  126. if find.MemoID != nil {
  127. where, args = append(where, "memo_id = ?"), append(args, find.MemoID)
  128. }
  129. if find.RelatedMemoID != nil {
  130. where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID)
  131. }
  132. if find.Type != nil {
  133. where, args = append(where, "type = ?"), append(args, find.Type)
  134. }
  135. rows, err := tx.QueryContext(ctx, `
  136. SELECT
  137. memo_id,
  138. related_memo_id,
  139. type
  140. FROM memo_relation
  141. WHERE `+strings.Join(where, " AND "), args...)
  142. if err != nil {
  143. return nil, err
  144. }
  145. defer rows.Close()
  146. memoRelationMessages := []*MemoRelation{}
  147. for rows.Next() {
  148. memoRelationMessage := &MemoRelation{}
  149. if err := rows.Scan(
  150. &memoRelationMessage.MemoID,
  151. &memoRelationMessage.RelatedMemoID,
  152. &memoRelationMessage.Type,
  153. ); err != nil {
  154. return nil, err
  155. }
  156. memoRelationMessages = append(memoRelationMessages, memoRelationMessage)
  157. }
  158. if err := rows.Err(); err != nil {
  159. return nil, err
  160. }
  161. return memoRelationMessages, nil
  162. }
  163. func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
  164. if _, err := tx.ExecContext(ctx, `
  165. DELETE FROM memo_relation
  166. WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo)
  167. `); err != nil {
  168. return err
  169. }
  170. return nil
  171. }