diff options
author | 2025-03-09 17:47:56 +0100 | |
---|---|---|
committer | 2025-03-10 01:59:49 +0100 | |
commit | 3ac1ee16f377d31a0fb80c8dae28b6239ac4229e (patch) | |
tree | f61faa581feaaeaba2542b9f2b8234a590684413 /vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go | |
parent | [chore] update URLs to forked source (diff) | |
download | gotosocial-3ac1ee16f377d31a0fb80c8dae28b6239ac4229e.tar.xz |
[chore] remove vendor
Diffstat (limited to 'vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go')
-rw-r--r-- | vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go | 589 |
1 files changed, 0 insertions, 589 deletions
diff --git a/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go b/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go deleted file mode 100644 index 05ca19245..000000000 --- a/vendor/github.com/superseriousbusiness/oauth2/v4/server/server.go +++ /dev/null @@ -1,589 +0,0 @@ -package server - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "time" - - "github.com/superseriousbusiness/oauth2/v4" - "github.com/superseriousbusiness/oauth2/v4/errors" -) - -// NewDefaultServer create a default authorization server -func NewDefaultServer(manager oauth2.Manager) *Server { - return NewServer(NewConfig(), manager) -} - -// NewServer create authorization server -func NewServer(cfg *Config, manager oauth2.Manager) *Server { - srv := &Server{ - Config: cfg, - Manager: manager, - } - - // default handler - srv.ClientInfoHandler = ClientBasicHandler - - srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { - return "", errors.ErrAccessDenied - } - - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { - return "", errors.ErrAccessDenied - } - return srv -} - -// Server Provide authorization server -type Server struct { - Config *Config - Manager oauth2.Manager - ClientInfoHandler ClientInfoHandler - ClientAuthorizedHandler ClientAuthorizedHandler - ClientScopeHandler ClientScopeHandler - UserAuthorizationHandler UserAuthorizationHandler - PasswordAuthorizationHandler PasswordAuthorizationHandler - RefreshingValidationHandler RefreshingValidationHandler - RefreshingScopeHandler RefreshingScopeHandler - ResponseErrorHandler ResponseErrorHandler - InternalErrorHandler InternalErrorHandler - ExtensionFieldsHandler ExtensionFieldsHandler - AccessTokenExpHandler AccessTokenExpHandler - AuthorizeScopeHandler AuthorizeScopeHandler -} - -func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { - if req == nil { - return err - } - data, _, _ := s.GetErrorData(err) - return s.redirect(w, req, data) -} - -func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { - uri, err := s.GetRedirectURI(req, data) - if err != nil { - return err - } - - w.Header().Set("Location", uri) - w.WriteHeader(302) - return nil -} - -func (s *Server) tokenError(w http.ResponseWriter, err error) error { - data, statusCode, header := s.GetErrorData(err) - return s.token(w, data, header, statusCode) -} - -func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { - w.Header().Set("Content-Type", "application/json;charset=UTF-8") - w.Header().Set("Cache-Control", "no-store") - w.Header().Set("Pragma", "no-cache") - - for key := range header { - w.Header().Set(key, header.Get(key)) - } - - status := http.StatusOK - if len(statusCode) > 0 && statusCode[0] > 0 { - status = statusCode[0] - } - - w.WriteHeader(status) - return json.NewEncoder(w).Encode(data) -} - -// GetRedirectURI get redirect uri -func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { - u, err := url.Parse(req.RedirectURI) - if err != nil { - return "", err - } - - q := u.Query() - if req.State != "" { - q.Set("state", req.State) - } - - for k, v := range data { - q.Set(k, fmt.Sprint(v)) - } - - switch req.ResponseType { - case oauth2.Code: - u.RawQuery = q.Encode() - case oauth2.Token: - u.RawQuery = "" - fragment, err := url.QueryUnescape(q.Encode()) - if err != nil { - return "", err - } - u.Fragment = fragment - } - - return u.String(), nil -} - -// CheckResponseType check allows response type -func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.Config.AllowedResponseTypes { - if art == rt { - return true - } - } - return false -} - -// CheckCodeChallengeMethod checks for allowed code challenge method -func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { - for _, c := range s.Config.AllowedCodeChallengeMethods { - if c == ccm { - return true - } - } - return false -} - -// ValidationAuthorizeRequest the authorization request validation -func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { - redirectURI := r.FormValue("redirect_uri") - clientID := r.FormValue("client_id") - if !(r.Method == "GET" || r.Method == "POST") || - clientID == "" { - return nil, errors.ErrInvalidRequest - } - - resType := oauth2.ResponseType(r.FormValue("response_type")) - if resType.String() == "" { - return nil, errors.ErrUnsupportedResponseType - } else if allowed := s.CheckResponseType(resType); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - cc := r.FormValue("code_challenge") - if cc == "" && s.Config.ForcePKCE { - return nil, errors.ErrCodeChallengeRquired - } - if cc != "" && (len(cc) < 43 || len(cc) > 128) { - return nil, errors.ErrInvalidCodeChallengeLen - } - - ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) - // set default - if ccm == "" { - ccm = oauth2.CodeChallengePlain - } - if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { - return nil, errors.ErrUnsupportedCodeChallengeMethod - } - - req := &AuthorizeRequest{ - RedirectURI: redirectURI, - ResponseType: resType, - ClientID: clientID, - State: r.FormValue("state"), - Scope: r.FormValue("scope"), - Request: r, - CodeChallenge: cc, - CodeChallengeMethod: ccm, - } - return req, nil -} - -// GetAuthorizeToken get authorization token(code) -func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { - // check the client allows the grant type - if fn := s.ClientAuthorizedHandler; fn != nil { - gt := oauth2.AuthorizationCode - if req.ResponseType == oauth2.Token { - gt = oauth2.Implicit - } - - allowed, err := fn(req.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, - } - - // check the client allows the authorized scope - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - tgr.CodeChallenge = req.CodeChallenge - tgr.CodeChallengeMethod = req.CodeChallengeMethod - - return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) -} - -// GetAuthorizeData get authorization response data -func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { - if rt == oauth2.Code { - return map[string]interface{}{ - "code": ti.GetCode(), - } - } - return s.GetTokenData(ti) -} - -// HandleAuthorizeRequest the authorization request handling -func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - req, err := s.ValidationAuthorizeRequest(r) - if err != nil { - return s.redirectError(w, req, err) - } - - // user authorization - userID, err := s.UserAuthorizationHandler(w, r) - if err != nil { - return s.redirectError(w, req, err) - } else if userID == "" { - return nil - } - req.UserID = userID - - // specify the scope of authorization - if fn := s.AuthorizeScopeHandler; fn != nil { - scope, err := fn(w, r) - if err != nil { - return err - } else if scope != "" { - req.Scope = scope - } - } - - // specify the expiration time of access token - if fn := s.AccessTokenExpHandler; fn != nil { - exp, err := fn(w, r) - if err != nil { - return err - } - req.AccessTokenExp = exp - } - - ti, err := s.GetAuthorizeToken(ctx, req) - if err != nil { - return s.redirectError(w, req, err) - } - - // If the redirect URI is empty, the default domain provided by the client is used. - if req.RedirectURI == "" { - client, err := s.Manager.GetClient(ctx, req.ClientID) - if err != nil { - return err - } - req.RedirectURI = client.GetDomain() - } - - return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) -} - -// ValidationTokenRequest the token request validation -func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { - if v := r.Method; !(v == "POST" || - (s.Config.AllowGetAccessRequest && v == "GET")) { - return "", nil, errors.ErrInvalidRequest - } - - gt := oauth2.GrantType(r.FormValue("grant_type")) - if gt.String() == "" { - return "", nil, errors.ErrUnsupportedGrantType - } - - if !s.CheckGrantType(gt) { - return "", nil, errors.ErrUnsupportedGrantType - } - - clientID, clientSecret, err := s.ClientInfoHandler(r) - if err != nil { - return "", nil, err - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: clientID, - ClientSecret: clientSecret, - Request: r, - } - - switch gt { - case oauth2.AuthorizationCode: - tgr.RedirectURI = r.FormValue("redirect_uri") - tgr.Code = r.FormValue("code") - if tgr.RedirectURI == "" || - tgr.Code == "" { - return "", nil, errors.ErrInvalidRequest - } - tgr.CodeVerifier = r.FormValue("code_verifier") - if s.Config.ForcePKCE && tgr.CodeVerifier == "" { - return "", nil, errors.ErrInvalidRequest - } - case oauth2.PasswordCredentials: - tgr.Scope = r.FormValue("scope") - username, password := r.FormValue("username"), r.FormValue("password") - if username == "" || password == "" { - return "", nil, errors.ErrInvalidRequest - } - - userID, err := s.PasswordAuthorizationHandler(username, password) - if err != nil { - return "", nil, err - } else if userID == "" { - return "", nil, errors.ErrInvalidGrant - } - tgr.UserID = userID - case oauth2.ClientCredentials: - tgr.Scope = r.FormValue("scope") - tgr.RedirectURI = r.FormValue("redirect_uri") - case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") - tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest - } - } - return gt, tgr, nil -} - -// CheckGrantType check allows grant type -func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { - for _, agt := range s.Config.AllowedGrantTypes { - if agt == gt { - return true - } - } - return false -} - -// GetAccessToken access token -func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, - error) { - if allowed := s.CheckGrantType(gt); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - if fn := s.ClientAuthorizedHandler; fn != nil { - allowed, err := fn(tgr.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - switch gt { - case oauth2.AuthorizationCode: - ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) - if err != nil { - switch err { - case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge: - return nil, errors.ErrInvalidGrant - case errors.ErrInvalidClient: - return nil, errors.ErrInvalidClient - default: - return nil, err - } - } - return ti, nil - case oauth2.PasswordCredentials, oauth2.ClientCredentials: - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - return s.Manager.GenerateAccessToken(ctx, gt, tgr) - case oauth2.Refreshing: - // check scope - if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { - rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - - allowed, err := scopeFn(tgr, rti.GetScope()) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - if validationFn := s.RefreshingValidationHandler; validationFn != nil { - rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - allowed, err := validationFn(rti) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - ti, err := s.Manager.RefreshAccessToken(ctx, tgr) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - return ti, nil - } - - return nil, errors.ErrUnsupportedGrantType -} - -// GetTokenData token data -func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { - data := map[string]interface{}{ - "access_token": ti.GetAccess(), - "token_type": s.Config.TokenType, - "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), - } - - if scope := ti.GetScope(); scope != "" { - data["scope"] = scope - } - - if refresh := ti.GetRefresh(); refresh != "" { - data["refresh_token"] = refresh - } - - if fn := s.ExtensionFieldsHandler; fn != nil { - ext := fn(ti) - for k, v := range ext { - if _, ok := data[k]; ok { - continue - } - data[k] = v - } - } - return data -} - -// HandleTokenRequest token request handling -func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - gt, tgr, err := s.ValidationTokenRequest(r) - if err != nil { - return s.tokenError(w, err) - } - - ti, err := s.GetAccessToken(ctx, gt, tgr) - if err != nil { - return s.tokenError(w, err) - } - - return s.token(w, s.GetTokenData(ti), nil) -} - -// GetErrorData get error response data -func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { - var re errors.Response - if v, ok := errors.Descriptions[err]; ok { - re.Error = err - re.Description = v - re.StatusCode = errors.StatusCodes[err] - } else { - if fn := s.InternalErrorHandler; fn != nil { - if v := fn(err); v != nil { - re = *v - } - } - - if re.Error == nil { - re.Error = errors.ErrServerError - re.Description = errors.Descriptions[errors.ErrServerError] - re.StatusCode = errors.StatusCodes[errors.ErrServerError] - } - } - - if fn := s.ResponseErrorHandler; fn != nil { - fn(&re) - } - - data := make(map[string]interface{}) - if err := re.Error; err != nil { - data["error"] = err.Error() - } - - if v := re.ErrorCode; v != 0 { - data["error_code"] = v - } - - if v := re.Description; v != "" { - data["error_description"] = v - } - - if v := re.URI; v != "" { - data["error_uri"] = v - } - - statusCode := http.StatusInternalServerError - if v := re.StatusCode; v > 0 { - statusCode = v - } - - return data, statusCode, re.Header -} - -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - -// ValidationBearerToken validation the bearer tokens -// https://tools.ietf.org/html/rfc6750 -func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { - ctx := r.Context() - - accessToken, ok := s.BearerAuth(r) - if !ok { - return nil, errors.ErrInvalidAccessToken - } - - return s.Manager.LoadAccessToken(ctx, accessToken) -} |