diff options
Diffstat (limited to 'internal/db/bundb/migrations/20250415111056_thread_all_statuses.go')
| -rw-r--r-- | internal/db/bundb/migrations/20250415111056_thread_all_statuses.go | 359 |
1 files changed, 271 insertions, 88 deletions
diff --git a/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go index fc02d1e40..2a3808120 100644 --- a/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go +++ b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go @@ -24,13 +24,16 @@ import ( "reflect" "slices" "strings" + "time" "code.superseriousbusiness.org/gotosocial/internal/db" newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new" oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old" + "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/util" "code.superseriousbusiness.org/gotosocial/internal/gtserror" "code.superseriousbusiness.org/gotosocial/internal/id" "code.superseriousbusiness.org/gotosocial/internal/log" + "code.superseriousbusiness.org/gotosocial/internal/util/xslices" "github.com/uptrace/bun" ) @@ -49,10 +52,26 @@ func init() { "thread_id", "thread_id_new", 1) var sr statusRethreader - var count int + var updatedTotal int64 var maxID string var statuses []*oldmodel.Status + // Create thread_id_new already + // so we can populate it as we go. + log.Info(ctx, "creating statuses column thread_id_new") + if _, err := db.NewAddColumn(). + Table("statuses"). + ColumnExpr(newColDef). + Exec(ctx); err != nil { + return gtserror.Newf("error adding statuses column thread_id_new: %w", err) + } + + // Try to merge the wal so we're + // not working on the wal file. + if err := doWALCheckpoint(ctx, db); err != nil { + return err + } + // Get a total count of all statuses before migration. total, err := db.NewSelect().Table("statuses").Count(ctx) if err != nil { @@ -63,74 +82,129 @@ func init() { // possible ULID value. maxID = id.Highest - log.Warn(ctx, "rethreading top-level statuses, this will take a *long* time") - for /* TOP LEVEL STATUS LOOP */ { + log.Warnf(ctx, "rethreading %d statuses, this will take a *long* time", total) + + // Open initial transaction. + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + for i := 1; ; i++ { // Reset slice. clear(statuses) statuses = statuses[:0] + batchStart := time.Now() + // Select top-level statuses. - if err := db.NewSelect(). + if err := tx.NewSelect(). Model(&statuses). - Column("id", "thread_id"). - + Column("id"). // We specifically use in_reply_to_account_id instead of in_reply_to_id as // they should both be set / unset in unison, but we specifically have an // index on in_reply_to_account_id with ID ordering, unlike in_reply_to_id. Where("? IS NULL", bun.Ident("in_reply_to_account_id")). Where("? < ?", bun.Ident("id"), maxID). OrderExpr("? DESC", bun.Ident("id")). - Limit(5000). + Limit(500). Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { return gtserror.Newf("error selecting top level statuses: %w", err) } - // Reached end of block. - if len(statuses) == 0 { + l := len(statuses) + if l == 0 { + // No more statuses! + // + // Transaction will be closed + // after leaving the loop. break + + } else if i%200 == 0 { + // Begin a new transaction every + // 200 batches (~100,000 statuses), + // to avoid massive commits. + + // Close existing transaction. + if err := tx.Commit(); err != nil { + return err + } + + // Try to flush the wal + // to avoid silly wal sizes. + if err := doWALCheckpoint(ctx, db); err != nil { + return err + } + + // Open new transaction. + tx, err = db.BeginTx(ctx, nil) + if err != nil { + return err + } } // Set next maxID value from statuses. maxID = statuses[len(statuses)-1].ID - // Rethread each selected batch of top-level statuses in a transaction. - if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - - // Rethread each top-level status. - for _, status := range statuses { - n, err := sr.rethreadStatus(ctx, tx, status) - if err != nil { - return gtserror.Newf("error rethreading status %s: %w", status.URI, err) - } - count += n + // Rethread using the + // open transaction. + var updatedInBatch int64 + for _, status := range statuses { + n, err := sr.rethreadStatus(ctx, tx, status, false) + if err != nil { + return gtserror.Newf("error rethreading status %s: %w", status.URI, err) } - - return nil - }); err != nil { - return err + updatedInBatch += n + updatedTotal += n } - log.Infof(ctx, "[approx %d of %d] rethreading statuses (top-level)", count, total) + // Show speed for this batch. + timeTaken := time.Since(batchStart).Milliseconds() + msPerRow := float64(timeTaken) / float64(updatedInBatch) + rowsPerMs := float64(1) / float64(msPerRow) + rowsPerSecond := 1000 * rowsPerMs + + // Show percent migrated overall. + totalDone := (float64(updatedTotal) / float64(total)) * 100 + + log.Infof( + ctx, + "[~%.2f%% done; ~%.0f rows/s] migrating threads", + totalDone, rowsPerSecond, + ) } - // Attempt to merge any sqlite write-ahead-log. - if err := doWALCheckpoint(ctx, db); err != nil { + // Close transaction. + if err := tx.Commit(); err != nil { return err } - log.Warn(ctx, "rethreading straggler statuses, this will take a *long* time") - for /* STRAGGLER STATUS LOOP */ { + // Create a partial index on thread_id_new to find stragglers. + // This index will be removed at the end of the migration. + log.Info(ctx, "creating temporary statuses thread_id_new index") + if _, err := db.NewCreateIndex(). + Table("statuses"). + Index("statuses_thread_id_new_idx"). + Column("thread_id_new"). + Where("? = ?", bun.Ident("thread_id_new"), id.Lowest). + Exec(ctx); err != nil { + return gtserror.Newf("error creating new thread_id index: %w", err) + } + + for i := 1; ; i++ { // Reset slice. clear(statuses) statuses = statuses[:0] + batchStart := time.Now() + // Select straggler statuses. if err := db.NewSelect(). Model(&statuses). - Column("id", "in_reply_to_id", "thread_id"). - Where("? IS NULL", bun.Ident("thread_id")). + Column("id"). + Where("? = ?", bun.Ident("thread_id_new"), id.Lowest). // We select in smaller batches for this part // of the migration as there is a chance that @@ -138,7 +212,7 @@ func init() { // part of the same thread, i.e. one call to // rethreadStatus() may effect other statuses // later in the slice. - Limit(1000). + Limit(250). Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { return gtserror.Newf("error selecting straggler statuses: %w", err) } @@ -149,23 +223,35 @@ func init() { } // Rethread each selected batch of straggler statuses in a transaction. + var updatedInBatch int64 if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - - // Rethread each top-level status. for _, status := range statuses { - n, err := sr.rethreadStatus(ctx, tx, status) + n, err := sr.rethreadStatus(ctx, tx, status, true) if err != nil { return gtserror.Newf("error rethreading status %s: %w", status.URI, err) } - count += n + updatedInBatch += n + updatedTotal += n } - return nil }); err != nil { return err } - log.Infof(ctx, "[approx %d of %d] rethreading statuses (stragglers)", count, total) + // Show speed for this batch. + timeTaken := time.Since(batchStart).Milliseconds() + msPerRow := float64(timeTaken) / float64(updatedInBatch) + rowsPerMs := float64(1) / float64(msPerRow) + rowsPerSecond := 1000 * rowsPerMs + + // Show percent migrated overall. + totalDone := (float64(updatedTotal) / float64(total)) * 100 + + log.Infof( + ctx, + "[~%.2f%% done; ~%.0f rows/s] migrating stragglers", + totalDone, rowsPerSecond, + ) } // Attempt to merge any sqlite write-ahead-log. @@ -173,6 +259,13 @@ func init() { return err } + log.Info(ctx, "dropping temporary thread_id_new index") + if _, err := db.NewDropIndex(). + Index("statuses_thread_id_new_idx"). + Exec(ctx); err != nil { + return gtserror.Newf("error dropping temporary thread_id_new index: %w", err) + } + log.Info(ctx, "dropping old thread_to_statuses table") if _, err := db.NewDropTable(). Table("thread_to_statuses"). @@ -180,33 +273,6 @@ func init() { return gtserror.Newf("error dropping old thread_to_statuses table: %w", err) } - log.Info(ctx, "creating new statuses thread_id column") - if _, err := db.NewAddColumn(). - Table("statuses"). - ColumnExpr(newColDef). - Exec(ctx); err != nil { - return gtserror.Newf("error adding new thread_id column: %w", err) - } - - log.Info(ctx, "setting thread_id_new = thread_id (this may take a while...)") - if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - return batchUpdateByID(ctx, tx, - "statuses", // table - "id", // batchByCol - "UPDATE ? SET ? = ?", // updateQuery - []any{bun.Ident("statuses"), - bun.Ident("thread_id_new"), - bun.Ident("thread_id")}, - ) - }); err != nil { - return err - } - - // Attempt to merge any sqlite write-ahead-log. - if err := doWALCheckpoint(ctx, db); err != nil { - return err - } - log.Info(ctx, "dropping old statuses thread_id index") if _, err := db.NewDropIndex(). Index("statuses_thread_id_idx"). @@ -274,6 +340,11 @@ type statusRethreader struct { // its contents are ephemeral. statuses []*oldmodel.Status + // newThreadIDSet is used to track whether + // statuses in statusIDs have already have + // thread_id_new set on them. + newThreadIDSet map[string]struct{} + // seenIDs tracks the unique status and // thread IDs we have seen, ensuring we // don't append duplicates to statusIDs @@ -289,14 +360,15 @@ type statusRethreader struct { } // rethreadStatus is the main logic handler for statusRethreader{}. this is what gets called from the migration -// in order to trigger a status rethreading operation for the given status, returning total number rethreaded. -func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, status *oldmodel.Status) (int, error) { +// in order to trigger a status rethreading operation for the given status, returning total number of rows changed. +func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, status *oldmodel.Status, straggler bool) (int64, error) { // Zero slice and // map ptr values. clear(sr.statusIDs) clear(sr.threadIDs) clear(sr.statuses) + clear(sr.newThreadIDSet) clear(sr.seenIDs) // Reset slices and values for use. @@ -305,6 +377,11 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu sr.statuses = sr.statuses[:0] sr.allThreaded = true + if sr.newThreadIDSet == nil { + // Allocate new hash set for newThreadIDSet. + sr.newThreadIDSet = make(map[string]struct{}) + } + if sr.seenIDs == nil { // Allocate new hash set for status IDs. sr.seenIDs = make(map[string]struct{}) @@ -317,12 +394,22 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu // to the rethreadStatus() call. if err := tx.NewSelect(). Model(status). - Column("in_reply_to_id", "thread_id"). + Column("in_reply_to_id", "thread_id", "thread_id_new"). Where("? = ?", bun.Ident("id"), status.ID). Scan(ctx); err != nil { return 0, gtserror.Newf("error selecting status: %w", err) } + // If we've just threaded this status by setting + // thread_id_new, then by definition anything we + // could find from the entire thread must now be + // threaded, so we can save some database calls + // by skipping iterating up + down from here. + if status.ThreadIDNew != id.Lowest { + log.Debugf(ctx, "skipping just rethreaded status: %s", status.ID) + return 0, nil + } + // status and thread ID cursor // index values. these are used // to keep track of newly loaded @@ -371,14 +458,14 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu threadIdx = len(sr.threadIDs) } - // Total number of - // statuses threaded. - total := len(sr.statusIDs) - // Check for the case where the entire // batch of statuses is already correctly // threaded. Then we have nothing to do! - if sr.allThreaded && len(sr.threadIDs) == 1 { + // + // Skip this check for straggler statuses + // that are part of broken threads. + if !straggler && sr.allThreaded && len(sr.threadIDs) == 1 { + log.Debug(ctx, "skipping just rethreaded thread") return 0, nil } @@ -417,36 +504,120 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu } } - // Update all the statuses to - // use determined thread_id. - if _, err := tx.NewUpdate(). - Table("statuses"). - Where("? IN (?)", bun.Ident("id"), bun.In(sr.statusIDs)). - Set("? = ?", bun.Ident("thread_id"), threadID). - Exec(ctx); err != nil { + var ( + res sql.Result + err error + ) + + if len(sr.statusIDs) == 1 { + + // If we're only updating one status + // we can use a simple update query. + res, err = tx.NewUpdate(). + // Update the status model. + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + // Set the new thread ID, which we can use as + // an indication that we've migrated this batch. + Set("? = ?", bun.Ident("thread_id_new"), threadID). + // While we're here, also set old thread_id, as + // we'll use it for further rethreading purposes. + Set("? = ?", bun.Ident("thread_id"), threadID). + Where("? = ?", bun.Ident("status.id"), sr.statusIDs[0]). + Exec(ctx) + + } else { + + // If we're updating multiple statuses at once, + // build up a common table expression to update + // all statuses in this thread to use threadID. + // + // This ought to be a little more readable than + // using an "IN(*)" query, and PG or SQLite *may* + // be able to optimize it better. + // + // See: + // + // - https://sqlite.org/lang_with.html + // - https://www.postgresql.org/docs/current/queries-with.html + // - https://bun.uptrace.dev/guide/query-update.html#bulk-update + values := make([]*util.Status, 0, len(sr.statusIDs)) + for _, statusID := range sr.statusIDs { + // Filter out statusIDs that have already had + // thread_id_new set, to avoid spurious writes. + if _, set := sr.newThreadIDSet[statusID]; !set { + values = append(values, &util.Status{ + ID: statusID, + }) + } + } + + // Resulting query will look something like this: + // + // WITH "_data" ("id") AS ( + // VALUES + // ('01JR6PZED0DDR2VZHQ8H87ZW98'), + // ('01JR6PZED0J91MJCAFDTCCCG8Q') + // ) + // UPDATE "statuses" AS "status" + // SET + // "thread_id_new" = '01K6MGKX54BBJ3Y1FBPQY45E5P', + // "thread_id" = '01K6MGKX54BBJ3Y1FBPQY45E5P' + // FROM _data + // WHERE "status"."id" = "_data"."id" + res, err = tx.NewUpdate(). + // Update the status model. + Model((*oldmodel.Status)(nil)). + // Provide the CTE values as "_data". + With("_data", tx.NewValues(&values)). + // Include `FROM _data` statement so we can use + // `_data` table in SET and WHERE components. + TableExpr("_data"). + // Set the new thread ID, which we can use as + // an indication that we've migrated this batch. + Set("? = ?", bun.Ident("thread_id_new"), threadID). + // While we're here, also set old thread_id, as + // we'll use it for further rethreading purposes. + Set("? = ?", bun.Ident("thread_id"), threadID). + // "Join" to the CTE on status ID. + Where("? = ?", bun.Ident("status.id"), bun.Ident("_data.id")). + Exec(ctx) + } + + if err != nil { return 0, gtserror.Newf("error updating status thread ids: %w", err) } + rowsAffected, err := res.RowsAffected() + if err != nil { + return 0, gtserror.Newf("error counting rows affected: %w", err) + } + if len(sr.threadIDs) > 0 { // Update any existing thread // mutes to use latest thread_id. + + // Dedupe thread IDs before query + // to avoid ludicrous "IN" clause. + threadIDs := sr.threadIDs + threadIDs = xslices.Deduplicate(threadIDs) if _, err := tx.NewUpdate(). Table("thread_mutes"). - Where("? IN (?)", bun.Ident("thread_id"), bun.In(sr.threadIDs)). + Where("? IN (?)", bun.Ident("thread_id"), bun.In(threadIDs)). Set("? = ?", bun.Ident("thread_id"), threadID). Exec(ctx); err != nil { return 0, gtserror.Newf("error updating mute thread ids: %w", err) } } - return total, nil + return rowsAffected, nil } // append will append the given status to the internal tracking of statusRethreader{} for // potential future operations, checking for uniqueness. it tracks the inReplyToID value // for the next call to getParents(), it tracks the status ID for list of statuses that -// need updating, the thread ID for the list of thread links and mutes that need updating, -// and whether all the statuses all have a provided thread ID (i.e. allThreaded). +// may need updating, whether a new thread ID has been set for each status, the thread ID +// for the list of thread links and mutes that need updating, and whether all the statuses +// all have a provided thread ID (i.e. allThreaded). func (sr *statusRethreader) append(status *oldmodel.Status) { // Check if status already seen before. @@ -479,7 +650,14 @@ func (sr *statusRethreader) append(status *oldmodel.Status) { } // Add status ID to map of seen IDs. - sr.seenIDs[status.ID] = struct{}{} + mark := struct{}{} + sr.seenIDs[status.ID] = mark + + // If new thread ID has already been + // set, add status ID to map of set IDs. + if status.ThreadIDNew != id.Lowest { + sr.newThreadIDSet[status.ID] = mark + } } func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error { @@ -496,7 +674,7 @@ func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error { // Select next parent status. if err := tx.NewSelect(). Model(&parent). - Column("id", "in_reply_to_id", "thread_id"). + Column("id", "in_reply_to_id", "thread_id", "thread_id_new"). Where("? = ?", bun.Ident("id"), id). Scan(ctx); err != nil && err != db.ErrNoEntries { return err @@ -535,7 +713,7 @@ func (sr *statusRethreader) getChildren(ctx context.Context, tx bun.Tx, idx int) // Select children of ID. if err := tx.NewSelect(). Model(&sr.statuses). - Column("id", "thread_id"). + Column("id", "thread_id", "thread_id_new"). Where("? = ?", bun.Ident("in_reply_to_id"), id). Scan(ctx); err != nil && err != db.ErrNoEntries { return err @@ -560,14 +738,19 @@ func (sr *statusRethreader) getStragglers(ctx context.Context, tx bun.Tx, idx in clear(sr.statuses) sr.statuses = sr.statuses[:0] + // Dedupe thread IDs before query + // to avoid ludicrous "IN" clause. + threadIDs := sr.threadIDs[idx:] + threadIDs = xslices.Deduplicate(threadIDs) + // Select stragglers that // also have thread IDs. if err := tx.NewSelect(). Model(&sr.statuses). - Column("id", "thread_id", "in_reply_to_id"). + Column("id", "thread_id", "in_reply_to_id", "thread_id_new"). Where("? IN (?) AND ? NOT IN (?)", bun.Ident("thread_id"), - bun.In(sr.threadIDs[idx:]), + bun.In(threadIDs), bun.Ident("id"), bun.In(sr.statusIDs), ). |
