summaryrefslogtreecommitdiff
path: root/internal/oauth
diff options
context:
space:
mode:
authorLibravatar Tobi Smethurst <31960611+tsmethurst@users.noreply.github.com>2021-05-08 14:25:55 +0200
committerLibravatar GitHub <noreply@github.com>2021-05-08 14:25:55 +0200
commit6f5c045284d34ba580d3007f70b97e05d6760527 (patch)
tree7614da22fba906361a918fb3527465b39272ac93 /internal/oauth
parentRevert "make boosts work woo (#12)" (#15) (diff)
downloadgotosocial-6f5c045284d34ba580d3007f70b97e05d6760527.tar.xz
Ap (#14)
Big restructuring and initial work on activitypub
Diffstat (limited to 'internal/oauth')
-rw-r--r--internal/oauth/clientstore.go3
-rw-r--r--internal/oauth/clientstore_test.go13
-rw-r--r--internal/oauth/oauth_test.go2
-rw-r--r--internal/oauth/server.go172
-rw-r--r--internal/oauth/tokenstore_test.go2
-rw-r--r--internal/oauth/util.go86
6 files changed, 137 insertions, 141 deletions
diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go
index 4e678891a..5241cf412 100644
--- a/internal/oauth/clientstore.go
+++ b/internal/oauth/clientstore.go
@@ -30,7 +30,8 @@ type clientStore struct {
db db.DB
}
-func newClientStore(db db.DB) oauth2.ClientStore {
+// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
+func NewClientStore(db db.DB) oauth2.ClientStore {
pts := &clientStore{
db: db,
}
diff --git a/internal/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go
index a7028228d..b77163e48 100644
--- a/internal/oauth/clientstore_test.go
+++ b/internal/oauth/clientstore_test.go
@@ -15,7 +15,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package oauth
+package oauth_test
import (
"context"
@@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/oauth2/v4/models"
)
@@ -61,7 +62,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
Database: "postgres",
ApplicationName: "gotosocial",
}
- db, err := db.New(context.Background(), c, log)
+ db, err := db.NewPostgresService(context.Background(), c, log)
if err != nil {
logrus.Panicf("error creating database connection: %s", err)
}
@@ -69,7 +70,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
suite.db = db
models := []interface{}{
- &Client{},
+ &oauth.Client{},
}
for _, m := range models {
@@ -82,7 +83,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
func (suite *PgClientStoreTestSuite) TearDownTest() {
models := []interface{}{
- &Client{},
+ &oauth.Client{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
@@ -97,7 +98,7 @@ func (suite *PgClientStoreTestSuite) TearDownTest() {
func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
// set a new client in the store
- cs := newClientStore(suite.db)
+ cs := oauth.NewClientStore(suite.db)
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
suite.FailNow(err.Error())
}
@@ -115,7 +116,7 @@ func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
// set a new client in the store
- cs := newClientStore(suite.db)
+ cs := oauth.NewClientStore(suite.db)
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
suite.FailNow(err.Error())
}
diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go
index 594b9b5a9..1b8449619 100644
--- a/internal/oauth/oauth_test.go
+++ b/internal/oauth/oauth_test.go
@@ -16,6 +16,6 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package oauth
+package oauth_test
// TODO: write tests
diff --git a/internal/oauth/server.go b/internal/oauth/server.go
index 1ddf18b03..7877d667e 100644
--- a/internal/oauth/server.go
+++ b/internal/oauth/server.go
@@ -23,10 +23,8 @@ import (
"fmt"
"net/http"
- "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/db/gtsmodel"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/errors"
"github.com/superseriousbusiness/oauth2/v4/manage"
@@ -66,94 +64,53 @@ type s struct {
log *logrus.Logger
}
-// Authed wraps an authorized token, application, user, and account.
-// It is used in the functions GetAuthed and MustAuth.
-// Because the user might *not* be authed, any of the fields in this struct
-// might be nil, so make sure to check that when you're using this struct anywhere.
-type Authed struct {
- Token oauth2.TokenInfo
- Application *gtsmodel.Application
- User *gtsmodel.User
- Account *gtsmodel.Account
-}
-
-// GetAuthed is a convenience function for returning an Authed struct from a gin context.
-// In essence, it tries to extract a token, application, user, and account from the context,
-// and then sets them on a struct for convenience.
-//
-// If any are not present in the context, they will be set to nil on the returned Authed struct.
-//
-// If *ALL* are not present, then nil and an error will be returned.
-//
-// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed).
-func GetAuthed(c *gin.Context) (*Authed, error) {
- ctx := c.Copy()
- a := &Authed{}
- var i interface{}
- var ok bool
+// New returns a new oauth server that implements the Server interface
+func New(database db.DB, log *logrus.Logger) Server {
+ ts := newTokenStore(context.Background(), database, log)
+ cs := NewClientStore(database)
- i, ok = ctx.Get(SessionAuthorizedToken)
- if ok {
- parsed, ok := i.(oauth2.TokenInfo)
- if !ok {
- return nil, errors.New("could not parse token from session context")
- }
- a.Token = parsed
+ manager := manage.NewDefaultManager()
+ manager.MapTokenStorage(ts)
+ manager.MapClientStorage(cs)
+ manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
+ sc := &server.Config{
+ TokenType: "Bearer",
+ // Must follow the spec.
+ AllowGetAccessRequest: false,
+ // Support only the non-implicit flow.
+ AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
+ // Allow:
+ // - Authorization Code (for first & third parties)
+ // - Client Credentials (for applications)
+ AllowedGrantTypes: []oauth2.GrantType{
+ oauth2.AuthorizationCode,
+ oauth2.ClientCredentials,
+ },
+ AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
}
- i, ok = ctx.Get(SessionAuthorizedApplication)
- if ok {
- parsed, ok := i.(*gtsmodel.Application)
- if !ok {
- return nil, errors.New("could not parse application from session context")
- }
- a.Application = parsed
- }
+ srv := server.NewServer(sc, manager)
+ srv.SetInternalErrorHandler(func(err error) *errors.Response {
+ log.Errorf("internal oauth error: %s", err)
+ return nil
+ })
- i, ok = ctx.Get(SessionAuthorizedUser)
- if ok {
- parsed, ok := i.(*gtsmodel.User)
- if !ok {
- return nil, errors.New("could not parse user from session context")
- }
- a.User = parsed
- }
+ srv.SetResponseErrorHandler(func(re *errors.Response) {
+ log.Errorf("internal response error: %s", re.Error)
+ })
- i, ok = ctx.Get(SessionAuthorizedAccount)
- if ok {
- parsed, ok := i.(*gtsmodel.Account)
- if !ok {
- return nil, errors.New("could not parse account from session context")
+ srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
+ userID := r.FormValue("userid")
+ if userID == "" {
+ return "", errors.New("userid was empty")
}
- a.Account = parsed
- }
-
- if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil {
- return nil, errors.New("not authorized")
- }
-
- return a, nil
-}
-
-// MustAuth is like GetAuthed, but will fail if one of the requirements is not met.
-func MustAuth(c *gin.Context, requireToken bool, requireApp bool, requireUser bool, requireAccount bool) (*Authed, error) {
- a, err := GetAuthed(c)
- if err != nil {
- return nil, err
- }
- if requireToken && a.Token == nil {
- return nil, errors.New("token not supplied")
- }
- if requireApp && a.Application == nil {
- return nil, errors.New("application not supplied")
- }
- if requireUser && a.User == nil {
- return nil, errors.New("user not supplied")
- }
- if requireAccount && a.Account == nil {
- return nil, errors.New("account not supplied")
+ return userID, nil
+ })
+ srv.SetClientInfoHandler(server.ClientFormHandler)
+ return &s{
+ server: srv,
+ log: log,
}
- return a, nil
}
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
@@ -211,52 +168,3 @@ func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, us
s.log.Tracef("obtained user-level access token: %+v", accessToken)
return accessToken, nil
}
-
-// New returns a new oauth server that implements the Server interface
-func New(database db.DB, log *logrus.Logger) Server {
- ts := newTokenStore(context.Background(), database, log)
- cs := newClientStore(database)
-
- manager := manage.NewDefaultManager()
- manager.MapTokenStorage(ts)
- manager.MapClientStorage(cs)
- manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
- sc := &server.Config{
- TokenType: "Bearer",
- // Must follow the spec.
- AllowGetAccessRequest: false,
- // Support only the non-implicit flow.
- AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
- // Allow:
- // - Authorization Code (for first & third parties)
- // - Client Credentials (for applications)
- AllowedGrantTypes: []oauth2.GrantType{
- oauth2.AuthorizationCode,
- oauth2.ClientCredentials,
- },
- AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
- }
-
- srv := server.NewServer(sc, manager)
- srv.SetInternalErrorHandler(func(err error) *errors.Response {
- log.Errorf("internal oauth error: %s", err)
- return nil
- })
-
- srv.SetResponseErrorHandler(func(re *errors.Response) {
- log.Errorf("internal response error: %s", re.Error)
- })
-
- srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
- userID := r.FormValue("userid")
- if userID == "" {
- return "", errors.New("userid was empty")
- }
- return userID, nil
- })
- srv.SetClientInfoHandler(server.ClientFormHandler)
- return &s{
- server: srv,
- log: log,
- }
-}
diff --git a/internal/oauth/tokenstore_test.go b/internal/oauth/tokenstore_test.go
index 594b9b5a9..1b8449619 100644
--- a/internal/oauth/tokenstore_test.go
+++ b/internal/oauth/tokenstore_test.go
@@ -16,6 +16,6 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package oauth
+package oauth_test
// TODO: write tests
diff --git a/internal/oauth/util.go b/internal/oauth/util.go
new file mode 100644
index 000000000..378b81450
--- /dev/null
+++ b/internal/oauth/util.go
@@ -0,0 +1,86 @@
+package oauth
+
+import (
+ "github.com/gin-gonic/gin"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/oauth2/v4"
+ "github.com/superseriousbusiness/oauth2/v4/errors"
+)
+
+// Auth wraps an authorized token, application, user, and account.
+// It is used in the functions GetAuthed and MustAuth.
+// Because the user might *not* be authed, any of the fields in this struct
+// might be nil, so make sure to check that when you're using this struct anywhere.
+type Auth struct {
+ Token oauth2.TokenInfo
+ Application *gtsmodel.Application
+ User *gtsmodel.User
+ Account *gtsmodel.Account
+}
+
+// Authed is a convenience function for returning an Authed struct from a gin context.
+// In essence, it tries to extract a token, application, user, and account from the context,
+// and then sets them on a struct for convenience.
+//
+// If any are not present in the context, they will be set to nil on the returned Authed struct.
+//
+// If *ALL* are not present, then nil and an error will be returned.
+//
+// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed).
+// Authed is like GetAuthed, but will fail if one of the requirements is not met.
+func Authed(c *gin.Context, requireToken bool, requireApp bool, requireUser bool, requireAccount bool) (*Auth, error) {
+ ctx := c.Copy()
+ a := &Auth{}
+ var i interface{}
+ var ok bool
+
+ i, ok = ctx.Get(SessionAuthorizedToken)
+ if ok {
+ parsed, ok := i.(oauth2.TokenInfo)
+ if !ok {
+ return nil, errors.New("could not parse token from session context")
+ }
+ a.Token = parsed
+ }
+
+ i, ok = ctx.Get(SessionAuthorizedApplication)
+ if ok {
+ parsed, ok := i.(*gtsmodel.Application)
+ if !ok {
+ return nil, errors.New("could not parse application from session context")
+ }
+ a.Application = parsed
+ }
+
+ i, ok = ctx.Get(SessionAuthorizedUser)
+ if ok {
+ parsed, ok := i.(*gtsmodel.User)
+ if !ok {
+ return nil, errors.New("could not parse user from session context")
+ }
+ a.User = parsed
+ }
+
+ i, ok = ctx.Get(SessionAuthorizedAccount)
+ if ok {
+ parsed, ok := i.(*gtsmodel.Account)
+ if !ok {
+ return nil, errors.New("could not parse account from session context")
+ }
+ a.Account = parsed
+ }
+
+ if requireToken && a.Token == nil {
+ return nil, errors.New("token not supplied")
+ }
+ if requireApp && a.Application == nil {
+ return nil, errors.New("application not supplied")
+ }
+ if requireUser && a.User == nil {
+ return nil, errors.New("user not supplied")
+ }
+ if requireAccount && a.Account == nil {
+ return nil, errors.New("account not supplied")
+ }
+ return a, nil
+}