diff options
Diffstat (limited to 'internal/db/bundb/notification.go')
-rw-r--r-- | internal/db/bundb/notification.go | 148 |
1 files changed, 104 insertions, 44 deletions
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 7532b9993..ed34222fb 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -28,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -37,18 +39,17 @@ type notificationDB struct { } func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) { - return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) { - var notif gtsmodel.Notification - - q := n.db.NewSelect(). - Model(¬if). - Where("? = ?", bun.Ident("notification.id"), id) - if err := q.Scan(ctx); err != nil { - return nil, err - } - - return ¬if, nil - }, id) + return n.getNotification( + ctx, + "ID", + func(notif *gtsmodel.Notification) error { + return n.db.NewSelect(). + Model(notif). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + }, + id, + ) } func (n *notificationDB) GetNotification( @@ -58,42 +59,113 @@ func (n *notificationDB) GetNotification( originAccountID string, statusID string, ) (*gtsmodel.Notification, error) { - notif, err := n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { - var notif gtsmodel.Notification + return n.getNotification( + ctx, + "NotificationType,TargetAccountID,OriginAccountID,StatusID", + func(notif *gtsmodel.Notification) error { + return n.db.NewSelect(). + Model(notif). + Where("? = ?", bun.Ident("notification_type"), notificationType). + Where("? = ?", bun.Ident("target_account_id"), targetAccountID). + Where("? = ?", bun.Ident("origin_account_id"), originAccountID). + Where("? = ?", bun.Ident("status_id"), statusID). + Scan(ctx) + }, + notificationType, targetAccountID, originAccountID, statusID, + ) +} - q := n.db.NewSelect(). - Model(¬if). - Where("? = ?", bun.Ident("notification_type"), notificationType). - Where("? = ?", bun.Ident("target_account_id"), targetAccountID). - Where("? = ?", bun.Ident("origin_account_id"), originAccountID). - Where("? = ?", bun.Ident("status_id"), statusID) +func (n *notificationDB) getNotification(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Notification) error, keyParts ...any) (*gtsmodel.Notification, error) { + // Fetch notification from cache with loader callback + notif, err := n.state.Caches.GTS.Notification.LoadOne(lookup, func() (*gtsmodel.Notification, error) { + var notif gtsmodel.Notification - if err := q.Scan(ctx); err != nil { + // Not cached! Perform database query + if err := dbQuery(¬if); err != nil { return nil, err } return ¬if, nil - }, notificationType, targetAccountID, originAccountID, statusID) + }, keyParts...) if err != nil { return nil, err } if gtscontext.Barebones(ctx) { - // no need to fully populate. + // Only a barebones model was requested. return notif, nil } - // Further populate the notif fields where applicable. - if err := n.PopulateNotification(ctx, notif); err != nil { + if err := n.state.DB.PopulateNotification(ctx, notif); err != nil { return nil, err } return notif, nil } +func (n *notificationDB) GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error) { + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all notif IDs via cache loader callbacks. + notifs, err := n.state.Caches.GTS.Notification.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached notification loader function. + func() ([]*gtsmodel.Notification, error) { + // Preallocate expected length of uncached notifications. + notifs := make([]*gtsmodel.Notification, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := n.db.NewSelect(). + Model(¬ifs). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return notifs, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the notifs by their + // IDs to ensure in correct order. + getID := func(n *gtsmodel.Notification) string { return n.ID } + util.OrderBy(notifs, ids, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return notifs, nil + } + + // Populate all loaded notifs, removing those we fail to + // populate (removes needing so many nil checks everywhere). + notifs = slices.DeleteFunc(notifs, func(notif *gtsmodel.Notification) bool { + if err := n.PopulateNotification(ctx, notif); err != nil { + log.Errorf(ctx, "error populating notif %s: %v", notif.ID, err) + return true + } + return false + }) + + return notifs, nil +} + func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmodel.Notification) error { var ( - errs = gtserror.NewMultiError(2) + errs gtserror.MultiError err error ) @@ -211,31 +283,19 @@ func (n *notificationDB) GetAccountNotifications( } } - notifs := make([]*gtsmodel.Notification, 0, len(notifIDs)) - for _, id := range notifIDs { - // Attempt fetch from DB - notif, err := n.GetNotificationByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching notification %q: %v", id, err) - continue - } - - // Append notification - notifs = append(notifs, notif) - } - - return notifs, nil + // Fetch notification models by their IDs. + return n.GetNotificationsByIDs(ctx, notifIDs) } func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { - return n.state.Caches.GTS.Notification().Store(notif, func() error { + return n.state.Caches.GTS.Notification.Store(notif, func() error { _, err := n.db.NewInsert().Model(notif).Exec(ctx) return err }) } func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error { - defer n.state.Caches.GTS.Notification().Invalidate("ID", id) + defer n.state.Caches.GTS.Notification.Invalidate("ID", id) // Load notif into cache before attempting a delete, // as we need it cached in order to trigger the invalidate @@ -288,7 +348,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string defer func() { // Invalidate all IDs on return. for _, id := range notifIDs { - n.state.Caches.GTS.Notification().Invalidate("ID", id) + n.state.Caches.GTS.Notification.Invalidate("ID", id) } }() @@ -326,7 +386,7 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu defer func() { // Invalidate all IDs on return. for _, id := range notifIDs { - n.state.Caches.GTS.Notification().Invalidate("ID", id) + n.state.Caches.GTS.Notification.Invalidate("ID", id) } }() |