server.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package testserver
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "strings"
  9. "testing"
  10. "time"
  11. "github.com/pkg/errors"
  12. // sqlite driver.
  13. _ "modernc.org/sqlite"
  14. "github.com/usememos/memos/api/auth"
  15. "github.com/usememos/memos/server"
  16. "github.com/usememos/memos/server/profile"
  17. "github.com/usememos/memos/store"
  18. "github.com/usememos/memos/store/db"
  19. "github.com/usememos/memos/test"
  20. )
  21. type TestingServer struct {
  22. server *server.Server
  23. client *http.Client
  24. profile *profile.Profile
  25. cookie string
  26. }
  27. func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) {
  28. profile := test.GetTestingProfile(t)
  29. dbDriver, err := db.NewDBDriver(profile)
  30. if err != nil {
  31. return nil, errors.Wrap(err, "failed to create db driver")
  32. }
  33. if err := dbDriver.Migrate(ctx); err != nil {
  34. return nil, errors.Wrap(err, "failed to migrate db")
  35. }
  36. store := store.New(dbDriver, profile)
  37. server, err := server.NewServer(ctx, profile, store)
  38. if err != nil {
  39. return nil, errors.Wrap(err, "failed to create server")
  40. }
  41. s := &TestingServer{
  42. server: server,
  43. client: &http.Client{},
  44. profile: profile,
  45. cookie: "",
  46. }
  47. errChan := make(chan error, 1)
  48. go func() {
  49. if err := s.server.Start(ctx); err != nil {
  50. if err != http.ErrServerClosed {
  51. errChan <- errors.Wrap(err, "failed to run main server")
  52. }
  53. }
  54. }()
  55. if err := s.waitForServerStart(errChan); err != nil {
  56. return nil, errors.Wrap(err, "failed to start server")
  57. }
  58. return s, nil
  59. }
  60. func (s *TestingServer) Shutdown(ctx context.Context) {
  61. s.server.Shutdown(ctx)
  62. }
  63. func (s *TestingServer) waitForServerStart(errChan <-chan error) error {
  64. ticker := time.NewTicker(100 * time.Millisecond)
  65. defer ticker.Stop()
  66. for {
  67. select {
  68. case <-ticker.C:
  69. if s == nil {
  70. continue
  71. }
  72. e := s.server.GetEcho()
  73. if e == nil {
  74. continue
  75. }
  76. addr := e.ListenerAddr()
  77. if addr != nil && strings.Contains(addr.String(), ":") {
  78. return nil // was started
  79. }
  80. case err := <-errChan:
  81. if err == http.ErrServerClosed {
  82. return nil
  83. }
  84. return err
  85. }
  86. }
  87. }
  88. func (s *TestingServer) request(method, uri string, body io.Reader, params, header map[string]string) (io.ReadCloser, error) {
  89. fullURL := fmt.Sprintf("http://localhost:%d%s", s.profile.Port, uri)
  90. req, err := http.NewRequest(method, fullURL, body)
  91. if err != nil {
  92. return nil, errors.Wrapf(err, "fail to create a new %s request(%q)", method, fullURL)
  93. }
  94. for k, v := range header {
  95. req.Header.Set(k, v)
  96. }
  97. q := url.Values{}
  98. for k, v := range params {
  99. q.Add(k, v)
  100. }
  101. if len(q) > 0 {
  102. req.URL.RawQuery = q.Encode()
  103. }
  104. resp, err := s.client.Do(req)
  105. if err != nil {
  106. return nil, errors.Wrapf(err, "fail to send a %s request(%q)", method, fullURL)
  107. }
  108. if resp.StatusCode != http.StatusOK {
  109. body, err := io.ReadAll(resp.Body)
  110. if err != nil {
  111. return nil, errors.Wrap(err, "failed to read http response body")
  112. }
  113. return nil, errors.Errorf("http response error code %v body %q", resp.StatusCode, string(body))
  114. }
  115. if method == "POST" {
  116. if strings.Contains(uri, "/api/v1/auth/login") || strings.Contains(uri, "/api/v1/auth/signup") {
  117. cookie := ""
  118. h := resp.Header.Get("Set-Cookie")
  119. parts := strings.Split(h, "; ")
  120. for _, p := range parts {
  121. if strings.HasPrefix(p, fmt.Sprintf("%s=", auth.AccessTokenCookieName)) {
  122. cookie = p
  123. break
  124. }
  125. }
  126. if cookie == "" {
  127. return nil, errors.New("unable to find access token in the login response headers")
  128. }
  129. s.cookie = cookie
  130. } else if strings.Contains(uri, "/api/v1/auth/signout") {
  131. s.cookie = ""
  132. }
  133. }
  134. return resp.Body, nil
  135. }
  136. // get sends a GET client request.
  137. func (s *TestingServer) get(url string, params map[string]string) (io.ReadCloser, error) {
  138. return s.request("GET", url, nil, params, map[string]string{
  139. "Cookie": s.cookie,
  140. })
  141. }
  142. // post sends a POST client request.
  143. func (s *TestingServer) post(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) {
  144. return s.request("POST", url, body, params, map[string]string{
  145. "Cookie": s.cookie,
  146. })
  147. }
  148. // patch sends a PATCH client request.
  149. func (s *TestingServer) patch(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) {
  150. return s.request("PATCH", url, body, params, map[string]string{
  151. "Cookie": s.cookie,
  152. })
  153. }
  154. // delete sends a DELETE client request.
  155. func (s *TestingServer) delete(url string, params map[string]string) (io.ReadCloser, error) {
  156. return s.request("DELETE", url, nil, params, map[string]string{
  157. "Cookie": s.cookie,
  158. })
  159. }