From fd6637df4aeed721442bff6dfbce9bdd1b5ac7b8 Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:42:41 +0000 Subject: [bugfix] boost and account recursion (#2982) * fix possible infinite recursion if moved accounts are self-referential * adds a defensive check for a boost being a boost of a boost wrapper * add checks on input for a boost of a boost * remove unnecessary check * add protections on account move to prevent move recursion loops * separate status conversion without boost logic into separate function to remove risk of recursion * move boost check to boost function itself * formatting * use error 422 instead of 500 * use gtserror not standard errors package for error creation --- internal/db/account.go | 3 +++ internal/db/bundb/account.go | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+) (limited to 'internal/db') diff --git a/internal/db/account.go b/internal/db/account.go index dec36d2ac..4f02a4d29 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -57,6 +57,9 @@ type Account interface { // GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong. GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) + // GetAccountByMovedToURI returns any accounts with given moved_to_uri set. + GetAccountsByMovedToURI(ctx context.Context, uri string) ([]*gtsmodel.Account, error) + // GetAccounts returns accounts // with the given parameters. GetAccounts( diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 4e969e0ef..eb5385c70 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -252,6 +252,27 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts return a.GetAccountByUsernameDomain(ctx, username, domain) } +func (a *accountDB) GetAccountsByMovedToURI(ctx context.Context, uri string) ([]*gtsmodel.Account, error) { + var accountIDs []string + + // Find all account IDs with + // given moved_to_uri column. + if err := a.db.NewSelect(). + Table("accounts"). + Column("id"). + Where("? = ?", bun.Ident("moved_to_uri"), uri). + Scan(ctx, &accountIDs); err != nil { + return nil, err + } + + if len(accountIDs) == 0 { + return nil, nil + } + + // Return account models for all found IDs. + return a.GetAccountsByIDs(ctx, accountIDs) +} + // GetAccounts selects accounts using the given parameters. // Unlike with other functions, the paging for GetAccounts // is done not by ID, but by a concatenation of `[domain]/@[username]`, -- cgit v1.2.3