diff options
Diffstat (limited to 'vendor/github.com/superseriousbusiness/oauth2/v4/manage/manager.go')
-rw-r--r-- | vendor/github.com/superseriousbusiness/oauth2/v4/manage/manager.go | 1004 |
1 files changed, 504 insertions, 500 deletions
diff --git a/vendor/github.com/superseriousbusiness/oauth2/v4/manage/manager.go b/vendor/github.com/superseriousbusiness/oauth2/v4/manage/manager.go index 41811b5ba..063816780 100644 --- a/vendor/github.com/superseriousbusiness/oauth2/v4/manage/manager.go +++ b/vendor/github.com/superseriousbusiness/oauth2/v4/manage/manager.go @@ -1,500 +1,504 @@ -package manage - -import ( - "context" - "time" - - "github.com/superseriousbusiness/oauth2/v4" - "github.com/superseriousbusiness/oauth2/v4/errors" - "github.com/superseriousbusiness/oauth2/v4/generates" - "github.com/superseriousbusiness/oauth2/v4/models" -) - -// NewDefaultManager create to default authorization management instance -func NewDefaultManager() *Manager { - m := NewManager() - // default implementation - m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) - m.MapAccessGenerate(generates.NewAccessGenerate()) - - return m -} - -// NewManager create to authorization management instance -func NewManager() *Manager { - return &Manager{ - gtcfg: make(map[oauth2.GrantType]*Config), - validateURI: DefaultValidateURI, - } -} - -// Manager provide authorization management -type Manager struct { - codeExp time.Duration - gtcfg map[oauth2.GrantType]*Config - rcfg *RefreshingConfig - validateURI ValidateURIHandler - authorizeGenerate oauth2.AuthorizeGenerate - accessGenerate oauth2.AccessGenerate - tokenStore oauth2.TokenStore - clientStore oauth2.ClientStore -} - -// get grant type config -func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { - if c, ok := m.gtcfg[gt]; ok && c != nil { - return c - } - switch gt { - case oauth2.AuthorizationCode: - return DefaultAuthorizeCodeTokenCfg - case oauth2.Implicit: - return DefaultImplicitTokenCfg - case oauth2.PasswordCredentials: - return DefaultPasswordTokenCfg - case oauth2.ClientCredentials: - return DefaultClientTokenCfg - } - return &Config{} -} - -// SetAuthorizeCodeExp set the authorization code expiration time -func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { - m.codeExp = exp -} - -// SetAuthorizeCodeTokenCfg set the authorization code grant token config -func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { - m.gtcfg[oauth2.AuthorizationCode] = cfg -} - -// SetImplicitTokenCfg set the implicit grant token config -func (m *Manager) SetImplicitTokenCfg(cfg *Config) { - m.gtcfg[oauth2.Implicit] = cfg -} - -// SetPasswordTokenCfg set the password grant token config -func (m *Manager) SetPasswordTokenCfg(cfg *Config) { - m.gtcfg[oauth2.PasswordCredentials] = cfg -} - -// SetClientTokenCfg set the client grant token config -func (m *Manager) SetClientTokenCfg(cfg *Config) { - m.gtcfg[oauth2.ClientCredentials] = cfg -} - -// SetRefreshTokenCfg set the refreshing token config -func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { - m.rcfg = cfg -} - -// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI -func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { - m.validateURI = handler -} - -// MapAuthorizeGenerate mapping the authorize code generate interface -func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { - m.authorizeGenerate = gen -} - -// MapAccessGenerate mapping the access token generate interface -func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { - m.accessGenerate = gen -} - -// MapClientStorage mapping the client store interface -func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { - m.clientStore = stor -} - -// MustClientStorage mandatory mapping the client store interface -func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { - if err != nil { - panic(err.Error()) - } - m.clientStore = stor -} - -// MapTokenStorage mapping the token store interface -func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { - m.tokenStore = stor -} - -// MustTokenStorage mandatory mapping the token store interface -func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { - if err != nil { - panic(err) - } - m.tokenStore = stor -} - -// GetClient get the client information -func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { - cli, err = m.clientStore.GetByID(ctx, clientID) - if err != nil { - return - } else if cli == nil { - err = errors.ErrInvalidClient - } - return -} - -// GenerateAuthToken generate the authorization token(code) -func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } else if tgr.RedirectURI != "" { - if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { - return nil, err - } - } - - ti := models.NewToken() - ti.SetClientID(tgr.ClientID) - ti.SetUserID(tgr.UserID) - ti.SetRedirectURI(tgr.RedirectURI) - ti.SetScope(tgr.Scope) - - createAt := time.Now() - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: tgr.UserID, - CreateAt: createAt, - TokenInfo: ti, - Request: tgr.Request, - } - switch rt { - case oauth2.Code: - codeExp := m.codeExp - if codeExp == 0 { - codeExp = DefaultCodeExp - } - ti.SetCodeCreateAt(createAt) - ti.SetCodeExpiresIn(codeExp) - if exp := tgr.AccessTokenExp; exp > 0 { - ti.SetAccessExpiresIn(exp) - } - if tgr.CodeChallenge != "" { - ti.SetCodeChallenge(tgr.CodeChallenge) - ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod) - } - - tv, err := m.authorizeGenerate.Token(ctx, td) - if err != nil { - return nil, err - } - ti.SetCode(tv) - case oauth2.Token: - // set access token expires - icfg := m.grantConfig(oauth2.Implicit) - aexp := icfg.AccessTokenExp - if exp := tgr.AccessTokenExp; exp > 0 { - aexp = exp - } - ti.SetAccessCreateAt(createAt) - ti.SetAccessExpiresIn(aexp) - - if icfg.IsGenerateRefresh { - ti.SetRefreshCreateAt(createAt) - ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) - } - - tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - ti.SetAccess(tv) - - if rv != "" { - ti.SetRefresh(rv) - } - } - - err = m.tokenStore.Create(ctx, ti) - if err != nil { - return nil, err - } - return ti, nil -} - -// get authorization code data -func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - ti, err := m.tokenStore.GetByCode(ctx, code) - if err != nil { - return nil, err - } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { - err = errors.ErrInvalidAuthorizeCode - return nil, errors.ErrInvalidAuthorizeCode - } - return ti, nil -} - -// delete authorization code data -func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { - return m.tokenStore.RemoveByCode(ctx, code) -} - -// get and delete authorization code data -func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - code := tgr.Code - ti, err := m.getAuthorizationCode(ctx, code) - if err != nil { - return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidAuthorizeCode - } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { - return nil, errors.ErrInvalidAuthorizeCode - } - - err = m.delAuthorizationCode(ctx, code) - if err != nil { - return nil, err - } - return ti, nil -} - -func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error { - cc := ti.GetCodeChallenge() - // early return - if cc == "" && ver == "" { - return nil - } - if cc == "" { - return errors.ErrMissingCodeVerifier - } - if ver == "" { - return errors.ErrMissingCodeVerifier - } - ccm := ti.GetCodeChallengeMethod() - if ccm.String() == "" { - ccm = oauth2.CodeChallengePlain - } - if !ccm.Validate(cc, ver) { - return errors.ErrInvalidCodeChallenge - } - return nil -} - -// GenerateAccessToken generate the access token -func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } - if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { - if !cliPass.VerifyPassword(tgr.ClientSecret) { - return nil, errors.ErrInvalidClient - } - } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient - } - if tgr.RedirectURI != "" { - if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { - return nil, err - } - } - - if gt == oauth2.AuthorizationCode { - ti, err := m.getAndDelAuthorizationCode(ctx, tgr) - if err != nil { - return nil, err - } - if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil { - return nil, err - } - tgr.UserID = ti.GetUserID() - tgr.Scope = ti.GetScope() - if exp := ti.GetAccessExpiresIn(); exp > 0 { - tgr.AccessTokenExp = exp - } - } - - ti := models.NewToken() - ti.SetClientID(tgr.ClientID) - ti.SetUserID(tgr.UserID) - ti.SetRedirectURI(tgr.RedirectURI) - ti.SetScope(tgr.Scope) - - createAt := time.Now() - ti.SetAccessCreateAt(createAt) - - // set access token expires - gcfg := m.grantConfig(gt) - aexp := gcfg.AccessTokenExp - if exp := tgr.AccessTokenExp; exp > 0 { - aexp = exp - } - ti.SetAccessExpiresIn(aexp) - if gcfg.IsGenerateRefresh { - ti.SetRefreshCreateAt(createAt) - ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) - } - - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: tgr.UserID, - CreateAt: createAt, - TokenInfo: ti, - Request: tgr.Request, - } - - av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - ti.SetAccess(av) - - if rv != "" { - ti.SetRefresh(rv) - } - - err = m.tokenStore.Create(ctx, ti) - if err != nil { - return nil, err - } - - return ti, nil -} - -// RefreshAccessToken refreshing an access token -func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } else if tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient - } - - ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidRefreshToken - } - - oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() - - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: ti.GetUserID(), - CreateAt: time.Now(), - TokenInfo: ti, - Request: tgr.Request, - } - - rcfg := DefaultRefreshTokenCfg - if v := m.rcfg; v != nil { - rcfg = v - } - - ti.SetAccessCreateAt(td.CreateAt) - if v := rcfg.AccessTokenExp; v > 0 { - ti.SetAccessExpiresIn(v) - } - - if v := rcfg.RefreshTokenExp; v > 0 { - ti.SetRefreshExpiresIn(v) - } - - if rcfg.IsResetRefreshTime { - ti.SetRefreshCreateAt(td.CreateAt) - } - - if scope := tgr.Scope; scope != "" { - ti.SetScope(scope) - } - - tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - - ti.SetAccess(tv) - if rv != "" { - ti.SetRefresh(rv) - } - - if err := m.tokenStore.Create(ctx, ti); err != nil { - return nil, err - } - - if rcfg.IsRemoveAccess { - // remove the old access token - if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { - return nil, err - } - } - - if rcfg.IsRemoveRefreshing && rv != "" { - // remove the old refresh token - if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { - return nil, err - } - } - - if rv == "" { - ti.SetRefresh("") - ti.SetRefreshCreateAt(time.Now()) - ti.SetRefreshExpiresIn(0) - } - - return ti, nil -} - -// RemoveAccessToken use the access token to delete the token information -func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { - if access == "" { - return errors.ErrInvalidAccessToken - } - return m.tokenStore.RemoveByAccess(ctx, access) -} - -// RemoveRefreshToken use the refresh token to delete the token information -func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { - if refresh == "" { - return errors.ErrInvalidAccessToken - } - return m.tokenStore.RemoveByRefresh(ctx, refresh) -} - -// LoadAccessToken according to the access token for corresponding token information -func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { - if access == "" { - return nil, errors.ErrInvalidAccessToken - } - - ct := time.Now() - ti, err := m.tokenStore.GetByAccess(ctx, access) - if err != nil { - return nil, err - } else if ti == nil || ti.GetAccess() != access { - return nil, errors.ErrInvalidAccessToken - } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { - return nil, errors.ErrExpiredRefreshToken - } else if ti.GetAccessExpiresIn() != 0 && - ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { - return nil, errors.ErrExpiredAccessToken - } - return ti, nil -} - -// LoadRefreshToken according to the refresh token for corresponding token information -func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - if refresh == "" { - return nil, errors.ErrInvalidRefreshToken - } - - ti, err := m.tokenStore.GetByRefresh(ctx, refresh) - if err != nil { - return nil, err - } else if ti == nil || ti.GetRefresh() != refresh { - return nil, errors.ErrInvalidRefreshToken - } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { - return nil, errors.ErrExpiredRefreshToken - } - return ti, nil -} +package manage
+
+import (
+ "context"
+ "time"
+
+ "github.com/superseriousbusiness/oauth2/v4"
+ "github.com/superseriousbusiness/oauth2/v4/errors"
+ "github.com/superseriousbusiness/oauth2/v4/generates"
+ "github.com/superseriousbusiness/oauth2/v4/models"
+)
+
+// NewDefaultManager create to default authorization management instance
+func NewDefaultManager() *Manager {
+ m := NewManager()
+ // default implementation
+ m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
+ m.MapAccessGenerate(generates.NewAccessGenerate())
+
+ return m
+}
+
+// NewManager create to authorization management instance
+func NewManager() *Manager {
+ return &Manager{
+ gtcfg: make(map[oauth2.GrantType]*Config),
+ validateURI: DefaultValidateURI,
+ }
+}
+
+// Manager provide authorization management
+type Manager struct {
+ codeExp time.Duration
+ gtcfg map[oauth2.GrantType]*Config
+ rcfg *RefreshingConfig
+ validateURI ValidateURIHandler
+ authorizeGenerate oauth2.AuthorizeGenerate
+ accessGenerate oauth2.AccessGenerate
+ tokenStore oauth2.TokenStore
+ clientStore oauth2.ClientStore
+}
+
+// get grant type config
+func (m *Manager) grantConfig(gt oauth2.GrantType) *Config {
+ if c, ok := m.gtcfg[gt]; ok && c != nil {
+ return c
+ }
+ switch gt {
+ case oauth2.AuthorizationCode:
+ return DefaultAuthorizeCodeTokenCfg
+ case oauth2.Implicit:
+ return DefaultImplicitTokenCfg
+ case oauth2.PasswordCredentials:
+ return DefaultPasswordTokenCfg
+ case oauth2.ClientCredentials:
+ return DefaultClientTokenCfg
+ }
+ return &Config{}
+}
+
+// SetAuthorizeCodeExp set the authorization code expiration time
+func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) {
+ m.codeExp = exp
+}
+
+// SetAuthorizeCodeTokenCfg set the authorization code grant token config
+func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) {
+ m.gtcfg[oauth2.AuthorizationCode] = cfg
+}
+
+// SetImplicitTokenCfg set the implicit grant token config
+func (m *Manager) SetImplicitTokenCfg(cfg *Config) {
+ m.gtcfg[oauth2.Implicit] = cfg
+}
+
+// SetPasswordTokenCfg set the password grant token config
+func (m *Manager) SetPasswordTokenCfg(cfg *Config) {
+ m.gtcfg[oauth2.PasswordCredentials] = cfg
+}
+
+// SetClientTokenCfg set the client grant token config
+func (m *Manager) SetClientTokenCfg(cfg *Config) {
+ m.gtcfg[oauth2.ClientCredentials] = cfg
+}
+
+// SetRefreshTokenCfg set the refreshing token config
+func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) {
+ m.rcfg = cfg
+}
+
+// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI
+func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) {
+ m.validateURI = handler
+}
+
+// MapAuthorizeGenerate mapping the authorize code generate interface
+func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) {
+ m.authorizeGenerate = gen
+}
+
+// MapAccessGenerate mapping the access token generate interface
+func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) {
+ m.accessGenerate = gen
+}
+
+// MapClientStorage mapping the client store interface
+func (m *Manager) MapClientStorage(stor oauth2.ClientStore) {
+ m.clientStore = stor
+}
+
+// MustClientStorage mandatory mapping the client store interface
+func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) {
+ if err != nil {
+ panic(err.Error())
+ }
+ m.clientStore = stor
+}
+
+// MapTokenStorage mapping the token store interface
+func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) {
+ m.tokenStore = stor
+}
+
+// MustTokenStorage mandatory mapping the token store interface
+func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
+ if err != nil {
+ panic(err)
+ }
+ m.tokenStore = stor
+}
+
+// GetClient get the client information
+func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) {
+ cli, err = m.clientStore.GetByID(ctx, clientID)
+ if err != nil {
+ return
+ } else if cli == nil {
+ err = errors.ErrInvalidClient
+ }
+ return
+}
+
+// GenerateAuthToken generate the authorization token(code)
+func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
+ cli, err := m.GetClient(ctx, tgr.ClientID)
+ if err != nil {
+ return nil, err
+ } else if tgr.RedirectURI != "" {
+ if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
+ return nil, err
+ }
+ }
+
+ ti := models.NewToken()
+ ti.SetClientID(tgr.ClientID)
+ ti.SetUserID(tgr.UserID)
+ ti.SetRedirectURI(tgr.RedirectURI)
+ ti.SetScope(tgr.Scope)
+
+ createAt := time.Now()
+ td := &oauth2.GenerateBasic{
+ Client: cli,
+ UserID: tgr.UserID,
+ CreateAt: createAt,
+ TokenInfo: ti,
+ Request: tgr.Request,
+ }
+ switch rt {
+ case oauth2.Code:
+ codeExp := m.codeExp
+ if codeExp == 0 {
+ codeExp = DefaultCodeExp
+ }
+ ti.SetCodeCreateAt(createAt)
+ ti.SetCodeExpiresIn(codeExp)
+ if exp := tgr.AccessTokenExp; exp > 0 {
+ ti.SetAccessExpiresIn(exp)
+ }
+ if tgr.CodeChallenge != "" {
+ ti.SetCodeChallenge(tgr.CodeChallenge)
+ ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
+ }
+
+ tv, err := m.authorizeGenerate.Token(ctx, td)
+ if err != nil {
+ return nil, err
+ }
+ ti.SetCode(tv)
+ case oauth2.Token:
+ // set access token expires
+ icfg := m.grantConfig(oauth2.Implicit)
+ aexp := icfg.AccessTokenExp
+ if exp := tgr.AccessTokenExp; exp > 0 {
+ aexp = exp
+ }
+ ti.SetAccessCreateAt(createAt)
+ ti.SetAccessExpiresIn(aexp)
+
+ if icfg.IsGenerateRefresh {
+ ti.SetRefreshCreateAt(createAt)
+ ti.SetRefreshExpiresIn(icfg.RefreshTokenExp)
+ }
+
+ tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh)
+ if err != nil {
+ return nil, err
+ }
+ ti.SetAccess(tv)
+
+ if rv != "" {
+ ti.SetRefresh(rv)
+ }
+ }
+
+ err = m.tokenStore.Create(ctx, ti)
+ if err != nil {
+ return nil, err
+ }
+ return ti, nil
+}
+
+// get authorization code data
+func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
+ ti, err := m.tokenStore.GetByCode(ctx, code)
+ if err != nil {
+ return nil, err
+ } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) {
+ err = errors.ErrInvalidAuthorizeCode
+ return nil, errors.ErrInvalidAuthorizeCode
+ }
+ return ti, nil
+}
+
+// delete authorization code data
+func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error {
+ return m.tokenStore.RemoveByCode(ctx, code)
+}
+
+// get and delete authorization code data
+func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
+ code := tgr.Code
+ ti, err := m.getAuthorizationCode(ctx, code)
+ if err != nil {
+ return nil, err
+ } else if ti.GetClientID() != tgr.ClientID {
+ return nil, errors.ErrInvalidAuthorizeCode
+ } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI {
+ return nil, errors.ErrInvalidAuthorizeCode
+ }
+
+ err = m.delAuthorizationCode(ctx, code)
+ if err != nil {
+ return nil, err
+ }
+ return ti, nil
+}
+
+func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
+ cc := ti.GetCodeChallenge()
+ // early return
+ if cc == "" && ver == "" {
+ return nil
+ }
+ if cc == "" {
+ return errors.ErrMissingCodeVerifier
+ }
+ if ver == "" {
+ return errors.ErrMissingCodeVerifier
+ }
+ ccm := ti.GetCodeChallengeMethod()
+ if ccm.String() == "" {
+ ccm = oauth2.CodeChallengePlain
+ }
+ if !ccm.Validate(cc, ver) {
+ return errors.ErrInvalidCodeChallenge
+ }
+ return nil
+}
+
+// GenerateAccessToken generate the access token
+func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
+ cli, err := m.GetClient(ctx, tgr.ClientID)
+ if err != nil {
+ return nil, err
+ }
+ if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
+ if !cliPass.VerifyPassword(tgr.ClientSecret) {
+ return nil, errors.ErrInvalidClient
+ }
+ } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() {
+ return nil, errors.ErrInvalidClient
+ }
+ if tgr.RedirectURI != "" {
+ if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
+ return nil, err
+ }
+ }
+
+ if gt == oauth2.AuthorizationCode {
+ ti, err := m.getAndDelAuthorizationCode(ctx, tgr)
+ if err != nil {
+ return nil, err
+ }
+ if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
+ return nil, err
+ }
+ tgr.UserID = ti.GetUserID()
+ tgr.Scope = ti.GetScope()
+ if exp := ti.GetAccessExpiresIn(); exp > 0 {
+ tgr.AccessTokenExp = exp
+ }
+ }
+
+ ti := models.NewToken()
+ ti.SetClientID(tgr.ClientID)
+ ti.SetUserID(tgr.UserID)
+ ti.SetRedirectURI(tgr.RedirectURI)
+ ti.SetScope(tgr.Scope)
+
+ createAt := time.Now()
+ ti.SetAccessCreateAt(createAt)
+
+ // set access token expires
+ gcfg := m.grantConfig(gt)
+ aexp := gcfg.AccessTokenExp
+ if exp := tgr.AccessTokenExp; exp > 0 {
+ aexp = exp
+ }
+ ti.SetAccessExpiresIn(aexp)
+ if gcfg.IsGenerateRefresh {
+ ti.SetRefreshCreateAt(createAt)
+ ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp)
+ }
+
+ td := &oauth2.GenerateBasic{
+ Client: cli,
+ UserID: tgr.UserID,
+ CreateAt: createAt,
+ TokenInfo: ti,
+ Request: tgr.Request,
+ }
+
+ av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh)
+ if err != nil {
+ return nil, err
+ }
+ ti.SetAccess(av)
+
+ if rv != "" {
+ ti.SetRefresh(rv)
+ }
+
+ err = m.tokenStore.Create(ctx, ti)
+ if err != nil {
+ return nil, err
+ }
+
+ return ti, nil
+}
+
+// RefreshAccessToken refreshing an access token
+func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
+ cli, err := m.GetClient(ctx, tgr.ClientID)
+ if err != nil {
+ return nil, err
+ } else if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
+ if !cliPass.VerifyPassword(tgr.ClientSecret) {
+ return nil, errors.ErrInvalidClient
+ }
+ } else if tgr.ClientSecret != cli.GetSecret() {
+ return nil, errors.ErrInvalidClient
+ }
+
+ ti, err := m.LoadRefreshToken(ctx, tgr.Refresh)
+ if err != nil {
+ return nil, err
+ } else if ti.GetClientID() != tgr.ClientID {
+ return nil, errors.ErrInvalidRefreshToken
+ }
+
+ oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
+
+ td := &oauth2.GenerateBasic{
+ Client: cli,
+ UserID: ti.GetUserID(),
+ CreateAt: time.Now(),
+ TokenInfo: ti,
+ Request: tgr.Request,
+ }
+
+ rcfg := DefaultRefreshTokenCfg
+ if v := m.rcfg; v != nil {
+ rcfg = v
+ }
+
+ ti.SetAccessCreateAt(td.CreateAt)
+ if v := rcfg.AccessTokenExp; v > 0 {
+ ti.SetAccessExpiresIn(v)
+ }
+
+ if v := rcfg.RefreshTokenExp; v > 0 {
+ ti.SetRefreshExpiresIn(v)
+ }
+
+ if rcfg.IsResetRefreshTime {
+ ti.SetRefreshCreateAt(td.CreateAt)
+ }
+
+ if scope := tgr.Scope; scope != "" {
+ ti.SetScope(scope)
+ }
+
+ tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh)
+ if err != nil {
+ return nil, err
+ }
+
+ ti.SetAccess(tv)
+ if rv != "" {
+ ti.SetRefresh(rv)
+ }
+
+ if err := m.tokenStore.Create(ctx, ti); err != nil {
+ return nil, err
+ }
+
+ if rcfg.IsRemoveAccess {
+ // remove the old access token
+ if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil {
+ return nil, err
+ }
+ }
+
+ if rcfg.IsRemoveRefreshing && rv != "" {
+ // remove the old refresh token
+ if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil {
+ return nil, err
+ }
+ }
+
+ if rv == "" {
+ ti.SetRefresh("")
+ ti.SetRefreshCreateAt(time.Now())
+ ti.SetRefreshExpiresIn(0)
+ }
+
+ return ti, nil
+}
+
+// RemoveAccessToken use the access token to delete the token information
+func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error {
+ if access == "" {
+ return errors.ErrInvalidAccessToken
+ }
+ return m.tokenStore.RemoveByAccess(ctx, access)
+}
+
+// RemoveRefreshToken use the refresh token to delete the token information
+func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error {
+ if refresh == "" {
+ return errors.ErrInvalidAccessToken
+ }
+ return m.tokenStore.RemoveByRefresh(ctx, refresh)
+}
+
+// LoadAccessToken according to the access token for corresponding token information
+func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) {
+ if access == "" {
+ return nil, errors.ErrInvalidAccessToken
+ }
+
+ ct := time.Now()
+ ti, err := m.tokenStore.GetByAccess(ctx, access)
+ if err != nil {
+ return nil, err
+ } else if ti == nil || ti.GetAccess() != access {
+ return nil, errors.ErrInvalidAccessToken
+ } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
+ ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
+ return nil, errors.ErrExpiredRefreshToken
+ } else if ti.GetAccessExpiresIn() != 0 &&
+ ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
+ return nil, errors.ErrExpiredAccessToken
+ }
+ return ti, nil
+}
+
+// LoadRefreshToken according to the refresh token for corresponding token information
+func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
+ if refresh == "" {
+ return nil, errors.ErrInvalidRefreshToken
+ }
+
+ ti, err := m.tokenStore.GetByRefresh(ctx, refresh)
+ if err != nil {
+ return nil, err
+ } else if ti == nil || ti.GetRefresh() != refresh {
+ return nil, errors.ErrInvalidRefreshToken
+ } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
+ ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
+ return nil, errors.ErrExpiredRefreshToken
+ }
+ return ti, nil
+}
|