summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/account.go21
-rw-r--r--internal/db/bundb/account.go254
-rw-r--r--internal/db/bundb/account_test.go185
-rw-r--r--internal/db/bundb/user.go20
-rw-r--r--internal/db/user.go6
5 files changed, 486 insertions, 0 deletions
diff --git a/internal/db/account.go b/internal/db/account.go
index 45276f41f..7cdf7b57f 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -19,9 +19,11 @@ package db
import (
"context"
+ "net/netip"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
// Account contains functions related to account getting/setting/creation.
@@ -56,6 +58,25 @@ type Account interface {
// GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong.
GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
+ // GetAccounts returns accounts
+ // with the given parameters.
+ GetAccounts(
+ ctx context.Context,
+ origin string,
+ status string,
+ mods bool,
+ invitedBy string,
+ username string,
+ displayName string,
+ domain string,
+ email string,
+ ip netip.Addr,
+ page *paging.Page,
+ ) (
+ []*gtsmodel.Account,
+ error,
+ )
+
// PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
PopulateAccount(ctx context.Context, account *gtsmodel.Account) error
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 1ecf28e42..45e67c10b 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -20,6 +20,8 @@ package bundb
import (
"context"
"errors"
+ "fmt"
+ "net/netip"
"slices"
"strings"
"time"
@@ -31,6 +33,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
@@ -249,6 +252,257 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
return a.GetAccountByUsernameDomain(ctx, username, domain)
}
+func (a *accountDB) GetAccounts(
+ ctx context.Context,
+ origin string,
+ status string,
+ mods bool,
+ invitedBy string,
+ username string,
+ displayName string,
+ domain string,
+ email string,
+ ip netip.Addr,
+ page *paging.Page,
+) (
+ []*gtsmodel.Account,
+ error,
+) {
+ var (
+ // local users lists,
+ // required for some
+ // limiting parameters.
+ users []*gtsmodel.User
+
+ // lazyLoadUsers only loads the users
+ // slice if it's required by params.
+ lazyLoadUsers = func() (err error) {
+ if users == nil {
+ users, err = a.state.DB.GetAllUsers(gtscontext.SetBarebones(ctx))
+ if err != nil {
+ return fmt.Errorf("error getting users: %w", err)
+ }
+ }
+ return nil
+ }
+
+ // Get paging params.
+ //
+ // Note this may be min_id OR since_id
+ // from the API, this gets handled below
+ // when checking order to reverse slice.
+ minID = page.GetMin()
+ maxID = page.GetMax()
+ limit = page.GetLimit()
+ order = page.GetOrder()
+
+ // Make educated guess for slice size
+ accountIDs = make([]string, 0, limit)
+ accountIDIn []string
+
+ useAccountIDIn bool
+ )
+
+ q := a.db.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ // Select only IDs from table
+ Column("account.id")
+
+ // Return only accounts OLDER
+ // than account with maxID.
+ if maxID != "" {
+ maxIDAcct, err := a.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ maxID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting maxID account %s: %w", maxID, err)
+ }
+
+ q = q.Where("? < ?", bun.Ident("account.created_at"), maxIDAcct.CreatedAt)
+ }
+
+ // Return only accounts NEWER
+ // than account with minID.
+ if minID != "" {
+ minIDAcct, err := a.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ minID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting minID account %s: %w", minID, err)
+ }
+
+ q = q.Where("? > ?", bun.Ident("account.created_at"), minIDAcct.CreatedAt)
+ }
+
+ switch status {
+
+ case "active":
+ // Get only enabled accounts.
+ if err := lazyLoadUsers(); err != nil {
+ return nil, err
+ }
+ for _, user := range users {
+ if !*user.Disabled {
+ accountIDIn = append(accountIDIn, user.AccountID)
+ }
+ }
+ useAccountIDIn = true
+
+ case "pending":
+ // Get only unapproved accounts.
+ if err := lazyLoadUsers(); err != nil {
+ return nil, err
+ }
+ for _, user := range users {
+ if !*user.Approved {
+ accountIDIn = append(accountIDIn, user.AccountID)
+ }
+ }
+ useAccountIDIn = true
+
+ case "disabled":
+ // Get only disabled accounts.
+ if err := lazyLoadUsers(); err != nil {
+ return nil, err
+ }
+ for _, user := range users {
+ if *user.Disabled {
+ accountIDIn = append(accountIDIn, user.AccountID)
+ }
+ }
+ useAccountIDIn = true
+
+ case "silenced":
+ // Get only silenced accounts.
+ q = q.Where("? IS NOT NULL", bun.Ident("account.silenced_at"))
+
+ case "suspended":
+ // Get only suspended accounts.
+ q = q.Where("? IS NOT NULL", bun.Ident("account.suspended_at"))
+ }
+
+ if mods {
+ // Get only mod accounts.
+ if err := lazyLoadUsers(); err != nil {
+ return nil, err
+ }
+ for _, user := range users {
+ if *user.Moderator || *user.Admin {
+ accountIDIn = append(accountIDIn, user.AccountID)
+ }
+ }
+ useAccountIDIn = true
+ }
+
+ // TODO: invitedBy
+
+ if username != "" {
+ q = q.Where("? = ?", bun.Ident("account.username"), username)
+ }
+
+ if displayName != "" {
+ q = q.Where("? = ?", bun.Ident("account.display_name"), displayName)
+ }
+
+ if domain != "" {
+ q = q.Where("? = ?", bun.Ident("account.domain"), domain)
+ }
+
+ if email != "" {
+ if err := lazyLoadUsers(); err != nil {
+ return nil, err
+ }
+ for _, user := range users {
+ if user.Email == email || user.UnconfirmedEmail == email {
+ accountIDIn = append(accountIDIn, user.AccountID)
+ }
+ }
+ useAccountIDIn = true
+ }
+
+ // Use ip if not zero value.
+ if ip.IsValid() {
+ if err := lazyLoadUsers(); err != nil {
+ return nil, err
+ }
+ for _, user := range users {
+ if user.SignUpIP.String() == ip.String() {
+ accountIDIn = append(accountIDIn, user.AccountID)
+ }
+ }
+ useAccountIDIn = true
+ }
+
+ if origin == "local" && !useAccountIDIn {
+ // In the case we're not already limiting
+ // by specific subset of account IDs, just
+ // use existing list of user.AccountIDs
+ // instead of adding WHERE to the query.
+ if err := lazyLoadUsers(); err != nil {
+ return nil, err
+ }
+ for _, user := range users {
+ accountIDIn = append(accountIDIn, user.AccountID)
+ }
+ useAccountIDIn = true
+
+ } else if origin == "remote" {
+ if useAccountIDIn {
+ // useAccountIDIn specifically indicates
+ // a parameter that limits querying to
+ // local accounts, there will be none.
+ return nil, nil
+ }
+
+ // Get only remote accounts.
+ q = q.Where("? IS NOT NULL", bun.Ident("account.domain"))
+ }
+
+ if useAccountIDIn {
+ if len(accountIDIn) == 0 {
+ // There will be no
+ // possible answer.
+ return nil, nil
+ }
+
+ q = q.Where("? IN (?)", bun.Ident("account.id"), bun.In(accountIDIn))
+ }
+
+ if limit > 0 {
+ // Limit amount of
+ // accounts returned.
+ q = q.Limit(limit)
+ }
+
+ if order == paging.OrderAscending {
+ // Page up.
+ q = q.Order("account.created_at ASC")
+ } else {
+ // Page down.
+ q = q.Order("account.created_at DESC")
+ }
+
+ if err := q.Scan(ctx, &accountIDs); err != nil {
+ return nil, err
+ }
+
+ if len(accountIDs) == 0 {
+ return nil, nil
+ }
+
+ // If we're paging up, we still want accounts
+ // to be sorted by createdAt desc, so reverse ids slice.
+ if order == paging.OrderAscending {
+ slices.Reverse(accountIDs)
+ }
+
+ // Return account IDs loaded from cache + db.
+ return a.state.DB.GetAccountsByIDs(ctx, accountIDs)
+}
+
func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {
// Fetch account from database cache with loader callback
account, err := a.state.Caches.GTS.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) {
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index 21e04dedc..dd96543b6 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -23,6 +23,7 @@ import (
"crypto/rsa"
"errors"
"fmt"
+ "net/netip"
"reflect"
"strings"
"testing"
@@ -33,6 +34,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -491,6 +493,189 @@ func (suite *AccountTestSuite) TestPopulateAccountWithUnknownMovedToURI() {
suite.NoError(err)
}
+func (suite *AccountTestSuite) TestGetAccountsAll() {
+ var (
+ ctx = context.Background()
+ origin = ""
+ status = ""
+ mods = false
+ invitedBy = ""
+ username = ""
+ displayName = ""
+ domain = ""
+ email = ""
+ ip netip.Addr
+ page *paging.Page = nil
+ )
+
+ accounts, err := suite.db.GetAccounts(
+ ctx,
+ origin,
+ status,
+ mods,
+ invitedBy,
+ username,
+ displayName,
+ domain,
+ email,
+ ip,
+ page,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 9)
+}
+
+func (suite *AccountTestSuite) TestGetAccountsModsOnly() {
+ var (
+ ctx = context.Background()
+ origin = ""
+ status = ""
+ mods = true
+ invitedBy = ""
+ username = ""
+ displayName = ""
+ domain = ""
+ email = ""
+ ip netip.Addr
+ page = &paging.Page{
+ Limit: 100,
+ }
+ )
+
+ accounts, err := suite.db.GetAccounts(
+ ctx,
+ origin,
+ status,
+ mods,
+ invitedBy,
+ username,
+ displayName,
+ domain,
+ email,
+ ip,
+ page,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 1)
+}
+
+func (suite *AccountTestSuite) TestGetAccountsLocalWithEmail() {
+ var (
+ ctx = context.Background()
+ origin = "local"
+ status = ""
+ mods = false
+ invitedBy = ""
+ username = ""
+ displayName = ""
+ domain = ""
+ email = "tortle.dude@example.org"
+ ip netip.Addr
+ page = &paging.Page{
+ Limit: 100,
+ }
+ )
+
+ accounts, err := suite.db.GetAccounts(
+ ctx,
+ origin,
+ status,
+ mods,
+ invitedBy,
+ username,
+ displayName,
+ domain,
+ email,
+ ip,
+ page,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 1)
+}
+
+func (suite *AccountTestSuite) TestGetAccountsWithIP() {
+ var (
+ ctx = context.Background()
+ origin = ""
+ status = ""
+ mods = false
+ invitedBy = ""
+ username = ""
+ displayName = ""
+ domain = ""
+ email = ""
+ ip = netip.MustParseAddr("199.222.111.89")
+ page = &paging.Page{
+ Limit: 100,
+ }
+ )
+
+ accounts, err := suite.db.GetAccounts(
+ ctx,
+ origin,
+ status,
+ mods,
+ invitedBy,
+ username,
+ displayName,
+ domain,
+ email,
+ ip,
+ page,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 1)
+}
+
+func (suite *AccountTestSuite) TestGetPendingAccounts() {
+ var (
+ ctx = context.Background()
+ origin = ""
+ status = "pending"
+ mods = false
+ invitedBy = ""
+ username = ""
+ displayName = ""
+ domain = ""
+ email = ""
+ ip netip.Addr
+ page = &paging.Page{
+ Limit: 100,
+ }
+ )
+
+ accounts, err := suite.db.GetAccounts(
+ ctx,
+ origin,
+ status,
+ mods,
+ invitedBy,
+ username,
+ displayName,
+ domain,
+ email,
+ ip,
+ page,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 1)
+}
+
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go
index 2854c0caa..f0221eeb1 100644
--- a/internal/db/bundb/user.go
+++ b/internal/db/bundb/user.go
@@ -230,3 +230,23 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error {
Exec(ctx)
return err
}
+
+func (u *userDB) PutDeniedUser(ctx context.Context, deniedUser *gtsmodel.DeniedUser) error {
+ _, err := u.db.NewInsert().
+ Model(deniedUser).
+ Exec(ctx)
+ return err
+}
+
+func (u *userDB) GetDeniedUserByID(ctx context.Context, id string) (*gtsmodel.DeniedUser, error) {
+ deniedUser := new(gtsmodel.DeniedUser)
+ if err := u.db.
+ NewSelect().
+ Model(deniedUser).
+ Where("? = ?", bun.Ident("denied_user.id"), id).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return deniedUser, nil
+}
diff --git a/internal/db/user.go b/internal/db/user.go
index c762ef2b3..28fa59130 100644
--- a/internal/db/user.go
+++ b/internal/db/user.go
@@ -54,4 +54,10 @@ type User interface {
// DeleteUserByID deletes one user by its ID.
DeleteUserByID(ctx context.Context, userID string) error
+
+ // PutDeniedUser inserts the given deniedUser into the db.
+ PutDeniedUser(ctx context.Context, deniedUser *gtsmodel.DeniedUser) error
+
+ // GetDeniedUserByID returns one denied user with the given ID.
+ GetDeniedUserByID(ctx context.Context, id string) (*gtsmodel.DeniedUser, error)
}