diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/notification.go | 49 | ||||
| -rw-r--r-- | internal/db/bundb/notification_test.go | 4 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 102 |
3 files changed, 138 insertions, 17 deletions
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 77d4861b2..2f4989c33 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -54,24 +54,28 @@ func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*g func (n *notificationDB) GetNotification( ctx context.Context, - notificationType gtsmodel.NotificationType, - targetAccountID string, - originAccountID string, - statusID string, + notifType gtsmodel.NotificationType, + targetAcctID string, + originAcctID string, + statusOrEditID string, ) (*gtsmodel.Notification, error) { return n.getNotification( ctx, - "NotificationType,TargetAccountID,OriginAccountID,StatusID", + "NotificationType,TargetAccountID,OriginAccountID,StatusOrEditID", func(notif *gtsmodel.Notification) error { - return n.db.NewSelect(). + q := n.db.NewSelect(). Model(notif). - Where("? = ?", bun.Ident("notification_type"), notificationType). - Where("? = ?", bun.Ident("target_account_id"), targetAccountID). - Where("? = ?", bun.Ident("origin_account_id"), originAccountID). - Where("? = ?", bun.Ident("status_id"), statusID). - Scan(ctx) + Where("? = ?", bun.Ident("notification_type"), notifType). + Where("? = ?", bun.Ident("target_account_id"), targetAcctID). + Where("? = ?", bun.Ident("origin_account_id"), originAcctID) + + if statusOrEditID != "" { + q = q.Where("? = ?", bun.Ident("status_id"), statusOrEditID) + } + + return q.Scan(ctx) }, - notificationType, targetAccountID, originAccountID, statusID, + notifType, targetAcctID, originAcctID, statusOrEditID, ) } @@ -176,14 +180,29 @@ func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmod } } - if notif.StatusID != "" && notif.Status == nil { + if notif.StatusOrEditID != "" && notif.Status == nil { + // Try getting status by ID first. notif.Status, err = n.state.DB.GetStatusByID( gtscontext.SetBarebones(ctx), - notif.StatusID, + notif.StatusOrEditID, ) - if err != nil { + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Only append real db error. It might be an edit ID. errs.Appendf("error populating notif status: %w", err) } + + if notif.Status == nil { + // If it's still not set, try + // getting status by edit ID. + notif.Status, err = n.state.DB.GetStatusByEditID( + gtscontext.SetBarebones(ctx), + notif.StatusOrEditID, + ) + if err != nil { + // Append any error here as it's an issue. + errs.Appendf("error populating notif status: %w", err) + } + } } return errs.Combine() diff --git a/internal/db/bundb/notification_test.go b/internal/db/bundb/notification_test.go index a6dcdd407..b3104f905 100644 --- a/internal/db/bundb/notification_test.go +++ b/internal/db/bundb/notification_test.go @@ -70,7 +70,7 @@ func (suite *NotificationTestSuite) spamNotifs() { CreatedAt: time.Now(), TargetAccountID: targetAccountID, OriginAccountID: originAccountID, - StatusID: statusID, + StatusOrEditID: statusID, Read: util.Ptr(false), } @@ -263,7 +263,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsPertainingToStatusID( } for _, n := range notif { - if n.StatusID == testStatus.ID { + if n.StatusOrEditID == testStatus.ID { suite.FailNowf("", "no notifications with status id %s should remain", testStatus.ID) } } diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index f33362a3d..cf4a2549a 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -732,3 +732,105 @@ func (s *statusDB) GetDirectStatusIDsBatch(ctx context.Context, minID string, ma } return statusIDs, nil } + +func (s *statusDB) GetStatusInteractions( + ctx context.Context, + statusID string, + localOnly bool, +) ([]gtsmodel.Interaction, error) { + // Prepare to get interactions. + interactions := []gtsmodel.Interaction{} + + // Gather faves. + faves, err := s.state.DB.GetStatusFaves(ctx, statusID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, err + } + + for _, fave := range faves { + // Get account at least. + if fave.Account == nil { + fave.Account, err = s.state.DB.GetAccountByID(ctx, fave.AccountID) + if err != nil { + log.Errorf(ctx, "error getting account for fave: %v", err) + continue + } + } + + if localOnly && !fave.Account.IsLocal() { + // Skip not local. + continue + } + + interactions = append(interactions, fave) + } + + // Gather replies. + replies, err := s.state.DB.GetStatusReplies(ctx, statusID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, err + } + + for _, reply := range replies { + // Get account at least. + if reply.Account == nil { + reply.Account, err = s.state.DB.GetAccountByID(ctx, reply.AccountID) + if err != nil { + log.Errorf(ctx, "error getting account for reply: %v", err) + continue + } + } + + if localOnly && !reply.Account.IsLocal() { + // Skip not local. + continue + } + + interactions = append(interactions, reply) + } + + // Gather boosts. + boosts, err := s.state.DB.GetStatusBoosts(ctx, statusID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, err + } + + for _, boost := range boosts { + // Get account at least. + if boost.Account == nil { + boost.Account, err = s.state.DB.GetAccountByID(ctx, boost.AccountID) + if err != nil { + log.Errorf(ctx, "error getting account for boost: %v", err) + continue + } + } + + if localOnly && !boost.Account.IsLocal() { + // Skip not local. + continue + } + + interactions = append(interactions, boost) + } + + if len(interactions) == 0 { + return nil, db.ErrNoEntries + } + + return interactions, nil +} + +func (s *statusDB) GetStatusByEditID( + ctx context.Context, + editID string, +) (*gtsmodel.Status, error) { + edit, err := s.state.DB.GetStatusEditByID( + gtscontext.SetBarebones(ctx), + editID, + ) + if err != nil { + return nil, err + } + + return s.GetStatusByID(ctx, edit.StatusID) +} |
