diff options
Diffstat (limited to 'vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go')
| -rw-r--r-- | vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go b/vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go index db9aba614..0c2d8a48a 100644 --- a/vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go +++ b/vendor/code.superseriousbusiness.org/oauth2/v4/manage/manager.go @@ -2,6 +2,7 @@ package manage import ( "context" + "net/url" "time" "code.superseriousbusiness.org/oauth2/v4" @@ -34,6 +35,7 @@ type Manager struct { gtcfg map[oauth2.GrantType]*Config rcfg *RefreshingConfig validateURI ValidateURIHandler + extractExtension ExtractExtensionHandler authorizeGenerate oauth2.AuthorizeGenerate accessGenerate oauth2.AccessGenerate tokenStore oauth2.TokenStore @@ -93,6 +95,11 @@ func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { m.validateURI = handler } +// SetExtractExtensionHandler set the token extension extractor +func (m *Manager) SetExtractExtensionHandler(handler ExtractExtensionHandler) { + m.extractExtension = handler +} + // MapAuthorizeGenerate mapping the authorize code generate interface func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { m.authorizeGenerate = gen @@ -152,6 +159,9 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, } ti := models.NewToken() + if m.extractExtension != nil { + m.extractExtension(tgr, ti) + } ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) @@ -296,6 +306,12 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, } } + if gt == oauth2.ClientCredentials && cli.IsPublic() == true { + return nil, errors.ErrInvalidClient + } + + var extension url.Values + if gt == oauth2.AuthorizationCode { ti, err := m.getAndDelAuthorizationCode(ctx, tgr) if err != nil { @@ -309,9 +325,16 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, if exp := ti.GetAccessExpiresIn(); exp > 0 { tgr.AccessTokenExp = exp } + if eti, ok := ti.(oauth2.ExtendableTokenInfo); ok { + extension = eti.GetExtension() + } } ti := models.NewToken() + ti.SetExtension(extension) + if m.extractExtension != nil { + m.extractExtension(tgr, ti) + } ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) @@ -360,22 +383,14 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, // RefreshAccessToken refreshing an access token func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) + ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { return nil, err - } else if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { - if !cliPass.VerifyPassword(tgr.ClientSecret) { - return nil, errors.ErrInvalidClient - } - } else if tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient } - ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) + cli, err := m.GetClient(ctx, ti.GetClientID()) if err != nil { return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidRefreshToken } oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() |
