summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/migrate/migration.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/migrate/migration.go')
-rw-r--r--vendor/github.com/uptrace/bun/migrate/migration.go49
1 files changed, 42 insertions, 7 deletions
diff --git a/vendor/github.com/uptrace/bun/migrate/migration.go b/vendor/github.com/uptrace/bun/migrate/migration.go
index 3f4076d2b..4d60a5858 100644
--- a/vendor/github.com/uptrace/bun/migrate/migration.go
+++ b/vendor/github.com/uptrace/bun/migrate/migration.go
@@ -9,6 +9,7 @@ import (
"io/fs"
"sort"
"strings"
+ "text/template"
"time"
"github.com/uptrace/bun"
@@ -23,8 +24,8 @@ type Migration struct {
GroupID int64
MigratedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"`
- Up MigrationFunc `bun:"-"`
- Down MigrationFunc `bun:"-"`
+ Up internalMigrationFunc `bun:"-"`
+ Down internalMigrationFunc `bun:"-"`
}
func (m Migration) String() string {
@@ -35,23 +36,57 @@ func (m Migration) IsApplied() bool {
return m.ID > 0
}
+type internalMigrationFunc func(ctx context.Context, db *bun.DB, templateData any) error
+
type MigrationFunc func(ctx context.Context, db *bun.DB) error
-func NewSQLMigrationFunc(fsys fs.FS, name string) MigrationFunc {
- return func(ctx context.Context, db *bun.DB) error {
+func NewSQLMigrationFunc(fsys fs.FS, name string) internalMigrationFunc {
+ return func(ctx context.Context, db *bun.DB, templateData any) error {
f, err := fsys.Open(name)
if err != nil {
return err
}
isTx := strings.HasSuffix(name, ".tx.up.sql") || strings.HasSuffix(name, ".tx.down.sql")
- return Exec(ctx, db, f, isTx)
+ return Exec(ctx, db, f, templateData, isTx)
+ }
+}
+
+func wrapMigrationFunc(fn MigrationFunc) internalMigrationFunc {
+ return func(ctx context.Context, db *bun.DB, templateData any) error {
+ return fn(ctx, db)
+ }
+}
+
+func renderTemplate(contents []byte, templateData any) (*bytes.Buffer, error) {
+ tmpl, err := template.New("migration").Parse(string(contents))
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse template: %w", err)
+ }
+
+ var rendered bytes.Buffer
+ if err := tmpl.Execute(&rendered, templateData); err != nil {
+ return nil, fmt.Errorf("failed to execute template: %w", err)
}
+
+ return &rendered, nil
}
// Exec reads and executes the SQL migration in the f.
-func Exec(ctx context.Context, db *bun.DB, f io.Reader, isTx bool) error {
- scanner := bufio.NewScanner(f)
+func Exec(ctx context.Context, db *bun.DB, f io.Reader, templateData any, isTx bool) error {
+ contents, err := io.ReadAll(f)
+ if err != nil {
+ return err
+ }
+ var reader io.Reader = bytes.NewReader(contents)
+ if templateData != nil {
+ buf, err := renderTemplate(contents, templateData)
+ if err != nil {
+ return err
+ }
+ reader = buf
+ }
+ scanner := bufio.NewScanner(reader)
var queries []string
var query []byte