diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go | 297 |
1 files changed, 297 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go new file mode 100644 index 000000000..42bbbe84f --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go @@ -0,0 +1,297 @@ +package pgdialect + +import ( + "context" + "strings" + + "github.com/uptrace/bun" + "github.com/uptrace/bun/migrate/sqlschema" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +type ( + Schema = sqlschema.BaseDatabase + Table = sqlschema.BaseTable + Column = sqlschema.BaseColumn +) + +func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector { + return newInspector(db, options...) +} + +type Inspector struct { + sqlschema.InspectorConfig + db *bun.DB +} + +var _ sqlschema.Inspector = (*Inspector)(nil) + +func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector { + i := &Inspector{db: db} + sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...) + return i +} + +func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { + dbSchema := Schema{ + Tables: orderedmap.New[string, sqlschema.Table](), + ForeignKeys: make(map[sqlschema.ForeignKey]string), + } + + exclude := in.ExcludeTables + if len(exclude) == 0 { + // Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice. + exclude = []string{""} + } + + var tables []*InformationSchemaTable + if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil { + return dbSchema, err + } + + var fks []*ForeignKey + if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil { + return dbSchema, err + } + dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks)) + + for _, table := range tables { + var columns []*InformationSchemaColumn + if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil { + return dbSchema, err + } + + colDefs := orderedmap.New[string, sqlschema.Column]() + uniqueGroups := make(map[string][]string) + + for _, c := range columns { + def := c.Default + if c.IsSerial || c.IsIdentity { + def = "" + } else if !c.IsDefaultLiteral { + def = strings.ToLower(def) + } + + colDefs.Set(c.Name, &Column{ + Name: c.Name, + SQLType: c.DataType, + VarcharLen: c.VarcharLen, + DefaultValue: def, + IsNullable: c.IsNullable, + IsAutoIncrement: c.IsSerial, + IsIdentity: c.IsIdentity, + }) + + for _, group := range c.UniqueGroups { + uniqueGroups[group] = append(uniqueGroups[group], c.Name) + } + } + + var unique []sqlschema.Unique + for name, columns := range uniqueGroups { + unique = append(unique, sqlschema.Unique{ + Name: name, + Columns: sqlschema.NewColumns(columns...), + }) + } + + var pk *sqlschema.PrimaryKey + if len(table.PrimaryKey.Columns) > 0 { + pk = &sqlschema.PrimaryKey{ + Name: table.PrimaryKey.ConstraintName, + Columns: sqlschema.NewColumns(table.PrimaryKey.Columns...), + } + } + + dbSchema.Tables.Set(table.Name, &Table{ + Schema: table.Schema, + Name: table.Name, + Columns: colDefs, + PrimaryKey: pk, + UniqueConstraints: unique, + }) + } + + for _, fk := range fks { + dbSchema.ForeignKeys[sqlschema.ForeignKey{ + From: sqlschema.NewColumnReference(fk.SourceTable, fk.SourceColumns...), + To: sqlschema.NewColumnReference(fk.TargetTable, fk.TargetColumns...), + }] = fk.ConstraintName + } + return dbSchema, nil +} + +type InformationSchemaTable struct { + Schema string `bun:"table_schema,pk"` + Name string `bun:"table_name,pk"` + PrimaryKey PrimaryKey `bun:"embed:primary_key_"` + + Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"` +} + +type InformationSchemaColumn struct { + Schema string `bun:"table_schema"` + Table string `bun:"table_name"` + Name string `bun:"column_name"` + DataType string `bun:"data_type"` + VarcharLen int `bun:"varchar_len"` + IsArray bool `bun:"is_array"` + ArrayDims int `bun:"array_dims"` + Default string `bun:"default"` + IsDefaultLiteral bool `bun:"default_is_literal_expr"` + IsIdentity bool `bun:"is_identity"` + IndentityType string `bun:"identity_type"` + IsSerial bool `bun:"is_serial"` + IsNullable bool `bun:"is_nullable"` + UniqueGroups []string `bun:"unique_groups,array"` +} + +type ForeignKey struct { + ConstraintName string `bun:"constraint_name"` + SourceSchema string `bun:"schema_name"` + SourceTable string `bun:"table_name"` + SourceColumns []string `bun:"columns,array"` + TargetSchema string `bun:"target_schema"` + TargetTable string `bun:"target_table"` + TargetColumns []string `bun:"target_columns,array"` +} + +type PrimaryKey struct { + ConstraintName string `bun:"name"` + Columns []string `bun:"columns,array"` +} + +const ( + // sqlInspectTables retrieves all user-defined tables in the selected schema. + // Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results. + sqlInspectTables = ` +SELECT + "t".table_schema, + "t".table_name, + pk.name AS primary_key_name, + pk.columns AS primary_key_columns +FROM information_schema.tables "t" + LEFT JOIN ( + SELECT i.indrelid, "idx".relname AS "name", ARRAY_AGG("a".attname) AS "columns" + FROM pg_index i + JOIN pg_attribute "a" + ON "a".attrelid = i.indrelid + AND "a".attnum = ANY("i".indkey) + AND i.indisprimary + JOIN pg_class "idx" ON i.indexrelid = "idx".oid + GROUP BY 1, 2 + ) pk + ON ("t".table_schema || '.' || "t".table_name)::regclass = pk.indrelid +WHERE table_type = 'BASE TABLE' + AND "t".table_schema = ? + AND "t".table_schema NOT LIKE 'pg_%' + AND "table_name" NOT IN (?) +ORDER BY "t".table_schema, "t".table_name +` + + // sqlInspectColumnsQuery retrieves column definitions for the specified table. + // Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw + // with additional args for table_schema and table_name. + sqlInspectColumnsQuery = ` +SELECT + "c".table_schema, + "c".table_name, + "c".column_name, + "c".data_type, + "c".character_maximum_length::integer AS varchar_len, + "c".data_type = 'ARRAY' AS is_array, + COALESCE("c".array_dims, 0) AS array_dims, + CASE + WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$') + ELSE "c".column_default + END AS "default", + "c".column_default ~ '^''.*''::.*$' OR "c".column_default ~ '^[0-9\.]+$' AS default_is_literal_expr, + "c".is_identity = 'YES' AS is_identity, + "c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial, + COALESCE("c".identity_type, '') AS identity_type, + "c".is_nullable = 'YES' AS is_nullable, + "c"."unique_groups" AS unique_groups +FROM ( + SELECT + "table_schema", + "table_name", + "column_name", + "c".data_type, + "c".character_maximum_length, + "c".column_default, + "c".is_identity, + "c".is_nullable, + att.array_dims, + att.identity_type, + att."unique_groups", + att."constraint_type" + FROM information_schema.columns "c" + LEFT JOIN ( + SELECT + s.nspname AS "table_schema", + "t".relname AS "table_name", + "c".attname AS "column_name", + "c".attndims AS array_dims, + "c".attidentity AS identity_type, + ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups", + ARRAY_AGG(con.contype) AS "constraint_type" + FROM ( + SELECT + conname, + contype, + connamespace, + conrelid, + conrelid AS attrelid, + UNNEST(conkey) AS attnum + FROM pg_constraint + ) con + LEFT JOIN pg_attribute "c" USING (attrelid, attnum) + LEFT JOIN pg_namespace s ON s.oid = con.connamespace + LEFT JOIN pg_class "t" ON "t".oid = con.conrelid + GROUP BY 1, 2, 3, 4, 5 + ) att USING ("table_schema", "table_name", "column_name") + ) "c" +WHERE "table_schema" = ? AND "table_name" = ? +ORDER BY "table_schema", "table_name", "column_name" +` + + // sqlInspectForeignKeys get FK definitions for user-defined tables. + // Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results. + sqlInspectForeignKeys = ` +WITH + "schemas" AS ( + SELECT oid, nspname + FROM pg_namespace + ), + "tables" AS ( + SELECT oid, relnamespace, relname, relkind + FROM pg_class + ), + "columns" AS ( + SELECT attrelid, attname, attnum + FROM pg_attribute + WHERE attisdropped = false + ) +SELECT DISTINCT + co.conname AS "constraint_name", + ss.nspname AS schema_name, + s.relname AS "table_name", + ARRAY_AGG(sc.attname) AS "columns", + ts.nspname AS target_schema, + "t".relname AS target_table, + ARRAY_AGG(tc.attname) AS target_columns +FROM pg_constraint co + LEFT JOIN "tables" s ON s.oid = co.conrelid + LEFT JOIN "schemas" ss ON ss.oid = s.relnamespace + LEFT JOIN "columns" sc ON sc.attrelid = s.oid AND sc.attnum = ANY(co.conkey) + LEFT JOIN "tables" t ON t.oid = co.confrelid + LEFT JOIN "schemas" ts ON ts.oid = "t".relnamespace + LEFT JOIN "columns" tc ON tc.attrelid = "t".oid AND tc.attnum = ANY(co.confkey) +WHERE co.contype = 'f' + AND co.conrelid IN (SELECT oid FROM pg_class WHERE relkind = 'r') + AND ARRAY_POSITION(co.conkey, sc.attnum) = ARRAY_POSITION(co.confkey, tc.attnum) + AND ss.nspname = ? + AND s.relname NOT IN (?) AND "t".relname NOT IN (?) +GROUP BY "constraint_name", "schema_name", "table_name", target_schema, target_table +` +) |