idp.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "strings"
  8. )
  9. type IdentityProviderType string
  10. const (
  11. IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
  12. )
  13. func (t IdentityProviderType) String() string {
  14. return string(t)
  15. }
  16. type IdentityProviderConfig struct {
  17. OAuth2Config *IdentityProviderOAuth2Config
  18. }
  19. type IdentityProviderOAuth2Config struct {
  20. ClientID string `json:"clientId"`
  21. ClientSecret string `json:"clientSecret"`
  22. AuthURL string `json:"authUrl"`
  23. TokenURL string `json:"tokenUrl"`
  24. UserInfoURL string `json:"userInfoUrl"`
  25. Scopes []string `json:"scopes"`
  26. FieldMapping *FieldMapping `json:"fieldMapping"`
  27. }
  28. type FieldMapping struct {
  29. Identifier string `json:"identifier"`
  30. DisplayName string `json:"displayName"`
  31. Email string `json:"email"`
  32. }
  33. type IdentityProvider struct {
  34. ID int
  35. Name string
  36. Type IdentityProviderType
  37. IdentifierFilter string
  38. Config *IdentityProviderConfig
  39. }
  40. type FindIdentityProvider struct {
  41. ID *int
  42. }
  43. type UpdateIdentityProvider struct {
  44. ID int
  45. Type IdentityProviderType
  46. Name *string
  47. IdentifierFilter *string
  48. Config *IdentityProviderConfig
  49. }
  50. type DeleteIdentityProvider struct {
  51. ID int
  52. }
  53. func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) {
  54. tx, err := s.db.BeginTx(ctx, nil)
  55. if err != nil {
  56. return nil, err
  57. }
  58. defer tx.Rollback()
  59. var configBytes []byte
  60. if create.Type == IdentityProviderOAuth2Type {
  61. configBytes, err = json.Marshal(create.Config.OAuth2Config)
  62. if err != nil {
  63. return nil, err
  64. }
  65. } else {
  66. return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
  67. }
  68. query := `
  69. INSERT INTO idp (
  70. name,
  71. type,
  72. identifier_filter,
  73. config
  74. )
  75. VALUES (?, ?, ?, ?)
  76. RETURNING id
  77. `
  78. if err := tx.QueryRowContext(
  79. ctx,
  80. query,
  81. create.Name,
  82. create.Type,
  83. create.IdentifierFilter,
  84. string(configBytes),
  85. ).Scan(
  86. &create.ID,
  87. ); err != nil {
  88. return nil, err
  89. }
  90. if err := tx.Commit(); err != nil {
  91. return nil, err
  92. }
  93. identityProvider := create
  94. s.idpCache.Store(identityProvider.ID, identityProvider)
  95. return identityProvider, nil
  96. }
  97. func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) {
  98. tx, err := s.db.BeginTx(ctx, nil)
  99. if err != nil {
  100. return nil, err
  101. }
  102. defer tx.Rollback()
  103. list, err := listIdentityProviders(ctx, tx, find)
  104. if err != nil {
  105. return nil, err
  106. }
  107. if err := tx.Commit(); err != nil {
  108. return nil, err
  109. }
  110. for _, item := range list {
  111. s.idpCache.Store(item.ID, item)
  112. }
  113. return list, nil
  114. }
  115. func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) {
  116. if find.ID != nil {
  117. if cache, ok := s.idpCache.Load(*find.ID); ok {
  118. return cache.(*IdentityProvider), nil
  119. }
  120. }
  121. tx, err := s.db.BeginTx(ctx, nil)
  122. if err != nil {
  123. return nil, err
  124. }
  125. defer tx.Rollback()
  126. list, err := listIdentityProviders(ctx, tx, find)
  127. if err != nil {
  128. return nil, err
  129. }
  130. if len(list) == 0 {
  131. return nil, nil
  132. }
  133. if err := tx.Commit(); err != nil {
  134. return nil, err
  135. }
  136. identityProvider := list[0]
  137. s.idpCache.Store(identityProvider.ID, identityProvider)
  138. return identityProvider, nil
  139. }
  140. func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) {
  141. tx, err := s.db.BeginTx(ctx, nil)
  142. if err != nil {
  143. return nil, err
  144. }
  145. defer tx.Rollback()
  146. set, args := []string{}, []any{}
  147. if v := update.Name; v != nil {
  148. set, args = append(set, "name = ?"), append(args, *v)
  149. }
  150. if v := update.IdentifierFilter; v != nil {
  151. set, args = append(set, "identifier_filter = ?"), append(args, *v)
  152. }
  153. if v := update.Config; v != nil {
  154. var configBytes []byte
  155. if update.Type == IdentityProviderOAuth2Type {
  156. configBytes, err = json.Marshal(update.Config.OAuth2Config)
  157. if err != nil {
  158. return nil, err
  159. }
  160. } else {
  161. return nil, fmt.Errorf("unsupported idp type %s", string(update.Type))
  162. }
  163. set, args = append(set, "config = ?"), append(args, string(configBytes))
  164. }
  165. args = append(args, update.ID)
  166. query := `
  167. UPDATE idp
  168. SET ` + strings.Join(set, ", ") + `
  169. WHERE id = ?
  170. RETURNING id, name, type, identifier_filter, config
  171. `
  172. var identityProvider IdentityProvider
  173. var identityProviderConfig string
  174. if err := tx.QueryRowContext(ctx, query, args...).Scan(
  175. &identityProvider.ID,
  176. &identityProvider.Name,
  177. &identityProvider.Type,
  178. &identityProvider.IdentifierFilter,
  179. &identityProviderConfig,
  180. ); err != nil {
  181. return nil, err
  182. }
  183. if identityProvider.Type == IdentityProviderOAuth2Type {
  184. oauth2Config := &IdentityProviderOAuth2Config{}
  185. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  186. return nil, err
  187. }
  188. identityProvider.Config = &IdentityProviderConfig{
  189. OAuth2Config: oauth2Config,
  190. }
  191. } else {
  192. return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
  193. }
  194. if err := tx.Commit(); err != nil {
  195. return nil, err
  196. }
  197. s.idpCache.Store(identityProvider.ID, identityProvider)
  198. return &identityProvider, nil
  199. }
  200. func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
  201. tx, err := s.db.BeginTx(ctx, nil)
  202. if err != nil {
  203. return err
  204. }
  205. defer tx.Rollback()
  206. where, args := []string{"id = ?"}, []any{delete.ID}
  207. stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
  208. result, err := tx.ExecContext(ctx, stmt, args...)
  209. if err != nil {
  210. return err
  211. }
  212. if _, err = result.RowsAffected(); err != nil {
  213. return err
  214. }
  215. if err := tx.Commit(); err != nil {
  216. return err
  217. }
  218. s.idpCache.Delete(delete.ID)
  219. return nil
  220. }
  221. func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) {
  222. where, args := []string{"1 = 1"}, []any{}
  223. if v := find.ID; v != nil {
  224. where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
  225. }
  226. rows, err := tx.QueryContext(ctx, `
  227. SELECT
  228. id,
  229. name,
  230. type,
  231. identifier_filter,
  232. config
  233. FROM idp
  234. WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
  235. args...,
  236. )
  237. if err != nil {
  238. return nil, err
  239. }
  240. defer rows.Close()
  241. var identityProviders []*IdentityProvider
  242. for rows.Next() {
  243. var identityProvider IdentityProvider
  244. var identityProviderConfig string
  245. if err := rows.Scan(
  246. &identityProvider.ID,
  247. &identityProvider.Name,
  248. &identityProvider.Type,
  249. &identityProvider.IdentifierFilter,
  250. &identityProviderConfig,
  251. ); err != nil {
  252. return nil, err
  253. }
  254. if identityProvider.Type == IdentityProviderOAuth2Type {
  255. oauth2Config := &IdentityProviderOAuth2Config{}
  256. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  257. return nil, err
  258. }
  259. identityProvider.Config = &IdentityProviderConfig{
  260. OAuth2Config: oauth2Config,
  261. }
  262. } else {
  263. return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
  264. }
  265. identityProviders = append(identityProviders, &identityProvider)
  266. }
  267. if err := rows.Err(); err != nil {
  268. return nil, err
  269. }
  270. return identityProviders, nil
  271. }