diff options
Diffstat (limited to 'internal/db/bundb/statusfave.go')
-rw-r--r-- | internal/db/bundb/statusfave.go | 216 |
1 files changed, 106 insertions, 110 deletions
diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index 7aff543fd..ab09fb1ba 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -19,6 +19,7 @@ package bundb import ( "context" + "database/sql" "errors" "fmt" @@ -44,8 +45,14 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, stat return s.db. NewSelect(). Model(fave). - Where("? = ?", bun.Ident("account_id"), accountID). - Where("? = ?", bun.Ident("status_id"), statusID). + Where("status_fave.account_id = ?", accountID). + Where("status_fave.status_id = ?", statusID). + + // Our old code actually allowed a status to + // be faved multiple times by the same author, + // so limit our query + order to fetch latest. + Order("status_fave.id DESC"). // our IDs are timestamped + Limit(1). Scan(ctx) }, accountID, @@ -89,63 +96,68 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery return fave, nil } - // Fetch the status fave author account. - fave.Account, err = s.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - fave.AccountID, - ) - if err != nil { - return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err) - } - - // Fetch the status fave target account. - fave.TargetAccount, err = s.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - fave.TargetAccountID, - ) - if err != nil { - return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err) - } - - // Fetch the status fave target status. - fave.Status, err = s.state.DB.GetStatusByID( - gtscontext.SetBarebones(ctx), - fave.StatusID, - ) - if err != nil { - return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err) + // Populate the status favourite model. + if err := s.PopulateStatusFave(ctx, fave); err != nil { + return nil, fmt.Errorf("error(s) populating status fave: %w", err) } return fave, nil } -func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) { - ids := []string{} - - if err := s.db. - NewSelect(). - Table("status_faves"). - Column("id"). - Where("? = ?", bun.Ident("status_id"), statusID). - Scan(ctx, &ids); err != nil { - return nil, s.db.ProcessError(err) +func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) { + // Fetch the status fave IDs for status. + faveIDs, err := s.getStatusFaveIDs(ctx, statusID) + if err != nil { + return nil, err } - faves := make([]*gtsmodel.StatusFave, 0, len(ids)) + // Preallocate a slice of expected status fave capacity. + faves := make([]*gtsmodel.StatusFave, 0, len(faveIDs)) - for _, id := range ids { + for _, id := range faveIDs { + // Fetch status fave model for each ID. fave, err := s.GetStatusFaveByID(ctx, id) if err != nil { log.Errorf(ctx, "error getting status fave %q: %v", id, err) continue } - faves = append(faves, fave) } return faves, nil } +func (s *statusFaveDB) IsStatusFavedBy(ctx context.Context, statusID string, accountID string) (bool, error) { + fave, err := s.GetStatusFave(ctx, accountID, statusID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return false, err + } + return (fave != nil), nil +} + +func (s *statusFaveDB) CountStatusFaves(ctx context.Context, statusID string) (int, error) { + faveIDs, err := s.getStatusFaveIDs(ctx, statusID) + return len(faveIDs), err +} + +func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) { + return s.state.Caches.GTS.StatusFaveIDs().Load(statusID, func() ([]string, error) { + var faveIDs []string + + // Status fave IDs not in cache, perform DB query! + if err := s.db. + NewSelect(). + Table("status_faves"). + Column("id"). + Where("? = ?", bun.Ident("status_id"), statusID). + Scan(ctx, &faveIDs); err != nil { + return nil, s.db.ProcessError(err) + } + + return faveIDs, nil + }) +} + func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error { var ( err error @@ -203,26 +215,32 @@ func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusF } func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) error { - defer s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + var statusID string - // Load fave into cache before attempting a delete, - // as we need it cached in order to trigger the invalidate - // callback. This in turn invalidates others. - _, err := s.GetStatusFaveByID(gtscontext.SetBarebones(ctx), id) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // not an issue. + // Perform DELETE on status fave, + // returning the status ID it was for. + if _, err := s.db.NewDelete(). + Table("status_faves"). + Where("id = ?", id). + Returning("status_id"). + Exec(ctx, &statusID); err != nil { + if err == sql.ErrNoRows { + // Not an issue, only due + // to us doing a RETURNING. err = nil } - return err + return s.db.ProcessError(err) } - // Finally delete fave from DB. - _, err = s.db.NewDelete(). - Table("status_faves"). - Where("? = ?", bun.Ident("id"), id). - Exec(ctx) - return s.db.ProcessError(err) + if statusID != "" { + // Invalidate any cached status faves for this status. + s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + + // Invalidate any cached status fave IDs for this status. + s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) + } + + return nil } func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error { @@ -230,12 +248,13 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set") } - var faveIDs []string + var statusIDs []string - q := s.db. - NewSelect(). - Column("id"). - Table("status_faves") + // Prepare DELETE query returning + // the deleted faves for status IDs. + q := s.db.NewDelete(). + Table("status_faves"). + Returning("status_id") if targetAccountID != "" { q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID) @@ -245,69 +264,46 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st q = q.Where("? = ?", bun.Ident("account_id"), originAccountID) } - if _, err := q.Exec(ctx, &faveIDs); err != nil { + // Execute query, store favourited status IDs. + if _, err := q.Exec(ctx, &statusIDs); err != nil { + if err == sql.ErrNoRows { + // Not an issue, only due + // to us doing a RETURNING. + err = nil + } return s.db.ProcessError(err) } - defer func() { - // Invalidate all IDs on return. - for _, id := range faveIDs { - s.state.Caches.GTS.StatusFave().Invalidate("ID", id) - } - }() + // Collate (deduplicating) status IDs. + statusIDs = collate(func(i int) string { + return statusIDs[i] + }, len(statusIDs)) - // Load all faves into cache, this *really* isn't great - // but it is the only way we can ensure we invalidate all - // related caches correctly (e.g. visibility). - for _, id := range faveIDs { - _, err := s.GetStatusFaveByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + for _, id := range statusIDs { + // Invalidate any cached status faves for this status. + s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + + // Invalidate any cached status fave IDs for this status. + s.state.Caches.GTS.StatusFaveIDs().Invalidate(id) } - // Finally delete all from DB. - _, err := s.db.NewDelete(). - Table("status_faves"). - Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)). - Exec(ctx) - return s.db.ProcessError(err) + return nil } func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) error { - // Capture fave IDs in a RETURNING statement. - var faveIDs []string - - q := s.db. - NewSelect(). - Column("id"). + // Delete all status faves for status. + if _, err := s.db.NewDelete(). Table("status_faves"). - Where("? = ?", bun.Ident("status_id"), statusID) - if _, err := q.Exec(ctx, &faveIDs); err != nil { + Where("status_id = ?", statusID). + Exec(ctx); err != nil { return s.db.ProcessError(err) } - defer func() { - // Invalidate all IDs on return. - for _, id := range faveIDs { - s.state.Caches.GTS.StatusFave().Invalidate("ID", id) - } - }() + // Invalidate any cached status faves for this status. + s.state.Caches.GTS.StatusFave().Invalidate("ID", statusID) - // Load all faves into cache, this *really* isn't great - // but it is the only way we can ensure we invalidate all - // related caches correctly (e.g. visibility). - for _, id := range faveIDs { - _, err := s.GetStatusFaveByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } - } + // Invalidate any cached status fave IDs for this status. + s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) - // Finally delete all from DB. - _, err := s.db.NewDelete(). - Table("status_faves"). - Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)). - Exec(ctx) - return s.db.ProcessError(err) + return nil } |