summaryrefslogtreecommitdiff
path: root/internal/db/bundb/drivers.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/drivers.go')
-rw-r--r--internal/db/bundb/drivers.go149
1 files changed, 122 insertions, 27 deletions
diff --git a/internal/db/bundb/drivers.go b/internal/db/bundb/drivers.go
index 14d84e6fa..a70b598d2 100644
--- a/internal/db/bundb/drivers.go
+++ b/internal/db/bundb/drivers.go
@@ -36,14 +36,14 @@ var (
sqliteDriver = getSQLiteDriver()
)
+//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver
+func getSQLiteDriver() *sqlite.Driver
+
func init() {
sql.Register("pgx-gts", &PostgreSQLDriver{})
sql.Register("sqlite-gts", &SQLiteDriver{})
}
-//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver
-func getSQLiteDriver() *sqlite.Driver
-
// PostgreSQLDriver is our own wrapper around the
// pgx/stdlib.Driver{} type in order to wrap further
// SQL driver types with our own err processing.
@@ -66,7 +66,10 @@ func (c *PostgreSQLConn) Begin() (driver.Tx, error) {
func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
tx, err := c.conn.BeginTx(ctx, opts)
err = processPostgresError(err)
- return tx, err
+ if err != nil {
+ return nil, err
+ }
+ return &PostgreSQLTx{tx}, nil
}
func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) {
@@ -74,13 +77,16 @@ func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) {
}
func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
- stmt, err := c.conn.PrepareContext(ctx, query)
+ st, err := c.conn.PrepareContext(ctx, query)
err = processPostgresError(err)
- return stmt, err
+ if err != nil {
+ return nil, err
+ }
+ return &PostgreSQLStmt{stmt: st.(stmt)}, nil
}
-func (c *PostgreSQLConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) {
- return c.ExecContext(context.Background(), query, args)
+func (c *PostgreSQLConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ return c.ExecContext(context.Background(), query, toNamedValues(args))
}
func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
@@ -89,8 +95,8 @@ func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []d
return result, err
}
-func (c *PostgreSQLConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) {
- return c.QueryContext(context.Background(), query, args)
+func (c *PostgreSQLConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+ return c.QueryContext(context.Background(), query, toNamedValues(args))
}
func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
@@ -115,6 +121,28 @@ func (tx *PostgreSQLTx) Rollback() error {
return processPostgresError(err)
}
+type PostgreSQLStmt struct{ stmt }
+
+func (stmt *PostgreSQLStmt) Exec(args []driver.Value) (driver.Result, error) {
+ return stmt.ExecContext(context.Background(), toNamedValues(args))
+}
+
+func (stmt *PostgreSQLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+ res, err := stmt.stmt.ExecContext(ctx, args)
+ err = processSQLiteError(err)
+ return res, err
+}
+
+func (stmt *PostgreSQLStmt) Query(args []driver.Value) (driver.Rows, error) {
+ return stmt.QueryContext(context.Background(), toNamedValues(args))
+}
+
+func (stmt *PostgreSQLStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+ rows, err := stmt.stmt.QueryContext(ctx, args)
+ err = processSQLiteError(err)
+ return rows, err
+}
+
// SQLiteDriver is our own wrapper around the
// sqlite.Driver{} type in order to wrap further
// SQL driver types with our own functionality,
@@ -141,6 +169,9 @@ func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx dri
err = processSQLiteError(err)
return err
})
+ if err != nil {
+ return nil, err
+ }
return &SQLiteTx{Context: ctx, Tx: tx}, nil
}
@@ -148,17 +179,20 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
-func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
+func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (st driver.Stmt, err error) {
err = retryOnBusy(ctx, func() error {
- stmt, err = c.conn.PrepareContext(ctx, query)
+ st, err = c.conn.PrepareContext(ctx, query)
err = processSQLiteError(err)
return err
})
- return
+ if err != nil {
+ return nil, err
+ }
+ return &SQLiteStmt{st.(stmt)}, nil
}
-func (c *SQLiteConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) {
- return c.ExecContext(context.Background(), query, args)
+func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ return c.ExecContext(context.Background(), query, toNamedValues(args))
}
func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
@@ -170,8 +204,8 @@ func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []drive
return
}
-func (c *SQLiteConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) {
- return c.QueryContext(context.Background(), query, args)
+func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+ return c.QueryContext(context.Background(), query, toNamedValues(args))
}
func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
@@ -213,29 +247,64 @@ func (tx *SQLiteTx) Rollback() (err error) {
return
}
+type SQLiteStmt struct{ stmt }
+
+func (stmt *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
+ return stmt.ExecContext(context.Background(), toNamedValues(args))
+}
+
+func (stmt *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) {
+ err = retryOnBusy(ctx, func() error {
+ res, err = stmt.stmt.ExecContext(ctx, args)
+ err = processSQLiteError(err)
+ return err
+ })
+ return
+}
+
+func (stmt *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
+ return stmt.QueryContext(context.Background(), toNamedValues(args))
+}
+
+func (stmt *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
+ err = retryOnBusy(ctx, func() error {
+ rows, err = stmt.stmt.QueryContext(ctx, args)
+ err = processSQLiteError(err)
+ return err
+ })
+ return
+}
+
type conn interface {
driver.Conn
driver.ConnPrepareContext
+ driver.Execer //nolint:staticcheck
driver.ExecerContext
+ driver.Queryer //nolint:staticcheck
driver.QueryerContext
driver.ConnBeginTx
}
+type stmt interface {
+ driver.Stmt
+ driver.StmtExecContext
+ driver.StmtQueryContext
+}
+
// retryOnBusy will retry given function on returned 'errBusy'.
func retryOnBusy(ctx context.Context, fn func() error) error {
+ if err := fn(); err != errBusy {
+ return err
+ }
+ return retryOnBusySlow(ctx, fn)
+}
+
+// retryOnBusySlow is the outlined form of retryOnBusy, to allow the fast path (i.e. only
+// 1 attempt) to be inlined, leaving the slow retry loop to be a separate function call.
+func retryOnBusySlow(ctx context.Context, fn func() error) error {
var backoff time.Duration
for i := 0; ; i++ {
- // Perform func.
- err := fn()
-
- if err != errBusy {
- // May be nil, or may be
- // some other error, either
- // way return here.
- return err
- }
-
// backoff according to a multiplier of 2ms * 2^2n,
// up to a maximum possible backoff time of 5 minutes.
//
@@ -257,11 +326,37 @@ func retryOnBusy(ctx context.Context, fn func() error) error {
select {
// Context cancelled.
case <-ctx.Done():
+ return ctx.Err()
// Backoff for some time.
case <-time.After(backoff):
}
+
+ // Perform func.
+ err := fn()
+
+ if err != errBusy {
+ // May be nil, or may be
+ // some other error, either
+ // way return here.
+ return err
+ }
}
return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff)
}
+
+// toNamedValues converts older driver.Value types to driver.NamedValue types.
+func toNamedValues(args []driver.Value) []driver.NamedValue {
+ if args == nil {
+ return nil
+ }
+ args2 := make([]driver.NamedValue, len(args))
+ for i := range args {
+ args2[i] = driver.NamedValue{
+ Ordinal: i + 1,
+ Value: args[i],
+ }
+ }
+ return args2
+}