db.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "embed"
  6. "errors"
  7. "fmt"
  8. "io/fs"
  9. "os"
  10. "regexp"
  11. "sort"
  12. "time"
  13. "github.com/usememos/memos/server/profile"
  14. "github.com/usememos/memos/server/version"
  15. )
  16. //go:embed migration
  17. var migrationFS embed.FS
  18. //go:embed seed
  19. var seedFS embed.FS
  20. type DB struct {
  21. // sqlite db connection instance
  22. DBInstance *sql.DB
  23. profile *profile.Profile
  24. }
  25. // NewDB returns a new instance of DB associated with the given datasource name.
  26. func NewDB(profile *profile.Profile) *DB {
  27. db := &DB{
  28. profile: profile,
  29. }
  30. return db
  31. }
  32. func (db *DB) Open(ctx context.Context) (err error) {
  33. // Ensure a DSN is set before attempting to open the database.
  34. if db.profile.DSN == "" {
  35. return fmt.Errorf("dsn required")
  36. }
  37. // Connect to the database without foreign_key.
  38. sqliteDB, err := sql.Open("sqlite", db.profile.DSN+"?cache=private&_foreign_keys=0&_busy_timeout=10000&_journal_mode=WAL")
  39. if err != nil {
  40. return fmt.Errorf("failed to open db with dsn: %s, err: %w", db.profile.DSN, err)
  41. }
  42. db.DBInstance = sqliteDB
  43. if db.profile.Mode == "prod" {
  44. _, err := os.Stat(db.profile.DSN)
  45. if err != nil {
  46. // If db file not exists, we should create a new one with latest schema.
  47. if errors.Is(err, os.ErrNotExist) {
  48. if err := db.applyLatestSchema(ctx); err != nil {
  49. return fmt.Errorf("failed to apply latest schema, err: %w", err)
  50. }
  51. } else {
  52. return fmt.Errorf("failed to get db file stat, err: %w", err)
  53. }
  54. } else {
  55. // If db file exists, we should check if we need to migrate the database.
  56. currentVersion := version.GetCurrentVersion(db.profile.Mode)
  57. migrationHistoryList, err := db.FindMigrationHistoryList(ctx, &MigrationHistoryFind{})
  58. if err != nil {
  59. return fmt.Errorf("failed to find migration history, err: %w", err)
  60. }
  61. if len(migrationHistoryList) == 0 {
  62. _, err := db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
  63. Version: currentVersion,
  64. })
  65. if err != nil {
  66. return fmt.Errorf("failed to upsert migration history, err: %w", err)
  67. }
  68. return nil
  69. }
  70. migrationHistoryVersionList := []string{}
  71. for _, migrationHistory := range migrationHistoryList {
  72. migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
  73. }
  74. sort.Sort(version.SortVersion(migrationHistoryVersionList))
  75. latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
  76. if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
  77. minorVersionList := getMinorVersionList()
  78. // backup the raw database file before migration
  79. rawBytes, err := os.ReadFile(db.profile.DSN)
  80. if err != nil {
  81. return fmt.Errorf("failed to read raw database file, err: %w", err)
  82. }
  83. backupDBFilePath := fmt.Sprintf("%s/memos_%s_%d_backup.db", db.profile.Data, db.profile.Version, time.Now().Unix())
  84. if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil {
  85. return fmt.Errorf("failed to write raw database file, err: %w", err)
  86. }
  87. println("succeed to copy a backup database file")
  88. println("start migrate")
  89. for _, minorVersion := range minorVersionList {
  90. normalizedVersion := minorVersion + ".0"
  91. if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
  92. println("applying migration for", normalizedVersion)
  93. if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
  94. return fmt.Errorf("failed to apply minor version migration: %w", err)
  95. }
  96. }
  97. }
  98. println("end migrate")
  99. // remove the created backup db file after migrate succeed
  100. if err := os.Remove(backupDBFilePath); err != nil {
  101. println(fmt.Sprintf("Failed to remove temp database file, err %v", err))
  102. }
  103. }
  104. }
  105. } else {
  106. // In non-prod mode, we should always migrate the database.
  107. if _, err := os.Stat(db.profile.DSN); errors.Is(err, os.ErrNotExist) {
  108. if err := db.applyLatestSchema(ctx); err != nil {
  109. return fmt.Errorf("failed to apply latest schema: %w", err)
  110. }
  111. // In demo mode, we should seed the database.
  112. if db.profile.Mode == "demo" {
  113. if err := db.seed(ctx); err != nil {
  114. return fmt.Errorf("failed to seed: %w", err)
  115. }
  116. }
  117. }
  118. }
  119. return nil
  120. }
  121. const (
  122. latestSchemaFileName = "LATEST__SCHEMA.sql"
  123. )
  124. func (db *DB) applyLatestSchema(ctx context.Context) error {
  125. schemaMode := "dev"
  126. if db.profile.Mode == "prod" {
  127. schemaMode = "prod"
  128. }
  129. latestSchemaPath := fmt.Sprintf("%s/%s/%s", "migration", schemaMode, latestSchemaFileName)
  130. buf, err := migrationFS.ReadFile(latestSchemaPath)
  131. if err != nil {
  132. return fmt.Errorf("failed to read latest schema %q, error %w", latestSchemaPath, err)
  133. }
  134. stmt := string(buf)
  135. if err := db.execute(ctx, stmt); err != nil {
  136. return fmt.Errorf("migrate error: statement:%s err=%w", stmt, err)
  137. }
  138. return nil
  139. }
  140. func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
  141. filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion))
  142. if err != nil {
  143. return fmt.Errorf("failed to read ddl files, err: %w", err)
  144. }
  145. sort.Strings(filenames)
  146. migrationStmt := ""
  147. // Loop over all migration files and execute them in order.
  148. for _, filename := range filenames {
  149. buf, err := migrationFS.ReadFile(filename)
  150. if err != nil {
  151. return fmt.Errorf("failed to read minor version migration file, filename=%s err=%w", filename, err)
  152. }
  153. stmt := string(buf)
  154. migrationStmt += stmt
  155. if err := db.execute(ctx, stmt); err != nil {
  156. return fmt.Errorf("migrate error: statement:%s err=%w", stmt, err)
  157. }
  158. }
  159. tx, err := db.DBInstance.Begin()
  160. if err != nil {
  161. return err
  162. }
  163. defer tx.Rollback()
  164. // upsert the newest version to migration_history
  165. version := minorVersion + ".0"
  166. if _, err = upsertMigrationHistory(ctx, tx, &MigrationHistoryUpsert{
  167. Version: version,
  168. }); err != nil {
  169. return fmt.Errorf("failed to upsert migration history with version: %s, err: %w", version, err)
  170. }
  171. return tx.Commit()
  172. }
  173. func (db *DB) seed(ctx context.Context) error {
  174. filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
  175. if err != nil {
  176. return fmt.Errorf("failed to read seed files, err: %w", err)
  177. }
  178. sort.Strings(filenames)
  179. // Loop over all seed files and execute them in order.
  180. for _, filename := range filenames {
  181. buf, err := seedFS.ReadFile(filename)
  182. if err != nil {
  183. return fmt.Errorf("failed to read seed file, filename=%s err=%w", filename, err)
  184. }
  185. stmt := string(buf)
  186. if err := db.execute(ctx, stmt); err != nil {
  187. return fmt.Errorf("seed error: statement:%s err=%w", stmt, err)
  188. }
  189. }
  190. return nil
  191. }
  192. // execute runs a single SQL statement within a transaction.
  193. func (db *DB) execute(ctx context.Context, stmt string) error {
  194. tx, err := db.DBInstance.Begin()
  195. if err != nil {
  196. return err
  197. }
  198. defer tx.Rollback()
  199. if _, err := tx.ExecContext(ctx, stmt); err != nil {
  200. return fmt.Errorf("failed to execute statement, err: %w", err)
  201. }
  202. return tx.Commit()
  203. }
  204. // minorDirRegexp is a regular expression for minor version directory.
  205. var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
  206. func getMinorVersionList() []string {
  207. minorVersionList := []string{}
  208. if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
  209. if err != nil {
  210. return err
  211. }
  212. if file.IsDir() && minorDirRegexp.MatchString(path) {
  213. minorVersionList = append(minorVersionList, file.Name())
  214. }
  215. return nil
  216. }); err != nil {
  217. panic(err)
  218. }
  219. sort.Sort(version.SortVersion(minorVersionList))
  220. return minorVersionList
  221. }