diff options
Diffstat (limited to 'internal/db/bundb/bundb.go')
-rw-r--r-- | internal/db/bundb/bundb.go | 52 |
1 files changed, 27 insertions, 25 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 048474782..4ecbec7b9 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -52,13 +52,6 @@ import ( "modernc.org/sqlite" ) -var registerTables = []interface{}{ - >smodel.AccountToEmoji{}, - >smodel.StatusToEmoji{}, - >smodel.StatusToTag{}, - >smodel.ThreadToStatus{}, -} - // DBService satisfies the DB interface type DBService struct { db.Account @@ -88,12 +81,12 @@ type DBService struct { db.Timeline db.User db.Tombstone - db *DB + db *bun.DB } // GetDB returns the underlying database connection pool. // Should only be used in testing + exceptional circumstance. -func (dbService *DBService) DB() *DB { +func (dbService *DBService) DB() *bun.DB { return dbService.db } @@ -129,18 +122,18 @@ func doMigration(ctx context.Context, db *bun.DB) error { // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { - var db *DB + var db *bun.DB var err error t := strings.ToLower(config.GetDbType()) switch t { case "postgres": - db, err = pgConn(ctx) + db, err = pgConn(ctx, state) if err != nil { return nil, err } case "sqlite": - db, err = sqliteConn(ctx) + db, err = sqliteConn(ctx, state) if err != nil { return nil, err } @@ -159,14 +152,19 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { // table registration is needed for many-to-many, see: // https://bun.uptrace.dev/orm/many-to-many-relation/ - for _, t := range registerTables { + for _, t := range []interface{}{ + >smodel.AccountToEmoji{}, + >smodel.StatusToEmoji{}, + >smodel.StatusToTag{}, + >smodel.ThreadToStatus{}, + } { db.RegisterModel(t) } // perform any pending database migrations: this includes // the very first 'migration' on startup which just creates // necessary tables - if err := doMigration(ctx, db.bun); err != nil { + if err := doMigration(ctx, db); err != nil { return nil, fmt.Errorf("db migration error: %s", err) } @@ -284,13 +282,18 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { return ps, nil } -func pgConn(ctx context.Context) (*DB, error) { +func pgConn(ctx context.Context, state *state.State) (*bun.DB, error) { opts, err := deriveBunDBPGOptions() //nolint:contextcheck if err != nil { - return nil, fmt.Errorf("could not create bundb postgres options: %s", err) + return nil, fmt.Errorf("could not create bundb postgres options: %w", err) } - sqldb := stdlib.OpenDB(*opts) + cfg := stdlib.RegisterConnConfig(opts) + + sqldb, err := sql.Open("pgx-gts", cfg) + if err != nil { + return nil, fmt.Errorf("could not open postgres db: %w", err) + } // Tune db connections for postgres, see: // - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql @@ -299,18 +302,18 @@ func pgConn(ctx context.Context) (*DB, error) { sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections - db := WrapDB(bun.NewDB(sqldb, pgdialect.New())) + db := bun.NewDB(sqldb, pgdialect.New()) // ping to check the db is there and listening if err := db.PingContext(ctx); err != nil { - return nil, fmt.Errorf("postgres ping: %s", err) + return nil, fmt.Errorf("postgres ping: %w", err) } log.Info(ctx, "connected to POSTGRES database") return db, nil } -func sqliteConn(ctx context.Context) (*DB, error) { +func sqliteConn(ctx context.Context, state *state.State) (*bun.DB, error) { // validate db address has actually been set address := config.GetDbAddress() if address == "" { @@ -321,7 +324,7 @@ func sqliteConn(ctx context.Context) (*DB, error) { address = buildSQLiteAddress(address) // Open new DB instance - sqldb, err := sql.Open("sqlite", address) + sqldb, err := sql.Open("sqlite-gts", address) if err != nil { if errWithCode, ok := err.(*sqlite.Error); ok { err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) @@ -336,15 +339,14 @@ func sqliteConn(ctx context.Context) (*DB, error) { sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around sqldb.SetConnMaxLifetime(0) // don't kill connections due to age - // Wrap Bun database conn in our own wrapper - db := WrapDB(bun.NewDB(sqldb, sqlitedialect.New())) + db := bun.NewDB(sqldb, sqlitedialect.New()) // ping to check the db is there and listening if err := db.PingContext(ctx); err != nil { if errWithCode, ok := err.(*sqlite.Error); ok { err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) } - return nil, fmt.Errorf("sqlite ping: %s", err) + return nil, fmt.Errorf("sqlite ping: %w", err) } log.Infof(ctx, "connected to SQLITE database with address %s", address) @@ -418,7 +420,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { // parse the PEM block into the certificate caCert, err := x509.ParseCertificate(caPem.Bytes) if err != nil { - return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err) + return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %w", certPath, err) } // we're happy, add it to the existing pool and then use this pool in our tls config |