diff options
Diffstat (limited to 'internal/db/bundb/drivers.go')
-rw-r--r-- | internal/db/bundb/drivers.go | 346 |
1 files changed, 5 insertions, 341 deletions
diff --git a/internal/db/bundb/drivers.go b/internal/db/bundb/drivers.go index 1811ad533..f39189c9d 100644 --- a/internal/db/bundb/drivers.go +++ b/internal/db/bundb/drivers.go @@ -18,350 +18,14 @@ package bundb import ( - "context" "database/sql" - "database/sql/driver" - "time" - _ "unsafe" // linkname shenanigans - pgx "github.com/jackc/pgx/v5/stdlib" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "modernc.org/sqlite" + "github.com/superseriousbusiness/gotosocial/internal/db/postgres" + "github.com/superseriousbusiness/gotosocial/internal/db/sqlite" ) -var ( - // global SQL driver instances. - postgresDriver = pgx.GetDefaultDriver() - sqliteDriver = getSQLiteDriver() - - // check the postgres connection - // conforms to our conn{} interface. - // (note SQLite doesn't export their - // conn type, and gets checked in - // tests very regularly anywho). - _ conn = (*pgx.Conn)(nil) -) - -//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver -func getSQLiteDriver() *sqlite.Driver - func init() { - sql.Register("pgx-gts", &PostgreSQLDriver{}) - sql.Register("sqlite-gts", &SQLiteDriver{}) -} - -// PostgreSQLDriver is our own wrapper around the -// pgx/stdlib.Driver{} type in order to wrap further -// SQL driver types with our own err processing. -type PostgreSQLDriver struct{} - -func (d *PostgreSQLDriver) Open(name string) (driver.Conn, error) { - c, err := postgresDriver.Open(name) - if err != nil { - return nil, err - } - return &PostgreSQLConn{conn: c.(conn)}, nil -} - -type PostgreSQLConn struct{ conn } - -func (c *PostgreSQLConn) Begin() (driver.Tx, error) { - return c.BeginTx(context.Background(), driver.TxOptions{}) -} - -func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - tx, err := c.conn.BeginTx(ctx, opts) - err = processPostgresError(err) - if err != nil { - return nil, err - } - return &PostgreSQLTx{tx}, nil -} - -func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) { - return c.PrepareContext(context.Background(), query) -} - -func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - st, err := c.conn.PrepareContext(ctx, query) - err = processPostgresError(err) - if err != nil { - return nil, err - } - return &PostgreSQLStmt{stmt: st.(stmt)}, nil -} - -func (c *PostgreSQLConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return c.ExecContext(context.Background(), query, toNamedValues(args)) -} - -func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - result, err := c.conn.ExecContext(ctx, query, args) - err = processPostgresError(err) - return result, err -} - -func (c *PostgreSQLConn) Query(query string, args []driver.Value) (driver.Rows, error) { - return c.QueryContext(context.Background(), query, toNamedValues(args)) -} - -func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - rows, err := c.conn.QueryContext(ctx, query, args) - err = processPostgresError(err) - return rows, err -} - -func (c *PostgreSQLConn) Close() error { - return c.conn.Close() -} - -type PostgreSQLTx struct{ driver.Tx } - -func (tx *PostgreSQLTx) Commit() error { - err := tx.Tx.Commit() - return processPostgresError(err) -} - -func (tx *PostgreSQLTx) Rollback() error { - err := tx.Tx.Rollback() - return processPostgresError(err) -} - -type PostgreSQLStmt struct{ stmt } - -func (stmt *PostgreSQLStmt) Exec(args []driver.Value) (driver.Result, error) { - return stmt.ExecContext(context.Background(), toNamedValues(args)) -} - -func (stmt *PostgreSQLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - res, err := stmt.stmt.ExecContext(ctx, args) - err = processPostgresError(err) - return res, err -} - -func (stmt *PostgreSQLStmt) Query(args []driver.Value) (driver.Rows, error) { - return stmt.QueryContext(context.Background(), toNamedValues(args)) -} - -func (stmt *PostgreSQLStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - rows, err := stmt.stmt.QueryContext(ctx, args) - err = processPostgresError(err) - return rows, err -} - -// SQLiteDriver is our own wrapper around the -// sqlite.Driver{} type in order to wrap further -// SQL driver types with our own functionality, -// e.g. hooks, retries and err processing. -type SQLiteDriver struct{} - -func (d *SQLiteDriver) Open(name string) (driver.Conn, error) { - c, err := sqliteDriver.Open(name) - if err != nil { - return nil, err - } - return &SQLiteConn{conn: c.(conn)}, nil -} - -type SQLiteConn struct{ conn } - -func (c *SQLiteConn) Begin() (driver.Tx, error) { - return c.BeginTx(context.Background(), driver.TxOptions{}) -} - -func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { - err = retryOnBusy(ctx, func() error { - tx, err = c.conn.BeginTx(ctx, opts) - err = processSQLiteError(err) - return err - }) - if err != nil { - return nil, err - } - return &SQLiteTx{Context: ctx, Tx: tx}, nil -} - -func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { - return c.PrepareContext(context.Background(), query) -} - -func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (st driver.Stmt, err error) { - err = retryOnBusy(ctx, func() error { - st, err = c.conn.PrepareContext(ctx, query) - err = processSQLiteError(err) - return err - }) - if err != nil { - return nil, err - } - return &SQLiteStmt{st.(stmt)}, nil -} - -func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return c.ExecContext(context.Background(), query, toNamedValues(args)) -} - -func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { - err = retryOnBusy(ctx, func() error { - result, err = c.conn.ExecContext(ctx, query, args) - err = processSQLiteError(err) - return err - }) - return -} - -func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { - return c.QueryContext(context.Background(), query, toNamedValues(args)) -} - -func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { - err = retryOnBusy(ctx, func() error { - rows, err = c.conn.QueryContext(ctx, query, args) - err = processSQLiteError(err) - return err - }) - return -} - -func (c *SQLiteConn) Close() error { - // see: https://www.sqlite.org/pragma.html#pragma_optimize - const onClose = "PRAGMA analysis_limit=1000; PRAGMA optimize;" - _, _ = c.conn.ExecContext(context.Background(), onClose, nil) - return c.conn.Close() -} - -type SQLiteTx struct { - context.Context - driver.Tx -} - -func (tx *SQLiteTx) Commit() (err error) { - err = retryOnBusy(tx.Context, func() error { - err = tx.Tx.Commit() - err = processSQLiteError(err) - return err - }) - return -} - -func (tx *SQLiteTx) Rollback() (err error) { - err = retryOnBusy(tx.Context, func() error { - err = tx.Tx.Rollback() - err = processSQLiteError(err) - return err - }) - return -} - -type SQLiteStmt struct{ stmt } - -func (stmt *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { - return stmt.ExecContext(context.Background(), toNamedValues(args)) -} - -func (stmt *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { - err = retryOnBusy(ctx, func() error { - res, err = stmt.stmt.ExecContext(ctx, args) - err = processSQLiteError(err) - return err - }) - return -} - -func (stmt *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { - return stmt.QueryContext(context.Background(), toNamedValues(args)) -} - -func (stmt *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { - err = retryOnBusy(ctx, func() error { - rows, err = stmt.stmt.QueryContext(ctx, args) - err = processSQLiteError(err) - return err - }) - return -} - -type conn interface { - driver.Conn - driver.ConnPrepareContext - driver.ExecerContext - driver.QueryerContext - driver.ConnBeginTx -} - -type stmt interface { - driver.Stmt - driver.StmtExecContext - driver.StmtQueryContext -} - -// retryOnBusy will retry given function on returned 'errBusy'. -func retryOnBusy(ctx context.Context, fn func() error) error { - if err := fn(); err != errBusy { - return err - } - return retryOnBusySlow(ctx, fn) -} - -// retryOnBusySlow is the outlined form of retryOnBusy, to allow the fast path (i.e. only -// 1 attempt) to be inlined, leaving the slow retry loop to be a separate function call. -func retryOnBusySlow(ctx context.Context, fn func() error) error { - var backoff time.Duration - - for i := 0; ; i++ { - // backoff according to a multiplier of 2ms * 2^2n, - // up to a maximum possible backoff time of 5 minutes. - // - // this works out as the following: - // 4ms - // 16ms - // 64ms - // 256ms - // 1.024s - // 4.096s - // 16.384s - // 1m5.536s - // 4m22.144s - backoff = 2 * time.Millisecond * (1 << (2*i + 1)) - if backoff >= 5*time.Minute { - break - } - - select { - // Context cancelled. - case <-ctx.Done(): - return ctx.Err() - - // Backoff for some time. - case <-time.After(backoff): - } - - // Perform func. - err := fn() - - if err != errBusy { - // May be nil, or may be - // some other error, either - // way return here. - return err - } - } - - return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) -} - -// toNamedValues converts older driver.Value types to driver.NamedValue types. -func toNamedValues(args []driver.Value) []driver.NamedValue { - if args == nil { - return nil - } - args2 := make([]driver.NamedValue, len(args)) - for i := range args { - args2[i] = driver.NamedValue{ - Ordinal: i + 1, - Value: args[i], - } - } - return args2 + // register our SQL driver implementations. + sql.Register("pgx-gts", &postgres.Driver{}) + sql.Register("sqlite-gts", &sqlite.Driver{}) } |