diff options
Diffstat (limited to 'internal/db/bundb/status.go')
-rw-r--r-- | internal/db/bundb/status.go | 158 |
1 files changed, 92 insertions, 66 deletions
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 25b773dfa..c6091e2c9 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -20,7 +20,6 @@ package bundb import ( "container/list" "context" - "database/sql" "errors" "time" @@ -96,6 +95,26 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St ) } +func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) { + return s.getStatus( + ctx, + "BoostOfID.AccountID", + func(status *gtsmodel.Status) error { + return s.newStatusQ(status). + Where("status.boost_of_id = ?", boostOfID). + Where("status.account_id = ?", byAccountID). + + // Our old code actually allowed a status to + // be boosted multiple times by the same author, + // so limit our query + order to fetch latest. + Order("status.id DESC"). // our IDs are timestamped + Limit(1). + Scan(ctx) + }, + boostOfID, byAccountID, + ) +} + func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) { // Fetch status from database cache with loader callback status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) { @@ -245,11 +264,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) } } - if err := errs.Combine(); err != nil { - return gtserror.Newf("%w", err) - } - - return nil + return errs.Combine() } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error { @@ -506,25 +521,17 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu } func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - var childIDs []string - - q := s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Column("status.id"). - Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID) - if minID != "" { - q = q.Where("? > ?", bun.Ident("status.id"), minID) - } - - if err := q.Scan(ctx, &childIDs); err != nil { - if err != sql.ErrNoRows { - log.Errorf(ctx, "error getting children for %q: %v", status.ID, err) - } + childIDs, err := s.getStatusReplyIDs(ctx, status.ID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + log.Errorf(ctx, "error getting status %s children: %v", status.ID, err) return } for _, id := range childIDs { + if id <= minID { + continue + } + // Fetch child with ID from database child, err := s.GetStatusByID(ctx, id) if err != nil { @@ -553,48 +560,80 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, } } -func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error) { - return s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID). - Count(ctx) +func (s *statusDB) GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { + statusIDs, err := s.getStatusReplyIDs(ctx, statusID) + if err != nil { + return nil, err + } + return s.GetStatusesByIDs(ctx, statusIDs) } -func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error) { - return s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). - Count(ctx) +func (s *statusDB) CountStatusReplies(ctx context.Context, statusID string) (int, error) { + statusIDs, err := s.getStatusReplyIDs(ctx, statusID) + return len(statusIDs), err } -func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error) { - return s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). - Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). - Count(ctx) +func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]string, error) { + return s.state.Caches.GTS.InReplyToIDs().Load(statusID, func() ([]string, error) { + var statusIDs []string + + // Status reply IDs not in cache, perform DB query! + if err := s.db. + NewSelect(). + Table("statuses"). + Column("id"). + Where("? = ?", bun.Ident("in_reply_to_id"), statusID). + Order("id DESC"). + Scan(ctx, &statusIDs); err != nil { + return nil, s.db.ProcessError(err) + } + + return statusIDs, nil + }) } -func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { - q := s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). - Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). - Where("? = ?", bun.Ident("status_fave.account_id"), accountID) +func (s *statusDB) GetStatusBoosts(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { + statusIDs, err := s.getStatusBoostIDs(ctx, statusID) + if err != nil { + return nil, err + } + return s.GetStatusesByIDs(ctx, statusIDs) +} - return s.db.Exists(ctx, q) +func (s *statusDB) IsStatusBoostedBy(ctx context.Context, statusID string, accountID string) (bool, error) { + boost, err := s.GetStatusBoost( + gtscontext.SetBarebones(ctx), + statusID, + accountID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return false, err + } + return (boost != nil), nil } -func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { - q := s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). - Where("? = ?", bun.Ident("status.account_id"), accountID) +func (s *statusDB) CountStatusBoosts(ctx context.Context, statusID string) (int, error) { + statusIDs, err := s.getStatusBoostIDs(ctx, statusID) + return len(statusIDs), err +} - return s.db.Exists(ctx, q) +func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]string, error) { + return s.state.Caches.GTS.BoostOfIDs().Load(statusID, func() ([]string, error) { + var statusIDs []string + + // Status boost IDs not in cache, perform DB query! + if err := s.db. + NewSelect(). + Table("statuses"). + Column("id"). + Where("? = ?", bun.Ident("boost_of_id"), statusID). + Order("id DESC"). + Scan(ctx, &statusIDs); err != nil { + return nil, s.db.ProcessError(err) + } + + return statusIDs, nil + }) } func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { @@ -616,16 +655,3 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St return s.db.Exists(ctx, q) } - -func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) { - reblogs := []*gtsmodel.Status{} - - q := s. - newStatusQ(&reblogs). - Where("? = ?", bun.Ident("status.boost_of_id"), status.ID) - - if err := q.Scan(ctx); err != nil { - return nil, s.db.ProcessError(err) - } - return reblogs, nil -} |