summaryrefslogtreecommitdiff
path: root/vendor/github.com/ncruces/go-sqlite3/stmt.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/stmt.go')
-rw-r--r--vendor/github.com/ncruces/go-sqlite3/stmt.go134
1 files changed, 104 insertions, 30 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/stmt.go b/vendor/github.com/ncruces/go-sqlite3/stmt.go
index 4e17d1039..1ea726ea1 100644
--- a/vendor/github.com/ncruces/go-sqlite3/stmt.go
+++ b/vendor/github.com/ncruces/go-sqlite3/stmt.go
@@ -106,7 +106,14 @@ func (s *Stmt) Busy() bool {
//
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
- s.c.checkInterrupt(s.c.handle)
+ if s.c.interrupt.Err() != nil {
+ s.err = INTERRUPT
+ return false
+ }
+ return s.step()
+}
+
+func (s *Stmt) step() bool {
rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle)))
switch rc {
case _ROW:
@@ -131,7 +138,11 @@ func (s *Stmt) Err() error {
// Exec is a convenience function that repeatedly calls [Stmt.Step] until it returns false,
// then calls [Stmt.Reset] to reset the statement and get any error that occurred.
func (s *Stmt) Exec() error {
- for s.Step() {
+ if s.c.interrupt.Err() != nil {
+ return INTERRUPT
+ }
+ // TODO: implement this in C.
+ for s.step() {
}
return s.Reset()
}
@@ -254,13 +265,15 @@ func (s *Stmt) BindText(param int, value string) error {
// BindRawText binds a []byte to the prepared statement as text.
// The leftmost SQL parameter has an index of 1.
-// Binding a nil slice is the same as calling [Stmt.BindNull].
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindRawText(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
+ if len(value) == 0 {
+ return s.BindText(param, "")
+ }
ptr := s.c.newBytes(value)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
@@ -270,13 +283,15 @@ func (s *Stmt) BindRawText(param int, value []byte) error {
// BindBlob binds a []byte to the prepared statement.
// The leftmost SQL parameter has an index of 1.
-// Binding a nil slice is the same as calling [Stmt.BindNull].
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
+ if len(value) == 0 {
+ return s.BindZeroBlob(param, 0)
+ }
ptr := s.c.newBytes(value)
rc := res_t(s.c.call("sqlite3_bind_blob_go",
stk_t(s.handle), stk_t(param),
@@ -560,7 +575,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
func (s *Stmt) ColumnRawText(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_text",
stk_t(s.handle), stk_t(col)))
- return s.columnRawBytes(col, ptr)
+ return s.columnRawBytes(col, ptr, 1)
}
// ColumnRawBlob returns the value of the result column as a []byte.
@@ -572,10 +587,10 @@ func (s *Stmt) ColumnRawText(col int) []byte {
func (s *Stmt) ColumnRawBlob(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_blob",
stk_t(s.handle), stk_t(col)))
- return s.columnRawBytes(col, ptr)
+ return s.columnRawBytes(col, ptr, 0)
}
-func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
+func (s *Stmt) columnRawBytes(col int, ptr ptr_t, nul int32) []byte {
if ptr == 0 {
rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle)))
if rc != _ROW && rc != _DONE {
@@ -586,7 +601,7 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
n := int32(s.c.call("sqlite3_column_bytes",
stk_t(s.handle), stk_t(col)))
- return util.View(s.c.mod, ptr, int64(n))
+ return util.View(s.c.mod, ptr, int64(n+nul))[:n]
}
// ColumnJSON parses the JSON-encoded value of the result column
@@ -633,21 +648,63 @@ func (s *Stmt) ColumnValue(col int) Value {
// [INTEGER] columns will be retrieved as int64 values,
// [FLOAT] as float64, [NULL] as nil,
// [TEXT] as string, and [BLOB] as []byte.
-// Any []byte are owned by SQLite and may be invalidated by
-// subsequent calls to [Stmt] methods.
func (s *Stmt) Columns(dest ...any) error {
- defer s.c.arena.mark()()
- count := int64(len(dest))
- typePtr := s.c.arena.new(count)
- dataPtr := s.c.arena.new(count * 8)
-
- rc := res_t(s.c.call("sqlite3_columns_go",
- stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
- if err := s.c.error(rc); err != nil {
+ types, ptr, err := s.columns(int64(len(dest)))
+ if err != nil {
return err
}
- types := util.View(s.c.mod, typePtr, count)
+ // Avoid bounds checks on types below.
+ if len(types) != len(dest) {
+ panic(util.AssertErr())
+ }
+
+ for i := range dest {
+ switch types[i] {
+ case byte(INTEGER):
+ dest[i] = util.Read64[int64](s.c.mod, ptr)
+ case byte(FLOAT):
+ dest[i] = util.ReadFloat64(s.c.mod, ptr)
+ case byte(NULL):
+ dest[i] = nil
+ case byte(TEXT):
+ len := util.Read32[int32](s.c.mod, ptr+4)
+ if len != 0 {
+ ptr := util.Read32[ptr_t](s.c.mod, ptr)
+ buf := util.View(s.c.mod, ptr, int64(len))
+ dest[i] = string(buf)
+ } else {
+ dest[i] = ""
+ }
+ case byte(BLOB):
+ len := util.Read32[int32](s.c.mod, ptr+4)
+ if len != 0 {
+ ptr := util.Read32[ptr_t](s.c.mod, ptr)
+ buf := util.View(s.c.mod, ptr, int64(len))
+ tmp, _ := dest[i].([]byte)
+ dest[i] = append(tmp[:0], buf...)
+ } else {
+ dest[i], _ = dest[i].([]byte)
+ }
+ }
+ ptr += 8
+ }
+ return nil
+}
+
+// ColumnsRaw populates result columns into the provided slice.
+// The slice must have [Stmt.ColumnCount] length.
+//
+// [INTEGER] columns will be retrieved as int64 values,
+// [FLOAT] as float64, [NULL] as nil,
+// [TEXT] and [BLOB] as []byte.
+// Any []byte are owned by SQLite and may be invalidated by
+// subsequent calls to [Stmt] methods.
+func (s *Stmt) ColumnsRaw(dest ...any) error {
+ types, ptr, err := s.columns(int64(len(dest)))
+ if err != nil {
+ return err
+ }
// Avoid bounds checks on types below.
if len(types) != len(dest) {
@@ -657,26 +714,43 @@ func (s *Stmt) Columns(dest ...any) error {
for i := range dest {
switch types[i] {
case byte(INTEGER):
- dest[i] = util.Read64[int64](s.c.mod, dataPtr)
+ dest[i] = util.Read64[int64](s.c.mod, ptr)
case byte(FLOAT):
- dest[i] = util.ReadFloat64(s.c.mod, dataPtr)
+ dest[i] = util.ReadFloat64(s.c.mod, ptr)
case byte(NULL):
dest[i] = nil
default:
- ptr := util.Read32[ptr_t](s.c.mod, dataPtr+0)
- if ptr == 0 {
+ len := util.Read32[int32](s.c.mod, ptr+4)
+ if len == 0 && types[i] == byte(BLOB) {
dest[i] = []byte{}
- continue
- }
- len := util.Read32[int32](s.c.mod, dataPtr+4)
- buf := util.View(s.c.mod, ptr, int64(len))
- if types[i] == byte(TEXT) {
- dest[i] = string(buf)
} else {
+ cap := len
+ if types[i] == byte(TEXT) {
+ cap++
+ }
+ ptr := util.Read32[ptr_t](s.c.mod, ptr)
+ buf := util.View(s.c.mod, ptr, int64(cap))[:len]
dest[i] = buf
}
}
- dataPtr += 8
+ ptr += 8
}
return nil
}
+
+func (s *Stmt) columns(count int64) ([]byte, ptr_t, error) {
+ defer s.c.arena.mark()()
+ typePtr := s.c.arena.new(count)
+ dataPtr := s.c.arena.new(count * 8)
+
+ rc := res_t(s.c.call("sqlite3_columns_go",
+ stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
+ if rc == res_t(MISUSE) {
+ return nil, 0, MISUSE
+ }
+ if err := s.c.error(rc); err != nil {
+ return nil, 0, err
+ }
+
+ return util.View(s.c.mod, typePtr, count), dataPtr, nil
+}