summaryrefslogtreecommitdiff
path: root/vendor/github.com/ncruces/go-sqlite3/driver
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/driver')
-rw-r--r--vendor/github.com/ncruces/go-sqlite3/driver/driver.go132
-rw-r--r--vendor/github.com/ncruces/go-sqlite3/driver/time.go4
2 files changed, 106 insertions, 30 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go
index 88c4c50db..477e9a940 100644
--- a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go
+++ b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go
@@ -81,6 +81,7 @@ import (
"fmt"
"io"
"net/url"
+ "reflect"
"strings"
"time"
"unsafe"
@@ -107,17 +108,17 @@ func init() {
// The second callback is called before the driver closes a connection.
// The [sqlite3.Conn] can be used to execute queries, register functions, etc.
func Open(dataSourceName string, fn ...func(*sqlite3.Conn) error) (*sql.DB, error) {
- var drv SQLite
if len(fn) > 2 {
return nil, sqlite3.MISUSE
}
+ var init, term func(*sqlite3.Conn) error
if len(fn) > 1 {
- drv.term = fn[1]
+ term = fn[1]
}
if len(fn) > 0 {
- drv.init = fn[0]
+ init = fn[0]
}
- c, err := drv.OpenConnector(dataSourceName)
+ c, err := newConnector(dataSourceName, init, term)
if err != nil {
return nil, err
}
@@ -125,10 +126,7 @@ func Open(dataSourceName string, fn ...func(*sqlite3.Conn) error) (*sql.DB, erro
}
// SQLite implements [database/sql/driver.Driver].
-type SQLite struct {
- init func(*sqlite3.Conn) error
- term func(*sqlite3.Conn) error
-}
+type SQLite struct{}
var (
// Ensure these interfaces are implemented:
@@ -137,7 +135,7 @@ var (
// Open implements [database/sql/driver.Driver].
func (d *SQLite) Open(name string) (driver.Conn, error) {
- c, err := d.newConnector(name)
+ c, err := newConnector(name, nil, nil)
if err != nil {
return nil, err
}
@@ -146,11 +144,11 @@ func (d *SQLite) Open(name string) (driver.Conn, error) {
// OpenConnector implements [database/sql/driver.DriverContext].
func (d *SQLite) OpenConnector(name string) (driver.Connector, error) {
- return d.newConnector(name)
+ return newConnector(name, nil, nil)
}
-func (d *SQLite) newConnector(name string) (*connector, error) {
- c := connector{driver: d, name: name}
+func newConnector(name string, init, term func(*sqlite3.Conn) error) (*connector, error) {
+ c := connector{name: name, init: init, term: term}
var txlock, timefmt string
if strings.HasPrefix(name, "file:") {
@@ -190,7 +188,8 @@ func (d *SQLite) newConnector(name string) (*connector, error) {
}
type connector struct {
- driver *SQLite
+ init func(*sqlite3.Conn) error
+ term func(*sqlite3.Conn) error
name string
txLock string
tmRead sqlite3.TimeFormat
@@ -199,7 +198,7 @@ type connector struct {
}
func (n *connector) Driver() driver.Driver {
- return n.driver
+ return &SQLite{}
}
func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
@@ -228,13 +227,13 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
return nil, err
}
}
- if n.driver.init != nil {
- err = n.driver.init(c.Conn)
+ if n.init != nil {
+ err = n.init(c.Conn)
if err != nil {
return nil, err
}
}
- if n.pragmas || n.driver.init != nil {
+ if n.pragmas || n.init != nil {
s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
if err != nil {
return nil, err
@@ -250,9 +249,9 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
return nil, err
}
}
- if n.driver.term != nil {
+ if n.term != nil {
err = c.Conn.Trace(sqlite3.TRACE_CLOSE, func(sqlite3.TraceEvent, any, any) error {
- return n.driver.term(c.Conn)
+ return n.term(c.Conn)
})
if err != nil {
return nil, err
@@ -288,6 +287,8 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
type Conn interface {
Raw() *sqlite3.Conn
driver.Conn
+ driver.ConnBeginTx
+ driver.ConnPrepareContext
}
type conn struct {
@@ -301,10 +302,8 @@ type conn struct {
var (
// Ensure these interfaces are implemented:
- _ Conn = &conn{}
- _ driver.ConnBeginTx = &conn{}
- _ driver.ConnPrepareContext = &conn{}
- _ driver.ExecerContext = &conn{}
+ _ Conn = &conn{}
+ _ driver.ExecerContext = &conn{}
)
func (c *conn) Raw() *sqlite3.Conn {
@@ -581,8 +580,22 @@ type rows struct {
names []string
types []string
nulls []bool
+ scans []scantype
}
+type scantype byte
+
+const (
+ _ANY scantype = iota
+ _INT scantype = scantype(sqlite3.INTEGER)
+ _REAL scantype = scantype(sqlite3.FLOAT)
+ _TEXT scantype = scantype(sqlite3.TEXT)
+ _BLOB scantype = scantype(sqlite3.BLOB)
+ _NULL scantype = scantype(sqlite3.NULL)
+ _BOOL scantype = iota
+ _TIME
+)
+
var (
// Ensure these interfaces are implemented:
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
@@ -606,21 +619,42 @@ func (r *rows) Columns() []string {
return r.names
}
-func (r *rows) loadTypes() {
+func (r *rows) loadColumnMetadata() {
if r.nulls == nil {
count := r.Stmt.ColumnCount()
nulls := make([]bool, count)
types := make([]string, count)
+ scans := make([]scantype, count)
for i := range nulls {
if col := r.Stmt.ColumnOriginName(i); col != "" {
types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i),
col)
+ types[i] = strings.ToUpper(types[i])
+ // These types are only used before we have rows,
+ // and otherwise as type hints.
+ // The first few ensure STRICT tables are strictly typed.
+ // The other two are type hints for booleans and time.
+ switch types[i] {
+ case "INT", "INTEGER":
+ scans[i] = _INT
+ case "REAL":
+ scans[i] = _REAL
+ case "TEXT":
+ scans[i] = _TEXT
+ case "BLOB":
+ scans[i] = _BLOB
+ case "BOOLEAN":
+ scans[i] = _BOOL
+ case "DATE", "TIME", "DATETIME", "TIMESTAMP":
+ scans[i] = _TIME
+ }
}
}
r.nulls = nulls
r.types = types
+ r.scans = scans
}
}
@@ -637,7 +671,7 @@ func (r *rows) declType(index int) string {
}
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
- r.loadTypes()
+ r.loadColumnMetadata()
decltype := r.types[index]
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
@@ -648,13 +682,57 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
}
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
- r.loadTypes()
+ r.loadColumnMetadata()
if r.nulls[index] {
return false, true
}
return true, false
}
+func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
+ r.loadColumnMetadata()
+ scan := r.scans[index]
+
+ if r.Stmt.Busy() {
+ // SQLite is dynamically typed and we now have a row.
+ // Always use the type of the value itself,
+ // unless the scan type is more specific
+ // and can scan the actual value.
+ val := scantype(r.Stmt.ColumnType(index))
+ useValType := true
+ switch {
+ case scan == _TIME && val != _BLOB && val != _NULL:
+ t := r.Stmt.ColumnTime(index, r.tmRead)
+ useValType = t == time.Time{}
+ case scan == _BOOL && val == _INT:
+ i := r.Stmt.ColumnInt64(index)
+ useValType = i != 0 && i != 1
+ case scan == _BLOB && val == _NULL:
+ useValType = false
+ }
+ if useValType {
+ scan = val
+ }
+ }
+
+ switch scan {
+ case _INT:
+ return reflect.TypeOf(int64(0))
+ case _REAL:
+ return reflect.TypeOf(float64(0))
+ case _TEXT:
+ return reflect.TypeOf("")
+ case _BLOB:
+ return reflect.TypeOf([]byte{})
+ case _BOOL:
+ return reflect.TypeOf(false)
+ case _TIME:
+ return reflect.TypeOf(time.Time{})
+ default:
+ return reflect.TypeOf((*any)(nil)).Elem()
+ }
+}
+
func (r *rows) Next(dest []driver.Value) error {
old := r.Stmt.Conn().SetInterrupt(r.ctx)
defer r.Stmt.Conn().SetInterrupt(old)
@@ -667,7 +745,7 @@ func (r *rows) Next(dest []driver.Value) error {
}
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
- err := r.Stmt.Columns(data)
+ err := r.Stmt.Columns(data...)
for i := range dest {
if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t
diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/time.go b/vendor/github.com/ncruces/go-sqlite3/driver/time.go
index 630a5b10b..b3ebdd263 100644
--- a/vendor/github.com/ncruces/go-sqlite3/driver/time.go
+++ b/vendor/github.com/ncruces/go-sqlite3/driver/time.go
@@ -1,8 +1,6 @@
package driver
-import (
- "time"
-)
+import "time"
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.