summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/account.go120
-rw-r--r--internal/db/bundb/account_test.go28
2 files changed, 114 insertions, 34 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index f7e243f47..d8aee80f4 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
@@ -475,48 +476,41 @@ func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (i
}
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) {
- statusIDs := []string{}
+ // Ensure reasonable
+ if limit < 0 {
+ limit = 0
+ }
+
+ // Make educated guess for slice size
+ var (
+ statusIDs = make([]string, 0, limit)
+ frontToBack = true
+ )
q := a.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ // Select only IDs from table
Column("status.id").
- Order("status.id DESC")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("status.account_id"), accountID)
- }
-
- if limit != 0 {
- q = q.Limit(limit)
- }
+ Where("? = ?", bun.Ident("status.account_id"), accountID)
if excludeReplies {
- // include self-replies (threads)
- whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
+ q = q.WhereGroup(" AND ", func(*bun.SelectQuery) *bun.SelectQuery {
return q.
- WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID).
- WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri"))
- }
-
- q = q.WhereGroup(" AND ", whereGroup)
+ // Do include self replies (threads), but
+ // don't include replies to other people.
+ Where("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID).
+ WhereOr("? IS NULL", bun.Ident("status.in_reply_to_uri"))
+ })
}
if excludeReblogs {
- q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id"))
- }
-
- if maxID != "" {
- q = q.Where("? < ?", bun.Ident("status.id"), maxID)
- }
-
- if minID != "" {
- q = q.Where("? > ?", bun.Ident("status.id"), minID)
+ q = q.Where("? IS NULL", bun.Ident("status.boost_of_id"))
}
if mediaOnly {
- // attachments are stored as a json object;
- // this implementation differs between sqlite and postgres,
+ // Attachments are stored as a json object; this
+ // implementation differs between SQLite and Postgres,
// so we have to be thorough to cover all eventualities
q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
switch a.conn.Dialect().Name() {
@@ -542,10 +536,46 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)
}
+ // return only statuses LOWER (ie., older) than maxID
+ if maxID == "" {
+ maxID = id.Highest
+ }
+ q = q.Where("? < ?", bun.Ident("status.id"), maxID)
+
+ if minID != "" {
+ // return only statuses HIGHER (ie., newer) than minID
+ q = q.Where("? > ?", bun.Ident("status.id"), minID)
+
+ // page up
+ frontToBack = false
+ }
+
+ if limit > 0 {
+ // limit amount of statuses returned
+ q = q.Limit(limit)
+ }
+
+ if frontToBack {
+ // Page down.
+ q = q.Order("status.id DESC")
+ } else {
+ // Page up.
+ q = q.Order("status.id ASC")
+ }
+
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err)
}
+ // If we're paging up, we still want statuses
+ // to be sorted by ID desc, so reverse ids slice.
+ // https://zchee.github.io/golang-wiki/SliceTricks/#reversing
+ if !frontToBack {
+ for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 {
+ statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l]
+ }
+ }
+
return a.statusesFromIDs(ctx, statusIDs)
}
@@ -568,23 +598,45 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri
}
func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) {
- statusIDs := []string{}
+ // Ensure reasonable
+ if limit < 0 {
+ limit = 0
+ }
+
+ // Make educated guess for slice size
+ statusIDs := make([]string, 0, limit)
q := a.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ // Select only IDs from table
Column("status.id").
Where("? = ?", bun.Ident("status.account_id"), accountID).
- WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
- WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
+ // Don't show replies or boosts.
+ Where("? IS NULL", bun.Ident("status.in_reply_to_uri")).
+ Where("? IS NULL", bun.Ident("status.boost_of_id")).
+ // Only Public statuses.
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
+ // Don't show local-only statuses on the web view.
Where("? = ?", bun.Ident("status.federated"), true)
- if maxID != "" {
- q = q.Where("? < ?", bun.Ident("status.id"), maxID)
+ // return only statuses LOWER (ie., older) than maxID
+ if maxID == "" {
+ maxID = id.Highest
+ }
+ q = q.Where("? < ?", bun.Ident("status.id"), maxID)
+
+ if limit > 0 {
+ // limit amount of statuses returned
+ q = q.Limit(limit)
+ }
+
+ if limit > 0 {
+ // limit amount of statuses returned
+ q = q.Limit(limit)
}
- q = q.Limit(limit).Order("status.id DESC")
+ q = q.Order("status.id DESC")
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err)
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index 2241ab783..bfe6df536 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -45,6 +45,34 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() {
suite.Len(statuses, 5)
}
+func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
+ // get the first page
+ statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, "", "", false, false)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.Len(statuses, 2)
+
+ // get the second page
+ statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.Len(statuses, 2)
+
+ // get the third page
+ statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.Len(statuses, 1)
+
+ // try to get the last page (should be empty)
+ statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false)
+ suite.ErrorIs(err, db.ErrNoEntries)
+ suite.Empty(statuses)
+}
+
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() {
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false)
suite.NoError(err)