summaryrefslogtreecommitdiff
path: root/internal/oauth
diff options
context:
space:
mode:
Diffstat (limited to 'internal/oauth')
-rw-r--r--internal/oauth/clientstore.go43
-rw-r--r--internal/oauth/clientstore_test.go77
-rw-r--r--internal/oauth/handlers/handlers.go153
-rw-r--r--internal/oauth/oauth_test.go20
-rw-r--r--internal/oauth/server.go174
-rw-r--r--internal/oauth/tokenstore.go162
6 files changed, 408 insertions, 221 deletions
diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go
index af48edac3..17de0e342 100644
--- a/internal/oauth/clientstore.go
+++ b/internal/oauth/clientstore.go
@@ -21,45 +21,30 @@ import (
"context"
"codeberg.org/superseriousbusiness/oauth2/v4"
- "codeberg.org/superseriousbusiness/oauth2/v4/models"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "codeberg.org/superseriousbusiness/oauth2/v4/errors"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
)
type clientStore struct {
- db db.DB
+ state *state.State
}
-// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
-func NewClientStore(db db.DB) oauth2.ClientStore {
- pts := &clientStore{
- db: db,
- }
- return pts
+// NewClientStore returns a minimal implementation of
+// oauth2.ClientStore interface, using state as storage.
+//
+// Only GetByID is implemented, Set and Delete are stubs.
+func NewClientStore(state *state.State) oauth2.ClientStore {
+ return &clientStore{state: state}
}
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
- client, err := cs.db.GetClientByID(ctx, clientID)
- if err != nil {
- return nil, err
- }
- return models.New(
- client.ID,
- client.Secret,
- client.Domain,
- client.UserID,
- ), nil
+ return cs.state.DB.GetApplicationByClientID(ctx, clientID)
}
-func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
- return cs.db.PutClient(ctx, &gtsmodel.Client{
- ID: cli.GetID(),
- Secret: cli.GetSecret(),
- Domain: cli.GetDomain(),
- UserID: cli.GetUserID(),
- })
+func (cs *clientStore) Set(_ context.Context, _ string, _ oauth2.ClientInfo) error {
+ return errors.New("func oauth2.ClientStore.Set not implemented")
}
-func (cs *clientStore) Delete(ctx context.Context, id string) error {
- return cs.db.DeleteClientByID(ctx, id)
+func (cs *clientStore) Delete(_ context.Context, _ string) error {
+ return errors.New("func oauth2.ClientStore.Delete not implemented")
}
diff --git a/internal/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go
index 59b0ec1d3..c6621186a 100644
--- a/internal/oauth/clientstore_test.go
+++ b/internal/oauth/clientstore_test.go
@@ -21,93 +21,58 @@ import (
"context"
"testing"
- "codeberg.org/superseriousbusiness/oauth2/v4/models"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/admin"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig"
)
-type PgClientStoreTestSuite struct {
+type ClientStoreTestSuite struct {
suite.Suite
db db.DB
state state.State
- testClientID string
- testClientSecret string
- testClientDomain string
- testClientUserID string
+ testApplications map[string]*gtsmodel.Application
}
-// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
-func (suite *PgClientStoreTestSuite) SetupSuite() {
- suite.testClientID = "01FCVB74EW6YBYAEY7QG9CQQF6"
- suite.testClientSecret = "4cc87402-259b-4a35-9485-2c8bf54f3763"
- suite.testClientDomain = "https://example.org"
- suite.testClientUserID = "01FEGYXKVCDB731QF9MVFXA4F5"
+func (suite *ClientStoreTestSuite) SetupSuite() {
+ suite.testApplications = testrig.NewTestApplications()
}
-// SetupTest creates a postgres connection and creates the oauth_clients table before each test
-func (suite *PgClientStoreTestSuite) SetupTest() {
+func (suite *ClientStoreTestSuite) SetupTest() {
suite.state.Caches.Init()
- testrig.InitTestLog()
testrig.InitTestConfig()
+ testrig.InitTestLog()
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers)
testrig.StandardDBSetup(suite.db, nil)
}
-// TearDownTest drops the oauth_clients table and closes the pg connection after each test
-func (suite *PgClientStoreTestSuite) TearDownTest() {
+func (suite *ClientStoreTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
-func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
- // set a new client in the store
- cs := oauth.NewClientStore(suite.db)
- if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
- suite.FailNow(err.Error())
- }
-
- // fetch that client from the store
- client, err := cs.GetByID(context.Background(), suite.testClientID)
- if err != nil {
- suite.FailNow(err.Error())
- }
-
- // check that the values are the same
- suite.NotNil(client)
- suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
-}
-
-func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
- // set a new client in the store
- cs := oauth.NewClientStore(suite.db)
- if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
- suite.FailNow(err.Error())
- }
+func (suite *ClientStoreTestSuite) TestClientStoreGet() {
+ testApp := suite.testApplications["application_1"]
+ cs := oauth.NewClientStore(&suite.state)
- // fetch the client from the store
- client, err := cs.GetByID(context.Background(), suite.testClientID)
+ // Fetch clientInfo from the store.
+ clientInfo, err := cs.GetByID(context.Background(), testApp.ClientID)
if err != nil {
suite.FailNow(err.Error())
}
- // check that the values are the same
- suite.NotNil(client)
- suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
- if err := cs.Delete(context.Background(), suite.testClientID); err != nil {
- suite.FailNow(err.Error())
- }
-
- // try to get the deleted client; we should get an error
- deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
- suite.Assert().Nil(deletedClient)
- suite.Assert().EqualValues(db.ErrNoEntries, err)
+ // Check expected values.
+ suite.NotNil(clientInfo)
+ suite.Equal(testApp.ClientID, clientInfo.GetID())
+ suite.Equal(testApp.ClientSecret, clientInfo.GetSecret())
+ suite.Equal(testApp.RedirectURIs[0], clientInfo.GetDomain())
+ suite.Equal(testApp.ManagedByUserID, clientInfo.GetUserID())
}
-func TestPgClientStoreTestSuite(t *testing.T) {
- suite.Run(t, new(PgClientStoreTestSuite))
+func TestClientStoreTestSuite(t *testing.T) {
+ suite.Run(t, new(ClientStoreTestSuite))
}
diff --git a/internal/oauth/handlers/handlers.go b/internal/oauth/handlers/handlers.go
new file mode 100644
index 000000000..f0af007f0
--- /dev/null
+++ b/internal/oauth/handlers/handlers.go
@@ -0,0 +1,153 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package handlers
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/url"
+ "slices"
+ "strings"
+
+ "codeberg.org/superseriousbusiness/oauth2/v4"
+ oautherr "codeberg.org/superseriousbusiness/oauth2/v4/errors"
+ "codeberg.org/superseriousbusiness/oauth2/v4/manage"
+ "codeberg.org/superseriousbusiness/oauth2/v4/server"
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+)
+
+// GetClientScopeHandler returns a handler for testing scope on a TokenGenerateRequest.
+func GetClientScopeHandler(ctx context.Context, state *state.State) server.ClientScopeHandler {
+ return func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) {
+ application, err := state.DB.GetApplicationByClientID(
+ gtscontext.SetBarebones(ctx),
+ tgr.ClientID,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ log.Errorf(ctx, "database error getting application: %v", err)
+ return false, err
+ }
+
+ if application == nil {
+ err := gtserror.Newf("no application found with client id %s", tgr.ClientID)
+ return false, err
+ }
+
+ // Normalize scope.
+ if strings.TrimSpace(tgr.Scope) == "" {
+ tgr.Scope = "read"
+ }
+
+ // Make sure requested scopes are all
+ // within scopes permitted by application.
+ hasScopes := strings.Split(application.Scopes, " ")
+ wantsScopes := strings.Split(tgr.Scope, " ")
+ for _, wantsScope := range wantsScopes {
+ thisOK := slices.ContainsFunc(
+ hasScopes,
+ func(hasScope string) bool {
+ has := apiutil.Scope(hasScope)
+ wants := apiutil.Scope(wantsScope)
+ return has.Permits(wants)
+ },
+ )
+
+ if !thisOK {
+ // Requested unpermitted
+ // scope for this app.
+ return false, nil
+ }
+ }
+
+ // All OK.
+ return true, nil
+ }
+}
+
+func GetValidateURIHandler(ctx context.Context) manage.ValidateURIHandler {
+ return func(hasRedirects string, wantsRedirect string) error {
+ // Normalize the wantsRedirect URI
+ // string by parsing + reserializing.
+ wantsRedirectURI, err := url.Parse(wantsRedirect)
+ if err != nil {
+ return err
+ }
+ wantsRedirect = wantsRedirectURI.String()
+
+ // Redirect URIs are given to us as
+ // a list of URIs, newline-separated.
+ //
+ // They're already normalized on input so
+ // we don't need to parse + reserialize them.
+ //
+ // Ensure that one of them matches.
+ if slices.ContainsFunc(
+ strings.Split(hasRedirects, "\n"),
+ func(hasRedirect string) bool {
+ // Want an exact match.
+ // See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/
+ return wantsRedirect == hasRedirect
+ },
+ ) {
+ return nil
+ }
+
+ return oautherr.ErrInvalidRedirectURI
+ }
+}
+
+func GetAuthorizeScopeHandler() server.AuthorizeScopeHandler {
+ return func(_ http.ResponseWriter, r *http.Request) (string, error) {
+ // Use provided scope or
+ // fall back to default "read".
+ scope := r.FormValue("scope")
+ if strings.TrimSpace(scope) == "" {
+ scope = "read"
+ }
+ return scope, nil
+ }
+}
+
+func GetInternalErrorHandler(ctx context.Context) server.InternalErrorHandler {
+ return func(err error) *oautherr.Response {
+ log.Errorf(ctx, "internal oauth error: %v", err)
+ return nil
+ }
+}
+
+func GetResponseErrorHandler(ctx context.Context) server.ResponseErrorHandler {
+ return func(re *oautherr.Response) {
+ log.Errorf(ctx, "internal response error: %v", re.Error)
+ }
+}
+
+func GetUserAuthorizationHandler() server.UserAuthorizationHandler {
+ return func(w http.ResponseWriter, r *http.Request) (string, error) {
+ userID := r.FormValue("userid")
+ if userID == "" {
+ return "", errors.New("userid was empty")
+ }
+ return userID, nil
+ }
+}
diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go
deleted file mode 100644
index 2b76024f7..000000000
--- a/internal/oauth/oauth_test.go
+++ /dev/null
@@ -1,20 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see <http://www.gnu.org/licenses/>.
-
-package oauth_test
-
-// TODO: write tests
diff --git a/internal/oauth/server.go b/internal/oauth/server.go
index 8330ee179..c0c3c329c 100644
--- a/internal/oauth/server.go
+++ b/internal/oauth/server.go
@@ -30,7 +30,10 @@ import (
"codeberg.org/superseriousbusiness/oauth2/v4/server"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
)
const (
@@ -60,7 +63,8 @@ const (
HelpfulAdviceGrant = "If you arrived at this error during a sign in/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client"
)
-// Server wraps some oauth2 server functions in an interface, exposing only what is needed
+// Server wraps some oauth2 server functions
+// in an interface, exposing only what is needed.
type Server interface {
HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode)
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode
@@ -69,66 +73,76 @@ type Server interface {
LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error)
}
-// s fulfils the Server interface using the underlying oauth2 server
+// s fulfils the Server interface
+// using the underlying oauth2 server.
type s struct {
server *server.Server
}
// New returns a new oauth server that implements the Server interface
-func New(ctx context.Context, database db.DB) Server {
- ts := newTokenStore(ctx, database)
- cs := NewClientStore(database)
-
+func New(
+ ctx context.Context,
+ state *state.State,
+ validateURIHandler manage.ValidateURIHandler,
+ clientScopeHandler server.ClientScopeHandler,
+ authorizeScopeHandler server.AuthorizeScopeHandler,
+ internalErrorHandler server.InternalErrorHandler,
+ responseErrorHandler server.ResponseErrorHandler,
+ userAuthorizationHandler server.UserAuthorizationHandler,
+) Server {
+ ts := newTokenStore(ctx, state)
+ cs := NewClientStore(state)
+
+ // Set up OAuth2 manager.
manager := manage.NewDefaultManager()
+ manager.SetValidateURIHandler(validateURIHandler)
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
- manager.SetAuthorizeCodeTokenCfg(&manage.Config{
- AccessTokenExp: 0, // access tokens don't expire -- they must be revoked
- IsGenerateRefresh: false, // don't use refresh tokens
- })
- sc := &server.Config{
- TokenType: "Bearer",
- // Must follow the spec.
- AllowGetAccessRequest: false,
- // Support only the non-implicit flow.
- AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
- // Allow:
- // - Authorization Code (for first & third parties)
- // - Client Credentials (for applications)
- AllowedGrantTypes: []oauth2.GrantType{
- oauth2.AuthorizationCode,
- oauth2.ClientCredentials,
+ manager.SetAuthorizeCodeTokenCfg(
+ &manage.Config{
+ // Following the Mastodon API,
+ // access tokens don't expire.
+ AccessTokenExp: 0,
+ // Don't use refresh tokens.
+ IsGenerateRefresh: false,
},
- AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
- oauth2.CodeChallengePlain,
- oauth2.CodeChallengeS256,
+ )
+
+ // Set up OAuth2 server.
+ srv := server.NewServer(
+ &server.Config{
+ TokenType: "Bearer",
+ // Must follow the spec.
+ AllowGetAccessRequest: false,
+ // Support only the non-implicit flow.
+ AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
+ // Allow:
+ // - Authorization Code (for first & third parties)
+ // - Client Credentials (for applications)
+ AllowedGrantTypes: []oauth2.GrantType{
+ oauth2.AuthorizationCode,
+ oauth2.ClientCredentials,
+ },
+ AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
+ oauth2.CodeChallengePlain,
+ oauth2.CodeChallengeS256,
+ },
},
- }
-
- srv := server.NewServer(sc, manager)
- srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
- log.Errorf(nil, "internal oauth error: %s", err)
- return nil
- })
-
- srv.SetResponseErrorHandler(func(re *oautherr.Response) {
- log.Errorf(nil, "internal response error: %s", re.Error)
- })
-
- srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
- userID := r.FormValue("userid")
- if userID == "" {
- return "", errors.New("userid was empty")
- }
- return userID, nil
- })
+ manager,
+ )
+ srv.SetAuthorizeScopeHandler(authorizeScopeHandler)
+ srv.SetClientScopeHandler(clientScopeHandler)
+ srv.SetInternalErrorHandler(internalErrorHandler)
+ srv.SetResponseErrorHandler(responseErrorHandler)
+ srv.SetUserAuthorizationHandler(userAuthorizationHandler)
srv.SetClientInfoHandler(server.ClientFormHandler)
- return &s{
- server: srv,
- }
+
+ return &s{srv}
}
-// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
+// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function,
+// providing some custom error handling (with more informative messages),
+// and a slightly different token serialization format.
func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) {
ctx := r.Context()
@@ -142,32 +156,43 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
return nil, gtserror.NewErrorBadRequest(err, help, adv)
}
+ // Get access token + do our own nicer error handling.
ti, err := s.server.GetAccessToken(ctx, gt, tgr)
- if err != nil {
- help := fmt.Sprintf("could not get access token: %s", err)
+ switch {
+ case err == nil:
+ // No problem.
+ break
+
+ case errors.Is(err, oautherr.ErrInvalidScope):
+ help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope)
+ return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
+
+ case errors.Is(err, oautherr.ErrInvalidRedirectURI):
+ help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI)
+ return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
+
+ default:
+ help := fmt.Sprintf("could not get access token: %v", err)
return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
}
+ // Wrangle data a bit.
data := s.server.GetTokenData(ti)
+ // Add created_at for Mastodon API compatibility.
+ data["created_at"] = ti.GetAccessCreateAt().Unix()
+
+ // If expires_in is 0 or less, omit it
+ // from serialization so that clients don't
+ // interpret the token as already expired.
if expiresInI, ok := data["expires_in"]; ok {
- switch expiresIn := expiresInI.(type) {
- case int64:
- // remove this key from the returned map
- // if the value is 0 or less, so that clients
- // don't interpret the token as already expired
- if expiresIn <= 0 {
- delete(data, "expires_in")
- }
- default:
- err := errors.New("expires_in was set on token response, but was not an int64")
- return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
+ // This will panic if expiresIn is
+ // not an int64, which is what we want.
+ if expiresInI.(int64) <= 0 {
+ delete(data, "expires_in")
}
}
- // add this for mastodon api compatibility
- data["created_at"] = ti.GetAccessCreateAt().Unix()
-
return data, nil
}
@@ -207,7 +232,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
req.UserID = userID
- // specify the scope of authorization
+ // Specify the scope of authorization.
if fn := s.server.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r)
if err != nil {
@@ -217,7 +242,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
}
- // specify the expiration time of access token
+ // Specify the expiration time of access token.
if fn := s.server.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r)
if err != nil {
@@ -231,13 +256,24 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
return s.errorOrRedirect(err, w, req)
}
- // If the redirect URI is empty, the default domain provided by the client is used.
+ // If the redirect URI is empty, use the
+ // first of the client's redirect URIs.
if req.RedirectURI == "" {
client, err := s.server.Manager.GetClient(ctx, req.ClientID)
- if err != nil {
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ // Real error.
+ err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err)
+ return gtserror.NewErrorInternalError(err)
+ }
+
+ if util.IsNil(client) {
+ // Application just not found.
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
}
- req.RedirectURI = client.GetDomain()
+
+ // This will panic if client is not a
+ // *gtsmodel.Application, which is what we want.
+ req.RedirectURI = client.(*gtsmodel.Application).RedirectURIs[0]
}
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))
diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go
index df2e419fe..8c6506fa3 100644
--- a/internal/oauth/tokenstore.go
+++ b/internal/oauth/tokenstore.go
@@ -22,30 +22,32 @@ import (
"errors"
"time"
+ "codeberg.org/gruf/go-mutexes"
"codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/models"
- "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
)
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct {
oauth2.TokenStore
- db db.DB
+ state *state.State
+ lastUsedLocks mutexes.MutexMap
}
// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
//
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
-func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
- ts := &tokenStore{
- db: db,
- }
+func newTokenStore(ctx context.Context, state *state.State) oauth2.TokenStore {
+ ts := &tokenStore{state: state}
- // set the token store to clean out expired tokens once per minute, or return if we're done
+ // Set the token store to clean out expired tokens
+ // once per minute, or return if we're done.
go func(ctx context.Context, ts *tokenStore) {
cleanloop:
for {
@@ -64,25 +66,48 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
return ts
}
-// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
+// sweep clears out old tokens that have expired;
+// it should be run on a loop about once per minute or so.
func (ts *tokenStore) sweep(ctx context.Context) error {
- // select *all* tokens from the db
- // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
- tokens, err := ts.db.GetAllTokens(ctx)
+ // Select *all* tokens from the db
+ //
+ // TODO: if this becomes expensive
+ // (ie., there are fucking LOADS of
+ // tokens) then figure out a better way.
+ tokens, err := ts.state.DB.GetAllTokens(ctx)
if err != nil {
return err
}
- // iterate through and remove expired tokens
+ // Remove any expired tokens, bearing
+ // in mind that zero time = no expiry.
now := time.Now()
- for _, dbt := range tokens {
- // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
- // we only want to check if a token expired before now if the expiry time is *not zero*;
- // ie., if it's been explicity set.
- if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
- if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil {
- return err
- }
+ for _, token := range tokens {
+ var expired bool
+
+ switch {
+ case !token.CodeExpiresAt.IsZero() && token.CodeExpiresAt.Before(now):
+ log.Tracef(ctx, "code token %s is expired", token.ID)
+ expired = true
+
+ case !token.RefreshExpiresAt.IsZero() && token.RefreshExpiresAt.Before(now):
+ log.Tracef(ctx, "refresh token %s is expired", token.ID)
+ expired = true
+
+ case !token.AccessExpiresAt.IsZero() && token.AccessExpiresAt.Before(now):
+ log.Tracef(ctx, "access token %s is expired", token.ID)
+ expired = true
+ }
+
+ if !expired {
+ // Token's
+ // still good.
+ continue
+ }
+
+ if err := ts.state.DB.DeleteTokenByID(ctx, token.ID); err != nil {
+ err := gtserror.Newf("db error expiring token %s: %w", token.ID, err)
+ return err
}
}
@@ -90,7 +115,6 @@ func (ts *tokenStore) sweep(ctx context.Context) error {
}
// Create creates and store the new token information.
-// For the original implementation, see https://codeberg.org/superseriousbusiness/oauth2/blob/master/store/token.go#L34
func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token)
if !ok {
@@ -99,55 +123,99 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
dbt := TokenToDBToken(t)
if dbt.ID == "" {
- dbtID, err := id.NewRandomULID()
- if err != nil {
- return err
- }
- dbt.ID = dbtID
+ dbt.ID = id.NewULID()
}
- return ts.db.PutToken(ctx, dbt)
+ return ts.state.DB.PutToken(ctx, dbt)
}
// RemoveByCode deletes a token from the DB based on the Code field
func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
- return ts.db.DeleteTokenByCode(ctx, code)
+ return ts.state.DB.DeleteTokenByCode(ctx, code)
}
// RemoveByAccess deletes a token from the DB based on the Access field
func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
- return ts.db.DeleteTokenByAccess(ctx, access)
+ return ts.state.DB.DeleteTokenByAccess(ctx, access)
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
- return ts.db.DeleteTokenByRefresh(ctx, refresh)
+ return ts.state.DB.DeleteTokenByRefresh(ctx, refresh)
}
-// GetByCode selects a token from the DB based on the Code field
-func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
- token, err := ts.db.GetTokenByCode(ctx, code)
- if err != nil {
- return nil, err
- }
- return DBTokenToToken(token), nil
+// GetByCode selects a token from
+// the DB based on the Code field
+func (ts *tokenStore) GetByCode(
+ ctx context.Context,
+ code string,
+) (oauth2.TokenInfo, error) {
+ return ts.getUpdateToken(
+ ctx,
+ ts.state.DB.GetTokenByCode,
+ code,
+ )
}
-// GetByAccess selects a token from the DB based on the Access field
-func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
- token, err := ts.db.GetTokenByAccess(ctx, access)
- if err != nil {
- return nil, err
- }
- return DBTokenToToken(token), nil
+// GetByAccess selects a token from
+// the DB based on the Access field.
+func (ts *tokenStore) GetByAccess(
+ ctx context.Context,
+ access string,
+) (oauth2.TokenInfo, error) {
+ return ts.getUpdateToken(
+ ctx,
+ ts.state.DB.GetTokenByAccess,
+ access,
+ )
}
-// GetByRefresh selects a token from the DB based on the Refresh field
-func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
- token, err := ts.db.GetTokenByRefresh(ctx, refresh)
+// GetByRefresh selects a token from
+// the DB based on the Refresh field
+func (ts *tokenStore) GetByRefresh(
+ ctx context.Context,
+ refresh string,
+) (oauth2.TokenInfo, error) {
+ return ts.getUpdateToken(
+ ctx,
+ ts.state.DB.GetTokenByRefresh,
+ refresh,
+ )
+}
+
+// package-internal function for getting a token
+// and potentially updating its last_used value.
+func (ts *tokenStore) getUpdateToken(
+ ctx context.Context,
+ getBy func(context.Context, string) (*gtsmodel.Token, error),
+ key string,
+) (oauth2.TokenInfo, error) {
+ // Hold a lock to get the token based on
+ // whatever func + key we've been given.
+ unlock := ts.lastUsedLocks.Lock(key)
+
+ token, err := getBy(ctx, key)
if err != nil {
+ // Unlock on error.
+ unlock()
return nil, err
}
+
+ // If token was last used more than
+ // an hour ago, update this in the db.
+ wasLastUsed := token.LastUsed
+ if now := time.Now(); now.Sub(wasLastUsed) > 1*time.Hour {
+ token.LastUsed = now
+ if err := ts.state.DB.UpdateToken(ctx, token, "last_used"); err != nil {
+ // Unlock on error.
+ unlock()
+ err := gtserror.Newf("error updating last_used on token: %w", err)
+ return nil, err
+ }
+ }
+
+ // We're done, unlock.
+ unlock()
return DBTokenToToken(token), nil
}