summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2022-10-08 13:50:48 +0200
committerLibravatar GitHub <noreply@github.com>2022-10-08 13:50:48 +0200
commitaa07750bdb4dacdb1be39d765114915bba3fc29f (patch)
tree30e9e5052f607f8c8e4f7d518559df8706275e0f
parent[performance] cache domains after max retries in transport (#884) (diff)
downloadgotosocial-aa07750bdb4dacdb1be39d765114915bba3fc29f.tar.xz
[chore] Standardize database queries, use `bun.Ident()` properly (#886)
* use bun.Ident for user queries * use bun.Ident for account queries * use bun.Ident for media queries * add DeleteAccount func * remove CaseInsensitive in Where+use Ident ipv Safe * update admin db * update domain, use ident * update emoji, use ident * update instance queries, use bun.Ident * fix media * update mentions, use bun ident * update relationship + tests * use tableexpr * add test follows to bun db test suite * update notifications * updatebyprimarykey => updatebyid * fix session * prefer explicit ID to pk * fix little fucky wucky * remove workaround * use proper db func for attachment selection * update status db * add m2m entries in test rig * fix up timeline * go fmt * fix status put issue * update GetAccountStatuses
-rw-r--r--cmd/gotosocial/action/admin/account/account.go2
-rw-r--r--internal/cache/account.go5
-rw-r--r--internal/db/account.go5
-rw-r--r--internal/db/basic.go4
-rw-r--r--internal/db/bundb/account.go164
-rw-r--r--internal/db/bundb/account_test.go16
-rw-r--r--internal/db/bundb/admin.go82
-rw-r--r--internal/db/bundb/admin_test.go39
-rw-r--r--internal/db/bundb/basic.go6
-rw-r--r--internal/db/bundb/bundb.go34
-rw-r--r--internal/db/bundb/bundb_test.go2
-rw-r--r--internal/db/bundb/domain.go5
-rw-r--r--internal/db/bundb/emoji.go24
-rw-r--r--internal/db/bundb/instance.go43
-rw-r--r--internal/db/bundb/instance_test.go83
-rw-r--r--internal/db/bundb/media.go30
-rw-r--r--internal/db/bundb/mention.go2
-rw-r--r--internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go4
-rw-r--r--internal/db/bundb/notification.go22
-rw-r--r--internal/db/bundb/relationship.go280
-rw-r--r--internal/db/bundb/relationship_test.go249
-rw-r--r--internal/db/bundb/session.go33
-rw-r--r--internal/db/bundb/session_test.go15
-rw-r--r--internal/db/bundb/status.go186
-rw-r--r--internal/db/bundb/status_test.go4
-rw-r--r--internal/db/bundb/timeline.go68
-rw-r--r--internal/db/bundb/timeline_test.go9
-rw-r--r--internal/db/bundb/user.go14
-rw-r--r--internal/db/bundb/util.go16
-rw-r--r--internal/db/params.go3
-rw-r--r--internal/media/processingmedia.go2
-rw-r--r--internal/media/prunemeta_test.go8
-rw-r--r--internal/media/pruneremote.go2
-rw-r--r--internal/processing/admin/createdomainblock.go8
-rw-r--r--internal/processing/admin/deletedomainblock.go4
-rw-r--r--internal/processing/instance.go2
-rw-r--r--internal/processing/media/getfile_test.go6
-rw-r--r--internal/processing/media/unattach.go2
-rw-r--r--internal/processing/media/update.go2
-rw-r--r--internal/processing/status/util.go33
-rw-r--r--internal/processing/user/changepassword.go2
-rw-r--r--internal/processing/user/emailconfirm.go4
-rw-r--r--internal/processing/user/emailconfirm_test.go4
-rw-r--r--testrig/db.go14
-rw-r--r--testrig/testmodels.go18
45 files changed, 1032 insertions, 528 deletions
diff --git a/cmd/gotosocial/action/admin/account/account.go b/cmd/gotosocial/action/admin/account/account.go
index f2cce57b5..422e4bfef 100644
--- a/cmd/gotosocial/action/admin/account/account.go
+++ b/cmd/gotosocial/action/admin/account/account.go
@@ -101,7 +101,7 @@ var Confirm action.GTSAction = func(ctx context.Context) error {
u.Email = u.UnconfirmedEmail
u.ConfirmedAt = time.Now()
u.UpdatedAt = time.Now()
- if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
+ if err := dbConn.UpdateByID(ctx, u, u.ID, updatingColumns...); err != nil {
return err
}
diff --git a/internal/cache/account.go b/internal/cache/account.go
index 7e23c3194..12675b6b9 100644
--- a/internal/cache/account.go
+++ b/internal/cache/account.go
@@ -101,6 +101,11 @@ func (c *AccountCache) Put(account *gtsmodel.Account) {
c.cache.Set(account.ID, copyAccount(account))
}
+// Invalidate removes (invalidates) one account from the cache by its ID.
+func (c *AccountCache) Invalidate(id string) {
+ c.cache.Invalidate(id)
+}
+
// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
// this should be a relatively cheap process
diff --git a/internal/db/account.go b/internal/db/account.go
index 351d6d01c..ae5eea7c6 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -48,6 +48,11 @@ type Account interface {
// UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
+ // DeleteAccount deletes one account from the database by its ID.
+ // DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the
+ // account as suspended instead, rather than deleting from the db entirely.
+ DeleteAccount(ctx context.Context, id string) Error
+
// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error)
diff --git a/internal/db/basic.go b/internal/db/basic.go
index 6e5184d31..8990edd5f 100644
--- a/internal/db/basic.go
+++ b/internal/db/basic.go
@@ -62,11 +62,11 @@ type Basic interface {
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(ctx context.Context, i interface{}) Error
- // UpdateByPrimaryKey updates values of i based on its primary key.
+ // UpdateByID updates values of i based on its id.
// If any columns are specified, these will be updated exclusively.
// Otherwise, the whole model will be updated.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) Error
+ UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 074804690..c04948fee 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -21,7 +21,6 @@ package bundb
import (
"context"
"errors"
- "fmt"
"strings"
"time"
@@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
return a.cache.GetByID(id)
},
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx)
+ return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
},
)
}
@@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
return a.cache.GetByURI(uri)
},
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx)
+ return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
},
)
}
@@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
return a.cache.GetByURL(url)
},
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx)
+ return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
},
)
}
@@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
q := a.newAccountQ(account)
if domain != "" {
- q = q.Where("account.username = ?", username)
- q = q.Where("account.domain = ?", domain)
+ q = q.Where("? = ?", bun.Ident("account.username"), username)
+ q = q.Where("? = ?", bun.Ident("account.domain"), domain)
} else {
- q = q.Where("account.username = ?", strings.ToLower(username))
- q = q.Where("account.domain IS NULL")
+ q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username))
+ q = q.Where("? IS NULL", bun.Ident("account.domain"))
}
return q.Scan(ctx)
@@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
return a.cache.GetByPubkeyID(id)
},
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx)
+ return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
},
)
}
@@ -169,26 +168,36 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses
// first clear out any old emoji links
- if _, err := tx.NewDelete().
- Model(&[]*gtsmodel.AccountToEmoji{}).
- Where("account_id = ?", account.ID).
+ if _, err := tx.
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
+ Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
Exec(ctx); err != nil {
return err
}
// now populate new emoji links
for _, i := range account.EmojiIDs {
- if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
- AccountID: account.ID,
- EmojiID: i,
- }).Exec(ctx); err != nil {
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.AccountToEmoji{
+ AccountID: account.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
return err
}
}
// update the account
- _, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx)
- return err
+ if _, err := tx.
+ NewUpdate().
+ Model(account).
+ Where("? = ?", bun.Ident("account.id"), account.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
}); err != nil {
return nil, a.conn.ProcessError(err)
}
@@ -197,6 +206,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
return account, nil
}
+func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
+ if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // clear out any emoji links
+ if _, err := tx.
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
+ Where("? = ?", bun.Ident("account_to_emoji.account_id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // delete the account
+ _, err := tx.
+ NewUpdate().
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Where("? = ?", bun.Ident("account.id"), id).
+ Exec(ctx)
+ return err
+ }); err != nil {
+ return a.conn.ProcessError(err)
+ }
+
+ a.cache.Invalidate(id)
+ return nil
+}
+
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
@@ -204,11 +239,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
if domain != "" {
q = q.
- Where("account.username = ?", domain).
- Where("account.domain = ?", domain)
+ Where("? = ?", bun.Ident("account.username"), domain).
+ Where("? = ?", bun.Ident("account.domain"), domain)
} else {
q = q.
- Where("account.username = ?", config.GetHost()).
+ Where("? = ?", bun.Ident("account.username"), config.GetHost()).
WhereGroup(" AND ", whereEmptyOrNull("domain"))
}
@@ -224,10 +259,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
q := a.conn.
NewSelect().
Model(status).
- Order("id DESC").
- Limit(1).
- Where("account_id = ?", accountID).
- Column("created_at")
+ Column("status.created_at").
+ Where("? = ?", bun.Ident("status.account_id"), accountID).
+ Order("status.id DESC").
+ Limit(1)
if err := q.Scan(ctx); err != nil {
return time.Time{}, a.conn.ProcessError(err)
@@ -240,12 +275,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
return errors.New("one media attachment cannot be both header and avatar")
}
- var headerOrAVI string
+ var column bun.Ident
switch {
case *mediaAttachment.Avatar:
- headerOrAVI = "avatar"
+ column = bun.Ident("account.avatar_media_attachment_id")
case *mediaAttachment.Header:
- headerOrAVI = "header"
+ column = bun.Ident("account.header_media_attachment_id")
default:
return errors.New("given media attachment was neither a header nor an avatar")
}
@@ -257,11 +292,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
Exec(ctx); err != nil {
return a.conn.ProcessError(err)
}
+
if _, err := a.conn.
NewUpdate().
- Model(&gtsmodel.Account{}).
- Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
- Where("id = ?", accountID).
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Set("? = ?", column, mediaAttachment.ID).
+ Where("? = ?", bun.Ident("account.id"), accountID).
Exec(ctx); err != nil {
return a.conn.ProcessError(err)
}
@@ -284,7 +320,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
if err := a.conn.
NewSelect().
Model(faves).
- Where("account_id = ?", accountID).
+ Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
Scan(ctx); err != nil {
return nil, a.conn.ProcessError(err)
}
@@ -295,8 +331,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
return a.conn.
NewSelect().
- Model(&gtsmodel.Status{}).
- Where("account_id = ?", accountID).
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Where("? = ?", bun.Ident("status.account_id"), accountID).
Count(ctx)
}
@@ -305,12 +341,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
q := a.conn.
NewSelect().
- Table("statuses").
- Column("id").
- Order("id DESC")
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Column("status.id").
+ Order("status.id DESC")
if accountID != "" {
- q = q.Where("account_id = ?", accountID)
+ q = q.Where("? = ?", bun.Ident("status.account_id"), accountID)
}
if limit != 0 {
@@ -321,27 +357,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
// include self-replies (threads)
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
return q.
- WhereOr("in_reply_to_account_id = ?", accountID).
- WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri"))
+ WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID).
+ WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri"))
}
q = q.WhereGroup(" AND ", whereGroup)
}
if excludeReblogs {
- q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id"))
+ q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id"))
}
if maxID != "" {
- q = q.Where("id < ?", maxID)
+ q = q.Where("? < ?", bun.Ident("status.id"), maxID)
}
if minID != "" {
- q = q.Where("id > ?", minID)
+ q = q.Where("? > ?", bun.Ident("status.id"), minID)
}
if pinnedOnly {
- q = q.Where("pinned = ?", true)
+ q = q.Where("? = ?", bun.Ident("status.pinned"), true)
}
if mediaOnly {
@@ -352,15 +388,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
switch a.conn.Dialect().Name() {
case dialect.PG:
return q.
- Where("? IS NOT NULL", bun.Ident("attachments")).
- Where("? != '{}'", bun.Ident("attachments"))
+ Where("? IS NOT NULL", bun.Ident("status.attachments")).
+ Where("? != '{}'", bun.Ident("status.attachments"))
case dialect.SQLite:
return q.
- Where("? IS NOT NULL", bun.Ident("attachments")).
- Where("? != ''", bun.Ident("attachments")).
- Where("? != 'null'", bun.Ident("attachments")).
- Where("? != '{}'", bun.Ident("attachments")).
- Where("? != '[]'", bun.Ident("attachments"))
+ Where("? IS NOT NULL", bun.Ident("status.attachments")).
+ Where("? != ''", bun.Ident("status.attachments")).
+ Where("? != 'null'", bun.Ident("status.attachments")).
+ Where("? != '{}'", bun.Ident("status.attachments")).
+ Where("? != '[]'", bun.Ident("status.attachments"))
default:
log.Panic("db dialect was neither pg nor sqlite")
return q
@@ -369,7 +405,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
}
if publicOnly {
- q = q.Where("visibility = ?", gtsmodel.VisibilityPublic)
+ q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)
}
if err := q.Scan(ctx, &statusIDs); err != nil {
@@ -384,19 +420,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
q := a.conn.
NewSelect().
- Table("statuses").
- Column("id").
- Where("account_id = ?", accountID).
- WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")).
- WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")).
- Where("visibility = ?", gtsmodel.VisibilityPublic).
- Where("federated = ?", true)
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Column("status.id").
+ Where("? = ?", bun.Ident("status.account_id"), accountID).
+ WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
+ WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
+ Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
+ Where("? = ?", bun.Ident("status.federated"), true)
if maxID != "" {
- q = q.Where("id < ?", maxID)
+ q = q.Where("? < ?", bun.Ident("status.id"), maxID)
}
- q = q.Limit(limit).Order("id DESC")
+ q = q.Limit(limit).Order("status.id DESC")
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err)
@@ -411,16 +447,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
fq := a.conn.
NewSelect().
Model(&blocks).
- Where("block.account_id = ?", accountID).
+ Where("? = ?", bun.Ident("block.account_id"), accountID).
Relation("TargetAccount").
Order("block.id DESC")
if maxID != "" {
- fq = fq.Where("block.id < ?", maxID)
+ fq = fq.Where("? < ?", bun.Ident("block.id"), maxID)
}
if sinceID != "" {
- fq = fq.Where("block.id > ?", sinceID)
+ fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID)
}
if limit > 0 {
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index ad2a217af..72adba487 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -42,6 +42,18 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() {
suite.Len(statuses, 5)
}
+func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() {
+ statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, false)
+ suite.NoError(err)
+ suite.Len(statuses, 5)
+}
+
+func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() {
+ statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, true)
+ suite.NoError(err)
+ suite.Len(statuses, 1)
+}
+
func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() {
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, true, false)
suite.NoError(err)
@@ -99,7 +111,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
err = dbService.GetConn().
NewSelect().
Model(noCache).
- Where("account.id = ?", bun.Ident(testAccount.ID)).
+ Where("? = ?", bun.Ident("account.id"), testAccount.ID).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment").
Relation("Emojis").
@@ -127,7 +139,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
err = dbService.GetConn().
NewSelect().
Model(noCache).
- Where("account.id = ?", bun.Ident(testAccount.ID)).
+ Where("? = ?", bun.Ident("account.id"), testAccount.ID).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment").
Relation("Emojis").
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
index 9fa78eca0..44861a4bb 100644
--- a/internal/db/bundb/admin.go
+++ b/internal/db/bundb/admin.go
@@ -22,7 +22,6 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
- "database/sql"
"fmt"
"net"
"net/mail"
@@ -37,21 +36,26 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/uris"
+ "github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
+// generate RSA keys of this length
+const rsaKeyBits = 2048
+
type adminDB struct {
- conn *DBConn
- userCache *cache.UserCache
+ conn *DBConn
+ userCache *cache.UserCache
+ accountCache *cache.AccountCache
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
q := a.conn.
NewSelect().
- Model(&gtsmodel.Account{}).
- Where("username = ?", username).
- Where("domain = ?", nil)
-
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? = ?", bun.Ident("account.username"), username).
+ Where("? IS NULL", bun.Ident("account.domain"))
return a.conn.NotExists(ctx, q)
}
@@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
- if err := a.conn.
+ emailDomainBlockedQ := a.conn.
NewSelect().
- Model(&gtsmodel.EmailDomainBlock{}).
- Where("domain = ?", domain).
- Scan(ctx); err == nil {
- // fail because we found something
+ TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
+ Column("email_domain_block.id").
+ Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
+ emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ)
+ if err != nil {
+ return false, err
+ }
+ if emailDomainBlocked {
return false, fmt.Errorf("email domain %s is blocked", domain)
- } else if err != sql.ErrNoRows {
- return false, a.conn.ProcessError(err)
}
// check if this email is associated with a user already
q := a.conn.
NewSelect().
- Model(&gtsmodel.User{}).
- Where("email = ?", email).
- WhereOr("unconfirmed_email = ?", email)
-
+ TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
+ Column("user.id").
+ Where("? = ?", bun.Ident("user.email"), email).
+ WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
return a.conn.NotExists(ctx, q)
}
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
- key, err := rsa.GenerateKey(rand.Reader, 2048)
+ key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
if err != nil {
log.Errorf("error creating new rsa key: %s", err)
return nil, err
@@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
// if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{}
- q := a.conn.NewSelect().
+ if err := a.conn.
+ NewSelect().
Model(acct).
- Where("username = ?", username).
- WhereGroup(" AND ", whereEmptyOrNull("domain"))
+ Where("? = ?", bun.Ident("account.username"), username).
+ WhereGroup(" AND ", whereEmptyOrNull("account.domain")).
+ Scan(ctx); err != nil {
+ err = a.conn.ProcessError(err)
+ if err != db.ErrNoEntries {
+ log.Errorf("error checking for existing account: %s", err)
+ return nil, err
+ }
- if err := q.Scan(ctx); err != nil {
- // we just don't have an account yet so create one before we proceed
+ // if we have db.ErrNoEntries, we just don't have an
+ // account yet so create one before we proceed
accountURIs := uris.GenerateURIsForAccount(username)
accountID, err := id.NewRandomULID()
if err != nil {
@@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
FeaturedCollectionURI: accountURIs.CollectionURI,
}
+ // insert the new account!
if _, err = a.conn.
NewInsert().
Model(acct).
Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err)
}
+ a.accountCache.Put(acct)
}
+ // we either created or already had an account by now,
+ // so proceed with creating a user for that account
+
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
@@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
u.Moderator = &moderator
}
+ // insert the user!
if _, err = a.conn.
NewInsert().
Model(u).
@@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
q := a.conn.
NewSelect().
- Model(&gtsmodel.Account{}).
- Where("username = ?", username).
- WhereGroup(" AND ", whereEmptyOrNull("domain"))
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? = ?", bun.Ident("account.username"), username).
+ WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
exists, err := a.conn.Exists(ctx, q)
if err != nil {
@@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return nil
}
- key, err := rsa.GenerateKey(rand.Reader, 2048)
+ key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
if err != nil {
log.Errorf("error creating new rsa key: %s", err)
return err
@@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return a.conn.ProcessError(err)
}
+ a.accountCache.Put(acct)
log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
@@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
// check if instance entry already exists
q := a.conn.
NewSelect().
- Model(&gtsmodel.Instance{}).
- Where("domain = ?", host)
+ Column("instance.id").
+ TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
+ Where("? = ?", bun.Ident("instance.domain"), host)
exists, err := a.conn.Exists(ctx, q)
if err != nil {
diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go
index 22041087a..f0a869a9b 100644
--- a/internal/db/bundb/admin_test.go
+++ b/internal/db/bundb/admin_test.go
@@ -23,6 +23,7 @@ import (
"testing"
"github.com/stretchr/testify/suite"
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -30,6 +31,44 @@ type AdminTestSuite struct {
BunDBStandardTestSuite
}
+func (suite *AdminTestSuite) TestIsUsernameAvailableNo() {
+ available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork")
+ suite.NoError(err)
+ suite.False(available)
+}
+
+func (suite *AdminTestSuite) TestIsUsernameAvailableYes() {
+ available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different")
+ suite.NoError(err)
+ suite.True(available)
+}
+
+func (suite *AdminTestSuite) TestIsEmailAvailableNo() {
+ available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org")
+ suite.NoError(err)
+ suite.False(available)
+}
+
+func (suite *AdminTestSuite) TestIsEmailAvailableYes() {
+ available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
+ suite.NoError(err)
+ suite.True(available)
+}
+
+func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
+ if err := suite.db.Put(context.Background(), &gtsmodel.EmailDomainBlock{
+ ID: "01GEEV2R2YC5GRSN96761YJE47",
+ Domain: "somewhere.com",
+ CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+ }); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
+ suite.EqualError(err, "email domain somewhere.com is blocked")
+ suite.False(available)
+}
+
func (suite *AdminTestSuite) TestCreateInstanceAccount() {
// we need to take an empty db for this...
testrig.StandardDBTeardown(suite.db)
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
index cd80c9330..ef8b35574 100644
--- a/internal/db/bundb/basic.go
+++ b/internal/db/bundb/basic.go
@@ -94,12 +94,12 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
return b.conn.ProcessError(err)
}
-func (b *basicDB) UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) db.Error {
+func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) db.Error {
q := b.conn.
NewUpdate().
Model(i).
Column(columns...).
- WherePK()
+ Where("? = ?", bun.Ident("id"), id)
_, err := q.Exec(ctx)
return b.conn.ProcessError(err)
@@ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
updateWhere(q, where)
- q = q.Set("? = ?", bun.Safe(key), value)
+ q = q.Set("? = ?", bun.Ident(key), value)
_, err := q.Exec(ctx)
return b.conn.ProcessError(err)
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 70a44d4c1..02522e6f7 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
return nil, fmt.Errorf("db migration error: %s", err)
}
- // Create DB structs that require ptrs to each other
- accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()}
- status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
- emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
- timeline := &timelineDB{conn: conn}
-
- // Setup DB cross-referencing
- accounts.status = status
- status.accounts = accounts
- timeline.status = status
+ // Prepare caches required by more than one struct
+ userCache := cache.NewUserCache()
+ accountCache := cache.NewAccountCache()
+ // Prepare other caches
// Prepare mentions cache
// TODO: move into internal/cache
mentionCache := grufcache.New[string, *gtsmodel.Mention]()
@@ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
notifCache.SetTTL(time.Minute*5, false)
notifCache.Start(time.Second * 10)
- // Prepare other caches
- blockCache := cache.NewDomainBlockCache()
- userCache := cache.NewUserCache()
+ // Create DB structs that require ptrs to each other
+ accounts := &accountDB{conn: conn, cache: accountCache}
+ status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
+ emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
+ timeline := &timelineDB{conn: conn}
+
+ // Setup DB cross-referencing
+ accounts.status = status
+ status.accounts = accounts
+ timeline.status = status
ps := &DBService{
Account: accounts,
Admin: &adminDB{
- conn: conn,
- userCache: userCache,
+ conn: conn,
+ userCache: userCache,
+ accountCache: accountCache,
},
Basic: &basicDB{
conn: conn,
},
Domain: &domainDB{
conn: conn,
- cache: blockCache,
+ cache: cache.NewDomainBlockCache(),
},
Emoji: emoji,
Instance: &instanceDB{
diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go
index 581573056..2af6cf122 100644
--- a/internal/db/bundb/bundb_test.go
+++ b/internal/db/bundb/bundb_test.go
@@ -40,6 +40,7 @@ type BunDBStandardTestSuite struct {
testStatuses map[string]*gtsmodel.Status
testTags map[string]*gtsmodel.Tag
testMentions map[string]*gtsmodel.Mention
+ testFollows map[string]*gtsmodel.Follow
}
func (suite *BunDBStandardTestSuite) SetupSuite() {
@@ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
+ suite.testFollows = testrig.NewTestFollows()
}
func (suite *BunDBStandardTestSuite) SetupTest() {
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index 9fc4bb276..0a752d3f3 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
"golang.org/x/net/idna"
)
@@ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
q := d.conn.
NewSelect().
Model(block).
- Where("domain = ?", domain).
+ Where("? = ?", bun.Ident("domain_block.domain"), domain).
Limit(1)
// Query database for domain block
@@ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
// Attempt to delete domain block
if _, err := d.conn.NewDelete().
Model((*gtsmodel.DomainBlock)(nil)).
- Where("domain = ?", domain).
+ Where("? = ?", bun.Ident("domain_block.domain"), domain).
Exec(ctx); err != nil {
return d.conn.ProcessError(err)
}
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go
index 758da0feb..e781e2f00 100644
--- a/internal/db/bundb/emoji.go
+++ b/internal/db/bundb/emoji.go
@@ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er
q := e.conn.
NewSelect().
- Table("emojis").
- Column("id").
- Where("visible_in_picker = true").
- Where("disabled = false").
- Where("domain IS NULL").
- Order("shortcode ASC")
+ TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
+ Column("emoji.id").
+ Where("? = ?", bun.Ident("emoji.visible_in_picker"), true).
+ Where("? = ?", bun.Ident("emoji.disabled"), false).
+ Where("? IS NULL", bun.Ident("emoji.domain")).
+ Order("emoji.shortcode ASC")
if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err)
@@ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,
return e.cache.GetByID(id)
},
func(emoji *gtsmodel.Emoji) error {
- return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx)
+ return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
},
)
}
@@ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
return e.cache.GetByURI(uri)
},
func(emoji *gtsmodel.Emoji) error {
- return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx)
+ return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
},
)
}
@@ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
q := e.newEmojiQ(emoji)
if domain != "" {
- q = q.Where("emoji.shortcode = ?", shortcode)
- q = q.Where("emoji.domain = ?", domain)
+ q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode)
+ q = q.Where("? = ?", bun.Ident("emoji.domain"), domain)
} else {
- q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode))
- q = q.Where("emoji.domain IS NULL")
+ q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode))
+ q = q.Where("? IS NULL", bun.Ident("emoji.domain"))
}
return q.Scan(ctx)
diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go
index fb6454e2f..604461708 100644
--- a/internal/db/bundb/instance.go
+++ b/internal/db/bundb/instance.go
@@ -24,7 +24,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
@@ -35,15 +34,16 @@ type instanceDB struct {
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
- Model(&[]*gtsmodel.Account{}).
- Where("username != ?", domain).
- Where("? IS NULL", bun.Ident("suspended_at"))
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? != ?", bun.Ident("account.username"), domain).
+ Where("? IS NULL", bun.Ident("account.suspended_at"))
- if domain == config.GetHost() {
+ if domain == config.GetHost() || domain == config.GetAccountDomain() {
// if the domain is *this* domain, just count where the domain field is null
- q = q.WhereGroup(" AND ", whereEmptyOrNull("domain"))
+ q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
} else {
- q = q.Where("domain = ?", domain)
+ q = q.Where("? = ?", bun.Ident("account.domain"), domain)
}
count, err := q.Count(ctx)
@@ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
- Model(&[]*gtsmodel.Status{})
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status"))
- if domain == config.GetHost() {
+ if domain == config.GetHost() || domain == config.GetAccountDomain() {
// if the domain is *this* domain, just count where local is true
- q = q.Where("local = ?", true)
+ q = q.Where("? = ?", bun.Ident("status.local"), true)
} else {
// join on the domain of the account
- q = q.Join("JOIN accounts AS account ON account.id = status.account_id").
- Where("account.domain = ?", domain)
+ q = q.
+ Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")).
+ Where("? = ?", bun.Ident("account.domain"), domain)
}
count, err := q.Count(ctx)
@@ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
- Model(&[]*gtsmodel.Instance{})
+ TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance"))
if domain == config.GetHost() {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.
- Where("domain != ?", domain).
- Where("? IS NULL", bun.Ident("suspended_at"))
+ Where("? != ?", bun.Ident("instance.domain"), domain).
+ Where("? IS NULL", bun.Ident("instance.suspended_at"))
} else {
// TODO: implement federated domain counting properly for remote domains
return 0, nil
@@ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
q := i.conn.
NewSelect().
Model(&instances).
- Where("domain != ?", config.GetHost())
+ Where("? != ?", bun.Ident("instance.domain"), config.GetHost())
if !includeSuspended {
- q = q.Where("? IS NULL", bun.Ident("suspended_at"))
+ q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))
}
if err := q.Scan(ctx); err != nil {
@@ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
}
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
- log.Debug("GetAccountsForInstance")
-
accounts := []*gtsmodel.Account{}
q := i.conn.NewSelect().
Model(&accounts).
- Where("domain = ?", domain).
- Order("id DESC")
+ Where("? = ?", bun.Ident("account.domain"), domain).
+ Order("account.id DESC")
if maxID != "" {
- q = q.Where("id < ?", maxID)
+ q = q.Where("? < ?", bun.Ident("account.id"), maxID)
}
if limit > 0 {
diff --git a/internal/db/bundb/instance_test.go b/internal/db/bundb/instance_test.go
new file mode 100644
index 000000000..50d118888
--- /dev/null
+++ b/internal/db/bundb/instance_test.go
@@ -0,0 +1,83 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+)
+
+type InstanceTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *InstanceTestSuite) TestCountInstanceUsers() {
+ count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost())
+ suite.NoError(err)
+ suite.Equal(4, count)
+}
+
+func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {
+ count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io")
+ suite.NoError(err)
+ suite.Equal(1, count)
+}
+
+func (suite *InstanceTestSuite) TestCountInstanceStatuses() {
+ count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())
+ suite.NoError(err)
+ suite.Equal(16, count)
+}
+
+func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {
+ count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")
+ suite.NoError(err)
+ suite.Equal(1, count)
+}
+
+func (suite *InstanceTestSuite) TestCountInstanceDomains() {
+ count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost())
+ suite.NoError(err)
+ suite.Equal(2, count)
+}
+
+func (suite *InstanceTestSuite) TestGetInstancePeers() {
+ peers, err := suite.db.GetInstancePeers(context.Background(), false)
+ suite.NoError(err)
+ suite.Len(peers, 2)
+}
+
+func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() {
+ peers, err := suite.db.GetInstancePeers(context.Background(), true)
+ suite.NoError(err)
+ suite.Len(peers, 2)
+}
+
+func (suite *InstanceTestSuite) TestGetInstanceAccounts() {
+ accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10)
+ suite.NoError(err)
+ suite.Len(accounts, 1)
+}
+
+func TestInstanceTestSuite(t *testing.T) {
+ suite.Run(t, new(InstanceTestSuite))
+}
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
index 71433b901..39e0ad0e3 100644
--- a/internal/db/bundb/media.go
+++ b/internal/db/bundb/media.go
@@ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
attachment := &gtsmodel.MediaAttachment{}
q := m.newMediaQ(attachment).
- Where("media_attachment.id = ?", id)
+ Where("? = ?", bun.Ident("media_attachment.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, m.conn.ProcessError(err)
@@ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
q := m.conn.
NewSelect().
Model(&attachments).
- Where("media_attachment.cached = true").
- Where("media_attachment.avatar = false").
- Where("media_attachment.header = false").
- Where("media_attachment.created_at < ?", olderThan).
+ Where("? = ?", bun.Ident("media_attachment.cached"), true).
+ Where("? = ?", bun.Ident("media_attachment.avatar"), false).
+ Where("? = ?", bun.Ident("media_attachment.header"), false).
+ Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")).
Order("media_attachment.created_at DESC")
@@ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
q := m.newMediaQ(&attachments).
WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {
return innerQ.
- WhereOr("media_attachment.avatar = true").
- WhereOr("media_attachment.header = true")
+ WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true).
+ WhereOr("? = ?", bun.Ident("media_attachment.header"), true)
}).
Order("media_attachment.id DESC")
if maxID != "" {
- q = q.Where("media_attachment.id < ?", maxID)
+ q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
}
if limit != 0 {
@@ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
attachments := []*gtsmodel.MediaAttachment{}
q := m.newMediaQ(&attachments).
- Where("media_attachment.cached = true").
- Where("media_attachment.avatar = false").
- Where("media_attachment.header = false").
- Where("media_attachment.created_at < ?", olderThan).
- Where("media_attachment.remote_url IS NULL").
- Where("media_attachment.status_id IS NULL")
+ Where("? = ?", bun.Ident("media_attachment.cached"), true).
+ Where("? = ?", bun.Ident("media_attachment.avatar"), false).
+ Where("? = ?", bun.Ident("media_attachment.header"), false).
+ Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
+ Where("? IS NULL", bun.Ident("media_attachment.remote_url")).
+ Where("? IS NULL", bun.Ident("media_attachment.status_id"))
if maxID != "" {
- q = q.Where("media_attachment.id < ?", maxID)
+ q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
}
if limit != 0 {
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
index e2c83ef3f..355078021 100644
--- a/internal/db/bundb/mention.go
+++ b/internal/db/bundb/mention.go
@@ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment
mention := gtsmodel.Mention{}
q := m.newMentionQ(&mention).
- Where("mention.id = ?", id)
+ Where("? = ?", bun.Ident("mention.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, m.conn.ProcessError(err)
diff --git a/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go b/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go
index 4c4ada594..b0179ec4f 100644
--- a/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go
+++ b/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go
@@ -47,8 +47,8 @@ func init() {
}
if _, err := tx.NewDelete().
- Model(a).
- WherePK().
+ TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
+ Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil {
l.Errorf("error deleting attachment with id %s: %s", a.ID, err)
} else {
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go
index 32523ca24..69e3cf39f 100644
--- a/internal/db/bundb/notification.go
+++ b/internal/db/bundb/notification.go
@@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/uptrace/bun"
)
type notificationDB struct {
@@ -44,7 +45,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
Relation("OriginAccount").
Relation("TargetAccount").
Relation("Status").
- WherePK()
+ Where("? = ?", bun.Ident("notification.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, n.conn.ProcessError(err)
@@ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
q := n.conn.
NewSelect().
- Table("notifications").
- Column("id")
+ TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
+ Column("notification.id")
if maxID != "" {
- q = q.Where("id < ?", maxID)
+ q = q.Where("? < ?", bun.Ident("notification.id"), maxID)
}
if sinceID != "" {
- q = q.Where("id > ?", sinceID)
+ q = q.Where("? > ?", bun.Ident("notification.id"), sinceID)
}
for _, excludeType := range excludeTypes {
- q = q.Where("notification_type != ?", excludeType)
+ q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType)
}
q = q.
- Where("target_account_id = ?", accountID).
- Order("id DESC")
+ Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
+ Order("notification.id DESC")
if limit != 0 {
q = q.Limit(limit)
@@ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error {
if _, err := n.conn.
NewDelete().
- Table("notifications").
- Where("target_account_id = ?", accountID).
+ TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
+ Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
Exec(ctx); err != nil {
return n.conn.ProcessError(err)
}
n.cache.Clear()
-
return nil
}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index ba72a053a..66e48e441 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.
NewSelect().
- Model(&gtsmodel.Block{}).
- ExcludeColumn("id", "created_at", "updated_at", "uri").
- Limit(1)
+ TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
+ Column("block.id")
if eitherDirection {
q = q.
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
return inner.
- Where("account_id = ?", account1).
- Where("target_account_id = ?", account2)
+ Where("? = ?", bun.Ident("block.account_id"), account1).
+ Where("? = ?", bun.Ident("block.target_account_id"), account2)
}).
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
return inner.
- Where("account_id = ?", account2).
- Where("target_account_id = ?", account1)
+ Where("? = ?", bun.Ident("block.account_id"), account2).
+ Where("? = ?", bun.Ident("block.target_account_id"), account1)
})
} else {
q = q.
- Where("account_id = ?", account1).
- Where("target_account_id = ?", account2)
+ Where("? = ?", bun.Ident("block.account_id"), account1).
+ Where("? = ?", bun.Ident("block.target_account_id"), account2)
}
return r.conn.Exists(ctx, q)
@@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
- Where("block.account_id = ?", account1).
- Where("block.target_account_id = ?", account2)
+ Where("? = ?", bun.Ident("block.account_id"), account1).
+ Where("? = ?", bun.Ident("block.target_account_id"), account2)
if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err)
@@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
if err := r.conn.
NewSelect().
Model(follow).
- Where("account_id = ?", requestingAccount).
- Where("target_account_id = ?", targetAccount).
+ Column("follow.show_reblogs", "follow.notify").
+ Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
+ Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
Limit(1).
Scan(ctx); err != nil {
- if err != sql.ErrNoRows {
- // a proper error
- return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
+ if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
+ return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
}
// no follow exists so these are all false
rel.Following = false
@@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
}
// check if the target account follows the requesting account
- count, err := r.conn.
+ followedByQ := r.conn.
NewSelect().
- Model(&gtsmodel.Follow{}).
- Where("account_id = ?", targetAccount).
- Where("target_account_id = ?", requestingAccount).
- Limit(1).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
+ Column("follow.id").
+ Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
+ Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
+ followedBy, err := r.conn.Exists(ctx, followedByQ)
if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
}
- rel.FollowedBy = count > 0
+ rel.FollowedBy = followedBy
- // check if the requesting account blocks the target account
- count, err = r.conn.NewSelect().
- Model(&gtsmodel.Block{}).
- Where("account_id = ?", requestingAccount).
- Where("target_account_id = ?", targetAccount).
- Limit(1).
- Count(ctx)
+ // check if there's a pending following request from requesting account to target account
+ requestedQ := r.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
+ Column("follow_request.id").
+ Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
+ requested, err := r.conn.Exists(ctx, requestedQ)
if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
}
- rel.Blocking = count > 0
+ rel.Requested = requested
- // check if the target account blocks the requesting account
- count, err = r.conn.
+ // check if the requesting account is blocking the target account
+ blockingQ := r.conn.
NewSelect().
- Model(&gtsmodel.Block{}).
- Where("account_id = ?", targetAccount).
- Where("target_account_id = ?", requestingAccount).
- Limit(1).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
+ Column("block.id").
+ Where("? = ?", bun.Ident("block.account_id"), requestingAccount).
+ Where("? = ?", bun.Ident("block.target_account_id"), targetAccount)
+ blocking, err := r.conn.Exists(ctx, blockingQ)
if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
}
- rel.BlockedBy = count > 0
+ rel.Blocking = blocking
- // check if there's a pending following request from requesting account to target account
- count, err = r.conn.
+ // check if the requesting account is blocked by the target account
+ blockedByQ := r.conn.
NewSelect().
- Model(&gtsmodel.FollowRequest{}).
- Where("account_id = ?", requestingAccount).
- Where("target_account_id = ?", targetAccount).
- Limit(1).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
+ Column("block.id").
+ Where("? = ?", bun.Ident("block.account_id"), targetAccount).
+ Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount)
+ blockedBy, err := r.conn.Exists(ctx, blockedByQ)
if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
}
- rel.Requested = count > 0
+ rel.BlockedBy = blockedBy
return rel, nil
}
@@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
q := r.conn.
NewSelect().
- Model(&gtsmodel.Follow{}).
- Where("account_id = ?", sourceAccount.ID).
- Where("target_account_id = ?", targetAccount.ID).
- Limit(1)
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
+ Column("follow.id").
+ Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
+ Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q)
}
@@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
q := r.conn.
NewSelect().
- Model(&gtsmodel.FollowRequest{}).
- Where("account_id = ?", sourceAccount.ID).
- Where("target_account_id = ?", targetAccount.ID)
+ TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
+ Column("follow_request.id").
+ Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q)
}
@@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
- // make sure the original follow request exists
- fr := &gtsmodel.FollowRequest{}
- if err := r.conn.
- NewSelect().
- Model(fr).
- Where("account_id = ?", originAccountID).
- Where("target_account_id = ?", targetAccountID).
- Scan(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
+ var follow *gtsmodel.Follow
+
+ if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // get original follow request
+ followRequest := &gtsmodel.FollowRequest{}
+ if err := tx.
+ NewSelect().
+ Model(followRequest).
+ Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
+ Scan(ctx); err != nil {
+ return err
+ }
- // create a new follow to 'replace' the request with
- follow := &gtsmodel.Follow{
- ID: fr.ID,
- AccountID: originAccountID,
- TargetAccountID: targetAccountID,
- URI: fr.URI,
- }
+ // create a new follow to 'replace' the request with
+ follow = &gtsmodel.Follow{
+ ID: followRequest.ID,
+ AccountID: originAccountID,
+ TargetAccountID: targetAccountID,
+ URI: followRequest.URI,
+ }
- // if the follow already exists, just update the URI -- we don't need to do anything else
- if _, err := r.conn.
- NewInsert().
- Model(follow).
- On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
+ // if the follow already exists, just update the URI -- we don't need to do anything else
+ if _, err := tx.
+ NewInsert().
+ Model(follow).
+ On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // now remove the follow request
+ if _, err := tx.
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
+ Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
- // now remove the follow request
- if _, err := r.conn.
- NewDelete().
- Model(&gtsmodel.FollowRequest{}).
- Where("account_id = ?", originAccountID).
- Where("target_account_id = ?", targetAccountID).
- Exec(ctx); err != nil {
+ return nil
+ }); err != nil {
return nil, r.conn.ProcessError(err)
}
+ // return the new follow
return follow, nil
}
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
- // first get the follow request out of the database
- fr := &gtsmodel.FollowRequest{}
- if err := r.conn.
- NewSelect().
- Model(fr).
- Where("account_id = ?", originAccountID).
- Where("target_account_id = ?", targetAccountID).
- Scan(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
+ followRequest := &gtsmodel.FollowRequest{}
+
+ if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // get original follow request
+ if err := tx.
+ NewSelect().
+ Model(followRequest).
+ Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
+ Scan(ctx); err != nil {
+ return err
+ }
- // now delete it from the database by ID
- if _, err := r.conn.
- NewDelete().
- Model(&gtsmodel.FollowRequest{ID: fr.ID}).
- WherePK().
- Exec(ctx); err != nil {
+ // now delete it from the database by ID
+ if _, err := tx.
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
+ Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ }); err != nil {
return nil, r.conn.ProcessError(err)
}
// return the deleted follow request
- return fr, nil
+ return followRequest, nil
}
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests).
- Where("target_account_id = ?", accountID).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID).
Order("follow_request.updated_at DESC")
if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
+
return followRequests, nil
}
@@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows).
- Where("account_id = ?", accountID).
+ Where("? = ?", bun.Ident("follow.account_id"), accountID).
Order("follow.updated_at DESC")
if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
+
return follows, nil
}
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
- return r.conn.
+ q := r.conn.
NewSelect().
- Model(&[]*gtsmodel.Follow{}).
- Where("account_id = ?", accountID).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
+
+ if localOnly {
+ q = q.
+ Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")).
+ Where("? = ?", bun.Ident("follow.account_id"), accountID).
+ Where("? IS NULL", bun.Ident("account.domain"))
+ } else {
+ q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
+ }
+
+ return q.Count(ctx)
}
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
@@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
Order("follow.updated_at DESC")
if localOnly {
- q = q.ColumnExpr("follow.*").
- Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").
- Where("follow.target_account_id = ?", accountID).
- WhereGroup(" AND ", whereEmptyOrNull("a.domain"))
+ q = q.
+ Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
+ Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
+ Where("? IS NULL", bun.Ident("account.domain"))
} else {
- q = q.Where("target_account_id = ?", accountID)
+ q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
}
err := q.Scan(ctx)
@@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
}
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
- return r.conn.
+ q := r.conn.
NewSelect().
- Model(&[]*gtsmodel.Follow{}).
- Where("target_account_id = ?", accountID).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
+
+ if localOnly {
+ q = q.
+ Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
+ Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
+ Where("? IS NULL", bun.Ident("account.domain"))
+ } else {
+ q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
+ }
+
+ return q.Count(ctx)
}
diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go
index 34fe85a57..3df16e2f3 100644
--- a/internal/db/bundb/relationship_test.go
+++ b/internal/db/bundb/relationship_test.go
@@ -20,7 +20,6 @@ package bundb_test
import (
"context"
- "errors"
"testing"
"github.com/stretchr/testify/suite"
@@ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
suite.False(blocked)
// have account1 block account2
- suite.db.Put(ctx, &gtsmodel.Block{
+ if err := suite.db.Put(ctx, &gtsmodel.Block{
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
URI: "http://localhost:8080/some_block_uri_1",
AccountID: account1,
TargetAccountID: account2,
- })
+ }); err != nil {
+ suite.FailNow(err.Error())
+ }
// account 1 now blocks account 2
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
@@ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
}
func (suite *RelationshipTestSuite) TestGetBlock() {
- suite.Suite.T().Skip("TODO: implement")
+ ctx := context.Background()
+
+ account1 := suite.testAccounts["local_account_1"].ID
+ account2 := suite.testAccounts["local_account_2"].ID
+
+ if err := suite.db.Put(ctx, &gtsmodel.Block{
+ ID: "01G202BCSXXJZ70BHB5KCAHH8C",
+ URI: "http://localhost:8080/some_block_uri_1",
+ AccountID: account1,
+ TargetAccountID: account2,
+ }); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ block, err := suite.db.GetBlock(ctx, account1, account2)
+ suite.NoError(err)
+ suite.NotNil(block)
+ suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
}
func (suite *RelationshipTestSuite) TestGetRelationship() {
- suite.Suite.T().Skip("TODO: implement")
+ requestingAccount := suite.testAccounts["local_account_1"]
+ targetAccount := suite.testAccounts["admin_account"]
+
+ relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID)
+ suite.NoError(err)
+ suite.NotNil(relationship)
+
+ suite.True(relationship.Following)
+ suite.True(relationship.ShowingReblogs)
+ suite.False(relationship.Notifying)
+ suite.True(relationship.FollowedBy)
+ suite.False(relationship.Blocking)
+ suite.False(relationship.BlockedBy)
+ suite.False(relationship.Muting)
+ suite.False(relationship.MutingNotifications)
+ suite.False(relationship.Requested)
+ suite.False(relationship.DomainBlocking)
+ suite.False(relationship.Endorsed)
+ suite.Empty(relationship.Note)
+}
+
+func (suite *RelationshipTestSuite) TestIsFollowingYes() {
+ requestingAccount := suite.testAccounts["local_account_1"]
+ targetAccount := suite.testAccounts["admin_account"]
+ isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
+ suite.NoError(err)
+ suite.True(isFollowing)
}
-func (suite *RelationshipTestSuite) TestIsFollowing() {
- suite.Suite.T().Skip("TODO: implement")
+func (suite *RelationshipTestSuite) TestIsFollowingNo() {
+ requestingAccount := suite.testAccounts["admin_account"]
+ targetAccount := suite.testAccounts["local_account_2"]
+ isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
+ suite.NoError(err)
+ suite.False(isFollowing)
}
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
- suite.Suite.T().Skip("TODO: implement")
+ requestingAccount := suite.testAccounts["local_account_1"]
+ targetAccount := suite.testAccounts["admin_account"]
+ isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
+ suite.NoError(err)
+ suite.True(isMutualFollowing)
+}
+
+func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
+ requestingAccount := suite.testAccounts["local_account_1"]
+ targetAccount := suite.testAccounts["local_account_2"]
+ isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
+ suite.NoError(err)
+ suite.True(isMutualFollowing)
}
-func (suite *RelationshipTestSuite) AcceptFollowRequest() {
- for _, account := range suite.testAccounts {
- _, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID")
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- suite.Suite.Fail("error accepting follow request: %v", err)
- }
+func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
+ ctx := context.Background()
+ account := suite.testAccounts["admin_account"]
+ targetAccount := suite.testAccounts["local_account_2"]
+
+ followRequest := &gtsmodel.FollowRequest{
+ ID: "01GEF753FWHCHRDWR0QEHBXM8W",
+ URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
+ AccountID: account.ID,
+ TargetAccountID: targetAccount.ID,
}
+
+ if err := suite.db.Put(ctx, followRequest); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
+ suite.NoError(err)
+ suite.NotNil(follow)
+ suite.Equal(followRequest.URI, follow.URI)
}
-func (suite *RelationshipTestSuite) GetAccountFollowRequests() {
- suite.Suite.T().Skip("TODO: implement")
+func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() {
+ ctx := context.Background()
+ account := suite.testAccounts["admin_account"]
+ targetAccount := suite.testAccounts["local_account_2"]
+
+ follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
+ suite.ErrorIs(err, db.ErrNoEntries)
+ suite.Nil(follow)
}
-func (suite *RelationshipTestSuite) GetAccountFollows() {
- suite.Suite.T().Skip("TODO: implement")
+func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() {
+ ctx := context.Background()
+ account := suite.testAccounts["local_account_1"]
+ targetAccount := suite.testAccounts["admin_account"]
+
+ // follow already exists in the db from local_account_1 -> admin_account
+ existingFollow := &gtsmodel.Follow{}
+ if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ followRequest := &gtsmodel.FollowRequest{
+ ID: "01GEF753FWHCHRDWR0QEHBXM8W",
+ URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
+ AccountID: account.ID,
+ TargetAccountID: targetAccount.ID,
+ }
+
+ if err := suite.db.Put(ctx, followRequest); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
+ suite.NoError(err)
+ suite.NotNil(follow)
+
+ // uri should be equal to value of new/overlapping follow request
+ suite.NotEqual(followRequest.URI, existingFollow.URI)
+ suite.Equal(followRequest.URI, follow.URI)
}
-func (suite *RelationshipTestSuite) CountAccountFollows() {
- suite.Suite.T().Skip("TODO: implement")
+func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
+ ctx := context.Background()
+ account := suite.testAccounts["admin_account"]
+ targetAccount := suite.testAccounts["local_account_2"]
+
+ followRequest := &gtsmodel.FollowRequest{
+ ID: "01GEF753FWHCHRDWR0QEHBXM8W",
+ URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
+ AccountID: account.ID,
+ TargetAccountID: targetAccount.ID,
+ }
+
+ if err := suite.db.Put(ctx, followRequest); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
+ suite.NoError(err)
+ suite.NotNil(rejectedFollowRequest)
}
-func (suite *RelationshipTestSuite) GetAccountFollowedBy() {
- // TODO: more comprehensive tests here
+func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
+ ctx := context.Background()
+ account := suite.testAccounts["admin_account"]
+ targetAccount := suite.testAccounts["local_account_2"]
- for _, account := range suite.testAccounts {
- var err error
+ rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
+ suite.ErrorIs(err, db.ErrNoEntries)
+ suite.Nil(rejectedFollowRequest)
+}
- _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
- if err != nil {
- suite.Suite.Fail("error checking accounts followed by: %v", err)
- }
+func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
+ ctx := context.Background()
+ account := suite.testAccounts["admin_account"]
+ targetAccount := suite.testAccounts["local_account_2"]
- _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
- if err != nil {
- suite.Suite.Fail("error checking localOnly accounts followed by: %v", err)
- }
+ followRequest := &gtsmodel.FollowRequest{
+ ID: "01GEF753FWHCHRDWR0QEHBXM8W",
+ URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
+ AccountID: account.ID,
+ TargetAccountID: targetAccount.ID,
}
+
+ if err := suite.db.Put(ctx, followRequest); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
+ suite.NoError(err)
+ suite.Len(followRequests, 1)
}
-func (suite *RelationshipTestSuite) CountAccountFollowedBy() {
- suite.Suite.T().Skip("TODO: implement")
+func (suite *RelationshipTestSuite) TestGetAccountFollows() {
+ account := suite.testAccounts["local_account_1"]
+ follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
+ suite.NoError(err)
+ suite.Len(follows, 2)
+}
+
+func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
+ account := suite.testAccounts["local_account_1"]
+ followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true)
+ suite.NoError(err)
+ suite.Equal(2, followsCount)
+}
+
+func (suite *RelationshipTestSuite) TestCountAccountFollows() {
+ account := suite.testAccounts["local_account_1"]
+ followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false)
+ suite.NoError(err)
+ suite.Equal(2, followsCount)
+}
+
+func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
+ account := suite.testAccounts["local_account_1"]
+ follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
+ suite.NoError(err)
+ suite.Len(follows, 2)
+}
+
+func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() {
+ account := suite.testAccounts["local_account_1"]
+ follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
+ suite.NoError(err)
+ suite.Len(follows, 2)
+}
+
+func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
+ account := suite.testAccounts["local_account_1"]
+ followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false)
+ suite.NoError(err)
+ suite.Equal(2, followsCount)
+}
+
+func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() {
+ account := suite.testAccounts["local_account_1"]
+ followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true)
+ suite.NoError(err)
+ suite.Equal(2, followsCount)
}
func TestRelationshipTestSuite(t *testing.T) {
diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go
index 9138072e1..b9e70a89f 100644
--- a/internal/db/bundb/session.go
+++ b/internal/db/bundb/session.go
@@ -21,7 +21,6 @@ package bundb
import (
"context"
"crypto/rand"
- "errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -35,29 +34,22 @@ type sessionDB struct {
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
rss := make([]*gtsmodel.RouterSession, 0, 1)
- _, err := s.conn.
+ // get the first router session in the db or...
+ if err := s.conn.
NewSelect().
Model(&rss).
Limit(1).
- Order("id DESC").
- Exec(ctx)
- if err != nil {
+ Order("router_session.id DESC").
+ Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err)
}
+ // ... create a new one
if len(rss) == 0 {
- // no session created yet, so make one
return s.createSession(ctx)
}
- if len(rss) != 1 {
- // we asked for 1 so we should get 1
- return nil, errors.New("more than 1 router session was returned")
- }
-
- // return the one session found
- rs := rss[0]
- return rs, nil
+ return rss[0], nil
}
func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
@@ -71,24 +63,23 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession,
return nil, err
}
- rid, err := id.NewULID()
+ id, err := id.NewULID()
if err != nil {
return nil, err
}
rs := &gtsmodel.RouterSession{
- ID: rid,
+ ID: id,
Auth: auth,
Crypt: crypt,
}
- q := s.conn.
+ if _, err := s.conn.
NewInsert().
- Model(rs)
-
- _, err = q.Exec(ctx)
- if err != nil {
+ Model(rs).
+ Exec(ctx); err != nil {
return nil, s.conn.ProcessError(err)
}
+
return rs, nil
}
diff --git a/internal/db/bundb/session_test.go b/internal/db/bundb/session_test.go
index ef508bde8..1e7fde5aa 100644
--- a/internal/db/bundb/session_test.go
+++ b/internal/db/bundb/session_test.go
@@ -37,14 +37,13 @@ func (suite *SessionTestSuite) TestGetSession() {
suite.NotEmpty(session.Crypt)
suite.NotEmpty(session.ID)
- // TODO -- the same session should be returned with consecutive selects
- // right now there's an issue with bytea in bun, so uncomment this when that issue is fixed: https://github.com/uptrace/bun/issues/122
- // session2, err := suite.db.GetSession(context.Background())
- // suite.NoError(err)
- // suite.NotNil(session2)
- // suite.Equal(session.Auth, session2.Auth)
- // suite.Equal(session.Crypt, session2.Crypt)
- // suite.Equal(session.ID, session2.ID)
+ // the same session should be returned with consecutive selects
+ session2, err := suite.db.GetSession(context.Background())
+ suite.NoError(err)
+ suite.NotNil(session2)
+ suite.Equal(session.Auth, session2.Auth)
+ suite.Equal(session.Crypt, session2.Crypt)
+ suite.Equal(session.ID, session2.ID)
}
func TestSessionTestSuite(t *testing.T) {
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index 2d920ee3f..bc72c2849 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -72,7 +72,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
return s.cache.GetByID(id)
},
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx)
+ return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)
},
)
}
@@ -84,7 +84,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
return s.cache.GetByURI(uri)
},
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("status.uri = ?", uri).Scan(ctx)
+ return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)
},
)
}
@@ -96,7 +96,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St
return s.cache.GetByURL(url)
},
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("status.url = ?", url).Scan(ctx)
+ return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)
},
)
}
@@ -109,8 +109,7 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta
status = &gtsmodel.Status{}
// Not cached! Perform database query
- err := dbQuery(status)
- if err != nil {
+ if err := dbQuery(status); err != nil {
return nil, s.conn.ProcessError(err)
}
@@ -138,24 +137,34 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
- return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
- if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
- StatusID: status.ID,
- EmojiID: i,
- }).Exec(ctx); err != nil {
- return err
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.StatusToEmoji{
+ StatusID: status.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ err = s.conn.errProc(err)
+ if !errors.Is(err, db.ErrAlreadyExists) {
+ return err
+ }
}
}
// create links between this status and any tags it uses
for _, i := range status.TagIDs {
- if _, err := tx.NewInsert().Model(&gtsmodel.StatusToTag{
- StatusID: status.ID,
- TagID: i,
- }).Exec(ctx); err != nil {
- return err
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.StatusToTag{
+ StatusID: status.ID,
+ TagID: i,
+ }).Exec(ctx); err != nil {
+ err = s.conn.errProc(err)
+ if !errors.Is(err, db.ErrAlreadyExists) {
+ return err
+ }
}
}
@@ -163,27 +172,46 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
- if _, err := tx.NewUpdate().Model(a).
- Where("id = ?", a.ID).
+ if _, err := tx.
+ NewUpdate().
+ Model(a).
+ Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil {
- return err
+ err = s.conn.errProc(err)
+ if !errors.Is(err, db.ErrAlreadyExists) {
+ return err
+ }
}
}
// Finally, insert the status
- _, err := tx.NewInsert().Model(status).Exec(ctx)
- return err
+ if _, err := tx.
+ NewInsert().
+ Model(status).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
})
+ if err != nil {
+ return s.conn.ProcessError(err)
+ }
+
+ s.cache.Put(status)
+ return nil
}
func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) {
err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
- if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
- StatusID: status.ID,
- EmojiID: i,
- }).Exec(ctx); err != nil {
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.StatusToEmoji{
+ StatusID: status.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
err = s.conn.errProc(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
@@ -193,10 +221,12 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
// create links between this status and any tags it uses
for _, i := range status.TagIDs {
- if _, err := tx.NewInsert().Model(&gtsmodel.StatusToTag{
- StatusID: status.ID,
- TagID: i,
- }).Exec(ctx); err != nil {
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.StatusToTag{
+ StatusID: status.ID,
+ TagID: i,
+ }).Exec(ctx); err != nil {
err = s.conn.errProc(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
@@ -208,23 +238,32 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
- if _, err := tx.NewUpdate().Model(a).
- Where("id = ?", a.ID).
+ if _, err := tx.
+ NewUpdate().
+ Model(a).
+ Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil {
return err
}
}
// Finally, update the status itself
- if _, err := tx.NewUpdate().Model(status).WherePK().Exec(ctx); err != nil {
+ if _, err := tx.
+ NewUpdate().
+ Model(status).
+ Where("? = ?", bun.Ident("status.id"), status.ID).
+ Exec(ctx); err != nil {
return err
}
- s.cache.Put(status)
return nil
})
+ if err != nil {
+ return nil, s.conn.ProcessError(err)
+ }
- return status, err
+ s.cache.Put(status)
+ return status, nil
}
func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
@@ -232,8 +271,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
// delete links between this status and any emojis it uses
if _, err := tx.
NewDelete().
- Model(&gtsmodel.StatusToEmoji{}).
- Where("status_id = ?", id).
+ TableExpr("? AS ?", bun.Ident("status_to_emojis"), bun.Ident("status_to_emoji")).
+ Where("? = ?", bun.Ident("status_to_emoji.status_id"), id).
Exec(ctx); err != nil {
return err
}
@@ -241,8 +280,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
// delete links between this status and any tags it uses
if _, err := tx.
NewDelete().
- Model(&gtsmodel.StatusToTag{}).
- Where("status_id = ?", id).
+ TableExpr("? AS ?", bun.Ident("status_to_tags"), bun.Ident("status_to_tag")).
+ Where("? = ?", bun.Ident("status_to_tag.status_id"), id).
Exec(ctx); err != nil {
return err
}
@@ -250,17 +289,20 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
// delete the status itself
if _, err := tx.
NewDelete().
- Model(&gtsmodel.Status{ID: id}).
- WherePK().
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Where("? = ?", bun.Ident("status.id"), id).
Exec(ctx); err != nil {
return err
}
- s.cache.Invalidate(id)
return nil
})
+ if err != nil {
+ return s.conn.ProcessError(err)
+ }
- return s.conn.ProcessError(err)
+ s.cache.Invalidate(id)
+ return nil
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
@@ -312,11 +354,11 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
q := s.conn.
NewSelect().
- Table("statuses").
- Column("id").
- Where("in_reply_to_id = ?", status.ID)
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Column("status.id").
+ Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID)
if minID != "" {
- q = q.Where("id > ?", minID)
+ q = q.Where("? > ?", bun.Ident("status.id"), minID)
}
if err := q.Scan(ctx, &childIDs); err != nil {
@@ -356,23 +398,35 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
}
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
- return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
+ return s.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID).
+ Count(ctx)
}
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
- return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
+ return s.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Where("? = ?", bun.Ident("status.boost_of_id"), status.ID).
+ Count(ctx)
}
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
- return s.conn.NewSelect().Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
+ return s.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
+ Where("? = ?", bun.Ident("status_fave.status_id"), status.ID).
+ Count(ctx)
}
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
- Model(&gtsmodel.StatusFave{}).
- Where("status_id = ?", status.ID).
- Where("account_id = ?", accountID)
+ TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
+ Where("? = ?", bun.Ident("status_fave.status_id"), status.ID).
+ Where("? = ?", bun.Ident("status_fave.account_id"), accountID)
return s.conn.Exists(ctx, q)
}
@@ -380,9 +434,9 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
- Model(&gtsmodel.Status{}).
- Where("boost_of_id = ?", status.ID).
- Where("account_id = ?", accountID)
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Where("? = ?", bun.Ident("status.boost_of_id"), status.ID).
+ Where("? = ?", bun.Ident("status.account_id"), accountID)
return s.conn.Exists(ctx, q)
}
@@ -390,9 +444,9 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
- Model(&gtsmodel.StatusMute{}).
- Where("status_id = ?", status.ID).
- Where("account_id = ?", accountID)
+ TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")).
+ Where("? = ?", bun.Ident("status_mute.status_id"), status.ID).
+ Where("? = ?", bun.Ident("status_mute.account_id"), accountID)
return s.conn.Exists(ctx, q)
}
@@ -400,9 +454,9 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
- Model(&gtsmodel.StatusBookmark{}).
- Where("status_id = ?", status.ID).
- Where("account_id = ?", accountID)
+ TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
+ Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID).
+ Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID)
return s.conn.Exists(ctx, q)
}
@@ -410,8 +464,9 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
- q := s.newFaveQ(&faves).
- Where("status_id = ?", status.ID)
+ q := s.
+ newFaveQ(&faves).
+ Where("? = ?", bun.Ident("status_fave.status_id"), status.ID)
if err := q.Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err)
@@ -422,8 +477,9 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
reblogs := []*gtsmodel.Status{}
- q := s.newStatusQ(&reblogs).
- Where("boost_of_id = ?", status.ID)
+ q := s.
+ newStatusQ(&reblogs).
+ Where("? = ?", bun.Ident("status.boost_of_id"), status.ID)
if err := q.Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err)
diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go
index a796ebdad..70bc7b845 100644
--- a/internal/db/bundb/status_test.go
+++ b/internal/db/bundb/status_test.go
@@ -108,14 +108,14 @@ func (suite *StatusTestSuite) TestGetStatusTwice() {
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
- fmt.Println(duration1.Milliseconds())
+ fmt.Println(duration1.Microseconds())
before2 := time.Now()
_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)
- fmt.Println(duration2.Milliseconds())
+ fmt.Println(duration2.Microseconds())
// second retrieval should be several orders faster since it will be cached now
suite.Less(duration2, duration1)
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
index d2b3cf07e..d4740dd96 100644
--- a/internal/db/bundb/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -34,38 +34,48 @@ type timelineDB struct {
}
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+ // Ensure reasonable
+ if limit < 0 {
+ limit = 0
+ }
+
// Make educated guess for slice size
statusIDs := make([]string, 0, limit)
q := t.conn.
NewSelect().
- Table("statuses").
-
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
- Column("statuses.id").
+ Column("status.id").
// Find out who accountID follows.
- Join("LEFT JOIN follows ON follows.target_account_id = statuses.account_id AND follows.account_id = ?", accountID).
+ Join("LEFT JOIN ? AS ? ON ? = ? AND ? = ?",
+ bun.Ident("follows"),
+ bun.Ident("follow"),
+ bun.Ident("follow.target_account_id"),
+ bun.Ident("status.account_id"),
+ bun.Ident("follow.account_id"),
+ accountID).
// Sort by highest ID (newest) to lowest ID (oldest)
- Order("statuses.id DESC")
+ Order("status.id DESC")
if maxID != "" {
// return only statuses LOWER (ie., older) than maxID
- q = q.Where("statuses.id < ?", maxID)
+ q = q.Where("? < ?", bun.Ident("status.id"), maxID)
}
if sinceID != "" {
// return only statuses HIGHER (ie., newer) than sinceID
- q = q.Where("statuses.id > ?", sinceID)
+ q = q.Where("? > ?", bun.Ident("status.id"), sinceID)
}
if minID != "" {
// return only statuses HIGHER (ie., newer) than minID
- q = q.Where("statuses.id > ?", minID)
+ q = q.Where("? > ?", bun.Ident("status.id"), minID)
}
if local {
// return only statuses posted by local account havers
- q = q.Where("statuses.local = ?", local)
+ q = q.Where("? = ?", bun.Ident("status.local"), local)
}
if limit > 0 {
@@ -78,13 +88,11 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
//
// This is equivalent to something like WHERE ... AND (... OR ...)
// See: https://bun.uptrace.dev/guide/queries.html#select
- whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
+ q = q.WhereGroup(" AND ", func(*bun.SelectQuery) *bun.SelectQuery {
return q.
- WhereOr("follows.account_id = ?", accountID).
- WhereOr("statuses.account_id = ?", accountID)
- }
-
- q = q.WhereGroup(" AND ", whereGroup)
+ WhereOr("? = ?", bun.Ident("follow.account_id"), accountID).
+ WhereOr("? = ?", bun.Ident("status.account_id"), accountID)
+ })
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err)
@@ -118,28 +126,28 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma
q := t.conn.
NewSelect().
- Table("statuses").
- Column("statuses.id").
- Where("statuses.visibility = ?", gtsmodel.VisibilityPublic).
- WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_id")).
- WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_uri")).
- WhereGroup(" AND ", whereEmptyOrNull("statuses.boost_of_id")).
- Order("statuses.id DESC")
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Column("status.id").
+ Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
+ WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")).
+ WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
+ WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
+ Order("status.id DESC")
if maxID != "" {
- q = q.Where("statuses.id < ?", maxID)
+ q = q.Where("? < ?", bun.Ident("status.id"), maxID)
}
if sinceID != "" {
- q = q.Where("statuses.id > ?", sinceID)
+ q = q.Where("? > ?", bun.Ident("status.id"), sinceID)
}
if minID != "" {
- q = q.Where("statuses.id > ?", minID)
+ q = q.Where("? > ?", bun.Ident("status.id"), minID)
}
if local {
- q = q.Where("statuses.local = ?", local)
+ q = q.Where("? = ?", bun.Ident("status.local"), local)
}
if limit > 0 {
@@ -181,15 +189,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
fq := t.conn.
NewSelect().
Model(&faves).
- Where("account_id = ?", accountID).
- Order("id DESC")
+ Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
+ Order("status_fave.id DESC")
if maxID != "" {
- fq = fq.Where("id < ?", maxID)
+ fq = fq.Where("? < ?", bun.Ident("status_fave.id"), maxID)
}
if minID != "" {
- fq = fq.Where("id > ?", minID)
+ fq = fq.Where("? > ?", bun.Ident("status_fave.id"), minID)
}
if limit > 0 {
diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go
index 2e991ac93..c14d72056 100644
--- a/internal/db/bundb/timeline_test.go
+++ b/internal/db/bundb/timeline_test.go
@@ -38,6 +38,15 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() {
suite.Len(s, 6)
}
+func (suite *TimelineTestSuite) TestGetHomeTimeline() {
+ viewingAccount := suite.testAccounts["local_account_1"]
+
+ s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)
+ suite.NoError(err)
+
+ suite.Len(s, 16)
+}
+
func TestTimelineTestSuite(t *testing.T) {
suite.Run(t, new(TimelineTestSuite))
}
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go
index 46f24c4b2..aa2f4c2c8 100644
--- a/internal/db/bundb/user.go
+++ b/internal/db/bundb/user.go
@@ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db
return u.cache.GetByID(id)
},
func(user *gtsmodel.User) error {
- return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx)
+ return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx)
},
)
}
@@ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts
return u.cache.GetByAccountID(accountID)
},
func(user *gtsmodel.User) error {
- return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx)
+ return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx)
},
)
}
@@ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string)
return u.cache.GetByEmail(emailAddress)
},
func(user *gtsmodel.User) error {
- return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx)
+ return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx)
},
)
}
@@ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok
return u.cache.GetByConfirmationToken(confirmationToken)
},
func(user *gtsmodel.User) error {
- return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx)
+ return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx)
},
)
}
@@ -127,7 +127,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
if _, err := u.conn.
NewUpdate().
Model(user).
- WherePK().
+ Where("? = ?", bun.Ident("user.id"), user.ID).
Column(columns...).
Exec(ctx); err != nil {
return nil, u.conn.ProcessError(err)
@@ -140,8 +140,8 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
if _, err := u.conn.
NewDelete().
- Model(&gtsmodel.User{ID: userID}).
- WherePK().
+ TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
+ Where("? = ?", bun.Ident("user.id"), userID).
Exec(ctx); err != nil {
return u.conn.ProcessError(err)
}
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
index 434d12f32..34f7eb76f 100644
--- a/internal/db/bundb/util.go
+++ b/internal/db/bundb/util.go
@@ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
return
}
- if w.CaseInsensitive {
- query = "LOWER(?) != LOWER(?)"
- args = []interface{}{bun.Safe(w.Key), w.Value}
- return
- }
-
query = "? != ?"
- args = []interface{}{bun.Safe(w.Key), w.Value}
+ args = []interface{}{bun.Ident(w.Key), w.Value}
return
}
@@ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
return
}
- if w.CaseInsensitive {
- query = "LOWER(?) = LOWER(?)"
- args = []interface{}{bun.Safe(w.Key), w.Value}
- return
- }
-
query = "? = ?"
- args = []interface{}{bun.Safe(w.Key), w.Value}
+ args = []interface{}{bun.Ident(w.Key), w.Value}
return
}
diff --git a/internal/db/params.go b/internal/db/params.go
index d1809f1c4..84694d6d3 100644
--- a/internal/db/params.go
+++ b/internal/db/params.go
@@ -24,9 +24,6 @@ type Where struct {
Key string
// The value to match.
Value interface{}
- // Whether the value (if a string) should be case sensitive or not.
- // Defaults to false.
- CaseInsensitive bool
// If set, reverse the where.
// `WHERE k = v` becomes `WHERE k != v`.
// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL`
diff --git a/internal/media/processingmedia.go b/internal/media/processingmedia.go
index c22bddfeb..1f5a58b9f 100644
--- a/internal/media/processingmedia.go
+++ b/internal/media/processingmedia.go
@@ -101,7 +101,7 @@ func (p *ProcessingMedia) LoadAttachment(ctx context.Context) (*gtsmodel.MediaAt
if !p.insertedInDB {
if p.recache {
// if it's a recache we should only need to update
- if err := p.database.UpdateByPrimaryKey(ctx, p.attachment); err != nil {
+ if err := p.database.UpdateByID(ctx, p.attachment, p.attachment.ID); err != nil {
return nil, err
}
} else {
diff --git a/internal/media/prunemeta_test.go b/internal/media/prunemeta_test.go
index 32a3b9a5c..d95a1f3ed 100644
--- a/internal/media/prunemeta_test.go
+++ b/internal/media/prunemeta_test.go
@@ -40,7 +40,7 @@ func (suite *PruneMetaTestSuite) TestPruneMeta() {
zork := suite.testAccounts["local_account_1"]
zork.AvatarMediaAttachmentID = ""
zork.HeaderMediaAttachmentID = ""
- if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
+ if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
panic(err)
}
@@ -72,7 +72,7 @@ func (suite *PruneMetaTestSuite) TestPruneMetaTwice() {
zork := suite.testAccounts["local_account_1"]
zork.AvatarMediaAttachmentID = ""
zork.HeaderMediaAttachmentID = ""
- if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
+ if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
panic(err)
}
@@ -95,14 +95,14 @@ func (suite *PruneMetaTestSuite) TestPruneMetaMultipleAccounts() {
zork := suite.testAccounts["local_account_1"]
zork.AvatarMediaAttachmentID = ""
zork.HeaderMediaAttachmentID = ""
- if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
+ if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
panic(err)
}
// set zork's unused header as belonging to turtle
turtle := suite.testAccounts["local_account_1"]
zorkOldHeader.AccountID = turtle.ID
- if err := suite.db.UpdateByPrimaryKey(ctx, zorkOldHeader, "account_id"); err != nil {
+ if err := suite.db.UpdateByID(ctx, zorkOldHeader, zorkOldHeader.ID, "account_id"); err != nil {
panic(err)
}
diff --git a/internal/media/pruneremote.go b/internal/media/pruneremote.go
index 43ce53cdc..19a9642d7 100644
--- a/internal/media/pruneremote.go
+++ b/internal/media/pruneremote.go
@@ -90,7 +90,7 @@ func (m *manager) pruneOneRemote(ctx context.Context, attachment *gtsmodel.Media
// update the attachment to reflect that we no longer have it cached
if changed {
- return m.db.UpdateByPrimaryKey(ctx, attachment, "updated_at", "cached")
+ return m.db.UpdateByID(ctx, attachment, attachment.ID, "updated_at", "cached")
}
return nil
diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go
index c39a91989..0dd60ea16 100644
--- a/internal/processing/admin/createdomainblock.go
+++ b/internal/processing/admin/createdomainblock.go
@@ -128,15 +128,17 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account
instance.ContactAccountUsername = ""
instance.ContactAccountID = ""
instance.Version = ""
- if err := p.db.UpdateByPrimaryKey(ctx, instance, updatingColumns...); err != nil {
+ if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
}
l.Debug("domainBlockProcessSideEffects: instance entry updated")
}
// if we have an instance account for this instance, delete it
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, &gtsmodel.Account{}); err != nil {
- l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err)
+ if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
+ if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil {
+ l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
+ }
}
// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines)
diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go
index b65954fe5..8637c173e 100644
--- a/internal/processing/admin/deletedomainblock.go
+++ b/internal/processing/admin/deletedomainblock.go
@@ -55,14 +55,14 @@ func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
// remove the domain block reference from the instance, if we have an entry for it
i := &gtsmodel.Instance{}
if err := p.db.GetWhere(ctx, []db.Where{
- {Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true},
+ {Key: "domain", Value: domainBlock.Domain},
{Key: "domain_block_id", Value: id},
}, i); err == nil {
updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}
i.SuspendedAt = time.Time{}
i.DomainBlockID = ""
i.UpdatedAt = time.Now()
- if err := p.db.UpdateByPrimaryKey(ctx, i, updatingColumns...); err != nil {
+ if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
}
}
diff --git a/internal/processing/instance.go b/internal/processing/instance.go
index 32a4de6f0..2d74fe181 100644
--- a/internal/processing/instance.go
+++ b/internal/processing/instance.go
@@ -224,7 +224,7 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
}
}
- if err := p.db.UpdateByPrimaryKey(ctx, i, updatingColumns...); err != nil {
+ if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err))
}
diff --git a/internal/processing/media/getfile_test.go b/internal/processing/media/getfile_test.go
index a8a3568e7..6e5271607 100644
--- a/internal/processing/media/getfile_test.go
+++ b/internal/processing/media/getfile_test.go
@@ -69,7 +69,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncached() {
// uncache the file from local
testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"]
testAttachment.Cached = testrig.FalseBool()
- err := suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached")
+ err := suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached")
suite.NoError(err)
err = suite.storage.Delete(ctx, testAttachment.File.Path)
suite.NoError(err)
@@ -124,7 +124,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncachedInterrupted() {
// uncache the file from local
testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"]
testAttachment.Cached = testrig.FalseBool()
- err := suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached")
+ err := suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached")
suite.NoError(err)
err = suite.storage.Delete(ctx, testAttachment.File.Path)
suite.NoError(err)
@@ -179,7 +179,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileThumbnailUncached() {
// uncache the file from local
testAttachment.Cached = testrig.FalseBool()
- err = suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached")
+ err = suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached")
suite.NoError(err)
err = suite.storage.Delete(ctx, testAttachment.File.Path)
suite.NoError(err)
diff --git a/internal/processing/media/unattach.go b/internal/processing/media/unattach.go
index 5ef8f81f4..d0f34eba1 100644
--- a/internal/processing/media/unattach.go
+++ b/internal/processing/media/unattach.go
@@ -47,7 +47,7 @@ func (p *processor) Unattach(ctx context.Context, account *gtsmodel.Account, med
attachment.UpdatedAt = time.Now()
attachment.StatusID = ""
- if err := p.db.UpdateByPrimaryKey(ctx, attachment, updatingColumns...); err != nil {
+ if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err))
}
diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go
index b8177eeb4..e0833a511 100644
--- a/internal/processing/media/update.go
+++ b/internal/processing/media/update.go
@@ -61,7 +61,7 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, media
updatingColumns = append(updatingColumns, "focus_x", "focus_y")
}
- if err := p.db.UpdateByPrimaryKey(ctx, attachment, updatingColumns...); err != nil {
+ if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err))
}
diff --git a/internal/processing/status/util.go b/internal/processing/status/util.go
index 7617894cc..c8b30e2ca 100644
--- a/internal/processing/status/util.go
+++ b/internal/processing/status/util.go
@@ -162,27 +162,28 @@ func (p *processor) ProcessMediaIDs(ctx context.Context, form *apimodel.Advanced
return nil
}
- gtsMediaAttachments := []*gtsmodel.MediaAttachment{}
- attachments := []string{}
+ attachments := []*gtsmodel.MediaAttachment{}
+ attachmentIDs := []string{}
for _, mediaID := range form.MediaIDs {
- // check these attachments exist
- a := &gtsmodel.MediaAttachment{}
- if err := p.db.GetByID(ctx, mediaID, a); err != nil {
- return fmt.Errorf("invalid media type or media not found for media id %s", mediaID)
+ attachment, err := p.db.GetAttachmentByID(ctx, mediaID)
+ if err != nil {
+ return fmt.Errorf("ProcessMediaIDs: invalid media type or media not found for media id %s", mediaID)
}
- // check they belong to the requesting account id
- if a.AccountID != thisAccountID {
- return fmt.Errorf("media with id %s does not belong to account %s", mediaID, thisAccountID)
+
+ if attachment.AccountID != thisAccountID {
+ return fmt.Errorf("ProcessMediaIDs: media with id %s does not belong to account %s", mediaID, thisAccountID)
}
- // check they're not already used in a status
- if a.StatusID != "" || a.ScheduledStatusID != "" {
- return fmt.Errorf("media with id %s is already attached to a status", mediaID)
+
+ if attachment.StatusID != "" || attachment.ScheduledStatusID != "" {
+ return fmt.Errorf("ProcessMediaIDs: media with id %s is already attached to a status", mediaID)
}
- gtsMediaAttachments = append(gtsMediaAttachments, a)
- attachments = append(attachments, a.ID)
+
+ attachments = append(attachments, attachment)
+ attachmentIDs = append(attachmentIDs, attachment.ID)
}
- status.Attachments = gtsMediaAttachments
- status.AttachmentIDs = attachments
+
+ status.Attachments = attachments
+ status.AttachmentIDs = attachmentIDs
return nil
}
diff --git a/internal/processing/user/changepassword.go b/internal/processing/user/changepassword.go
index ddfec6898..856e92bdc 100644
--- a/internal/processing/user/changepassword.go
+++ b/internal/processing/user/changepassword.go
@@ -45,7 +45,7 @@ func (p *processor) ChangePassword(ctx context.Context, user *gtsmodel.User, old
user.EncryptedPassword = string(newPasswordHash)
user.UpdatedAt = time.Now()
- if err := p.db.UpdateByPrimaryKey(ctx, user, "encrypted_password", "updated_at"); err != nil {
+ if err := p.db.UpdateByID(ctx, user, user.ID, "encrypted_password", "updated_at"); err != nil {
return gtserror.NewErrorInternalError(err, "database error")
}
diff --git a/internal/processing/user/emailconfirm.go b/internal/processing/user/emailconfirm.go
index 5a68383b8..82124007c 100644
--- a/internal/processing/user/emailconfirm.go
+++ b/internal/processing/user/emailconfirm.go
@@ -77,7 +77,7 @@ func (p *processor) SendConfirmEmail(ctx context.Context, user *gtsmodel.User, u
user.LastEmailedAt = time.Now()
user.UpdatedAt = time.Now()
- if err := p.db.UpdateByPrimaryKey(ctx, user, updatingColumns...); err != nil {
+ if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err)
}
@@ -126,7 +126,7 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U
user.ConfirmationToken = ""
user.UpdatedAt = time.Now()
- if err := p.db.UpdateByPrimaryKey(ctx, user, updatingColumns...); err != nil {
+ if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/user/emailconfirm_test.go b/internal/processing/user/emailconfirm_test.go
index 87aff9756..4e31a3646 100644
--- a/internal/processing/user/emailconfirm_test.go
+++ b/internal/processing/user/emailconfirm_test.go
@@ -74,7 +74,7 @@ func (suite *EmailConfirmTestSuite) TestConfirmEmail() {
user.ConfirmationSentAt = time.Now().Add(-5 * time.Minute)
user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6"
- err := suite.db.UpdateByPrimaryKey(ctx, user, updatingColumns...)
+ err := suite.db.UpdateByID(ctx, user, user.ID, updatingColumns...)
suite.NoError(err)
// confirm with the token set above
@@ -102,7 +102,7 @@ func (suite *EmailConfirmTestSuite) TestConfirmEmailOldToken() {
user.ConfirmationSentAt = time.Now().Add(-192 * time.Hour)
user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6"
- err := suite.db.UpdateByPrimaryKey(ctx, user, updatingColumns...)
+ err := suite.db.UpdateByID(ctx, user, user.ID, updatingColumns...)
suite.NoError(err)
// confirm with the token set above
diff --git a/testrig/db.go b/testrig/db.go
index 72446e2bc..88237d7d0 100644
--- a/testrig/db.go
+++ b/testrig/db.go
@@ -187,7 +187,7 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
}
for _, v := range NewTestStatuses() {
- if err := db.PutStatus(ctx, v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
log.Panic(err)
}
}
@@ -198,12 +198,24 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
}
}
+ for _, v := range NewTestStatusToEmojis() {
+ if err := db.Put(ctx, v); err != nil {
+ log.Panic(err)
+ }
+ }
+
for _, v := range NewTestTags() {
if err := db.Put(ctx, v); err != nil {
log.Panic(err)
}
}
+ for _, v := range NewTestStatusToTags() {
+ if err := db.Put(ctx, v); err != nil {
+ log.Panic(err)
+ }
+ }
+
for _, v := range NewTestMentions() {
if err := db.Put(ctx, v); err != nil {
log.Panic(err)
diff --git a/testrig/testmodels.go b/testrig/testmodels.go
index f53022fd8..054f14323 100644
--- a/testrig/testmodels.go
+++ b/testrig/testmodels.go
@@ -977,6 +977,15 @@ func NewTestEmojis() map[string]*gtsmodel.Emoji {
}
}
+func NewTestStatusToEmojis() map[string]*gtsmodel.StatusToEmoji {
+ return map[string]*gtsmodel.StatusToEmoji{
+ "admin_account_status_1_rainbow": {
+ StatusID: "01F8MH75CBF9JFX4ZAD54N0W0R",
+ EmojiID: "01F8MH9H8E4VG3KDYJR9EGPXCQ",
+ },
+ }
+}
+
func NewTestInstances() map[string]*gtsmodel.Instance {
return map[string]*gtsmodel.Instance{
"localhost:8080": {
@@ -1540,6 +1549,15 @@ func NewTestTags() map[string]*gtsmodel.Tag {
}
}
+func NewTestStatusToTags() map[string]*gtsmodel.StatusToTag {
+ return map[string]*gtsmodel.StatusToTag{
+ "admin_account_status_1_welcome": {
+ StatusID: "01F8MH75CBF9JFX4ZAD54N0W0R",
+ TagID: "01F8MHA1A2NF9MJ3WCCQ3K8BSZ",
+ },
+ }
+}
+
// NewTestMentions returns a map of gts model mentions keyed by their name.
func NewTestMentions() map[string]*gtsmodel.Mention {
return map[string]*gtsmodel.Mention{