oauth2_test.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package oauth2
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "net/url"
  10. "testing"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. "github.com/usememos/memos/plugin/idp"
  14. "github.com/usememos/memos/store"
  15. )
  16. func TestNewIdentityProvider(t *testing.T) {
  17. tests := []struct {
  18. name string
  19. config *store.IdentityProviderOAuth2Config
  20. containsErr string
  21. }{
  22. {
  23. name: "no tokenUrl",
  24. config: &store.IdentityProviderOAuth2Config{
  25. ClientID: "test-client-id",
  26. ClientSecret: "test-client-secret",
  27. AuthURL: "",
  28. TokenURL: "",
  29. UserInfoURL: "https://example.com/api/user",
  30. FieldMapping: &store.FieldMapping{
  31. Identifier: "login",
  32. },
  33. },
  34. containsErr: `the field "tokenUrl" is empty but required`,
  35. },
  36. {
  37. name: "no userInfoUrl",
  38. config: &store.IdentityProviderOAuth2Config{
  39. ClientID: "test-client-id",
  40. ClientSecret: "test-client-secret",
  41. AuthURL: "",
  42. TokenURL: "https://example.com/token",
  43. UserInfoURL: "",
  44. FieldMapping: &store.FieldMapping{
  45. Identifier: "login",
  46. },
  47. },
  48. containsErr: `the field "userInfoUrl" is empty but required`,
  49. },
  50. {
  51. name: "no field mapping identifier",
  52. config: &store.IdentityProviderOAuth2Config{
  53. ClientID: "test-client-id",
  54. ClientSecret: "test-client-secret",
  55. AuthURL: "",
  56. TokenURL: "https://example.com/token",
  57. UserInfoURL: "https://example.com/api/user",
  58. FieldMapping: &store.FieldMapping{
  59. Identifier: "",
  60. },
  61. },
  62. containsErr: `the field "fieldMapping.identifier" is empty but required`,
  63. },
  64. }
  65. for _, test := range tests {
  66. t.Run(test.name, func(t *testing.T) {
  67. _, err := NewIdentityProvider(test.config)
  68. assert.ErrorContains(t, err, test.containsErr)
  69. })
  70. }
  71. }
  72. func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
  73. mux := http.NewServeMux()
  74. var rawIDToken string
  75. mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
  76. require.Equal(t, http.MethodPost, r.Method)
  77. body, err := io.ReadAll(r.Body)
  78. require.NoError(t, err)
  79. vals, err := url.ParseQuery(string(body))
  80. require.NoError(t, err)
  81. require.Equal(t, code, vals.Get("code"))
  82. require.Equal(t, "authorization_code", vals.Get("grant_type"))
  83. w.Header().Set("Content-Type", "application/json")
  84. err = json.NewEncoder(w).Encode(map[string]any{
  85. "access_token": accessToken,
  86. "token_type": "Bearer",
  87. "refresh_token": "test-refresh-token",
  88. "expires_in": 3600,
  89. "id_token": rawIDToken,
  90. })
  91. require.NoError(t, err)
  92. })
  93. mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, r *http.Request) {
  94. w.Header().Set("Content-Type", "application/json")
  95. _, err := w.Write(userinfo)
  96. require.NoError(t, err)
  97. })
  98. s := httptest.NewServer(mux)
  99. return s
  100. }
  101. func TestIdentityProvider(t *testing.T) {
  102. ctx := context.Background()
  103. const (
  104. testClientID = "test-client-id"
  105. testCode = "test-code"
  106. testAccessToken = "test-access-token"
  107. testSubject = "123456789"
  108. testName = "John Doe"
  109. testEmail = "john.doe@example.com"
  110. )
  111. userInfo, err := json.Marshal(
  112. map[string]any{
  113. "sub": testSubject,
  114. "name": testName,
  115. "email": testEmail,
  116. },
  117. )
  118. require.NoError(t, err)
  119. s := newMockServer(t, testCode, testAccessToken, userInfo)
  120. oauth2, err := NewIdentityProvider(
  121. &store.IdentityProviderOAuth2Config{
  122. ClientID: testClientID,
  123. ClientSecret: "test-client-secret",
  124. TokenURL: fmt.Sprintf("%s/oauth2/token", s.URL),
  125. UserInfoURL: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
  126. FieldMapping: &store.FieldMapping{
  127. Identifier: "sub",
  128. DisplayName: "name",
  129. Email: "email",
  130. },
  131. },
  132. )
  133. require.NoError(t, err)
  134. redirectURL := "https://example.com/oauth/callback"
  135. oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
  136. require.NoError(t, err)
  137. require.Equal(t, testAccessToken, oauthToken)
  138. userInfoResult, err := oauth2.UserInfo(oauthToken)
  139. require.NoError(t, err)
  140. wantUserInfo := &idp.IdentityProviderUserInfo{
  141. Identifier: testSubject,
  142. DisplayName: testName,
  143. Email: testEmail,
  144. }
  145. assert.Equal(t, wantUserInfo, userInfoResult)
  146. }