diff options
Diffstat (limited to 'internal/db/bundb/migrations/20250415111056_thread_all_statuses.go')
| -rw-r--r-- | internal/db/bundb/migrations/20250415111056_thread_all_statuses.go | 580 |
1 files changed, 580 insertions, 0 deletions
diff --git a/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go new file mode 100644 index 000000000..4213da4f2 --- /dev/null +++ b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go @@ -0,0 +1,580 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + "database/sql" + "errors" + "reflect" + "slices" + "strings" + + "code.superseriousbusiness.org/gotosocial/internal/db" + newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new" + oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old" + "code.superseriousbusiness.org/gotosocial/internal/gtserror" + "code.superseriousbusiness.org/gotosocial/internal/id" + "code.superseriousbusiness.org/gotosocial/internal/log" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + newType := reflect.TypeOf(&newmodel.Status{}) + + // Get the new column definition with not-null thread_id. + newColDef, err := getBunColumnDef(db, newType, "ThreadID") + if err != nil { + return gtserror.Newf("error getting bun column def: %w", err) + } + + // Update column def to use '${name}_new'. + newColDef = strings.Replace(newColDef, + "thread_id", "thread_id_new", 1) + + var sr statusRethreader + var total uint64 + var maxID string + var statuses []*oldmodel.Status + + // Start at largest + // possible ULID value. + maxID = id.Highest + + log.Warn(ctx, "rethreading top-level statuses, this will take a *long* time") + for /* TOP LEVEL STATUS LOOP */ { + + // Reset slice. + clear(statuses) + statuses = statuses[:0] + + // Select top-level statuses. + if err := db.NewSelect(). + Model(&statuses). + Column("id", "thread_id"). + + // We specifically use in_reply_to_account_id instead of in_reply_to_id as + // they should both be set / unset in unison, but we specifically have an + // index on in_reply_to_account_id with ID ordering, unlike in_reply_to_id. + Where("? IS NULL", bun.Ident("in_reply_to_account_id")). + Where("? < ?", bun.Ident("id"), maxID). + OrderExpr("? DESC", bun.Ident("id")). + Limit(5000). + Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { + return gtserror.Newf("error selecting top level statuses: %w", err) + } + + // Reached end of block. + if len(statuses) == 0 { + break + } + + // Set next maxID value from statuses. + maxID = statuses[len(statuses)-1].ID + + // Rethread each selected batch of top-level statuses in a transaction. + if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + + // Rethread each top-level status. + for _, status := range statuses { + n, err := sr.rethreadStatus(ctx, tx, status) + if err != nil { + return gtserror.Newf("error rethreading status %s: %w", status.URI, err) + } + total += n + } + + return nil + }); err != nil { + return err + } + + log.Infof(ctx, "[%d] rethreading statuses (top-level)", total) + } + + log.Warn(ctx, "rethreading straggler statuses, this will take a *long* time") + for /* STRAGGLER STATUS LOOP */ { + + // Reset slice. + clear(statuses) + statuses = statuses[:0] + + // Select straggler statuses. + if err := db.NewSelect(). + Model(&statuses). + Column("id", "in_reply_to_id", "thread_id"). + Where("? IS NULL", bun.Ident("thread_id")). + + // We select in smaller batches for this part + // of the migration as there is a chance that + // we may be fetching statuses that might be + // part of the same thread, i.e. one call to + // rethreadStatus() may effect other statuses + // later in the slice. + Limit(1000). + Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { + return gtserror.Newf("error selecting straggler statuses: %w", err) + } + + // Reached end of block. + if len(statuses) == 0 { + break + } + + // Rethread each selected batch of straggler statuses in a transaction. + if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + + // Rethread each top-level status. + for _, status := range statuses { + n, err := sr.rethreadStatus(ctx, tx, status) + if err != nil { + return gtserror.Newf("error rethreading status %s: %w", status.URI, err) + } + total += n + } + + return nil + }); err != nil { + return err + } + + log.Infof(ctx, "[%d] rethreading statuses (stragglers)", total) + } + + // Attempt to merge any sqlite write-ahead-log. + if err := doWALCheckpoint(ctx, db); err != nil { + return err + } + + log.Info(ctx, "dropping old thread_to_statuses table") + if _, err := db.NewDropTable(). + Table("thread_to_statuses"). + IfExists(). + Exec(ctx); err != nil { + return gtserror.Newf("error dropping old thread_to_statuses table: %w", err) + } + + // Run the majority of the thread_id_new -> thread_id migration in a tx. + if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + log.Info(ctx, "creating new statuses thread_id column") + if _, err := tx.NewAddColumn(). + Table("statuses"). + ColumnExpr(newColDef). + Exec(ctx); err != nil { + return gtserror.Newf("error creating new thread_id column: %w", err) + } + + log.Info(ctx, "setting thread_id_new = thread_id (this may take a while...)") + if err := batchUpdateByID(ctx, tx, + "statuses", // table + "id", // batchByCol + "UPDATE ? SET ? = ?", // updateQuery + []any{bun.Ident("statuses"), + bun.Ident("thread_id_new"), + bun.Ident("thread_id")}, + ); err != nil { + return err + } + + log.Info(ctx, "dropping old statuses thread_id index") + if _, err := tx.NewDropIndex(). + Index("statuses_thread_id_idx"). + Exec(ctx); err != nil { + return gtserror.Newf("error dropping old thread_id index: %w", err) + } + + log.Info(ctx, "dropping old statuses thread_id column") + if _, err := tx.NewDropColumn(). + Table("statuses"). + Column("thread_id"). + Exec(ctx); err != nil { + return gtserror.Newf("error dropping old thread_id column: %w", err) + } + + log.Info(ctx, "renaming thread_id_new to thread_id") + if _, err := tx.NewRaw( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + bun.Ident("statuses"), + bun.Ident("thread_id_new"), + bun.Ident("thread_id"), + ).Exec(ctx); err != nil { + return gtserror.Newf("error renaming new column: %w", err) + } + + return nil + }); err != nil { + return err + } + + // Attempt to merge any sqlite write-ahead-log. + if err := doWALCheckpoint(ctx, db); err != nil { + return err + } + + log.Info(ctx, "creating new statuses thread_id index") + if _, err := db.NewCreateIndex(). + Table("statuses"). + Index("statuses_thread_id_idx"). + Column("thread_id"). + IfNotExists(). + Exec(ctx); err != nil { + return gtserror.Newf("error creating new thread_id index: %w", err) + } + + return nil + } + + down := func(ctx context.Context, db *bun.DB) error { + return nil + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} + +type statusRethreader struct { + // the unique status and thread IDs + // of all models passed to append(). + // these are later used to update all + // statuses to a single thread ID, and + // update all thread related models to + // use the new updated thread ID. + statusIDs []string + threadIDs []string + + // stores the unseen IDs of status + // InReplyTos newly tracked in append(), + // which is then used for a SELECT query + // in getParents(), then promptly reset. + inReplyToIDs []string + + // statuses simply provides a reusable + // slice of status models for selects. + // its contents are ephemeral. + statuses []*oldmodel.Status + + // seenIDs tracks the unique status and + // thread IDs we have seen, ensuring we + // don't append duplicates to statusIDs + // or threadIDs slices. also helps prevent + // adding duplicate parents to inReplyToIDs. + seenIDs map[string]struct{} + + // allThreaded tracks whether every status + // passed to append() has a thread ID set. + // together with len(threadIDs) this can + // determine if already threaded correctly. + allThreaded bool +} + +// rethreadStatus is the main logic handler for statusRethreader{}. this is what gets called from the migration +// in order to trigger a status rethreading operation for the given status, returning total number rethreaded. +func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, status *oldmodel.Status) (uint64, error) { + + // Zero slice and + // map ptr values. + clear(sr.statusIDs) + clear(sr.threadIDs) + clear(sr.statuses) + clear(sr.seenIDs) + + // Reset slices and values for use. + sr.statusIDs = sr.statusIDs[:0] + sr.threadIDs = sr.threadIDs[:0] + sr.statuses = sr.statuses[:0] + sr.allThreaded = true + + if sr.seenIDs == nil { + // Allocate new hash set for status IDs. + sr.seenIDs = make(map[string]struct{}) + } + + // Ensure the passed status + // has up-to-date information. + // This may have changed from + // the initial batch selection + // to the rethreadStatus() call. + if err := tx.NewSelect(). + Model(status). + Column("in_reply_to_id", "thread_id"). + Where("? = ?", bun.Ident("id"), status.ID). + Scan(ctx); err != nil { + return 0, gtserror.Newf("error selecting status: %w", err) + } + + // status and thread ID cursor + // index values. these are used + // to keep track of newly loaded + // status / thread IDs between + // loop iterations. + var statusIdx int + var threadIdx int + + // Append given status as + // first to our ID slices. + sr.append(status) + + for { + // Fetch parents for newly seen in_reply_tos since last loop. + if err := sr.getParents(ctx, tx); err != nil { + return 0, gtserror.Newf("error getting parents: %w", err) + } + + // Fetch children for newly seen statuses since last loop. + if err := sr.getChildren(ctx, tx, statusIdx); err != nil { + return 0, gtserror.Newf("error getting children: %w", err) + } + + // Check for newly picked-up threads + // to find stragglers for below. Else + // we've reached end of what we can do. + if threadIdx >= len(sr.threadIDs) { + break + } + + // Update status IDs cursor. + statusIdx = len(sr.statusIDs) + + // Fetch any stragglers for newly seen threads since last loop. + if err := sr.getStragglers(ctx, tx, threadIdx); err != nil { + return 0, gtserror.Newf("error getting stragglers: %w", err) + } + + // Check for newly picked-up straggling statuses / replies to + // find parents / children for. Else we've done all we can do. + if statusIdx >= len(sr.statusIDs) && len(sr.inReplyToIDs) == 0 { + break + } + + // Update thread IDs cursor. + threadIdx = len(sr.threadIDs) + } + + // Total number of + // statuses threaded. + total := len(sr.statusIDs) + + // Check for the case where the entire + // batch of statuses is already correctly + // threaded. Then we have nothing to do! + if sr.allThreaded && len(sr.threadIDs) == 1 { + return 0, nil + } + + // Sort all of the threads and + // status IDs by age; old -> new. + slices.Sort(sr.threadIDs) + slices.Sort(sr.statusIDs) + + var threadID string + + if len(sr.threadIDs) > 0 { + // Regardless of whether there ended up being + // multiple threads, we take the oldest value + // thread ID to use for entire batch of them. + threadID = sr.threadIDs[0] + sr.threadIDs = sr.threadIDs[1:] + } + + if threadID == "" { + // None of the previous parents were threaded, we instead + // generate new thread with ID based on oldest creation time. + createdAt, err := id.TimeFromULID(sr.statusIDs[0]) + if err != nil { + return 0, gtserror.Newf("error parsing status ulid: %w", err) + } + + // Generate thread ID from parsed time. + threadID = id.NewULIDFromTime(createdAt) + + // We need to create a + // new thread table entry. + if _, err = tx.NewInsert(). + Model(&newmodel.Thread{ID: threadID}). + Exec(ctx); err != nil { + return 0, gtserror.Newf("error creating new thread: %w", err) + } + } + + // Update all the statuses to + // use determined thread_id. + if _, err := tx.NewUpdate(). + Table("statuses"). + Where("? IN (?)", bun.Ident("id"), bun.In(sr.statusIDs)). + Set("? = ?", bun.Ident("thread_id"), threadID). + Exec(ctx); err != nil { + return 0, gtserror.Newf("error updating status thread ids: %w", err) + } + + if len(sr.threadIDs) > 0 { + // Update any existing thread + // mutes to use latest thread_id. + if _, err := tx.NewUpdate(). + Table("thread_mutes"). + Where("? IN (?)", bun.Ident("thread_id"), bun.In(sr.threadIDs)). + Set("? = ?", bun.Ident("thread_id"), threadID). + Exec(ctx); err != nil { + return 0, gtserror.Newf("error updating mute thread ids: %w", err) + } + } + + return uint64(total), nil +} + +// append will append the given status to the internal tracking of statusRethreader{} for +// potential future operations, checking for uniqueness. it tracks the inReplyToID value +// for the next call to getParents(), it tracks the status ID for list of statuses that +// need updating, the thread ID for the list of thread links and mutes that need updating, +// and whether all the statuses all have a provided thread ID (i.e. allThreaded). +func (sr *statusRethreader) append(status *oldmodel.Status) { + + // Check if status already seen before. + if _, ok := sr.seenIDs[status.ID]; ok { + return + } + + if status.InReplyToID != "" { + // Status has a parent, add any unique parent ID + // to list of reply IDs that need to be queried. + if _, ok := sr.seenIDs[status.InReplyToID]; ok { + sr.inReplyToIDs = append(sr.inReplyToIDs, status.InReplyToID) + } + } + + // Add status' ID to list of seen status IDs. + sr.statusIDs = append(sr.statusIDs, status.ID) + + if status.ThreadID != "" { + // Status was threaded, add any unique thread + // ID to our list of known status thread IDs. + if _, ok := sr.seenIDs[status.ThreadID]; !ok { + sr.threadIDs = append(sr.threadIDs, status.ThreadID) + } + } else { + // Status was not threaded, + // we now know not all statuses + // found were threaded. + sr.allThreaded = false + } + + // Add status ID to map of seen IDs. + sr.seenIDs[status.ID] = struct{}{} +} + +func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error { + var parent oldmodel.Status + + // Iteratively query parent for each stored + // reply ID. Note this is safe to do as slice + // loop since 'seenIDs' prevents duplicates. + for i := 0; i < len(sr.inReplyToIDs); i++ { + + // Get next status ID. + id := sr.statusIDs[i] + + // Select next parent status. + if err := tx.NewSelect(). + Model(&parent). + Column("id", "in_reply_to_id", "thread_id"). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx); err != nil && err != db.ErrNoEntries { + return err + } + + // Parent was missing. + if parent.ID == "" { + continue + } + + // Add to slices. + sr.append(&parent) + } + + // Reset reply slice. + clear(sr.inReplyToIDs) + sr.inReplyToIDs = sr.inReplyToIDs[:0] + + return nil +} + +func (sr *statusRethreader) getChildren(ctx context.Context, tx bun.Tx, idx int) error { + // Iteratively query all children for each + // of fetched parent statuses. Note this is + // safe to do as a slice loop since 'seenIDs' + // ensures it only ever contains unique IDs. + for i := idx; i < len(sr.statusIDs); i++ { + + // Get next status ID. + id := sr.statusIDs[i] + + // Reset child slice. + clear(sr.statuses) + sr.statuses = sr.statuses[:0] + + // Select children of ID. + if err := tx.NewSelect(). + Model(&sr.statuses). + Column("id", "thread_id"). + Where("? = ?", bun.Ident("in_reply_to_id"), id). + Scan(ctx); err != nil && err != db.ErrNoEntries { + return err + } + + // Append child status IDs to slices. + for _, child := range sr.statuses { + sr.append(child) + } + } + + return nil +} + +func (sr *statusRethreader) getStragglers(ctx context.Context, tx bun.Tx, idx int) error { + // Check for threads to query. + if idx >= len(sr.threadIDs) { + return nil + } + + // Reset status slice. + clear(sr.statuses) + sr.statuses = sr.statuses[:0] + + // Select stragglers that + // also have thread IDs. + if err := tx.NewSelect(). + Model(&sr.statuses). + Column("id", "thread_id", "in_reply_to_id"). + Where("? IN (?) AND ? NOT IN (?)", + bun.Ident("thread_id"), + bun.In(sr.threadIDs[idx:]), + bun.Ident("id"), + bun.In(sr.statusIDs), + ). + Scan(ctx); err != nil && err != db.ErrNoEntries { + return err + } + + // Append status IDs to slices. + for _, status := range sr.statuses { + sr.append(status) + } + + return nil +} |
