diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/account.go | 115 | ||||
-rw-r--r-- | internal/db/bundb/account_test.go | 74 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20240426122821_pageable_admin_accounts.go | 84 |
3 files changed, 246 insertions, 27 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 2b3c78aff..4e969e0ef 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -252,6 +252,32 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts return a.GetAccountByUsernameDomain(ctx, username, domain) } +// GetAccounts selects accounts using the given parameters. +// Unlike with other functions, the paging for GetAccounts +// is done not by ID, but by a concatenation of `[domain]/@[username]`, +// which allows callers to page through accounts in alphabetical +// order (much more useful for an admin overview of accounts, +// for example, than paging by ID (which is random) or by account +// created at date, which is not particularly interesting). +// +// Generated queries will look something like this +// (SQLite example, maxID was provided so we're paging down): +// +// SELECT "account"."id", (COALESCE("domain", '') || '/@' || "username") AS "domain_username" +// FROM "accounts" AS "account" +// WHERE ("domain_username" > '/@the_mighty_zork') +// ORDER BY "domain_username" ASC +// +// **NOTE ABOUT POSTGRES**: Postgres ordering expressions in +// this function specify COLLATE "C" to ensure that ordering +// is similar to SQLite (which uses BINARY ordering by default). +// This unfortunately means that A-Z > a-z, when ordering but +// that's an acceptable tradeoff for a query like this. +// +// See: +// +// - https://www.postgresql.org/docs/current/collation.html#COLLATION-MANAGING-STANDARD +// - https://sqlite.org/datatype3.html#collation func (a *accountDB) GetAccounts( ctx context.Context, origin string, @@ -269,6 +295,11 @@ func (a *accountDB) GetAccounts( error, ) { var ( + // We have to use different + // syntax for this query + // depending on dialect. + dbDialect = a.db.Dialect().Name() + // local users lists, // required for some // limiting parameters. @@ -287,10 +318,6 @@ func (a *accountDB) GetAccounts( } // Get paging params. - // - // Note this may be min_id OR since_id - // from the API, this gets handled below - // when checking order to reverse slice. minID = page.GetMin() maxID = page.GetMax() limit = page.GetLimit() @@ -309,32 +336,50 @@ func (a *accountDB) GetAccounts( // Select only IDs from table Column("account.id") - // Return only accounts OLDER - // than account with maxID. - if maxID != "" { - maxIDAcct, err := a.GetAccountByID( - gtscontext.SetBarebones(ctx), - maxID, + var subQ *bun.RawQuery + if dbDialect == dialect.SQLite { + // For SQLite we can just select + // our indexed expression once + // as a column alias. + q = q.ColumnExpr( + "(COALESCE(?, ?) || ? || ?) AS ?", + bun.Ident("domain"), "", + "/@", + bun.Ident("username"), + bun.Ident("domain_username"), ) - if err != nil { - return nil, fmt.Errorf("error getting maxID account %s: %w", maxID, err) - } + } else { + // Create a subquery for + // Postgres to reuse. + subQ = a.db.NewRaw( + "(COALESCE(?, ?) || ? || ?) COLLATE ?", + bun.Ident("domain"), "", + "/@", + bun.Ident("username"), + bun.Ident("C"), + ) + } - q = q.Where("? < ?", bun.Ident("account.created_at"), maxIDAcct.CreatedAt) + // Return only accounts with `[domain]/@[username]` + // later in the alphabet (a-z) than provided maxID. + if maxID != "" { + if dbDialect == dialect.SQLite { + // Use aliased column. + q = q.Where("? > ?", bun.Ident("domain_username"), maxID) + } else { + q = q.Where("? > ?", subQ, maxID) + } } - // Return only accounts NEWER - // than account with minID. + // Return only accounts with `[domain]/@[username]` + // earlier in the alphabet (a-z) than provided minID. if minID != "" { - minIDAcct, err := a.GetAccountByID( - gtscontext.SetBarebones(ctx), - minID, - ) - if err != nil { - return nil, fmt.Errorf("error getting minID account %s: %w", minID, err) + if dbDialect == dialect.SQLite { + // Use aliased column. + q = q.Where("? < ?", bun.Ident("domain_username"), minID) + } else { + q = q.Where("? < ?", subQ, minID) } - - q = q.Where("? > ?", bun.Ident("account.created_at"), minIDAcct.CreatedAt) } switch status { @@ -479,13 +524,29 @@ func (a *accountDB) GetAccounts( if order == paging.OrderAscending { // Page up. - q = q.Order("account.created_at ASC") + // It's counterintuitive because it + // says DESC in the query, but we're + // going backwards in the alphabet, + // and a < z in a string comparison. + if dbDialect == dialect.SQLite { + q = q.OrderExpr("? DESC", bun.Ident("domain_username")) + } else { + q = q.OrderExpr("(?) DESC", subQ) + } } else { // Page down. - q = q.Order("account.created_at DESC") + // It's counterintuitive because it + // says ASC in the query, but we're + // going forwards in the alphabet, + // and z > a in a string comparison. + if dbDialect == dialect.SQLite { + q = q.OrderExpr("? ASC", bun.Ident("domain_username")) + } else { + q = q.OrderExpr("? ASC", subQ) + } } - if err := q.Scan(ctx, &accountIDs); err != nil { + if err := q.Scan(ctx, &accountIDs, new([]string)); err != nil { return nil, err } diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index ea211e16f..5ed5d91a1 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -502,6 +502,80 @@ func (suite *AccountTestSuite) TestGetAccountsAll() { suite.Len(accounts, 9) } +func (suite *AccountTestSuite) TestGetAccountsMaxID() { + var ( + ctx = context.Background() + origin = "" + status = "" + mods = false + invitedBy = "" + username = "" + displayName = "" + domain = "" + email = "" + ip netip.Addr + // Get accounts with `[domain]/@[username]` + // later in the alphabet than `/@the_mighty_zork`. + page = &paging.Page{Max: paging.MaxID("/@the_mighty_zork")} + ) + + accounts, err := suite.db.GetAccounts( + ctx, + origin, + status, + mods, + invitedBy, + username, + displayName, + domain, + email, + ip, + page, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Len(accounts, 5) +} + +func (suite *AccountTestSuite) TestGetAccountsMinID() { + var ( + ctx = context.Background() + origin = "" + status = "" + mods = false + invitedBy = "" + username = "" + displayName = "" + domain = "" + email = "" + ip netip.Addr + // Get accounts with `[domain]/@[username]` + // earlier in the alphabet than `/@the_mighty_zork`. + page = &paging.Page{Min: paging.MinID("/@the_mighty_zork")} + ) + + accounts, err := suite.db.GetAccounts( + ctx, + origin, + status, + mods, + invitedBy, + username, + displayName, + domain, + email, + ip, + page, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Len(accounts, 3) +} + func (suite *AccountTestSuite) TestGetAccountsModsOnly() { var ( ctx = context.Background() diff --git a/internal/db/bundb/migrations/20240426122821_pageable_admin_accounts.go b/internal/db/bundb/migrations/20240426122821_pageable_admin_accounts.go new file mode 100644 index 000000000..00465cc85 --- /dev/null +++ b/internal/db/bundb/migrations/20240426122821_pageable_admin_accounts.go @@ -0,0 +1,84 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + log.Info(ctx, "reindexing accounts (accounts_paging_idx); this may take a few minutes, please don't interrupt this migration!") + + q := db.NewCreateIndex(). + TableExpr("accounts"). + Index("accounts_paging_idx"). + IfNotExists() + + switch d := db.Dialect().Name(); d { + case dialect.SQLite: + q = q.ColumnExpr( + "COALESCE(?, ?) || ? || ?", + bun.Ident("domain"), "", + "/@", + bun.Ident("username"), + ) + + // Specify C collation for Postgres to ensure + // alphabetic sort order is similar enough to + // SQLite (which uses BINARY sort by default). + // + // See: + // + // - https://www.postgresql.org/docs/current/collation.html#COLLATION-MANAGING-STANDARD + // - https://sqlite.org/datatype3.html#collation + case dialect.PG: + q = q.ColumnExpr( + "(COALESCE(?, ?) || ? || ?) COLLATE ?", + bun.Ident("domain"), "", + "/@", + bun.Ident("username"), + bun.Ident("C"), + ) + + default: + log.Panicf(ctx, "dialect %s was neither postgres nor sqlite", d) + } + + if _, err := q.Exec(ctx); err != nil { + return err + } + + return nil + + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} |