diff options
Diffstat (limited to 'internal/oauth')
| -rw-r--r-- | internal/oauth/clientstore.go | 43 | ||||
| -rw-r--r-- | internal/oauth/clientstore_test.go | 77 | ||||
| -rw-r--r-- | internal/oauth/handlers/handlers.go | 153 | ||||
| -rw-r--r-- | internal/oauth/oauth_test.go | 20 | ||||
| -rw-r--r-- | internal/oauth/server.go | 174 | ||||
| -rw-r--r-- | internal/oauth/tokenstore.go | 162 |
6 files changed, 408 insertions, 221 deletions
diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go index af48edac3..17de0e342 100644 --- a/internal/oauth/clientstore.go +++ b/internal/oauth/clientstore.go @@ -21,45 +21,30 @@ import ( "context" "codeberg.org/superseriousbusiness/oauth2/v4" - "codeberg.org/superseriousbusiness/oauth2/v4/models" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "codeberg.org/superseriousbusiness/oauth2/v4/errors" + "github.com/superseriousbusiness/gotosocial/internal/state" ) type clientStore struct { - db db.DB + state *state.State } -// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. -func NewClientStore(db db.DB) oauth2.ClientStore { - pts := &clientStore{ - db: db, - } - return pts +// NewClientStore returns a minimal implementation of +// oauth2.ClientStore interface, using state as storage. +// +// Only GetByID is implemented, Set and Delete are stubs. +func NewClientStore(state *state.State) oauth2.ClientStore { + return &clientStore{state: state} } func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { - client, err := cs.db.GetClientByID(ctx, clientID) - if err != nil { - return nil, err - } - return models.New( - client.ID, - client.Secret, - client.Domain, - client.UserID, - ), nil + return cs.state.DB.GetApplicationByClientID(ctx, clientID) } -func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { - return cs.db.PutClient(ctx, >smodel.Client{ - ID: cli.GetID(), - Secret: cli.GetSecret(), - Domain: cli.GetDomain(), - UserID: cli.GetUserID(), - }) +func (cs *clientStore) Set(_ context.Context, _ string, _ oauth2.ClientInfo) error { + return errors.New("func oauth2.ClientStore.Set not implemented") } -func (cs *clientStore) Delete(ctx context.Context, id string) error { - return cs.db.DeleteClientByID(ctx, id) +func (cs *clientStore) Delete(_ context.Context, _ string) error { + return errors.New("func oauth2.ClientStore.Delete not implemented") } diff --git a/internal/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go index 59b0ec1d3..c6621186a 100644 --- a/internal/oauth/clientstore_test.go +++ b/internal/oauth/clientstore_test.go @@ -21,93 +21,58 @@ import ( "context" "testing" - "codeberg.org/superseriousbusiness/oauth2/v4/models" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/admin" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/testrig" ) -type PgClientStoreTestSuite struct { +type ClientStoreTestSuite struct { suite.Suite db db.DB state state.State - testClientID string - testClientSecret string - testClientDomain string - testClientUserID string + testApplications map[string]*gtsmodel.Application } -// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout -func (suite *PgClientStoreTestSuite) SetupSuite() { - suite.testClientID = "01FCVB74EW6YBYAEY7QG9CQQF6" - suite.testClientSecret = "4cc87402-259b-4a35-9485-2c8bf54f3763" - suite.testClientDomain = "https://example.org" - suite.testClientUserID = "01FEGYXKVCDB731QF9MVFXA4F5" +func (suite *ClientStoreTestSuite) SetupSuite() { + suite.testApplications = testrig.NewTestApplications() } -// SetupTest creates a postgres connection and creates the oauth_clients table before each test -func (suite *PgClientStoreTestSuite) SetupTest() { +func (suite *ClientStoreTestSuite) SetupTest() { suite.state.Caches.Init() - testrig.InitTestLog() testrig.InitTestConfig() + testrig.InitTestLog() suite.db = testrig.NewTestDB(&suite.state) suite.state.DB = suite.db suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers) testrig.StandardDBSetup(suite.db, nil) } -// TearDownTest drops the oauth_clients table and closes the pg connection after each test -func (suite *PgClientStoreTestSuite) TearDownTest() { +func (suite *ClientStoreTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) } -func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { - // set a new client in the store - cs := oauth.NewClientStore(suite.db) - if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { - suite.FailNow(err.Error()) - } - - // fetch that client from the store - client, err := cs.GetByID(context.Background(), suite.testClientID) - if err != nil { - suite.FailNow(err.Error()) - } - - // check that the values are the same - suite.NotNil(client) - suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client) -} - -func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { - // set a new client in the store - cs := oauth.NewClientStore(suite.db) - if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { - suite.FailNow(err.Error()) - } +func (suite *ClientStoreTestSuite) TestClientStoreGet() { + testApp := suite.testApplications["application_1"] + cs := oauth.NewClientStore(&suite.state) - // fetch the client from the store - client, err := cs.GetByID(context.Background(), suite.testClientID) + // Fetch clientInfo from the store. + clientInfo, err := cs.GetByID(context.Background(), testApp.ClientID) if err != nil { suite.FailNow(err.Error()) } - // check that the values are the same - suite.NotNil(client) - suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client) - if err := cs.Delete(context.Background(), suite.testClientID); err != nil { - suite.FailNow(err.Error()) - } - - // try to get the deleted client; we should get an error - deletedClient, err := cs.GetByID(context.Background(), suite.testClientID) - suite.Assert().Nil(deletedClient) - suite.Assert().EqualValues(db.ErrNoEntries, err) + // Check expected values. + suite.NotNil(clientInfo) + suite.Equal(testApp.ClientID, clientInfo.GetID()) + suite.Equal(testApp.ClientSecret, clientInfo.GetSecret()) + suite.Equal(testApp.RedirectURIs[0], clientInfo.GetDomain()) + suite.Equal(testApp.ManagedByUserID, clientInfo.GetUserID()) } -func TestPgClientStoreTestSuite(t *testing.T) { - suite.Run(t, new(PgClientStoreTestSuite)) +func TestClientStoreTestSuite(t *testing.T) { + suite.Run(t, new(ClientStoreTestSuite)) } diff --git a/internal/oauth/handlers/handlers.go b/internal/oauth/handlers/handlers.go new file mode 100644 index 000000000..f0af007f0 --- /dev/null +++ b/internal/oauth/handlers/handlers.go @@ -0,0 +1,153 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package handlers + +import ( + "context" + "errors" + "net/http" + "net/url" + "slices" + "strings" + + "codeberg.org/superseriousbusiness/oauth2/v4" + oautherr "codeberg.org/superseriousbusiness/oauth2/v4/errors" + "codeberg.org/superseriousbusiness/oauth2/v4/manage" + "codeberg.org/superseriousbusiness/oauth2/v4/server" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" +) + +// GetClientScopeHandler returns a handler for testing scope on a TokenGenerateRequest. +func GetClientScopeHandler(ctx context.Context, state *state.State) server.ClientScopeHandler { + return func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) { + application, err := state.DB.GetApplicationByClientID( + gtscontext.SetBarebones(ctx), + tgr.ClientID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + log.Errorf(ctx, "database error getting application: %v", err) + return false, err + } + + if application == nil { + err := gtserror.Newf("no application found with client id %s", tgr.ClientID) + return false, err + } + + // Normalize scope. + if strings.TrimSpace(tgr.Scope) == "" { + tgr.Scope = "read" + } + + // Make sure requested scopes are all + // within scopes permitted by application. + hasScopes := strings.Split(application.Scopes, " ") + wantsScopes := strings.Split(tgr.Scope, " ") + for _, wantsScope := range wantsScopes { + thisOK := slices.ContainsFunc( + hasScopes, + func(hasScope string) bool { + has := apiutil.Scope(hasScope) + wants := apiutil.Scope(wantsScope) + return has.Permits(wants) + }, + ) + + if !thisOK { + // Requested unpermitted + // scope for this app. + return false, nil + } + } + + // All OK. + return true, nil + } +} + +func GetValidateURIHandler(ctx context.Context) manage.ValidateURIHandler { + return func(hasRedirects string, wantsRedirect string) error { + // Normalize the wantsRedirect URI + // string by parsing + reserializing. + wantsRedirectURI, err := url.Parse(wantsRedirect) + if err != nil { + return err + } + wantsRedirect = wantsRedirectURI.String() + + // Redirect URIs are given to us as + // a list of URIs, newline-separated. + // + // They're already normalized on input so + // we don't need to parse + reserialize them. + // + // Ensure that one of them matches. + if slices.ContainsFunc( + strings.Split(hasRedirects, "\n"), + func(hasRedirect string) bool { + // Want an exact match. + // See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/ + return wantsRedirect == hasRedirect + }, + ) { + return nil + } + + return oautherr.ErrInvalidRedirectURI + } +} + +func GetAuthorizeScopeHandler() server.AuthorizeScopeHandler { + return func(_ http.ResponseWriter, r *http.Request) (string, error) { + // Use provided scope or + // fall back to default "read". + scope := r.FormValue("scope") + if strings.TrimSpace(scope) == "" { + scope = "read" + } + return scope, nil + } +} + +func GetInternalErrorHandler(ctx context.Context) server.InternalErrorHandler { + return func(err error) *oautherr.Response { + log.Errorf(ctx, "internal oauth error: %v", err) + return nil + } +} + +func GetResponseErrorHandler(ctx context.Context) server.ResponseErrorHandler { + return func(re *oautherr.Response) { + log.Errorf(ctx, "internal response error: %v", re.Error) + } +} + +func GetUserAuthorizationHandler() server.UserAuthorizationHandler { + return func(w http.ResponseWriter, r *http.Request) (string, error) { + userID := r.FormValue("userid") + if userID == "" { + return "", errors.New("userid was empty") + } + return userID, nil + } +} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go deleted file mode 100644 index 2b76024f7..000000000 --- a/internal/oauth/oauth_test.go +++ /dev/null @@ -1,20 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see <http://www.gnu.org/licenses/>. - -package oauth_test - -// TODO: write tests diff --git a/internal/oauth/server.go b/internal/oauth/server.go index 8330ee179..c0c3c329c 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -30,7 +30,10 @@ import ( "codeberg.org/superseriousbusiness/oauth2/v4/server" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" ) const ( @@ -60,7 +63,8 @@ const ( HelpfulAdviceGrant = "If you arrived at this error during a sign in/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client" ) -// Server wraps some oauth2 server functions in an interface, exposing only what is needed +// Server wraps some oauth2 server functions +// in an interface, exposing only what is needed. type Server interface { HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode @@ -69,66 +73,76 @@ type Server interface { LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) } -// s fulfils the Server interface using the underlying oauth2 server +// s fulfils the Server interface +// using the underlying oauth2 server. type s struct { server *server.Server } // New returns a new oauth server that implements the Server interface -func New(ctx context.Context, database db.DB) Server { - ts := newTokenStore(ctx, database) - cs := NewClientStore(database) - +func New( + ctx context.Context, + state *state.State, + validateURIHandler manage.ValidateURIHandler, + clientScopeHandler server.ClientScopeHandler, + authorizeScopeHandler server.AuthorizeScopeHandler, + internalErrorHandler server.InternalErrorHandler, + responseErrorHandler server.ResponseErrorHandler, + userAuthorizationHandler server.UserAuthorizationHandler, +) Server { + ts := newTokenStore(ctx, state) + cs := NewClientStore(state) + + // Set up OAuth2 manager. manager := manage.NewDefaultManager() + manager.SetValidateURIHandler(validateURIHandler) manager.MapTokenStorage(ts) manager.MapClientStorage(cs) - manager.SetAuthorizeCodeTokenCfg(&manage.Config{ - AccessTokenExp: 0, // access tokens don't expire -- they must be revoked - IsGenerateRefresh: false, // don't use refresh tokens - }) - sc := &server.Config{ - TokenType: "Bearer", - // Must follow the spec. - AllowGetAccessRequest: false, - // Support only the non-implicit flow. - AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, - // Allow: - // - Authorization Code (for first & third parties) - // - Client Credentials (for applications) - AllowedGrantTypes: []oauth2.GrantType{ - oauth2.AuthorizationCode, - oauth2.ClientCredentials, + manager.SetAuthorizeCodeTokenCfg( + &manage.Config{ + // Following the Mastodon API, + // access tokens don't expire. + AccessTokenExp: 0, + // Don't use refresh tokens. + IsGenerateRefresh: false, }, - AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ - oauth2.CodeChallengePlain, - oauth2.CodeChallengeS256, + ) + + // Set up OAuth2 server. + srv := server.NewServer( + &server.Config{ + TokenType: "Bearer", + // Must follow the spec. + AllowGetAccessRequest: false, + // Support only the non-implicit flow. + AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, + // Allow: + // - Authorization Code (for first & third parties) + // - Client Credentials (for applications) + AllowedGrantTypes: []oauth2.GrantType{ + oauth2.AuthorizationCode, + oauth2.ClientCredentials, + }, + AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ + oauth2.CodeChallengePlain, + oauth2.CodeChallengeS256, + }, }, - } - - srv := server.NewServer(sc, manager) - srv.SetInternalErrorHandler(func(err error) *oautherr.Response { - log.Errorf(nil, "internal oauth error: %s", err) - return nil - }) - - srv.SetResponseErrorHandler(func(re *oautherr.Response) { - log.Errorf(nil, "internal response error: %s", re.Error) - }) - - srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { - userID := r.FormValue("userid") - if userID == "" { - return "", errors.New("userid was empty") - } - return userID, nil - }) + manager, + ) + srv.SetAuthorizeScopeHandler(authorizeScopeHandler) + srv.SetClientScopeHandler(clientScopeHandler) + srv.SetInternalErrorHandler(internalErrorHandler) + srv.SetResponseErrorHandler(responseErrorHandler) + srv.SetUserAuthorizationHandler(userAuthorizationHandler) srv.SetClientInfoHandler(server.ClientFormHandler) - return &s{ - server: srv, - } + + return &s{srv} } -// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function +// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function, +// providing some custom error handling (with more informative messages), +// and a slightly different token serialization format. func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) { ctx := r.Context() @@ -142,32 +156,43 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro return nil, gtserror.NewErrorBadRequest(err, help, adv) } + // Get access token + do our own nicer error handling. ti, err := s.server.GetAccessToken(ctx, gt, tgr) - if err != nil { - help := fmt.Sprintf("could not get access token: %s", err) + switch { + case err == nil: + // No problem. + break + + case errors.Is(err, oautherr.ErrInvalidScope): + help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope) + return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice) + + case errors.Is(err, oautherr.ErrInvalidRedirectURI): + help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI) + return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice) + + default: + help := fmt.Sprintf("could not get access token: %v", err) return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice) } + // Wrangle data a bit. data := s.server.GetTokenData(ti) + // Add created_at for Mastodon API compatibility. + data["created_at"] = ti.GetAccessCreateAt().Unix() + + // If expires_in is 0 or less, omit it + // from serialization so that clients don't + // interpret the token as already expired. if expiresInI, ok := data["expires_in"]; ok { - switch expiresIn := expiresInI.(type) { - case int64: - // remove this key from the returned map - // if the value is 0 or less, so that clients - // don't interpret the token as already expired - if expiresIn <= 0 { - delete(data, "expires_in") - } - default: - err := errors.New("expires_in was set on token response, but was not an int64") - return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice) + // This will panic if expiresIn is + // not an int64, which is what we want. + if expiresInI.(int64) <= 0 { + delete(data, "expires_in") } } - // add this for mastodon api compatibility - data["created_at"] = ti.GetAccessCreateAt().Unix() - return data, nil } @@ -207,7 +232,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser } req.UserID = userID - // specify the scope of authorization + // Specify the scope of authorization. if fn := s.server.AuthorizeScopeHandler; fn != nil { scope, err := fn(w, r) if err != nil { @@ -217,7 +242,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser } } - // specify the expiration time of access token + // Specify the expiration time of access token. if fn := s.server.AccessTokenExpHandler; fn != nil { exp, err := fn(w, r) if err != nil { @@ -231,13 +256,24 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser return s.errorOrRedirect(err, w, req) } - // If the redirect URI is empty, the default domain provided by the client is used. + // If the redirect URI is empty, use the + // first of the client's redirect URIs. if req.RedirectURI == "" { client, err := s.server.Manager.GetClient(ctx, req.ClientID) - if err != nil { + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Real error. + err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err) + return gtserror.NewErrorInternalError(err) + } + + if util.IsNil(client) { + // Application just not found. return gtserror.NewErrorUnauthorized(err, HelpfulAdvice) } - req.RedirectURI = client.GetDomain() + + // This will panic if client is not a + // *gtsmodel.Application, which is what we want. + req.RedirectURI = client.(*gtsmodel.Application).RedirectURIs[0] } uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti)) diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go index df2e419fe..8c6506fa3 100644 --- a/internal/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -22,30 +22,32 @@ import ( "errors" "time" + "codeberg.org/gruf/go-mutexes" "codeberg.org/superseriousbusiness/oauth2/v4" "codeberg.org/superseriousbusiness/oauth2/v4/models" - "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" ) // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. type tokenStore struct { oauth2.TokenStore - db db.DB + state *state.State + lastUsedLocks mutexes.MutexMap } // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. // // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // the tokens in the DB once per minute and deletes any that have expired. -func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore { - ts := &tokenStore{ - db: db, - } +func newTokenStore(ctx context.Context, state *state.State) oauth2.TokenStore { + ts := &tokenStore{state: state} - // set the token store to clean out expired tokens once per minute, or return if we're done + // Set the token store to clean out expired tokens + // once per minute, or return if we're done. go func(ctx context.Context, ts *tokenStore) { cleanloop: for { @@ -64,25 +66,48 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore { return ts } -// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. +// sweep clears out old tokens that have expired; +// it should be run on a loop about once per minute or so. func (ts *tokenStore) sweep(ctx context.Context) error { - // select *all* tokens from the db - // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. - tokens, err := ts.db.GetAllTokens(ctx) + // Select *all* tokens from the db + // + // TODO: if this becomes expensive + // (ie., there are fucking LOADS of + // tokens) then figure out a better way. + tokens, err := ts.state.DB.GetAllTokens(ctx) if err != nil { return err } - // iterate through and remove expired tokens + // Remove any expired tokens, bearing + // in mind that zero time = no expiry. now := time.Now() - for _, dbt := range tokens { - // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: - // we only want to check if a token expired before now if the expiry time is *not zero*; - // ie., if it's been explicity set. - if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { - if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil { - return err - } + for _, token := range tokens { + var expired bool + + switch { + case !token.CodeExpiresAt.IsZero() && token.CodeExpiresAt.Before(now): + log.Tracef(ctx, "code token %s is expired", token.ID) + expired = true + + case !token.RefreshExpiresAt.IsZero() && token.RefreshExpiresAt.Before(now): + log.Tracef(ctx, "refresh token %s is expired", token.ID) + expired = true + + case !token.AccessExpiresAt.IsZero() && token.AccessExpiresAt.Before(now): + log.Tracef(ctx, "access token %s is expired", token.ID) + expired = true + } + + if !expired { + // Token's + // still good. + continue + } + + if err := ts.state.DB.DeleteTokenByID(ctx, token.ID); err != nil { + err := gtserror.Newf("db error expiring token %s: %w", token.ID, err) + return err } } @@ -90,7 +115,6 @@ func (ts *tokenStore) sweep(ctx context.Context) error { } // Create creates and store the new token information. -// For the original implementation, see https://codeberg.org/superseriousbusiness/oauth2/blob/master/store/token.go#L34 func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { @@ -99,55 +123,99 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { dbt := TokenToDBToken(t) if dbt.ID == "" { - dbtID, err := id.NewRandomULID() - if err != nil { - return err - } - dbt.ID = dbtID + dbt.ID = id.NewULID() } - return ts.db.PutToken(ctx, dbt) + return ts.state.DB.PutToken(ctx, dbt) } // RemoveByCode deletes a token from the DB based on the Code field func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { - return ts.db.DeleteTokenByCode(ctx, code) + return ts.state.DB.DeleteTokenByCode(ctx, code) } // RemoveByAccess deletes a token from the DB based on the Access field func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { - return ts.db.DeleteTokenByAccess(ctx, access) + return ts.state.DB.DeleteTokenByAccess(ctx, access) } // RemoveByRefresh deletes a token from the DB based on the Refresh field func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - return ts.db.DeleteTokenByRefresh(ctx, refresh) + return ts.state.DB.DeleteTokenByRefresh(ctx, refresh) } -// GetByCode selects a token from the DB based on the Code field -func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - token, err := ts.db.GetTokenByCode(ctx, code) - if err != nil { - return nil, err - } - return DBTokenToToken(token), nil +// GetByCode selects a token from +// the DB based on the Code field +func (ts *tokenStore) GetByCode( + ctx context.Context, + code string, +) (oauth2.TokenInfo, error) { + return ts.getUpdateToken( + ctx, + ts.state.DB.GetTokenByCode, + code, + ) } -// GetByAccess selects a token from the DB based on the Access field -func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { - token, err := ts.db.GetTokenByAccess(ctx, access) - if err != nil { - return nil, err - } - return DBTokenToToken(token), nil +// GetByAccess selects a token from +// the DB based on the Access field. +func (ts *tokenStore) GetByAccess( + ctx context.Context, + access string, +) (oauth2.TokenInfo, error) { + return ts.getUpdateToken( + ctx, + ts.state.DB.GetTokenByAccess, + access, + ) } -// GetByRefresh selects a token from the DB based on the Refresh field -func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - token, err := ts.db.GetTokenByRefresh(ctx, refresh) +// GetByRefresh selects a token from +// the DB based on the Refresh field +func (ts *tokenStore) GetByRefresh( + ctx context.Context, + refresh string, +) (oauth2.TokenInfo, error) { + return ts.getUpdateToken( + ctx, + ts.state.DB.GetTokenByRefresh, + refresh, + ) +} + +// package-internal function for getting a token +// and potentially updating its last_used value. +func (ts *tokenStore) getUpdateToken( + ctx context.Context, + getBy func(context.Context, string) (*gtsmodel.Token, error), + key string, +) (oauth2.TokenInfo, error) { + // Hold a lock to get the token based on + // whatever func + key we've been given. + unlock := ts.lastUsedLocks.Lock(key) + + token, err := getBy(ctx, key) if err != nil { + // Unlock on error. + unlock() return nil, err } + + // If token was last used more than + // an hour ago, update this in the db. + wasLastUsed := token.LastUsed + if now := time.Now(); now.Sub(wasLastUsed) > 1*time.Hour { + token.LastUsed = now + if err := ts.state.DB.UpdateToken(ctx, token, "last_used"); err != nil { + // Unlock on error. + unlock() + err := gtserror.Newf("error updating last_used on token: %w", err) + return nil, err + } + } + + // We're done, unlock. + unlock() return DBTokenToToken(token), nil } |
