idp.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package store
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. "github.com/pkg/errors"
  8. )
  9. type IdentityProviderType string
  10. const (
  11. IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
  12. )
  13. func (t IdentityProviderType) String() string {
  14. return string(t)
  15. }
  16. type IdentityProviderConfig struct {
  17. OAuth2Config *IdentityProviderOAuth2Config
  18. }
  19. type IdentityProviderOAuth2Config struct {
  20. ClientID string `json:"clientId"`
  21. ClientSecret string `json:"clientSecret"`
  22. AuthURL string `json:"authUrl"`
  23. TokenURL string `json:"tokenUrl"`
  24. UserInfoURL string `json:"userInfoUrl"`
  25. Scopes []string `json:"scopes"`
  26. FieldMapping *FieldMapping `json:"fieldMapping"`
  27. }
  28. type FieldMapping struct {
  29. Identifier string `json:"identifier"`
  30. DisplayName string `json:"displayName"`
  31. Email string `json:"email"`
  32. }
  33. type IdentityProvider struct {
  34. ID int32
  35. Name string
  36. Type IdentityProviderType
  37. IdentifierFilter string
  38. Config *IdentityProviderConfig
  39. }
  40. type FindIdentityProvider struct {
  41. ID *int32
  42. }
  43. type UpdateIdentityProvider struct {
  44. ID int32
  45. Type IdentityProviderType
  46. Name *string
  47. IdentifierFilter *string
  48. Config *IdentityProviderConfig
  49. }
  50. type DeleteIdentityProvider struct {
  51. ID int32
  52. }
  53. func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) {
  54. var configBytes []byte
  55. if create.Type == IdentityProviderOAuth2Type {
  56. bytes, err := json.Marshal(create.Config.OAuth2Config)
  57. if err != nil {
  58. return nil, err
  59. }
  60. configBytes = bytes
  61. } else {
  62. return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
  63. }
  64. stmt := `
  65. INSERT INTO idp (
  66. name,
  67. type,
  68. identifier_filter,
  69. config
  70. )
  71. VALUES (?, ?, ?, ?)
  72. RETURNING id
  73. `
  74. if err := s.db.QueryRowContext(
  75. ctx,
  76. stmt,
  77. create.Name,
  78. create.Type,
  79. create.IdentifierFilter,
  80. string(configBytes),
  81. ).Scan(
  82. &create.ID,
  83. ); err != nil {
  84. return nil, err
  85. }
  86. identityProvider := create
  87. s.idpCache.Store(identityProvider.ID, identityProvider)
  88. return identityProvider, nil
  89. }
  90. func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) {
  91. where, args := []string{"1 = 1"}, []any{}
  92. if v := find.ID; v != nil {
  93. where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
  94. }
  95. rows, err := s.db.QueryContext(ctx, `
  96. SELECT
  97. id,
  98. name,
  99. type,
  100. identifier_filter,
  101. config
  102. FROM idp
  103. WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
  104. args...,
  105. )
  106. if err != nil {
  107. return nil, err
  108. }
  109. defer rows.Close()
  110. var identityProviders []*IdentityProvider
  111. for rows.Next() {
  112. var identityProvider IdentityProvider
  113. var identityProviderConfig string
  114. if err := rows.Scan(
  115. &identityProvider.ID,
  116. &identityProvider.Name,
  117. &identityProvider.Type,
  118. &identityProvider.IdentifierFilter,
  119. &identityProviderConfig,
  120. ); err != nil {
  121. return nil, err
  122. }
  123. if identityProvider.Type == IdentityProviderOAuth2Type {
  124. oauth2Config := &IdentityProviderOAuth2Config{}
  125. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  126. return nil, err
  127. }
  128. identityProvider.Config = &IdentityProviderConfig{
  129. OAuth2Config: oauth2Config,
  130. }
  131. } else {
  132. return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
  133. }
  134. identityProviders = append(identityProviders, &identityProvider)
  135. }
  136. if err := rows.Err(); err != nil {
  137. return nil, err
  138. }
  139. for _, item := range identityProviders {
  140. s.idpCache.Store(item.ID, item)
  141. }
  142. return identityProviders, nil
  143. }
  144. func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) {
  145. if find.ID != nil {
  146. if cache, ok := s.idpCache.Load(*find.ID); ok {
  147. return cache.(*IdentityProvider), nil
  148. }
  149. }
  150. list, err := s.ListIdentityProviders(ctx, find)
  151. if err != nil {
  152. return nil, err
  153. }
  154. if len(list) == 0 {
  155. return nil, nil
  156. }
  157. identityProvider := list[0]
  158. s.idpCache.Store(identityProvider.ID, identityProvider)
  159. return identityProvider, nil
  160. }
  161. func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) {
  162. set, args := []string{}, []any{}
  163. if v := update.Name; v != nil {
  164. set, args = append(set, "name = ?"), append(args, *v)
  165. }
  166. if v := update.IdentifierFilter; v != nil {
  167. set, args = append(set, "identifier_filter = ?"), append(args, *v)
  168. }
  169. if v := update.Config; v != nil {
  170. var configBytes []byte
  171. if update.Type == IdentityProviderOAuth2Type {
  172. bytes, err := json.Marshal(update.Config.OAuth2Config)
  173. if err != nil {
  174. return nil, err
  175. }
  176. configBytes = bytes
  177. } else {
  178. return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
  179. }
  180. set, args = append(set, "config = ?"), append(args, string(configBytes))
  181. }
  182. args = append(args, update.ID)
  183. stmt := `
  184. UPDATE idp
  185. SET ` + strings.Join(set, ", ") + `
  186. WHERE id = ?
  187. RETURNING id, name, type, identifier_filter, config
  188. `
  189. var identityProvider IdentityProvider
  190. var identityProviderConfig string
  191. if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
  192. &identityProvider.ID,
  193. &identityProvider.Name,
  194. &identityProvider.Type,
  195. &identityProvider.IdentifierFilter,
  196. &identityProviderConfig,
  197. ); err != nil {
  198. return nil, err
  199. }
  200. if identityProvider.Type == IdentityProviderOAuth2Type {
  201. oauth2Config := &IdentityProviderOAuth2Config{}
  202. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  203. return nil, err
  204. }
  205. identityProvider.Config = &IdentityProviderConfig{
  206. OAuth2Config: oauth2Config,
  207. }
  208. } else {
  209. return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
  210. }
  211. s.idpCache.Store(identityProvider.ID, identityProvider)
  212. return &identityProvider, nil
  213. }
  214. func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
  215. where, args := []string{"id = ?"}, []any{delete.ID}
  216. stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
  217. result, err := s.db.ExecContext(ctx, stmt, args...)
  218. if err != nil {
  219. return err
  220. }
  221. if _, err = result.RowsAffected(); err != nil {
  222. return err
  223. }
  224. s.idpCache.Delete(delete.ID)
  225. return nil
  226. }