diff options
author | 2021-08-25 15:34:33 +0200 | |
---|---|---|
committer | 2021-08-25 15:34:33 +0200 | |
commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /internal/oauth/tokenstore.go | |
parent | Manually approves followers (#146) (diff) | |
download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'internal/oauth/tokenstore.go')
-rw-r--r-- | internal/oauth/tokenstore.go | 136 |
1 files changed, 68 insertions, 68 deletions
diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go index 4fd3183fc..264678ff5 100644 --- a/internal/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -43,13 +43,13 @@ type tokenStore struct { // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // the tokens in the DB once per minute and deletes any that have expired. func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore { - pts := &tokenStore{ + ts := &tokenStore{ db: db, log: log, } // set the token store to clean out expired tokens once per minute, or return if we're done - go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) { + go func(ctx context.Context, ts *tokenStore, log *logrus.Logger) { cleanloop: for { select { @@ -58,32 +58,32 @@ func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2. break cleanloop case <-time.After(1 * time.Minute): log.Trace("sweeping out old oauth entries broom broom") - if err := pts.sweep(); err != nil { + if err := ts.sweep(ctx); err != nil { log.Errorf("error while sweeping oauth entries: %s", err) } } } - }(ctx, pts, log) - return pts + }(ctx, ts, log) + return ts } // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. -func (pts *tokenStore) sweep() error { +func (ts *tokenStore) sweep(ctx context.Context) error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. tokens := new([]*Token) - if err := pts.db.GetAll(tokens); err != nil { + if err := ts.db.GetAll(ctx, tokens); err != nil { return err } // iterate through and remove expired tokens now := time.Now() - for _, pgt := range *tokens { + for _, dbt := range *tokens { // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: // we only want to check if a token expired before now if the expiry time is *not zero*; // ie., if it's been explicity set. - if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) { - if err := pts.db.DeleteByID(pgt.ID, pgt); err != nil { + if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { + if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil { return err } } @@ -94,92 +94,92 @@ func (pts *tokenStore) sweep() error { // Create creates and store the new token information. // For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34 -func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { +func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { return errors.New("info param was not a models.Token") } - pgt := TokenToPGToken(t) - if pgt.ID == "" { - pgtID, err := id.NewRandomULID() + dbt := TokenToDBToken(t) + if dbt.ID == "" { + dbtID, err := id.NewRandomULID() if err != nil { return err } - pgt.ID = pgtID + dbt.ID = dbtID } - if err := pts.db.Put(pgt); err != nil { + if err := ts.db.Put(ctx, dbt); err != nil { return fmt.Errorf("error in tokenstore create: %s", err) } return nil } // RemoveByCode deletes a token from the DB based on the Code field -func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { - return pts.db.DeleteWhere([]db.Where{{Key: "code", Value: code}}, &Token{}) +func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{}) } // RemoveByAccess deletes a token from the DB based on the Access field -func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { - return pts.db.DeleteWhere([]db.Where{{Key: "access", Value: access}}, &Token{}) +func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{}) } // RemoveByRefresh deletes a token from the DB based on the Refresh field -func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - return pts.db.DeleteWhere([]db.Where{{Key: "refresh", Value: refresh}}, &Token{}) +func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{}) } // GetByCode selects a token from the DB based on the Code field -func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { if code == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Code: code, } - if err := pts.db.GetWhere([]db.Where{{Key: "code", Value: code}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } // GetByAccess selects a token from the DB based on the Access field -func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { if access == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Access: access, } - if err := pts.db.GetWhere([]db.Where{{Key: "access", Value: access}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } // GetByRefresh selects a token from the DB based on the Refresh field -func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { if refresh == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Refresh: refresh, } - if err := pts.db.GetWhere([]db.Where{{Key: "refresh", Value: refresh}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } /* - The following models are basically helpers for the postgres token store implementation, they should only be used internally. + The following models are basically helpers for the token store implementation, they should only be used internally. */ // Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. // // Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, -// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and +// and tokens with expired TTLs are automatically removed. Since some databases don't have that feature, it's easier to set an expiry time and // then periodically sweep out tokens when that time has passed. // // Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22 @@ -187,26 +187,26 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2 // As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // and pgTokenToOauthToken can be used for that. type Token struct { - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` ClientID string UserID string RedirectURI string Scope string - Code string `pg:"default:'',pk"` + Code string `bun:"default:'',pk"` CodeChallenge string CodeChallengeMethod string - CodeCreateAt time.Time `pg:"type:timestamp"` - CodeExpiresAt time.Time `pg:"type:timestamp"` - Access string `pg:"default:'',pk"` - AccessCreateAt time.Time `pg:"type:timestamp"` - AccessExpiresAt time.Time `pg:"type:timestamp"` - Refresh string `pg:"default:'',pk"` - RefreshCreateAt time.Time `pg:"type:timestamp"` - RefreshExpiresAt time.Time `pg:"type:timestamp"` + CodeCreateAt time.Time `bun:",nullzero"` + CodeExpiresAt time.Time `bun:",nullzero"` + Access string `bun:"default:'',pk"` + AccessCreateAt time.Time `bun:",nullzero"` + AccessExpiresAt time.Time `bun:",nullzero"` + Refresh string `bun:"default:'',pk"` + RefreshCreateAt time.Time `bun:",nullzero"` + RefreshExpiresAt time.Time `bun:",nullzero"` } -// TokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres -func TokenToPGToken(tkn *models.Token) *Token { +// TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database. +func TokenToDBToken(tkn *models.Token) *Token { now := time.Now() // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's @@ -247,40 +247,40 @@ func TokenToPGToken(tkn *models.Token) *Token { } } -// TokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token -func TokenToOauthToken(pgt *Token) *models.Token { +// DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token +func DBTokenToToken(dbt *Token) *models.Token { now := time.Now() var codeExpiresIn time.Duration - if !pgt.CodeExpiresAt.IsZero() { - codeExpiresIn = pgt.CodeExpiresAt.Sub(now) + if !dbt.CodeExpiresAt.IsZero() { + codeExpiresIn = dbt.CodeExpiresAt.Sub(now) } var accessExpiresIn time.Duration - if !pgt.AccessExpiresAt.IsZero() { - accessExpiresIn = pgt.AccessExpiresAt.Sub(now) + if !dbt.AccessExpiresAt.IsZero() { + accessExpiresIn = dbt.AccessExpiresAt.Sub(now) } var refreshExpiresIn time.Duration - if !pgt.RefreshExpiresAt.IsZero() { - refreshExpiresIn = pgt.RefreshExpiresAt.Sub(now) + if !dbt.RefreshExpiresAt.IsZero() { + refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now) } return &models.Token{ - ClientID: pgt.ClientID, - UserID: pgt.UserID, - RedirectURI: pgt.RedirectURI, - Scope: pgt.Scope, - Code: pgt.Code, - CodeChallenge: pgt.CodeChallenge, - CodeChallengeMethod: pgt.CodeChallengeMethod, - CodeCreateAt: pgt.CodeCreateAt, + ClientID: dbt.ClientID, + UserID: dbt.UserID, + RedirectURI: dbt.RedirectURI, + Scope: dbt.Scope, + Code: dbt.Code, + CodeChallenge: dbt.CodeChallenge, + CodeChallengeMethod: dbt.CodeChallengeMethod, + CodeCreateAt: dbt.CodeCreateAt, CodeExpiresIn: codeExpiresIn, - Access: pgt.Access, - AccessCreateAt: pgt.AccessCreateAt, + Access: dbt.Access, + AccessCreateAt: dbt.AccessCreateAt, AccessExpiresIn: accessExpiresIn, - Refresh: pgt.Refresh, - RefreshCreateAt: pgt.RefreshCreateAt, + Refresh: dbt.Refresh, + RefreshCreateAt: dbt.RefreshCreateAt, RefreshExpiresIn: refreshExpiresIn, } } |