summaryrefslogtreecommitdiff
path: root/internal/db/bundb/account.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r--internal/db/bundb/account.go147
1 files changed, 122 insertions, 25 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index df73168e2..ccf7aaa46 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -20,11 +20,13 @@ package bundb
import (
"context"
"errors"
+ "fmt"
"strings"
"time"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -37,18 +39,15 @@ type accountDB struct {
state *state.State
}
-func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
- return a.conn.
- NewSelect().
- Model(account)
-}
-
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"ID",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.id"), id).
+ Scan(ctx)
},
id,
)
@@ -59,7 +58,10 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
ctx,
"URI",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.uri"), uri).
+ Scan(ctx)
},
uri,
)
@@ -70,7 +72,10 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
ctx,
"URL",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.url"), url).
+ Scan(ctx)
},
url,
)
@@ -81,7 +86,8 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
ctx,
"Username.Domain",
func(account *gtsmodel.Account) error {
- q := a.newAccountQ(account)
+ q := a.conn.NewSelect().
+ Model(account)
if domain != "" {
q = q.
@@ -105,12 +111,71 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
ctx,
"PublicKeyURI",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.public_key_uri"), id).
+ Scan(ctx)
},
id,
)
}
+func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "InboxURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.inbox_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "OutboxURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.outbox_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "FollowersURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.followers_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "FollowingURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.following_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
var username string
@@ -141,31 +206,56 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(
return nil, err
}
- if account.AvatarMediaAttachmentID != "" {
- // Set the account's related avatar
- account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return account, nil
+ }
+
+ // Further populate the account fields where applicable.
+ if err := a.PopulateAccount(ctx, account); err != nil {
+ return nil, err
+ }
+
+ return account, nil
+}
+
+func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error {
+ var err error
+
+ if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" {
+ // Account avatar attachment is not set, fetch from database.
+ account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(
+ ctx, // these are already barebones
+ account.AvatarMediaAttachmentID,
+ )
if err != nil {
- log.Errorf(ctx, "error getting account %s avatar: %v", account.ID, err)
+ return fmt.Errorf("error populating account avatar: %w", err)
}
}
- if account.HeaderMediaAttachmentID != "" {
- // Set the account's related header
- account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.HeaderMediaAttachmentID)
+ if account.HeaderMediaAttachment == nil && account.HeaderMediaAttachmentID != "" {
+ // Account header attachment is not set, fetch from database.
+ account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(
+ ctx, // these are already barebones
+ account.HeaderMediaAttachmentID,
+ )
if err != nil {
- log.Errorf(ctx, "error getting account %s header: %v", account.ID, err)
+ return fmt.Errorf("error populating account header: %w", err)
}
}
- if len(account.EmojiIDs) > 0 {
- // Set the account's related emojis
- account.Emojis, err = a.state.DB.GetEmojisByIDs(ctx, account.EmojiIDs)
+ if !account.EmojisPopulated() {
+ // Account emojis are out-of-date with IDs, repopulate.
+ account.Emojis, err = a.state.DB.GetEmojisByIDs(
+ ctx, // these are already barebones
+ account.EmojiIDs,
+ )
if err != nil {
- log.Errorf(ctx, "error getting account %s emojis: %v", account.ID, err)
+ return fmt.Errorf("error populating account emojis: %w", err)
}
}
- return account, nil
+ return nil
}
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
@@ -198,7 +288,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "updated_at")
}
- return a.state.Caches.GTS.Account().Store(account, func() error {
+ err := a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@@ -234,6 +324,11 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
return err
})
})
+ if err != nil {
+ return err
+ }
+
+ return nil
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
@@ -258,7 +353,9 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
return err
}
+ // Invalidate account from database lookups.
a.state.Caches.GTS.Account().Invalidate("ID", id)
+
return nil
}