memo_relation_test.go 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package testserver
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "testing"
  8. "github.com/pkg/errors"
  9. "github.com/stretchr/testify/require"
  10. "github.com/usememos/memos/api"
  11. apiv1 "github.com/usememos/memos/api/v1"
  12. )
  13. func TestMemoRelationServer(t *testing.T) {
  14. ctx := context.Background()
  15. s, err := NewTestingServer(ctx, t)
  16. require.NoError(t, err)
  17. defer s.Shutdown(ctx)
  18. signup := &apiv1.SignUp{
  19. Username: "testuser",
  20. Password: "testpassword",
  21. }
  22. user, err := s.postAuthSignup(signup)
  23. require.NoError(t, err)
  24. require.Equal(t, signup.Username, user.Username)
  25. memo, err := s.postMemoCreate(&api.CreateMemoRequest{
  26. Content: "test memo",
  27. })
  28. require.NoError(t, err)
  29. require.Equal(t, "test memo", memo.Content)
  30. memo2, err := s.postMemoCreate(&api.CreateMemoRequest{
  31. Content: "test memo2",
  32. RelationList: []*api.MemoRelationUpsert{
  33. {
  34. RelatedMemoID: memo.ID,
  35. Type: api.MemoRelationReference,
  36. },
  37. },
  38. })
  39. require.NoError(t, err)
  40. require.Equal(t, "test memo2", memo2.Content)
  41. memoList, err := s.getMemoList()
  42. require.NoError(t, err)
  43. require.Len(t, memoList, 2)
  44. require.Len(t, memo2.RelationList, 1)
  45. err = s.deleteMemoRelation(memo2.ID, memo.ID, api.MemoRelationReference)
  46. require.NoError(t, err)
  47. memo2, err = s.getMemo(memo2.ID)
  48. require.NoError(t, err)
  49. require.Len(t, memo2.RelationList, 0)
  50. memoRelation, err := s.postMemoRelationUpsert(memo2.ID, &api.MemoRelationUpsert{
  51. RelatedMemoID: memo.ID,
  52. Type: api.MemoRelationReference,
  53. })
  54. require.NoError(t, err)
  55. require.Equal(t, memo.ID, memoRelation.RelatedMemoID)
  56. memo2, err = s.getMemo(memo2.ID)
  57. require.NoError(t, err)
  58. require.Len(t, memo2.RelationList, 1)
  59. }
  60. func (s *TestingServer) postMemoRelationUpsert(memoID int, memoRelationUpsert *api.MemoRelationUpsert) (*api.MemoRelation, error) {
  61. rawData, err := json.Marshal(&memoRelationUpsert)
  62. if err != nil {
  63. return nil, errors.Wrap(err, "failed to marshal memo relation upsert")
  64. }
  65. reader := bytes.NewReader(rawData)
  66. body, err := s.post(fmt.Sprintf("/api/memo/%d/relation", memoID), reader, nil)
  67. if err != nil {
  68. return nil, err
  69. }
  70. buf := &bytes.Buffer{}
  71. _, err = buf.ReadFrom(body)
  72. if err != nil {
  73. return nil, errors.Wrap(err, "fail to read response body")
  74. }
  75. type MemoCreateResponse struct {
  76. Data *api.MemoRelation `json:"data"`
  77. }
  78. res := new(MemoCreateResponse)
  79. if err = json.Unmarshal(buf.Bytes(), res); err != nil {
  80. return nil, errors.Wrap(err, "fail to unmarshal post memo relation upsert response")
  81. }
  82. return res.Data, nil
  83. }
  84. func (s *TestingServer) deleteMemoRelation(memoID int, relatedMemoID int, relationType api.MemoRelationType) error {
  85. _, err := s.delete(fmt.Sprintf("/api/memo/%d/relation/%d/type/%s", memoID, relatedMemoID, relationType), nil)
  86. return err
  87. }