diff options
Diffstat (limited to 'internal/db/bundb/notification.go')
-rw-r--r-- | internal/db/bundb/notification.go | 62 |
1 files changed, 59 insertions, 3 deletions
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 627dc1783..b1e7f45ff 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -19,6 +19,7 @@ package bundb import ( "context" + "errors" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -104,15 +105,70 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, return notifs, nil } -func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error { +func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.Error { if _, err := n.conn. NewDelete(). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). - Where("? = ?", bun.Ident("notification.target_account_id"), accountID). + Where("? = ?", bun.Ident("notification.id"), id). Exec(ctx); err != nil { return n.conn.ProcessError(err) } - n.state.Caches.GTS.Notification().Clear() + n.state.Caches.GTS.Notification().Invalidate("ID", id) + return nil +} + +func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) db.Error { + if targetAccountID == "" && originAccountID == "" { + return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set") + } + + // Capture notification IDs in a RETURNING statement. + ids := []string{} + + q := n.conn. + NewDelete(). + TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). + Returning("?", bun.Ident("id")) + + if targetAccountID != "" { + q = q.Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID) + } + + if originAccountID != "" { + q = q.Where("? = ?", bun.Ident("notification.origin_account_id"), originAccountID) + } + + if _, err := q.Exec(ctx, &ids); err != nil { + return n.conn.ProcessError(err) + } + + // Invalidate each returned ID. + for _, id := range ids { + n.state.Caches.GTS.Notification().Invalidate("ID", id) + } + + return nil +} + +func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error { + // Capture notification IDs in a RETURNING statement. + ids := []string{} + + q := n.conn. + NewDelete(). + TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). + Where("? = ?", bun.Ident("notification.status_id"), statusID). + Returning("?", bun.Ident("id")) + + if _, err := q.Exec(ctx, &ids); err != nil { + return n.conn.ProcessError(err) + } + + // Invalidate each returned ID. + for _, id := range ids { + n.state.Caches.GTS.Notification().Invalidate("ID", id) + } + return nil } |