diff options
Diffstat (limited to 'internal/db')
28 files changed, 954 insertions, 488 deletions
| diff --git a/internal/db/account.go b/internal/db/account.go index 351d6d01c..ae5eea7c6 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -48,6 +48,11 @@ type Account interface {  	// UpdateAccount updates one account by ID.  	UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) +	// DeleteAccount deletes one account from the database by its ID. +	// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the +	// account as suspended instead, rather than deleting from the db entirely. +	DeleteAccount(ctx context.Context, id string) Error +  	// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.  	GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error) diff --git a/internal/db/basic.go b/internal/db/basic.go index 6e5184d31..8990edd5f 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -62,11 +62,11 @@ type Basic interface {  	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.  	Put(ctx context.Context, i interface{}) Error -	// UpdateByPrimaryKey updates values of i based on its primary key. +	// UpdateByID updates values of i based on its id.  	// If any columns are specified, these will be updated exclusively.  	// Otherwise, the whole model will be updated.  	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. -	UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) Error +	UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) Error  	// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.  	UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 074804690..c04948fee 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -21,7 +21,6 @@ package bundb  import (  	"context"  	"errors" -	"fmt"  	"strings"  	"time" @@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac  			return a.cache.GetByID(id)  		},  		func(account *gtsmodel.Account) error { -			return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) +			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)  		},  	)  } @@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.  			return a.cache.GetByURI(uri)  		},  		func(account *gtsmodel.Account) error { -			return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) +			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)  		},  	)  } @@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.  			return a.cache.GetByURL(url)  		},  		func(account *gtsmodel.Account) error { -			return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) +			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)  		},  	)  } @@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str  			q := a.newAccountQ(account)  			if domain != "" { -				q = q.Where("account.username = ?", username) -				q = q.Where("account.domain = ?", domain) +				q = q.Where("? = ?", bun.Ident("account.username"), username) +				q = q.Where("? = ?", bun.Ident("account.domain"), domain)  			} else { -				q = q.Where("account.username = ?", strings.ToLower(username)) -				q = q.Where("account.domain IS NULL") +				q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) +				q = q.Where("? IS NULL", bun.Ident("account.domain"))  			}  			return q.Scan(ctx) @@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo  			return a.cache.GetByPubkeyID(id)  		},  		func(account *gtsmodel.Account) error { -			return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx) +			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)  		},  	)  } @@ -169,26 +168,36 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account  	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {  		// create links between this account and any emojis it uses  		// first clear out any old emoji links -		if _, err := tx.NewDelete(). -			Model(&[]*gtsmodel.AccountToEmoji{}). -			Where("account_id = ?", account.ID). +		if _, err := tx. +			NewDelete(). +			TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). +			Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).  			Exec(ctx); err != nil {  			return err  		}  		// now populate new emoji links  		for _, i := range account.EmojiIDs { -			if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ -				AccountID: account.ID, -				EmojiID:   i, -			}).Exec(ctx); err != nil { +			if _, err := tx. +				NewInsert(). +				Model(>smodel.AccountToEmoji{ +					AccountID: account.ID, +					EmojiID:   i, +				}).Exec(ctx); err != nil {  				return err  			}  		}  		// update the account -		_, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx) -		return err +		if _, err := tx. +			NewUpdate(). +			Model(account). +			Where("? = ?", bun.Ident("account.id"), account.ID). +			Exec(ctx); err != nil { +			return err +		} + +		return nil  	}); err != nil {  		return nil, a.conn.ProcessError(err)  	} @@ -197,6 +206,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account  	return account, nil  } +func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { +	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { +		// clear out any emoji links +		if _, err := tx. +			NewDelete(). +			TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). +			Where("? = ?", bun.Ident("account_to_emoji.account_id"), id). +			Exec(ctx); err != nil { +			return err +		} + +		// delete the account +		_, err := tx. +			NewUpdate(). +			TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). +			Where("? = ?", bun.Ident("account.id"), id). +			Exec(ctx) +		return err +	}); err != nil { +		return a.conn.ProcessError(err) +	} + +	a.cache.Invalidate(id) +	return nil +} +  func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {  	account := new(gtsmodel.Account) @@ -204,11 +239,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts  	if domain != "" {  		q = q. -			Where("account.username = ?", domain). -			Where("account.domain = ?", domain) +			Where("? = ?", bun.Ident("account.username"), domain). +			Where("? = ?", bun.Ident("account.domain"), domain)  	} else {  		q = q. -			Where("account.username = ?", config.GetHost()). +			Where("? = ?", bun.Ident("account.username"), config.GetHost()).  			WhereGroup(" AND ", whereEmptyOrNull("domain"))  	} @@ -224,10 +259,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)  	q := a.conn.  		NewSelect().  		Model(status). -		Order("id DESC"). -		Limit(1). -		Where("account_id = ?", accountID). -		Column("created_at") +		Column("status.created_at"). +		Where("? = ?", bun.Ident("status.account_id"), accountID). +		Order("status.id DESC"). +		Limit(1)  	if err := q.Scan(ctx); err != nil {  		return time.Time{}, a.conn.ProcessError(err) @@ -240,12 +275,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen  		return errors.New("one media attachment cannot be both header and avatar")  	} -	var headerOrAVI string +	var column bun.Ident  	switch {  	case *mediaAttachment.Avatar: -		headerOrAVI = "avatar" +		column = bun.Ident("account.avatar_media_attachment_id")  	case *mediaAttachment.Header: -		headerOrAVI = "header" +		column = bun.Ident("account.header_media_attachment_id")  	default:  		return errors.New("given media attachment was neither a header nor an avatar")  	} @@ -257,11 +292,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen  		Exec(ctx); err != nil {  		return a.conn.ProcessError(err)  	} +  	if _, err := a.conn.  		NewUpdate(). -		Model(>smodel.Account{}). -		Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). -		Where("id = ?", accountID). +		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). +		Set("? = ?", column, mediaAttachment.ID). +		Where("? = ?", bun.Ident("account.id"), accountID).  		Exec(ctx); err != nil {  		return a.conn.ProcessError(err)  	} @@ -284,7 +320,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g  	if err := a.conn.  		NewSelect().  		Model(faves). -		Where("account_id = ?", accountID). +		Where("? = ?", bun.Ident("status_fave.account_id"), accountID).  		Scan(ctx); err != nil {  		return nil, a.conn.ProcessError(err)  	} @@ -295,8 +331,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g  func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {  	return a.conn.  		NewSelect(). -		Model(>smodel.Status{}). -		Where("account_id = ?", accountID). +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		Where("? = ?", bun.Ident("status.account_id"), accountID).  		Count(ctx)  } @@ -305,12 +341,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li  	q := a.conn.  		NewSelect(). -		Table("statuses"). -		Column("id"). -		Order("id DESC") +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		Column("status.id"). +		Order("status.id DESC")  	if accountID != "" { -		q = q.Where("account_id = ?", accountID) +		q = q.Where("? = ?", bun.Ident("status.account_id"), accountID)  	}  	if limit != 0 { @@ -321,27 +357,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li  		// include self-replies (threads)  		whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {  			return q. -				WhereOr("in_reply_to_account_id = ?", accountID). -				WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri")) +				WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). +				WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri"))  		}  		q = q.WhereGroup(" AND ", whereGroup)  	}  	if excludeReblogs { -		q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")) +		q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id"))  	}  	if maxID != "" { -		q = q.Where("id < ?", maxID) +		q = q.Where("? < ?", bun.Ident("status.id"), maxID)  	}  	if minID != "" { -		q = q.Where("id > ?", minID) +		q = q.Where("? > ?", bun.Ident("status.id"), minID)  	}  	if pinnedOnly { -		q = q.Where("pinned = ?", true) +		q = q.Where("? = ?", bun.Ident("status.pinned"), true)  	}  	if mediaOnly { @@ -352,15 +388,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li  			switch a.conn.Dialect().Name() {  			case dialect.PG:  				return q. -					Where("? IS NOT NULL", bun.Ident("attachments")). -					Where("? != '{}'", bun.Ident("attachments")) +					Where("? IS NOT NULL", bun.Ident("status.attachments")). +					Where("? != '{}'", bun.Ident("status.attachments"))  			case dialect.SQLite:  				return q. -					Where("? IS NOT NULL", bun.Ident("attachments")). -					Where("? != ''", bun.Ident("attachments")). -					Where("? != 'null'", bun.Ident("attachments")). -					Where("? != '{}'", bun.Ident("attachments")). -					Where("? != '[]'", bun.Ident("attachments")) +					Where("? IS NOT NULL", bun.Ident("status.attachments")). +					Where("? != ''", bun.Ident("status.attachments")). +					Where("? != 'null'", bun.Ident("status.attachments")). +					Where("? != '{}'", bun.Ident("status.attachments")). +					Where("? != '[]'", bun.Ident("status.attachments"))  			default:  				log.Panic("db dialect was neither pg nor sqlite")  				return q @@ -369,7 +405,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li  	}  	if publicOnly { -		q = q.Where("visibility = ?", gtsmodel.VisibilityPublic) +		q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)  	}  	if err := q.Scan(ctx, &statusIDs); err != nil { @@ -384,19 +420,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,  	q := a.conn.  		NewSelect(). -		Table("statuses"). -		Column("id"). -		Where("account_id = ?", accountID). -		WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")). -		WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")). -		Where("visibility = ?", gtsmodel.VisibilityPublic). -		Where("federated = ?", true) +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		Column("status.id"). +		Where("? = ?", bun.Ident("status.account_id"), accountID). +		WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). +		WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). +		Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). +		Where("? = ?", bun.Ident("status.federated"), true)  	if maxID != "" { -		q = q.Where("id < ?", maxID) +		q = q.Where("? < ?", bun.Ident("status.id"), maxID)  	} -	q = q.Limit(limit).Order("id DESC") +	q = q.Limit(limit).Order("status.id DESC")  	if err := q.Scan(ctx, &statusIDs); err != nil {  		return nil, a.conn.ProcessError(err) @@ -411,16 +447,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI  	fq := a.conn.  		NewSelect().  		Model(&blocks). -		Where("block.account_id = ?", accountID). +		Where("? = ?", bun.Ident("block.account_id"), accountID).  		Relation("TargetAccount").  		Order("block.id DESC")  	if maxID != "" { -		fq = fq.Where("block.id < ?", maxID) +		fq = fq.Where("? < ?", bun.Ident("block.id"), maxID)  	}  	if sinceID != "" { -		fq = fq.Where("block.id > ?", sinceID) +		fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID)  	}  	if limit > 0 { diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index ad2a217af..72adba487 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -42,6 +42,18 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() {  	suite.Len(statuses, 5)  } +func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() { +	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, false) +	suite.NoError(err) +	suite.Len(statuses, 5) +} + +func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() { +	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, true) +	suite.NoError(err) +	suite.Len(statuses, 1) +} +  func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() {  	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, true, false)  	suite.NoError(err) @@ -99,7 +111,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {  	err = dbService.GetConn().  		NewSelect().  		Model(noCache). -		Where("account.id = ?", bun.Ident(testAccount.ID)). +		Where("? = ?", bun.Ident("account.id"), testAccount.ID).  		Relation("AvatarMediaAttachment").  		Relation("HeaderMediaAttachment").  		Relation("Emojis"). @@ -127,7 +139,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {  	err = dbService.GetConn().  		NewSelect().  		Model(noCache). -		Where("account.id = ?", bun.Ident(testAccount.ID)). +		Where("? = ?", bun.Ident("account.id"), testAccount.ID).  		Relation("AvatarMediaAttachment").  		Relation("HeaderMediaAttachment").  		Relation("Emojis"). diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 9fa78eca0..44861a4bb 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -22,7 +22,6 @@ import (  	"context"  	"crypto/rand"  	"crypto/rsa" -	"database/sql"  	"fmt"  	"net"  	"net/mail" @@ -37,21 +36,26 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/id"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/uris" +	"github.com/uptrace/bun"  	"golang.org/x/crypto/bcrypt"  ) +// generate RSA keys of this length +const rsaKeyBits = 2048 +  type adminDB struct { -	conn      *DBConn -	userCache *cache.UserCache +	conn         *DBConn +	userCache    *cache.UserCache +	accountCache *cache.AccountCache  }  func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {  	q := a.conn.  		NewSelect(). -		Model(>smodel.Account{}). -		Where("username = ?", username). -		Where("domain = ?", nil) - +		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). +		Column("account.id"). +		Where("? = ?", bun.Ident("account.username"), username). +		Where("? IS NULL", bun.Ident("account.domain"))  	return a.conn.NotExists(ctx, q)  } @@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.  	domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @  	// check if the email domain is blocked -	if err := a.conn. +	emailDomainBlockedQ := a.conn.  		NewSelect(). -		Model(>smodel.EmailDomainBlock{}). -		Where("domain = ?", domain). -		Scan(ctx); err == nil { -		// fail because we found something +		TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). +		Column("email_domain_block.id"). +		Where("? = ?", bun.Ident("email_domain_block.domain"), domain) +	emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ) +	if err != nil { +		return false, err +	} +	if emailDomainBlocked {  		return false, fmt.Errorf("email domain %s is blocked", domain) -	} else if err != sql.ErrNoRows { -		return false, a.conn.ProcessError(err)  	}  	// check if this email is associated with a user already  	q := a.conn.  		NewSelect(). -		Model(>smodel.User{}). -		Where("email = ?", email). -		WhereOr("unconfirmed_email = ?", email) - +		TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). +		Column("user.id"). +		Where("? = ?", bun.Ident("user.email"), email). +		WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)  	return a.conn.NotExists(ctx, q)  }  func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { -	key, err := rsa.GenerateKey(rand.Reader, 2048) +	key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)  	if err != nil {  		log.Errorf("error creating new rsa key: %s", err)  		return nil, err @@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,  	// if something went wrong while creating a user, we might already have an account, so check here first...  	acct := >smodel.Account{} -	q := a.conn.NewSelect(). +	if err := a.conn. +		NewSelect().  		Model(acct). -		Where("username = ?", username). -		WhereGroup(" AND ", whereEmptyOrNull("domain")) +		Where("? = ?", bun.Ident("account.username"), username). +		WhereGroup(" AND ", whereEmptyOrNull("account.domain")). +		Scan(ctx); err != nil { +		err = a.conn.ProcessError(err) +		if err != db.ErrNoEntries { +			log.Errorf("error checking for existing account: %s", err) +			return nil, err +		} -	if err := q.Scan(ctx); err != nil { -		// we just don't have an account yet so create one before we proceed +		// if we have db.ErrNoEntries, we just don't have an +		// account yet so create one before we proceed  		accountURIs := uris.GenerateURIsForAccount(username)  		accountID, err := id.NewRandomULID()  		if err != nil { @@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,  			FeaturedCollectionURI: accountURIs.CollectionURI,  		} +		// insert the new account!  		if _, err = a.conn.  			NewInsert().  			Model(acct).  			Exec(ctx); err != nil {  			return nil, a.conn.ProcessError(err)  		} +		a.accountCache.Put(acct)  	} +	// we either created or already had an account by now, +	// so proceed with creating a user for that account +  	pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)  	if err != nil {  		return nil, fmt.Errorf("error hashing password: %s", err) @@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,  		u.Moderator = &moderator  	} +	// insert the user!  	if _, err = a.conn.  		NewInsert().  		Model(u). @@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {  	q := a.conn.  		NewSelect(). -		Model(>smodel.Account{}). -		Where("username = ?", username). -		WhereGroup(" AND ", whereEmptyOrNull("domain")) +		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). +		Column("account.id"). +		Where("? = ?", bun.Ident("account.username"), username). +		WhereGroup(" AND ", whereEmptyOrNull("account.domain"))  	exists, err := a.conn.Exists(ctx, q)  	if err != nil { @@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {  		return nil  	} -	key, err := rsa.GenerateKey(rand.Reader, 2048) +	key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)  	if err != nil {  		log.Errorf("error creating new rsa key: %s", err)  		return err @@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {  		return a.conn.ProcessError(err)  	} +	a.accountCache.Put(acct)  	log.Infof("instance account %s CREATED with id %s", username, acct.ID)  	return nil  } @@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {  	// check if instance entry already exists  	q := a.conn.  		NewSelect(). -		Model(>smodel.Instance{}). -		Where("domain = ?", host) +		Column("instance.id"). +		TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). +		Where("? = ?", bun.Ident("instance.domain"), host)  	exists, err := a.conn.Exists(ctx, q)  	if err != nil { diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go index 22041087a..f0a869a9b 100644 --- a/internal/db/bundb/admin_test.go +++ b/internal/db/bundb/admin_test.go @@ -23,6 +23,7 @@ import (  	"testing"  	"github.com/stretchr/testify/suite" +	gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -30,6 +31,44 @@ type AdminTestSuite struct {  	BunDBStandardTestSuite  } +func (suite *AdminTestSuite) TestIsUsernameAvailableNo() { +	available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork") +	suite.NoError(err) +	suite.False(available) +} + +func (suite *AdminTestSuite) TestIsUsernameAvailableYes() { +	available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different") +	suite.NoError(err) +	suite.True(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableNo() { +	available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org") +	suite.NoError(err) +	suite.False(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableYes() { +	available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com") +	suite.NoError(err) +	suite.True(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() { +	if err := suite.db.Put(context.Background(), >smodel.EmailDomainBlock{ +		ID:                 "01GEEV2R2YC5GRSN96761YJE47", +		Domain:             "somewhere.com", +		CreatedByAccountID: suite.testAccounts["admin_account"].ID, +	}); err != nil { +		suite.FailNow(err.Error()) +	} + +	available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com") +	suite.EqualError(err, "email domain somewhere.com is blocked") +	suite.False(available) +} +  func (suite *AdminTestSuite) TestCreateInstanceAccount() {  	// we need to take an empty db for this...  	testrig.StandardDBTeardown(suite.db) diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index cd80c9330..ef8b35574 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -94,12 +94,12 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface  	return b.conn.ProcessError(err)  } -func (b *basicDB) UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) db.Error { +func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) db.Error {  	q := b.conn.  		NewUpdate().  		Model(i).  		Column(columns...). -		WherePK() +		Where("? = ?", bun.Ident("id"), id)  	_, err := q.Exec(ctx)  	return b.conn.ProcessError(err) @@ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,  	updateWhere(q, where) -	q = q.Set("? = ?", bun.Safe(key), value) +	q = q.Set("? = ?", bun.Ident(key), value)  	_, err := q.Exec(ctx)  	return b.conn.ProcessError(err) diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 70a44d4c1..02522e6f7 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {  		return nil, fmt.Errorf("db migration error: %s", err)  	} -	// Create DB structs that require ptrs to each other -	accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()} -	status := &statusDB{conn: conn, cache: cache.NewStatusCache()} -	emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()} -	timeline := &timelineDB{conn: conn} - -	// Setup DB cross-referencing -	accounts.status = status -	status.accounts = accounts -	timeline.status = status +	// Prepare caches required by more than one struct +	userCache := cache.NewUserCache() +	accountCache := cache.NewAccountCache() +	// Prepare other caches  	// Prepare mentions cache  	// TODO: move into internal/cache  	mentionCache := grufcache.New[string, *gtsmodel.Mention]() @@ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {  	notifCache.SetTTL(time.Minute*5, false)  	notifCache.Start(time.Second * 10) -	// Prepare other caches -	blockCache := cache.NewDomainBlockCache() -	userCache := cache.NewUserCache() +	// Create DB structs that require ptrs to each other +	accounts := &accountDB{conn: conn, cache: accountCache} +	status := &statusDB{conn: conn, cache: cache.NewStatusCache()} +	emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()} +	timeline := &timelineDB{conn: conn} + +	// Setup DB cross-referencing +	accounts.status = status +	status.accounts = accounts +	timeline.status = status  	ps := &DBService{  		Account: accounts,  		Admin: &adminDB{ -			conn:      conn, -			userCache: userCache, +			conn:         conn, +			userCache:    userCache, +			accountCache: accountCache,  		},  		Basic: &basicDB{  			conn: conn,  		},  		Domain: &domainDB{  			conn:  conn, -			cache: blockCache, +			cache: cache.NewDomainBlockCache(),  		},  		Emoji: emoji,  		Instance: &instanceDB{ diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index 581573056..2af6cf122 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -40,6 +40,7 @@ type BunDBStandardTestSuite struct {  	testStatuses     map[string]*gtsmodel.Status  	testTags         map[string]*gtsmodel.Tag  	testMentions     map[string]*gtsmodel.Mention +	testFollows      map[string]*gtsmodel.Follow  }  func (suite *BunDBStandardTestSuite) SetupSuite() { @@ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {  	suite.testStatuses = testrig.NewTestStatuses()  	suite.testTags = testrig.NewTestTags()  	suite.testMentions = testrig.NewTestMentions() +	suite.testFollows = testrig.NewTestFollows()  }  func (suite *BunDBStandardTestSuite) SetupTest() { diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 9fc4bb276..0a752d3f3 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -28,6 +28,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun"  	"golang.org/x/net/idna"  ) @@ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel  	q := d.conn.  		NewSelect().  		Model(block). -		Where("domain = ?", domain). +		Where("? = ?", bun.Ident("domain_block.domain"), domain).  		Limit(1)  	// Query database for domain block @@ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro  	// Attempt to delete domain block  	if _, err := d.conn.NewDelete().  		Model((*gtsmodel.DomainBlock)(nil)). -		Where("domain = ?", domain). +		Where("? = ?", bun.Ident("domain_block.domain"), domain).  		Exec(ctx); err != nil {  		return d.conn.ProcessError(err)  	} diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 758da0feb..e781e2f00 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er  	q := e.conn.  		NewSelect(). -		Table("emojis"). -		Column("id"). -		Where("visible_in_picker = true"). -		Where("disabled = false"). -		Where("domain IS NULL"). -		Order("shortcode ASC") +		TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")). +		Column("emoji.id"). +		Where("? = ?", bun.Ident("emoji.visible_in_picker"), true). +		Where("? = ?", bun.Ident("emoji.disabled"), false). +		Where("? IS NULL", bun.Ident("emoji.domain")). +		Order("emoji.shortcode ASC")  	if err := q.Scan(ctx, &emojiIDs); err != nil {  		return nil, e.conn.ProcessError(err) @@ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,  			return e.cache.GetByID(id)  		},  		func(emoji *gtsmodel.Emoji) error { -			return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx) +			return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)  		},  	)  } @@ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj  			return e.cache.GetByURI(uri)  		},  		func(emoji *gtsmodel.Emoji) error { -			return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx) +			return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)  		},  	)  } @@ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin  			q := e.newEmojiQ(emoji)  			if domain != "" { -				q = q.Where("emoji.shortcode = ?", shortcode) -				q = q.Where("emoji.domain = ?", domain) +				q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode) +				q = q.Where("? = ?", bun.Ident("emoji.domain"), domain)  			} else { -				q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode)) -				q = q.Where("emoji.domain IS NULL") +				q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode)) +				q = q.Where("? IS NULL", bun.Ident("emoji.domain"))  			}  			return q.Scan(ctx) diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index fb6454e2f..604461708 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -24,7 +24,6 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/uptrace/bun"  ) @@ -35,15 +34,16 @@ type instanceDB struct {  func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {  	q := i.conn.  		NewSelect(). -		Model(&[]*gtsmodel.Account{}). -		Where("username != ?", domain). -		Where("? IS NULL", bun.Ident("suspended_at")) +		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). +		Column("account.id"). +		Where("? != ?", bun.Ident("account.username"), domain). +		Where("? IS NULL", bun.Ident("account.suspended_at")) -	if domain == config.GetHost() { +	if domain == config.GetHost() || domain == config.GetAccountDomain() {  		// if the domain is *this* domain, just count where the domain field is null -		q = q.WhereGroup(" AND ", whereEmptyOrNull("domain")) +		q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain"))  	} else { -		q = q.Where("domain = ?", domain) +		q = q.Where("? = ?", bun.Ident("account.domain"), domain)  	}  	count, err := q.Count(ctx) @@ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int  func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {  	q := i.conn.  		NewSelect(). -		Model(&[]*gtsmodel.Status{}) +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")) -	if domain == config.GetHost() { +	if domain == config.GetHost() || domain == config.GetAccountDomain() {  		// if the domain is *this* domain, just count where local is true -		q = q.Where("local = ?", true) +		q = q.Where("? = ?", bun.Ident("status.local"), true)  	} else {  		// join on the domain of the account -		q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). -			Where("account.domain = ?", domain) +		q = q. +			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")). +			Where("? = ?", bun.Ident("account.domain"), domain)  	}  	count, err := q.Count(ctx) @@ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (  func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {  	q := i.conn.  		NewSelect(). -		Model(&[]*gtsmodel.Instance{}) +		TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance"))  	if domain == config.GetHost() {  		// if the domain is *this* domain, just count other instances it knows about  		// exclude domains that are blocked  		q = q. -			Where("domain != ?", domain). -			Where("? IS NULL", bun.Ident("suspended_at")) +			Where("? != ?", bun.Ident("instance.domain"), domain). +			Where("? IS NULL", bun.Ident("instance.suspended_at"))  	} else {  		// TODO: implement federated domain counting properly for remote domains  		return 0, nil @@ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool  	q := i.conn.  		NewSelect().  		Model(&instances). -		Where("domain != ?", config.GetHost()) +		Where("? != ?", bun.Ident("instance.domain"), config.GetHost())  	if !includeSuspended { -		q = q.Where("? IS NULL", bun.Ident("suspended_at")) +		q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))  	}  	if err := q.Scan(ctx); err != nil { @@ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool  }  func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { -	log.Debug("GetAccountsForInstance") -  	accounts := []*gtsmodel.Account{}  	q := i.conn.NewSelect().  		Model(&accounts). -		Where("domain = ?", domain). -		Order("id DESC") +		Where("? = ?", bun.Ident("account.domain"), domain). +		Order("account.id DESC")  	if maxID != "" { -		q = q.Where("id < ?", maxID) +		q = q.Where("? < ?", bun.Ident("account.id"), maxID)  	}  	if limit > 0 { diff --git a/internal/db/bundb/instance_test.go b/internal/db/bundb/instance_test.go new file mode 100644 index 000000000..50d118888 --- /dev/null +++ b/internal/db/bundb/instance_test.go @@ -0,0 +1,83 @@ +/* +   GoToSocial +   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( +	"context" +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/config" +) + +type InstanceTestSuite struct { +	BunDBStandardTestSuite +} + +func (suite *InstanceTestSuite) TestCountInstanceUsers() { +	count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost()) +	suite.NoError(err) +	suite.Equal(4, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() { +	count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io") +	suite.NoError(err) +	suite.Equal(1, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceStatuses() { +	count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost()) +	suite.NoError(err) +	suite.Equal(16, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() { +	count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io") +	suite.NoError(err) +	suite.Equal(1, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceDomains() { +	count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost()) +	suite.NoError(err) +	suite.Equal(2, count) +} + +func (suite *InstanceTestSuite) TestGetInstancePeers() { +	peers, err := suite.db.GetInstancePeers(context.Background(), false) +	suite.NoError(err) +	suite.Len(peers, 2) +} + +func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() { +	peers, err := suite.db.GetInstancePeers(context.Background(), true) +	suite.NoError(err) +	suite.Len(peers, 2) +} + +func (suite *InstanceTestSuite) TestGetInstanceAccounts() { +	accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10) +	suite.NoError(err) +	suite.Len(accounts, 1) +} + +func TestInstanceTestSuite(t *testing.T) { +	suite.Run(t, new(InstanceTestSuite)) +} diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index 71433b901..39e0ad0e3 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M  	attachment := >smodel.MediaAttachment{}  	q := m.newMediaQ(attachment). -		Where("media_attachment.id = ?", id) +		Where("? = ?", bun.Ident("media_attachment.id"), id)  	if err := q.Scan(ctx); err != nil {  		return nil, m.conn.ProcessError(err) @@ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l  	q := m.conn.  		NewSelect().  		Model(&attachments). -		Where("media_attachment.cached = true"). -		Where("media_attachment.avatar = false"). -		Where("media_attachment.header = false"). -		Where("media_attachment.created_at < ?", olderThan). +		Where("? = ?", bun.Ident("media_attachment.cached"), true). +		Where("? = ?", bun.Ident("media_attachment.avatar"), false). +		Where("? = ?", bun.Ident("media_attachment.header"), false). +		Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).  		WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")).  		Order("media_attachment.created_at DESC") @@ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit  	q := m.newMediaQ(&attachments).  		WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {  			return innerQ. -				WhereOr("media_attachment.avatar = true"). -				WhereOr("media_attachment.header = true") +				WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true). +				WhereOr("? = ?", bun.Ident("media_attachment.header"), true)  		}).  		Order("media_attachment.id DESC")  	if maxID != "" { -		q = q.Where("media_attachment.id < ?", maxID) +		q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)  	}  	if limit != 0 { @@ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim  	attachments := []*gtsmodel.MediaAttachment{}  	q := m.newMediaQ(&attachments). -		Where("media_attachment.cached = true"). -		Where("media_attachment.avatar = false"). -		Where("media_attachment.header = false"). -		Where("media_attachment.created_at < ?", olderThan). -		Where("media_attachment.remote_url IS NULL"). -		Where("media_attachment.status_id IS NULL") +		Where("? = ?", bun.Ident("media_attachment.cached"), true). +		Where("? = ?", bun.Ident("media_attachment.avatar"), false). +		Where("? = ?", bun.Ident("media_attachment.header"), false). +		Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan). +		Where("? IS NULL", bun.Ident("media_attachment.remote_url")). +		Where("? IS NULL", bun.Ident("media_attachment.status_id"))  	if maxID != "" { -		q = q.Where("media_attachment.id < ?", maxID) +		q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)  	}  	if limit != 0 { diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index e2c83ef3f..355078021 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment  	mention := gtsmodel.Mention{}  	q := m.newMentionQ(&mention). -		Where("mention.id = ?", id) +		Where("? = ?", bun.Ident("mention.id"), id)  	if err := q.Scan(ctx); err != nil {  		return nil, m.conn.ProcessError(err) diff --git a/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go b/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go index 4c4ada594..b0179ec4f 100644 --- a/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go +++ b/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go @@ -47,8 +47,8 @@ func init() {  		}  		if _, err := tx.NewDelete(). -			Model(a). -			WherePK(). +			TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). +			Where("? = ?", bun.Ident("media_attachment.id"), a.ID).  			Exec(ctx); err != nil {  			l.Errorf("error deleting attachment with id %s: %s", a.ID, err)  		} else { diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 32523ca24..69e3cf39f 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -25,6 +25,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/uptrace/bun"  )  type notificationDB struct { @@ -44,7 +45,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo  		Relation("OriginAccount").  		Relation("TargetAccount").  		Relation("Status"). -		WherePK() +		Where("? = ?", bun.Ident("notification.id"), id)  	if err := q.Scan(ctx); err != nil {  		return nil, n.conn.ProcessError(err) @@ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,  	q := n.conn.  		NewSelect(). -		Table("notifications"). -		Column("id") +		TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). +		Column("notification.id")  	if maxID != "" { -		q = q.Where("id < ?", maxID) +		q = q.Where("? < ?", bun.Ident("notification.id"), maxID)  	}  	if sinceID != "" { -		q = q.Where("id > ?", sinceID) +		q = q.Where("? > ?", bun.Ident("notification.id"), sinceID)  	}  	for _, excludeType := range excludeTypes { -		q = q.Where("notification_type != ?", excludeType) +		q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType)  	}  	q = q. -		Where("target_account_id = ?", accountID). -		Order("id DESC") +		Where("? = ?", bun.Ident("notification.target_account_id"), accountID). +		Order("notification.id DESC")  	if limit != 0 {  		q = q.Limit(limit) @@ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,  func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error {  	if _, err := n.conn.  		NewDelete(). -		Table("notifications"). -		Where("target_account_id = ?", accountID). +		TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). +		Where("? = ?", bun.Ident("notification.target_account_id"), accountID).  		Exec(ctx); err != nil {  		return n.conn.ProcessError(err)  	}  	n.cache.Clear() -  	return nil  } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index ba72a053a..66e48e441 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {  func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {  	q := r.conn.  		NewSelect(). -		Model(>smodel.Block{}). -		ExcludeColumn("id", "created_at", "updated_at", "uri"). -		Limit(1) +		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). +		Column("block.id")  	if eitherDirection {  		q = q.  			WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {  				return inner. -					Where("account_id = ?", account1). -					Where("target_account_id = ?", account2) +					Where("? = ?", bun.Ident("block.account_id"), account1). +					Where("? = ?", bun.Ident("block.target_account_id"), account2)  			}).  			WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {  				return inner. -					Where("account_id = ?", account2). -					Where("target_account_id = ?", account1) +					Where("? = ?", bun.Ident("block.account_id"), account2). +					Where("? = ?", bun.Ident("block.target_account_id"), account1)  			})  	} else {  		q = q. -			Where("account_id = ?", account1). -			Where("target_account_id = ?", account2) +			Where("? = ?", bun.Ident("block.account_id"), account1). +			Where("? = ?", bun.Ident("block.target_account_id"), account2)  	}  	return r.conn.Exists(ctx, q) @@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2  	block := >smodel.Block{}  	q := r.newBlockQ(block). -		Where("block.account_id = ?", account1). -		Where("block.target_account_id = ?", account2) +		Where("? = ?", bun.Ident("block.account_id"), account1). +		Where("? = ?", bun.Ident("block.target_account_id"), account2)  	if err := q.Scan(ctx); err != nil {  		return nil, r.conn.ProcessError(err) @@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  	if err := r.conn.  		NewSelect().  		Model(follow). -		Where("account_id = ?", requestingAccount). -		Where("target_account_id = ?", targetAccount). +		Column("follow.show_reblogs", "follow.notify"). +		Where("? = ?", bun.Ident("follow.account_id"), requestingAccount). +		Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).  		Limit(1).  		Scan(ctx); err != nil { -		if err != sql.ErrNoRows { -			// a proper error -			return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) +		if err := r.conn.ProcessError(err); err != db.ErrNoEntries { +			return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)  		}  		// no follow exists so these are all false  		rel.Following = false @@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  	}  	// check if the target account follows the requesting account -	count, err := r.conn. +	followedByQ := r.conn.  		NewSelect(). -		Model(>smodel.Follow{}). -		Where("account_id = ?", targetAccount). -		Where("target_account_id = ?", requestingAccount). -		Limit(1). -		Count(ctx) +		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). +		Column("follow.id"). +		Where("? = ?", bun.Ident("follow.account_id"), targetAccount). +		Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount) +	followedBy, err := r.conn.Exists(ctx, followedByQ)  	if err != nil { -		return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) +		return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)  	} -	rel.FollowedBy = count > 0 +	rel.FollowedBy = followedBy -	// check if the requesting account blocks the target account -	count, err = r.conn.NewSelect(). -		Model(>smodel.Block{}). -		Where("account_id = ?", requestingAccount). -		Where("target_account_id = ?", targetAccount). -		Limit(1). -		Count(ctx) +	// check if there's a pending following request from requesting account to target account +	requestedQ := r.conn. +		NewSelect(). +		TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). +		Column("follow_request.id"). +		Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount). +		Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount) +	requested, err := r.conn.Exists(ctx, requestedQ)  	if err != nil { -		return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) +		return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)  	} -	rel.Blocking = count > 0 +	rel.Requested = requested -	// check if the target account blocks the requesting account -	count, err = r.conn. +	// check if the requesting account is blocking the target account +	blockingQ := r.conn.  		NewSelect(). -		Model(>smodel.Block{}). -		Where("account_id = ?", targetAccount). -		Where("target_account_id = ?", requestingAccount). -		Limit(1). -		Count(ctx) +		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). +		Column("block.id"). +		Where("? = ?", bun.Ident("block.account_id"), requestingAccount). +		Where("? = ?", bun.Ident("block.target_account_id"), targetAccount) +	blocking, err := r.conn.Exists(ctx, blockingQ)  	if err != nil { -		return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) +		return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)  	} -	rel.BlockedBy = count > 0 +	rel.Blocking = blocking -	// check if there's a pending following request from requesting account to target account -	count, err = r.conn. +	// check if the requesting account is blocked by the target account +	blockedByQ := r.conn.  		NewSelect(). -		Model(>smodel.FollowRequest{}). -		Where("account_id = ?", requestingAccount). -		Where("target_account_id = ?", targetAccount). -		Limit(1). -		Count(ctx) +		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). +		Column("block.id"). +		Where("? = ?", bun.Ident("block.account_id"), targetAccount). +		Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount) +	blockedBy, err := r.conn.Exists(ctx, blockedByQ)  	if err != nil { -		return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) +		return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)  	} -	rel.Requested = count > 0 +	rel.BlockedBy = blockedBy  	return rel, nil  } @@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode  	q := r.conn.  		NewSelect(). -		Model(>smodel.Follow{}). -		Where("account_id = ?", sourceAccount.ID). -		Where("target_account_id = ?", targetAccount.ID). -		Limit(1) +		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). +		Column("follow.id"). +		Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID). +		Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)  	return r.conn.Exists(ctx, q)  } @@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g  	q := r.conn.  		NewSelect(). -		Model(>smodel.FollowRequest{}). -		Where("account_id = ?", sourceAccount.ID). -		Where("target_account_id = ?", targetAccount.ID) +		TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). +		Column("follow_request.id"). +		Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID). +		Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)  	return r.conn.Exists(ctx, q)  } @@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod  }  func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { -	// make sure the original follow request exists -	fr := >smodel.FollowRequest{} -	if err := r.conn. -		NewSelect(). -		Model(fr). -		Where("account_id = ?", originAccountID). -		Where("target_account_id = ?", targetAccountID). -		Scan(ctx); err != nil { -		return nil, r.conn.ProcessError(err) -	} +	var follow *gtsmodel.Follow + +	if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { +		// get original follow request +		followRequest := >smodel.FollowRequest{} +		if err := tx. +			NewSelect(). +			Model(followRequest). +			Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). +			Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). +			Scan(ctx); err != nil { +			return err +		} -	// create a new follow to 'replace' the request with -	follow := >smodel.Follow{ -		ID:              fr.ID, -		AccountID:       originAccountID, -		TargetAccountID: targetAccountID, -		URI:             fr.URI, -	} +		// create a new follow to 'replace' the request with +		follow = >smodel.Follow{ +			ID:              followRequest.ID, +			AccountID:       originAccountID, +			TargetAccountID: targetAccountID, +			URI:             followRequest.URI, +		} -	// if the follow already exists, just update the URI -- we don't need to do anything else -	if _, err := r.conn. -		NewInsert(). -		Model(follow). -		On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI). -		Exec(ctx); err != nil { -		return nil, r.conn.ProcessError(err) -	} +		// if the follow already exists, just update the URI -- we don't need to do anything else +		if _, err := tx. +			NewInsert(). +			Model(follow). +			On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). +			Exec(ctx); err != nil { +			return err +		} + +		// now remove the follow request +		if _, err := tx. +			NewDelete(). +			TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). +			Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). +			Exec(ctx); err != nil { +			return err +		} -	// now remove the follow request -	if _, err := r.conn. -		NewDelete(). -		Model(>smodel.FollowRequest{}). -		Where("account_id = ?", originAccountID). -		Where("target_account_id = ?", targetAccountID). -		Exec(ctx); err != nil { +		return nil +	}); err != nil {  		return nil, r.conn.ProcessError(err)  	} +	// return the new follow  	return follow, nil  }  func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { -	// first get the follow request out of the database -	fr := >smodel.FollowRequest{} -	if err := r.conn. -		NewSelect(). -		Model(fr). -		Where("account_id = ?", originAccountID). -		Where("target_account_id = ?", targetAccountID). -		Scan(ctx); err != nil { -		return nil, r.conn.ProcessError(err) -	} +	followRequest := >smodel.FollowRequest{} + +	if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { +		// get original follow request +		if err := tx. +			NewSelect(). +			Model(followRequest). +			Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). +			Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). +			Scan(ctx); err != nil { +			return err +		} -	// now delete it from the database by ID -	if _, err := r.conn. -		NewDelete(). -		Model(>smodel.FollowRequest{ID: fr.ID}). -		WherePK(). -		Exec(ctx); err != nil { +		// now delete it from the database by ID +		if _, err := tx. +			NewDelete(). +			TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). +			Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). +			Exec(ctx); err != nil { +			return err +		} + +		return nil +	}); err != nil {  		return nil, r.conn.ProcessError(err)  	}  	// return the deleted follow request -	return fr, nil +	return followRequest, nil  }  func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {  	followRequests := []*gtsmodel.FollowRequest{}  	q := r.newFollowQ(&followRequests). -		Where("target_account_id = ?", accountID). +		Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID).  		Order("follow_request.updated_at DESC")  	if err := q.Scan(ctx); err != nil {  		return nil, r.conn.ProcessError(err)  	} +  	return followRequests, nil  } @@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string  	follows := []*gtsmodel.Follow{}  	q := r.newFollowQ(&follows). -		Where("account_id = ?", accountID). +		Where("? = ?", bun.Ident("follow.account_id"), accountID).  		Order("follow.updated_at DESC")  	if err := q.Scan(ctx); err != nil {  		return nil, r.conn.ProcessError(err)  	} +  	return follows, nil  }  func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { -	return r.conn. +	q := r.conn.  		NewSelect(). -		Model(&[]*gtsmodel.Follow{}). -		Where("account_id = ?", accountID). -		Count(ctx) +		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) + +	if localOnly { +		q = q. +			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")). +			Where("? = ?", bun.Ident("follow.account_id"), accountID). +			Where("? IS NULL", bun.Ident("account.domain")) +	} else { +		q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID) +	} + +	return q.Count(ctx)  }  func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { @@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str  		Order("follow.updated_at DESC")  	if localOnly { -		q = q.ColumnExpr("follow.*"). -			Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)"). -			Where("follow.target_account_id = ?", accountID). -			WhereGroup(" AND ", whereEmptyOrNull("a.domain")) +		q = q. +			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). +			Where("? = ?", bun.Ident("follow.target_account_id"), accountID). +			Where("? IS NULL", bun.Ident("account.domain"))  	} else { -		q = q.Where("target_account_id = ?", accountID) +		q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)  	}  	err := q.Scan(ctx) @@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str  }  func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { -	return r.conn. +	q := r.conn.  		NewSelect(). -		Model(&[]*gtsmodel.Follow{}). -		Where("target_account_id = ?", accountID). -		Count(ctx) +		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) + +	if localOnly { +		q = q. +			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). +			Where("? = ?", bun.Ident("follow.target_account_id"), accountID). +			Where("? IS NULL", bun.Ident("account.domain")) +	} else { +		q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) +	} + +	return q.Count(ctx)  } diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index 34fe85a57..3df16e2f3 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -20,7 +20,6 @@ package bundb_test  import (  	"context" -	"errors"  	"testing"  	"github.com/stretchr/testify/suite" @@ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {  	suite.False(blocked)  	// have account1 block account2 -	suite.db.Put(ctx, >smodel.Block{ +	if err := suite.db.Put(ctx, >smodel.Block{  		ID:              "01G202BCSXXJZ70BHB5KCAHH8C",  		URI:             "http://localhost:8080/some_block_uri_1",  		AccountID:       account1,  		TargetAccountID: account2, -	}) +	}); err != nil { +		suite.FailNow(err.Error()) +	}  	// account 1 now blocks account 2  	blocked, err = suite.db.IsBlocked(ctx, account1, account2, false) @@ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {  }  func (suite *RelationshipTestSuite) TestGetBlock() { -	suite.Suite.T().Skip("TODO: implement") +	ctx := context.Background() + +	account1 := suite.testAccounts["local_account_1"].ID +	account2 := suite.testAccounts["local_account_2"].ID + +	if err := suite.db.Put(ctx, >smodel.Block{ +		ID:              "01G202BCSXXJZ70BHB5KCAHH8C", +		URI:             "http://localhost:8080/some_block_uri_1", +		AccountID:       account1, +		TargetAccountID: account2, +	}); err != nil { +		suite.FailNow(err.Error()) +	} + +	block, err := suite.db.GetBlock(ctx, account1, account2) +	suite.NoError(err) +	suite.NotNil(block) +	suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)  }  func (suite *RelationshipTestSuite) TestGetRelationship() { -	suite.Suite.T().Skip("TODO: implement") +	requestingAccount := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["admin_account"] + +	relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID) +	suite.NoError(err) +	suite.NotNil(relationship) + +	suite.True(relationship.Following) +	suite.True(relationship.ShowingReblogs) +	suite.False(relationship.Notifying) +	suite.True(relationship.FollowedBy) +	suite.False(relationship.Blocking) +	suite.False(relationship.BlockedBy) +	suite.False(relationship.Muting) +	suite.False(relationship.MutingNotifications) +	suite.False(relationship.Requested) +	suite.False(relationship.DomainBlocking) +	suite.False(relationship.Endorsed) +	suite.Empty(relationship.Note) +} + +func (suite *RelationshipTestSuite) TestIsFollowingYes() { +	requestingAccount := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["admin_account"] +	isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) +	suite.NoError(err) +	suite.True(isFollowing)  } -func (suite *RelationshipTestSuite) TestIsFollowing() { -	suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestIsFollowingNo() { +	requestingAccount := suite.testAccounts["admin_account"] +	targetAccount := suite.testAccounts["local_account_2"] +	isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) +	suite.NoError(err) +	suite.False(isFollowing)  }  func (suite *RelationshipTestSuite) TestIsMutualFollowing() { -	suite.Suite.T().Skip("TODO: implement") +	requestingAccount := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["admin_account"] +	isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) +	suite.NoError(err) +	suite.True(isMutualFollowing) +} + +func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() { +	requestingAccount := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["local_account_2"] +	isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) +	suite.NoError(err) +	suite.True(isMutualFollowing)  } -func (suite *RelationshipTestSuite) AcceptFollowRequest() { -	for _, account := range suite.testAccounts { -		_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID") -		if err != nil && !errors.Is(err, db.ErrNoEntries) { -			suite.Suite.Fail("error accepting follow request: %v", err) -		} +func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { +	ctx := context.Background() +	account := suite.testAccounts["admin_account"] +	targetAccount := suite.testAccounts["local_account_2"] + +	followRequest := >smodel.FollowRequest{ +		ID:              "01GEF753FWHCHRDWR0QEHBXM8W", +		URI:             "http://localhost:8080/weeeeeeeeeeeeeeeee", +		AccountID:       account.ID, +		TargetAccountID: targetAccount.ID,  	} + +	if err := suite.db.Put(ctx, followRequest); err != nil { +		suite.FailNow(err.Error()) +	} + +	follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) +	suite.NoError(err) +	suite.NotNil(follow) +	suite.Equal(followRequest.URI, follow.URI)  } -func (suite *RelationshipTestSuite) GetAccountFollowRequests() { -	suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() { +	ctx := context.Background() +	account := suite.testAccounts["admin_account"] +	targetAccount := suite.testAccounts["local_account_2"] + +	follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) +	suite.ErrorIs(err, db.ErrNoEntries) +	suite.Nil(follow)  } -func (suite *RelationshipTestSuite) GetAccountFollows() { -	suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() { +	ctx := context.Background() +	account := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["admin_account"] + +	// follow already exists in the db from local_account_1 -> admin_account +	existingFollow := >smodel.Follow{} +	if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil { +		suite.FailNow(err.Error()) +	} + +	followRequest := >smodel.FollowRequest{ +		ID:              "01GEF753FWHCHRDWR0QEHBXM8W", +		URI:             "http://localhost:8080/weeeeeeeeeeeeeeeee", +		AccountID:       account.ID, +		TargetAccountID: targetAccount.ID, +	} + +	if err := suite.db.Put(ctx, followRequest); err != nil { +		suite.FailNow(err.Error()) +	} + +	follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) +	suite.NoError(err) +	suite.NotNil(follow) + +	// uri should be equal to value of new/overlapping follow request +	suite.NotEqual(followRequest.URI, existingFollow.URI) +	suite.Equal(followRequest.URI, follow.URI)  } -func (suite *RelationshipTestSuite) CountAccountFollows() { -	suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() { +	ctx := context.Background() +	account := suite.testAccounts["admin_account"] +	targetAccount := suite.testAccounts["local_account_2"] + +	followRequest := >smodel.FollowRequest{ +		ID:              "01GEF753FWHCHRDWR0QEHBXM8W", +		URI:             "http://localhost:8080/weeeeeeeeeeeeeeeee", +		AccountID:       account.ID, +		TargetAccountID: targetAccount.ID, +	} + +	if err := suite.db.Put(ctx, followRequest); err != nil { +		suite.FailNow(err.Error()) +	} + +	rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) +	suite.NoError(err) +	suite.NotNil(rejectedFollowRequest)  } -func (suite *RelationshipTestSuite) GetAccountFollowedBy() { -	// TODO: more comprehensive tests here +func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() { +	ctx := context.Background() +	account := suite.testAccounts["admin_account"] +	targetAccount := suite.testAccounts["local_account_2"] -	for _, account := range suite.testAccounts { -		var err error +	rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) +	suite.ErrorIs(err, db.ErrNoEntries) +	suite.Nil(rejectedFollowRequest) +} -		_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) -		if err != nil { -			suite.Suite.Fail("error checking accounts followed by: %v", err) -		} +func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() { +	ctx := context.Background() +	account := suite.testAccounts["admin_account"] +	targetAccount := suite.testAccounts["local_account_2"] -		_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) -		if err != nil { -			suite.Suite.Fail("error checking localOnly accounts followed by: %v", err) -		} +	followRequest := >smodel.FollowRequest{ +		ID:              "01GEF753FWHCHRDWR0QEHBXM8W", +		URI:             "http://localhost:8080/weeeeeeeeeeeeeeeee", +		AccountID:       account.ID, +		TargetAccountID: targetAccount.ID,  	} + +	if err := suite.db.Put(ctx, followRequest); err != nil { +		suite.FailNow(err.Error()) +	} + +	followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) +	suite.NoError(err) +	suite.Len(followRequests, 1)  } -func (suite *RelationshipTestSuite) CountAccountFollowedBy() { -	suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestGetAccountFollows() { +	account := suite.testAccounts["local_account_1"] +	follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) +	suite.NoError(err) +	suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() { +	account := suite.testAccounts["local_account_1"] +	followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true) +	suite.NoError(err) +	suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollows() { +	account := suite.testAccounts["local_account_1"] +	followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false) +	suite.NoError(err) +	suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() { +	account := suite.testAccounts["local_account_1"] +	follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) +	suite.NoError(err) +	suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() { +	account := suite.testAccounts["local_account_1"] +	follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) +	suite.NoError(err) +	suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() { +	account := suite.testAccounts["local_account_1"] +	followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false) +	suite.NoError(err) +	suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() { +	account := suite.testAccounts["local_account_1"] +	followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true) +	suite.NoError(err) +	suite.Equal(2, followsCount)  }  func TestRelationshipTestSuite(t *testing.T) { diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go index 9138072e1..b9e70a89f 100644 --- a/internal/db/bundb/session.go +++ b/internal/db/bundb/session.go @@ -21,7 +21,6 @@ package bundb  import (  	"context"  	"crypto/rand" -	"errors"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -35,29 +34,22 @@ type sessionDB struct {  func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {  	rss := make([]*gtsmodel.RouterSession, 0, 1) -	_, err := s.conn. +	// get the first router session in the db or... +	if err := s.conn.  		NewSelect().  		Model(&rss).  		Limit(1). -		Order("id DESC"). -		Exec(ctx) -	if err != nil { +		Order("router_session.id DESC"). +		Scan(ctx); err != nil {  		return nil, s.conn.ProcessError(err)  	} +	// ... create a new one  	if len(rss) == 0 { -		// no session created yet, so make one  		return s.createSession(ctx)  	} -	if len(rss) != 1 { -		// we asked for 1 so we should get 1 -		return nil, errors.New("more than 1 router session was returned") -	} - -	// return the one session found -	rs := rss[0] -	return rs, nil +	return rss[0], nil  }  func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { @@ -71,24 +63,23 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession,  		return nil, err  	} -	rid, err := id.NewULID() +	id, err := id.NewULID()  	if err != nil {  		return nil, err  	}  	rs := >smodel.RouterSession{ -		ID:    rid, +		ID:    id,  		Auth:  auth,  		Crypt: crypt,  	} -	q := s.conn. +	if _, err := s.conn.  		NewInsert(). -		Model(rs) - -	_, err = q.Exec(ctx) -	if err != nil { +		Model(rs). +		Exec(ctx); err != nil {  		return nil, s.conn.ProcessError(err)  	} +  	return rs, nil  } diff --git a/internal/db/bundb/session_test.go b/internal/db/bundb/session_test.go index ef508bde8..1e7fde5aa 100644 --- a/internal/db/bundb/session_test.go +++ b/internal/db/bundb/session_test.go @@ -37,14 +37,13 @@ func (suite *SessionTestSuite) TestGetSession() {  	suite.NotEmpty(session.Crypt)  	suite.NotEmpty(session.ID) -	// TODO -- the same session should be returned with consecutive selects -	// right now there's an issue with bytea in bun, so uncomment this when that issue is fixed: https://github.com/uptrace/bun/issues/122 -	// session2, err := suite.db.GetSession(context.Background()) -	// suite.NoError(err) -	// suite.NotNil(session2) -	// suite.Equal(session.Auth, session2.Auth) -	// suite.Equal(session.Crypt, session2.Crypt) -	// suite.Equal(session.ID, session2.ID) +	// the same session should be returned with consecutive selects +	session2, err := suite.db.GetSession(context.Background()) +	suite.NoError(err) +	suite.NotNil(session2) +	suite.Equal(session.Auth, session2.Auth) +	suite.Equal(session.Crypt, session2.Crypt) +	suite.Equal(session.ID, session2.ID)  }  func TestSessionTestSuite(t *testing.T) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 2d920ee3f..bc72c2849 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -72,7 +72,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat  			return s.cache.GetByID(id)  		},  		func(status *gtsmodel.Status) error { -			return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx) +			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)  		},  	)  } @@ -84,7 +84,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St  			return s.cache.GetByURI(uri)  		},  		func(status *gtsmodel.Status) error { -			return s.newStatusQ(status).Where("status.uri = ?", uri).Scan(ctx) +			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)  		},  	)  } @@ -96,7 +96,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St  			return s.cache.GetByURL(url)  		},  		func(status *gtsmodel.Status) error { -			return s.newStatusQ(status).Where("status.url = ?", url).Scan(ctx) +			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)  		},  	)  } @@ -109,8 +109,7 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta  		status = >smodel.Status{}  		// Not cached! Perform database query -		err := dbQuery(status) -		if err != nil { +		if err := dbQuery(status); err != nil {  			return nil, s.conn.ProcessError(err)  		} @@ -138,24 +137,34 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta  }  func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { -	return s.conn.RunInTx(ctx, func(tx bun.Tx) error { +	err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {  		// create links between this status and any emojis it uses  		for _, i := range status.EmojiIDs { -			if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ -				StatusID: status.ID, -				EmojiID:  i, -			}).Exec(ctx); err != nil { -				return err +			if _, err := tx. +				NewInsert(). +				Model(>smodel.StatusToEmoji{ +					StatusID: status.ID, +					EmojiID:  i, +				}).Exec(ctx); err != nil { +				err = s.conn.errProc(err) +				if !errors.Is(err, db.ErrAlreadyExists) { +					return err +				}  			}  		}  		// create links between this status and any tags it uses  		for _, i := range status.TagIDs { -			if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ -				StatusID: status.ID, -				TagID:    i, -			}).Exec(ctx); err != nil { -				return err +			if _, err := tx. +				NewInsert(). +				Model(>smodel.StatusToTag{ +					StatusID: status.ID, +					TagID:    i, +				}).Exec(ctx); err != nil { +				err = s.conn.errProc(err) +				if !errors.Is(err, db.ErrAlreadyExists) { +					return err +				}  			}  		} @@ -163,27 +172,46 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er  		for _, a := range status.Attachments {  			a.StatusID = status.ID  			a.UpdatedAt = time.Now() -			if _, err := tx.NewUpdate().Model(a). -				Where("id = ?", a.ID). +			if _, err := tx. +				NewUpdate(). +				Model(a). +				Where("? = ?", bun.Ident("media_attachment.id"), a.ID).  				Exec(ctx); err != nil { -				return err +				err = s.conn.errProc(err) +				if !errors.Is(err, db.ErrAlreadyExists) { +					return err +				}  			}  		}  		// Finally, insert the status -		_, err := tx.NewInsert().Model(status).Exec(ctx) -		return err +		if _, err := tx. +			NewInsert(). +			Model(status). +			Exec(ctx); err != nil { +			return err +		} + +		return nil  	}) +	if err != nil { +		return s.conn.ProcessError(err) +	} + +	s.cache.Put(status) +	return nil  }  func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) {  	err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {  		// create links between this status and any emojis it uses  		for _, i := range status.EmojiIDs { -			if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ -				StatusID: status.ID, -				EmojiID:  i, -			}).Exec(ctx); err != nil { +			if _, err := tx. +				NewInsert(). +				Model(>smodel.StatusToEmoji{ +					StatusID: status.ID, +					EmojiID:  i, +				}).Exec(ctx); err != nil {  				err = s.conn.errProc(err)  				if !errors.Is(err, db.ErrAlreadyExists) {  					return err @@ -193,10 +221,12 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*  		// create links between this status and any tags it uses  		for _, i := range status.TagIDs { -			if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ -				StatusID: status.ID, -				TagID:    i, -			}).Exec(ctx); err != nil { +			if _, err := tx. +				NewInsert(). +				Model(>smodel.StatusToTag{ +					StatusID: status.ID, +					TagID:    i, +				}).Exec(ctx); err != nil {  				err = s.conn.errProc(err)  				if !errors.Is(err, db.ErrAlreadyExists) {  					return err @@ -208,23 +238,32 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*  		for _, a := range status.Attachments {  			a.StatusID = status.ID  			a.UpdatedAt = time.Now() -			if _, err := tx.NewUpdate().Model(a). -				Where("id = ?", a.ID). +			if _, err := tx. +				NewUpdate(). +				Model(a). +				Where("? = ?", bun.Ident("media_attachment.id"), a.ID).  				Exec(ctx); err != nil {  				return err  			}  		}  		// Finally, update the status itself -		if _, err := tx.NewUpdate().Model(status).WherePK().Exec(ctx); err != nil { +		if _, err := tx. +			NewUpdate(). +			Model(status). +			Where("? = ?", bun.Ident("status.id"), status.ID). +			Exec(ctx); err != nil {  			return err  		} -		s.cache.Put(status)  		return nil  	}) +	if err != nil { +		return nil, s.conn.ProcessError(err) +	} -	return status, err +	s.cache.Put(status) +	return status, nil  }  func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { @@ -232,8 +271,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {  		// delete links between this status and any emojis it uses  		if _, err := tx.  			NewDelete(). -			Model(>smodel.StatusToEmoji{}). -			Where("status_id = ?", id). +			TableExpr("? AS ?", bun.Ident("status_to_emojis"), bun.Ident("status_to_emoji")). +			Where("? = ?", bun.Ident("status_to_emoji.status_id"), id).  			Exec(ctx); err != nil {  			return err  		} @@ -241,8 +280,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {  		// delete links between this status and any tags it uses  		if _, err := tx.  			NewDelete(). -			Model(>smodel.StatusToTag{}). -			Where("status_id = ?", id). +			TableExpr("? AS ?", bun.Ident("status_to_tags"), bun.Ident("status_to_tag")). +			Where("? = ?", bun.Ident("status_to_tag.status_id"), id).  			Exec(ctx); err != nil {  			return err  		} @@ -250,17 +289,20 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {  		// delete the status itself  		if _, err := tx.  			NewDelete(). -			Model(>smodel.Status{ID: id}). -			WherePK(). +			TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +			Where("? = ?", bun.Ident("status.id"), id).  			Exec(ctx); err != nil {  			return err  		} -		s.cache.Invalidate(id)  		return nil  	}) +	if err != nil { +		return s.conn.ProcessError(err) +	} -	return s.conn.ProcessError(err) +	s.cache.Invalidate(id) +	return nil  }  func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { @@ -312,11 +354,11 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,  	q := s.conn.  		NewSelect(). -		Table("statuses"). -		Column("id"). -		Where("in_reply_to_id = ?", status.ID) +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		Column("status.id"). +		Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID)  	if minID != "" { -		q = q.Where("id > ?", minID) +		q = q.Where("? > ?", bun.Ident("status.id"), minID)  	}  	if err := q.Scan(ctx, &childIDs); err != nil { @@ -356,23 +398,35 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,  }  func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { -	return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx) +	return s.conn. +		NewSelect(). +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID). +		Count(ctx)  }  func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { -	return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx) +	return s.conn. +		NewSelect(). +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). +		Count(ctx)  }  func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { -	return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx) +	return s.conn. +		NewSelect(). +		TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). +		Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). +		Count(ctx)  }  func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {  	q := s.conn.  		NewSelect(). -		Model(>smodel.StatusFave{}). -		Where("status_id = ?", status.ID). -		Where("account_id = ?", accountID) +		TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). +		Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). +		Where("? = ?", bun.Ident("status_fave.account_id"), accountID)  	return s.conn.Exists(ctx, q)  } @@ -380,9 +434,9 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,  func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {  	q := s.conn.  		NewSelect(). -		Model(>smodel.Status{}). -		Where("boost_of_id = ?", status.ID). -		Where("account_id = ?", accountID) +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). +		Where("? = ?", bun.Ident("status.account_id"), accountID)  	return s.conn.Exists(ctx, q)  } @@ -390,9 +444,9 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta  func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {  	q := s.conn.  		NewSelect(). -		Model(>smodel.StatusMute{}). -		Where("status_id = ?", status.ID). -		Where("account_id = ?", accountID) +		TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")). +		Where("? = ?", bun.Ident("status_mute.status_id"), status.ID). +		Where("? = ?", bun.Ident("status_mute.account_id"), accountID)  	return s.conn.Exists(ctx, q)  } @@ -400,9 +454,9 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,  func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {  	q := s.conn.  		NewSelect(). -		Model(>smodel.StatusBookmark{}). -		Where("status_id = ?", status.ID). -		Where("account_id = ?", accountID) +		TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). +		Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID). +		Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID)  	return s.conn.Exists(ctx, q)  } @@ -410,8 +464,9 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St  func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {  	faves := []*gtsmodel.StatusFave{} -	q := s.newFaveQ(&faves). -		Where("status_id = ?", status.ID) +	q := s. +		newFaveQ(&faves). +		Where("? = ?", bun.Ident("status_fave.status_id"), status.ID)  	if err := q.Scan(ctx); err != nil {  		return nil, s.conn.ProcessError(err) @@ -422,8 +477,9 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)  func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {  	reblogs := []*gtsmodel.Status{} -	q := s.newStatusQ(&reblogs). -		Where("boost_of_id = ?", status.ID) +	q := s. +		newStatusQ(&reblogs). +		Where("? = ?", bun.Ident("status.boost_of_id"), status.ID)  	if err := q.Scan(ctx); err != nil {  		return nil, s.conn.ProcessError(err) diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index a796ebdad..70bc7b845 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -108,14 +108,14 @@ func (suite *StatusTestSuite) TestGetStatusTwice() {  	suite.NoError(err)  	after1 := time.Now()  	duration1 := after1.Sub(before1) -	fmt.Println(duration1.Milliseconds()) +	fmt.Println(duration1.Microseconds())  	before2 := time.Now()  	_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)  	suite.NoError(err)  	after2 := time.Now()  	duration2 := after2.Sub(before2) -	fmt.Println(duration2.Milliseconds()) +	fmt.Println(duration2.Microseconds())  	// second retrieval should be several orders faster since it will be cached now  	suite.Less(duration2, duration1) diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index d2b3cf07e..d4740dd96 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -34,38 +34,48 @@ type timelineDB struct {  }  func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { +	// Ensure reasonable +	if limit < 0 { +		limit = 0 +	} +  	// Make educated guess for slice size  	statusIDs := make([]string, 0, limit)  	q := t.conn.  		NewSelect(). -		Table("statuses"). - +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).  		// Select only IDs from table -		Column("statuses.id"). +		Column("status.id").  		// Find out who accountID follows. -		Join("LEFT JOIN follows ON follows.target_account_id = statuses.account_id AND follows.account_id = ?", accountID). +		Join("LEFT JOIN ? AS ? ON ? = ? AND ? = ?", +			bun.Ident("follows"), +			bun.Ident("follow"), +			bun.Ident("follow.target_account_id"), +			bun.Ident("status.account_id"), +			bun.Ident("follow.account_id"), +			accountID).  		// Sort by highest ID (newest) to lowest ID (oldest) -		Order("statuses.id DESC") +		Order("status.id DESC")  	if maxID != "" {  		// return only statuses LOWER (ie., older) than maxID -		q = q.Where("statuses.id < ?", maxID) +		q = q.Where("? < ?", bun.Ident("status.id"), maxID)  	}  	if sinceID != "" {  		// return only statuses HIGHER (ie., newer) than sinceID -		q = q.Where("statuses.id > ?", sinceID) +		q = q.Where("? > ?", bun.Ident("status.id"), sinceID)  	}  	if minID != "" {  		// return only statuses HIGHER (ie., newer) than minID -		q = q.Where("statuses.id > ?", minID) +		q = q.Where("? > ?", bun.Ident("status.id"), minID)  	}  	if local {  		// return only statuses posted by local account havers -		q = q.Where("statuses.local = ?", local) +		q = q.Where("? = ?", bun.Ident("status.local"), local)  	}  	if limit > 0 { @@ -78,13 +88,11 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI  	//  	// This is equivalent to something like WHERE ... AND (... OR ...)  	// See: https://bun.uptrace.dev/guide/queries.html#select -	whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { +	q = q.WhereGroup(" AND ", func(*bun.SelectQuery) *bun.SelectQuery {  		return q. -			WhereOr("follows.account_id = ?", accountID). -			WhereOr("statuses.account_id = ?", accountID) -	} - -	q = q.WhereGroup(" AND ", whereGroup) +			WhereOr("? = ?", bun.Ident("follow.account_id"), accountID). +			WhereOr("? = ?", bun.Ident("status.account_id"), accountID) +	})  	if err := q.Scan(ctx, &statusIDs); err != nil {  		return nil, t.conn.ProcessError(err) @@ -118,28 +126,28 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma  	q := t.conn.  		NewSelect(). -		Table("statuses"). -		Column("statuses.id"). -		Where("statuses.visibility = ?", gtsmodel.VisibilityPublic). -		WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_id")). -		WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_uri")). -		WhereGroup(" AND ", whereEmptyOrNull("statuses.boost_of_id")). -		Order("statuses.id DESC") +		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). +		Column("status.id"). +		Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). +		WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")). +		WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). +		WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). +		Order("status.id DESC")  	if maxID != "" { -		q = q.Where("statuses.id < ?", maxID) +		q = q.Where("? < ?", bun.Ident("status.id"), maxID)  	}  	if sinceID != "" { -		q = q.Where("statuses.id > ?", sinceID) +		q = q.Where("? > ?", bun.Ident("status.id"), sinceID)  	}  	if minID != "" { -		q = q.Where("statuses.id > ?", minID) +		q = q.Where("? > ?", bun.Ident("status.id"), minID)  	}  	if local { -		q = q.Where("statuses.local = ?", local) +		q = q.Where("? = ?", bun.Ident("status.local"), local)  	}  	if limit > 0 { @@ -181,15 +189,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max  	fq := t.conn.  		NewSelect().  		Model(&faves). -		Where("account_id = ?", accountID). -		Order("id DESC") +		Where("? = ?", bun.Ident("status_fave.account_id"), accountID). +		Order("status_fave.id DESC")  	if maxID != "" { -		fq = fq.Where("id < ?", maxID) +		fq = fq.Where("? < ?", bun.Ident("status_fave.id"), maxID)  	}  	if minID != "" { -		fq = fq.Where("id > ?", minID) +		fq = fq.Where("? > ?", bun.Ident("status_fave.id"), minID)  	}  	if limit > 0 { diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index 2e991ac93..c14d72056 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -38,6 +38,15 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() {  	suite.Len(s, 6)  } +func (suite *TimelineTestSuite) TestGetHomeTimeline() { +	viewingAccount := suite.testAccounts["local_account_1"] + +	s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) +	suite.NoError(err) + +	suite.Len(s, 16) +} +  func TestTimelineTestSuite(t *testing.T) {  	suite.Run(t, new(TimelineTestSuite))  } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 46f24c4b2..aa2f4c2c8 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db  			return u.cache.GetByID(id)  		},  		func(user *gtsmodel.User) error { -			return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx) +			return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx)  		},  	)  } @@ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts  			return u.cache.GetByAccountID(accountID)  		},  		func(user *gtsmodel.User) error { -			return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx) +			return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx)  		},  	)  } @@ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string)  			return u.cache.GetByEmail(emailAddress)  		},  		func(user *gtsmodel.User) error { -			return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx) +			return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx)  		},  	)  } @@ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok  			return u.cache.GetByConfirmationToken(confirmationToken)  		},  		func(user *gtsmodel.User) error { -			return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx) +			return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx)  		},  	)  } @@ -127,7 +127,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..  	if _, err := u.conn.  		NewUpdate().  		Model(user). -		WherePK(). +		Where("? = ?", bun.Ident("user.id"), user.ID).  		Column(columns...).  		Exec(ctx); err != nil {  		return nil, u.conn.ProcessError(err) @@ -140,8 +140,8 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..  func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {  	if _, err := u.conn.  		NewDelete(). -		Model(>smodel.User{ID: userID}). -		WherePK(). +		TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). +		Where("? = ?", bun.Ident("user.id"), userID).  		Exec(ctx); err != nil {  		return u.conn.ProcessError(err)  	} diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index 434d12f32..34f7eb76f 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) {  			return  		} -		if w.CaseInsensitive { -			query = "LOWER(?) != LOWER(?)" -			args = []interface{}{bun.Safe(w.Key), w.Value} -			return -		} -  		query = "? != ?" -		args = []interface{}{bun.Safe(w.Key), w.Value} +		args = []interface{}{bun.Ident(w.Key), w.Value}  		return  	} @@ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) {  		return  	} -	if w.CaseInsensitive { -		query = "LOWER(?) = LOWER(?)" -		args = []interface{}{bun.Safe(w.Key), w.Value} -		return -	} -  	query = "? = ?" -	args = []interface{}{bun.Safe(w.Key), w.Value} +	args = []interface{}{bun.Ident(w.Key), w.Value}  	return  } diff --git a/internal/db/params.go b/internal/db/params.go index d1809f1c4..84694d6d3 100644 --- a/internal/db/params.go +++ b/internal/db/params.go @@ -24,9 +24,6 @@ type Where struct {  	Key string  	// The value to match.  	Value interface{} -	// Whether the value (if a string) should be case sensitive or not. -	// Defaults to false. -	CaseInsensitive bool  	// If set, reverse the where.  	// `WHERE k = v` becomes `WHERE k != v`.  	// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL` | 
