summaryrefslogtreecommitdiff
path: root/internal/db/bundb/migrations/util.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/migrations/util.go')
-rw-r--r--internal/db/bundb/migrations/util.go198
1 files changed, 198 insertions, 0 deletions
diff --git a/internal/db/bundb/migrations/util.go b/internal/db/bundb/migrations/util.go
index 47de09e23..edf7c1d05 100644
--- a/internal/db/bundb/migrations/util.go
+++ b/internal/db/bundb/migrations/util.go
@@ -19,11 +19,209 @@ package migrations
import (
"context"
+ "errors"
+ "fmt"
+ "reflect"
+ "strconv"
+ "strings"
+ "codeberg.org/gruf/go-byteutil"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/schema"
)
+// convertEnums performs a transaction that converts
+// a table's column of our old-style enums (strings) to
+// more performant and space-saving integer types.
+func convertEnums[OldType ~string, NewType ~int16](
+ ctx context.Context,
+ tx bun.Tx,
+ table string,
+ column string,
+ mapping map[OldType]NewType,
+ defaultValue *NewType,
+) error {
+ if len(mapping) == 0 {
+ return errors.New("empty mapping")
+ }
+
+ // Generate new column name.
+ newColumn := column + "_new"
+
+ log.Infof(ctx, "converting %s.%s enums; "+
+ "this may take a while, please don't interrupt!",
+ table, column,
+ )
+
+ // Ensure a default value.
+ if defaultValue == nil {
+ var zero NewType
+ defaultValue = &zero
+ }
+
+ // Add new column to database.
+ if _, err := tx.NewAddColumn().
+ Table(table).
+ ColumnExpr("? SMALLINT NOT NULL DEFAULT ?",
+ bun.Ident(newColumn),
+ *defaultValue).
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error adding new column: %w", err)
+ }
+
+ // Get a count of all in table.
+ total, err := tx.NewSelect().
+ Table(table).
+ Count(ctx)
+ if err != nil {
+ return gtserror.Newf("error selecting total count: %w", err)
+ }
+
+ var updated int
+ for old, new := range mapping {
+
+ // Update old to new values.
+ res, err := tx.NewUpdate().
+ Table(table).
+ Where("? = ?", bun.Ident(column), old).
+ Set("? = ?", bun.Ident(newColumn), new).
+ Exec(ctx)
+ if err != nil {
+ return gtserror.Newf("error updating old column values: %w", err)
+ }
+
+ // Count number items updated.
+ n, _ := res.RowsAffected()
+ updated += int(n)
+ }
+
+ // Check total updated.
+ if total != updated {
+ log.Warnf(ctx, "total=%d does not match updated=%d", total, updated)
+ }
+
+ // Drop the old column from table.
+ if _, err := tx.NewDropColumn().
+ Table(table).
+ ColumnExpr("?", bun.Ident(column)).
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error dropping old column: %w", err)
+ }
+
+ // Rename new to old name.
+ if _, err := tx.NewRaw(
+ "ALTER TABLE ? RENAME COLUMN ? TO ?",
+ bun.Ident(table),
+ bun.Ident(newColumn),
+ bun.Ident(column),
+ ).Exec(ctx); err != nil {
+ return gtserror.Newf("error renaming new column: %w", err)
+ }
+
+ return nil
+}
+
+// getBunColumnDef generates a column definition string for the SQL table represented by
+// Go type, with the SQL column represented by the given Go field name. This ensures when
+// adding a new column for table by migration that it will end up as bun would create it.
+//
+// NOTE: this function must stay in sync with (*bun.CreateTableQuery{}).AppendQuery(),
+// specifically where it loops over table fields appending each column definition.
+func getBunColumnDef(db bun.IDB, rtype reflect.Type, fieldName string) (string, error) {
+ d := db.Dialect()
+ f := d.Features()
+
+ // Get bun schema definitions for Go type and its field.
+ field, table, err := getModelField(db, rtype, fieldName)
+ if err != nil {
+ return "", err
+ }
+
+ // Start with reasonable buf.
+ buf := make([]byte, 0, 64)
+
+ // Start with the SQL column name.
+ buf = append(buf, field.SQLName...)
+ buf = append(buf, " "...)
+
+ // Append the SQL
+ // type information.
+ switch {
+
+ // Most of the time these two will match, but for the cases where DiscoveredSQLType is dialect-specific,
+ // e.g. pgdialect would change sqltype.SmallInt to pgTypeSmallSerial for columns that have `bun:",autoincrement"`
+ case !strings.EqualFold(field.CreateTableSQLType, field.DiscoveredSQLType):
+ buf = append(buf, field.CreateTableSQLType...)
+
+ // For all common SQL types except VARCHAR, both UserDefinedSQLType and DiscoveredSQLType specify the correct type,
+ // and we needn't modify it. For VARCHAR columns, we will stop to check if a valid length has been set in .Varchar(int).
+ case !strings.EqualFold(field.CreateTableSQLType, sqltype.VarChar):
+ buf = append(buf, field.CreateTableSQLType...)
+
+ // All else falls back
+ // to a default varchar.
+ default:
+ if d.Name() == dialect.Oracle {
+ buf = append(buf, "VARCHAR2"...)
+ } else {
+ buf = append(buf, sqltype.VarChar...)
+ }
+ buf = append(buf, "("...)
+ buf = strconv.AppendInt(buf, int64(d.DefaultVarcharLen()), 10)
+ buf = append(buf, ")"...)
+ }
+
+ // Append not null definition if field requires.
+ if field.NotNull && d.Name() != dialect.Oracle {
+ buf = append(buf, " NOT NULL"...)
+ }
+
+ // Append autoincrement definition if field requires.
+ if field.Identity && f.Has(feature.GeneratedIdentity) ||
+ (field.AutoIncrement && (f.Has(feature.AutoIncrement) || f.Has(feature.Identity))) {
+ buf = d.AppendSequence(buf, table, field)
+ }
+
+ // Append any default value.
+ if field.SQLDefault != "" {
+ buf = append(buf, " DEFAULT "...)
+ buf = append(buf, field.SQLDefault...)
+ }
+
+ return byteutil.B2S(buf), nil
+}
+
+// getModelField returns the uptrace/bun schema details for given Go type and field name.
+func getModelField(db bun.IDB, rtype reflect.Type, fieldName string) (*schema.Field, *schema.Table, error) {
+
+ // Get the associated table for Go type.
+ table := db.Dialect().Tables().Get(rtype)
+ if table == nil {
+ return nil, nil, fmt.Errorf("no table found for type: %s", rtype)
+ }
+
+ var field *schema.Field
+
+ // Look for field matching Go name.
+ for i := range table.Fields {
+ if table.Fields[i].GoName == fieldName {
+ field = table.Fields[i]
+ break
+ }
+ }
+
+ if field == nil {
+ return nil, nil, fmt.Errorf("no bun field found on %s with name: %s", rtype, fieldName)
+ }
+
+ return field, table, nil
+}
+
// doesColumnExist safely checks whether given column exists on table, handling both SQLite and PostgreSQL appropriately.
func doesColumnExist(ctx context.Context, tx bun.Tx, table, col string) (bool, error) {
var n int