diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/report.go | 45 | ||||
| -rw-r--r-- | internal/db/bundb/report_test.go | 101 | 
2 files changed, 135 insertions, 11 deletions
| diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go index 486bf09f0..f99f0b5cc 100644 --- a/internal/db/bundb/report.go +++ b/internal/db/bundb/report.go @@ -20,6 +20,7 @@ package bundb  import (  	"context"  	"errors" +	"slices"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -27,6 +28,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/uptrace/bun"  ) @@ -51,14 +53,23 @@ func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Repo  	)  } -func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, error) { -	reportIDs := []string{} +func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, page *paging.Page) ([]*gtsmodel.Report, error) { +	var ( +		// Get paging params. +		minID = page.GetMin() +		maxID = page.GetMax() +		limit = page.GetLimit() +		order = page.GetOrder() + +		// Make educated guess for slice size +		reportIDs = make([]string, 0, limit) +	)  	q := r.db.  		NewSelect().  		TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")). -		Column("report.id"). -		Order("report.id DESC") +		// Select only IDs from table. +		Column("report.id")  	if resolved != nil {  		i := bun.Ident("report.action_taken_by_account_id") @@ -77,22 +88,32 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str  		q = q.Where("? = ?", bun.Ident("report.target_account_id"), targetAccountID)  	} +	// Return only reports with id +	// lower than provided maxID.  	if maxID != "" {  		q = q.Where("? < ?", bun.Ident("report.id"), maxID)  	} -	if sinceID != "" { -		q = q.Where("? > ?", bun.Ident("report.id"), minID) -	} - +	// Return only reports with id +	// greater than provided minID.  	if minID != "" {  		q = q.Where("? > ?", bun.Ident("report.id"), minID)  	} -	if limit != 0 { +	if limit > 0 { +		// Limit amount of +		// reports returned.  		q = q.Limit(limit)  	} +	if order == paging.OrderAscending { +		// Page up. +		q = q.OrderExpr("? ASC", bun.Ident("report.id")) +	} else { +		// Page down. +		q = q.OrderExpr("? DESC", bun.Ident("report.id")) +	} +  	if err := q.Scan(ctx, &reportIDs); err != nil {  		return nil, err  	} @@ -102,6 +123,12 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str  		return nil, db.ErrNoEntries  	} +	// If we're paging up, we still want reports +	// to be sorted by ID desc, so reverse ids slice. +	if order == paging.OrderAscending { +		slices.Reverse(reportIDs) +	} +  	// Allocate return slice (will be at most len reportIDs)  	reports := make([]*gtsmodel.Report, 0, len(reportIDs))  	for _, id := range reportIDs { diff --git a/internal/db/bundb/report_test.go b/internal/db/bundb/report_test.go index 594b0b7aa..1a488c729 100644 --- a/internal/db/bundb/report_test.go +++ b/internal/db/bundb/report_test.go @@ -24,6 +24,8 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/id" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -61,14 +63,109 @@ func (suite *ReportTestSuite) TestGetReportByURI() {  }  func (suite *ReportTestSuite) TestGetAllReports() { -	reports, err := suite.db.GetReports(context.Background(), nil, "", "", "", "", "", 0) +	reports, err := suite.db.GetReports( +		context.Background(), +		nil, +		"", +		"", +		&paging.Page{}, +	)  	suite.NoError(err)  	suite.NotEmpty(reports)  } +func (suite *ReportTestSuite) TestReportPagingDown() { +	// Get one from the top. +	reports1, err := suite.db.GetReports( +		context.Background(), +		nil, +		"", +		"", +		&paging.Page{ +			Limit: 1, +		}, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	if l := len(reports1); l != 1 { +		suite.FailNowf("", "expected reports len 1, got %d", l) +	} +	id1 := reports1[0].ID + +	// Use this one to page down. +	reports2, err := suite.db.GetReports( +		context.Background(), +		nil, +		"", +		"", +		&paging.Page{ +			Limit: 1, +			Max:   paging.MaxID(id1), +		}, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	if l := len(reports2); l != 1 { +		suite.FailNowf("", "expected reports len 1, got %d", l) +	} +	id2 := reports2[0].ID + +	suite.Greater(id1, id2) +} + +func (suite *ReportTestSuite) TestReportPagingUp() { +	// Get one from the bottom. +	reports1, err := suite.db.GetReports( +		context.Background(), +		nil, +		"", +		"", +		&paging.Page{ +			Limit: 1, +			Min:   paging.MinID(id.Lowest), +		}, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	if l := len(reports1); l != 1 { +		suite.FailNowf("", "expected reports len 1, got %d", l) +	} +	id1 := reports1[0].ID + +	// Use this one to page up. +	reports2, err := suite.db.GetReports( +		context.Background(), +		nil, +		"", +		"", +		&paging.Page{ +			Limit: 1, +			Min:   paging.MinID(id1), +		}, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	if l := len(reports2); l != 1 { +		suite.FailNowf("", "expected reports len 1, got %d", l) +	} +	id2 := reports2[0].ID + +	suite.Less(id1, id2) +} +  func (suite *ReportTestSuite) TestGetAllReportsByAccountID() {  	accountID := suite.testAccounts["local_account_2"].ID -	reports, err := suite.db.GetReports(context.Background(), nil, accountID, "", "", "", "", 0) +	reports, err := suite.db.GetReports( +		context.Background(), +		nil, +		accountID, +		"", +		&paging.Page{}, +	)  	suite.NoError(err)  	suite.NotEmpty(reports)  	for _, r := range reports { | 
