summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go')
-rw-r--r--vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go39
1 files changed, 39 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go
index d20f8c069..4a2cc8864 100644
--- a/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go
@@ -1,6 +1,7 @@
package pgdialect
import (
+ "context"
"fmt"
"strings"
@@ -57,6 +58,11 @@ func (m *migrator) AppendSQL(b []byte, operation interface{}) (_ []byte, err err
case *migrate.DropUniqueConstraintOp:
b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName), change.Unique.Name)
case *migrate.ChangeColumnTypeOp:
+ // If column changes to SERIAL, create sequence first.
+ // https://gist.github.com/oleglomako/185df689706c5499612a0d54d3ffe856
+ if !change.From.GetIsAutoIncrement() && change.To.GetIsAutoIncrement() {
+ change.To, b, err = m.createDefaultSequence(fmter, b, change)
+ }
b, err = m.changeColumnType(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.AddForeignKeyOp:
b, err = m.addForeignKey(fmter, appendAlterTable(b, change.TableName()), change)
@@ -187,6 +193,39 @@ func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate.
return b, nil
}
+// createDefaultSequence creates a SEQUENCE to back a serial column.
+// Having a backing sequence is necessary to change column type to SERIAL.
+// The updated Column's default is set to "nextval" of the new sequence.
+func (m *migrator) createDefaultSequence(_ schema.Formatter, b []byte, op *migrate.ChangeColumnTypeOp) (_ sqlschema.Column, _ []byte, err error) {
+ var last int
+ if err = m.db.NewSelect().Table(op.TableName).
+ ColumnExpr("MAX(?)", op.Column).Scan(context.TODO(), &last); err != nil {
+ return nil, b, err
+ }
+ seq := op.TableName + "_" + op.Column + "_seq"
+ fqn := op.TableName + "." + op.Column
+
+ // A sequence that is OWNED BY a table will be dropped
+ // if the table is dropped with CASCADE action.
+ b = append(b, "CREATE SEQUENCE "...)
+ b = append(b, seq...)
+ b = append(b, " START WITH "...)
+ b = append(b, fmt.Sprint(last+1)...) // start with next value
+ b = append(b, " OWNED BY "...)
+ b = append(b, fqn...)
+ b = append(b, ";\n"...)
+
+ return &Column{
+ Name: op.To.GetName(),
+ SQLType: op.To.GetSQLType(),
+ VarcharLen: op.To.GetVarcharLen(),
+ DefaultValue: fmt.Sprintf("nextval('%s'::regclass)", seq),
+ IsNullable: op.To.GetIsNullable(),
+ IsAutoIncrement: op.To.GetIsAutoIncrement(),
+ IsIdentity: op.To.GetIsIdentity(),
+ }, b, nil
+}
+
func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *migrate.ChangeColumnTypeOp) (_ []byte, err error) {
// alterColumn never re-assigns err, so there is no need to check for err != nil after calling it
var i int