diff options
Diffstat (limited to 'internal/router')
-rw-r--r-- | internal/router/cors.go | 3 | ||||
-rw-r--r-- | internal/router/router.go | 42 | ||||
-rw-r--r-- | internal/router/session.go | 20 | ||||
-rw-r--r-- | internal/router/session_test.go | 48 | ||||
-rw-r--r-- | internal/router/template.go | 6 |
5 files changed, 69 insertions, 50 deletions
diff --git a/internal/router/cors.go b/internal/router/cors.go index 9f8d379dd..e2ce9ce87 100644 --- a/internal/router/cors.go +++ b/internal/router/cors.go @@ -23,7 +23,6 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" - "github.com/superseriousbusiness/gotosocial/internal/config" ) var corsConfig = cors.Config{ @@ -81,7 +80,7 @@ var corsConfig = cors.Config{ } // useCors attaches the corsConfig above to the given gin engine -func useCors(cfg *config.Config, engine *gin.Engine) error { +func useCors(engine *gin.Engine) error { c := cors.New(corsConfig) engine.Use(c) return nil diff --git a/internal/router/router.go b/internal/router/router.go index aef5c32e4..aa588906f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -26,6 +26,7 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" + "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "golang.org/x/crypto/acme/autocert" @@ -58,7 +59,6 @@ type Router interface { type router struct { engine *gin.Engine srv *http.Server - config *config.Config certManager *autocert.Manager } @@ -69,10 +69,16 @@ func (r *router) AttachStaticFS(relativePath string, fs http.FileSystem) { // Start starts the router nicely. It will serve two handlers if letsencrypt is enabled, and only the web/API handler if letsencrypt is not enabled. func (r *router) Start() { - if r.config.LetsEncryptConfig.Enabled { + keys := config.Keys + leEnabled := viper.GetBool(keys.LetsEncryptEnabled) + + if leEnabled { + bindAddress := viper.GetString(keys.BindAddress) + lePort := viper.GetInt(keys.LetsEncryptPort) + // serve the http handler on the selected letsencrypt port, for receiving letsencrypt requests and solving their devious riddles go func() { - listen := fmt.Sprintf("%s:%d", r.config.BindAddress, r.config.LetsEncryptConfig.Port) + listen := fmt.Sprintf("%s:%d", bindAddress, lePort) if err := http.ListenAndServe(listen, r.certManager.HTTPHandler(http.HandlerFunc(httpsRedirect))); err != nil && err != http.ErrServerClosed { logrus.Fatalf("listen: %s", err) } @@ -103,7 +109,9 @@ func (r *router) Stop(ctx context.Context) error { // // The given DB is only used in the New function for parsing config values, and is not otherwise // pinned to the router. -func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) { +func New(ctx context.Context, db db.DB) (Router, error) { + keys := config.Keys + gin.SetMode(gin.ReleaseMode) // create the actual engine here -- this is the core request routing handler for gts @@ -116,12 +124,13 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) { engine.MaxMultipartMemory = 8 << 20 // set up IP forwarding via x-forward-* headers. - if err := engine.SetTrustedProxies(cfg.TrustedProxies); err != nil { + trustedProxies := viper.GetStringSlice(keys.TrustedProxies) + if err := engine.SetTrustedProxies(trustedProxies); err != nil { return nil, err } // enable cors on the engine - if err := useCors(cfg, engine); err != nil { + if err := useCors(engine); err != nil { return nil, err } @@ -129,17 +138,19 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) { loadTemplateFunctions(engine) // load templates onto the engine - if err := loadTemplates(cfg, engine); err != nil { + if err := loadTemplates(engine); err != nil { return nil, err } // enable session store middleware on the engine - if err := useSession(ctx, cfg, db, engine); err != nil { + if err := useSession(ctx, db, engine); err != nil { return nil, err } // create the http server here, passing the gin engine as handler - listen := fmt.Sprintf("%s:%d", cfg.BindAddress, cfg.Port) + bindAddress := viper.GetString(keys.BindAddress) + port := viper.GetInt(keys.Port) + listen := fmt.Sprintf("%s:%d", bindAddress, port) s := &http.Server{ Addr: listen, Handler: engine, @@ -151,15 +162,19 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) { // We need to spawn the underlying server slightly differently depending on whether lets encrypt is enabled or not. // In either case, the gin engine will still be used for routing requests. + leEnabled := viper.GetBool(keys.LetsEncryptEnabled) var m *autocert.Manager - if cfg.LetsEncryptConfig.Enabled { + if leEnabled { // le IS enabled, so roll up an autocert manager for handling letsencrypt requests + host := viper.GetString(keys.Host) + leCertDir := viper.GetString(keys.LetsEncryptCertDir) + leEmailAddress := viper.GetString(keys.LetsEncryptEmailAddress) m = &autocert.Manager{ Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(cfg.Host), - Cache: autocert.DirCache(cfg.LetsEncryptConfig.CertDir), - Email: cfg.LetsEncryptConfig.EmailAddress, + HostPolicy: autocert.HostWhitelist(host), + Cache: autocert.DirCache(leCertDir), + Email: leEmailAddress, } s.TLSConfig = m.TLSConfig() } @@ -167,7 +182,6 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) { return &router{ engine: engine, srv: s, - config: cfg, certManager: m, }, nil } diff --git a/internal/router/session.go b/internal/router/session.go index 3276c38aa..1f7b8bcfa 100644 --- a/internal/router/session.go +++ b/internal/router/session.go @@ -28,15 +28,16 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/memstore" "github.com/gin-gonic/gin" + "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" ) // sessionOptions returns the standard set of options to use for each session. -func sessionOptions(cfg *config.Config) sessions.Options { +func sessionOptions() sessions.Options { return sessions.Options{ Path: "/", - Domain: cfg.Host, + Domain: viper.GetString(config.Keys.Host), MaxAge: 120, // 2 minutes Secure: true, // only use cookie over https HttpOnly: true, // exclude javascript from inspecting cookie @@ -44,9 +45,12 @@ func sessionOptions(cfg *config.Config) sessions.Options { } } -func sessionName(cfg *config.Config) (string, error) { +// SessionName is a utility function that derives an appropriate session name from the hostname. +func SessionName() (string, error) { // parse the protocol + host - u, err := url.Parse(fmt.Sprintf("%s://%s", cfg.Protocol, cfg.Host)) + protocol := viper.GetString(config.Keys.Protocol) + host := viper.GetString(config.Keys.Host) + u, err := url.Parse(fmt.Sprintf("%s://%s", protocol, host)) if err != nil { return "", err } @@ -54,13 +58,13 @@ func sessionName(cfg *config.Config) (string, error) { // take the hostname without any port attached strippedHostname := u.Hostname() if strippedHostname == "" { - return "", fmt.Errorf("could not derive hostname without port from %s://%s", cfg.Protocol, cfg.Host) + return "", fmt.Errorf("could not derive hostname without port from %s://%s", protocol, host) } return fmt.Sprintf("gotosocial-%s", strippedHostname), nil } -func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error { +func useSession(ctx context.Context, sessionDB db.Session, engine *gin.Engine) error { // check if we have a saved router session already rs, err := sessionDB.GetSession(ctx) if err != nil { @@ -71,9 +75,9 @@ func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, e } store := memstore.NewStore(rs.Auth, rs.Crypt) - store.Options(sessionOptions(cfg)) + store.Options(sessionOptions()) - sessionName, err := sessionName(cfg) + sessionName, err := SessionName() if err != nil { return err } diff --git a/internal/router/session_test.go b/internal/router/session_test.go index 7c2d324fd..31beec1ae 100644 --- a/internal/router/session_test.go +++ b/internal/router/session_test.go @@ -16,68 +16,68 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package router +package router_test import ( "testing" + "github.com/spf13/viper" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/router" + "github.com/superseriousbusiness/gotosocial/testrig" ) type SessionTestSuite struct { suite.Suite } +func (suite *SessionTestSuite) SetupTest() { + testrig.InitTestConfig() +} + func (suite *SessionTestSuite) TestDeriveSessionNameLocalhostWithPort() { - cfg := &config.Config{ - Protocol: "http", - Host: "localhost:8080", - } + viper.Set(config.Keys.Protocol, "http") + viper.Set(config.Keys.Host, "localhost:8080") - sessionName, err := sessionName(cfg) + sessionName, err := router.SessionName() suite.NoError(err) suite.Equal("gotosocial-localhost", sessionName) } func (suite *SessionTestSuite) TestDeriveSessionNameLocalhost() { - cfg := &config.Config{ - Protocol: "http", - Host: "localhost", - } + viper.Set(config.Keys.Protocol, "http") + viper.Set(config.Keys.Host, "localhost") - sessionName, err := sessionName(cfg) + sessionName, err := router.SessionName() suite.NoError(err) suite.Equal("gotosocial-localhost", sessionName) } func (suite *SessionTestSuite) TestDeriveSessionNoProtocol() { - cfg := &config.Config{ - Host: "localhost", - } + viper.Set(config.Keys.Protocol, "") + viper.Set(config.Keys.Host, "localhost") - sessionName, err := sessionName(cfg) + sessionName, err := router.SessionName() suite.EqualError(err, "parse \"://localhost\": missing protocol scheme") suite.Equal("", sessionName) } func (suite *SessionTestSuite) TestDeriveSessionNoHost() { - cfg := &config.Config{ - Protocol: "https", - } + viper.Set(config.Keys.Protocol, "https") + viper.Set(config.Keys.Host, "") + viper.Set(config.Keys.Port, 0) - sessionName, err := sessionName(cfg) + sessionName, err := router.SessionName() suite.EqualError(err, "could not derive hostname without port from https://") suite.Equal("", sessionName) } func (suite *SessionTestSuite) TestDeriveSessionOK() { - cfg := &config.Config{ - Protocol: "https", - Host: "example.org", - } + viper.Set(config.Keys.Protocol, "https") + viper.Set(config.Keys.Host, "example.org") - sessionName, err := sessionName(cfg) + sessionName, err := router.SessionName() suite.NoError(err) suite.Equal("gotosocial-example.org", sessionName) } diff --git a/internal/router/template.go b/internal/router/template.go index bf5682628..b0d998208 100644 --- a/internal/router/template.go +++ b/internal/router/template.go @@ -26,18 +26,20 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/config" ) // loadTemplates loads html templates for use by the given engine -func loadTemplates(cfg *config.Config, engine *gin.Engine) error { +func loadTemplates(engine *gin.Engine) error { cwd, err := os.Getwd() if err != nil { return fmt.Errorf("error getting current working directory: %s", err) } - tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", cfg.TemplateConfig.BaseDir)) + templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir) + tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir)) engine.LoadHTMLGlob(tmPath) return nil |