diff options
Diffstat (limited to 'internal/db/bundb')
-rw-r--r-- | internal/db/bundb/emoji.go | 20 | ||||
-rw-r--r-- | internal/db/bundb/instance.go | 4 | ||||
-rw-r--r-- | internal/db/bundb/notification.go | 57 | ||||
-rw-r--r-- | internal/db/bundb/relationship_note.go | 48 | ||||
-rw-r--r-- | internal/db/bundb/status_test.go | 45 | ||||
-rw-r--r-- | internal/db/bundb/user.go | 35 |
6 files changed, 184 insertions, 25 deletions
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index a3a19485d..34a08b694 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -26,6 +26,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" @@ -548,6 +549,25 @@ func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gts return emoji, nil } +func (e *emojiDB) PopulateEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error { + var ( + errs = gtserror.NewMultiError(1) + err error + ) + + if emoji.CategoryID != "" && emoji.Category == nil { + emoji.Category, err = e.GetEmojiCategory( + ctx, // these are already barebones + emoji.CategoryID, + ) + if err != nil { + errs.Appendf("error populating emoji category: %w", err) + } + } + + return errs.Combine() +} + func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) { if len(emojiIDs) == 0 { return nil, db.ErrNoEntries diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 6fec3f2fe..567a44ee2 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -173,14 +173,14 @@ func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery fun } // Further populate the instance fields where applicable. - if err := i.populateInstance(ctx, instance); err != nil { + if err := i.PopulateInstance(ctx, instance); err != nil { return nil, err } return instance, nil } -func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.Instance) error { +func (i *instanceDB) PopulateInstance(ctx context.Context, instance *gtsmodel.Instance) error { var ( err error errs = gtserror.NewMultiError(2) diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 423fd0be1..7532b9993 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -23,6 +23,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -57,7 +58,7 @@ func (n *notificationDB) GetNotification( originAccountID string, statusID string, ) (*gtsmodel.Notification, error) { - return n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { + notif, err := n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { var notif gtsmodel.Notification q := n.db.NewSelect(). @@ -73,6 +74,60 @@ func (n *notificationDB) GetNotification( return ¬if, nil }, notificationType, targetAccountID, originAccountID, statusID) + if err != nil { + return nil, err + } + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return notif, nil + } + + // Further populate the notif fields where applicable. + if err := n.PopulateNotification(ctx, notif); err != nil { + return nil, err + } + + return notif, nil +} + +func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmodel.Notification) error { + var ( + errs = gtserror.NewMultiError(2) + err error + ) + + if notif.TargetAccount == nil { + notif.TargetAccount, err = n.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + notif.TargetAccountID, + ) + if err != nil { + errs.Appendf("error populating notif target account: %w", err) + } + } + + if notif.OriginAccount == nil { + notif.OriginAccount, err = n.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + notif.OriginAccountID, + ) + if err != nil { + errs.Appendf("error populating notif origin account: %w", err) + } + } + + if notif.StatusID != "" && notif.Status == nil { + notif.Status, err = n.state.DB.GetStatusByID( + gtscontext.SetBarebones(ctx), + notif.StatusID, + ) + if err != nil { + errs.Appendf("error populating notif status: %w", err) + } + } + + return errs.Combine() } func (n *notificationDB) GetAccountNotifications( diff --git a/internal/db/bundb/relationship_note.go b/internal/db/bundb/relationship_note.go index 84f0ebeab..f7d15f8b7 100644 --- a/internal/db/bundb/relationship_note.go +++ b/internal/db/bundb/relationship_note.go @@ -19,10 +19,10 @@ package bundb import ( "context" - "fmt" "time" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/uptrace/bun" ) @@ -64,25 +64,43 @@ func (r *relationshipDB) getNote(ctx context.Context, lookup string, dbQuery fun return note, nil } - // Set the note source account - note.Account, err = r.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - note.AccountID, - ) - if err != nil { - return nil, fmt.Errorf("error getting note source account: %w", err) + // Further populate the account fields where applicable. + if err := r.PopulateNote(ctx, note); err != nil { + return nil, err } - // Set the note target account - note.TargetAccount, err = r.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - note.TargetAccountID, + return note, nil +} + +func (r *relationshipDB) PopulateNote(ctx context.Context, note *gtsmodel.AccountNote) error { + var ( + errs = gtserror.NewMultiError(2) + err error ) - if err != nil { - return nil, fmt.Errorf("error getting note target account: %w", err) + + // Ensure note source account set. + if note.Account == nil { + note.Account, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + note.AccountID, + ) + if err != nil { + errs.Appendf("error populating note source account: %w", err) + } } - return note, nil + // Ensure note target account set. + if note.TargetAccount == nil { + note.TargetAccount, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + note.TargetAccountID, + ) + if err != nil { + errs.Appendf("error populating note target account: %w", err) + } + } + + return errs.Combine() } func (r *relationshipDB) PutNote(ctx context.Context, note *gtsmodel.AccountNote) error { diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index c0ff6c0da..2129aa0e8 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -224,6 +224,51 @@ func (suite *StatusTestSuite) TestUpdateStatus() { suite.True(updated.PinnedAt.IsZero()) } +func (suite *StatusTestSuite) TestPutPopulatedStatus() { + ctx := context.Background() + + targetStatus := >smodel.Status{} + *targetStatus = *suite.testStatuses["admin_account_status_1"] + + // Populate fields on the target status. + if err := suite.db.PopulateStatus(ctx, targetStatus); err != nil { + suite.FailNow(err.Error()) + } + + // Delete it from the database. + if err := suite.db.DeleteStatusByID(ctx, targetStatus.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Reinsert the populated version + // so that it becomes cached. + if err := suite.db.PutStatus(ctx, targetStatus); err != nil { + suite.FailNow(err.Error()) + } + + // Update the status owner's + // account with a new bio. + account := >smodel.Account{} + *account = *targetStatus.Account + account.Note = "new note for this test" + if err := suite.db.UpdateAccount(ctx, account, "note"); err != nil { + suite.FailNow(err.Error()) + } + + dbStatus, err := suite.db.GetStatusByID(ctx, targetStatus.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Account note should be updated, + // even though we stored this + // status with the old note. + suite.Equal( + "new note for this test", + dbStatus.Account.Note, + ) +} + func TestStatusTestSuite(t *testing.T) { suite.Run(t, new(StatusTestSuite)) } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index eaa1d8e3d..46b3c568f 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -130,18 +130,39 @@ func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmo return nil, err } - // Fetch the related account model for this user. - user.Account, err = u.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - user.AccountID, - ) - if err != nil { - return nil, gtserror.Newf("error populating user account: %w", err) + if gtscontext.Barebones(ctx) { + // Return without populating. + return user, nil + } + + if err := u.PopulateUser(ctx, user); err != nil { + return nil, err } return user, nil } +// PopulateUser ensures that the user's struct fields are populated. +func (u *userDB) PopulateUser(ctx context.Context, user *gtsmodel.User) error { + var ( + errs = gtserror.NewMultiError(1) + err error + ) + + if user.Account == nil { + // Fetch the related account model for this user. + user.Account, err = u.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + user.AccountID, + ) + if err != nil { + errs.Appendf("error populating user account: %w", err) + } + } + + return errs.Combine() +} + func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { var userIDs []string |