diff options
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/driver/driver.go')
| -rw-r--r-- | vendor/github.com/ncruces/go-sqlite3/driver/driver.go | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go index 9250cf39d..f473220c0 100644 --- a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go +++ b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go @@ -241,8 +241,9 @@ func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) { } }() - old := c.Conn.SetInterrupt(ctx) - defer c.Conn.SetInterrupt(old) + if old := c.Conn.SetInterrupt(ctx); old != ctx { + defer c.Conn.SetInterrupt(old) + } if !n.pragmas { err = c.Conn.BusyTimeout(time.Minute) @@ -362,8 +363,9 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e c.txReset = `; PRAGMA query_only=` + string(c.readOnly) } - old := c.Conn.SetInterrupt(ctx) - defer c.Conn.SetInterrupt(old) + if old := c.Conn.SetInterrupt(ctx); old != ctx { + defer c.Conn.SetInterrupt(old) + } err := c.Conn.Exec(txBegin) if err != nil { @@ -382,8 +384,10 @@ func (c *conn) Commit() error { func (c *conn) Rollback() error { // ROLLBACK even if interrupted. - old := c.Conn.SetInterrupt(context.Background()) - defer c.Conn.SetInterrupt(old) + ctx := context.Background() + if old := c.Conn.SetInterrupt(ctx); old != ctx { + defer c.Conn.SetInterrupt(old) + } return c.Conn.Exec(`ROLLBACK` + c.txReset) } @@ -393,8 +397,9 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) { } func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - old := c.Conn.SetInterrupt(ctx) - defer c.Conn.SetInterrupt(old) + if old := c.Conn.SetInterrupt(ctx); old != ctx { + defer c.Conn.SetInterrupt(old) + } s, tail, err := c.Conn.Prepare(query) if err != nil { @@ -419,8 +424,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name return resultRowsAffected(0), nil } - old := c.Conn.SetInterrupt(ctx) - defer c.Conn.SetInterrupt(old) + if old := c.Conn.SetInterrupt(ctx); old != ctx { + defer c.Conn.SetInterrupt(old) + } err := c.Conn.Exec(query) if err != nil { @@ -483,8 +489,10 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive return nil, err } - old := s.Stmt.Conn().SetInterrupt(ctx) - defer s.Stmt.Conn().SetInterrupt(old) + c := s.Stmt.Conn() + if old := c.SetInterrupt(ctx); old != ctx { + defer c.SetInterrupt(old) + } err = errors.Join( s.Stmt.Exec(), @@ -493,7 +501,7 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive return nil, err } - return newResult(s.Stmt.Conn()), nil + return newResult(c), nil } func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -678,13 +686,14 @@ func (r *rows) scanType(index int) scantype { func (r *rows) loadColumnMetadata() { if r.nulls == nil { + c := r.Stmt.Conn() count := r.Stmt.ColumnCount() nulls := make([]bool, count) types := make([]string, count) scans := make([]scantype, count) for i := range nulls { if col := r.Stmt.ColumnOriginName(i); col != "" { - types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata( + types[i], _, nulls[i], _, _, _ = c.TableColumnMetadata( r.Stmt.ColumnDatabaseName(i), r.Stmt.ColumnTableName(i), col) @@ -762,8 +771,10 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) { } func (r *rows) Next(dest []driver.Value) error { - old := r.Stmt.Conn().SetInterrupt(r.ctx) - defer r.Stmt.Conn().SetInterrupt(old) + c := r.Stmt.Conn() + if old := c.SetInterrupt(r.ctx); old != r.ctx { + defer c.SetInterrupt(old) + } if !r.Stmt.Step() { if err := r.Stmt.Err(); err != nil { |
