summaryrefslogtreecommitdiff
path: root/internal/db/bundb/bundb.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/bundb.go')
-rw-r--r--internal/db/bundb/bundb.go52
1 files changed, 27 insertions, 25 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 048474782..4ecbec7b9 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -52,13 +52,6 @@ import (
"modernc.org/sqlite"
)
-var registerTables = []interface{}{
- &gtsmodel.AccountToEmoji{},
- &gtsmodel.StatusToEmoji{},
- &gtsmodel.StatusToTag{},
- &gtsmodel.ThreadToStatus{},
-}
-
// DBService satisfies the DB interface
type DBService struct {
db.Account
@@ -88,12 +81,12 @@ type DBService struct {
db.Timeline
db.User
db.Tombstone
- db *DB
+ db *bun.DB
}
// GetDB returns the underlying database connection pool.
// Should only be used in testing + exceptional circumstance.
-func (dbService *DBService) DB() *DB {
+func (dbService *DBService) DB() *bun.DB {
return dbService.db
}
@@ -129,18 +122,18 @@ func doMigration(ctx context.Context, db *bun.DB) error {
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
- var db *DB
+ var db *bun.DB
var err error
t := strings.ToLower(config.GetDbType())
switch t {
case "postgres":
- db, err = pgConn(ctx)
+ db, err = pgConn(ctx, state)
if err != nil {
return nil, err
}
case "sqlite":
- db, err = sqliteConn(ctx)
+ db, err = sqliteConn(ctx, state)
if err != nil {
return nil, err
}
@@ -159,14 +152,19 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
// table registration is needed for many-to-many, see:
// https://bun.uptrace.dev/orm/many-to-many-relation/
- for _, t := range registerTables {
+ for _, t := range []interface{}{
+ &gtsmodel.AccountToEmoji{},
+ &gtsmodel.StatusToEmoji{},
+ &gtsmodel.StatusToTag{},
+ &gtsmodel.ThreadToStatus{},
+ } {
db.RegisterModel(t)
}
// perform any pending database migrations: this includes
// the very first 'migration' on startup which just creates
// necessary tables
- if err := doMigration(ctx, db.bun); err != nil {
+ if err := doMigration(ctx, db); err != nil {
return nil, fmt.Errorf("db migration error: %s", err)
}
@@ -284,13 +282,18 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
return ps, nil
}
-func pgConn(ctx context.Context) (*DB, error) {
+func pgConn(ctx context.Context, state *state.State) (*bun.DB, error) {
opts, err := deriveBunDBPGOptions() //nolint:contextcheck
if err != nil {
- return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
+ return nil, fmt.Errorf("could not create bundb postgres options: %w", err)
}
- sqldb := stdlib.OpenDB(*opts)
+ cfg := stdlib.RegisterConnConfig(opts)
+
+ sqldb, err := sql.Open("pgx-gts", cfg)
+ if err != nil {
+ return nil, fmt.Errorf("could not open postgres db: %w", err)
+ }
// Tune db connections for postgres, see:
// - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql
@@ -299,18 +302,18 @@ func pgConn(ctx context.Context) (*DB, error) {
sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted
sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections
- db := WrapDB(bun.NewDB(sqldb, pgdialect.New()))
+ db := bun.NewDB(sqldb, pgdialect.New())
// ping to check the db is there and listening
if err := db.PingContext(ctx); err != nil {
- return nil, fmt.Errorf("postgres ping: %s", err)
+ return nil, fmt.Errorf("postgres ping: %w", err)
}
log.Info(ctx, "connected to POSTGRES database")
return db, nil
}
-func sqliteConn(ctx context.Context) (*DB, error) {
+func sqliteConn(ctx context.Context, state *state.State) (*bun.DB, error) {
// validate db address has actually been set
address := config.GetDbAddress()
if address == "" {
@@ -321,7 +324,7 @@ func sqliteConn(ctx context.Context) (*DB, error) {
address = buildSQLiteAddress(address)
// Open new DB instance
- sqldb, err := sql.Open("sqlite", address)
+ sqldb, err := sql.Open("sqlite-gts", address)
if err != nil {
if errWithCode, ok := err.(*sqlite.Error); ok {
err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
@@ -336,15 +339,14 @@ func sqliteConn(ctx context.Context) (*DB, error) {
sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around
sqldb.SetConnMaxLifetime(0) // don't kill connections due to age
- // Wrap Bun database conn in our own wrapper
- db := WrapDB(bun.NewDB(sqldb, sqlitedialect.New()))
+ db := bun.NewDB(sqldb, sqlitedialect.New())
// ping to check the db is there and listening
if err := db.PingContext(ctx); err != nil {
if errWithCode, ok := err.(*sqlite.Error); ok {
err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
}
- return nil, fmt.Errorf("sqlite ping: %s", err)
+ return nil, fmt.Errorf("sqlite ping: %w", err)
}
log.Infof(ctx, "connected to SQLITE database with address %s", address)
@@ -418,7 +420,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
// parse the PEM block into the certificate
caCert, err := x509.ParseCertificate(caPem.Bytes)
if err != nil {
- return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err)
+ return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %w", certPath, err)
}
// we're happy, add it to the existing pool and then use this pool in our tls config