diff options
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r-- | internal/db/bundb/account.go | 147 |
1 files changed, 122 insertions, 25 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index df73168e2..ccf7aaa46 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -20,11 +20,13 @@ package bundb import ( "context" "errors" + "fmt" "strings" "time" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" @@ -37,18 +39,15 @@ type accountDB struct { state *state.State } -func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { - return a.conn. - NewSelect(). - Model(account) -} - func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, "ID", func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.id"), id). + Scan(ctx) }, id, ) @@ -59,7 +58,10 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. ctx, "URI", func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.uri"), uri). + Scan(ctx) }, uri, ) @@ -70,7 +72,10 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. ctx, "URL", func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.url"), url). + Scan(ctx) }, url, ) @@ -81,7 +86,8 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str ctx, "Username.Domain", func(account *gtsmodel.Account) error { - q := a.newAccountQ(account) + q := a.conn.NewSelect(). + Model(account) if domain != "" { q = q. @@ -105,12 +111,71 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo ctx, "PublicKeyURI", func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.public_key_uri"), id). + Scan(ctx) }, id, ) } +func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + "InboxURI", + func(account *gtsmodel.Account) error { + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.inbox_uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + "OutboxURI", + func(account *gtsmodel.Account) error { + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.outbox_uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + "FollowersURI", + func(account *gtsmodel.Account) error { + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.followers_uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + "FollowingURI", + func(account *gtsmodel.Account) error { + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.following_uri"), uri). + Scan(ctx) + }, + uri, + ) +} + func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { var username string @@ -141,31 +206,56 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func( return nil, err } - if account.AvatarMediaAttachmentID != "" { - // Set the account's related avatar - account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return account, nil + } + + // Further populate the account fields where applicable. + if err := a.PopulateAccount(ctx, account); err != nil { + return nil, err + } + + return account, nil +} + +func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error { + var err error + + if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" { + // Account avatar attachment is not set, fetch from database. + account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID( + ctx, // these are already barebones + account.AvatarMediaAttachmentID, + ) if err != nil { - log.Errorf(ctx, "error getting account %s avatar: %v", account.ID, err) + return fmt.Errorf("error populating account avatar: %w", err) } } - if account.HeaderMediaAttachmentID != "" { - // Set the account's related header - account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.HeaderMediaAttachmentID) + if account.HeaderMediaAttachment == nil && account.HeaderMediaAttachmentID != "" { + // Account header attachment is not set, fetch from database. + account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID( + ctx, // these are already barebones + account.HeaderMediaAttachmentID, + ) if err != nil { - log.Errorf(ctx, "error getting account %s header: %v", account.ID, err) + return fmt.Errorf("error populating account header: %w", err) } } - if len(account.EmojiIDs) > 0 { - // Set the account's related emojis - account.Emojis, err = a.state.DB.GetEmojisByIDs(ctx, account.EmojiIDs) + if !account.EmojisPopulated() { + // Account emojis are out-of-date with IDs, repopulate. + account.Emojis, err = a.state.DB.GetEmojisByIDs( + ctx, // these are already barebones + account.EmojiIDs, + ) if err != nil { - log.Errorf(ctx, "error getting account %s emojis: %v", account.ID, err) + return fmt.Errorf("error populating account emojis: %w", err) } } - return account, nil + return nil } func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { @@ -198,7 +288,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account columns = append(columns, "updated_at") } - return a.state.Caches.GTS.Account().Store(account, func() error { + err := a.state.Caches.GTS.Account().Store(account, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // @@ -234,6 +324,11 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account return err }) }) + if err != nil { + return err + } + + return nil } func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { @@ -258,7 +353,9 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { return err } + // Invalidate account from database lookups. a.state.Caches.GTS.Account().Invalidate("ID", id) + return nil } |