idp.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package postgres
  2. import (
  3. "context"
  4. "encoding/json"
  5. "strings"
  6. "github.com/pkg/errors"
  7. "github.com/usememos/memos/store"
  8. )
  9. func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
  10. var configBytes []byte
  11. if create.Type == store.IdentityProviderOAuth2Type {
  12. bytes, err := json.Marshal(create.Config.OAuth2Config)
  13. if err != nil {
  14. return nil, err
  15. }
  16. configBytes = bytes
  17. } else {
  18. return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
  19. }
  20. fields := []string{"name", "type", "identifier_filter", "config"}
  21. args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}
  22. stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
  23. if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
  24. return nil, err
  25. }
  26. identityProvider := create
  27. return identityProvider, nil
  28. }
  29. func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
  30. where, args := []string{"1 = 1"}, []any{}
  31. if v := find.ID; v != nil {
  32. where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
  33. }
  34. rows, err := d.db.QueryContext(ctx, `
  35. SELECT
  36. id,
  37. name,
  38. type,
  39. identifier_filter,
  40. config
  41. FROM idp
  42. WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
  43. args...,
  44. )
  45. if err != nil {
  46. return nil, err
  47. }
  48. defer rows.Close()
  49. var identityProviders []*store.IdentityProvider
  50. for rows.Next() {
  51. var identityProvider store.IdentityProvider
  52. var identityProviderConfig string
  53. if err := rows.Scan(
  54. &identityProvider.ID,
  55. &identityProvider.Name,
  56. &identityProvider.Type,
  57. &identityProvider.IdentifierFilter,
  58. &identityProviderConfig,
  59. ); err != nil {
  60. return nil, err
  61. }
  62. if identityProvider.Type == store.IdentityProviderOAuth2Type {
  63. oauth2Config := &store.IdentityProviderOAuth2Config{}
  64. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  65. return nil, err
  66. }
  67. identityProvider.Config = &store.IdentityProviderConfig{
  68. OAuth2Config: oauth2Config,
  69. }
  70. } else {
  71. return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
  72. }
  73. identityProviders = append(identityProviders, &identityProvider)
  74. }
  75. if err := rows.Err(); err != nil {
  76. return nil, err
  77. }
  78. return identityProviders, nil
  79. }
  80. func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
  81. list, err := d.ListIdentityProviders(ctx, find)
  82. if err != nil {
  83. return nil, err
  84. }
  85. if len(list) == 0 {
  86. return nil, nil
  87. }
  88. return list[0], nil
  89. }
  90. func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
  91. set, args := []string{}, []any{}
  92. if v := update.Name; v != nil {
  93. set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *v)
  94. }
  95. if v := update.IdentifierFilter; v != nil {
  96. set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v)
  97. }
  98. if v := update.Config; v != nil {
  99. var configBytes []byte
  100. if update.Type == store.IdentityProviderOAuth2Type {
  101. bytes, err := json.Marshal(update.Config.OAuth2Config)
  102. if err != nil {
  103. return nil, err
  104. }
  105. configBytes = bytes
  106. } else {
  107. return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
  108. }
  109. set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, string(configBytes))
  110. }
  111. stmt := `
  112. UPDATE idp
  113. SET ` + strings.Join(set, ", ") + `
  114. WHERE id = ` + placeholder(len(args)+1) + `
  115. RETURNING id, name, type, identifier_filter, config
  116. `
  117. args = append(args, update.ID)
  118. var identityProvider store.IdentityProvider
  119. var identityProviderConfig string
  120. if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
  121. &identityProvider.ID,
  122. &identityProvider.Name,
  123. &identityProvider.Type,
  124. &identityProvider.IdentifierFilter,
  125. &identityProviderConfig,
  126. ); err != nil {
  127. return nil, err
  128. }
  129. if identityProvider.Type == store.IdentityProviderOAuth2Type {
  130. oauth2Config := &store.IdentityProviderOAuth2Config{}
  131. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  132. return nil, err
  133. }
  134. identityProvider.Config = &store.IdentityProviderConfig{
  135. OAuth2Config: oauth2Config,
  136. }
  137. } else {
  138. return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
  139. }
  140. return &identityProvider, nil
  141. }
  142. func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
  143. where, args := []string{"id = $1"}, []any{delete.ID}
  144. stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
  145. result, err := d.db.ExecContext(ctx, stmt, args...)
  146. if err != nil {
  147. return err
  148. }
  149. if _, err = result.RowsAffected(); err != nil {
  150. return err
  151. }
  152. return nil
  153. }