diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/migrate/migrator.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/migrate/migrator.go | 401 |
1 files changed, 401 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/migrate/migrator.go b/vendor/github.com/uptrace/bun/migrate/migrator.go new file mode 100644 index 000000000..f9b4a51c2 --- /dev/null +++ b/vendor/github.com/uptrace/bun/migrate/migrator.go @@ -0,0 +1,401 @@ +package migrate + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "log" + "path/filepath" + "regexp" + "time" + + "github.com/uptrace/bun" +) + +type MigratorOption func(m *Migrator) + +func WithTableName(table string) MigratorOption { + return func(m *Migrator) { + m.table = table + } +} + +func WithLocksTableName(table string) MigratorOption { + return func(m *Migrator) { + m.locksTable = table + } +} + +type Migrator struct { + db *bun.DB + migrations *Migrations + + ms MigrationSlice + + table string + locksTable string +} + +func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Migrator { + m := &Migrator{ + db: db, + migrations: migrations, + + ms: migrations.ms, + + table: "bun_migrations", + locksTable: "bun_migration_locks", + } + for _, opt := range opts { + opt(m) + } + return m +} + +func (m *Migrator) DB() *bun.DB { + return m.db +} + +// MigrationsWithStatus returns migrations with status in ascending order. +func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) { + sorted := m.migrations.Sorted() + + applied, err := m.selectAppliedMigrations(ctx) + if err != nil { + return nil, err + } + + appliedMap := migrationMap(applied) + for i := range sorted { + m1 := &sorted[i] + if m2, ok := appliedMap[m1.Name]; ok { + m1.ID = m2.ID + m1.GroupID = m2.GroupID + m1.MigratedAt = m2.MigratedAt + } + } + + return sorted, nil +} + +func (m *Migrator) Init(ctx context.Context) error { + if _, err := m.db.NewCreateTable(). + Model((*Migration)(nil)). + ModelTableExpr(m.table). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + if _, err := m.db.NewCreateTable(). + Model((*migrationLock)(nil)). + ModelTableExpr(m.locksTable). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + return nil +} + +func (m *Migrator) Reset(ctx context.Context) error { + if _, err := m.db.NewDropTable(). + Model((*Migration)(nil)). + ModelTableExpr(m.table). + IfExists(). + Exec(ctx); err != nil { + return err + } + if _, err := m.db.NewDropTable(). + Model((*migrationLock)(nil)). + ModelTableExpr(m.locksTable). + IfExists(). + Exec(ctx); err != nil { + return err + } + return m.Init(ctx) +} + +func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { + cfg := newMigrationConfig(opts) + + if err := m.validate(); err != nil { + return nil, err + } + + if err := m.Lock(ctx); err != nil { + return nil, err + } + defer m.Unlock(ctx) //nolint:errcheck + + migrations, err := m.MigrationsWithStatus(ctx) + if err != nil { + return nil, err + } + + group := &MigrationGroup{ + Migrations: migrations.Unapplied(), + } + if len(group.Migrations) == 0 { + return group, nil + } + group.ID = migrations.LastGroupID() + 1 + + for i := range group.Migrations { + migration := &group.Migrations[i] + migration.GroupID = group.ID + + if !cfg.nop && migration.Up != nil { + if err := migration.Up(ctx, m.db); err != nil { + return nil, err + } + } + + if err := m.MarkApplied(ctx, migration); err != nil { + return nil, err + } + } + + return group, nil +} + +func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { + cfg := newMigrationConfig(opts) + + if err := m.validate(); err != nil { + return nil, err + } + + if err := m.Lock(ctx); err != nil { + return nil, err + } + defer m.Unlock(ctx) //nolint:errcheck + + migrations, err := m.MigrationsWithStatus(ctx) + if err != nil { + return nil, err + } + + lastGroup := migrations.LastGroup() + + for i := len(lastGroup.Migrations) - 1; i >= 0; i-- { + migration := &lastGroup.Migrations[i] + + if !cfg.nop && migration.Down != nil { + if err := migration.Down(ctx, m.db); err != nil { + return nil, err + } + } + + if err := m.MarkUnapplied(ctx, migration); err != nil { + return nil, err + } + } + + return lastGroup, nil +} + +type MigrationStatus struct { + Migrations MigrationSlice + NewMigrations MigrationSlice + LastGroup *MigrationGroup +} + +func (m *Migrator) Status(ctx context.Context) (*MigrationStatus, error) { + log.Printf( + "DEPRECATED: bun: replace Status(ctx) with " + + "MigrationsWithStatus(ctx)") + + migrations, err := m.MigrationsWithStatus(ctx) + if err != nil { + return nil, err + } + return &MigrationStatus{ + Migrations: migrations, + NewMigrations: migrations.Unapplied(), + LastGroup: migrations.LastGroup(), + }, nil +} + +func (m *Migrator) MarkCompleted(ctx context.Context) (*MigrationGroup, error) { + log.Printf( + "DEPRECATED: bun: replace MarkCompleted(ctx) with " + + "Migrate(ctx, migrate.WithNopMigration())") + + return m.Migrate(ctx, WithNopMigration()) +} + +type goMigrationConfig struct { + packageName string +} + +type GoMigrationOption func(cfg *goMigrationConfig) + +func WithPackageName(name string) GoMigrationOption { + return func(cfg *goMigrationConfig) { + cfg.packageName = name + } +} + +// CreateGoMigration creates a Go migration file. +func (m *Migrator) CreateGoMigration( + ctx context.Context, name string, opts ...GoMigrationOption, +) (*MigrationFile, error) { + cfg := &goMigrationConfig{ + packageName: "migrations", + } + for _, opt := range opts { + opt(cfg) + } + + name, err := m.genMigrationName(name) + if err != nil { + return nil, err + } + + fname := name + ".go" + fpath := filepath.Join(m.migrations.getDirectory(), fname) + content := fmt.Sprintf(goTemplate, cfg.packageName) + + if err := ioutil.WriteFile(fpath, []byte(content), 0o644); err != nil { + return nil, err + } + + mf := &MigrationFile{ + Name: fname, + Path: fpath, + Content: content, + } + return mf, nil +} + +// CreateSQLMigrations creates an up and down SQL migration files. +func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) { + name, err := m.genMigrationName(name) + if err != nil { + return nil, err + } + + up, err := m.createSQL(ctx, name+".up.sql") + if err != nil { + return nil, err + } + + down, err := m.createSQL(ctx, name+".down.sql") + if err != nil { + return nil, err + } + + return []*MigrationFile{up, down}, nil +} + +func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) { + fpath := filepath.Join(m.migrations.getDirectory(), fname) + + if err := ioutil.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil { + return nil, err + } + + mf := &MigrationFile{ + Name: fname, + Path: fpath, + Content: goTemplate, + } + return mf, nil +} + +var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`) + +func (m *Migrator) genMigrationName(name string) (string, error) { + const timeFormat = "20060102150405" + + if name == "" { + return "", errors.New("migrate: migration name can't be empty") + } + if !nameRE.MatchString(name) { + return "", fmt.Errorf("migrate: invalid migration name: %q", name) + } + + version := time.Now().UTC().Format(timeFormat) + return fmt.Sprintf("%s_%s", version, name), nil +} + +// MarkApplied marks the migration as applied (applied). +func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error { + _, err := m.db.NewInsert().Model(migration). + ModelTableExpr(m.table). + Exec(ctx) + return err +} + +// MarkUnapplied marks the migration as unapplied (new). +func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error { + _, err := m.db.NewDelete(). + Model(migration). + ModelTableExpr(m.table). + Where("id = ?", migration.ID). + Exec(ctx) + return err +} + +// selectAppliedMigrations selects applied (applied) migrations in descending order. +func (m *Migrator) selectAppliedMigrations(ctx context.Context) (MigrationSlice, error) { + var ms MigrationSlice + if err := m.db.NewSelect(). + ColumnExpr("*"). + Model(&ms). + ModelTableExpr(m.table). + Scan(ctx); err != nil { + return nil, err + } + return ms, nil +} + +func (m *Migrator) formattedTableName(db *bun.DB) string { + return db.Formatter().FormatQuery(m.table) +} + +func (m *Migrator) validate() error { + if len(m.ms) == 0 { + return errors.New("migrate: there are no any migrations") + } + return nil +} + +//------------------------------------------------------------------------------ + +type migrationLock struct { + ID int64 + TableName string `bun:",unique"` +} + +func (m *Migrator) Lock(ctx context.Context) error { + lock := &migrationLock{ + TableName: m.formattedTableName(m.db), + } + if _, err := m.db.NewInsert(). + Model(lock). + ModelTableExpr(m.locksTable). + Exec(ctx); err != nil { + return fmt.Errorf("migrate: migrations table is already locked (%w)", err) + } + return nil +} + +func (m *Migrator) Unlock(ctx context.Context) error { + tableName := m.formattedTableName(m.db) + _, err := m.db.NewDelete(). + Model((*migrationLock)(nil)). + ModelTableExpr(m.locksTable). + Where("? = ?", bun.Ident("table_name"), tableName). + Exec(ctx) + return err +} + +func migrationMap(ms MigrationSlice) map[string]*Migration { + mp := make(map[string]*Migration) + for i := range ms { + m := &ms[i] + mp[m.Name] = m + } + return mp +} |