diff options
Diffstat (limited to 'vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go')
| -rw-r--r-- | vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go | 56 |
1 files changed, 29 insertions, 27 deletions
diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go index 91b9effb7..c0a59b755 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "time" "code.superseriousbusiness.org/oauth2/v4" @@ -26,14 +25,16 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server { Manager: manager, } - // default handler + // default handlers srv.ClientInfoHandler = ClientBasicHandler + srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler + srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { + srv.PasswordAuthorizationHandler = func(ctx context.Context, clientID, username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv @@ -49,18 +50,31 @@ type Server struct { UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingValidationHandler RefreshingValidationHandler + PreRedirectErrorHandler PreRedirectErrorHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler + ResponseTokenHandler ResponseTokenHandler + RefreshTokenResolveHandler RefreshTokenResolveHandler + AccessTokenResolveHandler AccessTokenResolveHandler +} + +func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { + if fn := s.PreRedirectErrorHandler; fn != nil { + return fn(w, req, err) + } + + return s.redirectError(w, req, err) } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } + data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } @@ -82,6 +96,9 @@ func (s *Server) tokenError(w http.ResponseWriter, err error) error { } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { + if fn := s.ResponseTokenHandler; fn != nil { + return fn(w, data, header, statusCode...) + } w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") @@ -182,7 +199,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, oauth2.CodeChallengePlain, ) } - if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { + if ccm != "" && !s.CheckCodeChallengeMethod(ccm) { return nil, errors.ErrUnsupportedCodeChallengeMethod } @@ -257,13 +274,13 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) req, err := s.ValidationAuthorizeRequest(r) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } else if userID == "" { return nil } @@ -290,7 +307,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. @@ -351,7 +368,7 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau return "", nil, errors.ErrInvalidRequest } - userID, err := s.PasswordAuthorizationHandler(username, password) + userID, err := s.PasswordAuthorizationHandler(r.Context(), clientID, username, password) if err != nil { return "", nil, err } else if userID == "" { @@ -362,10 +379,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau tgr.Scope = r.FormValue("scope") tgr.RedirectURI = r.FormValue("redirect_uri") case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") + tgr.Refresh, err = s.RefreshTokenResolveHandler(r) tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest + if err != nil { + return "", nil, err } } return gt, tgr, nil @@ -564,27 +581,12 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head return data, statusCode, re.Header } -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() - accessToken, ok := s.BearerAuth(r) + accessToken, ok := s.AccessTokenResolveHandler(r) if !ok { return nil, errors.ErrInvalidAccessToken } |
