summaryrefslogtreecommitdiff
path: root/vendor/github.com/ncruces/go-sqlite3/driver/driver.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/driver/driver.go')
-rw-r--r--vendor/github.com/ncruces/go-sqlite3/driver/driver.go107
1 files changed, 59 insertions, 48 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go
index 27496f6cb..5d2847369 100644
--- a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go
+++ b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go
@@ -607,14 +607,24 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
type scantype byte
const (
- _ANY scantype = iota
- _INT scantype = scantype(sqlite3.INTEGER)
- _REAL scantype = scantype(sqlite3.FLOAT)
- _TEXT scantype = scantype(sqlite3.TEXT)
- _BLOB scantype = scantype(sqlite3.BLOB)
- _NULL scantype = scantype(sqlite3.NULL)
- _BOOL scantype = iota
+ _ANY scantype = iota
+ _INT
+ _REAL
+ _TEXT
+ _BLOB
+ _NULL
+ _BOOL
_TIME
+ _NOT_NULL
+)
+
+var (
+ _ [0]struct{} = [scantype(sqlite3.INTEGER) - _INT]struct{}{}
+ _ [0]struct{} = [scantype(sqlite3.FLOAT) - _REAL]struct{}{}
+ _ [0]struct{} = [scantype(sqlite3.TEXT) - _TEXT]struct{}{}
+ _ [0]struct{} = [scantype(sqlite3.BLOB) - _BLOB]struct{}{}
+ _ [0]struct{} = [scantype(sqlite3.NULL) - _NULL]struct{}{}
+ _ [0]struct{} = [_NOT_NULL & (_NOT_NULL - 1)]struct{}{}
)
func scanFromDecl(decl string) scantype {
@@ -644,8 +654,8 @@ type rows struct {
*stmt
names []string
types []string
- nulls []bool
scans []scantype
+ dest []driver.Value
}
var (
@@ -675,34 +685,36 @@ func (r *rows) Columns() []string {
func (r *rows) scanType(index int) scantype {
if r.scans == nil {
- count := r.Stmt.ColumnCount()
+ count := len(r.names)
scans := make([]scantype, count)
for i := range scans {
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
}
r.scans = scans
}
- return r.scans[index]
+ return r.scans[index] &^ _NOT_NULL
}
func (r *rows) loadColumnMetadata() {
- if r.nulls == nil {
+ if r.types == nil {
c := r.Stmt.Conn()
- count := r.Stmt.ColumnCount()
- nulls := make([]bool, count)
+ count := len(r.names)
types := make([]string, count)
scans := make([]scantype, count)
- for i := range nulls {
+ for i := range types {
+ var notnull bool
if col := r.Stmt.ColumnOriginName(i); col != "" {
- types[i], _, nulls[i], _, _, _ = c.TableColumnMetadata(
+ types[i], _, notnull, _, _, _ = c.TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i),
col)
types[i] = strings.ToUpper(types[i])
scans[i] = scanFromDecl(types[i])
+ if notnull {
+ scans[i] |= _NOT_NULL
+ }
}
}
- r.nulls = nulls
r.types = types
r.scans = scans
}
@@ -721,15 +733,13 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
r.loadColumnMetadata()
- if r.nulls[index] {
- return false, true
- }
- return true, false
+ nullable = r.scans[index]&^_NOT_NULL == 0
+ return nullable, !nullable
}
func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
r.loadColumnMetadata()
- scan := r.scans[index]
+ scan := r.scans[index] &^ _NOT_NULL
if r.Stmt.Busy() {
// SQLite is dynamically typed and we now have a row.
@@ -772,6 +782,7 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
}
func (r *rows) Next(dest []driver.Value) error {
+ r.dest = nil
c := r.Stmt.Conn()
if old := c.SetInterrupt(r.ctx); old != r.ctx {
defer c.SetInterrupt(old)
@@ -790,18 +801,7 @@ func (r *rows) Next(dest []driver.Value) error {
}
for i := range dest {
scan := r.scanType(i)
- switch v := dest[i].(type) {
- case int64:
- if scan == _BOOL {
- switch v {
- case 1:
- dest[i] = true
- case 0:
- dest[i] = false
- }
- continue
- }
- case []byte:
+ if v, ok := dest[i].([]byte); ok {
if len(v) == cap(v) { // a BLOB
continue
}
@@ -816,38 +816,49 @@ func (r *rows) Next(dest []driver.Value) error {
}
}
dest[i] = string(v)
- case float64:
- break
- default:
- continue
}
- if scan == _TIME {
+ switch scan {
+ case _TIME:
t, err := r.tmRead.Decode(dest[i])
if err == nil {
dest[i] = t
- continue
+ }
+ case _BOOL:
+ switch dest[i] {
+ case int64(0):
+ dest[i] = false
+ case int64(1):
+ dest[i] = true
}
}
}
+ r.dest = dest
return nil
}
-func (r *rows) ScanColumn(dest any, index int) error {
+func (r *rows) ScanColumn(dest any, index int) (err error) {
// notest // Go 1.26
- var ptr *time.Time
+ var tm *time.Time
+ var ok *bool
switch d := dest.(type) {
case *time.Time:
- ptr = d
+ tm = d
case *sql.NullTime:
- ptr = &d.Time
+ tm = &d.Time
+ ok = &d.Valid
case *sql.Null[time.Time]:
- ptr = &d.V
+ tm = &d.V
+ ok = &d.Valid
default:
return driver.ErrSkip
}
- if t := r.Stmt.ColumnTime(index, r.tmRead); !t.IsZero() {
- *ptr = t
- return nil
+ value := r.dest[index]
+ *tm, err = r.tmRead.Decode(value)
+ if ok != nil {
+ *ok = err == nil
+ if value == nil {
+ return nil
+ }
}
- return driver.ErrSkip
+ return err
}