idp.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "strings"
  8. )
  9. type IdentityProviderType string
  10. const (
  11. IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
  12. )
  13. type IdentityProviderConfig struct {
  14. OAuth2Config *IdentityProviderOAuth2Config
  15. }
  16. type IdentityProviderOAuth2Config struct {
  17. ClientID string `json:"clientId"`
  18. ClientSecret string `json:"clientSecret"`
  19. AuthURL string `json:"authUrl"`
  20. TokenURL string `json:"tokenUrl"`
  21. UserInfoURL string `json:"userInfoUrl"`
  22. Scopes []string `json:"scopes"`
  23. FieldMapping *FieldMapping `json:"fieldMapping"`
  24. }
  25. type FieldMapping struct {
  26. Identifier string `json:"identifier"`
  27. DisplayName string `json:"displayName"`
  28. Email string `json:"email"`
  29. }
  30. type IdentityProvider struct {
  31. ID int
  32. Name string
  33. Type IdentityProviderType
  34. IdentifierFilter string
  35. Config *IdentityProviderConfig
  36. }
  37. type FindIdentityProvider struct {
  38. ID *int
  39. }
  40. type UpdateIdentityProvider struct {
  41. ID int
  42. Type IdentityProviderType
  43. Name *string
  44. IdentifierFilter *string
  45. Config *IdentityProviderConfig
  46. }
  47. type DeleteIdentityProvider struct {
  48. ID int
  49. }
  50. func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) {
  51. tx, err := s.db.BeginTx(ctx, nil)
  52. if err != nil {
  53. return nil, err
  54. }
  55. defer tx.Rollback()
  56. var configBytes []byte
  57. if create.Type == IdentityProviderOAuth2 {
  58. configBytes, err = json.Marshal(create.Config.OAuth2Config)
  59. if err != nil {
  60. return nil, err
  61. }
  62. } else {
  63. return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
  64. }
  65. query := `
  66. INSERT INTO idp (
  67. name,
  68. type,
  69. identifier_filter,
  70. config
  71. )
  72. VALUES (?, ?, ?, ?)
  73. RETURNING id
  74. `
  75. if err := tx.QueryRowContext(
  76. ctx,
  77. query,
  78. create.Name,
  79. create.Type,
  80. create.IdentifierFilter,
  81. string(configBytes),
  82. ).Scan(
  83. &create.ID,
  84. ); err != nil {
  85. return nil, err
  86. }
  87. if err := tx.Commit(); err != nil {
  88. return nil, err
  89. }
  90. identityProvider := create
  91. s.idpCache.Store(identityProvider.ID, identityProvider)
  92. return identityProvider, nil
  93. }
  94. func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) {
  95. tx, err := s.db.BeginTx(ctx, nil)
  96. if err != nil {
  97. return nil, err
  98. }
  99. defer tx.Rollback()
  100. list, err := listIdentityProviders(ctx, tx, find)
  101. if err != nil {
  102. return nil, err
  103. }
  104. for _, item := range list {
  105. s.idpCache.Store(item.ID, item)
  106. }
  107. return list, nil
  108. }
  109. func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) {
  110. if find.ID != nil {
  111. if cache, ok := s.idpCache.Load(*find.ID); ok {
  112. return cache.(*IdentityProvider), nil
  113. }
  114. }
  115. tx, err := s.db.BeginTx(ctx, nil)
  116. if err != nil {
  117. return nil, err
  118. }
  119. defer tx.Rollback()
  120. list, err := listIdentityProviders(ctx, tx, find)
  121. if err != nil {
  122. return nil, err
  123. }
  124. if len(list) == 0 {
  125. return nil, nil
  126. }
  127. identityProvider := list[0]
  128. s.idpCache.Store(identityProvider.ID, identityProvider)
  129. return identityProvider, nil
  130. }
  131. func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) {
  132. tx, err := s.db.BeginTx(ctx, nil)
  133. if err != nil {
  134. return nil, err
  135. }
  136. defer tx.Rollback()
  137. set, args := []string{}, []any{}
  138. if v := update.Name; v != nil {
  139. set, args = append(set, "name = ?"), append(args, *v)
  140. }
  141. if v := update.IdentifierFilter; v != nil {
  142. set, args = append(set, "identifier_filter = ?"), append(args, *v)
  143. }
  144. if v := update.Config; v != nil {
  145. var configBytes []byte
  146. if update.Type == IdentityProviderOAuth2 {
  147. configBytes, err = json.Marshal(update.Config.OAuth2Config)
  148. if err != nil {
  149. return nil, err
  150. }
  151. } else {
  152. return nil, fmt.Errorf("unsupported idp type %s", string(update.Type))
  153. }
  154. set, args = append(set, "config = ?"), append(args, string(configBytes))
  155. }
  156. args = append(args, update.ID)
  157. query := `
  158. UPDATE idp
  159. SET ` + strings.Join(set, ", ") + `
  160. WHERE id = ?
  161. RETURNING id, name, type, identifier_filter, config
  162. `
  163. var identityProvider IdentityProvider
  164. var identityProviderConfig string
  165. if err := tx.QueryRowContext(ctx, query, args...).Scan(
  166. &identityProvider.ID,
  167. &identityProvider.Name,
  168. &identityProvider.Type,
  169. &identityProvider.IdentifierFilter,
  170. &identityProviderConfig,
  171. ); err != nil {
  172. return nil, err
  173. }
  174. if identityProvider.Type == IdentityProviderOAuth2 {
  175. oauth2Config := &IdentityProviderOAuth2Config{}
  176. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  177. return nil, err
  178. }
  179. identityProvider.Config = &IdentityProviderConfig{
  180. OAuth2Config: oauth2Config,
  181. }
  182. } else {
  183. return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
  184. }
  185. if err := tx.Commit(); err != nil {
  186. return nil, err
  187. }
  188. s.idpCache.Store(identityProvider.ID, identityProvider)
  189. return &identityProvider, nil
  190. }
  191. func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
  192. tx, err := s.db.BeginTx(ctx, nil)
  193. if err != nil {
  194. return err
  195. }
  196. defer tx.Rollback()
  197. where, args := []string{"id = ?"}, []any{delete.ID}
  198. stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
  199. result, err := tx.ExecContext(ctx, stmt, args...)
  200. if err != nil {
  201. return err
  202. }
  203. if _, err = result.RowsAffected(); err != nil {
  204. return err
  205. }
  206. if err := tx.Commit(); err != nil {
  207. return err
  208. }
  209. s.idpCache.Delete(delete.ID)
  210. return nil
  211. }
  212. func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) {
  213. where, args := []string{"TRUE"}, []any{}
  214. if v := find.ID; v != nil {
  215. where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
  216. }
  217. rows, err := tx.QueryContext(ctx, `
  218. SELECT
  219. id,
  220. name,
  221. type,
  222. identifier_filter,
  223. config
  224. FROM idp
  225. WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
  226. args...,
  227. )
  228. if err != nil {
  229. return nil, err
  230. }
  231. defer rows.Close()
  232. var identityProviders []*IdentityProvider
  233. for rows.Next() {
  234. var identityProvider IdentityProvider
  235. var identityProviderConfig string
  236. if err := rows.Scan(
  237. &identityProvider.ID,
  238. &identityProvider.Name,
  239. &identityProvider.Type,
  240. &identityProvider.IdentifierFilter,
  241. &identityProviderConfig,
  242. ); err != nil {
  243. return nil, err
  244. }
  245. if identityProvider.Type == IdentityProviderOAuth2 {
  246. oauth2Config := &IdentityProviderOAuth2Config{}
  247. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  248. return nil, err
  249. }
  250. identityProvider.Config = &IdentityProviderConfig{
  251. OAuth2Config: oauth2Config,
  252. }
  253. } else {
  254. return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
  255. }
  256. identityProviders = append(identityProviders, &identityProvider)
  257. }
  258. if err := rows.Err(); err != nil {
  259. return nil, err
  260. }
  261. return identityProviders, nil
  262. }