tag.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "github.com/usememos/memos/api"
  8. "github.com/usememos/memos/common"
  9. )
  10. type tagRaw struct {
  11. Name string
  12. CreatorID int
  13. }
  14. func (raw *tagRaw) toTag() *api.Tag {
  15. return &api.Tag{
  16. Name: raw.Name,
  17. CreatorID: raw.CreatorID,
  18. }
  19. }
  20. func (s *Store) UpsertTag(ctx context.Context, upsert *api.TagUpsert) (*api.Tag, error) {
  21. tx, err := s.db.BeginTx(ctx, nil)
  22. if err != nil {
  23. return nil, FormatError(err)
  24. }
  25. defer tx.Rollback()
  26. tagRaw, err := upsertTag(ctx, tx, upsert)
  27. if err != nil {
  28. return nil, err
  29. }
  30. if err := tx.Commit(); err != nil {
  31. return nil, err
  32. }
  33. tag := tagRaw.toTag()
  34. return tag, nil
  35. }
  36. func (s *Store) FindTagList(ctx context.Context, find *api.TagFind) ([]*api.Tag, error) {
  37. tx, err := s.db.BeginTx(ctx, nil)
  38. if err != nil {
  39. return nil, FormatError(err)
  40. }
  41. defer tx.Rollback()
  42. tagRawList, err := findTagList(ctx, tx, find)
  43. if err != nil {
  44. return nil, err
  45. }
  46. list := []*api.Tag{}
  47. for _, raw := range tagRawList {
  48. list = append(list, raw.toTag())
  49. }
  50. return list, nil
  51. }
  52. func (s *Store) DeleteTag(ctx context.Context, delete *api.TagDelete) error {
  53. tx, err := s.db.BeginTx(ctx, nil)
  54. if err != nil {
  55. return FormatError(err)
  56. }
  57. defer tx.Rollback()
  58. if err := deleteTag(ctx, tx, delete); err != nil {
  59. return FormatError(err)
  60. }
  61. if err := tx.Commit(); err != nil {
  62. return FormatError(err)
  63. }
  64. return nil
  65. }
  66. func upsertTag(ctx context.Context, tx *sql.Tx, upsert *api.TagUpsert) (*tagRaw, error) {
  67. query := `
  68. INSERT INTO tag (
  69. name, creator_id
  70. )
  71. VALUES (?, ?)
  72. ON CONFLICT(name, creator_id) DO UPDATE
  73. SET
  74. name = EXCLUDED.name
  75. RETURNING name, creator_id
  76. `
  77. var tagRaw tagRaw
  78. if err := tx.QueryRowContext(ctx, query, upsert.Name, upsert.CreatorID).Scan(
  79. &tagRaw.Name,
  80. &tagRaw.CreatorID,
  81. ); err != nil {
  82. return nil, FormatError(err)
  83. }
  84. return &tagRaw, nil
  85. }
  86. func findTagList(ctx context.Context, tx *sql.Tx, find *api.TagFind) ([]*tagRaw, error) {
  87. where, args := []string{"creator_id = ?"}, []interface{}{find.CreatorID}
  88. query := `
  89. SELECT
  90. name,
  91. creator_id
  92. FROM tag
  93. WHERE ` + strings.Join(where, " AND ") + `
  94. ORDER BY name ASC
  95. `
  96. rows, err := tx.QueryContext(ctx, query, args...)
  97. if err != nil {
  98. return nil, FormatError(err)
  99. }
  100. defer rows.Close()
  101. tagRawList := make([]*tagRaw, 0)
  102. for rows.Next() {
  103. var tagRaw tagRaw
  104. if err := rows.Scan(
  105. &tagRaw.Name,
  106. &tagRaw.CreatorID,
  107. ); err != nil {
  108. return nil, FormatError(err)
  109. }
  110. tagRawList = append(tagRawList, &tagRaw)
  111. }
  112. if err := rows.Err(); err != nil {
  113. return nil, FormatError(err)
  114. }
  115. return tagRawList, nil
  116. }
  117. func deleteTag(ctx context.Context, tx *sql.Tx, delete *api.TagDelete) error {
  118. where, args := []string{"name = ?", "creator_id = ?"}, []interface{}{delete.Name, delete.CreatorID}
  119. stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ")
  120. result, err := tx.ExecContext(ctx, stmt, args...)
  121. if err != nil {
  122. return FormatError(err)
  123. }
  124. rows, _ := result.RowsAffected()
  125. if rows == 0 {
  126. return &common.Error{Code: common.NotFound, Err: fmt.Errorf("tag not found")}
  127. }
  128. return nil
  129. }
  130. func vacuumTag(ctx context.Context, tx *sql.Tx) error {
  131. stmt := `
  132. DELETE FROM
  133. tag
  134. WHERE
  135. creator_id NOT IN (
  136. SELECT
  137. id
  138. FROM
  139. user
  140. )`
  141. _, err := tx.ExecContext(ctx, stmt)
  142. if err != nil {
  143. return FormatError(err)
  144. }
  145. return nil
  146. }