diff options
| author | 2025-03-03 16:03:36 +0100 | |
|---|---|---|
| committer | 2025-03-03 15:03:36 +0000 | |
| commit | 1b37944f8b8eccc2afcfb0f603786209a3b7402d (patch) | |
| tree | 2bc0be27cf0405e16ac3e14efc3b6973eb096b8b /internal/oauth/server.go | |
| parent | bumps go-ffmpreg to v0.6.6 (#3866) (diff) | |
| download | gotosocial-1b37944f8b8eccc2afcfb0f603786209a3b7402d.tar.xz | |
[feature] Refactor tokens, allow multiple app redirect_uris (#3849)
* [feature] Refactor tokens, allow multiple app redirect_uris
* move + tweak handlers a bit
* return error for unset oauth2.ClientStore funcs
* wrap UpdateToken with cache
* panic handling
* cheeky little time optimization
* unlock on error
Diffstat (limited to 'internal/oauth/server.go')
| -rw-r--r-- | internal/oauth/server.go | 174 |
1 files changed, 105 insertions, 69 deletions
diff --git a/internal/oauth/server.go b/internal/oauth/server.go index 8330ee179..c0c3c329c 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -30,7 +30,10 @@ import ( "codeberg.org/superseriousbusiness/oauth2/v4/server" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" ) const ( @@ -60,7 +63,8 @@ const ( HelpfulAdviceGrant = "If you arrived at this error during a sign in/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client" ) -// Server wraps some oauth2 server functions in an interface, exposing only what is needed +// Server wraps some oauth2 server functions +// in an interface, exposing only what is needed. type Server interface { HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode @@ -69,66 +73,76 @@ type Server interface { LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) } -// s fulfils the Server interface using the underlying oauth2 server +// s fulfils the Server interface +// using the underlying oauth2 server. type s struct { server *server.Server } // New returns a new oauth server that implements the Server interface -func New(ctx context.Context, database db.DB) Server { - ts := newTokenStore(ctx, database) - cs := NewClientStore(database) - +func New( + ctx context.Context, + state *state.State, + validateURIHandler manage.ValidateURIHandler, + clientScopeHandler server.ClientScopeHandler, + authorizeScopeHandler server.AuthorizeScopeHandler, + internalErrorHandler server.InternalErrorHandler, + responseErrorHandler server.ResponseErrorHandler, + userAuthorizationHandler server.UserAuthorizationHandler, +) Server { + ts := newTokenStore(ctx, state) + cs := NewClientStore(state) + + // Set up OAuth2 manager. manager := manage.NewDefaultManager() + manager.SetValidateURIHandler(validateURIHandler) manager.MapTokenStorage(ts) manager.MapClientStorage(cs) - manager.SetAuthorizeCodeTokenCfg(&manage.Config{ - AccessTokenExp: 0, // access tokens don't expire -- they must be revoked - IsGenerateRefresh: false, // don't use refresh tokens - }) - sc := &server.Config{ - TokenType: "Bearer", - // Must follow the spec. - AllowGetAccessRequest: false, - // Support only the non-implicit flow. - AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, - // Allow: - // - Authorization Code (for first & third parties) - // - Client Credentials (for applications) - AllowedGrantTypes: []oauth2.GrantType{ - oauth2.AuthorizationCode, - oauth2.ClientCredentials, + manager.SetAuthorizeCodeTokenCfg( + &manage.Config{ + // Following the Mastodon API, + // access tokens don't expire. + AccessTokenExp: 0, + // Don't use refresh tokens. + IsGenerateRefresh: false, }, - AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ - oauth2.CodeChallengePlain, - oauth2.CodeChallengeS256, + ) + + // Set up OAuth2 server. + srv := server.NewServer( + &server.Config{ + TokenType: "Bearer", + // Must follow the spec. + AllowGetAccessRequest: false, + // Support only the non-implicit flow. + AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, + // Allow: + // - Authorization Code (for first & third parties) + // - Client Credentials (for applications) + AllowedGrantTypes: []oauth2.GrantType{ + oauth2.AuthorizationCode, + oauth2.ClientCredentials, + }, + AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ + oauth2.CodeChallengePlain, + oauth2.CodeChallengeS256, + }, }, - } - - srv := server.NewServer(sc, manager) - srv.SetInternalErrorHandler(func(err error) *oautherr.Response { - log.Errorf(nil, "internal oauth error: %s", err) - return nil - }) - - srv.SetResponseErrorHandler(func(re *oautherr.Response) { - log.Errorf(nil, "internal response error: %s", re.Error) - }) - - srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { - userID := r.FormValue("userid") - if userID == "" { - return "", errors.New("userid was empty") - } - return userID, nil - }) + manager, + ) + srv.SetAuthorizeScopeHandler(authorizeScopeHandler) + srv.SetClientScopeHandler(clientScopeHandler) + srv.SetInternalErrorHandler(internalErrorHandler) + srv.SetResponseErrorHandler(responseErrorHandler) + srv.SetUserAuthorizationHandler(userAuthorizationHandler) srv.SetClientInfoHandler(server.ClientFormHandler) - return &s{ - server: srv, - } + + return &s{srv} } -// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function +// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function, +// providing some custom error handling (with more informative messages), +// and a slightly different token serialization format. func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) { ctx := r.Context() @@ -142,32 +156,43 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro return nil, gtserror.NewErrorBadRequest(err, help, adv) } + // Get access token + do our own nicer error handling. ti, err := s.server.GetAccessToken(ctx, gt, tgr) - if err != nil { - help := fmt.Sprintf("could not get access token: %s", err) + switch { + case err == nil: + // No problem. + break + + case errors.Is(err, oautherr.ErrInvalidScope): + help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope) + return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice) + + case errors.Is(err, oautherr.ErrInvalidRedirectURI): + help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI) + return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice) + + default: + help := fmt.Sprintf("could not get access token: %v", err) return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice) } + // Wrangle data a bit. data := s.server.GetTokenData(ti) + // Add created_at for Mastodon API compatibility. + data["created_at"] = ti.GetAccessCreateAt().Unix() + + // If expires_in is 0 or less, omit it + // from serialization so that clients don't + // interpret the token as already expired. if expiresInI, ok := data["expires_in"]; ok { - switch expiresIn := expiresInI.(type) { - case int64: - // remove this key from the returned map - // if the value is 0 or less, so that clients - // don't interpret the token as already expired - if expiresIn <= 0 { - delete(data, "expires_in") - } - default: - err := errors.New("expires_in was set on token response, but was not an int64") - return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice) + // This will panic if expiresIn is + // not an int64, which is what we want. + if expiresInI.(int64) <= 0 { + delete(data, "expires_in") } } - // add this for mastodon api compatibility - data["created_at"] = ti.GetAccessCreateAt().Unix() - return data, nil } @@ -207,7 +232,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser } req.UserID = userID - // specify the scope of authorization + // Specify the scope of authorization. if fn := s.server.AuthorizeScopeHandler; fn != nil { scope, err := fn(w, r) if err != nil { @@ -217,7 +242,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser } } - // specify the expiration time of access token + // Specify the expiration time of access token. if fn := s.server.AccessTokenExpHandler; fn != nil { exp, err := fn(w, r) if err != nil { @@ -231,13 +256,24 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser return s.errorOrRedirect(err, w, req) } - // If the redirect URI is empty, the default domain provided by the client is used. + // If the redirect URI is empty, use the + // first of the client's redirect URIs. if req.RedirectURI == "" { client, err := s.server.Manager.GetClient(ctx, req.ClientID) - if err != nil { + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Real error. + err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err) + return gtserror.NewErrorInternalError(err) + } + + if util.IsNil(client) { + // Application just not found. return gtserror.NewErrorUnauthorized(err, HelpfulAdvice) } - req.RedirectURI = client.GetDomain() + + // This will panic if client is not a + // *gtsmodel.Application, which is what we want. + req.RedirectURI = client.(*gtsmodel.Application).RedirectURIs[0] } uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti)) |
