1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
|
package migrate
import (
"github.com/uptrace/bun/migrate/sqlschema"
)
// changeset is a set of changes to the database schema definition.
type changeset struct {
operations []Operation
}
// Add new operations to the changeset.
func (c *changeset) Add(op ...Operation) {
c.operations = append(c.operations, op...)
}
// diff calculates the diff between the current database schema and the target state.
// The changeset is not sorted -- the caller should resolve dependencies before applying the changes.
func diff(got, want sqlschema.Database, opts ...diffOption) *changeset {
d := newDetector(got, want, opts...)
return d.detectChanges()
}
func (d *detector) detectChanges() *changeset {
currentTables := d.current.GetTables()
targetTables := d.target.GetTables()
RenameCreate:
for wantName, wantTable := range targetTables.FromOldest() {
// A table with this name exists in the database. We assume that schema objects won't
// be renamed to an already existing name, nor do we support such cases.
// Simply check if the table definition has changed.
if haveTable, ok := currentTables.Get(wantName); ok {
d.detectColumnChanges(haveTable, wantTable, true)
d.detectConstraintChanges(haveTable, wantTable)
continue
}
// Find all renamed tables. We assume that renamed tables have the same signature.
for haveName, haveTable := range currentTables.FromOldest() {
if _, exists := targetTables.Get(haveName); !exists && d.canRename(haveTable, wantTable) {
d.changes.Add(&RenameTableOp{
TableName: haveTable.GetName(),
NewName: wantName,
})
d.refMap.RenameTable(haveTable.GetName(), wantName)
// Find renamed columns, if any, and check if constraints (PK, UNIQUE) have been updated.
// We need not check wantTable any further.
d.detectColumnChanges(haveTable, wantTable, false)
d.detectConstraintChanges(haveTable, wantTable)
currentTables.Delete(haveName)
continue RenameCreate
}
}
// If wantTable does not exist in the database and was not renamed
// then we need to create this table in the database.
additional := wantTable.(*sqlschema.BunTable)
d.changes.Add(&CreateTableOp{
TableName: wantTable.GetName(),
Model: additional.Model,
})
}
// Drop any remaining "current" tables which do not have a model.
for name, table := range currentTables.FromOldest() {
if _, keep := targetTables.Get(name); !keep {
d.changes.Add(&DropTableOp{
TableName: table.GetName(),
})
}
}
targetFKs := d.target.GetForeignKeys()
currentFKs := d.refMap.Deref()
for fk := range targetFKs {
if _, ok := currentFKs[fk]; !ok {
d.changes.Add(&AddForeignKeyOp{
ForeignKey: fk,
ConstraintName: "", // leave empty to let each dialect apply their convention
})
}
}
for fk, name := range currentFKs {
if _, ok := targetFKs[fk]; !ok {
d.changes.Add(&DropForeignKeyOp{
ConstraintName: name,
ForeignKey: fk,
})
}
}
return &d.changes
}
// detechColumnChanges finds renamed columns and, if checkType == true, columns with changed type.
func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkType bool) {
currentColumns := current.GetColumns()
targetColumns := target.GetColumns()
ChangeRename:
for tName, tCol := range targetColumns.FromOldest() {
// This column exists in the database, so it hasn't been renamed, dropped, or added.
// Still, we should not delete(columns, thisColumn), because later we will need to
// check that we do not try to rename a column to an already a name that already exists.
if cCol, ok := currentColumns.Get(tName); ok {
if checkType && !d.equalColumns(cCol, tCol) {
d.changes.Add(&ChangeColumnTypeOp{
TableName: target.GetName(),
Column: tName,
From: cCol,
To: d.makeTargetColDef(cCol, tCol),
})
}
continue
}
// Column tName does not exist in the database -- it's been either renamed or added.
// Find renamed columns first.
for cName, cCol := range currentColumns.FromOldest() {
// Cannot rename if a column with this name already exists or the types differ.
if _, exists := targetColumns.Get(cName); exists || !d.equalColumns(tCol, cCol) {
continue
}
d.changes.Add(&RenameColumnOp{
TableName: target.GetName(),
OldName: cName,
NewName: tName,
})
d.refMap.RenameColumn(target.GetName(), cName, tName)
currentColumns.Delete(cName) // no need to check this column again
// Update primary key definition to avoid superficially recreating the constraint.
current.GetPrimaryKey().Columns.Replace(cName, tName)
continue ChangeRename
}
d.changes.Add(&AddColumnOp{
TableName: target.GetName(),
ColumnName: tName,
Column: tCol,
})
}
// Drop columns which do not exist in the target schema and were not renamed.
for cName, cCol := range currentColumns.FromOldest() {
if _, keep := targetColumns.Get(cName); !keep {
d.changes.Add(&DropColumnOp{
TableName: target.GetName(),
ColumnName: cName,
Column: cCol,
})
}
}
}
func (d *detector) detectConstraintChanges(current, target sqlschema.Table) {
Add:
for _, want := range target.GetUniqueConstraints() {
for _, got := range current.GetUniqueConstraints() {
if got.Equals(want) {
continue Add
}
}
d.changes.Add(&AddUniqueConstraintOp{
TableName: target.GetName(),
Unique: want,
})
}
Drop:
for _, got := range current.GetUniqueConstraints() {
for _, want := range target.GetUniqueConstraints() {
if got.Equals(want) {
continue Drop
}
}
d.changes.Add(&DropUniqueConstraintOp{
TableName: target.GetName(),
Unique: got,
})
}
targetPK := target.GetPrimaryKey()
currentPK := current.GetPrimaryKey()
// Detect primary key changes
if targetPK == nil && currentPK == nil {
return
}
switch {
case targetPK == nil && currentPK != nil:
d.changes.Add(&DropPrimaryKeyOp{
TableName: target.GetName(),
PrimaryKey: *currentPK,
})
case currentPK == nil && targetPK != nil:
d.changes.Add(&AddPrimaryKeyOp{
TableName: target.GetName(),
PrimaryKey: *targetPK,
})
case targetPK.Columns != currentPK.Columns:
d.changes.Add(&ChangePrimaryKeyOp{
TableName: target.GetName(),
Old: *currentPK,
New: *targetPK,
})
}
}
func newDetector(got, want sqlschema.Database, opts ...diffOption) *detector {
cfg := &detectorConfig{
cmpType: func(c1, c2 sqlschema.Column) bool {
return c1.GetSQLType() == c2.GetSQLType() && c1.GetVarcharLen() == c2.GetVarcharLen()
},
}
for _, opt := range opts {
opt(cfg)
}
return &detector{
current: got,
target: want,
refMap: newRefMap(got.GetForeignKeys()),
cmpType: cfg.cmpType,
}
}
type diffOption func(*detectorConfig)
func withCompareTypeFunc(f CompareTypeFunc) diffOption {
return func(cfg *detectorConfig) {
cfg.cmpType = f
}
}
// detectorConfig controls how differences in the model states are resolved.
type detectorConfig struct {
cmpType CompareTypeFunc
}
// detector may modify the passed database schemas, so it isn't safe to re-use them.
type detector struct {
// current state represents the existing database schema.
current sqlschema.Database
// target state represents the database schema defined in bun models.
target sqlschema.Database
changes changeset
refMap refMap
// cmpType determines column type equivalence.
// Default is direct comparison with '==' operator, which is inaccurate
// due to the existence of dialect-specific type aliases. The caller
// should pass a concrete InspectorDialect.EquuivalentType for robust comparison.
cmpType CompareTypeFunc
}
// canRename checks if t1 can be renamed to t2.
func (d detector) canRename(t1, t2 sqlschema.Table) bool {
return t1.GetSchema() == t2.GetSchema() && equalSignatures(t1, t2, d.equalColumns)
}
func (d detector) equalColumns(col1, col2 sqlschema.Column) bool {
return d.cmpType(col1, col2) &&
col1.GetDefaultValue() == col2.GetDefaultValue() &&
col1.GetIsNullable() == col2.GetIsNullable() &&
col1.GetIsAutoIncrement() == col2.GetIsAutoIncrement() &&
col1.GetIsIdentity() == col2.GetIsIdentity()
}
func (d detector) makeTargetColDef(current, target sqlschema.Column) sqlschema.Column {
// Avoid unneccessary type-change migrations if the types are equivalent.
if d.cmpType(current, target) {
target = &sqlschema.BaseColumn{
Name: target.GetName(),
DefaultValue: target.GetDefaultValue(),
IsNullable: target.GetIsNullable(),
IsAutoIncrement: target.GetIsAutoIncrement(),
IsIdentity: target.GetIsIdentity(),
SQLType: current.GetSQLType(),
VarcharLen: current.GetVarcharLen(),
}
}
return target
}
type CompareTypeFunc func(sqlschema.Column, sqlschema.Column) bool
// equalSignatures determines if two tables have the same "signature".
func equalSignatures(t1, t2 sqlschema.Table, eq CompareTypeFunc) bool {
sig1 := newSignature(t1, eq)
sig2 := newSignature(t2, eq)
return sig1.Equals(sig2)
}
// signature is a set of column definitions, which allows "relation/name-agnostic" comparison between them;
// meaning that two columns are considered equal if their types are the same.
type signature struct {
// underlying stores the number of occurences for each unique column type.
// It helps to account for the fact that a table might have multiple columns that have the same type.
underlying map[sqlschema.BaseColumn]int
eq CompareTypeFunc
}
func newSignature(t sqlschema.Table, eq CompareTypeFunc) signature {
s := signature{
underlying: make(map[sqlschema.BaseColumn]int),
eq: eq,
}
s.scan(t)
return s
}
// scan iterates over table's field and counts occurrences of each unique column definition.
func (s *signature) scan(t sqlschema.Table) {
for _, icol := range t.GetColumns().FromOldest() {
scanCol := icol.(*sqlschema.BaseColumn)
// This is slightly more expensive than if the columns could be compared directly
// and we always did s.underlying[col]++, but we get type-equivalence in return.
col, count := s.getCount(*scanCol)
if count == 0 {
s.underlying[*scanCol] = 1
} else {
s.underlying[col]++
}
}
}
// getCount uses CompareTypeFunc to find a column with the same (equivalent) SQL type
// and returns its count. Count 0 means there are no columns with of this type.
func (s *signature) getCount(keyCol sqlschema.BaseColumn) (key sqlschema.BaseColumn, count int) {
for col, cnt := range s.underlying {
if s.eq(&col, &keyCol) {
return col, cnt
}
}
return keyCol, 0
}
// Equals returns true if 2 signatures share an identical set of columns.
func (s *signature) Equals(other signature) bool {
if len(s.underlying) != len(other.underlying) {
return false
}
for col, count := range s.underlying {
if _, countOther := other.getCount(col); countOther != count {
return false
}
}
return true
}
// refMap is a utility for tracking superficial changes in foreign keys,
// which do not require any modificiation in the database.
// Modern SQL dialects automatically updated foreign key constraints whenever
// a column or a table is renamed. Detector can use refMap to ignore any
// differences in foreign keys which were caused by renamed column/table.
type refMap map[*sqlschema.ForeignKey]string
func newRefMap(fks map[sqlschema.ForeignKey]string) refMap {
rm := make(map[*sqlschema.ForeignKey]string)
for fk, name := range fks {
rm[&fk] = name
}
return rm
}
// RenameT updates table name in all foreign key definions which depend on it.
func (rm refMap) RenameTable(tableName string, newName string) {
for fk := range rm {
switch tableName {
case fk.From.TableName:
fk.From.TableName = newName
case fk.To.TableName:
fk.To.TableName = newName
}
}
}
// RenameColumn updates column name in all foreign key definions which depend on it.
func (rm refMap) RenameColumn(tableName string, column, newName string) {
for fk := range rm {
if tableName == fk.From.TableName {
fk.From.Column.Replace(column, newName)
}
if tableName == fk.To.TableName {
fk.To.Column.Replace(column, newName)
}
}
}
// Deref returns copies of ForeignKey values to a map.
func (rm refMap) Deref() map[sqlschema.ForeignKey]string {
out := make(map[sqlschema.ForeignKey]string)
for fk, name := range rm {
out[*fk] = name
}
return out
}
|