summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/application.go54
1 files changed, 35 insertions, 19 deletions
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go
index 1a600d620..562614e5e 100644
--- a/internal/db/bundb/application.go
+++ b/internal/db/bundb/application.go
@@ -19,8 +19,10 @@ package bundb
import (
"context"
+ "errors"
"slices"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging"
@@ -409,8 +411,11 @@ func (a *applicationDB) UpdateToken(ctx context.Context, token *gtsmodel.Token,
}
func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
+ var token gtsmodel.Token
+ token.ID = id
+
_, err := a.db.NewDelete().
- Table("tokens").
+ Model(&token).
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
if err != nil {
@@ -418,68 +423,79 @@ func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
}
a.state.Caches.DB.Token.Invalidate("ID", id)
+ a.state.Caches.OnInvalidateToken(&token)
return nil
}
func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error {
+ var token gtsmodel.Token
+
_, err := a.db.NewDelete().
- Table("tokens").
+ Model(&token).
Where("? = ?", bun.Ident("code"), code).
+ Returning("?", bun.Ident("id")).
Exec(ctx)
- if err != nil {
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
a.state.Caches.DB.Token.Invalidate("Code", code)
+ a.state.Caches.OnInvalidateToken(&token)
return nil
}
func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error {
+ var token gtsmodel.Token
+
_, err := a.db.NewDelete().
- Table("tokens").
+ Model(&token).
Where("? = ?", bun.Ident("access"), access).
+ Returning("?", bun.Ident("id")).
Exec(ctx)
- if err != nil {
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
a.state.Caches.DB.Token.Invalidate("Access", access)
+ a.state.Caches.OnInvalidateToken(&token)
return nil
}
func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error {
+ var token gtsmodel.Token
+
_, err := a.db.NewDelete().
- Table("tokens").
+ Model(&token).
Where("? = ?", bun.Ident("refresh"), refresh).
+ Returning("?", bun.Ident("id")).
Exec(ctx)
- if err != nil {
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
a.state.Caches.DB.Token.Invalidate("Refresh", refresh)
+ a.state.Caches.OnInvalidateToken(&token)
return nil
}
func (a *applicationDB) DeleteTokensByClientID(ctx context.Context, clientID string) error {
+ var tokens []*gtsmodel.Token
+
// Delete tokens owned by
// clientID and gather token IDs.
- var tokenIDs []string
- if _, err := a.db.
- NewDelete().
- Table("tokens").
+ if _, err := a.db.NewDelete().
+ Model(&tokens).
Where("? = ?", bun.Ident("client_id"), clientID).
- Returning("id").
- Exec(ctx, &tokenIDs); err != nil {
+ Returning("?", bun.Ident("id")).
+ Exec(ctx); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
- if len(tokenIDs) == 0 {
- // Nothing was deleted,
- // nothing to invalidate.
- return nil
+ // Invalidate all deleted tokens.
+ for _, token := range tokens {
+ a.state.Caches.DB.Token.Invalidate("ID", token.ID)
+ a.state.Caches.OnInvalidateToken(token)
}
- // Invalidate all deleted tokens.
- a.state.Caches.DB.Token.InvalidateIDs("ID", tokenIDs)
return nil
}