diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/dialect/pgdialect')
6 files changed, 660 insertions, 12 deletions
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go new file mode 100644 index 000000000..d20f8c069 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/alter_table.go @@ -0,0 +1,245 @@ +package pgdialect + +import ( + "fmt" + "strings" + + "github.com/uptrace/bun" + "github.com/uptrace/bun/migrate" + "github.com/uptrace/bun/migrate/sqlschema" + "github.com/uptrace/bun/schema" +) + +func (d *Dialect) NewMigrator(db *bun.DB, schemaName string) sqlschema.Migrator { + return &migrator{db: db, schemaName: schemaName, BaseMigrator: sqlschema.NewBaseMigrator(db)} +} + +type migrator struct { + *sqlschema.BaseMigrator + + db *bun.DB + schemaName string +} + +var _ sqlschema.Migrator = (*migrator)(nil) + +func (m *migrator) AppendSQL(b []byte, operation interface{}) (_ []byte, err error) { + fmter := m.db.Formatter() + + // Append ALTER TABLE statement to the enclosed query bytes []byte. + appendAlterTable := func(query []byte, tableName string) []byte { + query = append(query, "ALTER TABLE "...) + query = m.appendFQN(fmter, query, tableName) + return append(query, " "...) + } + + switch change := operation.(type) { + case *migrate.CreateTableOp: + return m.AppendCreateTable(b, change.Model) + case *migrate.DropTableOp: + return m.AppendDropTable(b, m.schemaName, change.TableName) + case *migrate.RenameTableOp: + b, err = m.renameTable(fmter, appendAlterTable(b, change.TableName), change) + case *migrate.RenameColumnOp: + b, err = m.renameColumn(fmter, appendAlterTable(b, change.TableName), change) + case *migrate.AddColumnOp: + b, err = m.addColumn(fmter, appendAlterTable(b, change.TableName), change) + case *migrate.DropColumnOp: + b, err = m.dropColumn(fmter, appendAlterTable(b, change.TableName), change) + case *migrate.AddPrimaryKeyOp: + b, err = m.addPrimaryKey(fmter, appendAlterTable(b, change.TableName), change.PrimaryKey) + case *migrate.ChangePrimaryKeyOp: + b, err = m.changePrimaryKey(fmter, appendAlterTable(b, change.TableName), change) + case *migrate.DropPrimaryKeyOp: + b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName), change.PrimaryKey.Name) + case *migrate.AddUniqueConstraintOp: + b, err = m.addUnique(fmter, appendAlterTable(b, change.TableName), change) + case *migrate.DropUniqueConstraintOp: + b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName), change.Unique.Name) + case *migrate.ChangeColumnTypeOp: + b, err = m.changeColumnType(fmter, appendAlterTable(b, change.TableName), change) + case *migrate.AddForeignKeyOp: + b, err = m.addForeignKey(fmter, appendAlterTable(b, change.TableName()), change) + case *migrate.DropForeignKeyOp: + b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName()), change.ConstraintName) + default: + return nil, fmt.Errorf("append sql: unknown operation %T", change) + } + if err != nil { + return nil, fmt.Errorf("append sql: %w", err) + } + return b, nil +} + +func (m *migrator) appendFQN(fmter schema.Formatter, b []byte, tableName string) []byte { + return fmter.AppendQuery(b, "?.?", bun.Ident(m.schemaName), bun.Ident(tableName)) +} + +func (m *migrator) renameTable(fmter schema.Formatter, b []byte, rename *migrate.RenameTableOp) (_ []byte, err error) { + b = append(b, "RENAME TO "...) + b = fmter.AppendName(b, rename.NewName) + return b, nil +} + +func (m *migrator) renameColumn(fmter schema.Formatter, b []byte, rename *migrate.RenameColumnOp) (_ []byte, err error) { + b = append(b, "RENAME COLUMN "...) + b = fmter.AppendName(b, rename.OldName) + + b = append(b, " TO "...) + b = fmter.AppendName(b, rename.NewName) + + return b, nil +} + +func (m *migrator) addColumn(fmter schema.Formatter, b []byte, add *migrate.AddColumnOp) (_ []byte, err error) { + b = append(b, "ADD COLUMN "...) + b = fmter.AppendName(b, add.ColumnName) + b = append(b, " "...) + + b, err = add.Column.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if add.Column.GetDefaultValue() != "" { + b = append(b, " DEFAULT "...) + b = append(b, add.Column.GetDefaultValue()...) + b = append(b, " "...) + } + + if add.Column.GetIsIdentity() { + b = appendGeneratedAsIdentity(b) + } + + return b, nil +} + +func (m *migrator) dropColumn(fmter schema.Formatter, b []byte, drop *migrate.DropColumnOp) (_ []byte, err error) { + b = append(b, "DROP COLUMN "...) + b = fmter.AppendName(b, drop.ColumnName) + + return b, nil +} + +func (m *migrator) addPrimaryKey(fmter schema.Formatter, b []byte, pk sqlschema.PrimaryKey) (_ []byte, err error) { + b = append(b, "ADD PRIMARY KEY ("...) + b, _ = pk.Columns.AppendQuery(fmter, b) + b = append(b, ")"...) + + return b, nil +} + +func (m *migrator) changePrimaryKey(fmter schema.Formatter, b []byte, change *migrate.ChangePrimaryKeyOp) (_ []byte, err error) { + b, _ = m.dropConstraint(fmter, b, change.Old.Name) + b = append(b, ", "...) + b, _ = m.addPrimaryKey(fmter, b, change.New) + return b, nil +} + +func (m *migrator) addUnique(fmter schema.Formatter, b []byte, change *migrate.AddUniqueConstraintOp) (_ []byte, err error) { + b = append(b, "ADD CONSTRAINT "...) + if change.Unique.Name != "" { + b = fmter.AppendName(b, change.Unique.Name) + } else { + // Default naming scheme for unique constraints in Postgres is <table>_<column>_key + b = fmter.AppendName(b, fmt.Sprintf("%s_%s_key", change.TableName, change.Unique.Columns)) + } + b = append(b, " UNIQUE ("...) + b, _ = change.Unique.Columns.AppendQuery(fmter, b) + b = append(b, ")"...) + + return b, nil +} + +func (m *migrator) dropConstraint(fmter schema.Formatter, b []byte, name string) (_ []byte, err error) { + b = append(b, "DROP CONSTRAINT "...) + b = fmter.AppendName(b, name) + + return b, nil +} + +func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate.AddForeignKeyOp) (_ []byte, err error) { + b = append(b, "ADD CONSTRAINT "...) + + name := add.ConstraintName + if name == "" { + colRef := add.ForeignKey.From + columns := strings.Join(colRef.Column.Split(), "_") + name = fmt.Sprintf("%s_%s_fkey", colRef.TableName, columns) + } + b = fmter.AppendName(b, name) + + b = append(b, " FOREIGN KEY ("...) + if b, err = add.ForeignKey.From.Column.AppendQuery(fmter, b); err != nil { + return b, err + } + b = append(b, ")"...) + + b = append(b, " REFERENCES "...) + b = m.appendFQN(fmter, b, add.ForeignKey.To.TableName) + + b = append(b, " ("...) + if b, err = add.ForeignKey.To.Column.AppendQuery(fmter, b); err != nil { + return b, err + } + b = append(b, ")"...) + + return b, nil +} + +func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *migrate.ChangeColumnTypeOp) (_ []byte, err error) { + // alterColumn never re-assigns err, so there is no need to check for err != nil after calling it + var i int + appendAlterColumn := func() { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, "ALTER COLUMN "...) + b = fmter.AppendName(b, colDef.Column) + i++ + } + + got, want := colDef.From, colDef.To + + inspector := m.db.Dialect().(sqlschema.InspectorDialect) + if !inspector.CompareType(want, got) { + appendAlterColumn() + b = append(b, " SET DATA TYPE "...) + if b, err = want.AppendQuery(fmter, b); err != nil { + return b, err + } + } + + // Column must be declared NOT NULL before identity can be added. + // Although PG can resolve the order of operations itself, we make this explicit in the query. + if want.GetIsNullable() != got.GetIsNullable() { + appendAlterColumn() + if !want.GetIsNullable() { + b = append(b, " SET NOT NULL"...) + } else { + b = append(b, " DROP NOT NULL"...) + } + } + + if want.GetIsIdentity() != got.GetIsIdentity() { + appendAlterColumn() + if !want.GetIsIdentity() { + b = append(b, " DROP IDENTITY"...) + } else { + b = append(b, " ADD"...) + b = appendGeneratedAsIdentity(b) + } + } + + if want.GetDefaultValue() != got.GetDefaultValue() { + appendAlterColumn() + if want.GetDefaultValue() == "" { + b = append(b, " DROP DEFAULT"...) + } else { + b = append(b, " SET DEFAULT "...) + b = append(b, want.GetDefaultValue()...) + } + } + + return b, nil +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go index 46b55659b..b0d9a2331 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "encoding/hex" "fmt" + "math" "reflect" "strconv" "time" @@ -159,7 +160,7 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { case int64: return strconv.AppendInt(b, v, 10) case float64: - return dialect.AppendFloat64(b, v) + return arrayAppendFloat64(b, v) case bool: return dialect.AppendBool(b, v) case []byte: @@ -167,7 +168,10 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { case string: return arrayAppendString(b, v) case time.Time: - return fmter.Dialect().AppendTime(b, v) + b = append(b, '"') + b = appendTime(b, v) + b = append(b, '"') + return b default: err := fmt.Errorf("pgdialect: can't append %T", v) return dialect.AppendError(b, err) @@ -288,7 +292,7 @@ func appendFloat64Slice(b []byte, floats []float64) []byte { b = append(b, '{') for _, n := range floats { - b = dialect.AppendFloat64(b, n) + b = arrayAppendFloat64(b, n) b = append(b, ',') } if len(floats) > 0 { @@ -302,6 +306,19 @@ func appendFloat64Slice(b []byte, floats []float64) []byte { return b } +func arrayAppendFloat64(b []byte, num float64) []byte { + switch { + case math.IsNaN(num): + return append(b, "NaN"...) + case math.IsInf(num, 1): + return append(b, "Infinity"...) + case math.IsInf(num, -1): + return append(b, "-Infinity"...) + default: + return strconv.AppendFloat(b, num, 'f', -1, 64) + } +} + func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { ts := v.Convert(sliceTimeType).Interface().([]time.Time) return appendTimeSlice(fmter, b, ts) @@ -383,6 +400,10 @@ func arrayScanner(typ reflect.Type) schema.ScannerFunc { } } + if src == nil { + return nil + } + b, err := toBytes(src) if err != nil { return err @@ -553,7 +574,7 @@ func scanFloat64SliceValue(dest reflect.Value, src interface{}) error { } func scanFloat64Slice(src interface{}) ([]float64, error) { - if src == -1 { + if src == nil { return nil, nil } @@ -593,7 +614,7 @@ func toBytes(src interface{}) ([]byte, error) { case []byte: return src, nil default: - return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + return nil, fmt.Errorf("pgdialect: got %T, wanted []byte or string", src) } } diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go index 358971f61..040163f98 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go @@ -10,6 +10,7 @@ import ( "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/migrate/sqlschema" "github.com/uptrace/bun/schema" ) @@ -29,6 +30,10 @@ type Dialect struct { features feature.Feature } +var _ schema.Dialect = (*Dialect)(nil) +var _ sqlschema.InspectorDialect = (*Dialect)(nil) +var _ sqlschema.MigratorDialect = (*Dialect)(nil) + func New() *Dialect { d := new(Dialect) d.tables = schema.NewTables(d) @@ -48,7 +53,8 @@ func New() *Dialect { feature.InsertOnConflict | feature.SelectExists | feature.GeneratedIdentity | - feature.CompositeIn + feature.CompositeIn | + feature.DeleteReturning return d } @@ -118,5 +124,10 @@ func (d *Dialect) AppendUint64(b []byte, n uint64) []byte { } func (d *Dialect) AppendSequence(b []byte, _ *schema.Table, _ *schema.Field) []byte { + return appendGeneratedAsIdentity(b) +} + +// appendGeneratedAsIdentity appends GENERATED BY DEFAULT AS IDENTITY to the column definition. +func appendGeneratedAsIdentity(b []byte) []byte { return append(b, " GENERATED BY DEFAULT AS IDENTITY"...) } diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go new file mode 100644 index 000000000..42bbbe84f --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go @@ -0,0 +1,297 @@ +package pgdialect + +import ( + "context" + "strings" + + "github.com/uptrace/bun" + "github.com/uptrace/bun/migrate/sqlschema" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +type ( + Schema = sqlschema.BaseDatabase + Table = sqlschema.BaseTable + Column = sqlschema.BaseColumn +) + +func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector { + return newInspector(db, options...) +} + +type Inspector struct { + sqlschema.InspectorConfig + db *bun.DB +} + +var _ sqlschema.Inspector = (*Inspector)(nil) + +func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector { + i := &Inspector{db: db} + sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...) + return i +} + +func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { + dbSchema := Schema{ + Tables: orderedmap.New[string, sqlschema.Table](), + ForeignKeys: make(map[sqlschema.ForeignKey]string), + } + + exclude := in.ExcludeTables + if len(exclude) == 0 { + // Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice. + exclude = []string{""} + } + + var tables []*InformationSchemaTable + if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil { + return dbSchema, err + } + + var fks []*ForeignKey + if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil { + return dbSchema, err + } + dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks)) + + for _, table := range tables { + var columns []*InformationSchemaColumn + if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil { + return dbSchema, err + } + + colDefs := orderedmap.New[string, sqlschema.Column]() + uniqueGroups := make(map[string][]string) + + for _, c := range columns { + def := c.Default + if c.IsSerial || c.IsIdentity { + def = "" + } else if !c.IsDefaultLiteral { + def = strings.ToLower(def) + } + + colDefs.Set(c.Name, &Column{ + Name: c.Name, + SQLType: c.DataType, + VarcharLen: c.VarcharLen, + DefaultValue: def, + IsNullable: c.IsNullable, + IsAutoIncrement: c.IsSerial, + IsIdentity: c.IsIdentity, + }) + + for _, group := range c.UniqueGroups { + uniqueGroups[group] = append(uniqueGroups[group], c.Name) + } + } + + var unique []sqlschema.Unique + for name, columns := range uniqueGroups { + unique = append(unique, sqlschema.Unique{ + Name: name, + Columns: sqlschema.NewColumns(columns...), + }) + } + + var pk *sqlschema.PrimaryKey + if len(table.PrimaryKey.Columns) > 0 { + pk = &sqlschema.PrimaryKey{ + Name: table.PrimaryKey.ConstraintName, + Columns: sqlschema.NewColumns(table.PrimaryKey.Columns...), + } + } + + dbSchema.Tables.Set(table.Name, &Table{ + Schema: table.Schema, + Name: table.Name, + Columns: colDefs, + PrimaryKey: pk, + UniqueConstraints: unique, + }) + } + + for _, fk := range fks { + dbSchema.ForeignKeys[sqlschema.ForeignKey{ + From: sqlschema.NewColumnReference(fk.SourceTable, fk.SourceColumns...), + To: sqlschema.NewColumnReference(fk.TargetTable, fk.TargetColumns...), + }] = fk.ConstraintName + } + return dbSchema, nil +} + +type InformationSchemaTable struct { + Schema string `bun:"table_schema,pk"` + Name string `bun:"table_name,pk"` + PrimaryKey PrimaryKey `bun:"embed:primary_key_"` + + Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"` +} + +type InformationSchemaColumn struct { + Schema string `bun:"table_schema"` + Table string `bun:"table_name"` + Name string `bun:"column_name"` + DataType string `bun:"data_type"` + VarcharLen int `bun:"varchar_len"` + IsArray bool `bun:"is_array"` + ArrayDims int `bun:"array_dims"` + Default string `bun:"default"` + IsDefaultLiteral bool `bun:"default_is_literal_expr"` + IsIdentity bool `bun:"is_identity"` + IndentityType string `bun:"identity_type"` + IsSerial bool `bun:"is_serial"` + IsNullable bool `bun:"is_nullable"` + UniqueGroups []string `bun:"unique_groups,array"` +} + +type ForeignKey struct { + ConstraintName string `bun:"constraint_name"` + SourceSchema string `bun:"schema_name"` + SourceTable string `bun:"table_name"` + SourceColumns []string `bun:"columns,array"` + TargetSchema string `bun:"target_schema"` + TargetTable string `bun:"target_table"` + TargetColumns []string `bun:"target_columns,array"` +} + +type PrimaryKey struct { + ConstraintName string `bun:"name"` + Columns []string `bun:"columns,array"` +} + +const ( + // sqlInspectTables retrieves all user-defined tables in the selected schema. + // Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results. + sqlInspectTables = ` +SELECT + "t".table_schema, + "t".table_name, + pk.name AS primary_key_name, + pk.columns AS primary_key_columns +FROM information_schema.tables "t" + LEFT JOIN ( + SELECT i.indrelid, "idx".relname AS "name", ARRAY_AGG("a".attname) AS "columns" + FROM pg_index i + JOIN pg_attribute "a" + ON "a".attrelid = i.indrelid + AND "a".attnum = ANY("i".indkey) + AND i.indisprimary + JOIN pg_class "idx" ON i.indexrelid = "idx".oid + GROUP BY 1, 2 + ) pk + ON ("t".table_schema || '.' || "t".table_name)::regclass = pk.indrelid +WHERE table_type = 'BASE TABLE' + AND "t".table_schema = ? + AND "t".table_schema NOT LIKE 'pg_%' + AND "table_name" NOT IN (?) +ORDER BY "t".table_schema, "t".table_name +` + + // sqlInspectColumnsQuery retrieves column definitions for the specified table. + // Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw + // with additional args for table_schema and table_name. + sqlInspectColumnsQuery = ` +SELECT + "c".table_schema, + "c".table_name, + "c".column_name, + "c".data_type, + "c".character_maximum_length::integer AS varchar_len, + "c".data_type = 'ARRAY' AS is_array, + COALESCE("c".array_dims, 0) AS array_dims, + CASE + WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$') + ELSE "c".column_default + END AS "default", + "c".column_default ~ '^''.*''::.*$' OR "c".column_default ~ '^[0-9\.]+$' AS default_is_literal_expr, + "c".is_identity = 'YES' AS is_identity, + "c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial, + COALESCE("c".identity_type, '') AS identity_type, + "c".is_nullable = 'YES' AS is_nullable, + "c"."unique_groups" AS unique_groups +FROM ( + SELECT + "table_schema", + "table_name", + "column_name", + "c".data_type, + "c".character_maximum_length, + "c".column_default, + "c".is_identity, + "c".is_nullable, + att.array_dims, + att.identity_type, + att."unique_groups", + att."constraint_type" + FROM information_schema.columns "c" + LEFT JOIN ( + SELECT + s.nspname AS "table_schema", + "t".relname AS "table_name", + "c".attname AS "column_name", + "c".attndims AS array_dims, + "c".attidentity AS identity_type, + ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups", + ARRAY_AGG(con.contype) AS "constraint_type" + FROM ( + SELECT + conname, + contype, + connamespace, + conrelid, + conrelid AS attrelid, + UNNEST(conkey) AS attnum + FROM pg_constraint + ) con + LEFT JOIN pg_attribute "c" USING (attrelid, attnum) + LEFT JOIN pg_namespace s ON s.oid = con.connamespace + LEFT JOIN pg_class "t" ON "t".oid = con.conrelid + GROUP BY 1, 2, 3, 4, 5 + ) att USING ("table_schema", "table_name", "column_name") + ) "c" +WHERE "table_schema" = ? AND "table_name" = ? +ORDER BY "table_schema", "table_name", "column_name" +` + + // sqlInspectForeignKeys get FK definitions for user-defined tables. + // Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results. + sqlInspectForeignKeys = ` +WITH + "schemas" AS ( + SELECT oid, nspname + FROM pg_namespace + ), + "tables" AS ( + SELECT oid, relnamespace, relname, relkind + FROM pg_class + ), + "columns" AS ( + SELECT attrelid, attname, attnum + FROM pg_attribute + WHERE attisdropped = false + ) +SELECT DISTINCT + co.conname AS "constraint_name", + ss.nspname AS schema_name, + s.relname AS "table_name", + ARRAY_AGG(sc.attname) AS "columns", + ts.nspname AS target_schema, + "t".relname AS target_table, + ARRAY_AGG(tc.attname) AS target_columns +FROM pg_constraint co + LEFT JOIN "tables" s ON s.oid = co.conrelid + LEFT JOIN "schemas" ss ON ss.oid = s.relnamespace + LEFT JOIN "columns" sc ON sc.attrelid = s.oid AND sc.attnum = ANY(co.conkey) + LEFT JOIN "tables" t ON t.oid = co.confrelid + LEFT JOIN "schemas" ts ON ts.oid = "t".relnamespace + LEFT JOIN "columns" tc ON tc.attrelid = "t".oid AND tc.attnum = ANY(co.confkey) +WHERE co.contype = 'f' + AND co.conrelid IN (SELECT oid FROM pg_class WHERE relkind = 'r') + AND ARRAY_POSITION(co.conkey, sc.attnum) = ARRAY_POSITION(co.confkey, tc.attnum) + AND ss.nspname = ? + AND s.relname NOT IN (?) AND "t".relname NOT IN (?) +GROUP BY "constraint_name", "schema_name", "table_name", target_schema, target_table +` +) diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go index fad84209d..bacc00e86 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go @@ -5,18 +5,22 @@ import ( "encoding/json" "net" "reflect" + "strings" "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/migrate/sqlschema" "github.com/uptrace/bun/schema" ) const ( // Date / Time - pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone - pgTypeDate = "DATE" // Date - pgTypeTime = "TIME" // Time without a time zone - pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone - pgTypeInterval = "INTERVAL" // Time Interval + pgTypeTimestamp = "TIMESTAMP" // Timestamp + pgTypeTimestampWithTz = "TIMESTAMP WITH TIME ZONE" // Timestamp with a time zone + pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone (alias) + pgTypeDate = "DATE" // Date + pgTypeTime = "TIME" // Time without a time zone + pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone + pgTypeInterval = "INTERVAL" // Time interval // Network Addresses pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks @@ -28,6 +32,13 @@ const ( pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer + // Character Types + pgTypeChar = "CHAR" // fixed length string (blank padded) + pgTypeCharacter = "CHARACTER" // alias for CHAR + pgTypeText = "TEXT" // variable length string without limit + pgTypeVarchar = "VARCHAR" // variable length string with optional limit + pgTypeCharacterVarying = "CHARACTER VARYING" // alias for VARCHAR + // Binary Data Types pgTypeBytea = "BYTEA" // binary string ) @@ -43,6 +54,10 @@ func (d *Dialect) DefaultVarcharLen() int { return 0 } +func (d *Dialect) DefaultSchema() string { + return "public" +} + func fieldSQLType(field *schema.Field) string { if field.UserSQLType != "" { return field.UserSQLType @@ -103,3 +118,62 @@ func sqlType(typ reflect.Type) string { return sqlType } + +var ( + char = newAliases(pgTypeChar, pgTypeCharacter) + varchar = newAliases(pgTypeVarchar, pgTypeCharacterVarying) + timestampTz = newAliases(sqltype.Timestamp, pgTypeTimestampTz, pgTypeTimestampWithTz) +) + +func (d *Dialect) CompareType(col1, col2 sqlschema.Column) bool { + typ1, typ2 := strings.ToUpper(col1.GetSQLType()), strings.ToUpper(col2.GetSQLType()) + + if typ1 == typ2 { + return checkVarcharLen(col1, col2, d.DefaultVarcharLen()) + } + + switch { + case char.IsAlias(typ1) && char.IsAlias(typ2): + return checkVarcharLen(col1, col2, d.DefaultVarcharLen()) + case varchar.IsAlias(typ1) && varchar.IsAlias(typ2): + return checkVarcharLen(col1, col2, d.DefaultVarcharLen()) + case timestampTz.IsAlias(typ1) && timestampTz.IsAlias(typ2): + return true + } + return false +} + +// checkVarcharLen returns true if columns have the same VarcharLen, or, +// if one specifies no VarcharLen and the other one has the default lenght for pgdialect. +// We assume that the types are otherwise equivalent and that any non-character column +// would have VarcharLen == 0; +func checkVarcharLen(col1, col2 sqlschema.Column, defaultLen int) bool { + vl1, vl2 := col1.GetVarcharLen(), col2.GetVarcharLen() + + if vl1 == vl2 { + return true + } + + if (vl1 == 0 && vl2 == defaultLen) || (vl1 == defaultLen && vl2 == 0) { + return true + } + return false +} + +// typeAlias defines aliases for common data types. It is a lightweight string set implementation. +type typeAlias map[string]struct{} + +// IsAlias checks if typ1 and typ2 are aliases of the same data type. +func (t typeAlias) IsAlias(typ string) bool { + _, ok := t[typ] + return ok +} + +// newAliases creates a set of aliases. +func newAliases(aliases ...string) typeAlias { + types := make(typeAlias) + for _, a := range aliases { + types[a] = struct{}{} + } + return types +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go index c06043647..a4a6a760a 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go @@ -2,5 +2,5 @@ package pgdialect // Version is the current release version. func Version() string { - return "1.2.5" + return "1.2.6" } |