summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/migrate/migrator.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/migrate/migrator.go')
-rw-r--r--vendor/github.com/uptrace/bun/migrate/migrator.go35
1 files changed, 30 insertions, 5 deletions
diff --git a/vendor/github.com/uptrace/bun/migrate/migrator.go b/vendor/github.com/uptrace/bun/migrate/migrator.go
index ddf5485c0..52290b370 100644
--- a/vendor/github.com/uptrace/bun/migrate/migrator.go
+++ b/vendor/github.com/uptrace/bun/migrate/migrator.go
@@ -267,19 +267,39 @@ func (m *Migrator) CreateGoMigration(
return mf, nil
}
-// CreateSQLMigrations creates an up and down SQL migration files.
+// CreateTxSQLMigration creates transactional up and down SQL migration files.
+func (m *Migrator) CreateTxSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
+ name, err := m.genMigrationName(name)
+ if err != nil {
+ return nil, err
+ }
+
+ up, err := m.createSQL(ctx, name+".up.tx.sql", true)
+ if err != nil {
+ return nil, err
+ }
+
+ down, err := m.createSQL(ctx, name+".down.tx.sql", true)
+ if err != nil {
+ return nil, err
+ }
+
+ return []*MigrationFile{up, down}, nil
+}
+
+// CreateSQLMigrations creates up and down SQL migration files.
func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
name, err := m.genMigrationName(name)
if err != nil {
return nil, err
}
- up, err := m.createSQL(ctx, name+".up.sql")
+ up, err := m.createSQL(ctx, name+".up.sql", false)
if err != nil {
return nil, err
}
- down, err := m.createSQL(ctx, name+".down.sql")
+ down, err := m.createSQL(ctx, name+".down.sql", false)
if err != nil {
return nil, err
}
@@ -287,10 +307,15 @@ func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*Mig
return []*MigrationFile{up, down}, nil
}
-func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) {
+func (m *Migrator) createSQL(ctx context.Context, fname string, transactional bool) (*MigrationFile, error) {
fpath := filepath.Join(m.migrations.getDirectory(), fname)
- if err := os.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil {
+ template := sqlTemplate
+ if transactional {
+ template = transactionalSQLTemplate
+ }
+
+ if err := os.WriteFile(fpath, []byte(template), 0o644); err != nil {
return nil, err
}