diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/config/config.go | 64 | ||||
-rw-r--r-- | internal/config/db.go | 33 | ||||
-rw-r--r-- | internal/config/default.go | 14 | ||||
-rw-r--r-- | internal/db/pg/pg.go | 52 |
4 files changed, 126 insertions, 37 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index 28bbc8542..323b7de81 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -165,6 +165,14 @@ func (c *Config) ParseCLIFlags(f KeyedFlags, version string) error { c.DBConfig.Database = f.String(fn.DbDatabase) } + if c.DBConfig.TLSMode == DBTLSModeUnset || f.IsSet(fn.DbTLSMode) { + c.DBConfig.TLSMode = DBTLSMode(f.String(fn.DbTLSMode)) + } + + if c.DBConfig.TLSCACert == "" || f.IsSet(fn.DbTLSCACert) { + c.DBConfig.TLSCACert = f.String(fn.DbTLSCACert) + } + // template flags if c.TemplateConfig.BaseDir == "" || f.IsSet(fn.TemplateBaseDir) { c.TemplateConfig.BaseDir = f.String(fn.TemplateBaseDir) @@ -284,12 +292,14 @@ type Flags struct { Host string Protocol string - DbType string - DbAddress string - DbPort string - DbUser string - DbPassword string - DbDatabase string + DbType string + DbAddress string + DbPort string + DbUser string + DbPassword string + DbDatabase string + DbTLSMode string + DbTLSCACert string TemplateBaseDir string AssetBaseDir string @@ -329,12 +339,14 @@ type Defaults struct { Protocol string SoftwareVersion string - DbType string - DbAddress string - DbPort int - DbUser string - DbPassword string - DbDatabase string + DbType string + DbAddress string + DbPort int + DbUser string + DbPassword string + DbDatabase string + DBTlsMode string + DBTlsCACert string TemplateBaseDir string AssetBaseDir string @@ -375,12 +387,14 @@ func GetFlagNames() Flags { Host: "host", Protocol: "protocol", - DbType: "db-type", - DbAddress: "db-address", - DbPort: "db-port", - DbUser: "db-user", - DbPassword: "db-password", - DbDatabase: "db-database", + DbType: "db-type", + DbAddress: "db-address", + DbPort: "db-port", + DbUser: "db-user", + DbPassword: "db-password", + DbDatabase: "db-database", + DbTLSMode: "db-tls-mode", + DbTLSCACert: "db-tls-ca-cert", TemplateBaseDir: "template-basedir", AssetBaseDir: "asset-basedir", @@ -422,12 +436,14 @@ func GetEnvNames() Flags { Host: "GTS_HOST", Protocol: "GTS_PROTOCOL", - DbType: "GTS_DB_TYPE", - DbAddress: "GTS_DB_ADDRESS", - DbPort: "GTS_DB_PORT", - DbUser: "GTS_DB_USER", - DbPassword: "GTS_DB_PASSWORD", - DbDatabase: "GTS_DB_DATABASE", + DbType: "GTS_DB_TYPE", + DbAddress: "GTS_DB_ADDRESS", + DbPort: "GTS_DB_PORT", + DbUser: "GTS_DB_USER", + DbPassword: "GTS_DB_PASSWORD", + DbDatabase: "GTS_DB_DATABASE", + DbTLSMode: "GTS_DB_TLS_MODE", + DbTLSCACert: "GTS_DB_CA_CERT", TemplateBaseDir: "GTS_TEMPLATE_BASEDIR", AssetBaseDir: "GTS_ASSET_BASEDIR", diff --git a/internal/config/db.go b/internal/config/db.go index fbde6fe82..7ea71a8b6 100644 --- a/internal/config/db.go +++ b/internal/config/db.go @@ -20,11 +20,30 @@ package config // DBConfig provides configuration options for the database connection type DBConfig struct { - Type string `yaml:"type"` - Address string `yaml:"address"` - Port int `yaml:"port"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` - ApplicationName string `yaml:"applicationName"` + Type string `yaml:"type"` + Address string `yaml:"address"` + Port int `yaml:"port"` + User string `yaml:"user"` + Password string `yaml:"password"` + Database string `yaml:"database"` + ApplicationName string `yaml:"applicationName"` + TLSMode DBTLSMode `yaml:"tlsMode"` + TLSCACert string `yaml:"tlsCACert"` } + +// DBTLSMode describes a mode of connecting to a database with or without TLS. +type DBTLSMode string + +// DBTLSModeDisable does not attempt to make a TLS connection to the database. +var DBTLSModeDisable DBTLSMode = "disable" + +// DBTLSModeEnable attempts to make a TLS connection to the database, but doesn't fail if +// the certificate passed by the database isn't verified. +var DBTLSModeEnable DBTLSMode = "enable" + +// DBTLSModeRequire attempts to make a TLS connection to the database, and requires +// that the certificate presented by the database is valid. +var DBTLSModeRequire DBTLSMode = "require" + +// DBTLSModeUnset means that the TLS mode has not been set. +var DBTLSModeUnset DBTLSMode = "" diff --git a/internal/config/default.go b/internal/config/default.go index 40df4c57e..7a030beb5 100644 --- a/internal/config/default.go +++ b/internal/config/default.go @@ -120,12 +120,14 @@ func GetDefaults() Defaults { Host: "", Protocol: "https", - DbType: "postgres", - DbAddress: "localhost", - DbPort: 5432, - DbUser: "postgres", - DbPassword: "postgres", - DbDatabase: "postgres", + DbType: "postgres", + DbAddress: "localhost", + DbPort: 5432, + DbUser: "postgres", + DbPassword: "postgres", + DbDatabase: "postgres", + DBTlsMode: "disable", + DBTlsCACert: "", TemplateBaseDir: "./web/template/", AssetBaseDir: "./web/assets/", diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index ad75cef15..5301f0410 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -22,10 +22,14 @@ import ( "context" "crypto/rand" "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" "errors" "fmt" "net" "net/mail" + "os" "strings" "time" @@ -133,6 +137,53 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { return nil, errors.New("no database set") } + var tlsConfig *tls.Config + switch c.DBConfig.TLSMode { + case config.DBTLSModeDisable, config.DBTLSModeUnset: + break // nothing to do + case config.DBTLSModeEnable: + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + case config.DBTLSModeRequire: + tlsConfig = &tls.Config{ + InsecureSkipVerify: false, + } + } + + if tlsConfig != nil && c.DBConfig.TLSCACert != "" { + // load the system cert pool first -- we'll append the given CA cert to this + certPool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("error fetching system CA cert pool: %s", err) + } + + // open the file itself and make sure there's something in it + caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert) + if err != nil { + return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err) + } + if len(caCertBytes) == 0 { + return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert) + } + + // make sure we have a PEM block + caPem, _ := pem.Decode(caCertBytes) + if caPem == nil { + return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert) + } + + // parse the PEM block into the certificate + caCert, err := x509.ParseCertificate(caPem.Bytes) + if err != nil { + return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err) + } + + // we're happy, add it to the existing pool and then use this pool in our tls config + certPool.AddCert(caCert) + tlsConfig.RootCAs = certPool + } + // We can rely on the pg library we're using to set // sensible defaults for everything we don't set here. options := &pg.Options{ @@ -141,6 +192,7 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { Password: c.DBConfig.Password, Database: c.DBConfig.Database, ApplicationName: c.ApplicationName, + TLSConfig: tlsConfig, } return options, nil |