diff options
Diffstat (limited to 'internal/oauth/pgtokenstore.go')
-rw-r--r-- | internal/oauth/pgtokenstore.go | 67 |
1 files changed, 33 insertions, 34 deletions
diff --git a/internal/oauth/pgtokenstore.go b/internal/oauth/pgtokenstore.go index 26026d292..a927be862 100644 --- a/internal/oauth/pgtokenstore.go +++ b/internal/oauth/pgtokenstore.go @@ -23,14 +23,14 @@ import ( "errors" "time" - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/models" "github.com/go-pg/pg/v10" + "github.com/gotosocial/oauth2/v4" + "github.com/gotosocial/oauth2/v4/models" "github.com/sirupsen/logrus" ) -// PGTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend. -type PGTokenStore struct { +// pgTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend. +type pgTokenStore struct { oauth2.TokenStore conn *pg.DB log *logrus.Logger @@ -41,13 +41,13 @@ type PGTokenStore struct { // In order to allow tokens to 'expire' (not really a thing in Postgres world), it will also set off a // goroutine that iterates through the tokens in the DB once per minute and deletes any that have expired. func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth2.TokenStore { - pts := &PGTokenStore{ + pts := &pgTokenStore{ conn: conn, log: log, } // set the token store to clean out expired tokens once per minute, or return if we're done - go func(ctx context.Context, pts *PGTokenStore, log *logrus.Logger) { + go func(ctx context.Context, pts *pgTokenStore, log *logrus.Logger) { cleanloop: for { select { @@ -66,10 +66,10 @@ func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth } // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. -func (pts *PGTokenStore) sweep() error { +func (pts *pgTokenStore) sweep() error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. - var tokens []pgOauthToken + var tokens []oauthToken if err := pts.conn.Model(&tokens).Select(); err != nil { return err } @@ -91,8 +91,8 @@ func (pts *PGTokenStore) sweep() error { } // Create creates and store the new token information. -// For the original implementation, see https://github.com/go-oauth2/oauth2/blob/master/store/token.go#L34 -func (pts *PGTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { +// For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34 +func (pts *pgTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { return errors.New("info param was not a models.Token") @@ -102,26 +102,26 @@ func (pts *PGTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) erro } // RemoveByCode deletes a token from the DB based on the Code field -func (pts *PGTokenStore) RemoveByCode(ctx context.Context, code string) error { - _, err := pts.conn.Model(&pgOauthToken{}).Where("code = ?", code).Delete() +func (pts *pgTokenStore) RemoveByCode(ctx context.Context, code string) error { + _, err := pts.conn.Model(&oauthToken{}).Where("code = ?", code).Delete() return err } // RemoveByAccess deletes a token from the DB based on the Access field -func (pts *PGTokenStore) RemoveByAccess(ctx context.Context, access string) error { - _, err := pts.conn.Model(&pgOauthToken{}).Where("access = ?", access).Delete() +func (pts *pgTokenStore) RemoveByAccess(ctx context.Context, access string) error { + _, err := pts.conn.Model(&oauthToken{}).Where("access = ?", access).Delete() return err } // RemoveByRefresh deletes a token from the DB based on the Refresh field -func (pts *PGTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - _, err := pts.conn.Model(&pgOauthToken{}).Where("refresh = ?", refresh).Delete() +func (pts *pgTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { + _, err := pts.conn.Model(&oauthToken{}).Where("refresh = ?", refresh).Delete() return err } // GetByCode selects a token from the DB based on the Code field -func (pts *PGTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - pgt := &pgOauthToken{} +func (pts *pgTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{} if err := pts.conn.Model(pgt).Where("code = ?", code).Select(); err != nil { return nil, err } @@ -129,8 +129,8 @@ func (pts *PGTokenStore) GetByCode(ctx context.Context, code string) (oauth2.Tok } // GetByAccess selects a token from the DB based on the Access field -func (pts *PGTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { - pgt := &pgOauthToken{} +func (pts *pgTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{} if err := pts.conn.Model(pgt).Where("access = ?", access).Select(); err != nil { return nil, err } @@ -138,8 +138,8 @@ func (pts *PGTokenStore) GetByAccess(ctx context.Context, access string) (oauth2 } // GetByRefresh selects a token from the DB based on the Refresh field -func (pts *PGTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - pgt := &pgOauthToken{} +func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{} if err := pts.conn.Model(pgt).Where("refresh = ?", refresh).Select(); err != nil { return nil, err } @@ -150,18 +150,17 @@ func (pts *PGTokenStore) GetByRefresh(ctx context.Context, refresh string) (oaut The following models are basically helpers for the postgres token store implementation, they should only be used internally. */ -// pgOauthToken is a translation of the go-oauth2 token with the ExpiresIn fields replaced with ExpiresAt. +// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. // -// Explanation for this: go-oauth2 assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, +// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, // and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and // then periodically sweep out tokens when that time has passed. // -// Note that this struct does *not* satisfy the token interface shown here: https://github.com/go-oauth2/oauth2/blob/master/model.go#L22 -// and implemented here: https://github.com/go-oauth2/oauth2/blob/master/models/token.go. -// As such, manual translation is always required between pgOauthToken and the go-oauth2 *model.Token. The helper functions oauthTokenToPGToken +// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22 +// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go. +// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // and pgTokenToOauthToken can be used for that. -type pgOauthToken struct { - tableName struct{} `pg:"oauth_tokens"` +type oauthToken struct { ClientID string UserID string RedirectURI string @@ -179,8 +178,8 @@ type pgOauthToken struct { RefreshExpiresAt time.Time `pg:"type:timestamp"` } -// oauthTokenToPGToken is a lil util function that takes a go-oauth2 token and gives back a token for inserting into postgres -func oauthTokenToPGToken(tkn *models.Token) *pgOauthToken { +// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres +func oauthTokenToPGToken(tkn *models.Token) *oauthToken { now := time.Now() // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's @@ -202,7 +201,7 @@ func oauthTokenToPGToken(tkn *models.Token) *pgOauthToken { rea = now.Add(tkn.RefreshExpiresIn) } - return &pgOauthToken{ + return &oauthToken{ ClientID: tkn.ClientID, UserID: tkn.UserID, RedirectURI: tkn.RedirectURI, @@ -221,8 +220,8 @@ func oauthTokenToPGToken(tkn *models.Token) *pgOauthToken { } } -// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a go-oauth2 token -func pgTokenToOauthToken(pgt *pgOauthToken) *models.Token { +// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token +func pgTokenToOauthToken(pgt *oauthToken) *models.Token { now := time.Now() return &models.Token{ |