diff options
Diffstat (limited to 'vendor/code.superseriousbusiness.org/oauth2/v4/server')
3 files changed, 105 insertions, 28 deletions
diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/server/handler.go b/vendor/code.superseriousbusiness.org/oauth2/v4/server/handler.go index 745716dc5..808b476d8 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/server/handler.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/server/handler.go @@ -1,7 +1,9 @@ package server import ( + "context" "net/http" + "strings" "time" "code.superseriousbusiness.org/oauth2/v4" @@ -22,7 +24,7 @@ type ( UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) // PasswordAuthorizationHandler get user id from username and password - PasswordAuthorizationHandler func(username, password string) (userID string, err error) + PasswordAuthorizationHandler func(ctx context.Context, clientID, username, password string) (userID string, err error) // RefreshingScopeHandler check the scope of the refreshing token RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) @@ -36,6 +38,9 @@ type ( // InternalErrorHandler internal error handing InternalErrorHandler func(err error) (re *errors.Response) + // PreRedirectErrorHandler is used to override "redirect-on-error" behavior + PreRedirectErrorHandler func(w http.ResponseWriter, req *AuthorizeRequest, err error) error + // AuthorizeScopeHandler set the authorized scope AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) @@ -44,6 +49,15 @@ type ( // ExtensionFieldsHandler in response to the access token with the extension of the field ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) + + // ResponseTokenHandler response token handling + ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error + + // Handler to fetch the refresh token from the request + RefreshTokenResolveHandler func(r *http.Request) (string, error) + + // Handler to fetch the access token from the request + AccessTokenResolveHandler func(r *http.Request) (string, bool) ) // ClientFormHandler get client data from form @@ -64,3 +78,44 @@ func ClientBasicHandler(r *http.Request) (string, string, error) { } return username, password, nil } + +func RefreshTokenFormResolveHandler(r *http.Request) (string, error) { + rt := r.FormValue("refresh_token") + if rt == "" { + return "", errors.ErrInvalidRequest + } + + return rt, nil +} + +func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) { + c, err := r.Cookie("refresh_token") + if err != nil { + return "", errors.ErrInvalidRequest + } + + return c.Value, nil +} + +func AccessTokenDefaultResolveHandler(r *http.Request) (string, bool) { + token := "" + auth := r.Header.Get("Authorization") + prefix := "Bearer " + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +func AccessTokenCookieResolveHandler(r *http.Request) (string, bool) { + c, err := r.Cookie("access_token") + if err != nil { + return "", false + } + + return c.Value, true +} diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go index 91b9effb7..c0a59b755 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "time" "code.superseriousbusiness.org/oauth2/v4" @@ -26,14 +25,16 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server { Manager: manager, } - // default handler + // default handlers srv.ClientInfoHandler = ClientBasicHandler + srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler + srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { + srv.PasswordAuthorizationHandler = func(ctx context.Context, clientID, username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv @@ -49,18 +50,31 @@ type Server struct { UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingValidationHandler RefreshingValidationHandler + PreRedirectErrorHandler PreRedirectErrorHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler + ResponseTokenHandler ResponseTokenHandler + RefreshTokenResolveHandler RefreshTokenResolveHandler + AccessTokenResolveHandler AccessTokenResolveHandler +} + +func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { + if fn := s.PreRedirectErrorHandler; fn != nil { + return fn(w, req, err) + } + + return s.redirectError(w, req, err) } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } + data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } @@ -82,6 +96,9 @@ func (s *Server) tokenError(w http.ResponseWriter, err error) error { } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { + if fn := s.ResponseTokenHandler; fn != nil { + return fn(w, data, header, statusCode...) + } w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") @@ -182,7 +199,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, oauth2.CodeChallengePlain, ) } - if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { + if ccm != "" && !s.CheckCodeChallengeMethod(ccm) { return nil, errors.ErrUnsupportedCodeChallengeMethod } @@ -257,13 +274,13 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) req, err := s.ValidationAuthorizeRequest(r) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } else if userID == "" { return nil } @@ -290,7 +307,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. @@ -351,7 +368,7 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau return "", nil, errors.ErrInvalidRequest } - userID, err := s.PasswordAuthorizationHandler(username, password) + userID, err := s.PasswordAuthorizationHandler(r.Context(), clientID, username, password) if err != nil { return "", nil, err } else if userID == "" { @@ -362,10 +379,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau tgr.Scope = r.FormValue("scope") tgr.RedirectURI = r.FormValue("redirect_uri") case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") + tgr.Refresh, err = s.RefreshTokenResolveHandler(r) tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest + if err != nil { + return "", nil, err } } return gt, tgr, nil @@ -564,27 +581,12 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head return data, statusCode, re.Header } -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() - accessToken, ok := s.BearerAuth(r) + accessToken, ok := s.AccessTokenResolveHandler(r) if !ok { return nil, errors.ErrInvalidAccessToken } diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server_config.go b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server_config.go index 4e8010196..70a8b2c11 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server_config.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server_config.go @@ -69,6 +69,11 @@ func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) { s.InternalErrorHandler = handler } +// SetPreRedirectErrorHandler sets the PreRedirectErrorHandler in current Server instance +func (s *Server) SetPreRedirectErrorHandler(handler PreRedirectErrorHandler) { + s.PreRedirectErrorHandler = handler +} + // SetExtensionFieldsHandler in response to the access token with the extension of the field func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) { s.ExtensionFieldsHandler = handler @@ -83,3 +88,18 @@ func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) { func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { s.AuthorizeScopeHandler = handler } + +// SetResponseTokenHandler response token handing +func (s *Server) SetResponseTokenHandler(handler ResponseTokenHandler) { + s.ResponseTokenHandler = handler +} + +// SetRefreshTokenResolveHandler refresh token resolver +func (s *Server) SetRefreshTokenResolveHandler(handler RefreshTokenResolveHandler) { + s.RefreshTokenResolveHandler = handler +} + +// SetAccessTokenResolveHandler access token resolver +func (s *Server) SetAccessTokenResolveHandler(handler AccessTokenResolveHandler) { + s.AccessTokenResolveHandler = handler +} |
