idp.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. package store
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. )
  8. type IdentityProviderType string
  9. const (
  10. IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
  11. )
  12. func (t IdentityProviderType) String() string {
  13. return string(t)
  14. }
  15. type IdentityProviderConfig struct {
  16. OAuth2Config *IdentityProviderOAuth2Config
  17. }
  18. type IdentityProviderOAuth2Config struct {
  19. ClientID string `json:"clientId"`
  20. ClientSecret string `json:"clientSecret"`
  21. AuthURL string `json:"authUrl"`
  22. TokenURL string `json:"tokenUrl"`
  23. UserInfoURL string `json:"userInfoUrl"`
  24. Scopes []string `json:"scopes"`
  25. FieldMapping *FieldMapping `json:"fieldMapping"`
  26. }
  27. type FieldMapping struct {
  28. Identifier string `json:"identifier"`
  29. DisplayName string `json:"displayName"`
  30. Email string `json:"email"`
  31. }
  32. type IdentityProvider struct {
  33. ID int32
  34. Name string
  35. Type IdentityProviderType
  36. IdentifierFilter string
  37. Config *IdentityProviderConfig
  38. }
  39. type FindIdentityProvider struct {
  40. ID *int32
  41. }
  42. type UpdateIdentityProvider struct {
  43. ID int32
  44. Type IdentityProviderType
  45. Name *string
  46. IdentifierFilter *string
  47. Config *IdentityProviderConfig
  48. }
  49. type DeleteIdentityProvider struct {
  50. ID int32
  51. }
  52. func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) {
  53. var configBytes []byte
  54. if create.Type == IdentityProviderOAuth2Type {
  55. bytes, err := json.Marshal(create.Config.OAuth2Config)
  56. if err != nil {
  57. return nil, err
  58. }
  59. configBytes = bytes
  60. } else {
  61. return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
  62. }
  63. stmt := `
  64. INSERT INTO idp (
  65. name,
  66. type,
  67. identifier_filter,
  68. config
  69. )
  70. VALUES (?, ?, ?, ?)
  71. RETURNING id
  72. `
  73. if err := s.db.QueryRowContext(
  74. ctx,
  75. stmt,
  76. create.Name,
  77. create.Type,
  78. create.IdentifierFilter,
  79. string(configBytes),
  80. ).Scan(
  81. &create.ID,
  82. ); err != nil {
  83. return nil, err
  84. }
  85. identityProvider := create
  86. s.idpCache.Store(identityProvider.ID, identityProvider)
  87. return identityProvider, nil
  88. }
  89. func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) {
  90. where, args := []string{"1 = 1"}, []any{}
  91. if v := find.ID; v != nil {
  92. where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
  93. }
  94. rows, err := s.db.QueryContext(ctx, `
  95. SELECT
  96. id,
  97. name,
  98. type,
  99. identifier_filter,
  100. config
  101. FROM idp
  102. WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
  103. args...,
  104. )
  105. if err != nil {
  106. return nil, err
  107. }
  108. defer rows.Close()
  109. var identityProviders []*IdentityProvider
  110. for rows.Next() {
  111. var identityProvider IdentityProvider
  112. var identityProviderConfig string
  113. if err := rows.Scan(
  114. &identityProvider.ID,
  115. &identityProvider.Name,
  116. &identityProvider.Type,
  117. &identityProvider.IdentifierFilter,
  118. &identityProviderConfig,
  119. ); err != nil {
  120. return nil, err
  121. }
  122. if identityProvider.Type == IdentityProviderOAuth2Type {
  123. oauth2Config := &IdentityProviderOAuth2Config{}
  124. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  125. return nil, err
  126. }
  127. identityProvider.Config = &IdentityProviderConfig{
  128. OAuth2Config: oauth2Config,
  129. }
  130. } else {
  131. return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
  132. }
  133. identityProviders = append(identityProviders, &identityProvider)
  134. }
  135. if err := rows.Err(); err != nil {
  136. return nil, err
  137. }
  138. for _, item := range identityProviders {
  139. s.idpCache.Store(item.ID, item)
  140. }
  141. return identityProviders, nil
  142. }
  143. func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) {
  144. if find.ID != nil {
  145. if cache, ok := s.idpCache.Load(*find.ID); ok {
  146. return cache.(*IdentityProvider), nil
  147. }
  148. }
  149. list, err := s.ListIdentityProviders(ctx, find)
  150. if err != nil {
  151. return nil, err
  152. }
  153. if len(list) == 0 {
  154. return nil, nil
  155. }
  156. identityProvider := list[0]
  157. s.idpCache.Store(identityProvider.ID, identityProvider)
  158. return identityProvider, nil
  159. }
  160. func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) {
  161. set, args := []string{}, []any{}
  162. if v := update.Name; v != nil {
  163. set, args = append(set, "name = ?"), append(args, *v)
  164. }
  165. if v := update.IdentifierFilter; v != nil {
  166. set, args = append(set, "identifier_filter = ?"), append(args, *v)
  167. }
  168. if v := update.Config; v != nil {
  169. var configBytes []byte
  170. if update.Type == IdentityProviderOAuth2Type {
  171. bytes, err := json.Marshal(update.Config.OAuth2Config)
  172. if err != nil {
  173. return nil, err
  174. }
  175. configBytes = bytes
  176. } else {
  177. return nil, fmt.Errorf("unsupported idp type %s", string(update.Type))
  178. }
  179. set, args = append(set, "config = ?"), append(args, string(configBytes))
  180. }
  181. args = append(args, update.ID)
  182. stmt := `
  183. UPDATE idp
  184. SET ` + strings.Join(set, ", ") + `
  185. WHERE id = ?
  186. RETURNING id, name, type, identifier_filter, config
  187. `
  188. var identityProvider IdentityProvider
  189. var identityProviderConfig string
  190. if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
  191. &identityProvider.ID,
  192. &identityProvider.Name,
  193. &identityProvider.Type,
  194. &identityProvider.IdentifierFilter,
  195. &identityProviderConfig,
  196. ); err != nil {
  197. return nil, err
  198. }
  199. if identityProvider.Type == IdentityProviderOAuth2Type {
  200. oauth2Config := &IdentityProviderOAuth2Config{}
  201. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  202. return nil, err
  203. }
  204. identityProvider.Config = &IdentityProviderConfig{
  205. OAuth2Config: oauth2Config,
  206. }
  207. } else {
  208. return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
  209. }
  210. s.idpCache.Store(identityProvider.ID, identityProvider)
  211. return &identityProvider, nil
  212. }
  213. func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
  214. where, args := []string{"id = ?"}, []any{delete.ID}
  215. stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
  216. result, err := s.db.ExecContext(ctx, stmt, args...)
  217. if err != nil {
  218. return err
  219. }
  220. if _, err = result.RowsAffected(); err != nil {
  221. return err
  222. }
  223. s.idpCache.Delete(delete.ID)
  224. return nil
  225. }