summaryrefslogtreecommitdiff
path: root/internal/router
diff options
context:
space:
mode:
Diffstat (limited to 'internal/router')
-rw-r--r--internal/router/cors.go3
-rw-r--r--internal/router/router.go42
-rw-r--r--internal/router/session.go20
-rw-r--r--internal/router/session_test.go48
-rw-r--r--internal/router/template.go6
5 files changed, 69 insertions, 50 deletions
diff --git a/internal/router/cors.go b/internal/router/cors.go
index 9f8d379dd..e2ce9ce87 100644
--- a/internal/router/cors.go
+++ b/internal/router/cors.go
@@ -23,7 +23,6 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
- "github.com/superseriousbusiness/gotosocial/internal/config"
)
var corsConfig = cors.Config{
@@ -81,7 +80,7 @@ var corsConfig = cors.Config{
}
// useCors attaches the corsConfig above to the given gin engine
-func useCors(cfg *config.Config, engine *gin.Engine) error {
+func useCors(engine *gin.Engine) error {
c := cors.New(corsConfig)
engine.Use(c)
return nil
diff --git a/internal/router/router.go b/internal/router/router.go
index aef5c32e4..aa588906f 100644
--- a/internal/router/router.go
+++ b/internal/router/router.go
@@ -26,6 +26,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
+ "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"golang.org/x/crypto/acme/autocert"
@@ -58,7 +59,6 @@ type Router interface {
type router struct {
engine *gin.Engine
srv *http.Server
- config *config.Config
certManager *autocert.Manager
}
@@ -69,10 +69,16 @@ func (r *router) AttachStaticFS(relativePath string, fs http.FileSystem) {
// Start starts the router nicely. It will serve two handlers if letsencrypt is enabled, and only the web/API handler if letsencrypt is not enabled.
func (r *router) Start() {
- if r.config.LetsEncryptConfig.Enabled {
+ keys := config.Keys
+ leEnabled := viper.GetBool(keys.LetsEncryptEnabled)
+
+ if leEnabled {
+ bindAddress := viper.GetString(keys.BindAddress)
+ lePort := viper.GetInt(keys.LetsEncryptPort)
+
// serve the http handler on the selected letsencrypt port, for receiving letsencrypt requests and solving their devious riddles
go func() {
- listen := fmt.Sprintf("%s:%d", r.config.BindAddress, r.config.LetsEncryptConfig.Port)
+ listen := fmt.Sprintf("%s:%d", bindAddress, lePort)
if err := http.ListenAndServe(listen, r.certManager.HTTPHandler(http.HandlerFunc(httpsRedirect))); err != nil && err != http.ErrServerClosed {
logrus.Fatalf("listen: %s", err)
}
@@ -103,7 +109,9 @@ func (r *router) Stop(ctx context.Context) error {
//
// The given DB is only used in the New function for parsing config values, and is not otherwise
// pinned to the router.
-func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
+func New(ctx context.Context, db db.DB) (Router, error) {
+ keys := config.Keys
+
gin.SetMode(gin.ReleaseMode)
// create the actual engine here -- this is the core request routing handler for gts
@@ -116,12 +124,13 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
engine.MaxMultipartMemory = 8 << 20
// set up IP forwarding via x-forward-* headers.
- if err := engine.SetTrustedProxies(cfg.TrustedProxies); err != nil {
+ trustedProxies := viper.GetStringSlice(keys.TrustedProxies)
+ if err := engine.SetTrustedProxies(trustedProxies); err != nil {
return nil, err
}
// enable cors on the engine
- if err := useCors(cfg, engine); err != nil {
+ if err := useCors(engine); err != nil {
return nil, err
}
@@ -129,17 +138,19 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
loadTemplateFunctions(engine)
// load templates onto the engine
- if err := loadTemplates(cfg, engine); err != nil {
+ if err := loadTemplates(engine); err != nil {
return nil, err
}
// enable session store middleware on the engine
- if err := useSession(ctx, cfg, db, engine); err != nil {
+ if err := useSession(ctx, db, engine); err != nil {
return nil, err
}
// create the http server here, passing the gin engine as handler
- listen := fmt.Sprintf("%s:%d", cfg.BindAddress, cfg.Port)
+ bindAddress := viper.GetString(keys.BindAddress)
+ port := viper.GetInt(keys.Port)
+ listen := fmt.Sprintf("%s:%d", bindAddress, port)
s := &http.Server{
Addr: listen,
Handler: engine,
@@ -151,15 +162,19 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
// We need to spawn the underlying server slightly differently depending on whether lets encrypt is enabled or not.
// In either case, the gin engine will still be used for routing requests.
+ leEnabled := viper.GetBool(keys.LetsEncryptEnabled)
var m *autocert.Manager
- if cfg.LetsEncryptConfig.Enabled {
+ if leEnabled {
// le IS enabled, so roll up an autocert manager for handling letsencrypt requests
+ host := viper.GetString(keys.Host)
+ leCertDir := viper.GetString(keys.LetsEncryptCertDir)
+ leEmailAddress := viper.GetString(keys.LetsEncryptEmailAddress)
m = &autocert.Manager{
Prompt: autocert.AcceptTOS,
- HostPolicy: autocert.HostWhitelist(cfg.Host),
- Cache: autocert.DirCache(cfg.LetsEncryptConfig.CertDir),
- Email: cfg.LetsEncryptConfig.EmailAddress,
+ HostPolicy: autocert.HostWhitelist(host),
+ Cache: autocert.DirCache(leCertDir),
+ Email: leEmailAddress,
}
s.TLSConfig = m.TLSConfig()
}
@@ -167,7 +182,6 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
return &router{
engine: engine,
srv: s,
- config: cfg,
certManager: m,
}, nil
}
diff --git a/internal/router/session.go b/internal/router/session.go
index 3276c38aa..1f7b8bcfa 100644
--- a/internal/router/session.go
+++ b/internal/router/session.go
@@ -28,15 +28,16 @@ import (
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore"
"github.com/gin-gonic/gin"
+ "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
// sessionOptions returns the standard set of options to use for each session.
-func sessionOptions(cfg *config.Config) sessions.Options {
+func sessionOptions() sessions.Options {
return sessions.Options{
Path: "/",
- Domain: cfg.Host,
+ Domain: viper.GetString(config.Keys.Host),
MaxAge: 120, // 2 minutes
Secure: true, // only use cookie over https
HttpOnly: true, // exclude javascript from inspecting cookie
@@ -44,9 +45,12 @@ func sessionOptions(cfg *config.Config) sessions.Options {
}
}
-func sessionName(cfg *config.Config) (string, error) {
+// SessionName is a utility function that derives an appropriate session name from the hostname.
+func SessionName() (string, error) {
// parse the protocol + host
- u, err := url.Parse(fmt.Sprintf("%s://%s", cfg.Protocol, cfg.Host))
+ protocol := viper.GetString(config.Keys.Protocol)
+ host := viper.GetString(config.Keys.Host)
+ u, err := url.Parse(fmt.Sprintf("%s://%s", protocol, host))
if err != nil {
return "", err
}
@@ -54,13 +58,13 @@ func sessionName(cfg *config.Config) (string, error) {
// take the hostname without any port attached
strippedHostname := u.Hostname()
if strippedHostname == "" {
- return "", fmt.Errorf("could not derive hostname without port from %s://%s", cfg.Protocol, cfg.Host)
+ return "", fmt.Errorf("could not derive hostname without port from %s://%s", protocol, host)
}
return fmt.Sprintf("gotosocial-%s", strippedHostname), nil
}
-func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error {
+func useSession(ctx context.Context, sessionDB db.Session, engine *gin.Engine) error {
// check if we have a saved router session already
rs, err := sessionDB.GetSession(ctx)
if err != nil {
@@ -71,9 +75,9 @@ func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, e
}
store := memstore.NewStore(rs.Auth, rs.Crypt)
- store.Options(sessionOptions(cfg))
+ store.Options(sessionOptions())
- sessionName, err := sessionName(cfg)
+ sessionName, err := SessionName()
if err != nil {
return err
}
diff --git a/internal/router/session_test.go b/internal/router/session_test.go
index 7c2d324fd..31beec1ae 100644
--- a/internal/router/session_test.go
+++ b/internal/router/session_test.go
@@ -16,68 +16,68 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package router
+package router_test
import (
"testing"
+ "github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/router"
+ "github.com/superseriousbusiness/gotosocial/testrig"
)
type SessionTestSuite struct {
suite.Suite
}
+func (suite *SessionTestSuite) SetupTest() {
+ testrig.InitTestConfig()
+}
+
func (suite *SessionTestSuite) TestDeriveSessionNameLocalhostWithPort() {
- cfg := &config.Config{
- Protocol: "http",
- Host: "localhost:8080",
- }
+ viper.Set(config.Keys.Protocol, "http")
+ viper.Set(config.Keys.Host, "localhost:8080")
- sessionName, err := sessionName(cfg)
+ sessionName, err := router.SessionName()
suite.NoError(err)
suite.Equal("gotosocial-localhost", sessionName)
}
func (suite *SessionTestSuite) TestDeriveSessionNameLocalhost() {
- cfg := &config.Config{
- Protocol: "http",
- Host: "localhost",
- }
+ viper.Set(config.Keys.Protocol, "http")
+ viper.Set(config.Keys.Host, "localhost")
- sessionName, err := sessionName(cfg)
+ sessionName, err := router.SessionName()
suite.NoError(err)
suite.Equal("gotosocial-localhost", sessionName)
}
func (suite *SessionTestSuite) TestDeriveSessionNoProtocol() {
- cfg := &config.Config{
- Host: "localhost",
- }
+ viper.Set(config.Keys.Protocol, "")
+ viper.Set(config.Keys.Host, "localhost")
- sessionName, err := sessionName(cfg)
+ sessionName, err := router.SessionName()
suite.EqualError(err, "parse \"://localhost\": missing protocol scheme")
suite.Equal("", sessionName)
}
func (suite *SessionTestSuite) TestDeriveSessionNoHost() {
- cfg := &config.Config{
- Protocol: "https",
- }
+ viper.Set(config.Keys.Protocol, "https")
+ viper.Set(config.Keys.Host, "")
+ viper.Set(config.Keys.Port, 0)
- sessionName, err := sessionName(cfg)
+ sessionName, err := router.SessionName()
suite.EqualError(err, "could not derive hostname without port from https://")
suite.Equal("", sessionName)
}
func (suite *SessionTestSuite) TestDeriveSessionOK() {
- cfg := &config.Config{
- Protocol: "https",
- Host: "example.org",
- }
+ viper.Set(config.Keys.Protocol, "https")
+ viper.Set(config.Keys.Host, "example.org")
- sessionName, err := sessionName(cfg)
+ sessionName, err := router.SessionName()
suite.NoError(err)
suite.Equal("gotosocial-example.org", sessionName)
}
diff --git a/internal/router/template.go b/internal/router/template.go
index bf5682628..b0d998208 100644
--- a/internal/router/template.go
+++ b/internal/router/template.go
@@ -26,18 +26,20 @@ import (
"time"
"github.com/gin-gonic/gin"
+ "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
// loadTemplates loads html templates for use by the given engine
-func loadTemplates(cfg *config.Config, engine *gin.Engine) error {
+func loadTemplates(engine *gin.Engine) error {
cwd, err := os.Getwd()
if err != nil {
return fmt.Errorf("error getting current working directory: %s", err)
}
- tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", cfg.TemplateConfig.BaseDir))
+ templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir)
+ tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir))
engine.LoadHTMLGlob(tmPath)
return nil