summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/account.go3
-rw-r--r--internal/db/bundb/account.go60
-rw-r--r--internal/db/bundb/account_test.go59
-rw-r--r--internal/db/bundb/bundb.go17
-rw-r--r--internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go69
5 files changed, 190 insertions, 18 deletions
diff --git a/internal/db/account.go b/internal/db/account.go
index 5f1336872..351d6d01c 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -42,6 +42,9 @@ type Account interface {
// GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong.
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
+ // PutAccount puts one account in the database.
+ PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
+
// UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 2105368d3..074804690 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -45,7 +45,8 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
NewSelect().
Model(account).
Relation("AvatarMediaAttachment").
- Relation("HeaderMediaAttachment")
+ Relation("HeaderMediaAttachment").
+ Relation("Emojis")
}
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
@@ -138,24 +139,61 @@ func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.A
return account, nil
}
+func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
+ if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // create links between this account and any emojis it uses
+ for _, i := range account.EmojiIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
+ AccountID: account.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // insert the account
+ _, err := tx.NewInsert().Model(account).Exec(ctx)
+ return err
+ }); err != nil {
+ return nil, a.conn.ProcessError(err)
+ }
+
+ a.cache.Put(account)
+ return account, nil
+}
+
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
// Update the account's last-updated
account.UpdatedAt = time.Now()
- // Update the account model in the DB
- _, err := a.conn.
- NewUpdate().
- Model(account).
- WherePK().
- Exec(ctx)
- if err != nil {
+ if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // create links between this account and any emojis it uses
+ // first clear out any old emoji links
+ if _, err := tx.NewDelete().
+ Model(&[]*gtsmodel.AccountToEmoji{}).
+ Where("account_id = ?", account.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // now populate new emoji links
+ for _, i := range account.EmojiIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
+ AccountID: account.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // update the account
+ _, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx)
+ return err
+ }); err != nil {
return nil, a.conn.ProcessError(err)
}
- // Place updated account in cache
- // (this will replace existing, i.e. invalidating)
a.cache.Put(account)
-
return account, nil
}
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index 3c19e84d9..1e6dc4436 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -27,7 +27,9 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type AccountTestSuite struct {
@@ -71,17 +73,70 @@ func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
}
func (suite *AccountTestSuite) TestUpdateAccount() {
+ ctx := context.Background()
+
testAccount := suite.testAccounts["local_account_1"]
testAccount.DisplayName = "new display name!"
+ testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}
+
+ _, err := suite.db.UpdateAccount(ctx, testAccount)
+ suite.NoError(err)
+
+ updated, err := suite.db.GetAccountByID(ctx, testAccount.ID)
+ suite.NoError(err)
+ suite.Equal("new display name!", updated.DisplayName)
+ suite.Equal([]string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}, updated.EmojiIDs)
+ suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
+
+ // get account without cache + make sure it's really in the db as desired
+ dbService, ok := suite.db.(*bundb.DBService)
+ if !ok {
+ panic("db was not *bundb.DBService")
+ }
+
+ noCache := &gtsmodel.Account{}
+ err = dbService.GetConn().
+ NewSelect().
+ Model(noCache).
+ Where("account.id = ?", bun.Ident(testAccount.ID)).
+ Relation("AvatarMediaAttachment").
+ Relation("HeaderMediaAttachment").
+ Relation("Emojis").
+ Scan(ctx)
+
+ suite.NoError(err)
+ suite.Equal("new display name!", noCache.DisplayName)
+ suite.Equal([]string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}, noCache.EmojiIDs)
+ suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second)
+ suite.NotNil(noCache.AvatarMediaAttachment)
+ suite.NotNil(noCache.HeaderMediaAttachment)
- _, err := suite.db.UpdateAccount(context.Background(), testAccount)
+ // update again to remove emoji associations
+ testAccount.EmojiIDs = []string{}
+
+ _, err = suite.db.UpdateAccount(ctx, testAccount)
suite.NoError(err)
- updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
+ updated, err = suite.db.GetAccountByID(ctx, testAccount.ID)
suite.NoError(err)
suite.Equal("new display name!", updated.DisplayName)
+ suite.Empty(updated.EmojiIDs)
suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
+
+ err = dbService.GetConn().
+ NewSelect().
+ Model(noCache).
+ Where("account.id = ?", bun.Ident(testAccount.ID)).
+ Relation("AvatarMediaAttachment").
+ Relation("HeaderMediaAttachment").
+ Relation("Emojis").
+ Scan(ctx)
+
+ suite.NoError(err)
+ suite.Equal("new display name!", noCache.DisplayName)
+ suite.Empty(noCache.EmojiIDs)
+ suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second)
}
func (suite *AccountTestSuite) TestInsertAccountWithDefaults() {
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index b944ae3ea..2fc65364f 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -67,12 +67,13 @@ const (
)
var registerTables = []interface{}{
+ &gtsmodel.AccountToEmoji{},
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
}
-// bunDBService satisfies the DB interface
-type bunDBService struct {
+// DBService satisfies the DB interface
+type DBService struct {
db.Account
db.Admin
db.Basic
@@ -89,6 +90,12 @@ type bunDBService struct {
conn *DBConn
}
+// GetConn returns the underlying bun connection.
+// Should only be used in testing + exceptional circumstance.
+func (dbService *DBService) GetConn() *DBConn {
+ return dbService.conn
+}
+
func doMigration(ctx context.Context, db *bun.DB) error {
migrator := migrate.NewMigrator(db, migrations.Migrations)
@@ -177,7 +184,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
// Prepare domain block cache
blockCache := cache.NewDomainBlockCache()
- ps := &bunDBService{
+ ps := &DBService{
Account: accounts,
Admin: &adminDB{
conn: conn,
@@ -399,7 +406,7 @@ func tweakConnectionValues(sqldb *sql.DB) {
CONVERSION FUNCTIONS
*/
-func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) {
+func (dbService *DBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) {
protocol := config.GetProtocol()
host := config.GetHost()
@@ -408,7 +415,7 @@ func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, ori
tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created
- if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
+ if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
if err == sql.ErrNoRows {
// tag doesn't exist yet so populate it
newID, err := id.NewRandomULID()
diff --git a/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go b/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go
new file mode 100644
index 000000000..91468a4c9
--- /dev/null
+++ b/internal/db/bundb/migrations/20220916122701_emojis_in_accounts.go
@@ -0,0 +1,69 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package migrations
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ q := tx.NewAddColumn().Model(&gtsmodel.Account{})
+
+ switch tx.Dialect().Name() {
+ case dialect.PG:
+ q = q.ColumnExpr("? VARCHAR[]", bun.Ident("emojis"))
+ case dialect.SQLite:
+ q = q.ColumnExpr("? VARCHAR", bun.Ident("emojis"))
+ default:
+ log.Panic("db dialect was neither pg nor sqlite")
+ }
+
+ if _, err := q.Exec(ctx); err != nil {
+ return err
+ }
+
+ if _, err := tx.
+ NewCreateTable().
+ Model(&gtsmodel.AccountToEmoji{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}