summaryrefslogtreecommitdiff
path: root/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go')
-rw-r--r--vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go47
1 files changed, 23 insertions, 24 deletions
diff --git a/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go b/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go
index 0aac66ffc..46bf23822 100644
--- a/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go
+++ b/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go
@@ -155,7 +155,6 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest,
clientID := r.FormValue("client_id")
if !(r.Method == "GET" || r.Method == "POST") ||
clientID == "" {
- fmt.Println(r.Method, clientID, r)
return nil, errors.ErrInvalidRequest
}
@@ -213,9 +212,18 @@ func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (
}
}
+ tgr := &oauth2.TokenGenerateRequest{
+ ClientID: req.ClientID,
+ UserID: req.UserID,
+ RedirectURI: req.RedirectURI,
+ Scope: req.Scope,
+ AccessTokenExp: req.AccessTokenExp,
+ Request: req.Request,
+ }
+
// check the client allows the authorized scope
if fn := s.ClientScopeHandler; fn != nil {
- allowed, err := fn(req.ClientID, req.Scope)
+ allowed, err := fn(tgr)
if err != nil {
return nil, err
} else if !allowed {
@@ -223,16 +231,9 @@ func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (
}
}
- tgr := &oauth2.TokenGenerateRequest{
- ClientID: req.ClientID,
- UserID: req.UserID,
- RedirectURI: req.RedirectURI,
- Scope: req.Scope,
- AccessTokenExp: req.AccessTokenExp,
- Request: req.Request,
- CodeChallenge: req.CodeChallenge,
- CodeChallengeMethod: req.CodeChallengeMethod,
- }
+ tgr.CodeChallenge = req.CodeChallenge
+ tgr.CodeChallengeMethod = req.CodeChallengeMethod
+
return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr)
}
@@ -312,11 +313,6 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau
return "", nil, errors.ErrUnsupportedGrantType
}
- codeVer := r.FormValue("code_verifier")
- if s.Config.ForcePKCE && codeVer == "" {
- return "", nil, errors.ErrInvalidRequest
- }
-
clientID, clientSecret, err := s.ClientInfoHandler(r)
if err != nil {
return "", nil, err
@@ -336,7 +332,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau
tgr.Code == "" {
return "", nil, errors.ErrInvalidRequest
}
- tgr.CodeVerifier = codeVer
+ tgr.CodeVerifier = r.FormValue("code_verifier")
+ if s.Config.ForcePKCE && tgr.CodeVerifier == "" {
+ return "", nil, errors.ErrInvalidRequest
+ }
case oauth2.PasswordCredentials:
tgr.Scope = r.FormValue("scope")
username, password := r.FormValue("username"), r.FormValue("password")
@@ -374,7 +373,8 @@ func (s *Server) CheckGrantType(gt oauth2.GrantType) bool {
}
// GetAccessToken access token
-func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
+func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo,
+ error) {
if allowed := s.CheckGrantType(gt); !allowed {
return nil, errors.ErrUnauthorizedClient
}
@@ -393,8 +393,7 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o
ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr)
if err != nil {
switch err {
- case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge,
- errors.ErrMissingCodeChallenge, errors.ErrMissingCodeChallenge:
+ case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge:
return nil, errors.ErrInvalidGrant
case errors.ErrInvalidClient:
return nil, errors.ErrInvalidClient
@@ -405,7 +404,7 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o
return ti, nil
case oauth2.PasswordCredentials, oauth2.ClientCredentials:
if fn := s.ClientScopeHandler; fn != nil {
- allowed, err := fn(tgr.ClientID, tgr.Scope)
+ allowed, err := fn(tgr)
if err != nil {
return nil, err
} else if !allowed {
@@ -415,7 +414,7 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o
return s.Manager.GenerateAccessToken(ctx, gt, tgr)
case oauth2.Refreshing:
// check scope
- if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil {
+ if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil {
rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
if err != nil {
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
@@ -424,7 +423,7 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o
return nil, err
}
- allowed, err := scopeFn(scope, rti.GetScope())
+ allowed, err := scopeFn(tgr, rti.GetScope())
if err != nil {
return nil, err
} else if !allowed {