diff options
Diffstat (limited to 'internal/oauth/server.go')
-rw-r--r-- | internal/oauth/server.go | 172 |
1 files changed, 40 insertions, 132 deletions
diff --git a/internal/oauth/server.go b/internal/oauth/server.go index 1ddf18b03..7877d667e 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -23,10 +23,8 @@ import ( "fmt" "net/http" - "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/db/gtsmodel" "github.com/superseriousbusiness/oauth2/v4" "github.com/superseriousbusiness/oauth2/v4/errors" "github.com/superseriousbusiness/oauth2/v4/manage" @@ -66,94 +64,53 @@ type s struct { log *logrus.Logger } -// Authed wraps an authorized token, application, user, and account. -// It is used in the functions GetAuthed and MustAuth. -// Because the user might *not* be authed, any of the fields in this struct -// might be nil, so make sure to check that when you're using this struct anywhere. -type Authed struct { - Token oauth2.TokenInfo - Application *gtsmodel.Application - User *gtsmodel.User - Account *gtsmodel.Account -} - -// GetAuthed is a convenience function for returning an Authed struct from a gin context. -// In essence, it tries to extract a token, application, user, and account from the context, -// and then sets them on a struct for convenience. -// -// If any are not present in the context, they will be set to nil on the returned Authed struct. -// -// If *ALL* are not present, then nil and an error will be returned. -// -// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed). -func GetAuthed(c *gin.Context) (*Authed, error) { - ctx := c.Copy() - a := &Authed{} - var i interface{} - var ok bool +// New returns a new oauth server that implements the Server interface +func New(database db.DB, log *logrus.Logger) Server { + ts := newTokenStore(context.Background(), database, log) + cs := NewClientStore(database) - i, ok = ctx.Get(SessionAuthorizedToken) - if ok { - parsed, ok := i.(oauth2.TokenInfo) - if !ok { - return nil, errors.New("could not parse token from session context") - } - a.Token = parsed + manager := manage.NewDefaultManager() + manager.MapTokenStorage(ts) + manager.MapClientStorage(cs) + manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) + sc := &server.Config{ + TokenType: "Bearer", + // Must follow the spec. + AllowGetAccessRequest: false, + // Support only the non-implicit flow. + AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, + // Allow: + // - Authorization Code (for first & third parties) + // - Client Credentials (for applications) + AllowedGrantTypes: []oauth2.GrantType{ + oauth2.AuthorizationCode, + oauth2.ClientCredentials, + }, + AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain}, } - i, ok = ctx.Get(SessionAuthorizedApplication) - if ok { - parsed, ok := i.(*gtsmodel.Application) - if !ok { - return nil, errors.New("could not parse application from session context") - } - a.Application = parsed - } + srv := server.NewServer(sc, manager) + srv.SetInternalErrorHandler(func(err error) *errors.Response { + log.Errorf("internal oauth error: %s", err) + return nil + }) - i, ok = ctx.Get(SessionAuthorizedUser) - if ok { - parsed, ok := i.(*gtsmodel.User) - if !ok { - return nil, errors.New("could not parse user from session context") - } - a.User = parsed - } + srv.SetResponseErrorHandler(func(re *errors.Response) { + log.Errorf("internal response error: %s", re.Error) + }) - i, ok = ctx.Get(SessionAuthorizedAccount) - if ok { - parsed, ok := i.(*gtsmodel.Account) - if !ok { - return nil, errors.New("could not parse account from session context") + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { + userID := r.FormValue("userid") + if userID == "" { + return "", errors.New("userid was empty") } - a.Account = parsed - } - - if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil { - return nil, errors.New("not authorized") - } - - return a, nil -} - -// MustAuth is like GetAuthed, but will fail if one of the requirements is not met. -func MustAuth(c *gin.Context, requireToken bool, requireApp bool, requireUser bool, requireAccount bool) (*Authed, error) { - a, err := GetAuthed(c) - if err != nil { - return nil, err - } - if requireToken && a.Token == nil { - return nil, errors.New("token not supplied") - } - if requireApp && a.Application == nil { - return nil, errors.New("application not supplied") - } - if requireUser && a.User == nil { - return nil, errors.New("user not supplied") - } - if requireAccount && a.Account == nil { - return nil, errors.New("account not supplied") + return userID, nil + }) + srv.SetClientInfoHandler(server.ClientFormHandler) + return &s{ + server: srv, + log: log, } - return a, nil } // HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function @@ -211,52 +168,3 @@ func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, us s.log.Tracef("obtained user-level access token: %+v", accessToken) return accessToken, nil } - -// New returns a new oauth server that implements the Server interface -func New(database db.DB, log *logrus.Logger) Server { - ts := newTokenStore(context.Background(), database, log) - cs := newClientStore(database) - - manager := manage.NewDefaultManager() - manager.MapTokenStorage(ts) - manager.MapClientStorage(cs) - manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) - sc := &server.Config{ - TokenType: "Bearer", - // Must follow the spec. - AllowGetAccessRequest: false, - // Support only the non-implicit flow. - AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, - // Allow: - // - Authorization Code (for first & third parties) - // - Client Credentials (for applications) - AllowedGrantTypes: []oauth2.GrantType{ - oauth2.AuthorizationCode, - oauth2.ClientCredentials, - }, - AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain}, - } - - srv := server.NewServer(sc, manager) - srv.SetInternalErrorHandler(func(err error) *errors.Response { - log.Errorf("internal oauth error: %s", err) - return nil - }) - - srv.SetResponseErrorHandler(func(re *errors.Response) { - log.Errorf("internal response error: %s", re.Error) - }) - - srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { - userID := r.FormValue("userid") - if userID == "" { - return "", errors.New("userid was empty") - } - return userID, nil - }) - srv.SetClientInfoHandler(server.ClientFormHandler) - return &s{ - server: srv, - log: log, - } -} |