diff options
| -rw-r--r-- | internal/router/session.go | 30 | ||||
| -rw-r--r-- | internal/router/session_test.go | 87 | ||||
| -rw-r--r-- | internal/router/template.go | 18 | 
3 files changed, 131 insertions, 4 deletions
diff --git a/internal/router/session.go b/internal/router/session.go index a42f04bfb..3276c38aa 100644 --- a/internal/router/session.go +++ b/internal/router/session.go @@ -23,6 +23,7 @@ import (  	"errors"  	"fmt"  	"net/http" +	"net/url"  	"github.com/gin-contrib/sessions"  	"github.com/gin-contrib/sessions/memstore" @@ -31,8 +32,8 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/db"  ) -// SessionOptions returns the standard set of options to use for each session. -func SessionOptions(cfg *config.Config) sessions.Options { +// sessionOptions returns the standard set of options to use for each session. +func sessionOptions(cfg *config.Config) sessions.Options {  	return sessions.Options{  		Path:     "/",  		Domain:   cfg.Host, @@ -43,6 +44,22 @@ func SessionOptions(cfg *config.Config) sessions.Options {  	}  } +func sessionName(cfg *config.Config) (string, error) { +	// parse the protocol + host +	u, err := url.Parse(fmt.Sprintf("%s://%s", cfg.Protocol, cfg.Host)) +	if err != nil { +		return "", err +	} + +	// take the hostname without any port attached +	strippedHostname := u.Hostname() +	if strippedHostname == "" { +		return "", fmt.Errorf("could not derive hostname without port from %s://%s", cfg.Protocol, cfg.Host) +	} + +	return fmt.Sprintf("gotosocial-%s", strippedHostname), nil +} +  func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error {  	// check if we have a saved router session already  	rs, err := sessionDB.GetSession(ctx) @@ -54,8 +71,13 @@ func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, e  	}  	store := memstore.NewStore(rs.Auth, rs.Crypt) -	store.Options(SessionOptions(cfg)) -	sessionName := fmt.Sprintf("gotosocial-%s", cfg.Host) +	store.Options(sessionOptions(cfg)) + +	sessionName, err := sessionName(cfg) +	if err != nil { +		return err +	} +  	engine.Use(sessions.Sessions(sessionName, store))  	return nil  } diff --git a/internal/router/session_test.go b/internal/router/session_test.go new file mode 100644 index 000000000..7c2d324fd --- /dev/null +++ b/internal/router/session_test.go @@ -0,0 +1,87 @@ +/* +   GoToSocial +   Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package router + +import ( +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/config" +) + +type SessionTestSuite struct { +	suite.Suite +} + +func (suite *SessionTestSuite) TestDeriveSessionNameLocalhostWithPort() { +	cfg := &config.Config{ +		Protocol: "http", +		Host:     "localhost:8080", +	} + +	sessionName, err := sessionName(cfg) +	suite.NoError(err) +	suite.Equal("gotosocial-localhost", sessionName) +} + +func (suite *SessionTestSuite) TestDeriveSessionNameLocalhost() { +	cfg := &config.Config{ +		Protocol: "http", +		Host:     "localhost", +	} + +	sessionName, err := sessionName(cfg) +	suite.NoError(err) +	suite.Equal("gotosocial-localhost", sessionName) +} + +func (suite *SessionTestSuite) TestDeriveSessionNoProtocol() { +	cfg := &config.Config{ +		Host: "localhost", +	} + +	sessionName, err := sessionName(cfg) +	suite.EqualError(err, "parse \"://localhost\": missing protocol scheme") +	suite.Equal("", sessionName) +} + +func (suite *SessionTestSuite) TestDeriveSessionNoHost() { +	cfg := &config.Config{ +		Protocol: "https", +	} + +	sessionName, err := sessionName(cfg) +	suite.EqualError(err, "could not derive hostname without port from https://") +	suite.Equal("", sessionName) +} + +func (suite *SessionTestSuite) TestDeriveSessionOK() { +	cfg := &config.Config{ +		Protocol: "https", +		Host:     "example.org", +	} + +	sessionName, err := sessionName(cfg) +	suite.NoError(err) +	suite.Equal("gotosocial-example.org", sessionName) +} + +func TestSessionTestSuite(t *testing.T) { +	suite.Run(t, &SessionTestSuite{}) +} diff --git a/internal/router/template.go b/internal/router/template.go index 787ade799..e7bdc3edf 100644 --- a/internal/router/template.go +++ b/internal/router/template.go @@ -1,3 +1,21 @@ +/* +   GoToSocial +   Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ +  package router  import (  | 
