diff options
author | 2021-03-15 18:59:38 +0100 | |
---|---|---|
committer | 2021-03-15 18:59:38 +0100 | |
commit | cca676dcb3824e9f9ee21ecf0a2cf8c55b0d4ad9 (patch) | |
tree | 0aef8a75cbd731eb1637fcac6964b9eea1adcb78 /internal | |
parent | bit of experimenting and tidying (diff) | |
download | gotosocial-cca676dcb3824e9f9ee21ecf0a2cf8c55b0d4ad9.tar.xz |
tests
Diffstat (limited to 'internal')
-rw-r--r-- | internal/api/server.go | 18 | ||||
-rw-r--r-- | internal/db/db.go | 2 | ||||
-rw-r--r-- | internal/db/postgres.go | 5 | ||||
-rw-r--r-- | internal/gotosocial/gotosocial.go | 2 | ||||
-rw-r--r-- | internal/oauth/README.md | 2 | ||||
-rw-r--r-- | internal/oauth/oauth.go | 71 | ||||
-rw-r--r-- | internal/oauth/pgclientstore.go | 74 | ||||
-rw-r--r-- | internal/oauth/pgclientstore_test.go | 104 | ||||
-rw-r--r-- | internal/oauth/pgtokenstore.go | 67 |
9 files changed, 256 insertions, 89 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index a27178855..ed622210b 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -28,9 +28,9 @@ import ( type Server interface { AttachHTTPHandler(method string, path string, handler http.HandlerFunc) - AttachGinHandler(method string, path string, handler gin.HandlerFunc) - // AttachMiddleware(handler gin.HandlerFunc) - GetAPIGroup() *gin.RouterGroup + AttachGinHandler(method string, path string, handler gin.HandlerFunc) + // AttachMiddleware(handler gin.HandlerFunc) + GetAPIGroup() *gin.RouterGroup Start() Stop() } @@ -46,20 +46,22 @@ func (s *server) GetAPIGroup() *gin.RouterGroup { } func (s *server) Start() { - // todo: start gracefully - s.engine.Run() + // todo: start gracefully + if err := s.engine.Run(); err != nil { + s.logger.Panicf("server error: %s", err) + } } func (s *server) Stop() { - // todo: shut down gracefully + // todo: shut down gracefully } func (s *server) AttachHTTPHandler(method string, path string, handler http.HandlerFunc) { - s.engine.Handle(method, path, gin.WrapH(handler)) + s.engine.Handle(method, path, gin.WrapH(handler)) } func (s *server) AttachGinHandler(method string, path string, handler gin.HandlerFunc) { - s.engine.Handle(method, path, handler) + s.engine.Handle(method, path, handler) } func New(config *config.Config, logger *logrus.Logger) Server { diff --git a/internal/db/db.go b/internal/db/db.go index 9ed196b2f..03e30b41b 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -24,8 +24,8 @@ import ( "strings" "github.com/go-fed/activity/pub" - "github.com/go-oauth2/oauth2/v4" "github.com/gotosocial/gotosocial/internal/config" + "github.com/gotosocial/oauth2/v4" "github.com/sirupsen/logrus" ) diff --git a/internal/db/postgres.go b/internal/db/postgres.go index 0453b207b..96452d5ae 100644 --- a/internal/db/postgres.go +++ b/internal/db/postgres.go @@ -30,13 +30,12 @@ import ( "github.com/go-fed/activity/streams" "github.com/go-fed/activity/streams/vocab" - "github.com/go-oauth2/oauth2/v4" "github.com/go-pg/pg/extra/pgdebug" "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/gtsmodel" - "github.com/gotosocial/gotosocial/internal/oauth" + "github.com/gotosocial/oauth2/v4" "github.com/sirupsen/logrus" ) @@ -46,7 +45,7 @@ type postgresService struct { log *logrus.Entry cancel context.CancelFunc locks *sync.Map - tokenStore *oauth.PGTokenStore + tokenStore oauth2.TokenStore } // newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. diff --git a/internal/gotosocial/gotosocial.go b/internal/gotosocial/gotosocial.go index a43af65f9..4409e85de 100644 --- a/internal/gotosocial/gotosocial.go +++ b/internal/gotosocial/gotosocial.go @@ -22,8 +22,8 @@ import ( "context" "github.com/go-fed/activity/pub" - "github.com/gotosocial/gotosocial/internal/cache" "github.com/gotosocial/gotosocial/internal/api" + "github.com/gotosocial/gotosocial/internal/cache" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/db" ) diff --git a/internal/oauth/README.md b/internal/oauth/README.md index 5eaef673f..50a9e1274 100644 --- a/internal/oauth/README.md +++ b/internal/oauth/README.md @@ -1,3 +1,3 @@ # oauth -This package provides uses [go-oauth2](https://github.com/go-oauth2/oauth2) to provide [oauth2](https://www.oauth.com/) server functionality to the GoToSocial APIs. +This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) server functionality to the GoToSocial APIs. diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index ac833d1fc..d79db95ed 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -18,47 +18,36 @@ package oauth -type Server struct { - +import ( + "github.com/gotosocial/oauth2/v4" + "github.com/gotosocial/oauth2/v4/errors" + "github.com/gotosocial/oauth2/v4/manage" + "github.com/gotosocial/oauth2/v4/server" + "github.com/sirupsen/logrus" +) + +type API struct { + manager *manage.Manager + server *server.Server } -func main() { -// manager := manage.NewDefaultManager() -// // token memory store -// manager.MustTokenStorage(store.NewMemoryTokenStore()) - -// // client memory store -// clientStore := store.NewClientStore() -// clientStore.Set("000000", &models.Client{ -// ID: "000000", -// Secret: "999999", -// Domain: "http://localhost", -// }) -// manager.MapClientStorage(clientStore) - -// srv := server.NewDefaultServer(manager) -// srv.SetAllowGetAccessRequest(true) -// srv.SetClientInfoHandler(server.ClientFormHandler) - -// srv.SetInternalErrorHandler(func(err error) (re *errors.Response) { -// log.Println("Internal Error:", err.Error()) -// return -// }) - -// srv.SetResponseErrorHandler(func(re *errors.Response) { -// log.Println("Response Error:", re.Error.Error()) -// }) - -// http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { -// err := srv.HandleAuthorizeRequest(w, r) -// if err != nil { -// http.Error(w, err.Error(), http.StatusBadRequest) -// } -// }) - -// http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { -// srv.HandleTokenRequest(w, r) -// }) - -// log.Fatal(http.ListenAndServe(":9096", nil)) +func New(ts oauth2.TokenStore, cs oauth2.ClientStore, log *logrus.Logger) *API { + manager := manage.NewDefaultManager() + manager.MapTokenStorage(ts) + manager.MapClientStorage(cs) + + srv := server.NewDefaultServer(manager) + srv.SetInternalErrorHandler(func(err error) *errors.Response { + log.Errorf("internal oauth error: %s", err) + return nil + }) + + srv.SetResponseErrorHandler(func(re *errors.Response) { + log.Errorf("internal response error: %s", re.Error) + }) + + return &API{ + manager: manager, + server: srv, + } } diff --git a/internal/oauth/pgclientstore.go b/internal/oauth/pgclientstore.go new file mode 100644 index 000000000..dda5fb3d6 --- /dev/null +++ b/internal/oauth/pgclientstore.go @@ -0,0 +1,74 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package oauth + +import ( + "context" + + "github.com/go-pg/pg/v10" + "github.com/gotosocial/oauth2/v4" + "github.com/gotosocial/oauth2/v4/models" +) + +type pgClientStore struct { + conn *pg.DB +} + +func NewPGClientStore(conn *pg.DB) oauth2.ClientStore { + pts := &pgClientStore{ + conn: conn, + } + return pts +} + +func (pcs *pgClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) { + poc := &oauthClient{ + ID: id, + } + if err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Select(); err != nil { + return nil, err + } + return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil +} + +func (pcs *pgClientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { + poc := &oauthClient{ + ID: cli.GetID(), + Secret: cli.GetSecret(), + Domain: cli.GetDomain(), + UserID: cli.GetUserID(), + } + _, err := pcs.conn.WithContext(ctx).Model(poc).OnConflict("(id) DO UPDATE").Insert() + return err +} + +func (pcs *pgClientStore) Delete(ctx context.Context, id string) error { + poc := &oauthClient{ + ID: id, + } + _, err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Delete() + return err +} + +type oauthClient struct { + ID string + Secret string + Domain string + UserID string +} diff --git a/internal/oauth/pgclientstore_test.go b/internal/oauth/pgclientstore_test.go new file mode 100644 index 000000000..3f8de064d --- /dev/null +++ b/internal/oauth/pgclientstore_test.go @@ -0,0 +1,104 @@ +package oauth + +import ( + "context" + "testing" + + "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/gotosocial/oauth2/v4/models" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" +) + +type PgClientStoreTestSuite struct { + suite.Suite + conn *pg.DB + testClientID string + testClientSecret string + testClientDomain string + testClientUserID string +} + +const () + +// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout +func (suite *PgClientStoreTestSuite) SetupSuite() { + suite.testClientID = "test-client-id" + suite.testClientSecret = "test-client-secret" + suite.testClientDomain = "https://example.org" + suite.testClientUserID = "test-client-user-id" +} + +// SetupTest creates a postgres connection and creates the oauth_clients table before each test +func (suite *PgClientStoreTestSuite) SetupTest() { + suite.conn = pg.Connect(&pg.Options{}) + if err := suite.conn.Ping(context.Background()); err != nil { + logrus.Panicf("db connection error: %s", err) + } + if err := suite.conn.Model(&oauthClient{}).CreateTable(&orm.CreateTableOptions{ + IfNotExists: true, + }); err != nil { + logrus.Panicf("db connection error: %s", err) + } +} + +// TearDownTest drops the oauth_clients table and closes the pg connection after each test +func (suite *PgClientStoreTestSuite) TearDownTest() { + if err := suite.conn.Model(&oauthClient{}).DropTable(&orm.DropTableOptions{}); err != nil { + logrus.Panicf("drop table error: %s", err) + } + if err := suite.conn.Close(); err != nil { + logrus.Panicf("error closing db connection: %s", err) + } + suite.conn = nil +} + +func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { + // set a new client in the store + cs := NewPGClientStore(suite.conn) + if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { + suite.FailNow(err.Error()) + } + + // fetch that client from the store + client, err := cs.GetByID(context.Background(), suite.testClientID) + if err != nil { + suite.FailNow(err.Error()) + } + + // check that the values are the same + suite.NotNil(client) + suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client) +} + +func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { + // set a new client in the store + cs := NewPGClientStore(suite.conn) + if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { + suite.FailNow(err.Error()) + } + + // fetch the client from the store + client, err := cs.GetByID(context.Background(), suite.testClientID) + if err != nil { + suite.FailNow(err.Error()) + } + + // check that the values are the same + suite.NotNil(client) + suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client) + if err := cs.Delete(context.Background(), suite.testClientID); err != nil { + suite.FailNow(err.Error()) + } + + // try to get the deleted client; we should get an error + deletedClient, err := cs.GetByID(context.Background(), suite.testClientID) + suite.Assert().Nil(deletedClient) + suite.Assert().NotNil(err) + suite.EqualValues("pg: no rows in result set", err.Error()) +} + +func TestPgClientStoreTestSuite(t *testing.T) { + suite.Run(t, new(PgClientStoreTestSuite)) +} diff --git a/internal/oauth/pgtokenstore.go b/internal/oauth/pgtokenstore.go index 26026d292..a927be862 100644 --- a/internal/oauth/pgtokenstore.go +++ b/internal/oauth/pgtokenstore.go @@ -23,14 +23,14 @@ import ( "errors" "time" - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/models" "github.com/go-pg/pg/v10" + "github.com/gotosocial/oauth2/v4" + "github.com/gotosocial/oauth2/v4/models" "github.com/sirupsen/logrus" ) -// PGTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend. -type PGTokenStore struct { +// pgTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend. +type pgTokenStore struct { oauth2.TokenStore conn *pg.DB log *logrus.Logger @@ -41,13 +41,13 @@ type PGTokenStore struct { // In order to allow tokens to 'expire' (not really a thing in Postgres world), it will also set off a // goroutine that iterates through the tokens in the DB once per minute and deletes any that have expired. func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth2.TokenStore { - pts := &PGTokenStore{ + pts := &pgTokenStore{ conn: conn, log: log, } // set the token store to clean out expired tokens once per minute, or return if we're done - go func(ctx context.Context, pts *PGTokenStore, log *logrus.Logger) { + go func(ctx context.Context, pts *pgTokenStore, log *logrus.Logger) { cleanloop: for { select { @@ -66,10 +66,10 @@ func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth } // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. -func (pts *PGTokenStore) sweep() error { +func (pts *pgTokenStore) sweep() error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. - var tokens []pgOauthToken + var tokens []oauthToken if err := pts.conn.Model(&tokens).Select(); err != nil { return err } @@ -91,8 +91,8 @@ func (pts *PGTokenStore) sweep() error { } // Create creates and store the new token information. -// For the original implementation, see https://github.com/go-oauth2/oauth2/blob/master/store/token.go#L34 -func (pts *PGTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { +// For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34 +func (pts *pgTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { return errors.New("info param was not a models.Token") @@ -102,26 +102,26 @@ func (pts *PGTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) erro } // RemoveByCode deletes a token from the DB based on the Code field -func (pts *PGTokenStore) RemoveByCode(ctx context.Context, code string) error { - _, err := pts.conn.Model(&pgOauthToken{}).Where("code = ?", code).Delete() +func (pts *pgTokenStore) RemoveByCode(ctx context.Context, code string) error { + _, err := pts.conn.Model(&oauthToken{}).Where("code = ?", code).Delete() return err } // RemoveByAccess deletes a token from the DB based on the Access field -func (pts *PGTokenStore) RemoveByAccess(ctx context.Context, access string) error { - _, err := pts.conn.Model(&pgOauthToken{}).Where("access = ?", access).Delete() +func (pts *pgTokenStore) RemoveByAccess(ctx context.Context, access string) error { + _, err := pts.conn.Model(&oauthToken{}).Where("access = ?", access).Delete() return err } // RemoveByRefresh deletes a token from the DB based on the Refresh field -func (pts *PGTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - _, err := pts.conn.Model(&pgOauthToken{}).Where("refresh = ?", refresh).Delete() +func (pts *pgTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { + _, err := pts.conn.Model(&oauthToken{}).Where("refresh = ?", refresh).Delete() return err } // GetByCode selects a token from the DB based on the Code field -func (pts *PGTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - pgt := &pgOauthToken{} +func (pts *pgTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{} if err := pts.conn.Model(pgt).Where("code = ?", code).Select(); err != nil { return nil, err } @@ -129,8 +129,8 @@ func (pts *PGTokenStore) GetByCode(ctx context.Context, code string) (oauth2.Tok } // GetByAccess selects a token from the DB based on the Access field -func (pts *PGTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { - pgt := &pgOauthToken{} +func (pts *pgTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{} if err := pts.conn.Model(pgt).Where("access = ?", access).Select(); err != nil { return nil, err } @@ -138,8 +138,8 @@ func (pts *PGTokenStore) GetByAccess(ctx context.Context, access string) (oauth2 } // GetByRefresh selects a token from the DB based on the Refresh field -func (pts *PGTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - pgt := &pgOauthToken{} +func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{} if err := pts.conn.Model(pgt).Where("refresh = ?", refresh).Select(); err != nil { return nil, err } @@ -150,18 +150,17 @@ func (pts *PGTokenStore) GetByRefresh(ctx context.Context, refresh string) (oaut The following models are basically helpers for the postgres token store implementation, they should only be used internally. */ -// pgOauthToken is a translation of the go-oauth2 token with the ExpiresIn fields replaced with ExpiresAt. +// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. // -// Explanation for this: go-oauth2 assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, +// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, // and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and // then periodically sweep out tokens when that time has passed. // -// Note that this struct does *not* satisfy the token interface shown here: https://github.com/go-oauth2/oauth2/blob/master/model.go#L22 -// and implemented here: https://github.com/go-oauth2/oauth2/blob/master/models/token.go. -// As such, manual translation is always required between pgOauthToken and the go-oauth2 *model.Token. The helper functions oauthTokenToPGToken +// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22 +// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go. +// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // and pgTokenToOauthToken can be used for that. -type pgOauthToken struct { - tableName struct{} `pg:"oauth_tokens"` +type oauthToken struct { ClientID string UserID string RedirectURI string @@ -179,8 +178,8 @@ type pgOauthToken struct { RefreshExpiresAt time.Time `pg:"type:timestamp"` } -// oauthTokenToPGToken is a lil util function that takes a go-oauth2 token and gives back a token for inserting into postgres -func oauthTokenToPGToken(tkn *models.Token) *pgOauthToken { +// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres +func oauthTokenToPGToken(tkn *models.Token) *oauthToken { now := time.Now() // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's @@ -202,7 +201,7 @@ func oauthTokenToPGToken(tkn *models.Token) *pgOauthToken { rea = now.Add(tkn.RefreshExpiresIn) } - return &pgOauthToken{ + return &oauthToken{ ClientID: tkn.ClientID, UserID: tkn.UserID, RedirectURI: tkn.RedirectURI, @@ -221,8 +220,8 @@ func oauthTokenToPGToken(tkn *models.Token) *pgOauthToken { } } -// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a go-oauth2 token -func pgTokenToOauthToken(pgt *pgOauthToken) *models.Token { +// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token +func pgTokenToOauthToken(pgt *oauthToken) *models.Token { now := time.Now() return &models.Token{ |