123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- package store
- import (
- "context"
- "github.com/pkg/errors"
- "google.golang.org/protobuf/encoding/protojson"
- storepb "github.com/usememos/memos/proto/gen/store"
- )
- type IdentityProvider struct {
- ID int32
- Name string
- Type storepb.IdentityProvider_Type
- IdentifierFilter string
- Config string
- }
- type FindIdentityProvider struct {
- ID *int32
- }
- type UpdateIdentityProvider struct {
- ID int32
- Name *string
- IdentifierFilter *string
- Config *string
- }
- type DeleteIdentityProvider struct {
- ID int32
- }
- func (s *Store) CreateIdentityProvider(ctx context.Context, create *storepb.IdentityProvider) (*storepb.IdentityProvider, error) {
- raw, err := convertIdentityProviderToRaw(create)
- if err != nil {
- return nil, err
- }
- identityProviderRaw, err := s.driver.CreateIdentityProvider(ctx, raw)
- if err != nil {
- return nil, err
- }
- identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
- if err != nil {
- return nil, err
- }
- s.idpCache.Store(identityProvider.Id, identityProvider)
- return identityProvider, nil
- }
- func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*storepb.IdentityProvider, error) {
- list, err := s.driver.ListIdentityProviders(ctx, find)
- if err != nil {
- return nil, err
- }
- identityProviders := []*storepb.IdentityProvider{}
- for _, raw := range list {
- identityProvider, err := convertIdentityProviderFromRaw(raw)
- if err != nil {
- return nil, err
- }
- identityProviders = append(identityProviders, identityProvider)
- s.idpCache.Store(identityProvider.Id, identityProvider)
- }
- return identityProviders, nil
- }
- func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*storepb.IdentityProvider, error) {
- if find.ID != nil {
- if cache, ok := s.idpCache.Load(*find.ID); ok {
- identityProvider, ok := cache.(*storepb.IdentityProvider)
- if ok {
- return identityProvider, nil
- }
- }
- }
- list, err := s.ListIdentityProviders(ctx, find)
- if err != nil {
- return nil, err
- }
- if len(list) == 0 {
- return nil, nil
- }
- if len(list) > 1 {
- return nil, errors.Errorf("Found multiple identity providers with ID %d", *find.ID)
- }
- identityProvider := list[0]
- return identityProvider, nil
- }
- type UpdateIdentityProviderV1 struct {
- ID int32
- Type storepb.IdentityProvider_Type
- Name *string
- IdentifierFilter *string
- Config *storepb.IdentityProviderConfig
- }
- func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderV1) (*storepb.IdentityProvider, error) {
- updateRaw := &UpdateIdentityProvider{
- ID: update.ID,
- }
- if update.Name != nil {
- updateRaw.Name = update.Name
- }
- if update.IdentifierFilter != nil {
- updateRaw.IdentifierFilter = update.IdentifierFilter
- }
- if update.Config != nil {
- configRaw, err := convertIdentityProviderConfigToRaw(update.Type, update.Config)
- if err != nil {
- return nil, err
- }
- updateRaw.Config = &configRaw
- }
- identityProviderRaw, err := s.driver.UpdateIdentityProvider(ctx, updateRaw)
- if err != nil {
- return nil, err
- }
- identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
- if err != nil {
- return nil, err
- }
- s.idpCache.Store(identityProvider.Id, identityProvider)
- return identityProvider, nil
- }
- func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
- err := s.driver.DeleteIdentityProvider(ctx, delete)
- if err != nil {
- return err
- }
- s.idpCache.Delete(delete.ID)
- return nil
- }
- func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
- identityProvider := &storepb.IdentityProvider{
- Id: raw.ID,
- Name: raw.Name,
- Type: raw.Type,
- IdentifierFilter: raw.IdentifierFilter,
- }
- config, err := convertIdentityProviderConfigFromRaw(identityProvider.Type, raw.Config)
- if err != nil {
- return nil, err
- }
- identityProvider.Config = config
- return identityProvider, nil
- }
- func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
- raw := &IdentityProvider{
- ID: identityProvider.Id,
- Name: identityProvider.Name,
- Type: identityProvider.Type,
- IdentifierFilter: identityProvider.IdentifierFilter,
- }
- configRaw, err := convertIdentityProviderConfigToRaw(identityProvider.Type, identityProvider.Config)
- if err != nil {
- return nil, err
- }
- raw.Config = configRaw
- return raw, nil
- }
- func convertIdentityProviderConfigFromRaw(identityProviderType storepb.IdentityProvider_Type, raw string) (*storepb.IdentityProviderConfig, error) {
- config := &storepb.IdentityProviderConfig{}
- if identityProviderType == storepb.IdentityProvider_OAUTH2 {
- oauth2Config := &storepb.OAuth2Config{}
- if err := protojsonUnmarshaler.Unmarshal([]byte(raw), oauth2Config); err != nil {
- return nil, errors.Wrap(err, "Failed to unmarshal OAuth2Config")
- }
- config.Config = &storepb.IdentityProviderConfig_Oauth2Config{Oauth2Config: oauth2Config}
- }
- return config, nil
- }
- func convertIdentityProviderConfigToRaw(identityProviderType storepb.IdentityProvider_Type, config *storepb.IdentityProviderConfig) (string, error) {
- raw := ""
- if identityProviderType == storepb.IdentityProvider_OAUTH2 {
- bytes, err := protojson.Marshal(config.GetOauth2Config())
- if err != nil {
- return "", errors.Wrap(err, "Failed to marshal OAuth2Config")
- }
- raw = string(bytes)
- }
- return raw, nil
- }
|