tag.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "github.com/usememos/memos/api"
  8. "github.com/usememos/memos/common"
  9. )
  10. type tagRaw struct {
  11. Name string
  12. CreatorID int
  13. }
  14. func (raw *tagRaw) toTag() *api.Tag {
  15. return &api.Tag{
  16. Name: raw.Name,
  17. CreatorID: raw.CreatorID,
  18. }
  19. }
  20. func (s *Store) UpsertTag(ctx context.Context, upsert *api.TagUpsert) (*api.Tag, error) {
  21. tx, err := s.db.BeginTx(ctx, nil)
  22. if err != nil {
  23. return nil, FormatError(err)
  24. }
  25. defer tx.Rollback()
  26. tagRaw, err := upsertTag(ctx, tx, upsert)
  27. if err != nil {
  28. return nil, err
  29. }
  30. if err := tx.Commit(); err != nil {
  31. return nil, err
  32. }
  33. tag := tagRaw.toTag()
  34. return tag, nil
  35. }
  36. func (s *Store) FindTagList(ctx context.Context, find *api.TagFind) ([]*api.Tag, error) {
  37. tx, err := s.db.BeginTx(ctx, nil)
  38. if err != nil {
  39. return nil, FormatError(err)
  40. }
  41. defer tx.Rollback()
  42. tagRawList, err := findTagList(ctx, tx, find)
  43. if err != nil {
  44. return nil, err
  45. }
  46. list := []*api.Tag{}
  47. for _, raw := range tagRawList {
  48. list = append(list, raw.toTag())
  49. }
  50. return list, nil
  51. }
  52. func (s *Store) DeleteTag(ctx context.Context, delete *api.TagDelete) error {
  53. tx, err := s.db.BeginTx(ctx, nil)
  54. if err != nil {
  55. return FormatError(err)
  56. }
  57. defer tx.Rollback()
  58. if err := deleteTag(ctx, tx, delete); err != nil {
  59. return FormatError(err)
  60. }
  61. if err := tx.Commit(); err != nil {
  62. return FormatError(err)
  63. }
  64. return nil
  65. }
  66. func upsertTag(ctx context.Context, tx *sql.Tx, upsert *api.TagUpsert) (*tagRaw, error) {
  67. query := `
  68. INSERT INTO tag (
  69. name, creator_id
  70. )
  71. VALUES (?, ?)
  72. ON CONFLICT(name, creator_id) DO UPDATE
  73. SET
  74. name = EXCLUDED.name
  75. RETURNING name, creator_id
  76. `
  77. var tagRaw tagRaw
  78. if err := tx.QueryRowContext(ctx, query, upsert.Name, upsert.CreatorID).Scan(
  79. &tagRaw.Name,
  80. &tagRaw.CreatorID,
  81. ); err != nil {
  82. return nil, FormatError(err)
  83. }
  84. return &tagRaw, nil
  85. }
  86. func findTagList(ctx context.Context, tx *sql.Tx, find *api.TagFind) ([]*tagRaw, error) {
  87. where, args := []string{"creator_id = ?"}, []interface{}{find.CreatorID}
  88. query := `
  89. SELECT
  90. name,
  91. creator_id
  92. FROM tag
  93. WHERE ` + strings.Join(where, " AND ")
  94. rows, err := tx.QueryContext(ctx, query, args...)
  95. if err != nil {
  96. return nil, FormatError(err)
  97. }
  98. defer rows.Close()
  99. tagRawList := make([]*tagRaw, 0)
  100. for rows.Next() {
  101. var tagRaw tagRaw
  102. if err := rows.Scan(
  103. &tagRaw.Name,
  104. &tagRaw.CreatorID,
  105. ); err != nil {
  106. return nil, FormatError(err)
  107. }
  108. tagRawList = append(tagRawList, &tagRaw)
  109. }
  110. if err := rows.Err(); err != nil {
  111. return nil, FormatError(err)
  112. }
  113. return tagRawList, nil
  114. }
  115. func deleteTag(ctx context.Context, tx *sql.Tx, delete *api.TagDelete) error {
  116. where, args := []string{"name = ?", "creator_id = ?"}, []interface{}{delete.Name, delete.CreatorID}
  117. stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ")
  118. result, err := tx.ExecContext(ctx, stmt, args...)
  119. if err != nil {
  120. return FormatError(err)
  121. }
  122. rows, _ := result.RowsAffected()
  123. if rows == 0 {
  124. return &common.Error{Code: common.NotFound, Err: fmt.Errorf("tag not found")}
  125. }
  126. return nil
  127. }
  128. func vacuumTag(ctx context.Context, tx *sql.Tx) error {
  129. stmt := `
  130. DELETE FROM
  131. tag
  132. WHERE
  133. creator_id NOT IN (
  134. SELECT
  135. id
  136. FROM
  137. user
  138. )`
  139. _, err := tx.ExecContext(ctx, stmt)
  140. if err != nil {
  141. return FormatError(err)
  142. }
  143. return nil
  144. }