summaryrefslogtreecommitdiff
path: root/internal/db/bundb/drivers.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/drivers.go')
-rw-r--r--internal/db/bundb/drivers.go346
1 files changed, 5 insertions, 341 deletions
diff --git a/internal/db/bundb/drivers.go b/internal/db/bundb/drivers.go
index 1811ad533..f39189c9d 100644
--- a/internal/db/bundb/drivers.go
+++ b/internal/db/bundb/drivers.go
@@ -18,350 +18,14 @@
package bundb
import (
- "context"
"database/sql"
- "database/sql/driver"
- "time"
- _ "unsafe" // linkname shenanigans
- pgx "github.com/jackc/pgx/v5/stdlib"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "modernc.org/sqlite"
+ "github.com/superseriousbusiness/gotosocial/internal/db/postgres"
+ "github.com/superseriousbusiness/gotosocial/internal/db/sqlite"
)
-var (
- // global SQL driver instances.
- postgresDriver = pgx.GetDefaultDriver()
- sqliteDriver = getSQLiteDriver()
-
- // check the postgres connection
- // conforms to our conn{} interface.
- // (note SQLite doesn't export their
- // conn type, and gets checked in
- // tests very regularly anywho).
- _ conn = (*pgx.Conn)(nil)
-)
-
-//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver
-func getSQLiteDriver() *sqlite.Driver
-
func init() {
- sql.Register("pgx-gts", &PostgreSQLDriver{})
- sql.Register("sqlite-gts", &SQLiteDriver{})
-}
-
-// PostgreSQLDriver is our own wrapper around the
-// pgx/stdlib.Driver{} type in order to wrap further
-// SQL driver types with our own err processing.
-type PostgreSQLDriver struct{}
-
-func (d *PostgreSQLDriver) Open(name string) (driver.Conn, error) {
- c, err := postgresDriver.Open(name)
- if err != nil {
- return nil, err
- }
- return &PostgreSQLConn{conn: c.(conn)}, nil
-}
-
-type PostgreSQLConn struct{ conn }
-
-func (c *PostgreSQLConn) Begin() (driver.Tx, error) {
- return c.BeginTx(context.Background(), driver.TxOptions{})
-}
-
-func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
- tx, err := c.conn.BeginTx(ctx, opts)
- err = processPostgresError(err)
- if err != nil {
- return nil, err
- }
- return &PostgreSQLTx{tx}, nil
-}
-
-func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) {
- return c.PrepareContext(context.Background(), query)
-}
-
-func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
- st, err := c.conn.PrepareContext(ctx, query)
- err = processPostgresError(err)
- if err != nil {
- return nil, err
- }
- return &PostgreSQLStmt{stmt: st.(stmt)}, nil
-}
-
-func (c *PostgreSQLConn) Exec(query string, args []driver.Value) (driver.Result, error) {
- return c.ExecContext(context.Background(), query, toNamedValues(args))
-}
-
-func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
- result, err := c.conn.ExecContext(ctx, query, args)
- err = processPostgresError(err)
- return result, err
-}
-
-func (c *PostgreSQLConn) Query(query string, args []driver.Value) (driver.Rows, error) {
- return c.QueryContext(context.Background(), query, toNamedValues(args))
-}
-
-func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
- rows, err := c.conn.QueryContext(ctx, query, args)
- err = processPostgresError(err)
- return rows, err
-}
-
-func (c *PostgreSQLConn) Close() error {
- return c.conn.Close()
-}
-
-type PostgreSQLTx struct{ driver.Tx }
-
-func (tx *PostgreSQLTx) Commit() error {
- err := tx.Tx.Commit()
- return processPostgresError(err)
-}
-
-func (tx *PostgreSQLTx) Rollback() error {
- err := tx.Tx.Rollback()
- return processPostgresError(err)
-}
-
-type PostgreSQLStmt struct{ stmt }
-
-func (stmt *PostgreSQLStmt) Exec(args []driver.Value) (driver.Result, error) {
- return stmt.ExecContext(context.Background(), toNamedValues(args))
-}
-
-func (stmt *PostgreSQLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
- res, err := stmt.stmt.ExecContext(ctx, args)
- err = processPostgresError(err)
- return res, err
-}
-
-func (stmt *PostgreSQLStmt) Query(args []driver.Value) (driver.Rows, error) {
- return stmt.QueryContext(context.Background(), toNamedValues(args))
-}
-
-func (stmt *PostgreSQLStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
- rows, err := stmt.stmt.QueryContext(ctx, args)
- err = processPostgresError(err)
- return rows, err
-}
-
-// SQLiteDriver is our own wrapper around the
-// sqlite.Driver{} type in order to wrap further
-// SQL driver types with our own functionality,
-// e.g. hooks, retries and err processing.
-type SQLiteDriver struct{}
-
-func (d *SQLiteDriver) Open(name string) (driver.Conn, error) {
- c, err := sqliteDriver.Open(name)
- if err != nil {
- return nil, err
- }
- return &SQLiteConn{conn: c.(conn)}, nil
-}
-
-type SQLiteConn struct{ conn }
-
-func (c *SQLiteConn) Begin() (driver.Tx, error) {
- return c.BeginTx(context.Background(), driver.TxOptions{})
-}
-
-func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
- err = retryOnBusy(ctx, func() error {
- tx, err = c.conn.BeginTx(ctx, opts)
- err = processSQLiteError(err)
- return err
- })
- if err != nil {
- return nil, err
- }
- return &SQLiteTx{Context: ctx, Tx: tx}, nil
-}
-
-func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
- return c.PrepareContext(context.Background(), query)
-}
-
-func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (st driver.Stmt, err error) {
- err = retryOnBusy(ctx, func() error {
- st, err = c.conn.PrepareContext(ctx, query)
- err = processSQLiteError(err)
- return err
- })
- if err != nil {
- return nil, err
- }
- return &SQLiteStmt{st.(stmt)}, nil
-}
-
-func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
- return c.ExecContext(context.Background(), query, toNamedValues(args))
-}
-
-func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
- err = retryOnBusy(ctx, func() error {
- result, err = c.conn.ExecContext(ctx, query, args)
- err = processSQLiteError(err)
- return err
- })
- return
-}
-
-func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
- return c.QueryContext(context.Background(), query, toNamedValues(args))
-}
-
-func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
- err = retryOnBusy(ctx, func() error {
- rows, err = c.conn.QueryContext(ctx, query, args)
- err = processSQLiteError(err)
- return err
- })
- return
-}
-
-func (c *SQLiteConn) Close() error {
- // see: https://www.sqlite.org/pragma.html#pragma_optimize
- const onClose = "PRAGMA analysis_limit=1000; PRAGMA optimize;"
- _, _ = c.conn.ExecContext(context.Background(), onClose, nil)
- return c.conn.Close()
-}
-
-type SQLiteTx struct {
- context.Context
- driver.Tx
-}
-
-func (tx *SQLiteTx) Commit() (err error) {
- err = retryOnBusy(tx.Context, func() error {
- err = tx.Tx.Commit()
- err = processSQLiteError(err)
- return err
- })
- return
-}
-
-func (tx *SQLiteTx) Rollback() (err error) {
- err = retryOnBusy(tx.Context, func() error {
- err = tx.Tx.Rollback()
- err = processSQLiteError(err)
- return err
- })
- return
-}
-
-type SQLiteStmt struct{ stmt }
-
-func (stmt *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
- return stmt.ExecContext(context.Background(), toNamedValues(args))
-}
-
-func (stmt *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) {
- err = retryOnBusy(ctx, func() error {
- res, err = stmt.stmt.ExecContext(ctx, args)
- err = processSQLiteError(err)
- return err
- })
- return
-}
-
-func (stmt *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
- return stmt.QueryContext(context.Background(), toNamedValues(args))
-}
-
-func (stmt *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
- err = retryOnBusy(ctx, func() error {
- rows, err = stmt.stmt.QueryContext(ctx, args)
- err = processSQLiteError(err)
- return err
- })
- return
-}
-
-type conn interface {
- driver.Conn
- driver.ConnPrepareContext
- driver.ExecerContext
- driver.QueryerContext
- driver.ConnBeginTx
-}
-
-type stmt interface {
- driver.Stmt
- driver.StmtExecContext
- driver.StmtQueryContext
-}
-
-// retryOnBusy will retry given function on returned 'errBusy'.
-func retryOnBusy(ctx context.Context, fn func() error) error {
- if err := fn(); err != errBusy {
- return err
- }
- return retryOnBusySlow(ctx, fn)
-}
-
-// retryOnBusySlow is the outlined form of retryOnBusy, to allow the fast path (i.e. only
-// 1 attempt) to be inlined, leaving the slow retry loop to be a separate function call.
-func retryOnBusySlow(ctx context.Context, fn func() error) error {
- var backoff time.Duration
-
- for i := 0; ; i++ {
- // backoff according to a multiplier of 2ms * 2^2n,
- // up to a maximum possible backoff time of 5 minutes.
- //
- // this works out as the following:
- // 4ms
- // 16ms
- // 64ms
- // 256ms
- // 1.024s
- // 4.096s
- // 16.384s
- // 1m5.536s
- // 4m22.144s
- backoff = 2 * time.Millisecond * (1 << (2*i + 1))
- if backoff >= 5*time.Minute {
- break
- }
-
- select {
- // Context cancelled.
- case <-ctx.Done():
- return ctx.Err()
-
- // Backoff for some time.
- case <-time.After(backoff):
- }
-
- // Perform func.
- err := fn()
-
- if err != errBusy {
- // May be nil, or may be
- // some other error, either
- // way return here.
- return err
- }
- }
-
- return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff)
-}
-
-// toNamedValues converts older driver.Value types to driver.NamedValue types.
-func toNamedValues(args []driver.Value) []driver.NamedValue {
- if args == nil {
- return nil
- }
- args2 := make([]driver.NamedValue, len(args))
- for i := range args {
- args2[i] = driver.NamedValue{
- Ordinal: i + 1,
- Value: args[i],
- }
- }
- return args2
+ // register our SQL driver implementations.
+ sql.Register("pgx-gts", &postgres.Driver{})
+ sql.Register("sqlite-gts", &sqlite.Driver{})
}