1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
|
package migrate
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"github.com/uptrace/bun"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)
type AutoMigratorOption func(m *AutoMigrator)
// WithModel adds a bun.Model to the scope of migrations.
func WithModel(models ...interface{}) AutoMigratorOption {
return func(m *AutoMigrator) {
m.includeModels = append(m.includeModels, models...)
}
}
// WithExcludeTable tells the AutoMigrator to ignore a table in the database.
// This prevents AutoMigrator from dropping tables which may exist in the schema
// but which are not used by the application.
//
// Do not exclude tables included via WithModel, as BunModelInspector ignores this setting.
func WithExcludeTable(tables ...string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.excludeTables = append(m.excludeTables, tables...)
}
}
// WithSchemaName changes the default database schema to migrate objects in.
func WithSchemaName(schemaName string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.schemaName = schemaName
}
}
// WithTableNameAuto overrides default migrations table name.
func WithTableNameAuto(table string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.table = table
m.migratorOpts = append(m.migratorOpts, WithTableName(table))
}
}
// WithLocksTableNameAuto overrides default migration locks table name.
func WithLocksTableNameAuto(table string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.locksTable = table
m.migratorOpts = append(m.migratorOpts, WithLocksTableName(table))
}
}
// WithMarkAppliedOnSuccessAuto sets the migrator to only mark migrations as applied/unapplied
// when their up/down is successful.
func WithMarkAppliedOnSuccessAuto(enabled bool) AutoMigratorOption {
return func(m *AutoMigrator) {
m.migratorOpts = append(m.migratorOpts, WithMarkAppliedOnSuccess(enabled))
}
}
// WithMigrationsDirectoryAuto overrides the default directory for migration files.
func WithMigrationsDirectoryAuto(directory string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.migrationsOpts = append(m.migrationsOpts, WithMigrationsDirectory(directory))
}
}
// AutoMigrator performs automated schema migrations.
//
// It is designed to be a drop-in replacement for some Migrator functionality and supports all existing
// configuration options.
// Similarly to Migrator, it has methods to create SQL migrations, write them to a file, and apply them.
// Unlike Migrator, it detects the differences between the state defined by bun models and the current
// database schema automatically.
//
// Usage:
// 1. Generate migrations and apply them au once with AutoMigrator.Migrate().
// 2. Create up- and down-SQL migration files and apply migrations using Migrator.Migrate().
//
// While both methods produce complete, reversible migrations (with entries in the database
// and SQL migration files), prefer creating migrations and applying them separately for
// any non-trivial cases to ensure AutoMigrator detects expected changes correctly.
//
// Limitations:
// - AutoMigrator only supports a subset of the possible ALTER TABLE modifications.
// - Some changes are not automatically reversible. For example, you would need to manually
// add a CREATE TABLE query to the .down migration file to revert a DROP TABLE migration.
// - Does not validate most dialect-specific constraints. For example, when changing column
// data type, make sure the data con be auto-casted to the new type.
// - Due to how the schema-state diff is calculated, it is not possible to rename a table and
// modify any of its columns' _data type_ in a single run. This will cause the AutoMigrator
// to drop and re-create the table under a different name; it is better to apply this change in 2 steps.
// Renaming a table and renaming its columns at the same time is possible.
// - Renaming table/column to an existing name, i.e. like this [A->B] [B->C], is not possible due to how
// AutoMigrator distinguishes "rename" and "unchanged" columns.
//
// Dialect must implement both sqlschema.Inspector and sqlschema.Migrator to be used with AutoMigrator.
type AutoMigrator struct {
db *bun.DB
// dbInspector creates the current state for the target database.
dbInspector sqlschema.Inspector
// modelInspector creates the desired state based on the model definitions.
modelInspector sqlschema.Inspector
// dbMigrator executes ALTER TABLE queries.
dbMigrator sqlschema.Migrator
table string // Migrations table (excluded from database inspection)
locksTable string // Migration locks table (excluded from database inspection)
// schemaName is the database schema considered for migration.
schemaName string
// includeModels define the migration scope.
includeModels []interface{}
// excludeTables are excluded from database inspection.
excludeTables []string
// diffOpts are passed to detector constructor.
diffOpts []diffOption
// migratorOpts are passed to Migrator constructor.
migratorOpts []MigratorOption
// migrationsOpts are passed to Migrations constructor.
migrationsOpts []MigrationsOption
}
func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, error) {
am := &AutoMigrator{
db: db,
table: defaultTable,
locksTable: defaultLocksTable,
schemaName: db.Dialect().DefaultSchema(),
}
for _, opt := range opts {
opt(am)
}
am.excludeTables = append(am.excludeTables, am.table, am.locksTable)
dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(am.schemaName), sqlschema.WithExcludeTables(am.excludeTables...))
if err != nil {
return nil, err
}
am.dbInspector = dbInspector
am.diffOpts = append(am.diffOpts, withCompareTypeFunc(db.Dialect().(sqlschema.InspectorDialect).CompareType))
dbMigrator, err := sqlschema.NewMigrator(db, am.schemaName)
if err != nil {
return nil, err
}
am.dbMigrator = dbMigrator
tables := schema.NewTables(db.Dialect())
tables.Register(am.includeModels...)
am.modelInspector = sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(am.schemaName))
return am, nil
}
func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) {
var err error
got, err := am.dbInspector.Inspect(ctx)
if err != nil {
return nil, err
}
want, err := am.modelInspector.Inspect(ctx)
if err != nil {
return nil, err
}
changes := diff(got, want, am.diffOpts...)
if err := changes.ResolveDependencies(); err != nil {
return nil, fmt.Errorf("plan migrations: %w", err)
}
return changes, nil
}
// Migrate writes required changes to a new migration file and runs the migration.
// This will create and entry in the migrations table, making it possible to revert
// the changes with Migrator.Rollback(). MigrationOptions are passed on to Migrator.Migrate().
func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
migrations, _, err := am.createSQLMigrations(ctx, false)
if err != nil {
if err == errNothingToMigrate {
return new(MigrationGroup), nil
}
return nil, fmt.Errorf("auto migrate: %w", err)
}
migrator := NewMigrator(am.db, migrations, am.migratorOpts...)
if err := migrator.Init(ctx); err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
}
group, err := migrator.Migrate(ctx, opts...)
if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
}
return group, nil
}
// CreateSQLMigration writes required changes to a new migration file.
// Use migrate.Migrator to apply the generated migrations.
func (am *AutoMigrator) CreateSQLMigrations(ctx context.Context) ([]*MigrationFile, error) {
_, files, err := am.createSQLMigrations(ctx, false)
if err == errNothingToMigrate {
return files, nil
}
return files, err
}
// CreateTxSQLMigration writes required changes to a new migration file making sure they will be executed
// in a transaction when applied. Use migrate.Migrator to apply the generated migrations.
func (am *AutoMigrator) CreateTxSQLMigrations(ctx context.Context) ([]*MigrationFile, error) {
_, files, err := am.createSQLMigrations(ctx, true)
if err == errNothingToMigrate {
return files, nil
}
return files, err
}
// errNothingToMigrate is a sentinel error which means the database is already in a desired state.
// Should not be returned to the user -- return a nil-error instead.
var errNothingToMigrate = errors.New("nothing to migrate")
func (am *AutoMigrator) createSQLMigrations(ctx context.Context, transactional bool) (*Migrations, []*MigrationFile, error) {
changes, err := am.plan(ctx)
if err != nil {
return nil, nil, fmt.Errorf("create sql migrations: %w", err)
}
if changes.Len() == 0 {
return nil, nil, errNothingToMigrate
}
name, _ := genMigrationName(am.schemaName + "_auto")
migrations := NewMigrations(am.migrationsOpts...)
migrations.Add(Migration{
Name: name,
Up: changes.Up(am.dbMigrator),
Down: changes.Down(am.dbMigrator),
Comment: "Changes detected by bun.AutoMigrator",
})
// Append .tx.up.sql or .up.sql to migration name, dependin if it should be transactional.
fname := func(direction string) string {
return name + map[bool]string{true: ".tx.", false: "."}[transactional] + direction + ".sql"
}
up, err := am.createSQL(ctx, migrations, fname("up"), changes, transactional)
if err != nil {
return nil, nil, fmt.Errorf("create sql migration up: %w", err)
}
down, err := am.createSQL(ctx, migrations, fname("down"), changes.GetReverse(), transactional)
if err != nil {
return nil, nil, fmt.Errorf("create sql migration down: %w", err)
}
return migrations, []*MigrationFile{up, down}, nil
}
func (am *AutoMigrator) createSQL(_ context.Context, migrations *Migrations, fname string, changes *changeset, transactional bool) (*MigrationFile, error) {
var buf bytes.Buffer
if transactional {
buf.WriteString("SET statement_timeout = 0;")
}
if err := changes.WriteTo(&buf, am.dbMigrator); err != nil {
return nil, err
}
content := buf.Bytes()
fpath := filepath.Join(migrations.getDirectory(), fname)
if err := os.WriteFile(fpath, content, 0o644); err != nil {
return nil, err
}
mf := &MigrationFile{
Name: fname,
Path: fpath,
Content: string(content),
}
return mf, nil
}
func (c *changeset) Len() int {
return len(c.operations)
}
// Func creates a MigrationFunc that applies all operations all the changeset.
func (c *changeset) Func(m sqlschema.Migrator) MigrationFunc {
return func(ctx context.Context, db *bun.DB) error {
return c.apply(ctx, db, m)
}
}
// GetReverse returns a new changeset with each operation in it "reversed" and in reverse order.
func (c *changeset) GetReverse() *changeset {
var reverse changeset
for i := len(c.operations) - 1; i >= 0; i-- {
reverse.Add(c.operations[i].GetReverse())
}
return &reverse
}
// Up is syntactic sugar.
func (c *changeset) Up(m sqlschema.Migrator) MigrationFunc {
return c.Func(m)
}
// Down is syntactic sugar.
func (c *changeset) Down(m sqlschema.Migrator) MigrationFunc {
return c.GetReverse().Func(m)
}
// apply generates SQL for each operation and executes it.
func (c *changeset) apply(ctx context.Context, db *bun.DB, m sqlschema.Migrator) error {
if len(c.operations) == 0 {
return nil
}
for _, op := range c.operations {
if _, isComment := op.(*comment); isComment {
continue
}
b := internal.MakeQueryBytes()
b, err := m.AppendSQL(b, op)
if err != nil {
return fmt.Errorf("apply changes: %w", err)
}
query := internal.String(b)
if _, err = db.ExecContext(ctx, query); err != nil {
return fmt.Errorf("apply changes: %w", err)
}
}
return nil
}
func (c *changeset) WriteTo(w io.Writer, m sqlschema.Migrator) error {
var err error
b := internal.MakeQueryBytes()
for _, op := range c.operations {
if c, isComment := op.(*comment); isComment {
b = append(b, "/*\n"...)
b = append(b, *c...)
b = append(b, "\n*/"...)
continue
}
b, err = m.AppendSQL(b, op)
if err != nil {
return fmt.Errorf("write changeset: %w", err)
}
b = append(b, ";\n"...)
}
if _, err := w.Write(b); err != nil {
return fmt.Errorf("write changeset: %w", err)
}
return nil
}
func (c *changeset) ResolveDependencies() error {
if len(c.operations) <= 1 {
return nil
}
const (
unvisited = iota
current
visited
)
status := make(map[Operation]int, len(c.operations))
for _, op := range c.operations {
status[op] = unvisited
}
var resolved []Operation
var nextOp Operation
var visit func(op Operation) error
next := func() bool {
for op, s := range status {
if s == unvisited {
nextOp = op
return true
}
}
return false
}
// visit iterates over c.operations until it finds all operations that depend on the current one
// or runs into cirtular dependency, in which case it will return an error.
visit = func(op Operation) error {
switch status[op] {
case visited:
return nil
case current:
// TODO: add details (circle) to the error message
return errors.New("detected circular dependency")
}
status[op] = current
for _, another := range c.operations {
if dop, hasDeps := another.(interface {
DependsOn(Operation) bool
}); another == op || !hasDeps || !dop.DependsOn(op) {
continue
}
if err := visit(another); err != nil {
return err
}
}
status[op] = visited
// Any dependent nodes would've already been added to the list by now, so we prepend.
resolved = append([]Operation{op}, resolved...)
return nil
}
for next() {
if err := visit(nextOp); err != nil {
return err
}
}
c.operations = resolved
return nil
}
|