summaryrefslogtreecommitdiff
path: root/internal/api/auth
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2023-01-02 13:10:50 +0100
committerLibravatar GitHub <noreply@github.com>2023-01-02 12:10:50 +0000
commit941893a774c83802afdc4cc76e1d30c59b6c5585 (patch)
tree6e7296146dedfeac8e83655157270f41e190724b /internal/api/auth
parent[chore]: Bump github.com/abema/go-mp4 from 0.8.0 to 0.9.0 (#1287) (diff)
downloadgotosocial-941893a774c83802afdc4cc76e1d30c59b6c5585.tar.xz
[chore] The Big Middleware and API Refactor (tm) (#1250)
* interim commit: start refactoring middlewares into package under router * another interim commit, this is becoming a big job * another fucking massive interim commit * refactor bookmarks to new style * ambassador, wiz zeze commits you are spoiling uz * she compiles, we're getting there * we're just normal men; we're just innocent men * apiutil * whoopsie * i'm glad noone reads commit msgs haha :blob_sweat: * use that weirdo go-bytesize library for maxMultipartMemory * fix media module paths
Diffstat (limited to 'internal/api/auth')
-rw-r--r--internal/api/auth/auth.go117
-rw-r--r--internal/api/auth/auth_test.go129
-rw-r--r--internal/api/auth/authorize.go335
-rw-r--r--internal/api/auth/authorize_test.go118
-rw-r--r--internal/api/auth/callback.go317
-rw-r--r--internal/api/auth/oob.go113
-rw-r--r--internal/api/auth/signin.go145
-rw-r--r--internal/api/auth/token.go115
-rw-r--r--internal/api/auth/token_test.go215
9 files changed, 1604 insertions, 0 deletions
diff --git a/internal/api/auth/auth.go b/internal/api/auth/auth.go
new file mode 100644
index 000000000..7ce992466
--- /dev/null
+++ b/internal/api/auth/auth.go
@@ -0,0 +1,117 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package auth
+
+import (
+ "net/http"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/oidc"
+ "github.com/superseriousbusiness/gotosocial/internal/processing"
+)
+
+const (
+ /*
+ paths prefixed with 'auth'
+ */
+
+ // AuthSignInPath is the API path for users to sign in through
+ AuthSignInPath = "/sign_in"
+ // AuthCheckYourEmailPath users land here after registering a new account, instructs them to confirm their email
+ AuthCheckYourEmailPath = "/check_your_email"
+ // AuthWaitForApprovalPath users land here after confirming their email
+ // but before an admin approves their account (if such is required)
+ AuthWaitForApprovalPath = "/wait_for_approval"
+ // AuthAccountDisabledPath users land here when their account is suspended by an admin
+ AuthAccountDisabledPath = "/account_disabled"
+ // AuthCallbackPath is the API path for receiving callback tokens from external OIDC providers
+ AuthCallbackPath = "/callback"
+
+ /*
+ paths prefixed with 'oauth'
+ */
+
+ // OauthTokenPath is the API path to use for granting token requests to users with valid credentials
+ OauthTokenPath = "/token" // #nosec G101 else we get a hardcoded credentials warning
+ // OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user)
+ OauthAuthorizePath = "/authorize"
+ // OauthFinalizePath is the API path for completing user registration with additional user details
+ OauthFinalizePath = "/finalize"
+ // OauthOobTokenPath is the path for serving an html representation of an oob token page.
+ OauthOobTokenPath = "/oob" // #nosec G101 else we get a hardcoded credentials warning
+
+ /*
+ params / session keys
+ */
+
+ callbackStateParam = "state"
+ callbackCodeParam = "code"
+ sessionUserID = "userid"
+ sessionClientID = "client_id"
+ sessionRedirectURI = "redirect_uri"
+ sessionForceLogin = "force_login"
+ sessionResponseType = "response_type"
+ sessionScope = "scope"
+ sessionInternalState = "internal_state"
+ sessionClientState = "client_state"
+ sessionClaims = "claims"
+ sessionAppID = "app_id"
+)
+
+type Module struct {
+ db db.DB
+ processor processing.Processor
+ idp oidc.IDP
+}
+
+// New returns an Auth module which provides both 'oauth' and 'auth' endpoints.
+//
+// It is safe to pass a nil idp if oidc is disabled.
+func New(db db.DB, processor processing.Processor, idp oidc.IDP) *Module {
+ return &Module{
+ db: db,
+ processor: processor,
+ idp: idp,
+ }
+}
+
+// RouteAuth routes all paths that should have an 'auth' prefix
+func (m *Module) RouteAuth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
+ attachHandler(http.MethodGet, AuthSignInPath, m.SignInGETHandler)
+ attachHandler(http.MethodPost, AuthSignInPath, m.SignInPOSTHandler)
+ attachHandler(http.MethodGet, AuthCallbackPath, m.CallbackGETHandler)
+}
+
+// RouteOauth routes all paths that should have an 'oauth' prefix
+func (m *Module) RouteOauth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
+ attachHandler(http.MethodPost, OauthTokenPath, m.TokenPOSTHandler)
+ attachHandler(http.MethodGet, OauthAuthorizePath, m.AuthorizeGETHandler)
+ attachHandler(http.MethodPost, OauthAuthorizePath, m.AuthorizePOSTHandler)
+ attachHandler(http.MethodPost, OauthFinalizePath, m.FinalizePOSTHandler)
+ attachHandler(http.MethodGet, OauthOobTokenPath, m.OobHandler)
+}
+
+func (m *Module) clearSession(s sessions.Session) {
+ s.Clear()
+ if err := s.Save(); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go
new file mode 100644
index 000000000..cb92850d0
--- /dev/null
+++ b/internal/api/auth/auth_test.go
@@ -0,0 +1,129 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package auth_test
+
+import (
+ "bytes"
+ "fmt"
+ "net/http/httptest"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-contrib/sessions/memstore"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/api/auth"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/email"
+ "github.com/superseriousbusiness/gotosocial/internal/federation"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/media"
+ "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/middleware"
+ "github.com/superseriousbusiness/gotosocial/internal/oidc"
+ "github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type AuthStandardTestSuite struct {
+ suite.Suite
+ db db.DB
+ storage *storage.Driver
+ mediaManager media.Manager
+ federator federation.Federator
+ processor processing.Processor
+ emailSender email.Sender
+ idp oidc.IDP
+
+ // standard suite models
+ testTokens map[string]*gtsmodel.Token
+ testClients map[string]*gtsmodel.Client
+ testApplications map[string]*gtsmodel.Application
+ testUsers map[string]*gtsmodel.User
+ testAccounts map[string]*gtsmodel.Account
+
+ // module being tested
+ authModule *auth.Module
+}
+
+const (
+ sessionUserID = "userid"
+ sessionClientID = "client_id"
+)
+
+func (suite *AuthStandardTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+}
+
+func (suite *AuthStandardTestSuite) SetupTest() {
+ testrig.InitTestConfig()
+ testrig.InitTestLog()
+
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+
+ suite.db = testrig.NewTestDB()
+ suite.storage = testrig.NewInMemoryStorage()
+ suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
+ suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
+ suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.authModule = auth.New(suite.db, suite.processor, suite.idp)
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *AuthStandardTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath string, requestBody []byte, bodyContentType string) (*gin.Context, *httptest.ResponseRecorder) {
+ // create the recorder and gin test context
+ recorder := httptest.NewRecorder()
+ ctx, engine := testrig.CreateGinTestContext(recorder, nil)
+
+ // load templates into the engine
+ testrig.ConfigureTemplatesWithGin(engine, "../../../web/template")
+
+ // create the request
+ protocol := config.GetProtocol()
+ host := config.GetHost()
+ baseURI := fmt.Sprintf("%s://%s", protocol, host)
+ requestURI := fmt.Sprintf("%s/%s", baseURI, requestPath)
+
+ ctx.Request = httptest.NewRequest(requestMethod, requestURI, bytes.NewReader(requestBody)) // the endpoint we're hitting
+ ctx.Request.Header.Set("accept", "text/html")
+
+ if bodyContentType != "" {
+ ctx.Request.Header.Set("Content-Type", bodyContentType)
+ }
+
+ // trigger the session middleware on the context
+ store := memstore.NewStore(make([]byte, 32), make([]byte, 32))
+ store.Options(middleware.SessionOptions())
+ sessionMiddleware := sessions.Sessions("gotosocial-localhost", store)
+ sessionMiddleware(ctx)
+
+ return ctx, recorder
+}
diff --git a/internal/api/auth/authorize.go b/internal/api/auth/authorize.go
new file mode 100644
index 000000000..e504f6be2
--- /dev/null
+++ b/internal/api/auth/authorize.go
@@ -0,0 +1,335 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package auth
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+)
+
+// AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize
+// The idea here is to present an oauth authorize page to the user, with a button
+// that they have to click to accept.
+func (m *Module) AuthorizeGETHandler(c *gin.Context) {
+ s := sessions.Default(c)
+
+ if _, err := apiutil.NegotiateAccept(c, apiutil.HTMLAcceptHeaders...); err != nil {
+ apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ // UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
+ // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
+ userID, ok := s.Get(sessionUserID).(string)
+ if !ok || userID == "" {
+ form := &apimodel.OAuthAuthorize{}
+ if err := c.ShouldBind(form); err != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ if errWithCode := saveAuthFormToSession(s, form); errWithCode != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ c.Redirect(http.StatusSeeOther, "/auth"+AuthSignInPath)
+ return
+ }
+
+ // use session information to validate app, user, and account for this request
+ clientID, ok := s.Get(sessionClientID).(string)
+ if !ok || clientID == "" {
+ m.clearSession(s)
+ err := fmt.Errorf("key %s was not found in session", sessionClientID)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ app := &gtsmodel.Application{}
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil {
+ m.clearSession(s)
+ safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID)
+ var errWithCode gtserror.WithCode
+ if err == db.ErrNoEntries {
+ errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
+ } else {
+ errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
+ }
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
+ m.clearSession(s)
+ safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
+ var errWithCode gtserror.WithCode
+ if err == db.ErrNoEntries {
+ errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
+ } else {
+ errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
+ }
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil {
+ m.clearSession(s)
+ safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID)
+ var errWithCode gtserror.WithCode
+ if err == db.ErrNoEntries {
+ errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
+ } else {
+ errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
+ }
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ if ensureUserIsAuthorizedOrRedirect(c, user, acct) {
+ return
+ }
+
+ // Finally we should also get the redirect and scope of this particular request, as stored in the session.
+ redirect, ok := s.Get(sessionRedirectURI).(string)
+ if !ok || redirect == "" {
+ m.clearSession(s)
+ err := fmt.Errorf("key %s was not found in session", sessionRedirectURI)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ scope, ok := s.Get(sessionScope).(string)
+ if !ok || scope == "" {
+ m.clearSession(s)
+ err := fmt.Errorf("key %s was not found in session", sessionScope)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost())
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ // the authorize template will display a form to the user where they can get some information
+ // about the app that's trying to authorize, and the scope of the request.
+ // They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler
+ c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
+ "appname": app.Name,
+ "appwebsite": app.Website,
+ "redirect": redirect,
+ "scope": scope,
+ "user": acct.Username,
+ "instance": instance,
+ })
+}
+
+// AuthorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
+// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
+// so we should proceed with the authentication flow and generate an oauth token for them if we can.
+func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
+ s := sessions.Default(c)
+
+ // We need to retrieve the original form submitted to the authorizeGEThandler, and
+ // recreate it on the request so that it can be used further by the oauth2 library.
+ errs := []string{}
+
+ forceLogin, ok := s.Get(sessionForceLogin).(string)
+ if !ok {
+ forceLogin = "false"
+ }
+
+ responseType, ok := s.Get(sessionResponseType).(string)
+ if !ok || responseType == "" {
+ errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionResponseType))
+ }
+
+ clientID, ok := s.Get(sessionClientID).(string)
+ if !ok || clientID == "" {
+ errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionClientID))
+ }
+
+ redirectURI, ok := s.Get(sessionRedirectURI).(string)
+ if !ok || redirectURI == "" {
+ errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionRedirectURI))
+ }
+
+ scope, ok := s.Get(sessionScope).(string)
+ if !ok {
+ errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope))
+ }
+
+ var clientState string
+ if s, ok := s.Get(sessionClientState).(string); ok {
+ clientState = s
+ }
+
+ userID, ok := s.Get(sessionUserID).(string)
+ if !ok {
+ errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID))
+ }
+
+ if len(errs) != 0 {
+ errs = append(errs, oauth.HelpfulAdvice)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during AuthorizePOSTHandler"), errs...), m.processor.InstanceGet)
+ return
+ }
+
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
+ m.clearSession(s)
+ safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
+ var errWithCode gtserror.WithCode
+ if err == db.ErrNoEntries {
+ errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
+ } else {
+ errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
+ }
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil {
+ m.clearSession(s)
+ safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID)
+ var errWithCode gtserror.WithCode
+ if err == db.ErrNoEntries {
+ errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
+ } else {
+ errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
+ }
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ if ensureUserIsAuthorizedOrRedirect(c, user, acct) {
+ return
+ }
+
+ if redirectURI != oauth.OOBURI {
+ // we're done with the session now, so just clear it out
+ m.clearSession(s)
+ }
+
+ // we have to set the values on the request form
+ // so that they're picked up by the oauth server
+ c.Request.Form = url.Values{
+ sessionForceLogin: {forceLogin},
+ sessionResponseType: {responseType},
+ sessionClientID: {clientID},
+ sessionRedirectURI: {redirectURI},
+ sessionScope: {scope},
+ sessionUserID: {userID},
+ }
+
+ if clientState != "" {
+ c.Request.Form.Set("state", clientState)
+ }
+
+ if errWithCode := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ }
+}
+
+// saveAuthFormToSession checks the given OAuthAuthorize form,
+// and stores the values in the form into the session.
+func saveAuthFormToSession(s sessions.Session, form *apimodel.OAuthAuthorize) gtserror.WithCode {
+ if form == nil {
+ err := errors.New("OAuthAuthorize form was nil")
+ return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
+ }
+
+ if form.ResponseType == "" {
+ err := errors.New("field response_type was not set on OAuthAuthorize form")
+ return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
+ }
+
+ if form.ClientID == "" {
+ err := errors.New("field client_id was not set on OAuthAuthorize form")
+ return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
+ }
+
+ if form.RedirectURI == "" {
+ err := errors.New("field redirect_uri was not set on OAuthAuthorize form")
+ return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
+ }
+
+ // set default scope to read
+ if form.Scope == "" {
+ form.Scope = "read"
+ }
+
+ // save these values from the form so we can use them elsewhere in the session
+ s.Set(sessionForceLogin, form.ForceLogin)
+ s.Set(sessionResponseType, form.ResponseType)
+ s.Set(sessionClientID, form.ClientID)
+ s.Set(sessionRedirectURI, form.RedirectURI)
+ s.Set(sessionScope, form.Scope)
+ s.Set(sessionInternalState, uuid.NewString())
+ s.Set(sessionClientState, form.State)
+
+ if err := s.Save(); err != nil {
+ err := fmt.Errorf("error saving form values onto session: %s", err)
+ return gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice)
+ }
+
+ return nil
+}
+
+func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) (redirected bool) {
+ if user.ConfirmedAt.IsZero() {
+ ctx.Redirect(http.StatusSeeOther, "/auth"+AuthCheckYourEmailPath)
+ redirected = true
+ return
+ }
+
+ if !*user.Approved {
+ ctx.Redirect(http.StatusSeeOther, "/auth"+AuthWaitForApprovalPath)
+ redirected = true
+ return
+ }
+
+ if *user.Disabled || !account.SuspendedAt.IsZero() {
+ ctx.Redirect(http.StatusSeeOther, "/auth"+AuthAccountDisabledPath)
+ redirected = true
+ return
+ }
+
+ return
+}
diff --git a/internal/api/auth/authorize_test.go b/internal/api/auth/authorize_test.go
new file mode 100644
index 000000000..ff65d041b
--- /dev/null
+++ b/internal/api/auth/authorize_test.go
@@ -0,0 +1,118 @@
+package auth_test
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/api/auth"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type AuthAuthorizeTestSuite struct {
+ AuthStandardTestSuite
+}
+
+type authorizeHandlerTestCase struct {
+ description string
+ mutateUserAccount func(*gtsmodel.User, *gtsmodel.Account) []string
+ expectedStatusCode int
+ expectedLocationHeader string
+}
+
+func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
+ tests := []authorizeHandlerTestCase{
+ {
+ description: "user has their email unconfirmed",
+ mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string {
+ user.ConfirmedAt = time.Time{}
+ return []string{"confirmed_at"}
+ },
+ expectedStatusCode: http.StatusSeeOther,
+ expectedLocationHeader: "/auth" + auth.AuthCheckYourEmailPath,
+ },
+ {
+ description: "user has their email confirmed but is not approved",
+ mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string {
+ user.ConfirmedAt = time.Now()
+ user.Email = user.UnconfirmedEmail
+ return []string{"confirmed_at", "email"}
+ },
+ expectedStatusCode: http.StatusSeeOther,
+ expectedLocationHeader: "/auth" + auth.AuthWaitForApprovalPath,
+ },
+ {
+ description: "user has their email confirmed and is approved, but User entity has been disabled",
+ mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string {
+ user.ConfirmedAt = time.Now()
+ user.Email = user.UnconfirmedEmail
+ user.Approved = testrig.TrueBool()
+ user.Disabled = testrig.TrueBool()
+ return []string{"confirmed_at", "email", "approved", "disabled"}
+ },
+ expectedStatusCode: http.StatusSeeOther,
+ expectedLocationHeader: "/auth" + auth.AuthAccountDisabledPath,
+ },
+ {
+ description: "user has their email confirmed and is approved, but Account entity has been suspended",
+ mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string {
+ user.ConfirmedAt = time.Now()
+ user.Email = user.UnconfirmedEmail
+ user.Approved = testrig.TrueBool()
+ user.Disabled = testrig.FalseBool()
+ account.SuspendedAt = time.Now()
+ return []string{"confirmed_at", "email", "approved", "disabled"}
+ },
+ expectedStatusCode: http.StatusSeeOther,
+ expectedLocationHeader: "/auth" + auth.AuthAccountDisabledPath,
+ },
+ }
+
+ doTest := func(testCase authorizeHandlerTestCase) {
+ ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "")
+
+ user := &gtsmodel.User{}
+ account := &gtsmodel.Account{}
+
+ *user = *suite.testUsers["unconfirmed_account"]
+ *account = *suite.testAccounts["unconfirmed_account"]
+
+ testSession := sessions.Default(ctx)
+ testSession.Set(sessionUserID, user.ID)
+ testSession.Set(sessionClientID, suite.testApplications["application_1"].ClientID)
+ if err := testSession.Save(); err != nil {
+ panic(fmt.Errorf("failed on case %s: %w", testCase.description, err))
+ }
+
+ columns := testCase.mutateUserAccount(user, account)
+
+ testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)
+
+ err := suite.db.UpdateUser(context.Background(), user, columns...)
+ suite.NoError(err)
+ err = suite.db.UpdateAccount(context.Background(), account)
+ suite.NoError(err)
+
+ // call the handler
+ suite.authModule.AuthorizeGETHandler(ctx)
+
+ // 1. we should have a redirect
+ suite.Equal(testCase.expectedStatusCode, recorder.Code, fmt.Sprintf("failed on case: %s", testCase.description))
+
+ // 2. we should have a redirect to the check your email path, as this user has not confirmed their email yet.
+ suite.Equal(testCase.expectedLocationHeader, recorder.Header().Get("Location"), fmt.Sprintf("failed on case: %s", testCase.description))
+ }
+
+ for _, testCase := range tests {
+ doTest(testCase)
+ }
+}
+
+func TestAccountUpdateTestSuite(t *testing.T) {
+ suite.Run(t, new(AuthAuthorizeTestSuite))
+}
diff --git a/internal/api/auth/callback.go b/internal/api/auth/callback.go
new file mode 100644
index 000000000..d344b5d5f
--- /dev/null
+++ b/internal/api/auth/callback.go
@@ -0,0 +1,317 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package auth
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "strings"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/oidc"
+ "github.com/superseriousbusiness/gotosocial/internal/validate"
+)
+
+// extraInfo wraps a form-submitted username and transmitted name
+type extraInfo struct {
+ Username string `form:"username"`
+ Name string `form:"name"` // note that this is only used for re-rendering the page in case of an error
+}
+
+// CallbackGETHandler parses a token from an external auth provider.
+func (m *Module) CallbackGETHandler(c *gin.Context) {
+ if !config.GetOIDCEnabled() {
+ err := errors.New("oidc is not enabled for this server")
+ apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ s := sessions.Default(c)
+
+ // check the query vs session state parameter to mitigate csrf
+ // https://auth0.com/docs/secure/attack-protection/state-parameters
+
+ returnedInternalState := c.Query(callbackStateParam)
+ if returnedInternalState == "" {
+ m.clearSession(s)
+ err := fmt.Errorf("%s parameter not found on callback query", callbackStateParam)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ savedInternalStateI := s.Get(sessionInternalState)
+ savedInternalState, ok := savedInternalStateI.(string)
+ if !ok {
+ m.clearSession(s)
+ err := fmt.Errorf("key %s was not found in session", sessionInternalState)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ if returnedInternalState != savedInternalState {
+ m.clearSession(s)
+ err := errors.New("mismatch between callback state and saved state")
+ apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ // retrieve stored claims using code
+ code := c.Query(callbackCodeParam)
+ if code == "" {
+ m.clearSession(s)
+ err := fmt.Errorf("%s parameter not found on callback query", callbackCodeParam)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ claims, errWithCode := m.idp.HandleCallback(c.Request.Context(), code)
+ if errWithCode != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ // We can use the client_id on the session to retrieve
+ // info about the app associated with the client_id
+ clientID, ok := s.Get(sessionClientID).(string)
+ if !ok || clientID == "" {
+ m.clearSession(s)
+ err := fmt.Errorf("key %s was not found in session", sessionClientID)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ app := &gtsmodel.Application{}
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil {
+ m.clearSession(s)
+ safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID)
+ var errWithCode gtserror.WithCode
+ if err == db.ErrNoEntries {
+ errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
+ } else {
+ errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
+ }
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ user, errWithCode := m.fetchUserForClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID)
+ if errWithCode != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+ if user == nil {
+ // no user exists yet - let's ask them for their preferred username
+ instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost())
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ // store the claims in the session - that way we know the user is authenticated when processing the form later
+ s.Set(sessionClaims, claims)
+ s.Set(sessionAppID, app.ID)
+ if err := s.Save(); err != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
+ return
+ }
+ c.HTML(http.StatusOK, "finalize.tmpl", gin.H{
+ "instance": instance,
+ "name": claims.Name,
+ "preferredUsername": claims.PreferredUsername,
+ })
+ return
+ }
+ s.Set(sessionUserID, user.ID)
+ if err := s.Save(); err != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
+ return
+ }
+ c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath)
+}
+
+// FinalizePOSTHandler registers the user after additional data has been provided
+func (m *Module) FinalizePOSTHandler(c *gin.Context) {
+ s := sessions.Default(c)
+
+ form := &extraInfo{}
+ if err := c.ShouldBind(form); err != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ // since we have multiple possible validation error, `validationError` is a shorthand for rendering them
+ validationError := func(err error) {
+ instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost())
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+ c.HTML(http.StatusOK, "finalize.tmpl", gin.H{
+ "instance": instance,
+ "name": form.Name,
+ "preferredUsername": form.Username,
+ "error": err,
+ })
+ }
+
+ // check if the username conforms to the spec
+ if err := validate.Username(form.Username); err != nil {
+ validationError(err)
+ return
+ }
+
+ // see if the username is still available
+ usernameAvailable, err := m.db.IsUsernameAvailable(c.Request.Context(), form.Username)
+ if err != nil {
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+ if !usernameAvailable {
+ validationError(fmt.Errorf("Username %s is already taken", form.Username))
+ return
+ }
+
+ // retrieve the information previously set by the oidc logic
+ appID, ok := s.Get(sessionAppID).(string)
+ if !ok {
+ err := fmt.Errorf("key %s was not found in session", sessionAppID)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ // retrieve the claims returned by the IDP. Having this present means that we previously already verified these claims
+ claims, ok := s.Get(sessionClaims).(*oidc.Claims)
+ if !ok {
+ err := fmt.Errorf("key %s was not found in session", sessionClaims)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ // we're now ready to actually create the user
+ user, errWithCode := m.createUserFromOIDC(c.Request.Context(), claims, form, net.IP(c.ClientIP()), appID)
+ if errWithCode != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+ s.Delete(sessionClaims)
+ s.Delete(sessionAppID)
+ s.Set(sessionUserID, user.ID)
+ if err := s.Save(); err != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
+ return
+ }
+ c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath)
+}
+
+func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) {
+ if claims.Sub == "" {
+ err := errors.New("no sub claim found - is your provider OIDC compliant?")
+ return nil, gtserror.NewErrorBadRequest(err, err.Error())
+ }
+ user, err := m.db.GetUserByExternalID(ctx, claims.Sub)
+ if err == nil {
+ return user, nil
+ }
+ if err != db.ErrNoEntries {
+ err := fmt.Errorf("error checking database for externalID %s: %s", claims.Sub, err)
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ if !config.GetOIDCLinkExisting() {
+ return nil, nil
+ }
+ // fallback to email if we want to link existing users
+ user, err = m.db.GetUserByEmailAddress(ctx, claims.Email)
+ if err == db.ErrNoEntries {
+ return nil, nil
+ } else if err != nil {
+ err := fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ // at this point we have found a matching user but still need to link the newly received external ID
+
+ user.ExternalID = claims.Sub
+ err = m.db.UpdateUser(ctx, user, "external_id")
+ if err != nil {
+ err := fmt.Errorf("error linking existing user %s: %s", claims.Email, err)
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ return user, nil
+}
+
+func (m *Module) createUserFromOIDC(ctx context.Context, claims *oidc.Claims, extraInfo *extraInfo, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) {
+ // check if the email address is available for use; if it's not there's nothing we can so
+ emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email)
+ if err != nil {
+ return nil, gtserror.NewErrorBadRequest(err)
+ }
+ if !emailAvailable {
+ help := "The email address given to us by your authentication provider already exists in our records and the server administrator has not enabled account migration"
+ return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", claims.Email), help)
+ }
+
+ // check if the user is in any recognised admin groups
+ var admin bool
+ for _, g := range claims.Groups {
+ if strings.EqualFold(g, "admin") || strings.EqualFold(g, "admins") {
+ admin = true
+ }
+ }
+
+ // We still need to set *a* password even if it's not a password the user will end up using, so set something random.
+ // We'll just set two uuids on top of each other, which should be long + random enough to baffle any attempts to crack.
+ //
+ // If the user ever wants to log in using gts password rather than oidc flow, they'll have to request a password reset, which is fine
+ password := uuid.NewString() + uuid.NewString()
+
+ // Since this user is created via oidc, which has been set up by the admin, we can assume that the account is already
+ // implicitly approved, and that the email address has already been verified: otherwise, we end up in situations where
+ // the admin first approves the user in OIDC, and then has to approve them again in GoToSocial, which doesn't make sense.
+ //
+ // In other words, if a user logs in via OIDC, they should be able to use their account straight away.
+ //
+ // See: https://github.com/superseriousbusiness/gotosocial/issues/357
+ requireApproval := false
+ emailVerified := true
+
+ // create the user! this will also create an account and store it in the database so we don't need to do that here
+ user, err := m.db.NewSignup(ctx, extraInfo.Username, "", requireApproval, claims.Email, password, ip, "", appID, emailVerified, claims.Sub, admin)
+ if err != nil {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+
+ return user, nil
+}
diff --git a/internal/api/auth/oob.go b/internal/api/auth/oob.go
new file mode 100644
index 000000000..97f9c0f8c
--- /dev/null
+++ b/internal/api/auth/oob.go
@@ -0,0 +1,113 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package auth
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+)
+
+func (m *Module) OobHandler(c *gin.Context) {
+ host := config.GetHost()
+ instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), host)
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ instanceGet := func(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) {
+ return instance, nil
+ }
+
+ oobToken := c.Query("code")
+ if oobToken == "" {
+ err := errors.New("no 'code' query value provided in callback redirect")
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice), instanceGet)
+ return
+ }
+
+ s := sessions.Default(c)
+
+ errs := []string{}
+
+ scope, ok := s.Get(sessionScope).(string)
+ if !ok {
+ errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope))
+ }
+
+ userID, ok := s.Get(sessionUserID).(string)
+ if !ok {
+ errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID))
+ }
+
+ if len(errs) != 0 {
+ errs = append(errs, oauth.HelpfulAdvice)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during OobHandler"), errs...), m.processor.InstanceGet)
+ return
+ }
+
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
+ m.clearSession(s)
+ safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
+ var errWithCode gtserror.WithCode
+ if err == db.ErrNoEntries {
+ errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
+ } else {
+ errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
+ }
+ apiutil.ErrorHandler(c, errWithCode, instanceGet)
+ return
+ }
+
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil {
+ m.clearSession(s)
+ safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID)
+ var errWithCode gtserror.WithCode
+ if err == db.ErrNoEntries {
+ errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
+ } else {
+ errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
+ }
+ apiutil.ErrorHandler(c, errWithCode, instanceGet)
+ return
+ }
+
+ // we're done with the session now, so just clear it out
+ m.clearSession(s)
+
+ c.HTML(http.StatusOK, "oob.tmpl", gin.H{
+ "instance": instance,
+ "user": acct.Username,
+ "oobToken": oobToken,
+ "scope": scope,
+ })
+}
diff --git a/internal/api/auth/signin.go b/internal/api/auth/signin.go
new file mode 100644
index 000000000..bae33a43b
--- /dev/null
+++ b/internal/api/auth/signin.go
@@ -0,0 +1,145 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package auth
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "golang.org/x/crypto/bcrypt"
+)
+
+// login just wraps a form-submitted username (we want an email) and password
+type login struct {
+ Email string `form:"username"`
+ Password string `form:"password"`
+}
+
+// SignInGETHandler should be served at https://example.org/auth/sign_in.
+// The idea is to present a sign in page to the user, where they can enter their username and password.
+// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler.
+// If an idp provider is set, then the user will be redirected to that to do their sign in.
+func (m *Module) SignInGETHandler(c *gin.Context) {
+ if _, err := apiutil.NegotiateAccept(c, apiutil.HTMLAcceptHeaders...); err != nil {
+ apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ if !config.GetOIDCEnabled() {
+ instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost())
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ // no idp provider, use our own funky little sign in page
+ c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{
+ "instance": instance,
+ })
+ return
+ }
+
+ // idp provider is in use, so redirect to it
+ s := sessions.Default(c)
+
+ internalStateI := s.Get(sessionInternalState)
+ internalState, ok := internalStateI.(string)
+ if !ok {
+ m.clearSession(s)
+ err := fmt.Errorf("key %s was not found in session", sessionInternalState)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ c.Redirect(http.StatusSeeOther, m.idp.AuthCodeURL(internalState))
+}
+
+// SignInPOSTHandler should be served at https://example.org/auth/sign_in.
+// The idea is to present a sign in page to the user, where they can enter their username and password.
+// The handler will then redirect to the auth handler served at /auth
+func (m *Module) SignInPOSTHandler(c *gin.Context) {
+ s := sessions.Default(c)
+
+ form := &login{}
+ if err := c.ShouldBind(form); err != nil {
+ m.clearSession(s)
+ apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ return
+ }
+
+ userid, errWithCode := m.ValidatePassword(c.Request.Context(), form.Email, form.Password)
+ if errWithCode != nil {
+ // don't clear session here, so the user can just press back and try again
+ // if they accidentally gave the wrong password or something
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
+ return
+ }
+
+ s.Set(sessionUserID, userid)
+ if err := s.Save(); err != nil {
+ err := fmt.Errorf("error saving user id onto session: %s", err)
+ apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
+ }
+
+ c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath)
+}
+
+// ValidatePassword takes an email address and a password.
+// The goal is to authenticate the password against the one for that email
+// address stored in the database. If OK, we return the userid (a ulid) for that user,
+// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
+func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (string, gtserror.WithCode) {
+ if email == "" || password == "" {
+ err := errors.New("email or password was not provided")
+ return incorrectPassword(err)
+ }
+
+ user, err := m.db.GetUserByEmailAddress(ctx, email)
+ if err != nil {
+ err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
+ return incorrectPassword(err)
+ }
+
+ if user.EncryptedPassword == "" {
+ err := fmt.Errorf("encrypted password for user %s was empty for some reason", user.Email)
+ return incorrectPassword(err)
+ }
+
+ if err := bcrypt.CompareHashAndPassword([]byte(user.EncryptedPassword), []byte(password)); err != nil {
+ err := fmt.Errorf("password hash didn't match for user %s during login attempt: %s", user.Email, err)
+ return incorrectPassword(err)
+ }
+
+ return user.ID, nil
+}
+
+// incorrectPassword wraps the given error in a gtserror.WithCode, and returns
+// only a generic 'safe' error message to the user, to not give any info away.
+func incorrectPassword(err error) (string, gtserror.WithCode) {
+ safeErr := fmt.Errorf("password/email combination was incorrect")
+ return "", gtserror.NewErrorUnauthorized(err, safeErr.Error(), oauth.HelpfulAdvice)
+}
diff --git a/internal/api/auth/token.go b/internal/api/auth/token.go
new file mode 100644
index 000000000..17c4d8d8b
--- /dev/null
+++ b/internal/api/auth/token.go
@@ -0,0 +1,115 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package auth
+
+import (
+ "net/http"
+ "net/url"
+
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+
+ "github.com/gin-gonic/gin"
+)
+
+type tokenRequestForm struct {
+ GrantType *string `form:"grant_type" json:"grant_type" xml:"grant_type"`
+ Code *string `form:"code" json:"code" xml:"code"`
+ RedirectURI *string `form:"redirect_uri" json:"redirect_uri" xml:"redirect_uri"`
+ ClientID *string `form:"client_id" json:"client_id" xml:"client_id"`
+ ClientSecret *string `form:"client_secret" json:"client_secret" xml:"client_secret"`
+ Scope *string `form:"scope" json:"scope" xml:"scope"`
+}
+
+// TokenPOSTHandler should be served as a POST at https://example.org/oauth/token
+// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
+func (m *Module) TokenPOSTHandler(c *gin.Context) {
+ if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
+ apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ help := []string{}
+
+ form := &tokenRequestForm{}
+ if err := c.ShouldBind(form); err != nil {
+ apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error()))
+ return
+ }
+
+ c.Request.Form = url.Values{}
+
+ var grantType string
+ if form.GrantType != nil {
+ grantType = *form.GrantType
+ c.Request.Form.Set("grant_type", grantType)
+ } else {
+ help = append(help, "grant_type was not set in the token request form, but must be set to authorization_code or client_credentials")
+ }
+
+ if form.ClientID != nil {
+ c.Request.Form.Set("client_id", *form.ClientID)
+ } else {
+ help = append(help, "client_id was not set in the token request form")
+ }
+
+ if form.ClientSecret != nil {
+ c.Request.Form.Set("client_secret", *form.ClientSecret)
+ } else {
+ help = append(help, "client_secret was not set in the token request form")
+ }
+
+ if form.RedirectURI != nil {
+ c.Request.Form.Set("redirect_uri", *form.RedirectURI)
+ } else {
+ help = append(help, "redirect_uri was not set in the token request form")
+ }
+
+ var code string
+ if form.Code != nil {
+ if grantType != "authorization_code" {
+ help = append(help, "a code was provided in the token request form, but grant_type was not set to authorization_code")
+ } else {
+ code = *form.Code
+ c.Request.Form.Set("code", code)
+ }
+ } else if grantType == "authorization_code" {
+ help = append(help, "code was not set in the token request form, but must be set since grant_type is authorization_code")
+ }
+
+ if form.Scope != nil {
+ c.Request.Form.Set("scope", *form.Scope)
+ }
+
+ if len(help) != 0 {
+ apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...))
+ return
+ }
+
+ token, errWithCode := m.processor.OAuthHandleTokenRequest(c.Request)
+ if errWithCode != nil {
+ apiutil.OAuthErrorHandler(c, errWithCode)
+ return
+ }
+
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+ c.JSON(http.StatusOK, token)
+}
diff --git a/internal/api/auth/token_test.go b/internal/api/auth/token_test.go
new file mode 100644
index 000000000..50bbd6918
--- /dev/null
+++ b/internal/api/auth/token_test.go
@@ -0,0 +1,215 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package auth_test
+
+import (
+ "context"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type TokenTestSuite struct {
+ AuthStandardTestSuite
+}
+
+func (suite *TokenTestSuite) TestPOSTTokenEmptyForm() {
+ ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", []byte{}, "")
+ ctx.Request.Header.Set("accept", "application/json")
+
+ suite.authModule.TokenPOSTHandler(ctx)
+
+ suite.Equal(http.StatusBadRequest, recorder.Code)
+
+ result := recorder.Result()
+ defer result.Body.Close()
+
+ b, err := ioutil.ReadAll(result.Body)
+ suite.NoError(err)
+
+ suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: grant_type was not set in the token request form, but must be set to authorization_code or client_credentials: client_id was not set in the token request form: client_secret was not set in the token request form: redirect_uri was not set in the token request form"}`, string(b))
+}
+
+func (suite *TokenTestSuite) TestRetrieveClientCredentialsOK() {
+ testClient := suite.testClients["local_account_1"]
+
+ requestBody, w, err := testrig.CreateMultipartFormData(
+ "", "",
+ map[string]string{
+ "grant_type": "client_credentials",
+ "client_id": testClient.ID,
+ "client_secret": testClient.Secret,
+ "redirect_uri": "http://localhost:8080",
+ })
+ if err != nil {
+ panic(err)
+ }
+ bodyBytes := requestBody.Bytes()
+
+ ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType())
+ ctx.Request.Header.Set("accept", "application/json")
+
+ suite.authModule.TokenPOSTHandler(ctx)
+
+ suite.Equal(http.StatusOK, recorder.Code)
+
+ result := recorder.Result()
+ defer result.Body.Close()
+
+ b, err := ioutil.ReadAll(result.Body)
+ suite.NoError(err)
+
+ t := &apimodel.Token{}
+ err = json.Unmarshal(b, t)
+ suite.NoError(err)
+
+ suite.Equal("Bearer", t.TokenType)
+ suite.NotEmpty(t.AccessToken)
+ suite.NotEmpty(t.CreatedAt)
+ suite.WithinDuration(time.Now(), time.Unix(t.CreatedAt, 0), 1*time.Minute)
+
+ // there should be a token in the database now too
+ dbToken := &gtsmodel.Token{}
+ err = suite.db.GetWhere(context.Background(), []db.Where{{Key: "access", Value: t.AccessToken}}, dbToken)
+ suite.NoError(err)
+ suite.NotNil(dbToken)
+}
+
+func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeOK() {
+ testClient := suite.testClients["local_account_1"]
+ testUserAuthorizationToken := suite.testTokens["local_account_1_user_authorization_token"]
+
+ requestBody, w, err := testrig.CreateMultipartFormData(
+ "", "",
+ map[string]string{
+ "grant_type": "authorization_code",
+ "client_id": testClient.ID,
+ "client_secret": testClient.Secret,
+ "redirect_uri": "http://localhost:8080",
+ "code": testUserAuthorizationToken.Code,
+ })
+ if err != nil {
+ panic(err)
+ }
+ bodyBytes := requestBody.Bytes()
+
+ ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType())
+ ctx.Request.Header.Set("accept", "application/json")
+
+ suite.authModule.TokenPOSTHandler(ctx)
+
+ suite.Equal(http.StatusOK, recorder.Code)
+
+ result := recorder.Result()
+ defer result.Body.Close()
+
+ b, err := ioutil.ReadAll(result.Body)
+ suite.NoError(err)
+
+ t := &apimodel.Token{}
+ err = json.Unmarshal(b, t)
+ suite.NoError(err)
+
+ suite.Equal("Bearer", t.TokenType)
+ suite.NotEmpty(t.AccessToken)
+ suite.NotEmpty(t.CreatedAt)
+ suite.WithinDuration(time.Now(), time.Unix(t.CreatedAt, 0), 1*time.Minute)
+
+ dbToken := &gtsmodel.Token{}
+ err = suite.db.GetWhere(context.Background(), []db.Where{{Key: "access", Value: t.AccessToken}}, dbToken)
+ suite.NoError(err)
+ suite.NotNil(dbToken)
+}
+
+func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeNoCode() {
+ testClient := suite.testClients["local_account_1"]
+
+ requestBody, w, err := testrig.CreateMultipartFormData(
+ "", "",
+ map[string]string{
+ "grant_type": "authorization_code",
+ "client_id": testClient.ID,
+ "client_secret": testClient.Secret,
+ "redirect_uri": "http://localhost:8080",
+ })
+ if err != nil {
+ panic(err)
+ }
+ bodyBytes := requestBody.Bytes()
+
+ ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType())
+ ctx.Request.Header.Set("accept", "application/json")
+
+ suite.authModule.TokenPOSTHandler(ctx)
+
+ suite.Equal(http.StatusBadRequest, recorder.Code)
+
+ result := recorder.Result()
+ defer result.Body.Close()
+
+ b, err := ioutil.ReadAll(result.Body)
+ suite.NoError(err)
+
+ suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: code was not set in the token request form, but must be set since grant_type is authorization_code"}`, string(b))
+}
+
+func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeWrongGrantType() {
+ testClient := suite.testClients["local_account_1"]
+
+ requestBody, w, err := testrig.CreateMultipartFormData(
+ "", "",
+ map[string]string{
+ "grant_type": "client_credentials",
+ "client_id": testClient.ID,
+ "client_secret": testClient.Secret,
+ "redirect_uri": "http://localhost:8080",
+ "code": "peepeepoopoo",
+ })
+ if err != nil {
+ panic(err)
+ }
+ bodyBytes := requestBody.Bytes()
+
+ ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType())
+ ctx.Request.Header.Set("accept", "application/json")
+
+ suite.authModule.TokenPOSTHandler(ctx)
+
+ suite.Equal(http.StatusBadRequest, recorder.Code)
+
+ result := recorder.Result()
+ defer result.Body.Close()
+
+ b, err := ioutil.ReadAll(result.Body)
+ suite.NoError(err)
+
+ suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: a code was provided in the token request form, but grant_type was not set to authorization_code"}`, string(b))
+}
+
+func TestTokenTestSuite(t *testing.T) {
+ suite.Run(t, &TokenTestSuite{})
+}