diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/config/config.go | 3 | ||||
| -rw-r--r-- | internal/config/defaults.go | 3 | ||||
| -rw-r--r-- | internal/config/flags.go | 4 | ||||
| -rw-r--r-- | internal/config/helpers.gen.go | 50 | ||||
| -rw-r--r-- | internal/config/validate.go | 13 | ||||
| -rw-r--r-- | internal/router/router.go | 21 | 
6 files changed, 94 insertions, 0 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index fdfda8583..a7a36eebf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -114,6 +114,9 @@ type Configuration struct {  	LetsEncryptCertDir      string `name:"letsencrypt-cert-dir" usage:"Directory to store acquired letsencrypt certificates."`  	LetsEncryptEmailAddress string `name:"letsencrypt-email-address" usage:"Email address to use when requesting letsencrypt certs. Will receive updates on cert expiry etc."` +	TLSCertificateChain string `name:"tls-certificate-chain" usage:"Filesystem path to the certificate chain including any intermediate CAs and the TLS public key"` +	TLSCertificateKey   string `name:"tls-certificate-key" usage:"Filesystem path to the TLS private key"` +  	OIDCEnabled          bool     `name:"oidc-enabled" usage:"Enabled OIDC authorization for this instance. If set to true, then the other OIDC flags must also be set."`  	OIDCIdpName          string   `name:"oidc-idp-name" usage:"Name of the OIDC identity provider. Will be shown to the user when logging in."`  	OIDCSkipVerification bool     `name:"oidc-skip-verification" usage:"Skip verification of tokens returned by the OIDC provider. Should only be set to 'true' for testing purposes, never in a production environment!"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index e9dd2b743..7d2427ee7 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -91,6 +91,9 @@ var Defaults = Configuration{  	LetsEncryptCertDir:      "/gotosocial/storage/certs",  	LetsEncryptEmailAddress: "", +	TLSCertificateChain: "", +	TLSCertificateKey:   "", +  	OIDCEnabled:          false,  	OIDCIdpName:          "",  	OIDCSkipVerification: false, diff --git a/internal/config/flags.go b/internal/config/flags.go index 3ef44bf62..5206ee8ae 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -114,6 +114,10 @@ func (s *ConfigState) AddServerFlags(cmd *cobra.Command) {  		cmd.Flags().String(LetsEncryptCertDirFlag(), cfg.LetsEncryptCertDir, fieldtag("LetsEncryptCertDir", "usage"))  		cmd.Flags().String(LetsEncryptEmailAddressFlag(), cfg.LetsEncryptEmailAddress, fieldtag("LetsEncryptEmailAddress", "usage")) +		// Manual TLS +		cmd.Flags().String(TLSCertificateChainFlag(), cfg.TLSCertificateChain, fieldtag("TLSCertificateChain", "usage")) +		cmd.Flags().String(TLSCertificateKeyFlag(), cfg.TLSCertificateKey, fieldtag("TLSCertificateKey", "usage")) +  		// OIDC  		cmd.Flags().Bool(OIDCEnabledFlag(), cfg.OIDCEnabled, fieldtag("OIDCEnabled", "usage"))  		cmd.Flags().String(OIDCIdpNameFlag(), cfg.OIDCIdpName, fieldtag("OIDCIdpName", "usage")) diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 5ea7b61b6..b021ed617 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -1524,6 +1524,56 @@ func GetLetsEncryptEmailAddress() string { return global.GetLetsEncryptEmailAddr  // SetLetsEncryptEmailAddress safely sets the value for global configuration 'LetsEncryptEmailAddress' field  func SetLetsEncryptEmailAddress(v string) { global.SetLetsEncryptEmailAddress(v) } +// GetTLSCertificateChain safely fetches the Configuration value for state's 'TLSCertificateChain' field +func (st *ConfigState) GetTLSCertificateChain() (v string) { +	st.mutex.Lock() +	v = st.config.TLSCertificateChain +	st.mutex.Unlock() +	return +} + +// SetTLSCertificateChain safely sets the Configuration value for state's 'TLSCertificateChain' field +func (st *ConfigState) SetTLSCertificateChain(v string) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.TLSCertificateChain = v +	st.reloadToViper() +} + +// TLSCertificateChainFlag returns the flag name for the 'TLSCertificateChain' field +func TLSCertificateChainFlag() string { return "tls-certificate-chain" } + +// GetTLSCertificateChain safely fetches the value for global configuration 'TLSCertificateChain' field +func GetTLSCertificateChain() string { return global.GetTLSCertificateChain() } + +// SetTLSCertificateChain safely sets the value for global configuration 'TLSCertificateChain' field +func SetTLSCertificateChain(v string) { global.SetTLSCertificateChain(v) } + +// GetTLSCertificateKey safely fetches the Configuration value for state's 'TLSCertificateKey' field +func (st *ConfigState) GetTLSCertificateKey() (v string) { +	st.mutex.Lock() +	v = st.config.TLSCertificateKey +	st.mutex.Unlock() +	return +} + +// SetTLSCertificateKey safely sets the Configuration value for state's 'TLSCertificateKey' field +func (st *ConfigState) SetTLSCertificateKey(v string) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.TLSCertificateKey = v +	st.reloadToViper() +} + +// TLSCertificateKeyFlag returns the flag name for the 'TLSCertificateKey' field +func TLSCertificateKeyFlag() string { return "tls-certificate-key" } + +// GetTLSCertificateKey safely fetches the value for global configuration 'TLSCertificateKey' field +func GetTLSCertificateKey() string { return global.GetTLSCertificateKey() } + +// SetTLSCertificateKey safely sets the value for global configuration 'TLSCertificateKey' field +func SetTLSCertificateKey(v string) { global.SetTLSCertificateKey(v) } +  // GetOIDCEnabled safely fetches the Configuration value for state's 'OIDCEnabled' field  func (st *ConfigState) GetOIDCEnabled() (v bool) {  	st.mutex.Lock() diff --git a/internal/config/validate.go b/internal/config/validate.go index 866ec1be1..2735a9229 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -67,6 +67,19 @@ func Validate() error {  		errs = append(errs, fmt.Errorf("%s must be set", WebAssetBaseDirFlag()))  	} +	tlsChain := GetTLSCertificateChain() +	tlsKey := GetTLSCertificateKey() +	tlsChainFlag := TLSCertificateChainFlag() +	tlsKeyFlag := TLSCertificateKeyFlag() + +	if GetLetsEncryptEnabled() && (tlsChain != "" || tlsKey != "") { +		errs = append(errs, fmt.Errorf("%s cannot be enabled when %s and/or %s are also set", LetsEncryptEnabledFlag(), tlsChainFlag, tlsKeyFlag)) +	} + +	if (tlsChain != "" && tlsKey == "") || (tlsChain == "" && tlsKey != "") { +		errs = append(errs, fmt.Errorf("%s and %s need to both be set or unset", tlsChainFlag, tlsKeyFlag)) +	} +  	if len(errs) > 0 {  		errStrings := []string{}  		for _, err := range errs { diff --git a/internal/router/router.go b/internal/router/router.go index 0b9b63494..edbf51cbb 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -20,6 +20,7 @@ package router  import (  	"context" +	"crypto/tls"  	"fmt"  	"net"  	"net/http" @@ -78,6 +79,26 @@ func (r *router) Start() {  	// but updated to TLS if LetsEncrypt is enabled.  	listen := r.srv.ListenAndServe +	// During config validation we already checked that both Chain and Key are set +	// so we can forego checking for both here +	if chain := config.GetTLSCertificateChain(); chain != "" { +		pkey := config.GetTLSCertificateKey() +		cer, err := tls.LoadX509KeyPair(chain, pkey) +		if err != nil { +			log.Fatalf( +				nil, +				"tls: failed to load keypair from %s and %s, ensure they are PEM-encoded and can be read by this process: %s", +				chain, pkey, err, +			) +		} +		r.srv.TLSConfig = &tls.Config{ +			MinVersion:   tls.VersionTLS12, +			Certificates: []tls.Certificate{cer}, +		} +		// TLS is enabled, update the listen function +		listen = func() error { return r.srv.ListenAndServeTLS("", "") } +	} +  	if config.GetLetsEncryptEnabled() {  		// LetsEncrypt support is enabled  | 
