diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/account.go | 4 | ||||
| -rw-r--r-- | internal/db/bundb/admin.go | 7 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 56 | ||||
| -rw-r--r-- | internal/db/bundb/bundbnew_test.go | 3 | ||||
| -rw-r--r-- | internal/db/bundb/domain.go | 3 | ||||
| -rw-r--r-- | internal/db/bundb/instance.go | 10 | 
6 files changed, 33 insertions, 50 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 59292055e..6061676c5 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -25,7 +25,6 @@ import (  	"strings"  	"time" -	"github.com/spf13/viper"  	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -133,9 +132,8 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts  			Where("account.username = ?", domain).  			Where("account.domain = ?", domain)  	} else { -		host := viper.GetString(config.Keys.Host)  		q = q. -			Where("account.username = ?", host). +			Where("account.username = ?", config.GetHost()).  			WhereGroup(" AND ", whereEmptyOrNull("domain"))  	} diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index a92834f9c..8b9c7c9a3 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -30,7 +30,6 @@ import (  	"time"  	"github.com/sirupsen/logrus" -	"github.com/spf13/viper"  	"github.com/superseriousbusiness/gotosocial/internal/ap"  	"github.com/superseriousbusiness/gotosocial/internal/config" @@ -178,7 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,  }  func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { -	username := viper.GetString(config.Keys.Host) +	username := config.GetHost()  	q := a.conn.  		NewSelect(). @@ -237,8 +236,8 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {  }  func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { -	protocol := viper.GetString(config.Keys.Protocol) -	host := viper.GetString(config.Keys.Host) +	protocol := config.GetProtocol() +	host := config.GetHost()  	// check if instance entry already exists  	q := a.conn. diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index f458132a1..fef62a55f 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -35,7 +35,6 @@ import (  	"github.com/jackc/pgx/v4"  	"github.com/jackc/pgx/v4/stdlib"  	"github.com/sirupsen/logrus" -	"github.com/spf13/viper"  	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -120,7 +119,7 @@ func doMigration(ctx context.Context, db *bun.DB) error {  func NewBunDBService(ctx context.Context) (db.DB, error) {  	var conn *DBConn  	var err error -	dbType := strings.ToLower(viper.GetString(config.Keys.DbType)) +	dbType := strings.ToLower(config.GetDbType())  	switch dbType {  	case dbTypePostgres: @@ -139,7 +138,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {  	// add a hook to log queries and the time they take  	// only do this for logging where performance isn't 1st concern -	if logrus.GetLevel() >= logrus.DebugLevel && viper.GetBool(config.Keys.LogDbQueries) { +	if logrus.GetLevel() >= logrus.DebugLevel && config.GetLogDbQueries() {  		conn.DB.AddQueryHook(newDebugQueryHook())  	} @@ -209,9 +208,9 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {  func sqliteConn(ctx context.Context) (*DBConn, error) {  	// validate db address has actually been set -	dbAddress := viper.GetString(config.Keys.DbAddress) +	dbAddress := config.GetDbAddress()  	if dbAddress == "" { -		return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.Keys.DbAddress) +		return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.DbAddressFlag())  	}  	// Drop anything fancy from DB address @@ -282,27 +281,21 @@ func pgConn(ctx context.Context) (*DBConn, error) {  // deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options  // with sensible defaults, or an error if it's not satisfied by the provided config.  func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { -	keys := config.Keys - -	if strings.ToUpper(viper.GetString(keys.DbType)) != db.DBTypePostgres { -		return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, viper.GetString(keys.DbType)) +	if strings.ToUpper(config.GetDbType()) != db.DBTypePostgres { +		return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, config.DbTypeFlag())  	}  	// these are all optional, the db adapter figures out defaults -	port := viper.GetInt(keys.DbPort) -	address := viper.GetString(keys.DbAddress) -	username := viper.GetString(keys.DbUser) -	password := viper.GetString(keys.DbPassword) +	address := config.GetDbAddress()  	// validate database -	database := viper.GetString(keys.DbDatabase) +	database := config.GetDbDatabase()  	if database == "" {  		return nil, errors.New("no database set")  	}  	var tlsConfig *tls.Config -	tlsMode := viper.GetString(keys.DbTLSMode) -	switch tlsMode { +	switch config.GetDbTLSMode() {  	case dbTLSModeDisable, dbTLSModeUnset:  		break // nothing to do  	case dbTLSModeEnable: @@ -313,13 +306,12 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {  	case dbTLSModeRequire:  		tlsConfig = &tls.Config{  			InsecureSkipVerify: false, -			ServerName:         viper.GetString(keys.DbAddress), +			ServerName:         address,  			MinVersion:         tls.VersionTLS12,  		}  	} -	caCertPath := viper.GetString(keys.DbTLSCACert) -	if tlsConfig != nil && caCertPath != "" { +	if certPath := config.GetDbTLSCACert(); tlsConfig != nil && certPath != "" {  		// load the system cert pool first -- we'll append the given CA cert to this  		certPool, err := x509.SystemCertPool()  		if err != nil { @@ -327,24 +319,24 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {  		}  		// open the file itself and make sure there's something in it -		caCertBytes, err := os.ReadFile(caCertPath) +		caCertBytes, err := os.ReadFile(certPath)  		if err != nil { -			return nil, fmt.Errorf("error opening CA certificate at %s: %s", caCertPath, err) +			return nil, fmt.Errorf("error opening CA certificate at %s: %s", certPath, err)  		}  		if len(caCertBytes) == 0 { -			return nil, fmt.Errorf("ca cert at %s was empty", caCertPath) +			return nil, fmt.Errorf("ca cert at %s was empty", certPath)  		}  		// make sure we have a PEM block  		caPem, _ := pem.Decode(caCertBytes)  		if caPem == nil { -			return nil, fmt.Errorf("could not parse cert at %s into PEM", caCertPath) +			return nil, fmt.Errorf("could not parse cert at %s into PEM", certPath)  		}  		// parse the PEM block into the certificate  		caCert, err := x509.ParseCertificate(caPem.Bytes)  		if err != nil { -			return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", caCertPath, err) +			return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err)  		}  		// we're happy, add it to the existing pool and then use this pool in our tls config @@ -356,21 +348,21 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {  	if address != "" {  		cfg.Host = address  	} -	if port > 0 { +	if port := config.GetPort(); port > 0 {  		cfg.Port = uint16(port)  	} -	if username != "" { -		cfg.User = username +	if u := config.GetDbUser(); u != "" { +		cfg.User = u  	} -	if password != "" { -		cfg.Password = password +	if p := config.GetDbPassword(); p != "" { +		cfg.Password = p  	}  	if tlsConfig != nil {  		cfg.TLSConfig = tlsConfig  	}  	cfg.Database = database  	cfg.PreferSimpleProtocol = true -	cfg.RuntimeParams["application_name"] = viper.GetString(keys.ApplicationName) +	cfg.RuntimeParams["application_name"] = config.GetApplicationName()  	return cfg, nil  } @@ -387,8 +379,8 @@ func tweakConnectionValues(sqldb *sql.DB) {  */  func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) { -	protocol := viper.GetString(config.Keys.Protocol) -	host := viper.GetString(config.Keys.Host) +	protocol := config.GetProtocol() +	host := config.GetHost()  	newTags := []*gtsmodel.Tag{}  	for _, t := range tags { diff --git a/internal/db/bundb/bundbnew_test.go b/internal/db/bundb/bundbnew_test.go index 40a05cb50..d5e413a4f 100644 --- a/internal/db/bundb/bundbnew_test.go +++ b/internal/db/bundb/bundbnew_test.go @@ -22,7 +22,6 @@ import (  	"context"  	"testing" -	"github.com/spf13/viper"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db/bundb" @@ -41,7 +40,7 @@ func (suite *BundbNewTestSuite) TestCreateNewDB() {  func (suite *BundbNewTestSuite) TestCreateNewSqliteDBNoAddress() {  	// create a new db with no address specified -	viper.Set(config.Keys.DbAddress, "") +	config.SetDbAddress("")  	db, err := bundb.NewBunDBService(context.Background())  	suite.EqualError(err, "'db-address' was not set when attempting to start sqlite")  	suite.Nil(db) diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 9ddd33b05..ee7fed6a9 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -23,7 +23,6 @@ import (  	"net/url"  	"strings" -	"github.com/spf13/viper"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -35,7 +34,7 @@ type domainDB struct {  }  func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { -	if domain == "" || domain == viper.GetString(config.Keys.Host) { +	if domain == "" || domain == config.GetHost() {  		return false, nil  	} diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 24cc6f1be..d16fac90b 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -22,7 +22,6 @@ import (  	"context"  	"github.com/sirupsen/logrus" -	"github.com/spf13/viper"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -41,8 +40,7 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int  		Where("username != ?", domain).  		Where("? IS NULL", bun.Ident("suspended_at")) -	host := viper.GetString(config.Keys.Host) -	if domain == host { +	if domain == config.GetHost() {  		// if the domain is *this* domain, just count where the domain field is null  		q = q.WhereGroup(" AND ", whereEmptyOrNull("domain"))  	} else { @@ -61,8 +59,7 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (  		NewSelect().  		Model(&[]*gtsmodel.Status{}) -	host := viper.GetString(config.Keys.Host) -	if domain == host { +	if domain == config.GetHost() {  		// if the domain is *this* domain, just count where local is true  		q = q.Where("local = ?", true)  	} else { @@ -83,8 +80,7 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i  		NewSelect().  		Model(&[]*gtsmodel.Instance{}) -	host := viper.GetString(config.Keys.Host) -	if domain == host { +	if domain == config.GetHost() {  		// if the domain is *this* domain, just count other instances it knows about  		// exclude domains that are blocked  		q = q.  | 
