summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/migrate/migrator.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/migrate/migrator.go')
-rw-r--r--vendor/github.com/uptrace/bun/migrate/migrator.go401
1 files changed, 401 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/migrate/migrator.go b/vendor/github.com/uptrace/bun/migrate/migrator.go
new file mode 100644
index 000000000..f9b4a51c2
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/migrate/migrator.go
@@ -0,0 +1,401 @@
+package migrate
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "path/filepath"
+ "regexp"
+ "time"
+
+ "github.com/uptrace/bun"
+)
+
+type MigratorOption func(m *Migrator)
+
+func WithTableName(table string) MigratorOption {
+ return func(m *Migrator) {
+ m.table = table
+ }
+}
+
+func WithLocksTableName(table string) MigratorOption {
+ return func(m *Migrator) {
+ m.locksTable = table
+ }
+}
+
+type Migrator struct {
+ db *bun.DB
+ migrations *Migrations
+
+ ms MigrationSlice
+
+ table string
+ locksTable string
+}
+
+func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Migrator {
+ m := &Migrator{
+ db: db,
+ migrations: migrations,
+
+ ms: migrations.ms,
+
+ table: "bun_migrations",
+ locksTable: "bun_migration_locks",
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+func (m *Migrator) DB() *bun.DB {
+ return m.db
+}
+
+// MigrationsWithStatus returns migrations with status in ascending order.
+func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) {
+ sorted := m.migrations.Sorted()
+
+ applied, err := m.selectAppliedMigrations(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ appliedMap := migrationMap(applied)
+ for i := range sorted {
+ m1 := &sorted[i]
+ if m2, ok := appliedMap[m1.Name]; ok {
+ m1.ID = m2.ID
+ m1.GroupID = m2.GroupID
+ m1.MigratedAt = m2.MigratedAt
+ }
+ }
+
+ return sorted, nil
+}
+
+func (m *Migrator) Init(ctx context.Context) error {
+ if _, err := m.db.NewCreateTable().
+ Model((*Migration)(nil)).
+ ModelTableExpr(m.table).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ if _, err := m.db.NewCreateTable().
+ Model((*migrationLock)(nil)).
+ ModelTableExpr(m.locksTable).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (m *Migrator) Reset(ctx context.Context) error {
+ if _, err := m.db.NewDropTable().
+ Model((*Migration)(nil)).
+ ModelTableExpr(m.table).
+ IfExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ if _, err := m.db.NewDropTable().
+ Model((*migrationLock)(nil)).
+ ModelTableExpr(m.locksTable).
+ IfExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ return m.Init(ctx)
+}
+
+func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
+ cfg := newMigrationConfig(opts)
+
+ if err := m.validate(); err != nil {
+ return nil, err
+ }
+
+ if err := m.Lock(ctx); err != nil {
+ return nil, err
+ }
+ defer m.Unlock(ctx) //nolint:errcheck
+
+ migrations, err := m.MigrationsWithStatus(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ group := &MigrationGroup{
+ Migrations: migrations.Unapplied(),
+ }
+ if len(group.Migrations) == 0 {
+ return group, nil
+ }
+ group.ID = migrations.LastGroupID() + 1
+
+ for i := range group.Migrations {
+ migration := &group.Migrations[i]
+ migration.GroupID = group.ID
+
+ if !cfg.nop && migration.Up != nil {
+ if err := migration.Up(ctx, m.db); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := m.MarkApplied(ctx, migration); err != nil {
+ return nil, err
+ }
+ }
+
+ return group, nil
+}
+
+func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
+ cfg := newMigrationConfig(opts)
+
+ if err := m.validate(); err != nil {
+ return nil, err
+ }
+
+ if err := m.Lock(ctx); err != nil {
+ return nil, err
+ }
+ defer m.Unlock(ctx) //nolint:errcheck
+
+ migrations, err := m.MigrationsWithStatus(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ lastGroup := migrations.LastGroup()
+
+ for i := len(lastGroup.Migrations) - 1; i >= 0; i-- {
+ migration := &lastGroup.Migrations[i]
+
+ if !cfg.nop && migration.Down != nil {
+ if err := migration.Down(ctx, m.db); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := m.MarkUnapplied(ctx, migration); err != nil {
+ return nil, err
+ }
+ }
+
+ return lastGroup, nil
+}
+
+type MigrationStatus struct {
+ Migrations MigrationSlice
+ NewMigrations MigrationSlice
+ LastGroup *MigrationGroup
+}
+
+func (m *Migrator) Status(ctx context.Context) (*MigrationStatus, error) {
+ log.Printf(
+ "DEPRECATED: bun: replace Status(ctx) with " +
+ "MigrationsWithStatus(ctx)")
+
+ migrations, err := m.MigrationsWithStatus(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return &MigrationStatus{
+ Migrations: migrations,
+ NewMigrations: migrations.Unapplied(),
+ LastGroup: migrations.LastGroup(),
+ }, nil
+}
+
+func (m *Migrator) MarkCompleted(ctx context.Context) (*MigrationGroup, error) {
+ log.Printf(
+ "DEPRECATED: bun: replace MarkCompleted(ctx) with " +
+ "Migrate(ctx, migrate.WithNopMigration())")
+
+ return m.Migrate(ctx, WithNopMigration())
+}
+
+type goMigrationConfig struct {
+ packageName string
+}
+
+type GoMigrationOption func(cfg *goMigrationConfig)
+
+func WithPackageName(name string) GoMigrationOption {
+ return func(cfg *goMigrationConfig) {
+ cfg.packageName = name
+ }
+}
+
+// CreateGoMigration creates a Go migration file.
+func (m *Migrator) CreateGoMigration(
+ ctx context.Context, name string, opts ...GoMigrationOption,
+) (*MigrationFile, error) {
+ cfg := &goMigrationConfig{
+ packageName: "migrations",
+ }
+ for _, opt := range opts {
+ opt(cfg)
+ }
+
+ name, err := m.genMigrationName(name)
+ if err != nil {
+ return nil, err
+ }
+
+ fname := name + ".go"
+ fpath := filepath.Join(m.migrations.getDirectory(), fname)
+ content := fmt.Sprintf(goTemplate, cfg.packageName)
+
+ if err := ioutil.WriteFile(fpath, []byte(content), 0o644); err != nil {
+ return nil, err
+ }
+
+ mf := &MigrationFile{
+ Name: fname,
+ Path: fpath,
+ Content: content,
+ }
+ return mf, nil
+}
+
+// CreateSQLMigrations creates an up and down SQL migration files.
+func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
+ name, err := m.genMigrationName(name)
+ if err != nil {
+ return nil, err
+ }
+
+ up, err := m.createSQL(ctx, name+".up.sql")
+ if err != nil {
+ return nil, err
+ }
+
+ down, err := m.createSQL(ctx, name+".down.sql")
+ if err != nil {
+ return nil, err
+ }
+
+ return []*MigrationFile{up, down}, nil
+}
+
+func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) {
+ fpath := filepath.Join(m.migrations.getDirectory(), fname)
+
+ if err := ioutil.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil {
+ return nil, err
+ }
+
+ mf := &MigrationFile{
+ Name: fname,
+ Path: fpath,
+ Content: goTemplate,
+ }
+ return mf, nil
+}
+
+var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`)
+
+func (m *Migrator) genMigrationName(name string) (string, error) {
+ const timeFormat = "20060102150405"
+
+ if name == "" {
+ return "", errors.New("migrate: migration name can't be empty")
+ }
+ if !nameRE.MatchString(name) {
+ return "", fmt.Errorf("migrate: invalid migration name: %q", name)
+ }
+
+ version := time.Now().UTC().Format(timeFormat)
+ return fmt.Sprintf("%s_%s", version, name), nil
+}
+
+// MarkApplied marks the migration as applied (applied).
+func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error {
+ _, err := m.db.NewInsert().Model(migration).
+ ModelTableExpr(m.table).
+ Exec(ctx)
+ return err
+}
+
+// MarkUnapplied marks the migration as unapplied (new).
+func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error {
+ _, err := m.db.NewDelete().
+ Model(migration).
+ ModelTableExpr(m.table).
+ Where("id = ?", migration.ID).
+ Exec(ctx)
+ return err
+}
+
+// selectAppliedMigrations selects applied (applied) migrations in descending order.
+func (m *Migrator) selectAppliedMigrations(ctx context.Context) (MigrationSlice, error) {
+ var ms MigrationSlice
+ if err := m.db.NewSelect().
+ ColumnExpr("*").
+ Model(&ms).
+ ModelTableExpr(m.table).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return ms, nil
+}
+
+func (m *Migrator) formattedTableName(db *bun.DB) string {
+ return db.Formatter().FormatQuery(m.table)
+}
+
+func (m *Migrator) validate() error {
+ if len(m.ms) == 0 {
+ return errors.New("migrate: there are no any migrations")
+ }
+ return nil
+}
+
+//------------------------------------------------------------------------------
+
+type migrationLock struct {
+ ID int64
+ TableName string `bun:",unique"`
+}
+
+func (m *Migrator) Lock(ctx context.Context) error {
+ lock := &migrationLock{
+ TableName: m.formattedTableName(m.db),
+ }
+ if _, err := m.db.NewInsert().
+ Model(lock).
+ ModelTableExpr(m.locksTable).
+ Exec(ctx); err != nil {
+ return fmt.Errorf("migrate: migrations table is already locked (%w)", err)
+ }
+ return nil
+}
+
+func (m *Migrator) Unlock(ctx context.Context) error {
+ tableName := m.formattedTableName(m.db)
+ _, err := m.db.NewDelete().
+ Model((*migrationLock)(nil)).
+ ModelTableExpr(m.locksTable).
+ Where("? = ?", bun.Ident("table_name"), tableName).
+ Exec(ctx)
+ return err
+}
+
+func migrationMap(ms MigrationSlice) map[string]*Migration {
+ mp := make(map[string]*Migration)
+ for i := range ms {
+ m := &ms[i]
+ mp[m.Name] = m
+ }
+ return mp
+}