diff options
Diffstat (limited to 'vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go')
-rw-r--r-- | vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go | 47 |
1 files changed, 23 insertions, 24 deletions
diff --git a/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go b/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go index 0aac66ffc..46bf23822 100644 --- a/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go +++ b/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go @@ -155,7 +155,6 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, clientID := r.FormValue("client_id") if !(r.Method == "GET" || r.Method == "POST") || clientID == "" { - fmt.Println(r.Method, clientID, r) return nil, errors.ErrInvalidRequest } @@ -213,9 +212,18 @@ func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) ( } } + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + } + // check the client allows the authorized scope if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(req.ClientID, req.Scope) + allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { @@ -223,16 +231,9 @@ func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) ( } } - tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, - CodeChallenge: req.CodeChallenge, - CodeChallengeMethod: req.CodeChallengeMethod, - } + tgr.CodeChallenge = req.CodeChallenge + tgr.CodeChallengeMethod = req.CodeChallengeMethod + return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) } @@ -312,11 +313,6 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau return "", nil, errors.ErrUnsupportedGrantType } - codeVer := r.FormValue("code_verifier") - if s.Config.ForcePKCE && codeVer == "" { - return "", nil, errors.ErrInvalidRequest - } - clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return "", nil, err @@ -336,7 +332,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau tgr.Code == "" { return "", nil, errors.ErrInvalidRequest } - tgr.CodeVerifier = codeVer + tgr.CodeVerifier = r.FormValue("code_verifier") + if s.Config.ForcePKCE && tgr.CodeVerifier == "" { + return "", nil, errors.ErrInvalidRequest + } case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") username, password := r.FormValue("username"), r.FormValue("password") @@ -374,7 +373,8 @@ func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { } // GetAccessToken access token -func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { +func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, + error) { if allowed := s.CheckGrantType(gt); !allowed { return nil, errors.ErrUnauthorizedClient } @@ -393,8 +393,7 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) if err != nil { switch err { - case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, - errors.ErrMissingCodeChallenge, errors.ErrMissingCodeChallenge: + case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge: return nil, errors.ErrInvalidGrant case errors.ErrInvalidClient: return nil, errors.ErrInvalidClient @@ -405,7 +404,7 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o return ti, nil case oauth2.PasswordCredentials, oauth2.ClientCredentials: if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr.ClientID, tgr.Scope) + allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { @@ -415,7 +414,7 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o return s.Manager.GenerateAccessToken(ctx, gt, tgr) case oauth2.Refreshing: // check scope - if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil { + if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { @@ -424,7 +423,7 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o return nil, err } - allowed, err := scopeFn(scope, rti.GetScope()) + allowed, err := scopeFn(tgr, rti.GetScope()) if err != nil { return nil, err } else if !allowed { |