summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/config/config.go64
-rw-r--r--internal/config/db.go33
-rw-r--r--internal/config/default.go14
-rw-r--r--internal/db/pg/pg.go52
4 files changed, 126 insertions, 37 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
index 28bbc8542..323b7de81 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -165,6 +165,14 @@ func (c *Config) ParseCLIFlags(f KeyedFlags, version string) error {
c.DBConfig.Database = f.String(fn.DbDatabase)
}
+ if c.DBConfig.TLSMode == DBTLSModeUnset || f.IsSet(fn.DbTLSMode) {
+ c.DBConfig.TLSMode = DBTLSMode(f.String(fn.DbTLSMode))
+ }
+
+ if c.DBConfig.TLSCACert == "" || f.IsSet(fn.DbTLSCACert) {
+ c.DBConfig.TLSCACert = f.String(fn.DbTLSCACert)
+ }
+
// template flags
if c.TemplateConfig.BaseDir == "" || f.IsSet(fn.TemplateBaseDir) {
c.TemplateConfig.BaseDir = f.String(fn.TemplateBaseDir)
@@ -284,12 +292,14 @@ type Flags struct {
Host string
Protocol string
- DbType string
- DbAddress string
- DbPort string
- DbUser string
- DbPassword string
- DbDatabase string
+ DbType string
+ DbAddress string
+ DbPort string
+ DbUser string
+ DbPassword string
+ DbDatabase string
+ DbTLSMode string
+ DbTLSCACert string
TemplateBaseDir string
AssetBaseDir string
@@ -329,12 +339,14 @@ type Defaults struct {
Protocol string
SoftwareVersion string
- DbType string
- DbAddress string
- DbPort int
- DbUser string
- DbPassword string
- DbDatabase string
+ DbType string
+ DbAddress string
+ DbPort int
+ DbUser string
+ DbPassword string
+ DbDatabase string
+ DBTlsMode string
+ DBTlsCACert string
TemplateBaseDir string
AssetBaseDir string
@@ -375,12 +387,14 @@ func GetFlagNames() Flags {
Host: "host",
Protocol: "protocol",
- DbType: "db-type",
- DbAddress: "db-address",
- DbPort: "db-port",
- DbUser: "db-user",
- DbPassword: "db-password",
- DbDatabase: "db-database",
+ DbType: "db-type",
+ DbAddress: "db-address",
+ DbPort: "db-port",
+ DbUser: "db-user",
+ DbPassword: "db-password",
+ DbDatabase: "db-database",
+ DbTLSMode: "db-tls-mode",
+ DbTLSCACert: "db-tls-ca-cert",
TemplateBaseDir: "template-basedir",
AssetBaseDir: "asset-basedir",
@@ -422,12 +436,14 @@ func GetEnvNames() Flags {
Host: "GTS_HOST",
Protocol: "GTS_PROTOCOL",
- DbType: "GTS_DB_TYPE",
- DbAddress: "GTS_DB_ADDRESS",
- DbPort: "GTS_DB_PORT",
- DbUser: "GTS_DB_USER",
- DbPassword: "GTS_DB_PASSWORD",
- DbDatabase: "GTS_DB_DATABASE",
+ DbType: "GTS_DB_TYPE",
+ DbAddress: "GTS_DB_ADDRESS",
+ DbPort: "GTS_DB_PORT",
+ DbUser: "GTS_DB_USER",
+ DbPassword: "GTS_DB_PASSWORD",
+ DbDatabase: "GTS_DB_DATABASE",
+ DbTLSMode: "GTS_DB_TLS_MODE",
+ DbTLSCACert: "GTS_DB_CA_CERT",
TemplateBaseDir: "GTS_TEMPLATE_BASEDIR",
AssetBaseDir: "GTS_ASSET_BASEDIR",
diff --git a/internal/config/db.go b/internal/config/db.go
index fbde6fe82..7ea71a8b6 100644
--- a/internal/config/db.go
+++ b/internal/config/db.go
@@ -20,11 +20,30 @@ package config
// DBConfig provides configuration options for the database connection
type DBConfig struct {
- Type string `yaml:"type"`
- Address string `yaml:"address"`
- Port int `yaml:"port"`
- User string `yaml:"user"`
- Password string `yaml:"password"`
- Database string `yaml:"database"`
- ApplicationName string `yaml:"applicationName"`
+ Type string `yaml:"type"`
+ Address string `yaml:"address"`
+ Port int `yaml:"port"`
+ User string `yaml:"user"`
+ Password string `yaml:"password"`
+ Database string `yaml:"database"`
+ ApplicationName string `yaml:"applicationName"`
+ TLSMode DBTLSMode `yaml:"tlsMode"`
+ TLSCACert string `yaml:"tlsCACert"`
}
+
+// DBTLSMode describes a mode of connecting to a database with or without TLS.
+type DBTLSMode string
+
+// DBTLSModeDisable does not attempt to make a TLS connection to the database.
+var DBTLSModeDisable DBTLSMode = "disable"
+
+// DBTLSModeEnable attempts to make a TLS connection to the database, but doesn't fail if
+// the certificate passed by the database isn't verified.
+var DBTLSModeEnable DBTLSMode = "enable"
+
+// DBTLSModeRequire attempts to make a TLS connection to the database, and requires
+// that the certificate presented by the database is valid.
+var DBTLSModeRequire DBTLSMode = "require"
+
+// DBTLSModeUnset means that the TLS mode has not been set.
+var DBTLSModeUnset DBTLSMode = ""
diff --git a/internal/config/default.go b/internal/config/default.go
index 40df4c57e..7a030beb5 100644
--- a/internal/config/default.go
+++ b/internal/config/default.go
@@ -120,12 +120,14 @@ func GetDefaults() Defaults {
Host: "",
Protocol: "https",
- DbType: "postgres",
- DbAddress: "localhost",
- DbPort: 5432,
- DbUser: "postgres",
- DbPassword: "postgres",
- DbDatabase: "postgres",
+ DbType: "postgres",
+ DbAddress: "localhost",
+ DbPort: 5432,
+ DbUser: "postgres",
+ DbPassword: "postgres",
+ DbDatabase: "postgres",
+ DBTlsMode: "disable",
+ DBTlsCACert: "",
TemplateBaseDir: "./web/template/",
AssetBaseDir: "./web/assets/",
diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go
index ad75cef15..5301f0410 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/pg/pg.go
@@ -22,10 +22,14 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/pem"
"errors"
"fmt"
"net"
"net/mail"
+ "os"
"strings"
"time"
@@ -133,6 +137,53 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
return nil, errors.New("no database set")
}
+ var tlsConfig *tls.Config
+ switch c.DBConfig.TLSMode {
+ case config.DBTLSModeDisable, config.DBTLSModeUnset:
+ break // nothing to do
+ case config.DBTLSModeEnable:
+ tlsConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ }
+ case config.DBTLSModeRequire:
+ tlsConfig = &tls.Config{
+ InsecureSkipVerify: false,
+ }
+ }
+
+ if tlsConfig != nil && c.DBConfig.TLSCACert != "" {
+ // load the system cert pool first -- we'll append the given CA cert to this
+ certPool, err := x509.SystemCertPool()
+ if err != nil {
+ return nil, fmt.Errorf("error fetching system CA cert pool: %s", err)
+ }
+
+ // open the file itself and make sure there's something in it
+ caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert)
+ if err != nil {
+ return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err)
+ }
+ if len(caCertBytes) == 0 {
+ return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert)
+ }
+
+ // make sure we have a PEM block
+ caPem, _ := pem.Decode(caCertBytes)
+ if caPem == nil {
+ return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert)
+ }
+
+ // parse the PEM block into the certificate
+ caCert, err := x509.ParseCertificate(caPem.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err)
+ }
+
+ // we're happy, add it to the existing pool and then use this pool in our tls config
+ certPool.AddCert(caCert)
+ tlsConfig.RootCAs = certPool
+ }
+
// We can rely on the pg library we're using to set
// sensible defaults for everything we don't set here.
options := &pg.Options{
@@ -141,6 +192,7 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
Password: c.DBConfig.Password,
Database: c.DBConfig.Database,
ApplicationName: c.ApplicationName,
+ TLSConfig: tlsConfig,
}
return options, nil