diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/account.go | 254 | ||||
| -rw-r--r-- | internal/db/bundb/account_test.go | 185 | ||||
| -rw-r--r-- | internal/db/bundb/user.go | 20 | 
3 files changed, 459 insertions, 0 deletions
| diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 1ecf28e42..45e67c10b 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -20,6 +20,8 @@ package bundb  import (  	"context"  	"errors" +	"fmt" +	"net/netip"  	"slices"  	"strings"  	"time" @@ -31,6 +33,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/id"  	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun" @@ -249,6 +252,257 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts  	return a.GetAccountByUsernameDomain(ctx, username, domain)  } +func (a *accountDB) GetAccounts( +	ctx context.Context, +	origin string, +	status string, +	mods bool, +	invitedBy string, +	username string, +	displayName string, +	domain string, +	email string, +	ip netip.Addr, +	page *paging.Page, +) ( +	[]*gtsmodel.Account, +	error, +) { +	var ( +		// local users lists, +		// required for some +		// limiting parameters. +		users []*gtsmodel.User + +		// lazyLoadUsers only loads the users +		// slice if it's required by params. +		lazyLoadUsers = func() (err error) { +			if users == nil { +				users, err = a.state.DB.GetAllUsers(gtscontext.SetBarebones(ctx)) +				if err != nil { +					return fmt.Errorf("error getting users: %w", err) +				} +			} +			return nil +		} + +		// Get paging params. +		// +		// Note this may be min_id OR since_id +		// from the API, this gets handled below +		// when checking order to reverse slice. +		minID = page.GetMin() +		maxID = page.GetMax() +		limit = page.GetLimit() +		order = page.GetOrder() + +		// Make educated guess for slice size +		accountIDs  = make([]string, 0, limit) +		accountIDIn []string + +		useAccountIDIn bool +	) + +	q := a.db. +		NewSelect(). +		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). +		// Select only IDs from table +		Column("account.id") + +	// Return only accounts OLDER +	// than account with maxID. +	if maxID != "" { +		maxIDAcct, err := a.GetAccountByID( +			gtscontext.SetBarebones(ctx), +			maxID, +		) +		if err != nil { +			return nil, fmt.Errorf("error getting maxID account %s: %w", maxID, err) +		} + +		q = q.Where("? < ?", bun.Ident("account.created_at"), maxIDAcct.CreatedAt) +	} + +	// Return only accounts NEWER +	// than account with minID. +	if minID != "" { +		minIDAcct, err := a.GetAccountByID( +			gtscontext.SetBarebones(ctx), +			minID, +		) +		if err != nil { +			return nil, fmt.Errorf("error getting minID account %s: %w", minID, err) +		} + +		q = q.Where("? > ?", bun.Ident("account.created_at"), minIDAcct.CreatedAt) +	} + +	switch status { + +	case "active": +		// Get only enabled accounts. +		if err := lazyLoadUsers(); err != nil { +			return nil, err +		} +		for _, user := range users { +			if !*user.Disabled { +				accountIDIn = append(accountIDIn, user.AccountID) +			} +		} +		useAccountIDIn = true + +	case "pending": +		// Get only unapproved accounts. +		if err := lazyLoadUsers(); err != nil { +			return nil, err +		} +		for _, user := range users { +			if !*user.Approved { +				accountIDIn = append(accountIDIn, user.AccountID) +			} +		} +		useAccountIDIn = true + +	case "disabled": +		// Get only disabled accounts. +		if err := lazyLoadUsers(); err != nil { +			return nil, err +		} +		for _, user := range users { +			if *user.Disabled { +				accountIDIn = append(accountIDIn, user.AccountID) +			} +		} +		useAccountIDIn = true + +	case "silenced": +		// Get only silenced accounts. +		q = q.Where("? IS NOT NULL", bun.Ident("account.silenced_at")) + +	case "suspended": +		// Get only suspended accounts. +		q = q.Where("? IS NOT NULL", bun.Ident("account.suspended_at")) +	} + +	if mods { +		// Get only mod accounts. +		if err := lazyLoadUsers(); err != nil { +			return nil, err +		} +		for _, user := range users { +			if *user.Moderator || *user.Admin { +				accountIDIn = append(accountIDIn, user.AccountID) +			} +		} +		useAccountIDIn = true +	} + +	// TODO: invitedBy + +	if username != "" { +		q = q.Where("? = ?", bun.Ident("account.username"), username) +	} + +	if displayName != "" { +		q = q.Where("? = ?", bun.Ident("account.display_name"), displayName) +	} + +	if domain != "" { +		q = q.Where("? = ?", bun.Ident("account.domain"), domain) +	} + +	if email != "" { +		if err := lazyLoadUsers(); err != nil { +			return nil, err +		} +		for _, user := range users { +			if user.Email == email || user.UnconfirmedEmail == email { +				accountIDIn = append(accountIDIn, user.AccountID) +			} +		} +		useAccountIDIn = true +	} + +	// Use ip if not zero value. +	if ip.IsValid() { +		if err := lazyLoadUsers(); err != nil { +			return nil, err +		} +		for _, user := range users { +			if user.SignUpIP.String() == ip.String() { +				accountIDIn = append(accountIDIn, user.AccountID) +			} +		} +		useAccountIDIn = true +	} + +	if origin == "local" && !useAccountIDIn { +		// In the case we're not already limiting +		// by specific subset of account IDs, just +		// use existing list of user.AccountIDs +		// instead of adding WHERE to the query. +		if err := lazyLoadUsers(); err != nil { +			return nil, err +		} +		for _, user := range users { +			accountIDIn = append(accountIDIn, user.AccountID) +		} +		useAccountIDIn = true + +	} else if origin == "remote" { +		if useAccountIDIn { +			// useAccountIDIn specifically indicates +			// a parameter that limits querying to +			// local accounts, there will be none. +			return nil, nil +		} + +		// Get only remote accounts. +		q = q.Where("? IS NOT NULL", bun.Ident("account.domain")) +	} + +	if useAccountIDIn { +		if len(accountIDIn) == 0 { +			// There will be no +			// possible answer. +			return nil, nil +		} + +		q = q.Where("? IN (?)", bun.Ident("account.id"), bun.In(accountIDIn)) +	} + +	if limit > 0 { +		// Limit amount of +		// accounts returned. +		q = q.Limit(limit) +	} + +	if order == paging.OrderAscending { +		// Page up. +		q = q.Order("account.created_at ASC") +	} else { +		// Page down. +		q = q.Order("account.created_at DESC") +	} + +	if err := q.Scan(ctx, &accountIDs); err != nil { +		return nil, err +	} + +	if len(accountIDs) == 0 { +		return nil, nil +	} + +	// If we're paging up, we still want accounts +	// to be sorted by createdAt desc, so reverse ids slice. +	if order == paging.OrderAscending { +		slices.Reverse(accountIDs) +	} + +	// Return account IDs loaded from cache + db. +	return a.state.DB.GetAccountsByIDs(ctx, accountIDs) +} +  func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {  	// Fetch account from database cache with loader callback  	account, err := a.state.Caches.GTS.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) { diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 21e04dedc..dd96543b6 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -23,6 +23,7 @@ import (  	"crypto/rsa"  	"errors"  	"fmt" +	"net/netip"  	"reflect"  	"strings"  	"testing" @@ -33,6 +34,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/db/bundb"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -491,6 +493,189 @@ func (suite *AccountTestSuite) TestPopulateAccountWithUnknownMovedToURI() {  	suite.NoError(err)  } +func (suite *AccountTestSuite) TestGetAccountsAll() { +	var ( +		ctx         = context.Background() +		origin      = "" +		status      = "" +		mods        = false +		invitedBy   = "" +		username    = "" +		displayName = "" +		domain      = "" +		email       = "" +		ip          netip.Addr +		page        *paging.Page = nil +	) + +	accounts, err := suite.db.GetAccounts( +		ctx, +		origin, +		status, +		mods, +		invitedBy, +		username, +		displayName, +		domain, +		email, +		ip, +		page, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.Len(accounts, 9) +} + +func (suite *AccountTestSuite) TestGetAccountsModsOnly() { +	var ( +		ctx         = context.Background() +		origin      = "" +		status      = "" +		mods        = true +		invitedBy   = "" +		username    = "" +		displayName = "" +		domain      = "" +		email       = "" +		ip          netip.Addr +		page        = &paging.Page{ +			Limit: 100, +		} +	) + +	accounts, err := suite.db.GetAccounts( +		ctx, +		origin, +		status, +		mods, +		invitedBy, +		username, +		displayName, +		domain, +		email, +		ip, +		page, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.Len(accounts, 1) +} + +func (suite *AccountTestSuite) TestGetAccountsLocalWithEmail() { +	var ( +		ctx         = context.Background() +		origin      = "local" +		status      = "" +		mods        = false +		invitedBy   = "" +		username    = "" +		displayName = "" +		domain      = "" +		email       = "tortle.dude@example.org" +		ip          netip.Addr +		page        = &paging.Page{ +			Limit: 100, +		} +	) + +	accounts, err := suite.db.GetAccounts( +		ctx, +		origin, +		status, +		mods, +		invitedBy, +		username, +		displayName, +		domain, +		email, +		ip, +		page, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.Len(accounts, 1) +} + +func (suite *AccountTestSuite) TestGetAccountsWithIP() { +	var ( +		ctx         = context.Background() +		origin      = "" +		status      = "" +		mods        = false +		invitedBy   = "" +		username    = "" +		displayName = "" +		domain      = "" +		email       = "" +		ip          = netip.MustParseAddr("199.222.111.89") +		page        = &paging.Page{ +			Limit: 100, +		} +	) + +	accounts, err := suite.db.GetAccounts( +		ctx, +		origin, +		status, +		mods, +		invitedBy, +		username, +		displayName, +		domain, +		email, +		ip, +		page, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.Len(accounts, 1) +} + +func (suite *AccountTestSuite) TestGetPendingAccounts() { +	var ( +		ctx         = context.Background() +		origin      = "" +		status      = "pending" +		mods        = false +		invitedBy   = "" +		username    = "" +		displayName = "" +		domain      = "" +		email       = "" +		ip          netip.Addr +		page        = &paging.Page{ +			Limit: 100, +		} +	) + +	accounts, err := suite.db.GetAccounts( +		ctx, +		origin, +		status, +		mods, +		invitedBy, +		username, +		displayName, +		domain, +		email, +		ip, +		page, +	) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.Len(accounts, 1) +} +  func TestAccountTestSuite(t *testing.T) {  	suite.Run(t, new(AccountTestSuite))  } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 2854c0caa..f0221eeb1 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -230,3 +230,23 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error {  		Exec(ctx)  	return err  } + +func (u *userDB) PutDeniedUser(ctx context.Context, deniedUser *gtsmodel.DeniedUser) error { +	_, err := u.db.NewInsert(). +		Model(deniedUser). +		Exec(ctx) +	return err +} + +func (u *userDB) GetDeniedUserByID(ctx context.Context, id string) (*gtsmodel.DeniedUser, error) { +	deniedUser := new(gtsmodel.DeniedUser) +	if err := u.db. +		NewSelect(). +		Model(deniedUser). +		Where("? = ?", bun.Ident("denied_user.id"), id). +		Scan(ctx); err != nil { +		return nil, err +	} + +	return deniedUser, nil +} | 
