idp_service.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package v1
  2. import (
  3. "context"
  4. "fmt"
  5. "google.golang.org/grpc/codes"
  6. "google.golang.org/grpc/status"
  7. "google.golang.org/protobuf/types/known/emptypb"
  8. v1pb "github.com/usememos/memos/proto/gen/api/v1"
  9. storepb "github.com/usememos/memos/proto/gen/store"
  10. "github.com/usememos/memos/store"
  11. )
  12. func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb.CreateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
  13. currentUser, err := s.GetCurrentUser(ctx)
  14. if err != nil {
  15. return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
  16. }
  17. if currentUser.Role != store.RoleHost {
  18. return nil, status.Errorf(codes.PermissionDenied, "permission denied")
  19. }
  20. identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
  21. if err != nil {
  22. return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
  23. }
  24. return convertIdentityProviderFromStore(identityProvider), nil
  25. }
  26. func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListIdentityProvidersRequest) (*v1pb.ListIdentityProvidersResponse, error) {
  27. identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
  28. if err != nil {
  29. return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
  30. }
  31. response := &v1pb.ListIdentityProvidersResponse{
  32. IdentityProviders: []*v1pb.IdentityProvider{},
  33. }
  34. for _, identityProvider := range identityProviders {
  35. response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
  36. }
  37. return response, nil
  38. }
  39. func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
  40. id, err := ExtractIdentityProviderIDFromName(request.Name)
  41. if err != nil {
  42. return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
  43. }
  44. identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
  45. ID: &id,
  46. })
  47. if err != nil {
  48. return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
  49. }
  50. if identityProvider == nil {
  51. return nil, status.Errorf(codes.NotFound, "identity provider not found")
  52. }
  53. return convertIdentityProviderFromStore(identityProvider), nil
  54. }
  55. func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
  56. if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
  57. return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
  58. }
  59. id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
  60. if err != nil {
  61. return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
  62. }
  63. update := &store.UpdateIdentityProviderV1{
  64. ID: id,
  65. Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
  66. }
  67. for _, field := range request.UpdateMask.Paths {
  68. switch field {
  69. case "title":
  70. update.Name = &request.IdentityProvider.Title
  71. case "identifier_filter":
  72. update.IdentifierFilter = &request.IdentityProvider.IdentifierFilter
  73. case "config":
  74. update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
  75. }
  76. }
  77. identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
  78. if err != nil {
  79. return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
  80. }
  81. return convertIdentityProviderFromStore(identityProvider), nil
  82. }
  83. func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
  84. id, err := ExtractIdentityProviderIDFromName(request.Name)
  85. if err != nil {
  86. return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
  87. }
  88. if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
  89. return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
  90. }
  91. return &emptypb.Empty{}, nil
  92. }
  93. func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
  94. temp := &v1pb.IdentityProvider{
  95. Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
  96. Title: identityProvider.Name,
  97. IdentifierFilter: identityProvider.IdentifierFilter,
  98. Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
  99. }
  100. if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
  101. oauth2Config := identityProvider.Config.GetOauth2Config()
  102. temp.Config = &v1pb.IdentityProviderConfig{
  103. Config: &v1pb.IdentityProviderConfig_Oauth2Config{
  104. Oauth2Config: &v1pb.OAuth2Config{
  105. ClientId: oauth2Config.ClientId,
  106. ClientSecret: oauth2Config.ClientSecret,
  107. AuthUrl: oauth2Config.AuthUrl,
  108. TokenUrl: oauth2Config.TokenUrl,
  109. UserInfoUrl: oauth2Config.UserInfoUrl,
  110. Scopes: oauth2Config.Scopes,
  111. FieldMapping: &v1pb.FieldMapping{
  112. Identifier: oauth2Config.FieldMapping.Identifier,
  113. DisplayName: oauth2Config.FieldMapping.DisplayName,
  114. Email: oauth2Config.FieldMapping.Email,
  115. },
  116. },
  117. },
  118. }
  119. }
  120. return temp
  121. }
  122. func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
  123. id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
  124. temp := &storepb.IdentityProvider{
  125. Id: id,
  126. Name: identityProvider.Title,
  127. IdentifierFilter: identityProvider.IdentifierFilter,
  128. Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
  129. Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
  130. }
  131. return temp
  132. }
  133. func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProvider_Type, config *v1pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
  134. if identityProviderType == v1pb.IdentityProvider_OAUTH2 {
  135. oauth2Config := config.GetOauth2Config()
  136. return &storepb.IdentityProviderConfig{
  137. Config: &storepb.IdentityProviderConfig_Oauth2Config{
  138. Oauth2Config: &storepb.OAuth2Config{
  139. ClientId: oauth2Config.ClientId,
  140. ClientSecret: oauth2Config.ClientSecret,
  141. AuthUrl: oauth2Config.AuthUrl,
  142. TokenUrl: oauth2Config.TokenUrl,
  143. UserInfoUrl: oauth2Config.UserInfoUrl,
  144. Scopes: oauth2Config.Scopes,
  145. FieldMapping: &storepb.FieldMapping{
  146. Identifier: oauth2Config.FieldMapping.Identifier,
  147. DisplayName: oauth2Config.FieldMapping.DisplayName,
  148. Email: oauth2Config.FieldMapping.Email,
  149. },
  150. },
  151. },
  152. }
  153. }
  154. return nil
  155. }