package generates import ( "context" "encoding/base64" "strings" "time" "code.superseriousbusiness.org/oauth2/v4" "code.superseriousbusiness.org/oauth2/v4/errors" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) // JWTAccessClaims jwt claims type JWTAccessClaims struct { jwt.RegisteredClaims } // Valid claims verification func (a *JWTAccessClaims) Valid() error { if a.ExpiresAt != nil && time.Unix(a.ExpiresAt.Unix(), 0).Before(time.Now()) { return errors.ErrInvalidAccessToken } return nil } // NewJWTAccessGenerate create to generate the jwt access token instance func NewJWTAccessGenerate(kid string, key []byte, method jwt.SigningMethod) *JWTAccessGenerate { return &JWTAccessGenerate{ SignedKeyID: kid, SignedKey: key, SignedMethod: method, } } // JWTAccessGenerate generate the jwt access token type JWTAccessGenerate struct { SignedKeyID string SignedKey []byte SignedMethod jwt.SigningMethod } // Token based on the UUID generated token func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) { claims := &JWTAccessClaims{ RegisteredClaims: jwt.RegisteredClaims{ Audience: jwt.ClaimStrings{data.Client.GetID()}, Subject: data.UserID, ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())), }, } token := jwt.NewWithClaims(a.SignedMethod, claims) if a.SignedKeyID != "" { token.Header["kid"] = a.SignedKeyID } var key interface{} if a.isEs() { v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey) if err != nil { return "", "", err } key = v } else if a.isRsOrPS() { v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey) if err != nil { return "", "", err } key = v } else if a.isHs() { key = a.SignedKey } else if a.isEd() { v, err := jwt.ParseEdPrivateKeyFromPEM(a.SignedKey) if err != nil { return "", "", err } key = v } else { return "", "", errors.New("unsupported sign method") } access, err := token.SignedString(key) if err != nil { return "", "", err } refresh := "" if isGenRefresh { t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String() refresh = base64.URLEncoding.EncodeToString([]byte(t)) refresh = strings.ToUpper(strings.TrimRight(refresh, "=")) } return access, refresh, nil } func (a *JWTAccessGenerate) isEs() bool { return strings.HasPrefix(a.SignedMethod.Alg(), "ES") } func (a *JWTAccessGenerate) isRsOrPS() bool { isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS") isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS") return isRs || isPs } func (a *JWTAccessGenerate) isHs() bool { return strings.HasPrefix(a.SignedMethod.Alg(), "HS") } func (a *JWTAccessGenerate) isEd() bool { return strings.HasPrefix(a.SignedMethod.Alg(), "Ed") }