diff options
author | 2022-10-08 13:50:48 +0200 | |
---|---|---|
committer | 2022-10-08 13:50:48 +0200 | |
commit | aa07750bdb4dacdb1be39d765114915bba3fc29f (patch) | |
tree | 30e9e5052f607f8c8e4f7d518559df8706275e0f /internal/db/bundb/account.go | |
parent | [performance] cache domains after max retries in transport (#884) (diff) | |
download | gotosocial-aa07750bdb4dacdb1be39d765114915bba3fc29f.tar.xz |
[chore] Standardize database queries, use `bun.Ident()` properly (#886)
* use bun.Ident for user queries
* use bun.Ident for account queries
* use bun.Ident for media queries
* add DeleteAccount func
* remove CaseInsensitive in Where+use Ident ipv Safe
* update admin db
* update domain, use ident
* update emoji, use ident
* update instance queries, use bun.Ident
* fix media
* update mentions, use bun ident
* update relationship + tests
* use tableexpr
* add test follows to bun db test suite
* update notifications
* updatebyprimarykey => updatebyid
* fix session
* prefer explicit ID to pk
* fix little fucky wucky
* remove workaround
* use proper db func for attachment selection
* update status db
* add m2m entries in test rig
* fix up timeline
* go fmt
* fix status put issue
* update GetAccountStatuses
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r-- | internal/db/bundb/account.go | 164 |
1 files changed, 100 insertions, 64 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 074804690..c04948fee 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -21,7 +21,6 @@ package bundb import ( "context" "errors" - "fmt" "strings" "time" @@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac return a.cache.GetByID(id) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) }, ) } @@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. return a.cache.GetByURI(uri) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) }, ) } @@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. return a.cache.GetByURL(url) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) }, ) } @@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str q := a.newAccountQ(account) if domain != "" { - q = q.Where("account.username = ?", username) - q = q.Where("account.domain = ?", domain) + q = q.Where("? = ?", bun.Ident("account.username"), username) + q = q.Where("? = ?", bun.Ident("account.domain"), domain) } else { - q = q.Where("account.username = ?", strings.ToLower(username)) - q = q.Where("account.domain IS NULL") + q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) + q = q.Where("? IS NULL", bun.Ident("account.domain")) } return q.Scan(ctx) @@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo return a.cache.GetByPubkeyID(id) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) }, ) } @@ -169,26 +168,36 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this account and any emojis it uses // first clear out any old emoji links - if _, err := tx.NewDelete(). - Model(&[]*gtsmodel.AccountToEmoji{}). - Where("account_id = ?", account.ID). + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). Exec(ctx); err != nil { return err } // now populate new emoji links for _, i := range account.EmojiIDs { - if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ - AccountID: account.ID, - EmojiID: i, - }).Exec(ctx); err != nil { + if _, err := tx. + NewInsert(). + Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { return err } } // update the account - _, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx) - return err + if _, err := tx. + NewUpdate(). + Model(account). + Where("? = ?", bun.Ident("account.id"), account.ID). + Exec(ctx); err != nil { + return err + } + + return nil }); err != nil { return nil, a.conn.ProcessError(err) } @@ -197,6 +206,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account return account, nil } +func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { + if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // clear out any emoji links + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), id). + Exec(ctx); err != nil { + return err + } + + // delete the account + _, err := tx. + NewUpdate(). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Where("? = ?", bun.Ident("account.id"), id). + Exec(ctx) + return err + }); err != nil { + return a.conn.ProcessError(err) + } + + a.cache.Invalidate(id) + return nil +} + func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { account := new(gtsmodel.Account) @@ -204,11 +239,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts if domain != "" { q = q. - Where("account.username = ?", domain). - Where("account.domain = ?", domain) + Where("? = ?", bun.Ident("account.username"), domain). + Where("? = ?", bun.Ident("account.domain"), domain) } else { q = q. - Where("account.username = ?", config.GetHost()). + Where("? = ?", bun.Ident("account.username"), config.GetHost()). WhereGroup(" AND ", whereEmptyOrNull("domain")) } @@ -224,10 +259,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) q := a.conn. NewSelect(). Model(status). - Order("id DESC"). - Limit(1). - Where("account_id = ?", accountID). - Column("created_at") + Column("status.created_at"). + Where("? = ?", bun.Ident("status.account_id"), accountID). + Order("status.id DESC"). + Limit(1) if err := q.Scan(ctx); err != nil { return time.Time{}, a.conn.ProcessError(err) @@ -240,12 +275,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen return errors.New("one media attachment cannot be both header and avatar") } - var headerOrAVI string + var column bun.Ident switch { case *mediaAttachment.Avatar: - headerOrAVI = "avatar" + column = bun.Ident("account.avatar_media_attachment_id") case *mediaAttachment.Header: - headerOrAVI = "header" + column = bun.Ident("account.header_media_attachment_id") default: return errors.New("given media attachment was neither a header nor an avatar") } @@ -257,11 +292,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen Exec(ctx); err != nil { return a.conn.ProcessError(err) } + if _, err := a.conn. NewUpdate(). - Model(>smodel.Account{}). - Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). - Where("id = ?", accountID). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Set("? = ?", column, mediaAttachment.ID). + Where("? = ?", bun.Ident("account.id"), accountID). Exec(ctx); err != nil { return a.conn.ProcessError(err) } @@ -284,7 +320,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g if err := a.conn. NewSelect(). Model(faves). - Where("account_id = ?", accountID). + Where("? = ?", bun.Ident("status_fave.account_id"), accountID). Scan(ctx); err != nil { return nil, a.conn.ProcessError(err) } @@ -295,8 +331,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { return a.conn. NewSelect(). - Model(>smodel.Status{}). - Where("account_id = ?", accountID). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Where("? = ?", bun.Ident("status.account_id"), accountID). Count(ctx) } @@ -305,12 +341,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li q := a.conn. NewSelect(). - Table("statuses"). - Column("id"). - Order("id DESC") + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("status.id"). + Order("status.id DESC") if accountID != "" { - q = q.Where("account_id = ?", accountID) + q = q.Where("? = ?", bun.Ident("status.account_id"), accountID) } if limit != 0 { @@ -321,27 +357,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li // include self-replies (threads) whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { return q. - WhereOr("in_reply_to_account_id = ?", accountID). - WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri")) + WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). + WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri")) } q = q.WhereGroup(" AND ", whereGroup) } if excludeReblogs { - q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")) + q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")) } if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("status.id"), maxID) } if minID != "" { - q = q.Where("id > ?", minID) + q = q.Where("? > ?", bun.Ident("status.id"), minID) } if pinnedOnly { - q = q.Where("pinned = ?", true) + q = q.Where("? = ?", bun.Ident("status.pinned"), true) } if mediaOnly { @@ -352,15 +388,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li switch a.conn.Dialect().Name() { case dialect.PG: return q. - Where("? IS NOT NULL", bun.Ident("attachments")). - Where("? != '{}'", bun.Ident("attachments")) + Where("? IS NOT NULL", bun.Ident("status.attachments")). + Where("? != '{}'", bun.Ident("status.attachments")) case dialect.SQLite: return q. - Where("? IS NOT NULL", bun.Ident("attachments")). - Where("? != ''", bun.Ident("attachments")). - Where("? != 'null'", bun.Ident("attachments")). - Where("? != '{}'", bun.Ident("attachments")). - Where("? != '[]'", bun.Ident("attachments")) + Where("? IS NOT NULL", bun.Ident("status.attachments")). + Where("? != ''", bun.Ident("status.attachments")). + Where("? != 'null'", bun.Ident("status.attachments")). + Where("? != '{}'", bun.Ident("status.attachments")). + Where("? != '[]'", bun.Ident("status.attachments")) default: log.Panic("db dialect was neither pg nor sqlite") return q @@ -369,7 +405,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } if publicOnly { - q = q.Where("visibility = ?", gtsmodel.VisibilityPublic) + q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic) } if err := q.Scan(ctx, &statusIDs); err != nil { @@ -384,19 +420,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, q := a.conn. NewSelect(). - Table("statuses"). - Column("id"). - Where("account_id = ?", accountID). - WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")). - WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")). - Where("visibility = ?", gtsmodel.VisibilityPublic). - Where("federated = ?", true) + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("status.id"). + Where("? = ?", bun.Ident("status.account_id"), accountID). + WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). + WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). + Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). + Where("? = ?", bun.Ident("status.federated"), true) if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("status.id"), maxID) } - q = q.Limit(limit).Order("id DESC") + q = q.Limit(limit).Order("status.id DESC") if err := q.Scan(ctx, &statusIDs); err != nil { return nil, a.conn.ProcessError(err) @@ -411,16 +447,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI fq := a.conn. NewSelect(). Model(&blocks). - Where("block.account_id = ?", accountID). + Where("? = ?", bun.Ident("block.account_id"), accountID). Relation("TargetAccount"). Order("block.id DESC") if maxID != "" { - fq = fq.Where("block.id < ?", maxID) + fq = fq.Where("? < ?", bun.Ident("block.id"), maxID) } if sinceID != "" { - fq = fq.Where("block.id > ?", sinceID) + fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID) } if limit > 0 { |