summaryrefslogtreecommitdiff
path: root/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/migrations/20250415111056_thread_all_statuses.go')
-rw-r--r--internal/db/bundb/migrations/20250415111056_thread_all_statuses.go580
1 files changed, 580 insertions, 0 deletions
diff --git a/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go
new file mode 100644
index 000000000..4213da4f2
--- /dev/null
+++ b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go
@@ -0,0 +1,580 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "reflect"
+ "slices"
+ "strings"
+
+ "code.superseriousbusiness.org/gotosocial/internal/db"
+ newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new"
+ oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old"
+ "code.superseriousbusiness.org/gotosocial/internal/gtserror"
+ "code.superseriousbusiness.org/gotosocial/internal/id"
+ "code.superseriousbusiness.org/gotosocial/internal/log"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ newType := reflect.TypeOf(&newmodel.Status{})
+
+ // Get the new column definition with not-null thread_id.
+ newColDef, err := getBunColumnDef(db, newType, "ThreadID")
+ if err != nil {
+ return gtserror.Newf("error getting bun column def: %w", err)
+ }
+
+ // Update column def to use '${name}_new'.
+ newColDef = strings.Replace(newColDef,
+ "thread_id", "thread_id_new", 1)
+
+ var sr statusRethreader
+ var total uint64
+ var maxID string
+ var statuses []*oldmodel.Status
+
+ // Start at largest
+ // possible ULID value.
+ maxID = id.Highest
+
+ log.Warn(ctx, "rethreading top-level statuses, this will take a *long* time")
+ for /* TOP LEVEL STATUS LOOP */ {
+
+ // Reset slice.
+ clear(statuses)
+ statuses = statuses[:0]
+
+ // Select top-level statuses.
+ if err := db.NewSelect().
+ Model(&statuses).
+ Column("id", "thread_id").
+
+ // We specifically use in_reply_to_account_id instead of in_reply_to_id as
+ // they should both be set / unset in unison, but we specifically have an
+ // index on in_reply_to_account_id with ID ordering, unlike in_reply_to_id.
+ Where("? IS NULL", bun.Ident("in_reply_to_account_id")).
+ Where("? < ?", bun.Ident("id"), maxID).
+ OrderExpr("? DESC", bun.Ident("id")).
+ Limit(5000).
+ Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return gtserror.Newf("error selecting top level statuses: %w", err)
+ }
+
+ // Reached end of block.
+ if len(statuses) == 0 {
+ break
+ }
+
+ // Set next maxID value from statuses.
+ maxID = statuses[len(statuses)-1].ID
+
+ // Rethread each selected batch of top-level statuses in a transaction.
+ if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+
+ // Rethread each top-level status.
+ for _, status := range statuses {
+ n, err := sr.rethreadStatus(ctx, tx, status)
+ if err != nil {
+ return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
+ }
+ total += n
+ }
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ log.Infof(ctx, "[%d] rethreading statuses (top-level)", total)
+ }
+
+ log.Warn(ctx, "rethreading straggler statuses, this will take a *long* time")
+ for /* STRAGGLER STATUS LOOP */ {
+
+ // Reset slice.
+ clear(statuses)
+ statuses = statuses[:0]
+
+ // Select straggler statuses.
+ if err := db.NewSelect().
+ Model(&statuses).
+ Column("id", "in_reply_to_id", "thread_id").
+ Where("? IS NULL", bun.Ident("thread_id")).
+
+ // We select in smaller batches for this part
+ // of the migration as there is a chance that
+ // we may be fetching statuses that might be
+ // part of the same thread, i.e. one call to
+ // rethreadStatus() may effect other statuses
+ // later in the slice.
+ Limit(1000).
+ Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return gtserror.Newf("error selecting straggler statuses: %w", err)
+ }
+
+ // Reached end of block.
+ if len(statuses) == 0 {
+ break
+ }
+
+ // Rethread each selected batch of straggler statuses in a transaction.
+ if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+
+ // Rethread each top-level status.
+ for _, status := range statuses {
+ n, err := sr.rethreadStatus(ctx, tx, status)
+ if err != nil {
+ return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
+ }
+ total += n
+ }
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ log.Infof(ctx, "[%d] rethreading statuses (stragglers)", total)
+ }
+
+ // Attempt to merge any sqlite write-ahead-log.
+ if err := doWALCheckpoint(ctx, db); err != nil {
+ return err
+ }
+
+ log.Info(ctx, "dropping old thread_to_statuses table")
+ if _, err := db.NewDropTable().
+ Table("thread_to_statuses").
+ IfExists().
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error dropping old thread_to_statuses table: %w", err)
+ }
+
+ // Run the majority of the thread_id_new -> thread_id migration in a tx.
+ if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ log.Info(ctx, "creating new statuses thread_id column")
+ if _, err := tx.NewAddColumn().
+ Table("statuses").
+ ColumnExpr(newColDef).
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error creating new thread_id column: %w", err)
+ }
+
+ log.Info(ctx, "setting thread_id_new = thread_id (this may take a while...)")
+ if err := batchUpdateByID(ctx, tx,
+ "statuses", // table
+ "id", // batchByCol
+ "UPDATE ? SET ? = ?", // updateQuery
+ []any{bun.Ident("statuses"),
+ bun.Ident("thread_id_new"),
+ bun.Ident("thread_id")},
+ ); err != nil {
+ return err
+ }
+
+ log.Info(ctx, "dropping old statuses thread_id index")
+ if _, err := tx.NewDropIndex().
+ Index("statuses_thread_id_idx").
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error dropping old thread_id index: %w", err)
+ }
+
+ log.Info(ctx, "dropping old statuses thread_id column")
+ if _, err := tx.NewDropColumn().
+ Table("statuses").
+ Column("thread_id").
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error dropping old thread_id column: %w", err)
+ }
+
+ log.Info(ctx, "renaming thread_id_new to thread_id")
+ if _, err := tx.NewRaw(
+ "ALTER TABLE ? RENAME COLUMN ? TO ?",
+ bun.Ident("statuses"),
+ bun.Ident("thread_id_new"),
+ bun.Ident("thread_id"),
+ ).Exec(ctx); err != nil {
+ return gtserror.Newf("error renaming new column: %w", err)
+ }
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ // Attempt to merge any sqlite write-ahead-log.
+ if err := doWALCheckpoint(ctx, db); err != nil {
+ return err
+ }
+
+ log.Info(ctx, "creating new statuses thread_id index")
+ if _, err := db.NewCreateIndex().
+ Table("statuses").
+ Index("statuses_thread_id_idx").
+ Column("thread_id").
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error creating new thread_id index: %w", err)
+ }
+
+ return nil
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return nil
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
+
+type statusRethreader struct {
+ // the unique status and thread IDs
+ // of all models passed to append().
+ // these are later used to update all
+ // statuses to a single thread ID, and
+ // update all thread related models to
+ // use the new updated thread ID.
+ statusIDs []string
+ threadIDs []string
+
+ // stores the unseen IDs of status
+ // InReplyTos newly tracked in append(),
+ // which is then used for a SELECT query
+ // in getParents(), then promptly reset.
+ inReplyToIDs []string
+
+ // statuses simply provides a reusable
+ // slice of status models for selects.
+ // its contents are ephemeral.
+ statuses []*oldmodel.Status
+
+ // seenIDs tracks the unique status and
+ // thread IDs we have seen, ensuring we
+ // don't append duplicates to statusIDs
+ // or threadIDs slices. also helps prevent
+ // adding duplicate parents to inReplyToIDs.
+ seenIDs map[string]struct{}
+
+ // allThreaded tracks whether every status
+ // passed to append() has a thread ID set.
+ // together with len(threadIDs) this can
+ // determine if already threaded correctly.
+ allThreaded bool
+}
+
+// rethreadStatus is the main logic handler for statusRethreader{}. this is what gets called from the migration
+// in order to trigger a status rethreading operation for the given status, returning total number rethreaded.
+func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, status *oldmodel.Status) (uint64, error) {
+
+ // Zero slice and
+ // map ptr values.
+ clear(sr.statusIDs)
+ clear(sr.threadIDs)
+ clear(sr.statuses)
+ clear(sr.seenIDs)
+
+ // Reset slices and values for use.
+ sr.statusIDs = sr.statusIDs[:0]
+ sr.threadIDs = sr.threadIDs[:0]
+ sr.statuses = sr.statuses[:0]
+ sr.allThreaded = true
+
+ if sr.seenIDs == nil {
+ // Allocate new hash set for status IDs.
+ sr.seenIDs = make(map[string]struct{})
+ }
+
+ // Ensure the passed status
+ // has up-to-date information.
+ // This may have changed from
+ // the initial batch selection
+ // to the rethreadStatus() call.
+ if err := tx.NewSelect().
+ Model(status).
+ Column("in_reply_to_id", "thread_id").
+ Where("? = ?", bun.Ident("id"), status.ID).
+ Scan(ctx); err != nil {
+ return 0, gtserror.Newf("error selecting status: %w", err)
+ }
+
+ // status and thread ID cursor
+ // index values. these are used
+ // to keep track of newly loaded
+ // status / thread IDs between
+ // loop iterations.
+ var statusIdx int
+ var threadIdx int
+
+ // Append given status as
+ // first to our ID slices.
+ sr.append(status)
+
+ for {
+ // Fetch parents for newly seen in_reply_tos since last loop.
+ if err := sr.getParents(ctx, tx); err != nil {
+ return 0, gtserror.Newf("error getting parents: %w", err)
+ }
+
+ // Fetch children for newly seen statuses since last loop.
+ if err := sr.getChildren(ctx, tx, statusIdx); err != nil {
+ return 0, gtserror.Newf("error getting children: %w", err)
+ }
+
+ // Check for newly picked-up threads
+ // to find stragglers for below. Else
+ // we've reached end of what we can do.
+ if threadIdx >= len(sr.threadIDs) {
+ break
+ }
+
+ // Update status IDs cursor.
+ statusIdx = len(sr.statusIDs)
+
+ // Fetch any stragglers for newly seen threads since last loop.
+ if err := sr.getStragglers(ctx, tx, threadIdx); err != nil {
+ return 0, gtserror.Newf("error getting stragglers: %w", err)
+ }
+
+ // Check for newly picked-up straggling statuses / replies to
+ // find parents / children for. Else we've done all we can do.
+ if statusIdx >= len(sr.statusIDs) && len(sr.inReplyToIDs) == 0 {
+ break
+ }
+
+ // Update thread IDs cursor.
+ threadIdx = len(sr.threadIDs)
+ }
+
+ // Total number of
+ // statuses threaded.
+ total := len(sr.statusIDs)
+
+ // Check for the case where the entire
+ // batch of statuses is already correctly
+ // threaded. Then we have nothing to do!
+ if sr.allThreaded && len(sr.threadIDs) == 1 {
+ return 0, nil
+ }
+
+ // Sort all of the threads and
+ // status IDs by age; old -> new.
+ slices.Sort(sr.threadIDs)
+ slices.Sort(sr.statusIDs)
+
+ var threadID string
+
+ if len(sr.threadIDs) > 0 {
+ // Regardless of whether there ended up being
+ // multiple threads, we take the oldest value
+ // thread ID to use for entire batch of them.
+ threadID = sr.threadIDs[0]
+ sr.threadIDs = sr.threadIDs[1:]
+ }
+
+ if threadID == "" {
+ // None of the previous parents were threaded, we instead
+ // generate new thread with ID based on oldest creation time.
+ createdAt, err := id.TimeFromULID(sr.statusIDs[0])
+ if err != nil {
+ return 0, gtserror.Newf("error parsing status ulid: %w", err)
+ }
+
+ // Generate thread ID from parsed time.
+ threadID = id.NewULIDFromTime(createdAt)
+
+ // We need to create a
+ // new thread table entry.
+ if _, err = tx.NewInsert().
+ Model(&newmodel.Thread{ID: threadID}).
+ Exec(ctx); err != nil {
+ return 0, gtserror.Newf("error creating new thread: %w", err)
+ }
+ }
+
+ // Update all the statuses to
+ // use determined thread_id.
+ if _, err := tx.NewUpdate().
+ Table("statuses").
+ Where("? IN (?)", bun.Ident("id"), bun.In(sr.statusIDs)).
+ Set("? = ?", bun.Ident("thread_id"), threadID).
+ Exec(ctx); err != nil {
+ return 0, gtserror.Newf("error updating status thread ids: %w", err)
+ }
+
+ if len(sr.threadIDs) > 0 {
+ // Update any existing thread
+ // mutes to use latest thread_id.
+ if _, err := tx.NewUpdate().
+ Table("thread_mutes").
+ Where("? IN (?)", bun.Ident("thread_id"), bun.In(sr.threadIDs)).
+ Set("? = ?", bun.Ident("thread_id"), threadID).
+ Exec(ctx); err != nil {
+ return 0, gtserror.Newf("error updating mute thread ids: %w", err)
+ }
+ }
+
+ return uint64(total), nil
+}
+
+// append will append the given status to the internal tracking of statusRethreader{} for
+// potential future operations, checking for uniqueness. it tracks the inReplyToID value
+// for the next call to getParents(), it tracks the status ID for list of statuses that
+// need updating, the thread ID for the list of thread links and mutes that need updating,
+// and whether all the statuses all have a provided thread ID (i.e. allThreaded).
+func (sr *statusRethreader) append(status *oldmodel.Status) {
+
+ // Check if status already seen before.
+ if _, ok := sr.seenIDs[status.ID]; ok {
+ return
+ }
+
+ if status.InReplyToID != "" {
+ // Status has a parent, add any unique parent ID
+ // to list of reply IDs that need to be queried.
+ if _, ok := sr.seenIDs[status.InReplyToID]; ok {
+ sr.inReplyToIDs = append(sr.inReplyToIDs, status.InReplyToID)
+ }
+ }
+
+ // Add status' ID to list of seen status IDs.
+ sr.statusIDs = append(sr.statusIDs, status.ID)
+
+ if status.ThreadID != "" {
+ // Status was threaded, add any unique thread
+ // ID to our list of known status thread IDs.
+ if _, ok := sr.seenIDs[status.ThreadID]; !ok {
+ sr.threadIDs = append(sr.threadIDs, status.ThreadID)
+ }
+ } else {
+ // Status was not threaded,
+ // we now know not all statuses
+ // found were threaded.
+ sr.allThreaded = false
+ }
+
+ // Add status ID to map of seen IDs.
+ sr.seenIDs[status.ID] = struct{}{}
+}
+
+func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error {
+ var parent oldmodel.Status
+
+ // Iteratively query parent for each stored
+ // reply ID. Note this is safe to do as slice
+ // loop since 'seenIDs' prevents duplicates.
+ for i := 0; i < len(sr.inReplyToIDs); i++ {
+
+ // Get next status ID.
+ id := sr.statusIDs[i]
+
+ // Select next parent status.
+ if err := tx.NewSelect().
+ Model(&parent).
+ Column("id", "in_reply_to_id", "thread_id").
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx); err != nil && err != db.ErrNoEntries {
+ return err
+ }
+
+ // Parent was missing.
+ if parent.ID == "" {
+ continue
+ }
+
+ // Add to slices.
+ sr.append(&parent)
+ }
+
+ // Reset reply slice.
+ clear(sr.inReplyToIDs)
+ sr.inReplyToIDs = sr.inReplyToIDs[:0]
+
+ return nil
+}
+
+func (sr *statusRethreader) getChildren(ctx context.Context, tx bun.Tx, idx int) error {
+ // Iteratively query all children for each
+ // of fetched parent statuses. Note this is
+ // safe to do as a slice loop since 'seenIDs'
+ // ensures it only ever contains unique IDs.
+ for i := idx; i < len(sr.statusIDs); i++ {
+
+ // Get next status ID.
+ id := sr.statusIDs[i]
+
+ // Reset child slice.
+ clear(sr.statuses)
+ sr.statuses = sr.statuses[:0]
+
+ // Select children of ID.
+ if err := tx.NewSelect().
+ Model(&sr.statuses).
+ Column("id", "thread_id").
+ Where("? = ?", bun.Ident("in_reply_to_id"), id).
+ Scan(ctx); err != nil && err != db.ErrNoEntries {
+ return err
+ }
+
+ // Append child status IDs to slices.
+ for _, child := range sr.statuses {
+ sr.append(child)
+ }
+ }
+
+ return nil
+}
+
+func (sr *statusRethreader) getStragglers(ctx context.Context, tx bun.Tx, idx int) error {
+ // Check for threads to query.
+ if idx >= len(sr.threadIDs) {
+ return nil
+ }
+
+ // Reset status slice.
+ clear(sr.statuses)
+ sr.statuses = sr.statuses[:0]
+
+ // Select stragglers that
+ // also have thread IDs.
+ if err := tx.NewSelect().
+ Model(&sr.statuses).
+ Column("id", "thread_id", "in_reply_to_id").
+ Where("? IN (?) AND ? NOT IN (?)",
+ bun.Ident("thread_id"),
+ bun.In(sr.threadIDs[idx:]),
+ bun.Ident("id"),
+ bun.In(sr.statusIDs),
+ ).
+ Scan(ctx); err != nil && err != db.ErrNoEntries {
+ return err
+ }
+
+ // Append status IDs to slices.
+ for _, status := range sr.statuses {
+ sr.append(status)
+ }
+
+ return nil
+}