migrator.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package postgres
  2. import (
  3. "context"
  4. "embed"
  5. "fmt"
  6. "io/fs"
  7. "regexp"
  8. "sort"
  9. "strings"
  10. "github.com/pkg/errors"
  11. "github.com/usememos/memos/server/version"
  12. "github.com/usememos/memos/store"
  13. )
  14. //go:embed migration
  15. var migrationFS embed.FS
  16. const (
  17. latestSchemaFileName = "LATEST__SCHEMA.sql"
  18. )
  19. func (d *DB) Migrate(ctx context.Context) error {
  20. if d.profile.IsDev() {
  21. return d.nonProdMigrate(ctx)
  22. }
  23. return d.prodMigrate(ctx)
  24. }
  25. func (d *DB) nonProdMigrate(ctx context.Context) error {
  26. rows, err := d.db.QueryContext(ctx, "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';")
  27. if err != nil {
  28. return errors.Errorf("failed to query database tables: %s", err)
  29. }
  30. if rows.Err() != nil {
  31. return errors.Errorf("failed to query database tables: %s", err)
  32. }
  33. defer rows.Close()
  34. var tables []string
  35. for rows.Next() {
  36. var table string
  37. err := rows.Scan(&table)
  38. if err != nil {
  39. return errors.Errorf("failed to scan table name: %s", err)
  40. }
  41. tables = append(tables, table)
  42. }
  43. if len(tables) != 0 {
  44. return nil
  45. }
  46. buf, err := migrationFS.ReadFile("migration/dev/" + latestSchemaFileName)
  47. if err != nil {
  48. return errors.Errorf("failed to read latest schema file: %s", err)
  49. }
  50. stmt := string(buf)
  51. if _, err := d.db.ExecContext(ctx, stmt); err != nil {
  52. return errors.Errorf("failed to exec SQL %s: %s", stmt, err)
  53. }
  54. return nil
  55. }
  56. func (d *DB) prodMigrate(ctx context.Context) error {
  57. currentVersion := version.GetCurrentVersion(d.profile.Mode)
  58. migrationHistoryList, err := d.FindMigrationHistoryList(ctx, &store.FindMigrationHistory{})
  59. // If there is no migration history, we should apply the latest schema.
  60. if err != nil || len(migrationHistoryList) == 0 {
  61. buf, err := migrationFS.ReadFile("migration/prod/" + latestSchemaFileName)
  62. if err != nil {
  63. return errors.Errorf("failed to read latest schema file: %s", err)
  64. }
  65. stmt := string(buf)
  66. if _, err := d.db.ExecContext(ctx, stmt); err != nil {
  67. return errors.Errorf("failed to exec SQL %s: %s", stmt, err)
  68. }
  69. if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
  70. Version: currentVersion,
  71. }); err != nil {
  72. return errors.Wrap(err, "failed to upsert migration history")
  73. }
  74. return nil
  75. }
  76. migrationHistoryVersionList := []string{}
  77. for _, migrationHistory := range migrationHistoryList {
  78. migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
  79. }
  80. sort.Sort(version.SortVersion(migrationHistoryVersionList))
  81. latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
  82. if !version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
  83. return nil
  84. }
  85. println("start migrate")
  86. for _, minorVersion := range getMinorVersionList() {
  87. normalizedVersion := minorVersion + ".0"
  88. if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
  89. println("applying migration for", normalizedVersion)
  90. if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
  91. return errors.Wrap(err, "failed to apply minor version migration")
  92. }
  93. }
  94. }
  95. println("end migrate")
  96. return nil
  97. }
  98. func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
  99. filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion))
  100. if err != nil {
  101. return errors.Wrap(err, "failed to read ddl files")
  102. }
  103. sort.Strings(filenames)
  104. // Loop over all migration files and execute them in order.
  105. for _, filename := range filenames {
  106. buf, err := migrationFS.ReadFile(filename)
  107. if err != nil {
  108. return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename)
  109. }
  110. for _, stmt := range strings.Split(string(buf), ";") {
  111. if strings.TrimSpace(stmt) == "" {
  112. continue
  113. }
  114. if _, err := d.db.ExecContext(ctx, stmt); err != nil {
  115. return errors.Wrapf(err, "migrate error: %s", stmt)
  116. }
  117. }
  118. }
  119. // Upsert the newest version to migration_history.
  120. version := minorVersion + ".0"
  121. if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{Version: version}); err != nil {
  122. return errors.Wrapf(err, "failed to upsert migration history with version: %s", version)
  123. }
  124. return nil
  125. }
  126. // minorDirRegexp is a regular expression for minor version directory.
  127. var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
  128. func getMinorVersionList() []string {
  129. minorVersionList := []string{}
  130. if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
  131. if err != nil {
  132. return err
  133. }
  134. if file.IsDir() && minorDirRegexp.MatchString(path) {
  135. minorVersionList = append(minorVersionList, file.Name())
  136. }
  137. return nil
  138. }); err != nil {
  139. panic(err)
  140. }
  141. sort.Sort(version.SortVersion(minorVersionList))
  142. return minorVersionList
  143. }