summaryrefslogtreecommitdiff
path: root/internal/db/bundb/bundb.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/bundb.go')
-rw-r--r--internal/db/bundb/bundb.go56
1 files changed, 24 insertions, 32 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index f458132a1..fef62a55f 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -35,7 +35,6 @@ import (
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
- "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -120,7 +119,7 @@ func doMigration(ctx context.Context, db *bun.DB) error {
func NewBunDBService(ctx context.Context) (db.DB, error) {
var conn *DBConn
var err error
- dbType := strings.ToLower(viper.GetString(config.Keys.DbType))
+ dbType := strings.ToLower(config.GetDbType())
switch dbType {
case dbTypePostgres:
@@ -139,7 +138,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
// add a hook to log queries and the time they take
// only do this for logging where performance isn't 1st concern
- if logrus.GetLevel() >= logrus.DebugLevel && viper.GetBool(config.Keys.LogDbQueries) {
+ if logrus.GetLevel() >= logrus.DebugLevel && config.GetLogDbQueries() {
conn.DB.AddQueryHook(newDebugQueryHook())
}
@@ -209,9 +208,9 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
func sqliteConn(ctx context.Context) (*DBConn, error) {
// validate db address has actually been set
- dbAddress := viper.GetString(config.Keys.DbAddress)
+ dbAddress := config.GetDbAddress()
if dbAddress == "" {
- return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.Keys.DbAddress)
+ return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.DbAddressFlag())
}
// Drop anything fancy from DB address
@@ -282,27 +281,21 @@ func pgConn(ctx context.Context) (*DBConn, error) {
// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
- keys := config.Keys
-
- if strings.ToUpper(viper.GetString(keys.DbType)) != db.DBTypePostgres {
- return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, viper.GetString(keys.DbType))
+ if strings.ToUpper(config.GetDbType()) != db.DBTypePostgres {
+ return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, config.DbTypeFlag())
}
// these are all optional, the db adapter figures out defaults
- port := viper.GetInt(keys.DbPort)
- address := viper.GetString(keys.DbAddress)
- username := viper.GetString(keys.DbUser)
- password := viper.GetString(keys.DbPassword)
+ address := config.GetDbAddress()
// validate database
- database := viper.GetString(keys.DbDatabase)
+ database := config.GetDbDatabase()
if database == "" {
return nil, errors.New("no database set")
}
var tlsConfig *tls.Config
- tlsMode := viper.GetString(keys.DbTLSMode)
- switch tlsMode {
+ switch config.GetDbTLSMode() {
case dbTLSModeDisable, dbTLSModeUnset:
break // nothing to do
case dbTLSModeEnable:
@@ -313,13 +306,12 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
case dbTLSModeRequire:
tlsConfig = &tls.Config{
InsecureSkipVerify: false,
- ServerName: viper.GetString(keys.DbAddress),
+ ServerName: address,
MinVersion: tls.VersionTLS12,
}
}
- caCertPath := viper.GetString(keys.DbTLSCACert)
- if tlsConfig != nil && caCertPath != "" {
+ if certPath := config.GetDbTLSCACert(); tlsConfig != nil && certPath != "" {
// load the system cert pool first -- we'll append the given CA cert to this
certPool, err := x509.SystemCertPool()
if err != nil {
@@ -327,24 +319,24 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
}
// open the file itself and make sure there's something in it
- caCertBytes, err := os.ReadFile(caCertPath)
+ caCertBytes, err := os.ReadFile(certPath)
if err != nil {
- return nil, fmt.Errorf("error opening CA certificate at %s: %s", caCertPath, err)
+ return nil, fmt.Errorf("error opening CA certificate at %s: %s", certPath, err)
}
if len(caCertBytes) == 0 {
- return nil, fmt.Errorf("ca cert at %s was empty", caCertPath)
+ return nil, fmt.Errorf("ca cert at %s was empty", certPath)
}
// make sure we have a PEM block
caPem, _ := pem.Decode(caCertBytes)
if caPem == nil {
- return nil, fmt.Errorf("could not parse cert at %s into PEM", caCertPath)
+ return nil, fmt.Errorf("could not parse cert at %s into PEM", certPath)
}
// parse the PEM block into the certificate
caCert, err := x509.ParseCertificate(caPem.Bytes)
if err != nil {
- return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", caCertPath, err)
+ return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err)
}
// we're happy, add it to the existing pool and then use this pool in our tls config
@@ -356,21 +348,21 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
if address != "" {
cfg.Host = address
}
- if port > 0 {
+ if port := config.GetPort(); port > 0 {
cfg.Port = uint16(port)
}
- if username != "" {
- cfg.User = username
+ if u := config.GetDbUser(); u != "" {
+ cfg.User = u
}
- if password != "" {
- cfg.Password = password
+ if p := config.GetDbPassword(); p != "" {
+ cfg.Password = p
}
if tlsConfig != nil {
cfg.TLSConfig = tlsConfig
}
cfg.Database = database
cfg.PreferSimpleProtocol = true
- cfg.RuntimeParams["application_name"] = viper.GetString(keys.ApplicationName)
+ cfg.RuntimeParams["application_name"] = config.GetApplicationName()
return cfg, nil
}
@@ -387,8 +379,8 @@ func tweakConnectionValues(sqldb *sql.DB) {
*/
func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) {
- protocol := viper.GetString(config.Keys.Protocol)
- host := viper.GetString(config.Keys.Host)
+ protocol := config.GetProtocol()
+ host := config.GetHost()
newTags := []*gtsmodel.Tag{}
for _, t := range tags {