diff options
Diffstat (limited to 'internal/oauth/tokenstore.go')
| -rw-r--r-- | internal/oauth/tokenstore.go | 162 |
1 files changed, 115 insertions, 47 deletions
diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go index df2e419fe..8c6506fa3 100644 --- a/internal/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -22,30 +22,32 @@ import ( "errors" "time" + "codeberg.org/gruf/go-mutexes" "codeberg.org/superseriousbusiness/oauth2/v4" "codeberg.org/superseriousbusiness/oauth2/v4/models" - "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" ) // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. type tokenStore struct { oauth2.TokenStore - db db.DB + state *state.State + lastUsedLocks mutexes.MutexMap } // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. // // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // the tokens in the DB once per minute and deletes any that have expired. -func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore { - ts := &tokenStore{ - db: db, - } +func newTokenStore(ctx context.Context, state *state.State) oauth2.TokenStore { + ts := &tokenStore{state: state} - // set the token store to clean out expired tokens once per minute, or return if we're done + // Set the token store to clean out expired tokens + // once per minute, or return if we're done. go func(ctx context.Context, ts *tokenStore) { cleanloop: for { @@ -64,25 +66,48 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore { return ts } -// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. +// sweep clears out old tokens that have expired; +// it should be run on a loop about once per minute or so. func (ts *tokenStore) sweep(ctx context.Context) error { - // select *all* tokens from the db - // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. - tokens, err := ts.db.GetAllTokens(ctx) + // Select *all* tokens from the db + // + // TODO: if this becomes expensive + // (ie., there are fucking LOADS of + // tokens) then figure out a better way. + tokens, err := ts.state.DB.GetAllTokens(ctx) if err != nil { return err } - // iterate through and remove expired tokens + // Remove any expired tokens, bearing + // in mind that zero time = no expiry. now := time.Now() - for _, dbt := range tokens { - // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: - // we only want to check if a token expired before now if the expiry time is *not zero*; - // ie., if it's been explicity set. - if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { - if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil { - return err - } + for _, token := range tokens { + var expired bool + + switch { + case !token.CodeExpiresAt.IsZero() && token.CodeExpiresAt.Before(now): + log.Tracef(ctx, "code token %s is expired", token.ID) + expired = true + + case !token.RefreshExpiresAt.IsZero() && token.RefreshExpiresAt.Before(now): + log.Tracef(ctx, "refresh token %s is expired", token.ID) + expired = true + + case !token.AccessExpiresAt.IsZero() && token.AccessExpiresAt.Before(now): + log.Tracef(ctx, "access token %s is expired", token.ID) + expired = true + } + + if !expired { + // Token's + // still good. + continue + } + + if err := ts.state.DB.DeleteTokenByID(ctx, token.ID); err != nil { + err := gtserror.Newf("db error expiring token %s: %w", token.ID, err) + return err } } @@ -90,7 +115,6 @@ func (ts *tokenStore) sweep(ctx context.Context) error { } // Create creates and store the new token information. -// For the original implementation, see https://codeberg.org/superseriousbusiness/oauth2/blob/master/store/token.go#L34 func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { @@ -99,55 +123,99 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { dbt := TokenToDBToken(t) if dbt.ID == "" { - dbtID, err := id.NewRandomULID() - if err != nil { - return err - } - dbt.ID = dbtID + dbt.ID = id.NewULID() } - return ts.db.PutToken(ctx, dbt) + return ts.state.DB.PutToken(ctx, dbt) } // RemoveByCode deletes a token from the DB based on the Code field func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { - return ts.db.DeleteTokenByCode(ctx, code) + return ts.state.DB.DeleteTokenByCode(ctx, code) } // RemoveByAccess deletes a token from the DB based on the Access field func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { - return ts.db.DeleteTokenByAccess(ctx, access) + return ts.state.DB.DeleteTokenByAccess(ctx, access) } // RemoveByRefresh deletes a token from the DB based on the Refresh field func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - return ts.db.DeleteTokenByRefresh(ctx, refresh) + return ts.state.DB.DeleteTokenByRefresh(ctx, refresh) } -// GetByCode selects a token from the DB based on the Code field -func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - token, err := ts.db.GetTokenByCode(ctx, code) - if err != nil { - return nil, err - } - return DBTokenToToken(token), nil +// GetByCode selects a token from +// the DB based on the Code field +func (ts *tokenStore) GetByCode( + ctx context.Context, + code string, +) (oauth2.TokenInfo, error) { + return ts.getUpdateToken( + ctx, + ts.state.DB.GetTokenByCode, + code, + ) } -// GetByAccess selects a token from the DB based on the Access field -func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { - token, err := ts.db.GetTokenByAccess(ctx, access) - if err != nil { - return nil, err - } - return DBTokenToToken(token), nil +// GetByAccess selects a token from +// the DB based on the Access field. +func (ts *tokenStore) GetByAccess( + ctx context.Context, + access string, +) (oauth2.TokenInfo, error) { + return ts.getUpdateToken( + ctx, + ts.state.DB.GetTokenByAccess, + access, + ) } -// GetByRefresh selects a token from the DB based on the Refresh field -func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - token, err := ts.db.GetTokenByRefresh(ctx, refresh) +// GetByRefresh selects a token from +// the DB based on the Refresh field +func (ts *tokenStore) GetByRefresh( + ctx context.Context, + refresh string, +) (oauth2.TokenInfo, error) { + return ts.getUpdateToken( + ctx, + ts.state.DB.GetTokenByRefresh, + refresh, + ) +} + +// package-internal function for getting a token +// and potentially updating its last_used value. +func (ts *tokenStore) getUpdateToken( + ctx context.Context, + getBy func(context.Context, string) (*gtsmodel.Token, error), + key string, +) (oauth2.TokenInfo, error) { + // Hold a lock to get the token based on + // whatever func + key we've been given. + unlock := ts.lastUsedLocks.Lock(key) + + token, err := getBy(ctx, key) if err != nil { + // Unlock on error. + unlock() return nil, err } + + // If token was last used more than + // an hour ago, update this in the db. + wasLastUsed := token.LastUsed + if now := time.Now(); now.Sub(wasLastUsed) > 1*time.Hour { + token.LastUsed = now + if err := ts.state.DB.UpdateToken(ctx, token, "last_used"); err != nil { + // Unlock on error. + unlock() + err := gtserror.Newf("error updating last_used on token: %w", err) + return nil, err + } + } + + // We're done, unlock. + unlock() return DBTokenToToken(token), nil } |
