summaryrefslogtreecommitdiff
path: root/internal/db/bundb/statusfave.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/statusfave.go')
-rw-r--r--internal/db/bundb/statusfave.go216
1 files changed, 106 insertions, 110 deletions
diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go
index 7aff543fd..ab09fb1ba 100644
--- a/internal/db/bundb/statusfave.go
+++ b/internal/db/bundb/statusfave.go
@@ -19,6 +19,7 @@ package bundb
import (
"context"
+ "database/sql"
"errors"
"fmt"
@@ -44,8 +45,14 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, stat
return s.db.
NewSelect().
Model(fave).
- Where("? = ?", bun.Ident("account_id"), accountID).
- Where("? = ?", bun.Ident("status_id"), statusID).
+ Where("status_fave.account_id = ?", accountID).
+ Where("status_fave.status_id = ?", statusID).
+
+ // Our old code actually allowed a status to
+ // be faved multiple times by the same author,
+ // so limit our query + order to fetch latest.
+ Order("status_fave.id DESC"). // our IDs are timestamped
+ Limit(1).
Scan(ctx)
},
accountID,
@@ -89,63 +96,68 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery
return fave, nil
}
- // Fetch the status fave author account.
- fave.Account, err = s.state.DB.GetAccountByID(
- gtscontext.SetBarebones(ctx),
- fave.AccountID,
- )
- if err != nil {
- return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err)
- }
-
- // Fetch the status fave target account.
- fave.TargetAccount, err = s.state.DB.GetAccountByID(
- gtscontext.SetBarebones(ctx),
- fave.TargetAccountID,
- )
- if err != nil {
- return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err)
- }
-
- // Fetch the status fave target status.
- fave.Status, err = s.state.DB.GetStatusByID(
- gtscontext.SetBarebones(ctx),
- fave.StatusID,
- )
- if err != nil {
- return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err)
+ // Populate the status favourite model.
+ if err := s.PopulateStatusFave(ctx, fave); err != nil {
+ return nil, fmt.Errorf("error(s) populating status fave: %w", err)
}
return fave, nil
}
-func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) {
- ids := []string{}
-
- if err := s.db.
- NewSelect().
- Table("status_faves").
- Column("id").
- Where("? = ?", bun.Ident("status_id"), statusID).
- Scan(ctx, &ids); err != nil {
- return nil, s.db.ProcessError(err)
+func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) {
+ // Fetch the status fave IDs for status.
+ faveIDs, err := s.getStatusFaveIDs(ctx, statusID)
+ if err != nil {
+ return nil, err
}
- faves := make([]*gtsmodel.StatusFave, 0, len(ids))
+ // Preallocate a slice of expected status fave capacity.
+ faves := make([]*gtsmodel.StatusFave, 0, len(faveIDs))
- for _, id := range ids {
+ for _, id := range faveIDs {
+ // Fetch status fave model for each ID.
fave, err := s.GetStatusFaveByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status fave %q: %v", id, err)
continue
}
-
faves = append(faves, fave)
}
return faves, nil
}
+func (s *statusFaveDB) IsStatusFavedBy(ctx context.Context, statusID string, accountID string) (bool, error) {
+ fave, err := s.GetStatusFave(ctx, accountID, statusID)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return false, err
+ }
+ return (fave != nil), nil
+}
+
+func (s *statusFaveDB) CountStatusFaves(ctx context.Context, statusID string) (int, error) {
+ faveIDs, err := s.getStatusFaveIDs(ctx, statusID)
+ return len(faveIDs), err
+}
+
+func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) {
+ return s.state.Caches.GTS.StatusFaveIDs().Load(statusID, func() ([]string, error) {
+ var faveIDs []string
+
+ // Status fave IDs not in cache, perform DB query!
+ if err := s.db.
+ NewSelect().
+ Table("status_faves").
+ Column("id").
+ Where("? = ?", bun.Ident("status_id"), statusID).
+ Scan(ctx, &faveIDs); err != nil {
+ return nil, s.db.ProcessError(err)
+ }
+
+ return faveIDs, nil
+ })
+}
+
func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error {
var (
err error
@@ -203,26 +215,32 @@ func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusF
}
func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) error {
- defer s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ var statusID string
- // Load fave into cache before attempting a delete,
- // as we need it cached in order to trigger the invalidate
- // callback. This in turn invalidates others.
- _, err := s.GetStatusFaveByID(gtscontext.SetBarebones(ctx), id)
- if err != nil {
- if errors.Is(err, db.ErrNoEntries) {
- // not an issue.
+ // Perform DELETE on status fave,
+ // returning the status ID it was for.
+ if _, err := s.db.NewDelete().
+ Table("status_faves").
+ Where("id = ?", id).
+ Returning("status_id").
+ Exec(ctx, &statusID); err != nil {
+ if err == sql.ErrNoRows {
+ // Not an issue, only due
+ // to us doing a RETURNING.
err = nil
}
- return err
+ return s.db.ProcessError(err)
}
- // Finally delete fave from DB.
- _, err = s.db.NewDelete().
- Table("status_faves").
- Where("? = ?", bun.Ident("id"), id).
- Exec(ctx)
- return s.db.ProcessError(err)
+ if statusID != "" {
+ // Invalidate any cached status faves for this status.
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+
+ // Invalidate any cached status fave IDs for this status.
+ s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID)
+ }
+
+ return nil
}
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error {
@@ -230,12 +248,13 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set")
}
- var faveIDs []string
+ var statusIDs []string
- q := s.db.
- NewSelect().
- Column("id").
- Table("status_faves")
+ // Prepare DELETE query returning
+ // the deleted faves for status IDs.
+ q := s.db.NewDelete().
+ Table("status_faves").
+ Returning("status_id")
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID)
@@ -245,69 +264,46 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
q = q.Where("? = ?", bun.Ident("account_id"), originAccountID)
}
- if _, err := q.Exec(ctx, &faveIDs); err != nil {
+ // Execute query, store favourited status IDs.
+ if _, err := q.Exec(ctx, &statusIDs); err != nil {
+ if err == sql.ErrNoRows {
+ // Not an issue, only due
+ // to us doing a RETURNING.
+ err = nil
+ }
return s.db.ProcessError(err)
}
- defer func() {
- // Invalidate all IDs on return.
- for _, id := range faveIDs {
- s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
- }
- }()
+ // Collate (deduplicating) status IDs.
+ statusIDs = collate(func(i int) string {
+ return statusIDs[i]
+ }, len(statusIDs))
- // Load all faves into cache, this *really* isn't great
- // but it is the only way we can ensure we invalidate all
- // related caches correctly (e.g. visibility).
- for _, id := range faveIDs {
- _, err := s.GetStatusFaveByID(ctx, id)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return err
- }
+ for _, id := range statusIDs {
+ // Invalidate any cached status faves for this status.
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+
+ // Invalidate any cached status fave IDs for this status.
+ s.state.Caches.GTS.StatusFaveIDs().Invalidate(id)
}
- // Finally delete all from DB.
- _, err := s.db.NewDelete().
- Table("status_faves").
- Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)).
- Exec(ctx)
- return s.db.ProcessError(err)
+ return nil
}
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) error {
- // Capture fave IDs in a RETURNING statement.
- var faveIDs []string
-
- q := s.db.
- NewSelect().
- Column("id").
+ // Delete all status faves for status.
+ if _, err := s.db.NewDelete().
Table("status_faves").
- Where("? = ?", bun.Ident("status_id"), statusID)
- if _, err := q.Exec(ctx, &faveIDs); err != nil {
+ Where("status_id = ?", statusID).
+ Exec(ctx); err != nil {
return s.db.ProcessError(err)
}
- defer func() {
- // Invalidate all IDs on return.
- for _, id := range faveIDs {
- s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
- }
- }()
+ // Invalidate any cached status faves for this status.
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", statusID)
- // Load all faves into cache, this *really* isn't great
- // but it is the only way we can ensure we invalidate all
- // related caches correctly (e.g. visibility).
- for _, id := range faveIDs {
- _, err := s.GetStatusFaveByID(ctx, id)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return err
- }
- }
+ // Invalidate any cached status fave IDs for this status.
+ s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID)
- // Finally delete all from DB.
- _, err := s.db.NewDelete().
- Table("status_faves").
- Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)).
- Exec(ctx)
- return s.db.ProcessError(err)
+ return nil
}