server.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. package testserver
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "strings"
  9. "testing"
  10. "time"
  11. "github.com/pkg/errors"
  12. "github.com/usememos/memos/server"
  13. "github.com/usememos/memos/server/profile"
  14. "github.com/usememos/memos/store"
  15. "github.com/usememos/memos/store/db"
  16. "github.com/usememos/memos/test"
  17. // sqlite driver.
  18. _ "modernc.org/sqlite"
  19. )
  20. type TestingServer struct {
  21. server *server.Server
  22. client *http.Client
  23. profile *profile.Profile
  24. cookie string
  25. }
  26. func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) {
  27. profile := test.GetTestingProfile(t)
  28. db := db.NewDB(profile)
  29. if err := db.Open(ctx); err != nil {
  30. return nil, errors.Wrap(err, "failed to open db")
  31. }
  32. store := store.New(db.DBInstance, profile)
  33. server, err := server.NewServer(ctx, profile, store)
  34. if err != nil {
  35. return nil, errors.Wrap(err, "failed to create server")
  36. }
  37. s := &TestingServer{
  38. server: server,
  39. client: &http.Client{},
  40. profile: profile,
  41. cookie: "",
  42. }
  43. errChan := make(chan error, 1)
  44. go func() {
  45. if err := s.server.Start(ctx); err != nil {
  46. if err != http.ErrServerClosed {
  47. errChan <- errors.Wrap(err, "failed to run main server")
  48. }
  49. }
  50. }()
  51. if err := s.waitForServerStart(errChan); err != nil {
  52. return nil, errors.Wrap(err, "failed to start server")
  53. }
  54. return s, nil
  55. }
  56. func (s *TestingServer) Shutdown(ctx context.Context) {
  57. s.server.Shutdown(ctx)
  58. }
  59. func (s *TestingServer) waitForServerStart(errChan <-chan error) error {
  60. ticker := time.NewTicker(100 * time.Millisecond)
  61. defer ticker.Stop()
  62. for {
  63. select {
  64. case <-ticker.C:
  65. if s == nil {
  66. continue
  67. }
  68. e := s.server.GetEcho()
  69. if e == nil {
  70. continue
  71. }
  72. addr := e.ListenerAddr()
  73. if addr != nil && strings.Contains(addr.String(), ":") {
  74. return nil // was started
  75. }
  76. case err := <-errChan:
  77. if err == http.ErrServerClosed {
  78. return nil
  79. }
  80. return err
  81. }
  82. }
  83. }
  84. func (s *TestingServer) request(method, uri string, body io.Reader, params, header map[string]string) (io.ReadCloser, error) {
  85. fullURL := fmt.Sprintf("http://localhost:%d%s", s.profile.Port, uri)
  86. req, err := http.NewRequest(method, fullURL, body)
  87. if err != nil {
  88. return nil, errors.Wrapf(err, "fail to create a new %s request(%q)", method, fullURL)
  89. }
  90. for k, v := range header {
  91. req.Header.Set(k, v)
  92. }
  93. q := url.Values{}
  94. for k, v := range params {
  95. q.Add(k, v)
  96. }
  97. if len(q) > 0 {
  98. req.URL.RawQuery = q.Encode()
  99. }
  100. resp, err := s.client.Do(req)
  101. if err != nil {
  102. return nil, errors.Wrapf(err, "fail to send a %s request(%q)", method, fullURL)
  103. }
  104. if resp.StatusCode != http.StatusOK {
  105. body, err := io.ReadAll(resp.Body)
  106. if err != nil {
  107. return nil, errors.Wrap(err, "failed to read http response body")
  108. }
  109. return nil, errors.Errorf("http response error code %v body %q", resp.StatusCode, string(body))
  110. }
  111. if method == "POST" {
  112. if strings.Contains(uri, "/api/v1/auth/login") || strings.Contains(uri, "/api/v1/auth/signup") {
  113. cookie := ""
  114. h := resp.Header.Get("Set-Cookie")
  115. parts := strings.Split(h, "; ")
  116. for _, p := range parts {
  117. if strings.HasPrefix(p, "access-token=") {
  118. cookie = p
  119. break
  120. }
  121. }
  122. if cookie == "" {
  123. return nil, errors.Errorf("unable to find access token in the login response headers")
  124. }
  125. s.cookie = cookie
  126. } else if strings.Contains(uri, "/api/v1/auth/logout") {
  127. s.cookie = ""
  128. }
  129. }
  130. return resp.Body, nil
  131. }
  132. // get sends a GET client request.
  133. func (s *TestingServer) get(url string, params map[string]string) (io.ReadCloser, error) {
  134. return s.request("GET", url, nil, params, map[string]string{
  135. "Cookie": s.cookie,
  136. })
  137. }
  138. // post sends a POST client request.
  139. func (s *TestingServer) post(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) {
  140. return s.request("POST", url, body, params, map[string]string{
  141. "Cookie": s.cookie,
  142. })
  143. }
  144. // patch sends a PATCH client request.
  145. func (s *TestingServer) patch(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) {
  146. return s.request("PATCH", url, body, params, map[string]string{
  147. "Cookie": s.cookie,
  148. })
  149. }
  150. // delete sends a DELETE client request.
  151. func (s *TestingServer) delete(url string, params map[string]string) (io.ReadCloser, error) {
  152. return s.request("DELETE", url, nil, params, map[string]string{
  153. "Cookie": s.cookie,
  154. })
  155. }