diff options
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/driver/driver.go')
| -rw-r--r-- | vendor/github.com/ncruces/go-sqlite3/driver/driver.go | 107 |
1 files changed, 59 insertions, 48 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go index 27496f6cb..5d2847369 100644 --- a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go +++ b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go @@ -607,14 +607,24 @@ func (r resultRowsAffected) RowsAffected() (int64, error) { type scantype byte const ( - _ANY scantype = iota - _INT scantype = scantype(sqlite3.INTEGER) - _REAL scantype = scantype(sqlite3.FLOAT) - _TEXT scantype = scantype(sqlite3.TEXT) - _BLOB scantype = scantype(sqlite3.BLOB) - _NULL scantype = scantype(sqlite3.NULL) - _BOOL scantype = iota + _ANY scantype = iota + _INT + _REAL + _TEXT + _BLOB + _NULL + _BOOL _TIME + _NOT_NULL +) + +var ( + _ [0]struct{} = [scantype(sqlite3.INTEGER) - _INT]struct{}{} + _ [0]struct{} = [scantype(sqlite3.FLOAT) - _REAL]struct{}{} + _ [0]struct{} = [scantype(sqlite3.TEXT) - _TEXT]struct{}{} + _ [0]struct{} = [scantype(sqlite3.BLOB) - _BLOB]struct{}{} + _ [0]struct{} = [scantype(sqlite3.NULL) - _NULL]struct{}{} + _ [0]struct{} = [_NOT_NULL & (_NOT_NULL - 1)]struct{}{} ) func scanFromDecl(decl string) scantype { @@ -644,8 +654,8 @@ type rows struct { *stmt names []string types []string - nulls []bool scans []scantype + dest []driver.Value } var ( @@ -675,34 +685,36 @@ func (r *rows) Columns() []string { func (r *rows) scanType(index int) scantype { if r.scans == nil { - count := r.Stmt.ColumnCount() + count := len(r.names) scans := make([]scantype, count) for i := range scans { scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i))) } r.scans = scans } - return r.scans[index] + return r.scans[index] &^ _NOT_NULL } func (r *rows) loadColumnMetadata() { - if r.nulls == nil { + if r.types == nil { c := r.Stmt.Conn() - count := r.Stmt.ColumnCount() - nulls := make([]bool, count) + count := len(r.names) types := make([]string, count) scans := make([]scantype, count) - for i := range nulls { + for i := range types { + var notnull bool if col := r.Stmt.ColumnOriginName(i); col != "" { - types[i], _, nulls[i], _, _, _ = c.TableColumnMetadata( + types[i], _, notnull, _, _, _ = c.TableColumnMetadata( r.Stmt.ColumnDatabaseName(i), r.Stmt.ColumnTableName(i), col) types[i] = strings.ToUpper(types[i]) scans[i] = scanFromDecl(types[i]) + if notnull { + scans[i] |= _NOT_NULL + } } } - r.nulls = nulls r.types = types r.scans = scans } @@ -721,15 +733,13 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string { func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { r.loadColumnMetadata() - if r.nulls[index] { - return false, true - } - return true, false + nullable = r.scans[index]&^_NOT_NULL == 0 + return nullable, !nullable } func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) { r.loadColumnMetadata() - scan := r.scans[index] + scan := r.scans[index] &^ _NOT_NULL if r.Stmt.Busy() { // SQLite is dynamically typed and we now have a row. @@ -772,6 +782,7 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) { } func (r *rows) Next(dest []driver.Value) error { + r.dest = nil c := r.Stmt.Conn() if old := c.SetInterrupt(r.ctx); old != r.ctx { defer c.SetInterrupt(old) @@ -790,18 +801,7 @@ func (r *rows) Next(dest []driver.Value) error { } for i := range dest { scan := r.scanType(i) - switch v := dest[i].(type) { - case int64: - if scan == _BOOL { - switch v { - case 1: - dest[i] = true - case 0: - dest[i] = false - } - continue - } - case []byte: + if v, ok := dest[i].([]byte); ok { if len(v) == cap(v) { // a BLOB continue } @@ -816,38 +816,49 @@ func (r *rows) Next(dest []driver.Value) error { } } dest[i] = string(v) - case float64: - break - default: - continue } - if scan == _TIME { + switch scan { + case _TIME: t, err := r.tmRead.Decode(dest[i]) if err == nil { dest[i] = t - continue + } + case _BOOL: + switch dest[i] { + case int64(0): + dest[i] = false + case int64(1): + dest[i] = true } } } + r.dest = dest return nil } -func (r *rows) ScanColumn(dest any, index int) error { +func (r *rows) ScanColumn(dest any, index int) (err error) { // notest // Go 1.26 - var ptr *time.Time + var tm *time.Time + var ok *bool switch d := dest.(type) { case *time.Time: - ptr = d + tm = d case *sql.NullTime: - ptr = &d.Time + tm = &d.Time + ok = &d.Valid case *sql.Null[time.Time]: - ptr = &d.V + tm = &d.V + ok = &d.Valid default: return driver.ErrSkip } - if t := r.Stmt.ColumnTime(index, r.tmRead); !t.IsZero() { - *ptr = t - return nil + value := r.dest[index] + *tm, err = r.tmRead.Decode(value) + if ok != nil { + *ok = err == nil + if value == nil { + return nil + } } - return driver.ErrSkip + return err } |
