idp.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package store
  2. import (
  3. "context"
  4. "github.com/pkg/errors"
  5. "google.golang.org/protobuf/encoding/protojson"
  6. storepb "github.com/usememos/memos/proto/gen/store"
  7. )
  8. type IdentityProvider struct {
  9. ID int32
  10. Name string
  11. Type storepb.IdentityProvider_Type
  12. IdentifierFilter string
  13. Config string
  14. }
  15. type FindIdentityProvider struct {
  16. ID *int32
  17. }
  18. type UpdateIdentityProvider struct {
  19. ID int32
  20. Name *string
  21. IdentifierFilter *string
  22. Config *string
  23. }
  24. type DeleteIdentityProvider struct {
  25. ID int32
  26. }
  27. func (s *Store) CreateIdentityProvider(ctx context.Context, create *storepb.IdentityProvider) (*storepb.IdentityProvider, error) {
  28. raw, err := convertIdentityProviderToRaw(create)
  29. if err != nil {
  30. return nil, err
  31. }
  32. identityProviderRaw, err := s.driver.CreateIdentityProvider(ctx, raw)
  33. if err != nil {
  34. return nil, err
  35. }
  36. identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
  37. if err != nil {
  38. return nil, err
  39. }
  40. s.idpCache.Store(identityProvider.Id, identityProvider)
  41. return identityProvider, nil
  42. }
  43. func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*storepb.IdentityProvider, error) {
  44. list, err := s.driver.ListIdentityProviders(ctx, find)
  45. if err != nil {
  46. return nil, err
  47. }
  48. identityProviders := []*storepb.IdentityProvider{}
  49. for _, raw := range list {
  50. identityProvider, err := convertIdentityProviderFromRaw(raw)
  51. if err != nil {
  52. return nil, err
  53. }
  54. identityProviders = append(identityProviders, identityProvider)
  55. s.idpCache.Store(identityProvider.Id, identityProvider)
  56. }
  57. return identityProviders, nil
  58. }
  59. func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*storepb.IdentityProvider, error) {
  60. if find.ID != nil {
  61. if cache, ok := s.idpCache.Load(*find.ID); ok {
  62. identityProvider, ok := cache.(*storepb.IdentityProvider)
  63. if ok {
  64. return identityProvider, nil
  65. }
  66. }
  67. }
  68. list, err := s.ListIdentityProviders(ctx, find)
  69. if err != nil {
  70. return nil, err
  71. }
  72. if len(list) == 0 {
  73. return nil, nil
  74. }
  75. if len(list) > 1 {
  76. return nil, errors.Errorf("Found multiple identity providers with ID %d", *find.ID)
  77. }
  78. identityProvider := list[0]
  79. return identityProvider, nil
  80. }
  81. type UpdateIdentityProviderV1 struct {
  82. ID int32
  83. Type storepb.IdentityProvider_Type
  84. Name *string
  85. IdentifierFilter *string
  86. Config *storepb.IdentityProviderConfig
  87. }
  88. func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderV1) (*storepb.IdentityProvider, error) {
  89. updateRaw := &UpdateIdentityProvider{
  90. ID: update.ID,
  91. }
  92. if update.Name != nil {
  93. updateRaw.Name = update.Name
  94. }
  95. if update.IdentifierFilter != nil {
  96. updateRaw.IdentifierFilter = update.IdentifierFilter
  97. }
  98. if update.Config != nil {
  99. configRaw, err := convertIdentityProviderConfigToRaw(update.Type, update.Config)
  100. if err != nil {
  101. return nil, err
  102. }
  103. updateRaw.Config = &configRaw
  104. }
  105. identityProviderRaw, err := s.driver.UpdateIdentityProvider(ctx, updateRaw)
  106. if err != nil {
  107. return nil, err
  108. }
  109. identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
  110. if err != nil {
  111. return nil, err
  112. }
  113. s.idpCache.Store(identityProvider.Id, identityProvider)
  114. return identityProvider, nil
  115. }
  116. func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
  117. err := s.driver.DeleteIdentityProvider(ctx, delete)
  118. if err != nil {
  119. return err
  120. }
  121. s.idpCache.Delete(delete.ID)
  122. return nil
  123. }
  124. func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
  125. identityProvider := &storepb.IdentityProvider{
  126. Id: raw.ID,
  127. Name: raw.Name,
  128. Type: raw.Type,
  129. IdentifierFilter: raw.IdentifierFilter,
  130. }
  131. config, err := convertIdentityProviderConfigFromRaw(identityProvider.Type, raw.Config)
  132. if err != nil {
  133. return nil, err
  134. }
  135. identityProvider.Config = config
  136. return identityProvider, nil
  137. }
  138. func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
  139. raw := &IdentityProvider{
  140. ID: identityProvider.Id,
  141. Name: identityProvider.Name,
  142. Type: identityProvider.Type,
  143. IdentifierFilter: identityProvider.IdentifierFilter,
  144. }
  145. configRaw, err := convertIdentityProviderConfigToRaw(identityProvider.Type, identityProvider.Config)
  146. if err != nil {
  147. return nil, err
  148. }
  149. raw.Config = configRaw
  150. return raw, nil
  151. }
  152. func convertIdentityProviderConfigFromRaw(identityProviderType storepb.IdentityProvider_Type, raw string) (*storepb.IdentityProviderConfig, error) {
  153. config := &storepb.IdentityProviderConfig{}
  154. if identityProviderType == storepb.IdentityProvider_OAUTH2 {
  155. oauth2Config := &storepb.OAuth2Config{}
  156. if err := protojsonUnmarshaler.Unmarshal([]byte(raw), oauth2Config); err != nil {
  157. return nil, errors.Wrap(err, "Failed to unmarshal OAuth2Config")
  158. }
  159. config.Config = &storepb.IdentityProviderConfig_Oauth2Config{Oauth2Config: oauth2Config}
  160. }
  161. return config, nil
  162. }
  163. func convertIdentityProviderConfigToRaw(identityProviderType storepb.IdentityProvider_Type, config *storepb.IdentityProviderConfig) (string, error) {
  164. raw := ""
  165. if identityProviderType == storepb.IdentityProvider_OAUTH2 {
  166. bytes, err := protojson.Marshal(config.GetOauth2Config())
  167. if err != nil {
  168. return "", errors.Wrap(err, "Failed to marshal OAuth2Config")
  169. }
  170. raw = string(bytes)
  171. }
  172. return raw, nil
  173. }