diff options
Diffstat (limited to 'internal/db/bundb/application.go')
| -rw-r--r-- | internal/db/bundb/application.go | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index d94c984d0..c21221c9f 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -19,8 +19,11 @@ package bundb import ( "context" + "slices" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/uptrace/bun" @@ -139,6 +142,74 @@ func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, er return tokens, nil } +func (a *applicationDB) GetAccessTokens( + ctx context.Context, + userID string, + page *paging.Page, +) ([]*gtsmodel.Token, error) { + var ( + // Get paging params. + minID = page.GetMin() + maxID = page.GetMax() + limit = page.GetLimit() + order = page.GetOrder() + + // Make educated guess for slice size. + tokenIDs = make([]string, 0, limit) + ) + + // Ensure user ID. + if userID == "" { + return nil, gtserror.New("userID not set") + } + + q := a.db. + NewSelect(). + TableExpr("? AS ?", bun.Ident("tokens"), bun.Ident("token")). + Column("token.id"). + Where("? = ?", bun.Ident("token.user_id"), userID). + Where("? != ?", bun.Ident("token.access"), "") + + if maxID != "" { + // Return only tokens LOWER (ie., older) than maxID. + q = q.Where("? < ?", bun.Ident("token.id"), maxID) + } + + if minID != "" { + // Return only tokens HIGHER (ie., newer) than minID. + q = q.Where("? > ?", bun.Ident("token.id"), minID) + } + + if limit > 0 { + q = q.Limit(limit) + } + + if order == paging.OrderAscending { + // Page up. + q = q.Order("token.id ASC") + } else { + // Page down. + q = q.Order("token.id DESC") + } + + if err := q.Scan(ctx, &tokenIDs); err != nil { + return nil, err + } + + if len(tokenIDs) == 0 { + return nil, nil + } + + // If we're paging up, we still want tokens + // to be sorted by ID desc (ie., newest to + // oldest), so reverse ids slice. + if order == paging.OrderAscending { + slices.Reverse(tokenIDs) + } + + return a.getTokensByIDs(ctx, tokenIDs) +} + func (a *applicationDB) GetTokenByID(ctx context.Context, code string) (*gtsmodel.Token, error) { return a.getTokenBy( "ID", @@ -149,6 +220,37 @@ func (a *applicationDB) GetTokenByID(ctx context.Context, code string) (*gtsmode ) } +func (a *applicationDB) getTokensByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Token, error) { + tokens, err := a.state.Caches.DB.Token.LoadIDs("ID", + ids, + func(uncached []string) ([]*gtsmodel.Token, error) { + // Preallocate expected length of uncached tokens. + tokens := make([]*gtsmodel.Token, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) token IDs. + if err := a.db.NewSelect(). + Model(&tokens). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return tokens, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the tokens by their + // IDs to ensure in correct order. + getID := func(t *gtsmodel.Token) string { return t.ID } + xslices.OrderBy(tokens, ids, getID) + + return tokens, nil +} + func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) { return a.getTokenBy( "Code", |
