diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/account.go | 115 | ||||
| -rw-r--r-- | internal/db/bundb/account_test.go | 74 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20240426122821_pageable_admin_accounts.go | 84 | 
3 files changed, 246 insertions, 27 deletions
| diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 2b3c78aff..4e969e0ef 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -252,6 +252,32 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts  	return a.GetAccountByUsernameDomain(ctx, username, domain)  } +// GetAccounts selects accounts using the given parameters. +// Unlike with other functions, the paging for GetAccounts +// is done not by ID, but by a concatenation of `[domain]/@[username]`, +// which allows callers to page through accounts in alphabetical +// order (much more useful for an admin overview of accounts, +// for example, than paging by ID (which is random) or by account +// created at date, which is not particularly interesting). +// +// Generated queries will look something like this +// (SQLite example, maxID was provided so we're paging down): +// +//	SELECT "account"."id", (COALESCE("domain", '') || '/@' || "username") AS "domain_username" +//	FROM "accounts" AS "account" +//	WHERE ("domain_username" > '/@the_mighty_zork') +//	ORDER BY "domain_username" ASC +// +// **NOTE ABOUT POSTGRES**: Postgres ordering expressions in +// this function specify COLLATE "C" to ensure that ordering +// is similar to SQLite (which uses BINARY ordering by default). +// This unfortunately means that A-Z > a-z, when ordering but +// that's an acceptable tradeoff for a query like this. +// +// See: +// +//   - https://www.postgresql.org/docs/current/collation.html#COLLATION-MANAGING-STANDARD +//   - https://sqlite.org/datatype3.html#collation  func (a *accountDB) GetAccounts(  	ctx context.Context,  	origin string, @@ -269,6 +295,11 @@ func (a *accountDB) GetAccounts(  	error,  ) {  	var ( +		// We have to use different +		// syntax for this query +		// depending on dialect. +		dbDialect = a.db.Dialect().Name() +  		// local users lists,  		// required for some  		// limiting parameters. @@ -287,10 +318,6 @@ func (a *accountDB) GetAccounts(  		}  		// Get paging params. -		// -		// Note this may be min_id OR since_id -		// from the API, this gets handled below -		// when checking order to reverse slice.  		minID = page.GetMin()  		maxID = page.GetMax()  		limit = page.GetLimit() @@ -309,32 +336,50 @@ func (a *accountDB) GetAccounts(  		// Select only IDs from table  		Column("account.id") -	// Return only accounts OLDER -	// than account with maxID. -	if maxID != "" { -		maxIDAcct, err := a.GetAccountByID( -			gtscontext.SetBarebones(ctx), -			maxID, +	var subQ *bun.RawQuery +	if dbDialect == dialect.SQLite { +		// For SQLite we can just select +		// our indexed expression once +		// as a column alias. +		q = q.ColumnExpr( +			"(COALESCE(?, ?) || ? || ?) AS ?", +			bun.Ident("domain"), "", +			"/@", +			bun.Ident("username"), +			bun.Ident("domain_username"),  		) -		if err != nil { -			return nil, fmt.Errorf("error getting maxID account %s: %w", maxID, err) -		} +	} else { +		// Create a subquery for +		// Postgres to reuse. +		subQ = a.db.NewRaw( +			"(COALESCE(?, ?) || ? || ?) COLLATE ?", +			bun.Ident("domain"), "", +			"/@", +			bun.Ident("username"), +			bun.Ident("C"), +		) +	} -		q = q.Where("? < ?", bun.Ident("account.created_at"), maxIDAcct.CreatedAt) +	// Return only accounts with `[domain]/@[username]` +	// later in the alphabet (a-z) than provided maxID. +	if maxID != "" { +		if dbDialect == dialect.SQLite { +			// Use aliased column. +			q = q.Where("? > ?", bun.Ident("domain_username"), maxID) +		} else { +			q = q.Where("? > ?", subQ, maxID) +		}  	} -	// Return only accounts NEWER -	// than account with minID. +	// Return only accounts with `[domain]/@[username]` +	// earlier in the alphabet (a-z) than provided minID.  	if minID != "" { -		minIDAcct, err := a.GetAccountByID( -			gtscontext.SetBarebones(ctx), -			minID, -		) -		if err != nil { -			return nil, fmt.Errorf("error getting minID account %s: %w", minID, err) +		if dbDialect == dialect.SQLite { +			// Use aliased column. +			q = q.Where("? < ?", bun.Ident("domain_username"), minID) +		} else { +			q = q.Where("? < ?", subQ, minID)  		} - -		q = q.Where("? > ?", bun.Ident("account.created_at"), minIDAcct.CreatedAt)  	}  	switch status { @@ -479,13 +524,29 @@ func (a *accountDB) GetAccounts(  	if order == paging.OrderAscending {  		// Page up. -		q = q.Order("account.created_at ASC") +		// It's counterintuitive because it +		// says DESC in the query, but we're +		// going backwards in the alphabet, +		// and a < z in a string comparison. +		if dbDialect == dialect.SQLite { +			q = q.OrderExpr("? DESC", bun.Ident("domain_username")) +		} else { +			q = q.OrderExpr("(?) DESC", subQ) +		}  	} else {  		// Page down. -		q = q.Order("account.created_at DESC") +		// It's counterintuitive because it +		// says ASC in the query, but we're +		// going forwards in the alphabet, +		// and z > a in a string comparison. +		if dbDialect == dialect.SQLite { +			q = q.OrderExpr("? ASC", bun.Ident("domain_username")) +		} else { +			q = q.OrderExpr("? ASC", subQ) +		}  	} -	if err := q.Scan(ctx, &accountIDs); err != nil { +	if err := q.Scan(ctx, &accountIDs, new([]string)); err != nil {  		return nil, err  	} diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index ea211e16f..5ed5d91a1 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -502,6 +502,80 @@ func (suite *AccountTestSuite) TestGetAccountsAll() {  	suite.Len(accounts, 9)  } +func (suite *AccountTestSuite) TestGetAccountsMaxID() { +	var ( +		ctx         = context.Background() +		origin      = "" +		status      = "" +		mods        = false +		invitedBy   = "" +		username    = "" +		displayName = "" +		domain      = "" +		email       = "" +		ip          netip.Addr +		// Get accounts with `[domain]/@[username]` +		// later in the alphabet than `/@the_mighty_zork`. +		page = &paging.Page{Max: paging.MaxID("/@the_mighty_zork")} +	) + +	accounts, err := suite.db.GetAccounts( +		ctx, +		origin, +		status, +		mods, +		invitedBy, +		username, +		displayName, +		domain, +		email, +		ip, +		page, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.Len(accounts, 5) +} + +func (suite *AccountTestSuite) TestGetAccountsMinID() { +	var ( +		ctx         = context.Background() +		origin      = "" +		status      = "" +		mods        = false +		invitedBy   = "" +		username    = "" +		displayName = "" +		domain      = "" +		email       = "" +		ip          netip.Addr +		// Get accounts with `[domain]/@[username]` +		// earlier in the alphabet than `/@the_mighty_zork`. +		page = &paging.Page{Min: paging.MinID("/@the_mighty_zork")} +	) + +	accounts, err := suite.db.GetAccounts( +		ctx, +		origin, +		status, +		mods, +		invitedBy, +		username, +		displayName, +		domain, +		email, +		ip, +		page, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.Len(accounts, 3) +} +  func (suite *AccountTestSuite) TestGetAccountsModsOnly() {  	var (  		ctx         = context.Background() diff --git a/internal/db/bundb/migrations/20240426122821_pageable_admin_accounts.go b/internal/db/bundb/migrations/20240426122821_pageable_admin_accounts.go new file mode 100644 index 000000000..00465cc85 --- /dev/null +++ b/internal/db/bundb/migrations/20240426122821_pageable_admin_accounts.go @@ -0,0 +1,84 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/uptrace/bun" +	"github.com/uptrace/bun/dialect" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		log.Info(ctx, "reindexing accounts (accounts_paging_idx); this may take a few minutes, please don't interrupt this migration!") + +		q := db.NewCreateIndex(). +			TableExpr("accounts"). +			Index("accounts_paging_idx"). +			IfNotExists() + +		switch d := db.Dialect().Name(); d { +		case dialect.SQLite: +			q = q.ColumnExpr( +				"COALESCE(?, ?) || ? || ?", +				bun.Ident("domain"), "", +				"/@", +				bun.Ident("username"), +			) + +		// Specify C collation for Postgres to ensure +		// alphabetic sort order is similar enough to +		// SQLite (which uses BINARY sort by default). +		// +		// See: +		// +		//   - https://www.postgresql.org/docs/current/collation.html#COLLATION-MANAGING-STANDARD +		//   - https://sqlite.org/datatype3.html#collation +		case dialect.PG: +			q = q.ColumnExpr( +				"(COALESCE(?, ?) || ? || ?) COLLATE ?", +				bun.Ident("domain"), "", +				"/@", +				bun.Ident("username"), +				bun.Ident("C"), +			) + +		default: +			log.Panicf(ctx, "dialect %s was neither postgres nor sqlite", d) +		} + +		if _, err := q.Exec(ctx); err != nil { +			return err +		} + +		return nil + +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} | 
