summaryrefslogtreecommitdiff
path: root/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go')
-rw-r--r--vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go585
1 files changed, 585 insertions, 0 deletions
diff --git a/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go b/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go
new file mode 100644
index 000000000..0aac66ffc
--- /dev/null
+++ b/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go
@@ -0,0 +1,585 @@
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/superseriousbusiness/oauth2/v4"
+ "github.com/superseriousbusiness/oauth2/v4/errors"
+)
+
+// NewDefaultServer create a default authorization server
+func NewDefaultServer(manager oauth2.Manager) *Server {
+ return NewServer(NewConfig(), manager)
+}
+
+// NewServer create authorization server
+func NewServer(cfg *Config, manager oauth2.Manager) *Server {
+ srv := &Server{
+ Config: cfg,
+ Manager: manager,
+ }
+
+ // default handler
+ srv.ClientInfoHandler = ClientBasicHandler
+
+ srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
+ return "", errors.ErrAccessDenied
+ }
+
+ srv.PasswordAuthorizationHandler = func(username, password string) (string, error) {
+ return "", errors.ErrAccessDenied
+ }
+ return srv
+}
+
+// Server Provide authorization server
+type Server struct {
+ Config *Config
+ Manager oauth2.Manager
+ ClientInfoHandler ClientInfoHandler
+ ClientAuthorizedHandler ClientAuthorizedHandler
+ ClientScopeHandler ClientScopeHandler
+ UserAuthorizationHandler UserAuthorizationHandler
+ PasswordAuthorizationHandler PasswordAuthorizationHandler
+ RefreshingValidationHandler RefreshingValidationHandler
+ RefreshingScopeHandler RefreshingScopeHandler
+ ResponseErrorHandler ResponseErrorHandler
+ InternalErrorHandler InternalErrorHandler
+ ExtensionFieldsHandler ExtensionFieldsHandler
+ AccessTokenExpHandler AccessTokenExpHandler
+ AuthorizeScopeHandler AuthorizeScopeHandler
+}
+
+func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
+ if req == nil {
+ return err
+ }
+ data, _, _ := s.GetErrorData(err)
+ return s.redirect(w, req, data)
+}
+
+func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error {
+ uri, err := s.GetRedirectURI(req, data)
+ if err != nil {
+ return err
+ }
+
+ w.Header().Set("Location", uri)
+ w.WriteHeader(302)
+ return nil
+}
+
+func (s *Server) tokenError(w http.ResponseWriter, err error) error {
+ data, statusCode, header := s.GetErrorData(err)
+ return s.token(w, data, header, statusCode)
+}
+
+func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error {
+ w.Header().Set("Content-Type", "application/json;charset=UTF-8")
+ w.Header().Set("Cache-Control", "no-store")
+ w.Header().Set("Pragma", "no-cache")
+
+ for key := range header {
+ w.Header().Set(key, header.Get(key))
+ }
+
+ status := http.StatusOK
+ if len(statusCode) > 0 && statusCode[0] > 0 {
+ status = statusCode[0]
+ }
+
+ w.WriteHeader(status)
+ return json.NewEncoder(w).Encode(data)
+}
+
+// GetRedirectURI get redirect uri
+func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) {
+ u, err := url.Parse(req.RedirectURI)
+ if err != nil {
+ return "", err
+ }
+
+ q := u.Query()
+ if req.State != "" {
+ q.Set("state", req.State)
+ }
+
+ for k, v := range data {
+ q.Set(k, fmt.Sprint(v))
+ }
+
+ switch req.ResponseType {
+ case oauth2.Code:
+ u.RawQuery = q.Encode()
+ case oauth2.Token:
+ u.RawQuery = ""
+ fragment, err := url.QueryUnescape(q.Encode())
+ if err != nil {
+ return "", err
+ }
+ u.Fragment = fragment
+ }
+
+ return u.String(), nil
+}
+
+// CheckResponseType check allows response type
+func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool {
+ for _, art := range s.Config.AllowedResponseTypes {
+ if art == rt {
+ return true
+ }
+ }
+ return false
+}
+
+// CheckCodeChallengeMethod checks for allowed code challenge method
+func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool {
+ for _, c := range s.Config.AllowedCodeChallengeMethods {
+ if c == ccm {
+ return true
+ }
+ }
+ return false
+}
+
+// ValidationAuthorizeRequest the authorization request validation
+func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) {
+ redirectURI := r.FormValue("redirect_uri")
+ clientID := r.FormValue("client_id")
+ if !(r.Method == "GET" || r.Method == "POST") ||
+ clientID == "" {
+ fmt.Println(r.Method, clientID, r)
+ return nil, errors.ErrInvalidRequest
+ }
+
+ resType := oauth2.ResponseType(r.FormValue("response_type"))
+ if resType.String() == "" {
+ return nil, errors.ErrUnsupportedResponseType
+ } else if allowed := s.CheckResponseType(resType); !allowed {
+ return nil, errors.ErrUnauthorizedClient
+ }
+
+ cc := r.FormValue("code_challenge")
+ if cc == "" && s.Config.ForcePKCE {
+ return nil, errors.ErrCodeChallengeRquired
+ }
+ if cc != "" && (len(cc) < 43 || len(cc) > 128) {
+ return nil, errors.ErrInvalidCodeChallengeLen
+ }
+
+ ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method"))
+ // set default
+ if ccm == "" {
+ ccm = oauth2.CodeChallengePlain
+ }
+ if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) {
+ return nil, errors.ErrUnsupportedCodeChallengeMethod
+ }
+
+ req := &AuthorizeRequest{
+ RedirectURI: redirectURI,
+ ResponseType: resType,
+ ClientID: clientID,
+ State: r.FormValue("state"),
+ Scope: r.FormValue("scope"),
+ Request: r,
+ CodeChallenge: cc,
+ CodeChallengeMethod: ccm,
+ }
+ return req, nil
+}
+
+// GetAuthorizeToken get authorization token(code)
+func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) {
+ // check the client allows the grant type
+ if fn := s.ClientAuthorizedHandler; fn != nil {
+ gt := oauth2.AuthorizationCode
+ if req.ResponseType == oauth2.Token {
+ gt = oauth2.Implicit
+ }
+
+ allowed, err := fn(req.ClientID, gt)
+ if err != nil {
+ return nil, err
+ } else if !allowed {
+ return nil, errors.ErrUnauthorizedClient
+ }
+ }
+
+ // check the client allows the authorized scope
+ if fn := s.ClientScopeHandler; fn != nil {
+ allowed, err := fn(req.ClientID, req.Scope)
+ if err != nil {
+ return nil, err
+ } else if !allowed {
+ return nil, errors.ErrInvalidScope
+ }
+ }
+
+ tgr := &oauth2.TokenGenerateRequest{
+ ClientID: req.ClientID,
+ UserID: req.UserID,
+ RedirectURI: req.RedirectURI,
+ Scope: req.Scope,
+ AccessTokenExp: req.AccessTokenExp,
+ Request: req.Request,
+ CodeChallenge: req.CodeChallenge,
+ CodeChallengeMethod: req.CodeChallengeMethod,
+ }
+ return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr)
+}
+
+// GetAuthorizeData get authorization response data
+func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} {
+ if rt == oauth2.Code {
+ return map[string]interface{}{
+ "code": ti.GetCode(),
+ }
+ }
+ return s.GetTokenData(ti)
+}
+
+// HandleAuthorizeRequest the authorization request handling
+func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
+ ctx := r.Context()
+
+ req, err := s.ValidationAuthorizeRequest(r)
+ if err != nil {
+ return s.redirectError(w, req, err)
+ }
+
+ // user authorization
+ userID, err := s.UserAuthorizationHandler(w, r)
+ if err != nil {
+ return s.redirectError(w, req, err)
+ } else if userID == "" {
+ return nil
+ }
+ req.UserID = userID
+
+ // specify the scope of authorization
+ if fn := s.AuthorizeScopeHandler; fn != nil {
+ scope, err := fn(w, r)
+ if err != nil {
+ return err
+ } else if scope != "" {
+ req.Scope = scope
+ }
+ }
+
+ // specify the expiration time of access token
+ if fn := s.AccessTokenExpHandler; fn != nil {
+ exp, err := fn(w, r)
+ if err != nil {
+ return err
+ }
+ req.AccessTokenExp = exp
+ }
+
+ ti, err := s.GetAuthorizeToken(ctx, req)
+ if err != nil {
+ return s.redirectError(w, req, err)
+ }
+
+ // If the redirect URI is empty, the default domain provided by the client is used.
+ if req.RedirectURI == "" {
+ client, err := s.Manager.GetClient(ctx, req.ClientID)
+ if err != nil {
+ return err
+ }
+ req.RedirectURI = client.GetDomain()
+ }
+
+ return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
+}
+
+// ValidationTokenRequest the token request validation
+func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) {
+ if v := r.Method; !(v == "POST" ||
+ (s.Config.AllowGetAccessRequest && v == "GET")) {
+ return "", nil, errors.ErrInvalidRequest
+ }
+
+ gt := oauth2.GrantType(r.FormValue("grant_type"))
+ if gt.String() == "" {
+ return "", nil, errors.ErrUnsupportedGrantType
+ }
+
+ codeVer := r.FormValue("code_verifier")
+ if s.Config.ForcePKCE && codeVer == "" {
+ return "", nil, errors.ErrInvalidRequest
+ }
+
+ clientID, clientSecret, err := s.ClientInfoHandler(r)
+ if err != nil {
+ return "", nil, err
+ }
+
+ tgr := &oauth2.TokenGenerateRequest{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ Request: r,
+ }
+
+ switch gt {
+ case oauth2.AuthorizationCode:
+ tgr.RedirectURI = r.FormValue("redirect_uri")
+ tgr.Code = r.FormValue("code")
+ if tgr.RedirectURI == "" ||
+ tgr.Code == "" {
+ return "", nil, errors.ErrInvalidRequest
+ }
+ tgr.CodeVerifier = codeVer
+ case oauth2.PasswordCredentials:
+ tgr.Scope = r.FormValue("scope")
+ username, password := r.FormValue("username"), r.FormValue("password")
+ if username == "" || password == "" {
+ return "", nil, errors.ErrInvalidRequest
+ }
+
+ userID, err := s.PasswordAuthorizationHandler(username, password)
+ if err != nil {
+ return "", nil, err
+ } else if userID == "" {
+ return "", nil, errors.ErrInvalidGrant
+ }
+ tgr.UserID = userID
+ case oauth2.ClientCredentials:
+ tgr.Scope = r.FormValue("scope")
+ case oauth2.Refreshing:
+ tgr.Refresh = r.FormValue("refresh_token")
+ tgr.Scope = r.FormValue("scope")
+ if tgr.Refresh == "" {
+ return "", nil, errors.ErrInvalidRequest
+ }
+ }
+ return gt, tgr, nil
+}
+
+// CheckGrantType check allows grant type
+func (s *Server) CheckGrantType(gt oauth2.GrantType) bool {
+ for _, agt := range s.Config.AllowedGrantTypes {
+ if agt == gt {
+ return true
+ }
+ }
+ return false
+}
+
+// GetAccessToken access token
+func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
+ if allowed := s.CheckGrantType(gt); !allowed {
+ return nil, errors.ErrUnauthorizedClient
+ }
+
+ if fn := s.ClientAuthorizedHandler; fn != nil {
+ allowed, err := fn(tgr.ClientID, gt)
+ if err != nil {
+ return nil, err
+ } else if !allowed {
+ return nil, errors.ErrUnauthorizedClient
+ }
+ }
+
+ switch gt {
+ case oauth2.AuthorizationCode:
+ ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr)
+ if err != nil {
+ switch err {
+ case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge,
+ errors.ErrMissingCodeChallenge, errors.ErrMissingCodeChallenge:
+ return nil, errors.ErrInvalidGrant
+ case errors.ErrInvalidClient:
+ return nil, errors.ErrInvalidClient
+ default:
+ return nil, err
+ }
+ }
+ return ti, nil
+ case oauth2.PasswordCredentials, oauth2.ClientCredentials:
+ if fn := s.ClientScopeHandler; fn != nil {
+ allowed, err := fn(tgr.ClientID, tgr.Scope)
+ if err != nil {
+ return nil, err
+ } else if !allowed {
+ return nil, errors.ErrInvalidScope
+ }
+ }
+ return s.Manager.GenerateAccessToken(ctx, gt, tgr)
+ case oauth2.Refreshing:
+ // check scope
+ if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil {
+ rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
+ if err != nil {
+ if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
+ return nil, errors.ErrInvalidGrant
+ }
+ return nil, err
+ }
+
+ allowed, err := scopeFn(scope, rti.GetScope())
+ if err != nil {
+ return nil, err
+ } else if !allowed {
+ return nil, errors.ErrInvalidScope
+ }
+ }
+
+ if validationFn := s.RefreshingValidationHandler; validationFn != nil {
+ rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
+ if err != nil {
+ if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
+ return nil, errors.ErrInvalidGrant
+ }
+ return nil, err
+ }
+ allowed, err := validationFn(rti)
+ if err != nil {
+ return nil, err
+ } else if !allowed {
+ return nil, errors.ErrInvalidScope
+ }
+ }
+
+ ti, err := s.Manager.RefreshAccessToken(ctx, tgr)
+ if err != nil {
+ if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
+ return nil, errors.ErrInvalidGrant
+ }
+ return nil, err
+ }
+ return ti, nil
+ }
+
+ return nil, errors.ErrUnsupportedGrantType
+}
+
+// GetTokenData token data
+func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} {
+ data := map[string]interface{}{
+ "access_token": ti.GetAccess(),
+ "token_type": s.Config.TokenType,
+ "expires_in": int64(ti.GetAccessExpiresIn() / time.Second),
+ }
+
+ if scope := ti.GetScope(); scope != "" {
+ data["scope"] = scope
+ }
+
+ if refresh := ti.GetRefresh(); refresh != "" {
+ data["refresh_token"] = refresh
+ }
+
+ if fn := s.ExtensionFieldsHandler; fn != nil {
+ ext := fn(ti)
+ for k, v := range ext {
+ if _, ok := data[k]; ok {
+ continue
+ }
+ data[k] = v
+ }
+ }
+ return data
+}
+
+// HandleTokenRequest token request handling
+func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
+ ctx := r.Context()
+
+ gt, tgr, err := s.ValidationTokenRequest(r)
+ if err != nil {
+ return s.tokenError(w, err)
+ }
+
+ ti, err := s.GetAccessToken(ctx, gt, tgr)
+ if err != nil {
+ return s.tokenError(w, err)
+ }
+
+ return s.token(w, s.GetTokenData(ti), nil)
+}
+
+// GetErrorData get error response data
+func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) {
+ var re errors.Response
+ if v, ok := errors.Descriptions[err]; ok {
+ re.Error = err
+ re.Description = v
+ re.StatusCode = errors.StatusCodes[err]
+ } else {
+ if fn := s.InternalErrorHandler; fn != nil {
+ if v := fn(err); v != nil {
+ re = *v
+ }
+ }
+
+ if re.Error == nil {
+ re.Error = errors.ErrServerError
+ re.Description = errors.Descriptions[errors.ErrServerError]
+ re.StatusCode = errors.StatusCodes[errors.ErrServerError]
+ }
+ }
+
+ if fn := s.ResponseErrorHandler; fn != nil {
+ fn(&re)
+ }
+
+ data := make(map[string]interface{})
+ if err := re.Error; err != nil {
+ data["error"] = err.Error()
+ }
+
+ if v := re.ErrorCode; v != 0 {
+ data["error_code"] = v
+ }
+
+ if v := re.Description; v != "" {
+ data["error_description"] = v
+ }
+
+ if v := re.URI; v != "" {
+ data["error_uri"] = v
+ }
+
+ statusCode := http.StatusInternalServerError
+ if v := re.StatusCode; v > 0 {
+ statusCode = v
+ }
+
+ return data, statusCode, re.Header
+}
+
+// BearerAuth parse bearer token
+func (s *Server) BearerAuth(r *http.Request) (string, bool) {
+ auth := r.Header.Get("Authorization")
+ prefix := "Bearer "
+ token := ""
+
+ if auth != "" && strings.HasPrefix(auth, prefix) {
+ token = auth[len(prefix):]
+ } else {
+ token = r.FormValue("access_token")
+ }
+
+ return token, token != ""
+}
+
+// ValidationBearerToken validation the bearer tokens
+// https://tools.ietf.org/html/rfc6750
+func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
+ ctx := r.Context()
+
+ accessToken, ok := s.BearerAuth(r)
+ if !ok {
+ return nil, errors.ErrInvalidAccessToken
+ }
+
+ return s.Manager.LoadAccessToken(ctx, accessToken)
+}