diff options
Diffstat (limited to 'internal/db/bundb/bundb.go')
-rw-r--r-- | internal/db/bundb/bundb.go | 56 |
1 files changed, 24 insertions, 32 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index f458132a1..fef62a55f 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -35,7 +35,6 @@ import ( "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" "github.com/sirupsen/logrus" - "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -120,7 +119,7 @@ func doMigration(ctx context.Context, db *bun.DB) error { func NewBunDBService(ctx context.Context) (db.DB, error) { var conn *DBConn var err error - dbType := strings.ToLower(viper.GetString(config.Keys.DbType)) + dbType := strings.ToLower(config.GetDbType()) switch dbType { case dbTypePostgres: @@ -139,7 +138,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { // add a hook to log queries and the time they take // only do this for logging where performance isn't 1st concern - if logrus.GetLevel() >= logrus.DebugLevel && viper.GetBool(config.Keys.LogDbQueries) { + if logrus.GetLevel() >= logrus.DebugLevel && config.GetLogDbQueries() { conn.DB.AddQueryHook(newDebugQueryHook()) } @@ -209,9 +208,9 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { func sqliteConn(ctx context.Context) (*DBConn, error) { // validate db address has actually been set - dbAddress := viper.GetString(config.Keys.DbAddress) + dbAddress := config.GetDbAddress() if dbAddress == "" { - return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.Keys.DbAddress) + return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.DbAddressFlag()) } // Drop anything fancy from DB address @@ -282,27 +281,21 @@ func pgConn(ctx context.Context) (*DBConn, error) { // deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options // with sensible defaults, or an error if it's not satisfied by the provided config. func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { - keys := config.Keys - - if strings.ToUpper(viper.GetString(keys.DbType)) != db.DBTypePostgres { - return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, viper.GetString(keys.DbType)) + if strings.ToUpper(config.GetDbType()) != db.DBTypePostgres { + return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, config.DbTypeFlag()) } // these are all optional, the db adapter figures out defaults - port := viper.GetInt(keys.DbPort) - address := viper.GetString(keys.DbAddress) - username := viper.GetString(keys.DbUser) - password := viper.GetString(keys.DbPassword) + address := config.GetDbAddress() // validate database - database := viper.GetString(keys.DbDatabase) + database := config.GetDbDatabase() if database == "" { return nil, errors.New("no database set") } var tlsConfig *tls.Config - tlsMode := viper.GetString(keys.DbTLSMode) - switch tlsMode { + switch config.GetDbTLSMode() { case dbTLSModeDisable, dbTLSModeUnset: break // nothing to do case dbTLSModeEnable: @@ -313,13 +306,12 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { case dbTLSModeRequire: tlsConfig = &tls.Config{ InsecureSkipVerify: false, - ServerName: viper.GetString(keys.DbAddress), + ServerName: address, MinVersion: tls.VersionTLS12, } } - caCertPath := viper.GetString(keys.DbTLSCACert) - if tlsConfig != nil && caCertPath != "" { + if certPath := config.GetDbTLSCACert(); tlsConfig != nil && certPath != "" { // load the system cert pool first -- we'll append the given CA cert to this certPool, err := x509.SystemCertPool() if err != nil { @@ -327,24 +319,24 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { } // open the file itself and make sure there's something in it - caCertBytes, err := os.ReadFile(caCertPath) + caCertBytes, err := os.ReadFile(certPath) if err != nil { - return nil, fmt.Errorf("error opening CA certificate at %s: %s", caCertPath, err) + return nil, fmt.Errorf("error opening CA certificate at %s: %s", certPath, err) } if len(caCertBytes) == 0 { - return nil, fmt.Errorf("ca cert at %s was empty", caCertPath) + return nil, fmt.Errorf("ca cert at %s was empty", certPath) } // make sure we have a PEM block caPem, _ := pem.Decode(caCertBytes) if caPem == nil { - return nil, fmt.Errorf("could not parse cert at %s into PEM", caCertPath) + return nil, fmt.Errorf("could not parse cert at %s into PEM", certPath) } // parse the PEM block into the certificate caCert, err := x509.ParseCertificate(caPem.Bytes) if err != nil { - return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", caCertPath, err) + return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err) } // we're happy, add it to the existing pool and then use this pool in our tls config @@ -356,21 +348,21 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { if address != "" { cfg.Host = address } - if port > 0 { + if port := config.GetPort(); port > 0 { cfg.Port = uint16(port) } - if username != "" { - cfg.User = username + if u := config.GetDbUser(); u != "" { + cfg.User = u } - if password != "" { - cfg.Password = password + if p := config.GetDbPassword(); p != "" { + cfg.Password = p } if tlsConfig != nil { cfg.TLSConfig = tlsConfig } cfg.Database = database cfg.PreferSimpleProtocol = true - cfg.RuntimeParams["application_name"] = viper.GetString(keys.ApplicationName) + cfg.RuntimeParams["application_name"] = config.GetApplicationName() return cfg, nil } @@ -387,8 +379,8 @@ func tweakConnectionValues(sqldb *sql.DB) { */ func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) { - protocol := viper.GetString(config.Keys.Protocol) - host := viper.GetString(config.Keys.Host) + protocol := config.GetProtocol() + host := config.GetHost() newTags := []*gtsmodel.Tag{} for _, t := range tags { |