diff options
Diffstat (limited to 'internal/db/bundb/drivers.go')
-rw-r--r-- | internal/db/bundb/drivers.go | 149 |
1 files changed, 122 insertions, 27 deletions
diff --git a/internal/db/bundb/drivers.go b/internal/db/bundb/drivers.go index 14d84e6fa..a70b598d2 100644 --- a/internal/db/bundb/drivers.go +++ b/internal/db/bundb/drivers.go @@ -36,14 +36,14 @@ var ( sqliteDriver = getSQLiteDriver() ) +//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver +func getSQLiteDriver() *sqlite.Driver + func init() { sql.Register("pgx-gts", &PostgreSQLDriver{}) sql.Register("sqlite-gts", &SQLiteDriver{}) } -//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver -func getSQLiteDriver() *sqlite.Driver - // PostgreSQLDriver is our own wrapper around the // pgx/stdlib.Driver{} type in order to wrap further // SQL driver types with our own err processing. @@ -66,7 +66,10 @@ func (c *PostgreSQLConn) Begin() (driver.Tx, error) { func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { tx, err := c.conn.BeginTx(ctx, opts) err = processPostgresError(err) - return tx, err + if err != nil { + return nil, err + } + return &PostgreSQLTx{tx}, nil } func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) { @@ -74,13 +77,16 @@ func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) { } func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - stmt, err := c.conn.PrepareContext(ctx, query) + st, err := c.conn.PrepareContext(ctx, query) err = processPostgresError(err) - return stmt, err + if err != nil { + return nil, err + } + return &PostgreSQLStmt{stmt: st.(stmt)}, nil } -func (c *PostgreSQLConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { - return c.ExecContext(context.Background(), query, args) +func (c *PostgreSQLConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return c.ExecContext(context.Background(), query, toNamedValues(args)) } func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { @@ -89,8 +95,8 @@ func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []d return result, err } -func (c *PostgreSQLConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { - return c.QueryContext(context.Background(), query, args) +func (c *PostgreSQLConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return c.QueryContext(context.Background(), query, toNamedValues(args)) } func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { @@ -115,6 +121,28 @@ func (tx *PostgreSQLTx) Rollback() error { return processPostgresError(err) } +type PostgreSQLStmt struct{ stmt } + +func (stmt *PostgreSQLStmt) Exec(args []driver.Value) (driver.Result, error) { + return stmt.ExecContext(context.Background(), toNamedValues(args)) +} + +func (stmt *PostgreSQLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + res, err := stmt.stmt.ExecContext(ctx, args) + err = processSQLiteError(err) + return res, err +} + +func (stmt *PostgreSQLStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.QueryContext(context.Background(), toNamedValues(args)) +} + +func (stmt *PostgreSQLStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + rows, err := stmt.stmt.QueryContext(ctx, args) + err = processSQLiteError(err) + return rows, err +} + // SQLiteDriver is our own wrapper around the // sqlite.Driver{} type in order to wrap further // SQL driver types with our own functionality, @@ -141,6 +169,9 @@ func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx dri err = processSQLiteError(err) return err }) + if err != nil { + return nil, err + } return &SQLiteTx{Context: ctx, Tx: tx}, nil } @@ -148,17 +179,20 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { return c.PrepareContext(context.Background(), query) } -func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { +func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (st driver.Stmt, err error) { err = retryOnBusy(ctx, func() error { - stmt, err = c.conn.PrepareContext(ctx, query) + st, err = c.conn.PrepareContext(ctx, query) err = processSQLiteError(err) return err }) - return + if err != nil { + return nil, err + } + return &SQLiteStmt{st.(stmt)}, nil } -func (c *SQLiteConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { - return c.ExecContext(context.Background(), query, args) +func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return c.ExecContext(context.Background(), query, toNamedValues(args)) } func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { @@ -170,8 +204,8 @@ func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []drive return } -func (c *SQLiteConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { - return c.QueryContext(context.Background(), query, args) +func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return c.QueryContext(context.Background(), query, toNamedValues(args)) } func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { @@ -213,29 +247,64 @@ func (tx *SQLiteTx) Rollback() (err error) { return } +type SQLiteStmt struct{ stmt } + +func (stmt *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { + return stmt.ExecContext(context.Background(), toNamedValues(args)) +} + +func (stmt *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { + err = retryOnBusy(ctx, func() error { + res, err = stmt.stmt.ExecContext(ctx, args) + err = processSQLiteError(err) + return err + }) + return +} + +func (stmt *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.QueryContext(context.Background(), toNamedValues(args)) +} + +func (stmt *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { + err = retryOnBusy(ctx, func() error { + rows, err = stmt.stmt.QueryContext(ctx, args) + err = processSQLiteError(err) + return err + }) + return +} + type conn interface { driver.Conn driver.ConnPrepareContext + driver.Execer //nolint:staticcheck driver.ExecerContext + driver.Queryer //nolint:staticcheck driver.QueryerContext driver.ConnBeginTx } +type stmt interface { + driver.Stmt + driver.StmtExecContext + driver.StmtQueryContext +} + // retryOnBusy will retry given function on returned 'errBusy'. func retryOnBusy(ctx context.Context, fn func() error) error { + if err := fn(); err != errBusy { + return err + } + return retryOnBusySlow(ctx, fn) +} + +// retryOnBusySlow is the outlined form of retryOnBusy, to allow the fast path (i.e. only +// 1 attempt) to be inlined, leaving the slow retry loop to be a separate function call. +func retryOnBusySlow(ctx context.Context, fn func() error) error { var backoff time.Duration for i := 0; ; i++ { - // Perform func. - err := fn() - - if err != errBusy { - // May be nil, or may be - // some other error, either - // way return here. - return err - } - // backoff according to a multiplier of 2ms * 2^2n, // up to a maximum possible backoff time of 5 minutes. // @@ -257,11 +326,37 @@ func retryOnBusy(ctx context.Context, fn func() error) error { select { // Context cancelled. case <-ctx.Done(): + return ctx.Err() // Backoff for some time. case <-time.After(backoff): } + + // Perform func. + err := fn() + + if err != errBusy { + // May be nil, or may be + // some other error, either + // way return here. + return err + } } return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) } + +// toNamedValues converts older driver.Value types to driver.NamedValue types. +func toNamedValues(args []driver.Value) []driver.NamedValue { + if args == nil { + return nil + } + args2 := make([]driver.NamedValue, len(args)) + for i := range args { + args2[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: args[i], + } + } + return args2 +} |