123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- package oauth2
- import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "net/url"
- "testing"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/usememos/memos/plugin/idp"
- "github.com/usememos/memos/store"
- )
- func TestNewIdentityProvider(t *testing.T) {
- tests := []struct {
- name string
- config *store.IdentityProviderOAuth2Config
- containsErr string
- }{
- {
- name: "no tokenUrl",
- config: &store.IdentityProviderOAuth2Config{
- ClientID: "test-client-id",
- ClientSecret: "test-client-secret",
- AuthURL: "",
- TokenURL: "",
- UserInfoURL: "https://example.com/api/user",
- FieldMapping: &store.FieldMapping{
- Identifier: "login",
- },
- },
- containsErr: `the field "tokenUrl" is empty but required`,
- },
- {
- name: "no userInfoUrl",
- config: &store.IdentityProviderOAuth2Config{
- ClientID: "test-client-id",
- ClientSecret: "test-client-secret",
- AuthURL: "",
- TokenURL: "https://example.com/token",
- UserInfoURL: "",
- FieldMapping: &store.FieldMapping{
- Identifier: "login",
- },
- },
- containsErr: `the field "userInfoUrl" is empty but required`,
- },
- {
- name: "no field mapping identifier",
- config: &store.IdentityProviderOAuth2Config{
- ClientID: "test-client-id",
- ClientSecret: "test-client-secret",
- AuthURL: "",
- TokenURL: "https://example.com/token",
- UserInfoURL: "https://example.com/api/user",
- FieldMapping: &store.FieldMapping{
- Identifier: "",
- },
- },
- containsErr: `the field "fieldMapping.identifier" is empty but required`,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- _, err := NewIdentityProvider(test.config)
- assert.ErrorContains(t, err, test.containsErr)
- })
- }
- }
- func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
- mux := http.NewServeMux()
- var rawIDToken string
- mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodPost, r.Method)
- body, err := io.ReadAll(r.Body)
- require.NoError(t, err)
- vals, err := url.ParseQuery(string(body))
- require.NoError(t, err)
- require.Equal(t, code, vals.Get("code"))
- require.Equal(t, "authorization_code", vals.Get("grant_type"))
- w.Header().Set("Content-Type", "application/json")
- err = json.NewEncoder(w).Encode(map[string]any{
- "access_token": accessToken,
- "token_type": "Bearer",
- "refresh_token": "test-refresh-token",
- "expires_in": 3600,
- "id_token": rawIDToken,
- })
- require.NoError(t, err)
- })
- mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- _, err := w.Write(userinfo)
- require.NoError(t, err)
- })
- s := httptest.NewServer(mux)
- return s
- }
- func TestIdentityProvider(t *testing.T) {
- ctx := context.Background()
- const (
- testClientID = "test-client-id"
- testCode = "test-code"
- testAccessToken = "test-access-token"
- testSubject = "123456789"
- testName = "John Doe"
- testEmail = "john.doe@example.com"
- )
- userInfo, err := json.Marshal(
- map[string]any{
- "sub": testSubject,
- "name": testName,
- "email": testEmail,
- },
- )
- require.NoError(t, err)
- s := newMockServer(t, testCode, testAccessToken, userInfo)
- oauth2, err := NewIdentityProvider(
- &store.IdentityProviderOAuth2Config{
- ClientID: testClientID,
- ClientSecret: "test-client-secret",
- TokenURL: fmt.Sprintf("%s/oauth2/token", s.URL),
- UserInfoURL: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
- FieldMapping: &store.FieldMapping{
- Identifier: "sub",
- DisplayName: "name",
- Email: "email",
- },
- },
- )
- require.NoError(t, err)
- redirectURL := "https://example.com/oauth/callback"
- oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
- require.NoError(t, err)
- require.Equal(t, testAccessToken, oauthToken)
- userInfoResult, err := oauth2.UserInfo(oauthToken)
- require.NoError(t, err)
- wantUserInfo := &idp.IdentityProviderUserInfo{
- Identifier: testSubject,
- DisplayName: testName,
- Email: testEmail,
- }
- assert.Equal(t, wantUserInfo, userInfoResult)
- }
|