diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/config/config.go | 20 | ||||
| -rw-r--r-- | internal/config/defaults.go | 20 | ||||
| -rw-r--r-- | internal/config/flags.go | 4 | ||||
| -rw-r--r-- | internal/config/helpers.gen.go | 100 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 126 | 
5 files changed, 206 insertions, 64 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index ec8675f2d..c28cfe419 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -58,14 +58,18 @@ type Configuration struct {  	TrustedProxies  []string `name:"trusted-proxies" usage:"Proxies to trust when parsing x-forwarded headers into real IPs."`  	SoftwareVersion string   `name:"software-version" usage:""` -	DbType      string `name:"db-type" usage:"Database type: eg., postgres"` -	DbAddress   string `name:"db-address" usage:"Database ipv4 address, hostname, or filename"` -	DbPort      int    `name:"db-port" usage:"Database port"` -	DbUser      string `name:"db-user" usage:"Database username"` -	DbPassword  string `name:"db-password" usage:"Database password"` -	DbDatabase  string `name:"db-database" usage:"Database name"` -	DbTLSMode   string `name:"db-tls-mode" usage:"Database tls mode"` -	DbTLSCACert string `name:"db-tls-ca-cert" usage:"Path to CA cert for db tls connection"` +	DbType              string        `name:"db-type" usage:"Database type: eg., postgres"` +	DbAddress           string        `name:"db-address" usage:"Database ipv4 address, hostname, or filename"` +	DbPort              int           `name:"db-port" usage:"Database port"` +	DbUser              string        `name:"db-user" usage:"Database username"` +	DbPassword          string        `name:"db-password" usage:"Database password"` +	DbDatabase          string        `name:"db-database" usage:"Database name"` +	DbTLSMode           string        `name:"db-tls-mode" usage:"Database tls mode"` +	DbTLSCACert         string        `name:"db-tls-ca-cert" usage:"Path to CA cert for db tls connection"` +	DbSqliteJournalMode string        `name:"db-sqlite-journal-mode" usage:"Sqlite only: see https://www.sqlite.org/pragma.html#pragma_journal_mode"` +	DbSqliteSynchronous string        `name:"db-sqlite-synchronous" usage:"Sqlite only: see https://www.sqlite.org/pragma.html#pragma_synchronous"` +	DbSqliteCacheSize   bytesize.Size `name:"db-sqlite-cache-size" usage:"Sqlite only: see https://www.sqlite.org/pragma.html#pragma_cache_size"` +	DbSqliteBusyTimeout time.Duration `name:"db-sqlite-busy-timeout" usage:"Sqlite only: see https://www.sqlite.org/pragma.html#pragma_busy_timeout"`  	WebTemplateBaseDir string `name:"web-template-base-dir" usage:"Basedir for html templating files for rendering pages and composing emails."`  	WebAssetBaseDir    string `name:"web-asset-base-dir" usage:"Directory to serve static assets from, accessible at example.org/assets/"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 4d61bec05..31f282113 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -40,14 +40,18 @@ var Defaults = Configuration{  	Port:            8080,  	TrustedProxies:  []string{"127.0.0.1/32", "::1"}, // localhost -	DbType:      "postgres", -	DbAddress:   "", -	DbPort:      5432, -	DbUser:      "", -	DbPassword:  "", -	DbDatabase:  "gotosocial", -	DbTLSMode:   "disable", -	DbTLSCACert: "", +	DbType:              "postgres", +	DbAddress:           "", +	DbPort:              5432, +	DbUser:              "", +	DbPassword:          "", +	DbDatabase:          "gotosocial", +	DbTLSMode:           "disable", +	DbTLSCACert:         "", +	DbSqliteJournalMode: "WAL", +	DbSqliteSynchronous: "NORMAL", +	DbSqliteCacheSize:   64 * bytesize.MiB, +	DbSqliteBusyTimeout: time.Second * 30,  	WebTemplateBaseDir: "./web/template/",  	WebAssetBaseDir:    "./web/assets/", diff --git a/internal/config/flags.go b/internal/config/flags.go index e3d1b20da..a0fde3eed 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -51,6 +51,10 @@ func (s *ConfigState) AddGlobalFlags(cmd *cobra.Command) {  		cmd.PersistentFlags().String(DbDatabaseFlag(), cfg.DbDatabase, fieldtag("DbDatabase", "usage"))  		cmd.PersistentFlags().String(DbTLSModeFlag(), cfg.DbTLSMode, fieldtag("DbTLSMode", "usage"))  		cmd.PersistentFlags().String(DbTLSCACertFlag(), cfg.DbTLSCACert, fieldtag("DbTLSCACert", "usage")) +		cmd.PersistentFlags().String(DbSqliteJournalModeFlag(), cfg.DbSqliteJournalMode, fieldtag("DbSqliteJournalMode", "usage")) +		cmd.PersistentFlags().String(DbSqliteSynchronousFlag(), cfg.DbSqliteSynchronous, fieldtag("DbSqliteSynchronous", "usage")) +		cmd.PersistentFlags().Uint64(DbSqliteCacheSizeFlag(), uint64(cfg.DbSqliteCacheSize), fieldtag("DbSqliteCacheSize", "usage")) +		cmd.PersistentFlags().Duration(DbSqliteBusyTimeoutFlag(), cfg.DbSqliteBusyTimeout, fieldtag("DbSqliteBusyTimeout", "usage"))  	})  } diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index f340360b2..1da2ff42c 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -524,6 +524,106 @@ func GetDbTLSCACert() string { return global.GetDbTLSCACert() }  // SetDbTLSCACert safely sets the value for global configuration 'DbTLSCACert' field  func SetDbTLSCACert(v string) { global.SetDbTLSCACert(v) } +// GetDbSqliteJournalMode safely fetches the Configuration value for state's 'DbSqliteJournalMode' field +func (st *ConfigState) GetDbSqliteJournalMode() (v string) { +	st.mutex.Lock() +	v = st.config.DbSqliteJournalMode +	st.mutex.Unlock() +	return +} + +// SetDbSqliteJournalMode safely sets the Configuration value for state's 'DbSqliteJournalMode' field +func (st *ConfigState) SetDbSqliteJournalMode(v string) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.DbSqliteJournalMode = v +	st.reloadToViper() +} + +// DbSqliteJournalModeFlag returns the flag name for the 'DbSqliteJournalMode' field +func DbSqliteJournalModeFlag() string { return "db-sqlite-journal-mode" } + +// GetDbSqliteJournalMode safely fetches the value for global configuration 'DbSqliteJournalMode' field +func GetDbSqliteJournalMode() string { return global.GetDbSqliteJournalMode() } + +// SetDbSqliteJournalMode safely sets the value for global configuration 'DbSqliteJournalMode' field +func SetDbSqliteJournalMode(v string) { global.SetDbSqliteJournalMode(v) } + +// GetDbSqliteSynchronous safely fetches the Configuration value for state's 'DbSqliteSynchronous' field +func (st *ConfigState) GetDbSqliteSynchronous() (v string) { +	st.mutex.Lock() +	v = st.config.DbSqliteSynchronous +	st.mutex.Unlock() +	return +} + +// SetDbSqliteSynchronous safely sets the Configuration value for state's 'DbSqliteSynchronous' field +func (st *ConfigState) SetDbSqliteSynchronous(v string) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.DbSqliteSynchronous = v +	st.reloadToViper() +} + +// DbSqliteSynchronousFlag returns the flag name for the 'DbSqliteSynchronous' field +func DbSqliteSynchronousFlag() string { return "db-sqlite-synchronous" } + +// GetDbSqliteSynchronous safely fetches the value for global configuration 'DbSqliteSynchronous' field +func GetDbSqliteSynchronous() string { return global.GetDbSqliteSynchronous() } + +// SetDbSqliteSynchronous safely sets the value for global configuration 'DbSqliteSynchronous' field +func SetDbSqliteSynchronous(v string) { global.SetDbSqliteSynchronous(v) } + +// GetDbSqliteCacheSize safely fetches the Configuration value for state's 'DbSqliteCacheSize' field +func (st *ConfigState) GetDbSqliteCacheSize() (v bytesize.Size) { +	st.mutex.Lock() +	v = st.config.DbSqliteCacheSize +	st.mutex.Unlock() +	return +} + +// SetDbSqliteCacheSize safely sets the Configuration value for state's 'DbSqliteCacheSize' field +func (st *ConfigState) SetDbSqliteCacheSize(v bytesize.Size) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.DbSqliteCacheSize = v +	st.reloadToViper() +} + +// DbSqliteCacheSizeFlag returns the flag name for the 'DbSqliteCacheSize' field +func DbSqliteCacheSizeFlag() string { return "db-sqlite-cache-size" } + +// GetDbSqliteCacheSize safely fetches the value for global configuration 'DbSqliteCacheSize' field +func GetDbSqliteCacheSize() bytesize.Size { return global.GetDbSqliteCacheSize() } + +// SetDbSqliteCacheSize safely sets the value for global configuration 'DbSqliteCacheSize' field +func SetDbSqliteCacheSize(v bytesize.Size) { global.SetDbSqliteCacheSize(v) } + +// GetDbSqliteBusyTimeout safely fetches the Configuration value for state's 'DbSqliteBusyTimeout' field +func (st *ConfigState) GetDbSqliteBusyTimeout() (v time.Duration) { +	st.mutex.Lock() +	v = st.config.DbSqliteBusyTimeout +	st.mutex.Unlock() +	return +} + +// SetDbSqliteBusyTimeout safely sets the Configuration value for state's 'DbSqliteBusyTimeout' field +func (st *ConfigState) SetDbSqliteBusyTimeout(v time.Duration) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.DbSqliteBusyTimeout = v +	st.reloadToViper() +} + +// DbSqliteBusyTimeoutFlag returns the flag name for the 'DbSqliteBusyTimeout' field +func DbSqliteBusyTimeoutFlag() string { return "db-sqlite-busy-timeout" } + +// GetDbSqliteBusyTimeout safely fetches the value for global configuration 'DbSqliteBusyTimeout' field +func GetDbSqliteBusyTimeout() time.Duration { return global.GetDbSqliteBusyTimeout() } + +// SetDbSqliteBusyTimeout safely sets the value for global configuration 'DbSqliteBusyTimeout' field +func SetDbSqliteBusyTimeout(v time.Duration) { global.SetDbSqliteBusyTimeout(v) } +  // GetWebTemplateBaseDir safely fetches the Configuration value for state's 'WebTemplateBaseDir' field  func (st *ConfigState) GetWebTemplateBaseDir() (v string) {  	st.mutex.Lock() diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 1225b2bb0..b6a07bdc6 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -28,9 +28,11 @@ import (  	"fmt"  	"os"  	"runtime" +	"strconv"  	"strings"  	"time" +	"codeberg.org/gruf/go-bytesize"  	"github.com/google/uuid"  	"github.com/jackc/pgx/v4"  	"github.com/jackc/pgx/v4/stdlib" @@ -49,22 +51,6 @@ import (  	"modernc.org/sqlite"  ) -const ( -	dbTypePostgres = "postgres" -	dbTypeSqlite   = "sqlite" - -	// dbTLSModeDisable does not attempt to make a TLS connection to the database. -	dbTLSModeDisable = "disable" -	// dbTLSModeEnable attempts to make a TLS connection to the database, but doesn't fail if -	// the certificate passed by the database isn't verified. -	dbTLSModeEnable = "enable" -	// dbTLSModeRequire attempts to make a TLS connection to the database, and requires -	// that the certificate presented by the database is valid. -	dbTLSModeRequire = "require" -	// dbTLSModeUnset means that the TLS mode has not been set. -	dbTLSModeUnset = "" -) -  var registerTables = []interface{}{  	>smodel.AccountToEmoji{},  	>smodel.StatusToEmoji{}, @@ -127,26 +113,34 @@ func doMigration(ctx context.Context, db *bun.DB) error {  func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  	var conn *DBConn  	var err error -	dbType := strings.ToLower(config.GetDbType()) +	t := strings.ToLower(config.GetDbType()) -	switch dbType { -	case dbTypePostgres: +	switch t { +	case "postgres":  		conn, err = pgConn(ctx)  		if err != nil {  			return nil, err  		} -	case dbTypeSqlite: +	case "sqlite":  		conn, err = sqliteConn(ctx)  		if err != nil {  			return nil, err  		}  	default: -		return nil, fmt.Errorf("database type %s not supported for bundb", dbType) +		return nil, fmt.Errorf("database type %s not supported for bundb", t)  	}  	// Add database query hook  	conn.DB.AddQueryHook(queryHook{}) +	// execute sqlite pragmas *after* adding database hook; +	// this allows the pragma queries to be logged +	if t == "sqlite" { +		if err := sqlitePragmas(ctx, conn); err != nil { +			return nil, err +		} +	} +  	// table registration is needed for many-to-many, see:  	// https://bun.uptrace.dev/orm/many-to-many-relation/  	for _, t := range registerTables { @@ -230,29 +224,29 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  func sqliteConn(ctx context.Context) (*DBConn, error) {  	// validate db address has actually been set -	dbAddress := config.GetDbAddress() -	if dbAddress == "" { +	address := config.GetDbAddress() +	if address == "" {  		return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.DbAddressFlag())  	}  	// Drop anything fancy from DB address -	dbAddress = strings.Split(dbAddress, "?")[0] -	dbAddress = strings.TrimPrefix(dbAddress, "file:") +	address = strings.Split(address, "?")[0] +	address = strings.TrimPrefix(address, "file:")  	// Append our own SQLite preferences -	dbAddress = "file:" + dbAddress + "?cache=shared" +	address = "file:" + address  	var inMem bool -	if dbAddress == "file::memory:?cache=shared" { -		dbAddress = fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString()) -		log.Infof("using in-memory database address " + dbAddress) +	if address == "file::memory:" { +		address = fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString()) +		log.Infof("using in-memory database address " + address)  		log.Warn("sqlite in-memory database should only be used for debugging")  		inMem = true  	}  	// Open new DB instance -	sqldb, err := sql.Open("sqlite", dbAddress) +	sqldb, err := sql.Open("sqlite", address)  	if err != nil {  		if errWithCode, ok := err.(*sqlite.Error); ok {  			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) @@ -260,8 +254,6 @@ func sqliteConn(ctx context.Context) (*DBConn, error) {  		return nil, fmt.Errorf("could not open sqlite db: %s", err)  	} -	tweakConnectionValues(sqldb) -  	if inMem {  		// don't close connections on disconnect -- otherwise  		// the SQLite database will be deleted when there @@ -269,6 +261,7 @@ func sqliteConn(ctx context.Context) (*DBConn, error) {  		sqldb.SetConnMaxLifetime(0)  	} +	// Wrap Bun database conn in our own wrapper  	conn := WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()))  	// ping to check the db is there and listening @@ -278,11 +271,56 @@ func sqliteConn(ctx context.Context) (*DBConn, error) {  		}  		return nil, fmt.Errorf("sqlite ping: %s", err)  	} -  	log.Info("connected to SQLITE database") +  	return conn, nil  } +func sqlitePragmas(ctx context.Context, conn *DBConn) error { +	var pragmas [][]string +	if mode := config.GetDbSqliteJournalMode(); mode != "" { +		// Set the user provided SQLite journal mode +		pragmas = append(pragmas, []string{"journal_mode", mode}) +	} + +	if mode := config.GetDbSqliteSynchronous(); mode != "" { +		// Set the user provided SQLite synchronous mode +		pragmas = append(pragmas, []string{"synchronous", mode}) +	} + +	if size := config.GetDbSqliteCacheSize(); size > 0 { +		// Set the user provided SQLite cache size (in kibibytes) +		// Prepend a '-' character to this to indicate to sqlite +		// that we're giving kibibytes rather than num pages. +		// https://www.sqlite.org/pragma.html#pragma_cache_size +		s := "-" + strconv.FormatUint(uint64(size/bytesize.KiB), 10) +		pragmas = append(pragmas, []string{"cache_size", s}) +	} + +	if timeout := config.GetDbSqliteBusyTimeout(); timeout > 0 { +		t := strconv.FormatInt(timeout.Milliseconds(), 10) +		pragmas = append(pragmas, []string{"busy_timeout", t}) +	} + +	for _, p := range pragmas { +		pk := p[0] +		pv := p[1] + +		if _, err := conn.DB.ExecContext(ctx, "PRAGMA ?=?", bun.Ident(pk), bun.Safe(pv)); err != nil { +			return fmt.Errorf("error executing sqlite pragma %s: %w", pk, err) +		} + +		var res string +		if err := conn.DB.NewRaw("PRAGMA ?", bun.Ident(pk)).Scan(ctx, &res); err != nil { +			return fmt.Errorf("error scanning sqlite pragma %s: %w", pv, err) +		} + +		log.Infof("sqlite pragma %s set to %s", pk, res) +	} + +	return nil +} +  func pgConn(ctx context.Context) (*DBConn, error) {  	opts, err := deriveBunDBPGOptions() //nolint:contextcheck  	if err != nil { @@ -291,7 +329,10 @@ func pgConn(ctx context.Context) (*DBConn, error) {  	sqldb := stdlib.OpenDB(*opts) -	tweakConnectionValues(sqldb) +	// https://bun.uptrace.dev/postgres/running-bun-in-production.html#database-sql +	maxOpenConns := 4 * runtime.GOMAXPROCS(0) +	sqldb.SetMaxOpenConns(maxOpenConns) +	sqldb.SetMaxIdleConns(maxOpenConns)  	conn := WrapDBConn(bun.NewDB(sqldb, pgdialect.New())) @@ -311,10 +352,6 @@ func pgConn(ctx context.Context) (*DBConn, error) {  // deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options  // with sensible defaults, or an error if it's not satisfied by the provided config.  func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { -	if strings.ToUpper(config.GetDbType()) != db.DBTypePostgres { -		return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, config.DbTypeFlag()) -	} -  	// these are all optional, the db adapter figures out defaults  	address := config.GetDbAddress() @@ -326,14 +363,14 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {  	var tlsConfig *tls.Config  	switch config.GetDbTLSMode() { -	case dbTLSModeDisable, dbTLSModeUnset: +	case "", "disable":  		break // nothing to do -	case dbTLSModeEnable: +	case "enable":  		/* #nosec G402 */  		tlsConfig = &tls.Config{  			InsecureSkipVerify: true,  		} -	case dbTLSModeRequire: +	case "require":  		tlsConfig = &tls.Config{  			InsecureSkipVerify: false,  			ServerName:         address, @@ -397,13 +434,6 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {  	return cfg, nil  } -// https://bun.uptrace.dev/postgres/running-bun-in-production.html#database-sql -func tweakConnectionValues(sqldb *sql.DB) { -	maxOpenConns := 4 * runtime.GOMAXPROCS(0) -	sqldb.SetMaxOpenConns(maxOpenConns) -	sqldb.SetMaxIdleConns(maxOpenConns) -} -  /*  	CONVERSION FUNCTIONS  */  | 
