diff options
Diffstat (limited to 'internal/db/bundb/timeline.go')
-rw-r--r-- | internal/db/bundb/timeline.go | 136 |
1 files changed, 75 insertions, 61 deletions
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index ca5922532..3c0d6d7e4 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -20,55 +20,52 @@ package bundb import ( "context" - "database/sql" - "sort" + "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/uptrace/bun" + "golang.org/x/exp/slices" ) type timelineDB struct { - conn *DBConn + conn *DBConn + status *statusDB } func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { - // Ensure reasonable - if limit < 0 { - limit = 0 - } - // Make educated guess for slice size - statuses := make([]*gtsmodel.Status, 0, limit) + statusIDs := make([]string, 0, limit) q := t.conn. NewSelect(). - Model(&statuses) + Table("statuses"). - q = q.ColumnExpr("status.*"). + // Select only IDs from table + Column("statuses.id"). // Find out who accountID follows. - Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id"). + Join("LEFT JOIN follows ON follows.target_account_id = statuses.account_id AND follows.account_id = ?", accountID). // Sort by highest ID (newest) to lowest ID (oldest) - Order("status.id DESC") + Order("statuses.id DESC") if maxID != "" { // return only statuses LOWER (ie., older) than maxID - q = q.Where("status.id < ?", maxID) + q = q.Where("statuses.id < ?", maxID) } if sinceID != "" { // return only statuses HIGHER (ie., newer) than sinceID - q = q.Where("status.id > ?", sinceID) + q = q.Where("statuses.id > ?", sinceID) } if minID != "" { // return only statuses HIGHER (ie., newer) than minID - q = q.Where("status.id > ?", minID) + q = q.Where("statuses.id > ?", minID) } if local { // return only statuses posted by local account havers - q = q.Where("status.local = ?", local) + q = q.Where("statuses.local = ?", local) } if limit > 0 { @@ -83,15 +80,30 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI // See: https://bun.uptrace.dev/guide/queries.html#select whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { return q. - WhereOr("f.account_id = ?", accountID). - WhereOr("status.account_id = ?", accountID) + WhereOr("follows.account_id = ?", accountID). + WhereOr("statuses.account_id = ?", accountID) } q = q.WhereGroup(" AND ", whereGroup) - if err := q.Scan(ctx); err != nil { + if err := q.Scan(ctx, &statusIDs); err != nil { return nil, t.conn.ProcessError(err) } + + statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) + + for _, id := range statusIDs { + // Fetch status from db for ID + status, err := t.status.GetStatusByID(ctx, id) + if err != nil { + logrus.Errorf("GetHomeTimeline: error fetching status %q: %v", id, err) + continue + } + + // Append status to slice + statuses = append(statuses, status) + } + return statuses, nil } @@ -102,40 +114,56 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma } // Make educated guess for slice size - statuses := make([]*gtsmodel.Status, 0, limit) + statusIDs := make([]string, 0, limit) q := t.conn. NewSelect(). - Model(&statuses). - Where("visibility = ?", gtsmodel.VisibilityPublic). - WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_id")). - WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")). - WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")). - Order("status.id DESC") + Table("statuses"). + Column("statuses.id"). + Where("statuses.visibility = ?", gtsmodel.VisibilityPublic). + WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_id")). + WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_uri")). + WhereGroup(" AND ", whereEmptyOrNull("statuses.boost_of_id")). + Order("statuses.id DESC") if maxID != "" { - q = q.Where("status.id < ?", maxID) + q = q.Where("statuses.id < ?", maxID) } if sinceID != "" { - q = q.Where("status.id > ?", sinceID) + q = q.Where("statuses.id > ?", sinceID) } if minID != "" { - q = q.Where("status.id > ?", minID) + q = q.Where("statuses.id > ?", minID) } if local { - q = q.Where("status.local = ?", local) + q = q.Where("statuses.local = ?", local) } if limit > 0 { q = q.Limit(limit) } - if err := q.Scan(ctx); err != nil { + if err := q.Scan(ctx, &statusIDs); err != nil { return nil, t.conn.ProcessError(err) } + + statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) + + for _, id := range statusIDs { + // Fetch status from db for ID + status, err := t.status.GetStatusByID(ctx, id) + if err != nil { + logrus.Errorf("GetPublicTimeline: error fetching status %q: %v", id, err) + continue + } + + // Append status to slice + statuses = append(statuses, status) + } + return statuses, nil } @@ -170,46 +198,32 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max err := fq.Scan(ctx) if err != nil { - if err == sql.ErrNoRows { - return nil, "", "", db.ErrNoEntries - } - return nil, "", "", err + return nil, "", "", t.conn.ProcessError(err) } if len(faves) == 0 { return nil, "", "", db.ErrNoEntries } - // map[statusID]faveID -- we need this to sort statuses by fave ID rather than status ID - statusesFavesMap := make(map[string]string, len(faves)) - statusIDs := make([]string, 0, len(faves)) - for _, f := range faves { - statusesFavesMap[f.StatusID] = f.ID - statusIDs = append(statusIDs, f.StatusID) - } + // Sort by favourite ID rather than status ID + slices.SortFunc(faves, func(a, b *gtsmodel.StatusFave) bool { + return a.ID < b.ID + }) - statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) + statuses := make([]*gtsmodel.Status, 0, len(faves)) - err = t.conn. - NewSelect(). - Model(&statuses). - Where("id IN (?)", bun.In(statusIDs)). - Scan(ctx) - if err != nil { - return nil, "", "", t.conn.ProcessError(err) - } + for _, fave := range faves { + // Fetch status from db for corresponding favourite + status, err := t.status.GetStatusByID(ctx, fave.StatusID) + if err != nil { + logrus.Errorf("GetFavedTimeline: error fetching status for fave %q: %v", fave.ID, err) + continue + } - if len(statuses) == 0 { - return nil, "", "", db.ErrNoEntries + // Append status to slice + statuses = append(statuses, status) } - // arrange statuses by fave ID - sort.Slice(statuses, func(i int, j int) bool { - statusI := statuses[i] - statusJ := statuses[j] - return statusesFavesMap[statusI.ID] < statusesFavesMap[statusJ.ID] - }) - nextMaxID := faves[len(faves)-1].ID prevMinID := faves[0].ID return statuses, nextMaxID, prevMinID, nil |