diff options
author | 2024-12-12 19:44:53 +0000 | |
---|---|---|
committer | 2024-12-12 19:44:53 +0000 | |
commit | fb12bbb10b228fddf40ebd0e463d5afcd9299ebb (patch) | |
tree | c9239fe362ea38e71ab4363649ceff605cb4ab89 /vendor/github.com/ncruces/go-sqlite3/driver/driver.go | |
parent | Bump nanoid from 3.3.7 to 5.0.9 in /web/source (#3615) (diff) | |
download | gotosocial-fb12bbb10b228fddf40ebd0e463d5afcd9299ebb.tar.xz |
bump ncruces/go-sqlite3 to v0.21.0 (#3621)
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/driver/driver.go')
-rw-r--r-- | vendor/github.com/ncruces/go-sqlite3/driver/driver.go | 132 |
1 files changed, 105 insertions, 27 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go index 88c4c50db..477e9a940 100644 --- a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go +++ b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go @@ -81,6 +81,7 @@ import ( "fmt" "io" "net/url" + "reflect" "strings" "time" "unsafe" @@ -107,17 +108,17 @@ func init() { // The second callback is called before the driver closes a connection. // The [sqlite3.Conn] can be used to execute queries, register functions, etc. func Open(dataSourceName string, fn ...func(*sqlite3.Conn) error) (*sql.DB, error) { - var drv SQLite if len(fn) > 2 { return nil, sqlite3.MISUSE } + var init, term func(*sqlite3.Conn) error if len(fn) > 1 { - drv.term = fn[1] + term = fn[1] } if len(fn) > 0 { - drv.init = fn[0] + init = fn[0] } - c, err := drv.OpenConnector(dataSourceName) + c, err := newConnector(dataSourceName, init, term) if err != nil { return nil, err } @@ -125,10 +126,7 @@ func Open(dataSourceName string, fn ...func(*sqlite3.Conn) error) (*sql.DB, erro } // SQLite implements [database/sql/driver.Driver]. -type SQLite struct { - init func(*sqlite3.Conn) error - term func(*sqlite3.Conn) error -} +type SQLite struct{} var ( // Ensure these interfaces are implemented: @@ -137,7 +135,7 @@ var ( // Open implements [database/sql/driver.Driver]. func (d *SQLite) Open(name string) (driver.Conn, error) { - c, err := d.newConnector(name) + c, err := newConnector(name, nil, nil) if err != nil { return nil, err } @@ -146,11 +144,11 @@ func (d *SQLite) Open(name string) (driver.Conn, error) { // OpenConnector implements [database/sql/driver.DriverContext]. func (d *SQLite) OpenConnector(name string) (driver.Connector, error) { - return d.newConnector(name) + return newConnector(name, nil, nil) } -func (d *SQLite) newConnector(name string) (*connector, error) { - c := connector{driver: d, name: name} +func newConnector(name string, init, term func(*sqlite3.Conn) error) (*connector, error) { + c := connector{name: name, init: init, term: term} var txlock, timefmt string if strings.HasPrefix(name, "file:") { @@ -190,7 +188,8 @@ func (d *SQLite) newConnector(name string) (*connector, error) { } type connector struct { - driver *SQLite + init func(*sqlite3.Conn) error + term func(*sqlite3.Conn) error name string txLock string tmRead sqlite3.TimeFormat @@ -199,7 +198,7 @@ type connector struct { } func (n *connector) Driver() driver.Driver { - return n.driver + return &SQLite{} } func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { @@ -228,13 +227,13 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { return nil, err } } - if n.driver.init != nil { - err = n.driver.init(c.Conn) + if n.init != nil { + err = n.init(c.Conn) if err != nil { return nil, err } } - if n.pragmas || n.driver.init != nil { + if n.pragmas || n.init != nil { s, _, err := c.Conn.Prepare(`PRAGMA query_only`) if err != nil { return nil, err @@ -250,9 +249,9 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { return nil, err } } - if n.driver.term != nil { + if n.term != nil { err = c.Conn.Trace(sqlite3.TRACE_CLOSE, func(sqlite3.TraceEvent, any, any) error { - return n.driver.term(c.Conn) + return n.term(c.Conn) }) if err != nil { return nil, err @@ -288,6 +287,8 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { type Conn interface { Raw() *sqlite3.Conn driver.Conn + driver.ConnBeginTx + driver.ConnPrepareContext } type conn struct { @@ -301,10 +302,8 @@ type conn struct { var ( // Ensure these interfaces are implemented: - _ Conn = &conn{} - _ driver.ConnBeginTx = &conn{} - _ driver.ConnPrepareContext = &conn{} - _ driver.ExecerContext = &conn{} + _ Conn = &conn{} + _ driver.ExecerContext = &conn{} ) func (c *conn) Raw() *sqlite3.Conn { @@ -581,8 +580,22 @@ type rows struct { names []string types []string nulls []bool + scans []scantype } +type scantype byte + +const ( + _ANY scantype = iota + _INT scantype = scantype(sqlite3.INTEGER) + _REAL scantype = scantype(sqlite3.FLOAT) + _TEXT scantype = scantype(sqlite3.TEXT) + _BLOB scantype = scantype(sqlite3.BLOB) + _NULL scantype = scantype(sqlite3.NULL) + _BOOL scantype = iota + _TIME +) + var ( // Ensure these interfaces are implemented: _ driver.RowsColumnTypeDatabaseTypeName = &rows{} @@ -606,21 +619,42 @@ func (r *rows) Columns() []string { return r.names } -func (r *rows) loadTypes() { +func (r *rows) loadColumnMetadata() { if r.nulls == nil { count := r.Stmt.ColumnCount() nulls := make([]bool, count) types := make([]string, count) + scans := make([]scantype, count) for i := range nulls { if col := r.Stmt.ColumnOriginName(i); col != "" { types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata( r.Stmt.ColumnDatabaseName(i), r.Stmt.ColumnTableName(i), col) + types[i] = strings.ToUpper(types[i]) + // These types are only used before we have rows, + // and otherwise as type hints. + // The first few ensure STRICT tables are strictly typed. + // The other two are type hints for booleans and time. + switch types[i] { + case "INT", "INTEGER": + scans[i] = _INT + case "REAL": + scans[i] = _REAL + case "TEXT": + scans[i] = _TEXT + case "BLOB": + scans[i] = _BLOB + case "BOOLEAN": + scans[i] = _BOOL + case "DATE", "TIME", "DATETIME", "TIMESTAMP": + scans[i] = _TIME + } } } r.nulls = nulls r.types = types + r.scans = scans } } @@ -637,7 +671,7 @@ func (r *rows) declType(index int) string { } func (r *rows) ColumnTypeDatabaseTypeName(index int) string { - r.loadTypes() + r.loadColumnMetadata() decltype := r.types[index] if len := len(decltype); len > 0 && decltype[len-1] == ')' { if i := strings.LastIndexByte(decltype, '('); i >= 0 { @@ -648,13 +682,57 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string { } func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { - r.loadTypes() + r.loadColumnMetadata() if r.nulls[index] { return false, true } return true, false } +func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) { + r.loadColumnMetadata() + scan := r.scans[index] + + if r.Stmt.Busy() { + // SQLite is dynamically typed and we now have a row. + // Always use the type of the value itself, + // unless the scan type is more specific + // and can scan the actual value. + val := scantype(r.Stmt.ColumnType(index)) + useValType := true + switch { + case scan == _TIME && val != _BLOB && val != _NULL: + t := r.Stmt.ColumnTime(index, r.tmRead) + useValType = t == time.Time{} + case scan == _BOOL && val == _INT: + i := r.Stmt.ColumnInt64(index) + useValType = i != 0 && i != 1 + case scan == _BLOB && val == _NULL: + useValType = false + } + if useValType { + scan = val + } + } + + switch scan { + case _INT: + return reflect.TypeOf(int64(0)) + case _REAL: + return reflect.TypeOf(float64(0)) + case _TEXT: + return reflect.TypeOf("") + case _BLOB: + return reflect.TypeOf([]byte{}) + case _BOOL: + return reflect.TypeOf(false) + case _TIME: + return reflect.TypeOf(time.Time{}) + default: + return reflect.TypeOf((*any)(nil)).Elem() + } +} + func (r *rows) Next(dest []driver.Value) error { old := r.Stmt.Conn().SetInterrupt(r.ctx) defer r.Stmt.Conn().SetInterrupt(old) @@ -667,7 +745,7 @@ func (r *rows) Next(dest []driver.Value) error { } data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest)) - err := r.Stmt.Columns(data) + err := r.Stmt.Columns(data...) for i := range dest { if t, ok := r.decodeTime(i, dest[i]); ok { dest[i] = t |