diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/account.go | 120 | ||||
| -rw-r--r-- | internal/db/bundb/account_test.go | 28 | 
2 files changed, 114 insertions, 34 deletions
| diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index f7e243f47..d8aee80f4 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -29,6 +29,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/id"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/util" @@ -475,48 +476,41 @@ func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (i  }  func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) { -	statusIDs := []string{} +	// Ensure reasonable +	if limit < 0 { +		limit = 0 +	} + +	// Make educated guess for slice size +	var ( +		statusIDs   = make([]string, 0, limit) +		frontToBack = true +	)  	q := a.conn.  		NewSelect().  		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		// Select only IDs from table  		Column("status.id"). -		Order("status.id DESC") - -	if accountID != "" { -		q = q.Where("? = ?", bun.Ident("status.account_id"), accountID) -	} - -	if limit != 0 { -		q = q.Limit(limit) -	} +		Where("? = ?", bun.Ident("status.account_id"), accountID)  	if excludeReplies { -		// include self-replies (threads) -		whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { +		q = q.WhereGroup(" AND ", func(*bun.SelectQuery) *bun.SelectQuery {  			return q. -				WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). -				WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri")) -		} - -		q = q.WhereGroup(" AND ", whereGroup) +				// Do include self replies (threads), but +				// don't include replies to other people. +				Where("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). +				WhereOr("? IS NULL", bun.Ident("status.in_reply_to_uri")) +		})  	}  	if excludeReblogs { -		q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")) -	} - -	if maxID != "" { -		q = q.Where("? < ?", bun.Ident("status.id"), maxID) -	} - -	if minID != "" { -		q = q.Where("? > ?", bun.Ident("status.id"), minID) +		q = q.Where("? IS NULL", bun.Ident("status.boost_of_id"))  	}  	if mediaOnly { -		// attachments are stored as a json object; -		// this implementation differs between sqlite and postgres, +		// Attachments are stored as a json object; this +		// implementation differs between SQLite and Postgres,  		// so we have to be thorough to cover all eventualities  		q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {  			switch a.conn.Dialect().Name() { @@ -542,10 +536,46 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li  		q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)  	} +	// return only statuses LOWER (ie., older) than maxID +	if maxID == "" { +		maxID = id.Highest +	} +	q = q.Where("? < ?", bun.Ident("status.id"), maxID) + +	if minID != "" { +		// return only statuses HIGHER (ie., newer) than minID +		q = q.Where("? > ?", bun.Ident("status.id"), minID) + +		// page up +		frontToBack = false +	} + +	if limit > 0 { +		// limit amount of statuses returned +		q = q.Limit(limit) +	} + +	if frontToBack { +		// Page down. +		q = q.Order("status.id DESC") +	} else { +		// Page up. +		q = q.Order("status.id ASC") +	} +  	if err := q.Scan(ctx, &statusIDs); err != nil {  		return nil, a.conn.ProcessError(err)  	} +	// If we're paging up, we still want statuses +	// to be sorted by ID desc, so reverse ids slice. +	// https://zchee.github.io/golang-wiki/SliceTricks/#reversing +	if !frontToBack { +		for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 { +			statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l] +		} +	} +  	return a.statusesFromIDs(ctx, statusIDs)  } @@ -568,23 +598,45 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri  }  func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) { -	statusIDs := []string{} +	// Ensure reasonable +	if limit < 0 { +		limit = 0 +	} + +	// Make educated guess for slice size +	statusIDs := make([]string, 0, limit)  	q := a.conn.  		NewSelect().  		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		// Select only IDs from table  		Column("status.id").  		Where("? = ?", bun.Ident("status.account_id"), accountID). -		WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). -		WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). +		// Don't show replies or boosts. +		Where("? IS NULL", bun.Ident("status.in_reply_to_uri")). +		Where("? IS NULL", bun.Ident("status.boost_of_id")). +		// Only Public statuses.  		Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). +		// Don't show local-only statuses on the web view.  		Where("? = ?", bun.Ident("status.federated"), true) -	if maxID != "" { -		q = q.Where("? < ?", bun.Ident("status.id"), maxID) +	// return only statuses LOWER (ie., older) than maxID +	if maxID == "" { +		maxID = id.Highest +	} +	q = q.Where("? < ?", bun.Ident("status.id"), maxID) + +	if limit > 0 { +		// limit amount of statuses returned +		q = q.Limit(limit) +	} + +	if limit > 0 { +		// limit amount of statuses returned +		q = q.Limit(limit)  	} -	q = q.Limit(limit).Order("status.id DESC") +	q = q.Order("status.id DESC")  	if err := q.Scan(ctx, &statusIDs); err != nil {  		return nil, a.conn.ProcessError(err) diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 2241ab783..bfe6df536 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -45,6 +45,34 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() {  	suite.Len(statuses, 5)  } +func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() { +	// get the first page +	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, "", "", false, false) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	suite.Len(statuses, 2) + +	// get the second page +	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	suite.Len(statuses, 2) + +	// get the third page +	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	suite.Len(statuses, 1) + +	// try to get the last page (should be empty) +	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) +	suite.ErrorIs(err, db.ErrNoEntries) +	suite.Empty(statuses) +} +  func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() {  	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false)  	suite.NoError(err) | 
