diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/account.go | 60 | ||||
| -rw-r--r-- | internal/db/bundb/account_test.go | 59 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 17 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go | 69 | 
4 files changed, 187 insertions, 18 deletions
| diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 2105368d3..074804690 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -45,7 +45,8 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {  		NewSelect().  		Model(account).  		Relation("AvatarMediaAttachment"). -		Relation("HeaderMediaAttachment") +		Relation("HeaderMediaAttachment"). +		Relation("Emojis")  }  func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { @@ -138,24 +139,61 @@ func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.A  	return account, nil  } +func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { +	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { +		// create links between this account and any emojis it uses +		for _, i := range account.EmojiIDs { +			if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ +				AccountID: account.ID, +				EmojiID:   i, +			}).Exec(ctx); err != nil { +				return err +			} +		} + +		// insert the account +		_, err := tx.NewInsert().Model(account).Exec(ctx) +		return err +	}); err != nil { +		return nil, a.conn.ProcessError(err) +	} + +	a.cache.Put(account) +	return account, nil +} +  func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {  	// Update the account's last-updated  	account.UpdatedAt = time.Now() -	// Update the account model in the DB -	_, err := a.conn. -		NewUpdate(). -		Model(account). -		WherePK(). -		Exec(ctx) -	if err != nil { +	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { +		// create links between this account and any emojis it uses +		// first clear out any old emoji links +		if _, err := tx.NewDelete(). +			Model(&[]*gtsmodel.AccountToEmoji{}). +			Where("account_id = ?", account.ID). +			Exec(ctx); err != nil { +			return err +		} + +		// now populate new emoji links +		for _, i := range account.EmojiIDs { +			if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ +				AccountID: account.ID, +				EmojiID:   i, +			}).Exec(ctx); err != nil { +				return err +			} +		} + +		// update the account +		_, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx) +		return err +	}); err != nil {  		return nil, a.conn.ProcessError(err)  	} -	// Place updated account in cache -	// (this will replace existing, i.e. invalidating)  	a.cache.Put(account) -  	return account, nil  } diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 3c19e84d9..1e6dc4436 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -27,7 +27,9 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/ap" +	"github.com/superseriousbusiness/gotosocial/internal/db/bundb"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun"  )  type AccountTestSuite struct { @@ -71,17 +73,70 @@ func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {  }  func (suite *AccountTestSuite) TestUpdateAccount() { +	ctx := context.Background() +  	testAccount := suite.testAccounts["local_account_1"]  	testAccount.DisplayName = "new display name!" +	testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"} + +	_, err := suite.db.UpdateAccount(ctx, testAccount) +	suite.NoError(err) + +	updated, err := suite.db.GetAccountByID(ctx, testAccount.ID) +	suite.NoError(err) +	suite.Equal("new display name!", updated.DisplayName) +	suite.Equal([]string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}, updated.EmojiIDs) +	suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) + +	// get account without cache + make sure it's really in the db as desired +	dbService, ok := suite.db.(*bundb.DBService) +	if !ok { +		panic("db was not *bundb.DBService") +	} + +	noCache := >smodel.Account{} +	err = dbService.GetConn(). +		NewSelect(). +		Model(noCache). +		Where("account.id = ?", bun.Ident(testAccount.ID)). +		Relation("AvatarMediaAttachment"). +		Relation("HeaderMediaAttachment"). +		Relation("Emojis"). +		Scan(ctx) + +	suite.NoError(err) +	suite.Equal("new display name!", noCache.DisplayName) +	suite.Equal([]string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}, noCache.EmojiIDs) +	suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second) +	suite.NotNil(noCache.AvatarMediaAttachment) +	suite.NotNil(noCache.HeaderMediaAttachment) -	_, err := suite.db.UpdateAccount(context.Background(), testAccount) +	// update again to remove emoji associations +	testAccount.EmojiIDs = []string{} + +	_, err = suite.db.UpdateAccount(ctx, testAccount)  	suite.NoError(err) -	updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) +	updated, err = suite.db.GetAccountByID(ctx, testAccount.ID)  	suite.NoError(err)  	suite.Equal("new display name!", updated.DisplayName) +	suite.Empty(updated.EmojiIDs)  	suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) + +	err = dbService.GetConn(). +		NewSelect(). +		Model(noCache). +		Where("account.id = ?", bun.Ident(testAccount.ID)). +		Relation("AvatarMediaAttachment"). +		Relation("HeaderMediaAttachment"). +		Relation("Emojis"). +		Scan(ctx) + +	suite.NoError(err) +	suite.Equal("new display name!", noCache.DisplayName) +	suite.Empty(noCache.EmojiIDs) +	suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second)  }  func (suite *AccountTestSuite) TestInsertAccountWithDefaults() { diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index b944ae3ea..2fc65364f 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -67,12 +67,13 @@ const (  )  var registerTables = []interface{}{ +	>smodel.AccountToEmoji{},  	>smodel.StatusToEmoji{},  	>smodel.StatusToTag{},  } -// bunDBService satisfies the DB interface -type bunDBService struct { +// DBService satisfies the DB interface +type DBService struct {  	db.Account  	db.Admin  	db.Basic @@ -89,6 +90,12 @@ type bunDBService struct {  	conn *DBConn  } +// GetConn returns the underlying bun connection. +// Should only be used in testing + exceptional circumstance. +func (dbService *DBService) GetConn() *DBConn { +	return dbService.conn +} +  func doMigration(ctx context.Context, db *bun.DB) error {  	migrator := migrate.NewMigrator(db, migrations.Migrations) @@ -177,7 +184,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {  	// Prepare domain block cache  	blockCache := cache.NewDomainBlockCache() -	ps := &bunDBService{ +	ps := &DBService{  		Account: accounts,  		Admin: &adminDB{  			conn: conn, @@ -399,7 +406,7 @@ func tweakConnectionValues(sqldb *sql.DB) {  	CONVERSION FUNCTIONS  */ -func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) { +func (dbService *DBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) {  	protocol := config.GetProtocol()  	host := config.GetHost() @@ -408,7 +415,7 @@ func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, ori  		tag := >smodel.Tag{}  		// we can use selectorinsert here to create the new tag if it doesn't exist already  		// inserted will be true if this is a new tag we just created -		if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil { +		if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {  			if err == sql.ErrNoRows {  				// tag doesn't exist yet so populate it  				newID, err := id.NewRandomULID() diff --git a/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go b/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go new file mode 100644 index 000000000..91468a4c9 --- /dev/null +++ b/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go @@ -0,0 +1,69 @@ +/* +   GoToSocial +   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package migrations + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/uptrace/bun" +	"github.com/uptrace/bun/dialect" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			q := tx.NewAddColumn().Model(>smodel.Account{}) + +			switch tx.Dialect().Name() { +			case dialect.PG: +				q = q.ColumnExpr("? VARCHAR[]", bun.Ident("emojis")) +			case dialect.SQLite: +				q = q.ColumnExpr("? VARCHAR", bun.Ident("emojis")) +			default: +				log.Panic("db dialect was neither pg nor sqlite") +			} + +			if _, err := q.Exec(ctx); err != nil { +				return err +			} + +			if _, err := tx. +				NewCreateTable(). +				Model(>smodel.AccountToEmoji{}). +				IfNotExists(). +				Exec(ctx); err != nil { +				return err +			} + +			return nil +		}) +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} | 
