summaryrefslogtreecommitdiff
path: root/internal/oauth/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/oauth/server.go')
-rw-r--r--internal/oauth/server.go174
1 files changed, 105 insertions, 69 deletions
diff --git a/internal/oauth/server.go b/internal/oauth/server.go
index 8330ee179..c0c3c329c 100644
--- a/internal/oauth/server.go
+++ b/internal/oauth/server.go
@@ -30,7 +30,10 @@ import (
"codeberg.org/superseriousbusiness/oauth2/v4/server"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
)
const (
@@ -60,7 +63,8 @@ const (
HelpfulAdviceGrant = "If you arrived at this error during a sign in/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client"
)
-// Server wraps some oauth2 server functions in an interface, exposing only what is needed
+// Server wraps some oauth2 server functions
+// in an interface, exposing only what is needed.
type Server interface {
HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode)
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode
@@ -69,66 +73,76 @@ type Server interface {
LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error)
}
-// s fulfils the Server interface using the underlying oauth2 server
+// s fulfils the Server interface
+// using the underlying oauth2 server.
type s struct {
server *server.Server
}
// New returns a new oauth server that implements the Server interface
-func New(ctx context.Context, database db.DB) Server {
- ts := newTokenStore(ctx, database)
- cs := NewClientStore(database)
-
+func New(
+ ctx context.Context,
+ state *state.State,
+ validateURIHandler manage.ValidateURIHandler,
+ clientScopeHandler server.ClientScopeHandler,
+ authorizeScopeHandler server.AuthorizeScopeHandler,
+ internalErrorHandler server.InternalErrorHandler,
+ responseErrorHandler server.ResponseErrorHandler,
+ userAuthorizationHandler server.UserAuthorizationHandler,
+) Server {
+ ts := newTokenStore(ctx, state)
+ cs := NewClientStore(state)
+
+ // Set up OAuth2 manager.
manager := manage.NewDefaultManager()
+ manager.SetValidateURIHandler(validateURIHandler)
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
- manager.SetAuthorizeCodeTokenCfg(&manage.Config{
- AccessTokenExp: 0, // access tokens don't expire -- they must be revoked
- IsGenerateRefresh: false, // don't use refresh tokens
- })
- sc := &server.Config{
- TokenType: "Bearer",
- // Must follow the spec.
- AllowGetAccessRequest: false,
- // Support only the non-implicit flow.
- AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
- // Allow:
- // - Authorization Code (for first & third parties)
- // - Client Credentials (for applications)
- AllowedGrantTypes: []oauth2.GrantType{
- oauth2.AuthorizationCode,
- oauth2.ClientCredentials,
+ manager.SetAuthorizeCodeTokenCfg(
+ &manage.Config{
+ // Following the Mastodon API,
+ // access tokens don't expire.
+ AccessTokenExp: 0,
+ // Don't use refresh tokens.
+ IsGenerateRefresh: false,
},
- AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
- oauth2.CodeChallengePlain,
- oauth2.CodeChallengeS256,
+ )
+
+ // Set up OAuth2 server.
+ srv := server.NewServer(
+ &server.Config{
+ TokenType: "Bearer",
+ // Must follow the spec.
+ AllowGetAccessRequest: false,
+ // Support only the non-implicit flow.
+ AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
+ // Allow:
+ // - Authorization Code (for first & third parties)
+ // - Client Credentials (for applications)
+ AllowedGrantTypes: []oauth2.GrantType{
+ oauth2.AuthorizationCode,
+ oauth2.ClientCredentials,
+ },
+ AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
+ oauth2.CodeChallengePlain,
+ oauth2.CodeChallengeS256,
+ },
},
- }
-
- srv := server.NewServer(sc, manager)
- srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
- log.Errorf(nil, "internal oauth error: %s", err)
- return nil
- })
-
- srv.SetResponseErrorHandler(func(re *oautherr.Response) {
- log.Errorf(nil, "internal response error: %s", re.Error)
- })
-
- srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
- userID := r.FormValue("userid")
- if userID == "" {
- return "", errors.New("userid was empty")
- }
- return userID, nil
- })
+ manager,
+ )
+ srv.SetAuthorizeScopeHandler(authorizeScopeHandler)
+ srv.SetClientScopeHandler(clientScopeHandler)
+ srv.SetInternalErrorHandler(internalErrorHandler)
+ srv.SetResponseErrorHandler(responseErrorHandler)
+ srv.SetUserAuthorizationHandler(userAuthorizationHandler)
srv.SetClientInfoHandler(server.ClientFormHandler)
- return &s{
- server: srv,
- }
+
+ return &s{srv}
}
-// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
+// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function,
+// providing some custom error handling (with more informative messages),
+// and a slightly different token serialization format.
func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) {
ctx := r.Context()
@@ -142,32 +156,43 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
return nil, gtserror.NewErrorBadRequest(err, help, adv)
}
+ // Get access token + do our own nicer error handling.
ti, err := s.server.GetAccessToken(ctx, gt, tgr)
- if err != nil {
- help := fmt.Sprintf("could not get access token: %s", err)
+ switch {
+ case err == nil:
+ // No problem.
+ break
+
+ case errors.Is(err, oautherr.ErrInvalidScope):
+ help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope)
+ return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
+
+ case errors.Is(err, oautherr.ErrInvalidRedirectURI):
+ help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI)
+ return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
+
+ default:
+ help := fmt.Sprintf("could not get access token: %v", err)
return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
}
+ // Wrangle data a bit.
data := s.server.GetTokenData(ti)
+ // Add created_at for Mastodon API compatibility.
+ data["created_at"] = ti.GetAccessCreateAt().Unix()
+
+ // If expires_in is 0 or less, omit it
+ // from serialization so that clients don't
+ // interpret the token as already expired.
if expiresInI, ok := data["expires_in"]; ok {
- switch expiresIn := expiresInI.(type) {
- case int64:
- // remove this key from the returned map
- // if the value is 0 or less, so that clients
- // don't interpret the token as already expired
- if expiresIn <= 0 {
- delete(data, "expires_in")
- }
- default:
- err := errors.New("expires_in was set on token response, but was not an int64")
- return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
+ // This will panic if expiresIn is
+ // not an int64, which is what we want.
+ if expiresInI.(int64) <= 0 {
+ delete(data, "expires_in")
}
}
- // add this for mastodon api compatibility
- data["created_at"] = ti.GetAccessCreateAt().Unix()
-
return data, nil
}
@@ -207,7 +232,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
req.UserID = userID
- // specify the scope of authorization
+ // Specify the scope of authorization.
if fn := s.server.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r)
if err != nil {
@@ -217,7 +242,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
}
- // specify the expiration time of access token
+ // Specify the expiration time of access token.
if fn := s.server.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r)
if err != nil {
@@ -231,13 +256,24 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
return s.errorOrRedirect(err, w, req)
}
- // If the redirect URI is empty, the default domain provided by the client is used.
+ // If the redirect URI is empty, use the
+ // first of the client's redirect URIs.
if req.RedirectURI == "" {
client, err := s.server.Manager.GetClient(ctx, req.ClientID)
- if err != nil {
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ // Real error.
+ err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err)
+ return gtserror.NewErrorInternalError(err)
+ }
+
+ if util.IsNil(client) {
+ // Application just not found.
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
}
- req.RedirectURI = client.GetDomain()
+
+ // This will panic if client is not a
+ // *gtsmodel.Application, which is what we want.
+ req.RedirectURI = client.(*gtsmodel.Application).RedirectURIs[0]
}
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))