idp.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package mysql
  2. import (
  3. "context"
  4. "strings"
  5. "github.com/pkg/errors"
  6. storepb "github.com/usememos/memos/proto/gen/store"
  7. "github.com/usememos/memos/store"
  8. )
  9. func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
  10. placeholders := []string{"?", "?", "?", "?"}
  11. fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
  12. args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
  13. stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
  14. result, err := d.db.ExecContext(ctx, stmt, args...)
  15. if err != nil {
  16. return nil, err
  17. }
  18. id, err := result.LastInsertId()
  19. if err != nil {
  20. return nil, err
  21. }
  22. create.ID = int32(id)
  23. return create, nil
  24. }
  25. func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
  26. where, args := []string{"1 = 1"}, []any{}
  27. if v := find.ID; v != nil {
  28. where, args = append(where, "`id` = ?"), append(args, *v)
  29. }
  30. rows, err := d.db.QueryContext(ctx, "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
  31. args...,
  32. )
  33. if err != nil {
  34. return nil, err
  35. }
  36. defer rows.Close()
  37. var identityProviders []*store.IdentityProvider
  38. for rows.Next() {
  39. var identityProvider store.IdentityProvider
  40. var typeString string
  41. if err := rows.Scan(
  42. &identityProvider.ID,
  43. &identityProvider.Name,
  44. &typeString,
  45. &identityProvider.IdentifierFilter,
  46. &identityProvider.Config,
  47. ); err != nil {
  48. return nil, err
  49. }
  50. identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
  51. identityProviders = append(identityProviders, &identityProvider)
  52. }
  53. if err := rows.Err(); err != nil {
  54. return nil, err
  55. }
  56. return identityProviders, nil
  57. }
  58. func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
  59. list, err := d.ListIdentityProviders(ctx, find)
  60. if err != nil {
  61. return nil, err
  62. }
  63. if len(list) == 0 {
  64. return nil, nil
  65. }
  66. identityProvider := list[0]
  67. return identityProvider, nil
  68. }
  69. func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
  70. set, args := []string{}, []any{}
  71. if v := update.Name; v != nil {
  72. set, args = append(set, "`name` = ?"), append(args, *v)
  73. }
  74. if v := update.IdentifierFilter; v != nil {
  75. set, args = append(set, "`identifier_filter` = ?"), append(args, *v)
  76. }
  77. if v := update.Config; v != nil {
  78. set, args = append(set, "`config` = ?"), append(args, *v)
  79. }
  80. args = append(args, update.ID)
  81. stmt := "UPDATE `idp` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
  82. _, err := d.db.ExecContext(ctx, stmt, args...)
  83. if err != nil {
  84. return nil, err
  85. }
  86. identityProvider, err := d.GetIdentityProvider(ctx, &store.FindIdentityProvider{
  87. ID: &update.ID,
  88. })
  89. if err != nil {
  90. return nil, err
  91. }
  92. if identityProvider == nil {
  93. return nil, errors.Errorf("idp %d not found", update.ID)
  94. }
  95. return identityProvider, nil
  96. }
  97. func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
  98. where, args := []string{"`id` = ?"}, []any{delete.ID}
  99. stmt := "DELETE FROM `idp` WHERE " + strings.Join(where, " AND ")
  100. result, err := d.db.ExecContext(ctx, stmt, args...)
  101. if err != nil {
  102. return err
  103. }
  104. if _, err = result.RowsAffected(); err != nil {
  105. return err
  106. }
  107. return nil
  108. }