summaryrefslogtreecommitdiff
path: root/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go')
-rw-r--r--vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go56
1 files changed, 29 insertions, 27 deletions
diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go
index 91b9effb7..c0a59b755 100644
--- a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go
+++ b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go
@@ -7,7 +7,6 @@ import (
"fmt"
"net/http"
"net/url"
- "strings"
"time"
"code.superseriousbusiness.org/oauth2/v4"
@@ -26,14 +25,16 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server {
Manager: manager,
}
- // default handler
+ // default handlers
srv.ClientInfoHandler = ClientBasicHandler
+ srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler
+ srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler
srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
return "", errors.ErrAccessDenied
}
- srv.PasswordAuthorizationHandler = func(username, password string) (string, error) {
+ srv.PasswordAuthorizationHandler = func(ctx context.Context, clientID, username, password string) (string, error) {
return "", errors.ErrAccessDenied
}
return srv
@@ -49,18 +50,31 @@ type Server struct {
UserAuthorizationHandler UserAuthorizationHandler
PasswordAuthorizationHandler PasswordAuthorizationHandler
RefreshingValidationHandler RefreshingValidationHandler
+ PreRedirectErrorHandler PreRedirectErrorHandler
RefreshingScopeHandler RefreshingScopeHandler
ResponseErrorHandler ResponseErrorHandler
InternalErrorHandler InternalErrorHandler
ExtensionFieldsHandler ExtensionFieldsHandler
AccessTokenExpHandler AccessTokenExpHandler
AuthorizeScopeHandler AuthorizeScopeHandler
+ ResponseTokenHandler ResponseTokenHandler
+ RefreshTokenResolveHandler RefreshTokenResolveHandler
+ AccessTokenResolveHandler AccessTokenResolveHandler
+}
+
+func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
+ if fn := s.PreRedirectErrorHandler; fn != nil {
+ return fn(w, req, err)
+ }
+
+ return s.redirectError(w, req, err)
}
func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
if req == nil {
return err
}
+
data, _, _ := s.GetErrorData(err)
return s.redirect(w, req, data)
}
@@ -82,6 +96,9 @@ func (s *Server) tokenError(w http.ResponseWriter, err error) error {
}
func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error {
+ if fn := s.ResponseTokenHandler; fn != nil {
+ return fn(w, data, header, statusCode...)
+ }
w.Header().Set("Content-Type", "application/json;charset=UTF-8")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
@@ -182,7 +199,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest,
oauth2.CodeChallengePlain,
)
}
- if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) {
+ if ccm != "" && !s.CheckCodeChallengeMethod(ccm) {
return nil, errors.ErrUnsupportedCodeChallengeMethod
}
@@ -257,13 +274,13 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
req, err := s.ValidationAuthorizeRequest(r)
if err != nil {
- return s.redirectError(w, req, err)
+ return s.handleError(w, req, err)
}
// user authorization
userID, err := s.UserAuthorizationHandler(w, r)
if err != nil {
- return s.redirectError(w, req, err)
+ return s.handleError(w, req, err)
} else if userID == "" {
return nil
}
@@ -290,7 +307,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
ti, err := s.GetAuthorizeToken(ctx, req)
if err != nil {
- return s.redirectError(w, req, err)
+ return s.handleError(w, req, err)
}
// If the redirect URI is empty, the default domain provided by the client is used.
@@ -351,7 +368,7 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau
return "", nil, errors.ErrInvalidRequest
}
- userID, err := s.PasswordAuthorizationHandler(username, password)
+ userID, err := s.PasswordAuthorizationHandler(r.Context(), clientID, username, password)
if err != nil {
return "", nil, err
} else if userID == "" {
@@ -362,10 +379,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau
tgr.Scope = r.FormValue("scope")
tgr.RedirectURI = r.FormValue("redirect_uri")
case oauth2.Refreshing:
- tgr.Refresh = r.FormValue("refresh_token")
+ tgr.Refresh, err = s.RefreshTokenResolveHandler(r)
tgr.Scope = r.FormValue("scope")
- if tgr.Refresh == "" {
- return "", nil, errors.ErrInvalidRequest
+ if err != nil {
+ return "", nil, err
}
}
return gt, tgr, nil
@@ -564,27 +581,12 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head
return data, statusCode, re.Header
}
-// BearerAuth parse bearer token
-func (s *Server) BearerAuth(r *http.Request) (string, bool) {
- auth := r.Header.Get("Authorization")
- prefix := "Bearer "
- token := ""
-
- if auth != "" && strings.HasPrefix(auth, prefix) {
- token = auth[len(prefix):]
- } else {
- token = r.FormValue("access_token")
- }
-
- return token, token != ""
-}
-
// ValidationBearerToken validation the bearer tokens
// https://tools.ietf.org/html/rfc6750
func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
ctx := r.Context()
- accessToken, ok := s.BearerAuth(r)
+ accessToken, ok := s.AccessTokenResolveHandler(r)
if !ok {
return nil, errors.ErrInvalidAccessToken
}