diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/account.go | 3 | ||||
-rw-r--r-- | internal/db/bundb/account.go | 60 | ||||
-rw-r--r-- | internal/db/bundb/account_test.go | 59 | ||||
-rw-r--r-- | internal/db/bundb/bundb.go | 17 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go | 69 |
5 files changed, 190 insertions, 18 deletions
diff --git a/internal/db/account.go b/internal/db/account.go index 5f1336872..351d6d01c 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -42,6 +42,9 @@ type Account interface { // GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong. GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error) + // PutAccount puts one account in the database. + PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) + // UpdateAccount updates one account by ID. UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 2105368d3..074804690 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -45,7 +45,8 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { NewSelect(). Model(account). Relation("AvatarMediaAttachment"). - Relation("HeaderMediaAttachment") + Relation("HeaderMediaAttachment"). + Relation("Emojis") } func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { @@ -138,24 +139,61 @@ func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.A return account, nil } +func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { + if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this account and any emojis it uses + for _, i := range account.EmojiIDs { + if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } + } + + // insert the account + _, err := tx.NewInsert().Model(account).Exec(ctx) + return err + }); err != nil { + return nil, a.conn.ProcessError(err) + } + + a.cache.Put(account) + return account, nil +} + func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { // Update the account's last-updated account.UpdatedAt = time.Now() - // Update the account model in the DB - _, err := a.conn. - NewUpdate(). - Model(account). - WherePK(). - Exec(ctx) - if err != nil { + if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this account and any emojis it uses + // first clear out any old emoji links + if _, err := tx.NewDelete(). + Model(&[]*gtsmodel.AccountToEmoji{}). + Where("account_id = ?", account.ID). + Exec(ctx); err != nil { + return err + } + + // now populate new emoji links + for _, i := range account.EmojiIDs { + if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } + } + + // update the account + _, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx) + return err + }); err != nil { return nil, a.conn.ProcessError(err) } - // Place updated account in cache - // (this will replace existing, i.e. invalidating) a.cache.Put(account) - return account, nil } diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 3c19e84d9..1e6dc4436 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -27,7 +27,9 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type AccountTestSuite struct { @@ -71,17 +73,70 @@ func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() { } func (suite *AccountTestSuite) TestUpdateAccount() { + ctx := context.Background() + testAccount := suite.testAccounts["local_account_1"] testAccount.DisplayName = "new display name!" + testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"} + + _, err := suite.db.UpdateAccount(ctx, testAccount) + suite.NoError(err) + + updated, err := suite.db.GetAccountByID(ctx, testAccount.ID) + suite.NoError(err) + suite.Equal("new display name!", updated.DisplayName) + suite.Equal([]string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}, updated.EmojiIDs) + suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) + + // get account without cache + make sure it's really in the db as desired + dbService, ok := suite.db.(*bundb.DBService) + if !ok { + panic("db was not *bundb.DBService") + } + + noCache := >smodel.Account{} + err = dbService.GetConn(). + NewSelect(). + Model(noCache). + Where("account.id = ?", bun.Ident(testAccount.ID)). + Relation("AvatarMediaAttachment"). + Relation("HeaderMediaAttachment"). + Relation("Emojis"). + Scan(ctx) + + suite.NoError(err) + suite.Equal("new display name!", noCache.DisplayName) + suite.Equal([]string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}, noCache.EmojiIDs) + suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second) + suite.NotNil(noCache.AvatarMediaAttachment) + suite.NotNil(noCache.HeaderMediaAttachment) - _, err := suite.db.UpdateAccount(context.Background(), testAccount) + // update again to remove emoji associations + testAccount.EmojiIDs = []string{} + + _, err = suite.db.UpdateAccount(ctx, testAccount) suite.NoError(err) - updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) + updated, err = suite.db.GetAccountByID(ctx, testAccount.ID) suite.NoError(err) suite.Equal("new display name!", updated.DisplayName) + suite.Empty(updated.EmojiIDs) suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) + + err = dbService.GetConn(). + NewSelect(). + Model(noCache). + Where("account.id = ?", bun.Ident(testAccount.ID)). + Relation("AvatarMediaAttachment"). + Relation("HeaderMediaAttachment"). + Relation("Emojis"). + Scan(ctx) + + suite.NoError(err) + suite.Equal("new display name!", noCache.DisplayName) + suite.Empty(noCache.EmojiIDs) + suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second) } func (suite *AccountTestSuite) TestInsertAccountWithDefaults() { diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index b944ae3ea..2fc65364f 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -67,12 +67,13 @@ const ( ) var registerTables = []interface{}{ + >smodel.AccountToEmoji{}, >smodel.StatusToEmoji{}, >smodel.StatusToTag{}, } -// bunDBService satisfies the DB interface -type bunDBService struct { +// DBService satisfies the DB interface +type DBService struct { db.Account db.Admin db.Basic @@ -89,6 +90,12 @@ type bunDBService struct { conn *DBConn } +// GetConn returns the underlying bun connection. +// Should only be used in testing + exceptional circumstance. +func (dbService *DBService) GetConn() *DBConn { + return dbService.conn +} + func doMigration(ctx context.Context, db *bun.DB) error { migrator := migrate.NewMigrator(db, migrations.Migrations) @@ -177,7 +184,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { // Prepare domain block cache blockCache := cache.NewDomainBlockCache() - ps := &bunDBService{ + ps := &DBService{ Account: accounts, Admin: &adminDB{ conn: conn, @@ -399,7 +406,7 @@ func tweakConnectionValues(sqldb *sql.DB) { CONVERSION FUNCTIONS */ -func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) { +func (dbService *DBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) { protocol := config.GetProtocol() host := config.GetHost() @@ -408,7 +415,7 @@ func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, ori tag := >smodel.Tag{} // we can use selectorinsert here to create the new tag if it doesn't exist already // inserted will be true if this is a new tag we just created - if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil { + if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil { if err == sql.ErrNoRows { // tag doesn't exist yet so populate it newID, err := id.NewRandomULID() diff --git a/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go b/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go new file mode 100644 index 000000000..91468a4c9 --- /dev/null +++ b/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go @@ -0,0 +1,69 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package migrations + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + q := tx.NewAddColumn().Model(>smodel.Account{}) + + switch tx.Dialect().Name() { + case dialect.PG: + q = q.ColumnExpr("? VARCHAR[]", bun.Ident("emojis")) + case dialect.SQLite: + q = q.ColumnExpr("? VARCHAR", bun.Ident("emojis")) + default: + log.Panic("db dialect was neither pg nor sqlite") + } + + if _, err := q.Exec(ctx); err != nil { + return err + } + + if _, err := tx. + NewCreateTable(). + Model(>smodel.AccountToEmoji{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} |