memo_relation_test.go 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package testserver
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "testing"
  8. "github.com/pkg/errors"
  9. "github.com/stretchr/testify/require"
  10. apiv1 "github.com/usememos/memos/api/v1"
  11. )
  12. func TestMemoRelationServer(t *testing.T) {
  13. ctx := context.Background()
  14. s, err := NewTestingServer(ctx, t)
  15. require.NoError(t, err)
  16. defer s.Shutdown(ctx)
  17. signup := &apiv1.SignUp{
  18. Username: "testuser",
  19. Password: "testpassword",
  20. }
  21. user, err := s.postAuthSignUp(signup)
  22. require.NoError(t, err)
  23. require.Equal(t, signup.Username, user.Username)
  24. memo, err := s.postMemoCreate(&apiv1.CreateMemoRequest{
  25. Content: "test memo",
  26. })
  27. require.NoError(t, err)
  28. require.Equal(t, "test memo", memo.Content)
  29. memo2, err := s.postMemoCreate(&apiv1.CreateMemoRequest{
  30. Content: "test memo2",
  31. RelationList: []*apiv1.UpsertMemoRelationRequest{
  32. {
  33. RelatedMemoID: memo.ID,
  34. Type: apiv1.MemoRelationReference,
  35. },
  36. },
  37. })
  38. require.NoError(t, err)
  39. require.Equal(t, "test memo2", memo2.Content)
  40. memoList, err := s.getMemoList()
  41. require.NoError(t, err)
  42. require.Len(t, memoList, 2)
  43. require.Len(t, memo2.RelationList, 1)
  44. err = s.deleteMemoRelation(memo2.ID, memo.ID, apiv1.MemoRelationReference)
  45. require.NoError(t, err)
  46. memo2, err = s.getMemo(memo2.ID)
  47. require.NoError(t, err)
  48. require.Len(t, memo2.RelationList, 0)
  49. memoRelation, err := s.postMemoRelationUpsert(memo2.ID, &apiv1.UpsertMemoRelationRequest{
  50. RelatedMemoID: memo.ID,
  51. Type: apiv1.MemoRelationReference,
  52. })
  53. require.NoError(t, err)
  54. require.Equal(t, memo.ID, memoRelation.RelatedMemoID)
  55. memo2, err = s.getMemo(memo2.ID)
  56. require.NoError(t, err)
  57. require.Len(t, memo2.RelationList, 1)
  58. }
  59. func (s *TestingServer) postMemoRelationUpsert(memoID int32, memoRelationUpsert *apiv1.UpsertMemoRelationRequest) (*apiv1.MemoRelation, error) {
  60. rawData, err := json.Marshal(&memoRelationUpsert)
  61. if err != nil {
  62. return nil, errors.Wrap(err, "failed to marshal memo relation upsert")
  63. }
  64. reader := bytes.NewReader(rawData)
  65. body, err := s.post(fmt.Sprintf("/api/v1/memo/%d/relation", memoID), reader, nil)
  66. if err != nil {
  67. return nil, err
  68. }
  69. buf := &bytes.Buffer{}
  70. _, err = buf.ReadFrom(body)
  71. if err != nil {
  72. return nil, errors.Wrap(err, "fail to read response body")
  73. }
  74. memoRelation := &apiv1.MemoRelation{}
  75. if err = json.Unmarshal(buf.Bytes(), memoRelation); err != nil {
  76. return nil, errors.Wrap(err, "fail to unmarshal post memo relation upsert response")
  77. }
  78. return memoRelation, nil
  79. }
  80. func (s *TestingServer) deleteMemoRelation(memoID int32, relatedMemoID int32, relationType apiv1.MemoRelationType) error {
  81. _, err := s.delete(fmt.Sprintf("/api/v1/memo/%d/relation/%d/type/%s", memoID, relatedMemoID, relationType), nil)
  82. return err
  83. }