oauth2.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. // Package oauth2 is the plugin for OAuth2 Identity Provider.
  2. package oauth2
  3. import (
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "github.com/pkg/errors"
  10. "golang.org/x/oauth2"
  11. "github.com/usememos/memos/plugin/idp"
  12. storepb "github.com/usememos/memos/proto/gen/store"
  13. )
  14. // IdentityProvider represents an OAuth2 Identity Provider.
  15. type IdentityProvider struct {
  16. config *storepb.OAuth2Config
  17. }
  18. // NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration.
  19. func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error) {
  20. for v, field := range map[string]string{
  21. config.ClientId: "clientId",
  22. config.ClientSecret: "clientSecret",
  23. config.TokenUrl: "tokenUrl",
  24. config.UserInfoUrl: "userInfoUrl",
  25. config.FieldMapping.Identifier: "fieldMapping.identifier",
  26. } {
  27. if v == "" {
  28. return nil, errors.Errorf(`the field "%s" is empty but required`, field)
  29. }
  30. }
  31. return &IdentityProvider{
  32. config: config,
  33. }, nil
  34. }
  35. // ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
  36. func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) {
  37. conf := &oauth2.Config{
  38. ClientID: p.config.ClientId,
  39. ClientSecret: p.config.ClientSecret,
  40. RedirectURL: redirectURL,
  41. Scopes: p.config.Scopes,
  42. Endpoint: oauth2.Endpoint{
  43. AuthURL: p.config.AuthUrl,
  44. TokenURL: p.config.TokenUrl,
  45. AuthStyle: oauth2.AuthStyleInParams,
  46. },
  47. }
  48. token, err := conf.Exchange(ctx, code)
  49. if err != nil {
  50. return "", errors.Wrap(err, "failed to exchange access token")
  51. }
  52. accessToken, ok := token.Extra("access_token").(string)
  53. if !ok {
  54. return "", errors.New(`missing "access_token" from authorization response`)
  55. }
  56. return accessToken, nil
  57. }
  58. // UserInfo returns the parsed user information using the given OAuth2 token.
  59. func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) {
  60. client := &http.Client{}
  61. req, err := http.NewRequest(http.MethodGet, p.config.UserInfoUrl, nil)
  62. if err != nil {
  63. return nil, errors.Wrap(err, "failed to new http request")
  64. }
  65. req.Header.Set("Content-Type", "application/json")
  66. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
  67. resp, err := client.Do(req)
  68. if err != nil {
  69. return nil, errors.Wrap(err, "failed to get user information")
  70. }
  71. body, err := io.ReadAll(resp.Body)
  72. if err != nil {
  73. return nil, errors.Wrap(err, "failed to read response body")
  74. }
  75. defer resp.Body.Close()
  76. var claims map[string]any
  77. err = json.Unmarshal(body, &claims)
  78. if err != nil {
  79. return nil, errors.Wrap(err, "failed to unmarshal response body")
  80. }
  81. userInfo := &idp.IdentityProviderUserInfo{}
  82. if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok {
  83. userInfo.Identifier = v
  84. }
  85. if userInfo.Identifier == "" {
  86. return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier)
  87. }
  88. // Best effort to map optional fields
  89. if p.config.FieldMapping.DisplayName != "" {
  90. if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok {
  91. userInfo.DisplayName = v
  92. }
  93. }
  94. if userInfo.DisplayName == "" {
  95. userInfo.DisplayName = userInfo.Identifier
  96. }
  97. if p.config.FieldMapping.Email != "" {
  98. if v, ok := claims[p.config.FieldMapping.Email].(string); ok {
  99. userInfo.Email = v
  100. }
  101. }
  102. return userInfo, nil
  103. }