server.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. package testserver
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "strings"
  9. "testing"
  10. "time"
  11. "github.com/pkg/errors"
  12. "github.com/usememos/memos/server"
  13. "github.com/usememos/memos/server/profile"
  14. "github.com/usememos/memos/store/db"
  15. "github.com/usememos/memos/test"
  16. // sqlite3 driver.
  17. _ "github.com/mattn/go-sqlite3"
  18. )
  19. type TestingServer struct {
  20. server *server.Server
  21. client *http.Client
  22. profile *profile.Profile
  23. cookie string
  24. }
  25. func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) {
  26. profile := test.GetTestingProfile(t)
  27. db := db.NewDB(profile)
  28. if err := db.Open(ctx); err != nil {
  29. return nil, errors.Wrap(err, "failed to open db")
  30. }
  31. server, err := server.NewServer(ctx, profile)
  32. if err != nil {
  33. return nil, errors.Wrap(err, "failed to create server")
  34. }
  35. errChan := make(chan error, 1)
  36. s := &TestingServer{
  37. server: server,
  38. client: &http.Client{},
  39. profile: profile,
  40. cookie: "",
  41. }
  42. go func() {
  43. if err := s.server.Start(ctx); err != nil {
  44. if err != http.ErrServerClosed {
  45. errChan <- errors.Wrap(err, "failed to run main server")
  46. }
  47. }
  48. }()
  49. if err := s.waitForServerStart(errChan); err != nil {
  50. return nil, errors.Wrap(err, "failed to start server")
  51. }
  52. return s, nil
  53. }
  54. func (s *TestingServer) Shutdown(ctx context.Context) {
  55. s.server.Shutdown(ctx)
  56. }
  57. func (s *TestingServer) waitForServerStart(errChan <-chan error) error {
  58. ticker := time.NewTicker(100 * time.Millisecond)
  59. defer ticker.Stop()
  60. for {
  61. select {
  62. case <-ticker.C:
  63. if s == nil {
  64. continue
  65. }
  66. e := s.server.GetEcho()
  67. if e == nil {
  68. continue
  69. }
  70. addr := e.ListenerAddr()
  71. if addr != nil && strings.Contains(addr.String(), ":") {
  72. return nil // was started
  73. }
  74. case err := <-errChan:
  75. if err == http.ErrServerClosed {
  76. return nil
  77. }
  78. return err
  79. }
  80. }
  81. }
  82. func (s *TestingServer) request(method, uri string, body io.Reader, params, header map[string]string) (io.ReadCloser, error) {
  83. fullURL := fmt.Sprintf("http://localhost:%d%s", s.profile.Port, uri)
  84. req, err := http.NewRequest(method, fullURL, body)
  85. if err != nil {
  86. return nil, errors.Wrapf(err, "fail to create a new %s request(%q)", method, fullURL)
  87. }
  88. for k, v := range header {
  89. req.Header.Set(k, v)
  90. }
  91. q := url.Values{}
  92. for k, v := range params {
  93. q.Add(k, v)
  94. }
  95. if len(q) > 0 {
  96. req.URL.RawQuery = q.Encode()
  97. }
  98. resp, err := s.client.Do(req)
  99. if err != nil {
  100. return nil, errors.Wrapf(err, "fail to send a %s request(%q)", method, fullURL)
  101. }
  102. if resp.StatusCode != http.StatusOK {
  103. body, err := io.ReadAll(resp.Body)
  104. if err != nil {
  105. return nil, errors.Wrap(err, "failed to read http response body")
  106. }
  107. return nil, errors.Errorf("http response error code %v body %q", resp.StatusCode, string(body))
  108. }
  109. if method == "POST" {
  110. if strings.Contains(uri, "/api/auth/login") || strings.Contains(uri, "/api/auth/signup") {
  111. cookie := ""
  112. h := resp.Header.Get("Set-Cookie")
  113. parts := strings.Split(h, "; ")
  114. for _, p := range parts {
  115. if strings.HasPrefix(p, "access-token=") {
  116. cookie = p
  117. break
  118. }
  119. }
  120. if cookie == "" {
  121. return nil, errors.Errorf("unable to find access token in the login response headers")
  122. }
  123. s.cookie = cookie
  124. } else if strings.Contains(uri, "/api/auth/logout") {
  125. s.cookie = ""
  126. }
  127. }
  128. return resp.Body, nil
  129. }
  130. // get sends a GET client request.
  131. func (s *TestingServer) get(url string, params map[string]string) (io.ReadCloser, error) {
  132. return s.request("GET", url, nil, params, map[string]string{
  133. "Cookie": s.cookie,
  134. })
  135. }
  136. // post sends a POST client request.
  137. func (s *TestingServer) post(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) {
  138. return s.request("POST", url, body, params, map[string]string{
  139. "Cookie": s.cookie,
  140. })
  141. }
  142. // patch sends a PATCH client request.
  143. func (s *TestingServer) patch(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) {
  144. return s.request("PATCH", url, body, params, map[string]string{
  145. "Cookie": s.cookie,
  146. })
  147. }
  148. // delete sends a DELETE client request.
  149. func (s *TestingServer) delete(url string, params map[string]string) (io.ReadCloser, error) {
  150. return s.request("DELETE", url, nil, params, map[string]string{
  151. "Cookie": s.cookie,
  152. })
  153. }