summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/account.go4
-rw-r--r--internal/db/bundb/admin.go7
-rw-r--r--internal/db/bundb/bundb.go56
-rw-r--r--internal/db/bundb/bundbnew_test.go3
-rw-r--r--internal/db/bundb/domain.go3
-rw-r--r--internal/db/bundb/instance.go10
6 files changed, 33 insertions, 50 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 59292055e..6061676c5 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -25,7 +25,6 @@ import (
"strings"
"time"
- "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -133,9 +132,8 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
Where("account.username = ?", domain).
Where("account.domain = ?", domain)
} else {
- host := viper.GetString(config.Keys.Host)
q = q.
- Where("account.username = ?", host).
+ Where("account.username = ?", config.GetHost()).
WhereGroup(" AND ", whereEmptyOrNull("domain"))
}
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
index a92834f9c..8b9c7c9a3 100644
--- a/internal/db/bundb/admin.go
+++ b/internal/db/bundb/admin.go
@@ -30,7 +30,6 @@ import (
"time"
"github.com/sirupsen/logrus"
- "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/config"
@@ -178,7 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
}
func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
- username := viper.GetString(config.Keys.Host)
+ username := config.GetHost()
q := a.conn.
NewSelect().
@@ -237,8 +236,8 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
}
func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
- protocol := viper.GetString(config.Keys.Protocol)
- host := viper.GetString(config.Keys.Host)
+ protocol := config.GetProtocol()
+ host := config.GetHost()
// check if instance entry already exists
q := a.conn.
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index f458132a1..fef62a55f 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -35,7 +35,6 @@ import (
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
- "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -120,7 +119,7 @@ func doMigration(ctx context.Context, db *bun.DB) error {
func NewBunDBService(ctx context.Context) (db.DB, error) {
var conn *DBConn
var err error
- dbType := strings.ToLower(viper.GetString(config.Keys.DbType))
+ dbType := strings.ToLower(config.GetDbType())
switch dbType {
case dbTypePostgres:
@@ -139,7 +138,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
// add a hook to log queries and the time they take
// only do this for logging where performance isn't 1st concern
- if logrus.GetLevel() >= logrus.DebugLevel && viper.GetBool(config.Keys.LogDbQueries) {
+ if logrus.GetLevel() >= logrus.DebugLevel && config.GetLogDbQueries() {
conn.DB.AddQueryHook(newDebugQueryHook())
}
@@ -209,9 +208,9 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
func sqliteConn(ctx context.Context) (*DBConn, error) {
// validate db address has actually been set
- dbAddress := viper.GetString(config.Keys.DbAddress)
+ dbAddress := config.GetDbAddress()
if dbAddress == "" {
- return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.Keys.DbAddress)
+ return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.DbAddressFlag())
}
// Drop anything fancy from DB address
@@ -282,27 +281,21 @@ func pgConn(ctx context.Context) (*DBConn, error) {
// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
- keys := config.Keys
-
- if strings.ToUpper(viper.GetString(keys.DbType)) != db.DBTypePostgres {
- return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, viper.GetString(keys.DbType))
+ if strings.ToUpper(config.GetDbType()) != db.DBTypePostgres {
+ return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, config.DbTypeFlag())
}
// these are all optional, the db adapter figures out defaults
- port := viper.GetInt(keys.DbPort)
- address := viper.GetString(keys.DbAddress)
- username := viper.GetString(keys.DbUser)
- password := viper.GetString(keys.DbPassword)
+ address := config.GetDbAddress()
// validate database
- database := viper.GetString(keys.DbDatabase)
+ database := config.GetDbDatabase()
if database == "" {
return nil, errors.New("no database set")
}
var tlsConfig *tls.Config
- tlsMode := viper.GetString(keys.DbTLSMode)
- switch tlsMode {
+ switch config.GetDbTLSMode() {
case dbTLSModeDisable, dbTLSModeUnset:
break // nothing to do
case dbTLSModeEnable:
@@ -313,13 +306,12 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
case dbTLSModeRequire:
tlsConfig = &tls.Config{
InsecureSkipVerify: false,
- ServerName: viper.GetString(keys.DbAddress),
+ ServerName: address,
MinVersion: tls.VersionTLS12,
}
}
- caCertPath := viper.GetString(keys.DbTLSCACert)
- if tlsConfig != nil && caCertPath != "" {
+ if certPath := config.GetDbTLSCACert(); tlsConfig != nil && certPath != "" {
// load the system cert pool first -- we'll append the given CA cert to this
certPool, err := x509.SystemCertPool()
if err != nil {
@@ -327,24 +319,24 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
}
// open the file itself and make sure there's something in it
- caCertBytes, err := os.ReadFile(caCertPath)
+ caCertBytes, err := os.ReadFile(certPath)
if err != nil {
- return nil, fmt.Errorf("error opening CA certificate at %s: %s", caCertPath, err)
+ return nil, fmt.Errorf("error opening CA certificate at %s: %s", certPath, err)
}
if len(caCertBytes) == 0 {
- return nil, fmt.Errorf("ca cert at %s was empty", caCertPath)
+ return nil, fmt.Errorf("ca cert at %s was empty", certPath)
}
// make sure we have a PEM block
caPem, _ := pem.Decode(caCertBytes)
if caPem == nil {
- return nil, fmt.Errorf("could not parse cert at %s into PEM", caCertPath)
+ return nil, fmt.Errorf("could not parse cert at %s into PEM", certPath)
}
// parse the PEM block into the certificate
caCert, err := x509.ParseCertificate(caPem.Bytes)
if err != nil {
- return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", caCertPath, err)
+ return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err)
}
// we're happy, add it to the existing pool and then use this pool in our tls config
@@ -356,21 +348,21 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
if address != "" {
cfg.Host = address
}
- if port > 0 {
+ if port := config.GetPort(); port > 0 {
cfg.Port = uint16(port)
}
- if username != "" {
- cfg.User = username
+ if u := config.GetDbUser(); u != "" {
+ cfg.User = u
}
- if password != "" {
- cfg.Password = password
+ if p := config.GetDbPassword(); p != "" {
+ cfg.Password = p
}
if tlsConfig != nil {
cfg.TLSConfig = tlsConfig
}
cfg.Database = database
cfg.PreferSimpleProtocol = true
- cfg.RuntimeParams["application_name"] = viper.GetString(keys.ApplicationName)
+ cfg.RuntimeParams["application_name"] = config.GetApplicationName()
return cfg, nil
}
@@ -387,8 +379,8 @@ func tweakConnectionValues(sqldb *sql.DB) {
*/
func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) {
- protocol := viper.GetString(config.Keys.Protocol)
- host := viper.GetString(config.Keys.Host)
+ protocol := config.GetProtocol()
+ host := config.GetHost()
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
diff --git a/internal/db/bundb/bundbnew_test.go b/internal/db/bundb/bundbnew_test.go
index 40a05cb50..d5e413a4f 100644
--- a/internal/db/bundb/bundbnew_test.go
+++ b/internal/db/bundb/bundbnew_test.go
@@ -22,7 +22,6 @@ import (
"context"
"testing"
- "github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
@@ -41,7 +40,7 @@ func (suite *BundbNewTestSuite) TestCreateNewDB() {
func (suite *BundbNewTestSuite) TestCreateNewSqliteDBNoAddress() {
// create a new db with no address specified
- viper.Set(config.Keys.DbAddress, "")
+ config.SetDbAddress("")
db, err := bundb.NewBunDBService(context.Background())
suite.EqualError(err, "'db-address' was not set when attempting to start sqlite")
suite.Nil(db)
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index 9ddd33b05..ee7fed6a9 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -23,7 +23,6 @@ import (
"net/url"
"strings"
- "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -35,7 +34,7 @@ type domainDB struct {
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
- if domain == "" || domain == viper.GetString(config.Keys.Host) {
+ if domain == "" || domain == config.GetHost() {
return false, nil
}
diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go
index 24cc6f1be..d16fac90b 100644
--- a/internal/db/bundb/instance.go
+++ b/internal/db/bundb/instance.go
@@ -22,7 +22,6 @@ import (
"context"
"github.com/sirupsen/logrus"
- "github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -41,8 +40,7 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
Where("username != ?", domain).
Where("? IS NULL", bun.Ident("suspended_at"))
- host := viper.GetString(config.Keys.Host)
- if domain == host {
+ if domain == config.GetHost() {
// if the domain is *this* domain, just count where the domain field is null
q = q.WhereGroup(" AND ", whereEmptyOrNull("domain"))
} else {
@@ -61,8 +59,7 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
NewSelect().
Model(&[]*gtsmodel.Status{})
- host := viper.GetString(config.Keys.Host)
- if domain == host {
+ if domain == config.GetHost() {
// if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true)
} else {
@@ -83,8 +80,7 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
NewSelect().
Model(&[]*gtsmodel.Instance{})
- host := viper.GetString(config.Keys.Host)
- if domain == host {
+ if domain == config.GetHost() {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.