diff options
Diffstat (limited to 'vendor/code.superseriousbusiness.org/oauth2')
15 files changed, 280 insertions, 148 deletions
diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/.travis.yml b/vendor/code.superseriousbusiness.org/oauth2/v4/.travis.yml new file mode 100644 index 000000000..4180c8dcb --- /dev/null +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/.travis.yml @@ -0,0 +1,13 @@ +language: go +sudo: false +go_import_path: github.com/go-oauth2/oauth2/v4 +go: + - 1.13 +before_install: + - go get -t -v ./... + +script: + - chmod +x ./go.test.sh && ./go.test.sh + +after_success: + - bash <(curl -s https://codecov.io/bash)
\ No newline at end of file diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/README.MD b/vendor/code.superseriousbusiness.org/oauth2/v4/README.MD new file mode 100644 index 000000000..b7145e2f9 --- /dev/null +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/README.MD @@ -0,0 +1,9 @@ +# oauth2 + +Fork of https://github.com/go-oauth2/oauth2 + +You should probably use the original upstream library, this is largely for our own usecase in [GoToSocial](https://codeberg.org/superseriousbusiness/gotosocial), and provides zero compatibility guarantees between versions. + +Versioning is a little complex but attempts to follow upstream as `ssb-v4.5.3-x` where `ssb` = superseriousbusiness, `v4.5.3` is the upstream version, and `x` is our own revision starting at 1 for each fresh upstream release version. + +A copy of the upstream development branch can be found at `upstream-main`, and our main branch will (MUST!) be kept up-to-date with this with regular rebases.
\ No newline at end of file diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/README.md b/vendor/code.superseriousbusiness.org/oauth2/v4/README.md deleted file mode 100644 index 25297aca0..000000000 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Golang OAuth 2.0 Server - -Forked from [go-oauth2](https://github.com/go-oauth2/oauth2). diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/generates/access.go b/vendor/code.superseriousbusiness.org/oauth2/v4/generates/access.go index 972b5dce1..ca66f840a 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/generates/access.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/generates/access.go @@ -1,38 +1,38 @@ -package generates - -import ( - "bytes" - "context" - "encoding/base64" - "strconv" - "strings" - - "code.superseriousbusiness.org/oauth2/v4" - "github.com/google/uuid" -) - -// NewAccessGenerate create to generate the access token instance -func NewAccessGenerate() *AccessGenerate { - return &AccessGenerate{} -} - -// AccessGenerate generate the access token -type AccessGenerate struct { -} - -// Token based on the UUID generated token -func (ag *AccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) { - buf := bytes.NewBufferString(data.Client.GetID()) - buf.WriteString(data.UserID) - buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10)) - - access := base64.URLEncoding.EncodeToString([]byte(uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes()).String())) - access = strings.ToUpper(strings.TrimRight(access, "=")) - refresh := "" - if isGenRefresh { - refresh = base64.URLEncoding.EncodeToString([]byte(uuid.NewSHA1(uuid.Must(uuid.NewRandom()), buf.Bytes()).String())) - refresh = strings.ToUpper(strings.TrimRight(refresh, "=")) - } - - return access, refresh, nil -} +package generates
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "strconv"
+ "strings"
+
+ "code.superseriousbusiness.org/oauth2/v4"
+ "github.com/google/uuid"
+)
+
+// NewAccessGenerate create to generate the access token instance
+func NewAccessGenerate() *AccessGenerate {
+ return &AccessGenerate{}
+}
+
+// AccessGenerate generate the access token
+type AccessGenerate struct {
+}
+
+// Token based on the UUID generated token
+func (ag *AccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
+ buf := bytes.NewBufferString(data.Client.GetID())
+ buf.WriteString(data.UserID)
+ buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10))
+
+ access := base64.URLEncoding.EncodeToString([]byte(uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes()).String()))
+ access = strings.ToUpper(strings.TrimRight(access, "="))
+ refresh := ""
+ if isGenRefresh {
+ refresh = base64.URLEncoding.EncodeToString([]byte(uuid.NewSHA1(uuid.Must(uuid.NewRandom()), buf.Bytes()).String()))
+ refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
+ }
+
+ return access, refresh, nil
+}
diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/generates/authorize.go b/vendor/code.superseriousbusiness.org/oauth2/v4/generates/authorize.go index 9d8f3fb45..0a4784903 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/generates/authorize.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/generates/authorize.go @@ -1,30 +1,30 @@ -package generates - -import ( - "bytes" - "context" - "encoding/base64" - "strings" - - "code.superseriousbusiness.org/oauth2/v4" - "github.com/google/uuid" -) - -// NewAuthorizeGenerate create to generate the authorize code instance -func NewAuthorizeGenerate() *AuthorizeGenerate { - return &AuthorizeGenerate{} -} - -// AuthorizeGenerate generate the authorize code -type AuthorizeGenerate struct{} - -// Token based on the UUID generated token -func (ag *AuthorizeGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic) (string, error) { - buf := bytes.NewBufferString(data.Client.GetID()) - buf.WriteString(data.UserID) - token := uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes()) - code := base64.URLEncoding.EncodeToString([]byte(token.String())) - code = strings.ToUpper(strings.TrimRight(code, "=")) - - return code, nil -} +package generates
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "strings"
+
+ "code.superseriousbusiness.org/oauth2/v4"
+ "github.com/google/uuid"
+)
+
+// NewAuthorizeGenerate create to generate the authorize code instance
+func NewAuthorizeGenerate() *AuthorizeGenerate {
+ return &AuthorizeGenerate{}
+}
+
+// AuthorizeGenerate generate the authorize code
+type AuthorizeGenerate struct{}
+
+// Token based on the UUID generated token
+func (ag *AuthorizeGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic) (string, error) {
+ buf := bytes.NewBufferString(data.Client.GetID())
+ buf.WriteString(data.UserID)
+ token := uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes())
+ code := base64.URLEncoding.EncodeToString([]byte(token.String()))
+ code = strings.ToUpper(strings.TrimRight(code, "="))
+
+ return code, nil
+}
diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/generates/jwt_access.go b/vendor/code.superseriousbusiness.org/oauth2/v4/generates/jwt_access.go index 57c2950f0..10021812b 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/generates/jwt_access.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/generates/jwt_access.go @@ -8,18 +8,18 @@ import ( "code.superseriousbusiness.org/oauth2/v4" "code.superseriousbusiness.org/oauth2/v4/errors" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) // JWTAccessClaims jwt claims type JWTAccessClaims struct { - jwt.StandardClaims + jwt.RegisteredClaims } // Valid claims verification func (a *JWTAccessClaims) Valid() error { - if time.Unix(a.ExpiresAt, 0).Before(time.Now()) { + if a.ExpiresAt != nil && time.Unix(a.ExpiresAt.Unix(), 0).Before(time.Now()) { return errors.ErrInvalidAccessToken } return nil @@ -44,10 +44,10 @@ type JWTAccessGenerate struct { // Token based on the UUID generated token func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) { claims := &JWTAccessClaims{ - StandardClaims: jwt.StandardClaims{ - Audience: data.Client.GetID(), + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{data.Client.GetID()}, Subject: data.UserID, - ExpiresAt: data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn()).Unix(), + ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())), }, } @@ -70,6 +70,12 @@ func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasi key = v } else if a.isHs() { key = a.SignedKey + } else if a.isEd() { + v, err := jwt.ParseEdPrivateKeyFromPEM(a.SignedKey) + if err != nil { + return "", "", err + } + key = v } else { return "", "", errors.New("unsupported sign method") } @@ -102,3 +108,7 @@ func (a *JWTAccessGenerate) isRsOrPS() bool { func (a *JWTAccessGenerate) isHs() bool { return strings.HasPrefix(a.SignedMethod.Alg(), "HS") } + +func (a *JWTAccessGenerate) isEd() bool { + return strings.HasPrefix(a.SignedMethod.Alg(), "Ed") +} diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/manage.go b/vendor/code.superseriousbusiness.org/oauth2/v4/manage.go index 5c0bdf871..23f2b3d31 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/manage.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/manage.go @@ -31,7 +31,7 @@ type Manager interface { GenerateAuthToken(ctx context.Context, rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error) // generate the access token - GenerateAccessToken(ctx context.Context, rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) + GenerateAccessToken(ctx context.Context, gt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) // refreshing an access token RefreshAccessToken(ctx context.Context, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go b/vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go index db9aba614..0c2d8a48a 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go @@ -2,6 +2,7 @@ package manage import ( "context" + "net/url" "time" "code.superseriousbusiness.org/oauth2/v4" @@ -34,6 +35,7 @@ type Manager struct { gtcfg map[oauth2.GrantType]*Config rcfg *RefreshingConfig validateURI ValidateURIHandler + extractExtension ExtractExtensionHandler authorizeGenerate oauth2.AuthorizeGenerate accessGenerate oauth2.AccessGenerate tokenStore oauth2.TokenStore @@ -93,6 +95,11 @@ func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { m.validateURI = handler } +// SetExtractExtensionHandler set the token extension extractor +func (m *Manager) SetExtractExtensionHandler(handler ExtractExtensionHandler) { + m.extractExtension = handler +} + // MapAuthorizeGenerate mapping the authorize code generate interface func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { m.authorizeGenerate = gen @@ -152,6 +159,9 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, } ti := models.NewToken() + if m.extractExtension != nil { + m.extractExtension(tgr, ti) + } ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) @@ -296,6 +306,12 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, } } + if gt == oauth2.ClientCredentials && cli.IsPublic() == true { + return nil, errors.ErrInvalidClient + } + + var extension url.Values + if gt == oauth2.AuthorizationCode { ti, err := m.getAndDelAuthorizationCode(ctx, tgr) if err != nil { @@ -309,9 +325,16 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, if exp := ti.GetAccessExpiresIn(); exp > 0 { tgr.AccessTokenExp = exp } + if eti, ok := ti.(oauth2.ExtendableTokenInfo); ok { + extension = eti.GetExtension() + } } ti := models.NewToken() + ti.SetExtension(extension) + if m.extractExtension != nil { + m.extractExtension(tgr, ti) + } ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) @@ -360,22 +383,14 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, // RefreshAccessToken refreshing an access token func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) + ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { return nil, err - } else if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { - if !cliPass.VerifyPassword(tgr.ClientSecret) { - return nil, errors.ErrInvalidClient - } - } else if tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient } - ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) + cli, err := m.GetClient(ctx, ti.GetClientID()) if err != nil { return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidRefreshToken } oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/manage/util.go b/vendor/code.superseriousbusiness.org/oauth2/v4/manage/util.go index fc4c4b610..733f880f5 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/manage/util.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/manage/util.go @@ -1,6 +1,7 @@ package manage import ( + "code.superseriousbusiness.org/oauth2/v4" "net/url" "strings" @@ -9,7 +10,8 @@ import ( type ( // ValidateURIHandler validates that redirectURI is contained in baseURI - ValidateURIHandler func(baseURI, redirectURI string) error + ValidateURIHandler func(baseURI, redirectURI string) error + ExtractExtensionHandler func(*oauth2.TokenGenerateRequest, oauth2.ExtendableTokenInfo) ) // DefaultValidateURI validates that redirectURI is contained in baseURI diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/model.go b/vendor/code.superseriousbusiness.org/oauth2/v4/model.go index 121a42d67..9073e40ef 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/model.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/model.go @@ -1,6 +1,7 @@ package oauth2 import ( + "net/url" "time" ) @@ -10,6 +11,7 @@ type ( GetID() string GetSecret() string GetDomain() string + IsPublic() bool GetUserID() string } @@ -56,4 +58,10 @@ type ( GetRefreshExpiresIn() time.Duration SetRefreshExpiresIn(time.Duration) } + + ExtendableTokenInfo interface { + TokenInfo + GetExtension() url.Values + SetExtension(url.Values) + } ) diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/models/client.go b/vendor/code.superseriousbusiness.org/oauth2/v4/models/client.go index 2006b6712..c31ad7fb0 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/models/client.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/models/client.go @@ -1,46 +1,35 @@ package models // Client client model -type Client interface { - GetID() string - GetSecret() string - GetDomain() string - GetUserID() string -} - -func New(id string, secret string, domain string, userID string) Client { - return &simpleClient{ - id: id, - secret: secret, - domain: domain, - userID: userID, - } -} - -// simpleClient is a very simple client model that satisfies the Client interface -type simpleClient struct { - id string - secret string - domain string - userID string +type Client struct { + ID string + Secret string + Domain string + Public bool + UserID string } // GetID client id -func (c *simpleClient) GetID() string { - return c.id +func (c *Client) GetID() string { + return c.ID } // GetSecret client secret -func (c *simpleClient) GetSecret() string { - return c.secret +func (c *Client) GetSecret() string { + return c.Secret } // GetDomain client domain -func (c *simpleClient) GetDomain() string { - return c.domain +func (c *Client) GetDomain() string { + return c.Domain +} + +// IsPublic public +func (c *Client) IsPublic() bool { + return c.Public } // GetUserID user id -func (c *simpleClient) GetUserID() string { - return c.userID +func (c *Client) GetUserID() string { + return c.UserID } diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/models/token.go b/vendor/code.superseriousbusiness.org/oauth2/v4/models/token.go index e14868e51..94756dcd4 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/models/token.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/models/token.go @@ -1,6 +1,7 @@ package models import ( + "net/url" "time" "code.superseriousbusiness.org/oauth2/v4" @@ -8,7 +9,7 @@ import ( // NewToken create to token model instance func NewToken() *Token { - return &Token{} + return &Token{Extension: make(url.Values)} } // Token token model @@ -28,6 +29,7 @@ type Token struct { Refresh string `bson:"Refresh"` RefreshCreateAt time.Time `bson:"RefreshCreateAt"` RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` + Extension url.Values `bson:"Extension"` } // New create to token model instance @@ -184,3 +186,13 @@ func (t *Token) GetRefreshExpiresIn() time.Duration { func (t *Token) SetRefreshExpiresIn(exp time.Duration) { t.RefreshExpiresIn = exp } + +// GetExtension extension of token +func (t *Token) GetExtension() url.Values { + return t.Extension +} + +// SetExtension set extension of token +func (t *Token) SetExtension(e url.Values) { + t.Extension = e +} diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/server/handler.go b/vendor/code.superseriousbusiness.org/oauth2/v4/server/handler.go index 745716dc5..808b476d8 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/server/handler.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/server/handler.go @@ -1,7 +1,9 @@ package server import ( + "context" "net/http" + "strings" "time" "code.superseriousbusiness.org/oauth2/v4" @@ -22,7 +24,7 @@ type ( UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) // PasswordAuthorizationHandler get user id from username and password - PasswordAuthorizationHandler func(username, password string) (userID string, err error) + PasswordAuthorizationHandler func(ctx context.Context, clientID, username, password string) (userID string, err error) // RefreshingScopeHandler check the scope of the refreshing token RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) @@ -36,6 +38,9 @@ type ( // InternalErrorHandler internal error handing InternalErrorHandler func(err error) (re *errors.Response) + // PreRedirectErrorHandler is used to override "redirect-on-error" behavior + PreRedirectErrorHandler func(w http.ResponseWriter, req *AuthorizeRequest, err error) error + // AuthorizeScopeHandler set the authorized scope AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) @@ -44,6 +49,15 @@ type ( // ExtensionFieldsHandler in response to the access token with the extension of the field ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) + + // ResponseTokenHandler response token handling + ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error + + // Handler to fetch the refresh token from the request + RefreshTokenResolveHandler func(r *http.Request) (string, error) + + // Handler to fetch the access token from the request + AccessTokenResolveHandler func(r *http.Request) (string, bool) ) // ClientFormHandler get client data from form @@ -64,3 +78,44 @@ func ClientBasicHandler(r *http.Request) (string, string, error) { } return username, password, nil } + +func RefreshTokenFormResolveHandler(r *http.Request) (string, error) { + rt := r.FormValue("refresh_token") + if rt == "" { + return "", errors.ErrInvalidRequest + } + + return rt, nil +} + +func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) { + c, err := r.Cookie("refresh_token") + if err != nil { + return "", errors.ErrInvalidRequest + } + + return c.Value, nil +} + +func AccessTokenDefaultResolveHandler(r *http.Request) (string, bool) { + token := "" + auth := r.Header.Get("Authorization") + prefix := "Bearer " + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +func AccessTokenCookieResolveHandler(r *http.Request) (string, bool) { + c, err := r.Cookie("access_token") + if err != nil { + return "", false + } + + return c.Value, true +} diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go index 91b9effb7..c0a59b755 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "time" "code.superseriousbusiness.org/oauth2/v4" @@ -26,14 +25,16 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server { Manager: manager, } - // default handler + // default handlers srv.ClientInfoHandler = ClientBasicHandler + srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler + srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { + srv.PasswordAuthorizationHandler = func(ctx context.Context, clientID, username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv @@ -49,18 +50,31 @@ type Server struct { UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingValidationHandler RefreshingValidationHandler + PreRedirectErrorHandler PreRedirectErrorHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler + ResponseTokenHandler ResponseTokenHandler + RefreshTokenResolveHandler RefreshTokenResolveHandler + AccessTokenResolveHandler AccessTokenResolveHandler +} + +func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { + if fn := s.PreRedirectErrorHandler; fn != nil { + return fn(w, req, err) + } + + return s.redirectError(w, req, err) } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } + data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } @@ -82,6 +96,9 @@ func (s *Server) tokenError(w http.ResponseWriter, err error) error { } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { + if fn := s.ResponseTokenHandler; fn != nil { + return fn(w, data, header, statusCode...) + } w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") @@ -182,7 +199,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, oauth2.CodeChallengePlain, ) } - if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { + if ccm != "" && !s.CheckCodeChallengeMethod(ccm) { return nil, errors.ErrUnsupportedCodeChallengeMethod } @@ -257,13 +274,13 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) req, err := s.ValidationAuthorizeRequest(r) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } else if userID == "" { return nil } @@ -290,7 +307,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { - return s.redirectError(w, req, err) + return s.handleError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. @@ -351,7 +368,7 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau return "", nil, errors.ErrInvalidRequest } - userID, err := s.PasswordAuthorizationHandler(username, password) + userID, err := s.PasswordAuthorizationHandler(r.Context(), clientID, username, password) if err != nil { return "", nil, err } else if userID == "" { @@ -362,10 +379,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau tgr.Scope = r.FormValue("scope") tgr.RedirectURI = r.FormValue("redirect_uri") case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") + tgr.Refresh, err = s.RefreshTokenResolveHandler(r) tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest + if err != nil { + return "", nil, err } } return gt, tgr, nil @@ -564,27 +581,12 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head return data, statusCode, re.Header } -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() - accessToken, ok := s.BearerAuth(r) + accessToken, ok := s.AccessTokenResolveHandler(r) if !ok { return nil, errors.ErrInvalidAccessToken } diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server_config.go b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server_config.go index 4e8010196..70a8b2c11 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/server/server_config.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/server/server_config.go @@ -69,6 +69,11 @@ func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) { s.InternalErrorHandler = handler } +// SetPreRedirectErrorHandler sets the PreRedirectErrorHandler in current Server instance +func (s *Server) SetPreRedirectErrorHandler(handler PreRedirectErrorHandler) { + s.PreRedirectErrorHandler = handler +} + // SetExtensionFieldsHandler in response to the access token with the extension of the field func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) { s.ExtensionFieldsHandler = handler @@ -83,3 +88,18 @@ func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) { func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { s.AuthorizeScopeHandler = handler } + +// SetResponseTokenHandler response token handing +func (s *Server) SetResponseTokenHandler(handler ResponseTokenHandler) { + s.ResponseTokenHandler = handler +} + +// SetRefreshTokenResolveHandler refresh token resolver +func (s *Server) SetRefreshTokenResolveHandler(handler RefreshTokenResolveHandler) { + s.RefreshTokenResolveHandler = handler +} + +// SetAccessTokenResolveHandler access token resolver +func (s *Server) SetAccessTokenResolveHandler(handler AccessTokenResolveHandler) { + s.AccessTokenResolveHandler = handler +} |
