summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2025-03-04 11:01:25 +0100
committerLibravatar GitHub <noreply@github.com>2025-03-04 10:01:25 +0000
commit829143d2636d4c0d274bf2ab4559912f472a2bc4 (patch)
treeb28175fadfbd2d02801337975560e522dd8e129b /internal/db
parent[chore] fixed email template to align with the new "Log in" button + separate... (diff)
downloadgotosocial-829143d2636d4c0d274bf2ab4559912f472a2bc4.tar.xz
[feature] Add token review / delete to backend + settings panel (#3845)
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/application.go4
-rw-r--r--internal/db/bundb/application.go102
2 files changed, 106 insertions, 0 deletions
diff --git a/internal/db/application.go b/internal/db/application.go
index 9f0109d59..a3061f028 100644
--- a/internal/db/application.go
+++ b/internal/db/application.go
@@ -21,6 +21,7 @@ import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
type Application interface {
@@ -39,6 +40,9 @@ type Application interface {
// GetAllTokens fetches all client oauth tokens from database.
GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error)
+ // GetAccessTokens allows paging through a user's access (ie., user-level) tokens.
+ GetAccessTokens(ctx context.Context, userID string, page *paging.Page) ([]*gtsmodel.Token, error)
+
// GetTokenByID fetches the client oauth token from database with ID.
GetTokenByID(ctx context.Context, id string) (*gtsmodel.Token, error)
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go
index d94c984d0..c21221c9f 100644
--- a/internal/db/bundb/application.go
+++ b/internal/db/bundb/application.go
@@ -19,8 +19,11 @@ package bundb
import (
"context"
+ "slices"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices"
"github.com/uptrace/bun"
@@ -139,6 +142,74 @@ func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, er
return tokens, nil
}
+func (a *applicationDB) GetAccessTokens(
+ ctx context.Context,
+ userID string,
+ page *paging.Page,
+) ([]*gtsmodel.Token, error) {
+ var (
+ // Get paging params.
+ minID = page.GetMin()
+ maxID = page.GetMax()
+ limit = page.GetLimit()
+ order = page.GetOrder()
+
+ // Make educated guess for slice size.
+ tokenIDs = make([]string, 0, limit)
+ )
+
+ // Ensure user ID.
+ if userID == "" {
+ return nil, gtserror.New("userID not set")
+ }
+
+ q := a.db.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("tokens"), bun.Ident("token")).
+ Column("token.id").
+ Where("? = ?", bun.Ident("token.user_id"), userID).
+ Where("? != ?", bun.Ident("token.access"), "")
+
+ if maxID != "" {
+ // Return only tokens LOWER (ie., older) than maxID.
+ q = q.Where("? < ?", bun.Ident("token.id"), maxID)
+ }
+
+ if minID != "" {
+ // Return only tokens HIGHER (ie., newer) than minID.
+ q = q.Where("? > ?", bun.Ident("token.id"), minID)
+ }
+
+ if limit > 0 {
+ q = q.Limit(limit)
+ }
+
+ if order == paging.OrderAscending {
+ // Page up.
+ q = q.Order("token.id ASC")
+ } else {
+ // Page down.
+ q = q.Order("token.id DESC")
+ }
+
+ if err := q.Scan(ctx, &tokenIDs); err != nil {
+ return nil, err
+ }
+
+ if len(tokenIDs) == 0 {
+ return nil, nil
+ }
+
+ // If we're paging up, we still want tokens
+ // to be sorted by ID desc (ie., newest to
+ // oldest), so reverse ids slice.
+ if order == paging.OrderAscending {
+ slices.Reverse(tokenIDs)
+ }
+
+ return a.getTokensByIDs(ctx, tokenIDs)
+}
+
func (a *applicationDB) GetTokenByID(ctx context.Context, code string) (*gtsmodel.Token, error) {
return a.getTokenBy(
"ID",
@@ -149,6 +220,37 @@ func (a *applicationDB) GetTokenByID(ctx context.Context, code string) (*gtsmode
)
}
+func (a *applicationDB) getTokensByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Token, error) {
+ tokens, err := a.state.Caches.DB.Token.LoadIDs("ID",
+ ids,
+ func(uncached []string) ([]*gtsmodel.Token, error) {
+ // Preallocate expected length of uncached tokens.
+ tokens := make([]*gtsmodel.Token, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) token IDs.
+ if err := a.db.NewSelect().
+ Model(&tokens).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return tokens, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the tokens by their
+ // IDs to ensure in correct order.
+ getID := func(t *gtsmodel.Token) string { return t.ID }
+ xslices.OrderBy(tokens, ids, getID)
+
+ return tokens, nil
+}
+
func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) {
return a.getTokenBy(
"Code",