diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/dialect/pgdialect')
4 files changed, 60 insertions, 11 deletions
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go index d20f8c069..4a2cc8864 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go @@ -1,6 +1,7 @@ package pgdialect import ( + "context" "fmt" "strings" @@ -57,6 +58,11 @@ func (m *migrator) AppendSQL(b []byte, operation interface{}) (_ []byte, err err case *migrate.DropUniqueConstraintOp: b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName), change.Unique.Name) case *migrate.ChangeColumnTypeOp: + // If column changes to SERIAL, create sequence first. + // https://gist.github.com/oleglomako/185df689706c5499612a0d54d3ffe856 + if !change.From.GetIsAutoIncrement() && change.To.GetIsAutoIncrement() { + change.To, b, err = m.createDefaultSequence(fmter, b, change) + } b, err = m.changeColumnType(fmter, appendAlterTable(b, change.TableName), change) case *migrate.AddForeignKeyOp: b, err = m.addForeignKey(fmter, appendAlterTable(b, change.TableName()), change) @@ -187,6 +193,39 @@ func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate. return b, nil } +// createDefaultSequence creates a SEQUENCE to back a serial column. +// Having a backing sequence is necessary to change column type to SERIAL. +// The updated Column's default is set to "nextval" of the new sequence. +func (m *migrator) createDefaultSequence(_ schema.Formatter, b []byte, op *migrate.ChangeColumnTypeOp) (_ sqlschema.Column, _ []byte, err error) { + var last int + if err = m.db.NewSelect().Table(op.TableName). + ColumnExpr("MAX(?)", op.Column).Scan(context.TODO(), &last); err != nil { + return nil, b, err + } + seq := op.TableName + "_" + op.Column + "_seq" + fqn := op.TableName + "." + op.Column + + // A sequence that is OWNED BY a table will be dropped + // if the table is dropped with CASCADE action. + b = append(b, "CREATE SEQUENCE "...) + b = append(b, seq...) + b = append(b, " START WITH "...) + b = append(b, fmt.Sprint(last+1)...) // start with next value + b = append(b, " OWNED BY "...) + b = append(b, fqn...) + b = append(b, ";\n"...) + + return &Column{ + Name: op.To.GetName(), + SQLType: op.To.GetSQLType(), + VarcharLen: op.To.GetVarcharLen(), + DefaultValue: fmt.Sprintf("nextval('%s'::regclass)", seq), + IsNullable: op.To.GetIsNullable(), + IsAutoIncrement: op.To.GetIsAutoIncrement(), + IsIdentity: op.To.GetIsIdentity(), + }, b, nil +} + func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *migrate.ChangeColumnTypeOp) (_ []byte, err error) { // alterColumn never re-assigns err, so there is no need to check for err != nil after calling it var i int diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go index 040df439c..ea5269ac2 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go @@ -5,7 +5,6 @@ import ( "strings" "github.com/uptrace/bun" - "github.com/uptrace/bun/internal/ordered" "github.com/uptrace/bun/migrate/sqlschema" ) @@ -34,13 +33,12 @@ func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector { func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { dbSchema := Schema{ - Tables: ordered.NewMap[string, sqlschema.Table](), ForeignKeys: make(map[sqlschema.ForeignKey]string), } exclude := in.ExcludeTables if len(exclude) == 0 { - // Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice. + // Avoid getting NOT LIKE ALL (ARRAY[NULL]) if bun.In() is called with an empty slice. exclude = []string{""} } @@ -61,7 +59,7 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { return dbSchema, err } - colDefs := ordered.NewMap[string, sqlschema.Column]() + var colDefs []sqlschema.Column uniqueGroups := make(map[string][]string) for _, c := range columns { @@ -72,7 +70,7 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { def = strings.ToLower(def) } - colDefs.Store(c.Name, &Column{ + colDefs = append(colDefs, &Column{ Name: c.Name, SQLType: c.DataType, VarcharLen: c.VarcharLen, @@ -103,7 +101,7 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { } } - dbSchema.Tables.Store(table.Name, &Table{ + dbSchema.Tables = append(dbSchema.Tables, &Table{ Schema: table.Schema, Name: table.Name, Columns: colDefs, @@ -113,10 +111,14 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { } for _, fk := range fks { - dbSchema.ForeignKeys[sqlschema.ForeignKey{ + dbFK := sqlschema.ForeignKey{ From: sqlschema.NewColumnReference(fk.SourceTable, fk.SourceColumns...), To: sqlschema.NewColumnReference(fk.TargetTable, fk.TargetColumns...), - }] = fk.ConstraintName + } + if _, exclude := in.ExcludeForeignKeys[dbFK]; exclude { + continue + } + dbSchema.ForeignKeys[dbFK] = fk.ConstraintName } return dbSchema, nil } @@ -185,7 +187,7 @@ FROM information_schema.tables "t" WHERE table_type = 'BASE TABLE' AND "t".table_schema = ? AND "t".table_schema NOT LIKE 'pg_%' - AND "table_name" NOT IN (?) + AND "table_name" NOT LIKE ALL (ARRAY[?]) ORDER BY "t".table_schema, "t".table_name ` @@ -291,7 +293,8 @@ WHERE co.contype = 'f' AND co.conrelid IN (SELECT oid FROM pg_class WHERE relkind = 'r') AND ARRAY_POSITION(co.conkey, sc.attnum) = ARRAY_POSITION(co.confkey, tc.attnum) AND ss.nspname = ? - AND s.relname NOT IN (?) AND "t".relname NOT IN (?) + AND s.relname NOT LIKE ALL (ARRAY[?]) + AND "t".relname NOT LIKE ALL (ARRAY[?]) GROUP BY "constraint_name", "schema_name", "table_name", target_schema, target_table ` ) diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go index 121a3d691..5f35a29ec 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go @@ -127,6 +127,9 @@ var ( char = newAliases(pgTypeChar, pgTypeCharacter) varchar = newAliases(pgTypeVarchar, pgTypeCharacterVarying) timestampTz = newAliases(sqltype.Timestamp, pgTypeTimestampTz, pgTypeTimestampWithTz) + bigint = newAliases(sqltype.BigInt, pgTypeBigSerial) + integer = newAliases(sqltype.Integer, pgTypeSerial) + smallint = newAliases(sqltype.SmallInt, pgTypeSmallSerial) ) func (d *Dialect) CompareType(col1, col2 sqlschema.Column) bool { @@ -143,6 +146,10 @@ func (d *Dialect) CompareType(col1, col2 sqlschema.Column) bool { return checkVarcharLen(col1, col2, d.DefaultVarcharLen()) case timestampTz.IsAlias(typ1) && timestampTz.IsAlias(typ2): return true + case bigint.IsAlias(typ1) && bigint.IsAlias(typ2), + integer.IsAlias(typ1) && integer.IsAlias(typ2), + smallint.IsAlias(typ1) && smallint.IsAlias(typ2): + return true } return false } diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go index d646f564f..c774ccc50 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go @@ -2,5 +2,5 @@ package pgdialect // Version is the current release version. func Version() string { - return "1.2.11" + return "1.2.14" } |
