oauth2_test.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package oauth2
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "net/url"
  10. "testing"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. "github.com/usememos/memos/plugin/idp"
  14. storepb "github.com/usememos/memos/proto/gen/store"
  15. )
  16. func TestNewIdentityProvider(t *testing.T) {
  17. tests := []struct {
  18. name string
  19. config *storepb.OAuth2Config
  20. containsErr string
  21. }{
  22. {
  23. name: "no tokenUrl",
  24. config: &storepb.OAuth2Config{
  25. ClientId: "test-client-id",
  26. ClientSecret: "test-client-secret",
  27. AuthUrl: "",
  28. TokenUrl: "",
  29. UserInfoUrl: "https://example.com/api/user",
  30. FieldMapping: &storepb.FieldMapping{
  31. Identifier: "login",
  32. },
  33. },
  34. containsErr: `the field "tokenUrl" is empty but required`,
  35. },
  36. {
  37. name: "no userInfoUrl",
  38. config: &storepb.OAuth2Config{
  39. ClientId: "test-client-id",
  40. ClientSecret: "test-client-secret",
  41. AuthUrl: "",
  42. TokenUrl: "https://example.com/token",
  43. UserInfoUrl: "",
  44. FieldMapping: &storepb.FieldMapping{
  45. Identifier: "login",
  46. },
  47. },
  48. containsErr: `the field "userInfoUrl" is empty but required`,
  49. },
  50. {
  51. name: "no field mapping identifier",
  52. config: &storepb.OAuth2Config{
  53. ClientId: "test-client-id",
  54. ClientSecret: "test-client-secret",
  55. AuthUrl: "",
  56. TokenUrl: "https://example.com/token",
  57. UserInfoUrl: "https://example.com/api/user",
  58. FieldMapping: &storepb.FieldMapping{
  59. Identifier: "",
  60. },
  61. },
  62. containsErr: `the field "fieldMapping.identifier" is empty but required`,
  63. },
  64. }
  65. for _, test := range tests {
  66. t.Run(test.name, func(*testing.T) {
  67. _, err := NewIdentityProvider(test.config)
  68. assert.ErrorContains(t, err, test.containsErr)
  69. })
  70. }
  71. }
  72. func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
  73. mux := http.NewServeMux()
  74. var rawIDToken string
  75. mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
  76. require.Equal(t, http.MethodPost, r.Method)
  77. body, err := io.ReadAll(r.Body)
  78. require.NoError(t, err)
  79. vals, err := url.ParseQuery(string(body))
  80. require.NoError(t, err)
  81. require.Equal(t, code, vals.Get("code"))
  82. require.Equal(t, "authorization_code", vals.Get("grant_type"))
  83. w.Header().Set("Content-Type", "application/json")
  84. err = json.NewEncoder(w).Encode(map[string]any{
  85. "access_token": accessToken,
  86. "token_type": "Bearer",
  87. "expires_in": 3600,
  88. "id_token": rawIDToken,
  89. })
  90. require.NoError(t, err)
  91. })
  92. mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, _ *http.Request) {
  93. w.Header().Set("Content-Type", "application/json")
  94. _, err := w.Write(userinfo)
  95. require.NoError(t, err)
  96. })
  97. s := httptest.NewServer(mux)
  98. return s
  99. }
  100. func TestIdentityProvider(t *testing.T) {
  101. ctx := context.Background()
  102. const (
  103. testClientID = "test-client-id"
  104. testCode = "test-code"
  105. testAccessToken = "test-access-token"
  106. testSubject = "123456789"
  107. testName = "John Doe"
  108. testEmail = "john.doe@example.com"
  109. )
  110. userInfo, err := json.Marshal(
  111. map[string]any{
  112. "sub": testSubject,
  113. "name": testName,
  114. "email": testEmail,
  115. },
  116. )
  117. require.NoError(t, err)
  118. s := newMockServer(t, testCode, testAccessToken, userInfo)
  119. oauth2, err := NewIdentityProvider(
  120. &storepb.OAuth2Config{
  121. ClientId: testClientID,
  122. ClientSecret: "test-client-secret",
  123. TokenUrl: fmt.Sprintf("%s/oauth2/token", s.URL),
  124. UserInfoUrl: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
  125. FieldMapping: &storepb.FieldMapping{
  126. Identifier: "sub",
  127. DisplayName: "name",
  128. Email: "email",
  129. },
  130. },
  131. )
  132. require.NoError(t, err)
  133. redirectURL := "https://example.com/oauth/callback"
  134. oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
  135. require.NoError(t, err)
  136. require.Equal(t, testAccessToken, oauthToken)
  137. userInfoResult, err := oauth2.UserInfo(oauthToken)
  138. require.NoError(t, err)
  139. wantUserInfo := &idp.IdentityProviderUserInfo{
  140. Identifier: testSubject,
  141. DisplayName: testName,
  142. Email: testEmail,
  143. }
  144. assert.Equal(t, wantUserInfo, userInfoResult)
  145. }