diff options
Diffstat (limited to 'vendor/modernc.org/sqlite/sqlite.go')
-rw-r--r-- | vendor/modernc.org/sqlite/sqlite.go | 306 |
1 files changed, 258 insertions, 48 deletions
diff --git a/vendor/modernc.org/sqlite/sqlite.go b/vendor/modernc.org/sqlite/sqlite.go index a484a2d0e..7fce68eeb 100644 --- a/vendor/modernc.org/sqlite/sqlite.go +++ b/vendor/modernc.org/sqlite/sqlite.go @@ -491,20 +491,7 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res var pstmt uintptr var done int32 if ctx != nil && ctx.Done() != nil { - donech := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - atomic.AddInt32(&done, 1) - s.c.interrupt(s.c.db) - case <-donech: - } - }() - - defer func() { - close(donech) - }() + defer interruptOnDone(ctx, s.c, &done)() } for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; { @@ -588,20 +575,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro var pstmt uintptr var done int32 if ctx != nil && ctx.Done() != nil { - donech := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - atomic.AddInt32(&done, 1) - s.c.interrupt(s.c.db) - case <-donech: - } - }() - - defer func() { - close(donech) - }() + defer interruptOnDone(ctx, s.c, &done)() } var allocs []uintptr @@ -718,19 +692,7 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) { //TODO use t.conn.ExecContext() instead if ctx != nil && ctx.Done() != nil { - donech := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - t.c.interrupt(t.c.db) - case <-donech: - } - }() - - defer func() { - close(donech) - }() + defer interruptOnDone(ctx, t.c, nil)() } if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK { @@ -740,6 +702,43 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) { return nil } +// interruptOnDone sets up a goroutine to interrupt the provided db when the +// context is canceled, and returns a function the caller must defer so it +// doesn't interrupt after the caller finishes. +func interruptOnDone( + ctx context.Context, + c *conn, + done *int32, +) func() { + if done == nil { + var d int32 + done = &d + } + + donech := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + // don't call interrupt if we were already done: it indicates that this + // call to exec is no longer running and we would be interrupting + // nothing, or even possibly an unrelated later call to exec. + if atomic.AddInt32(done, 1) == 1 { + c.interrupt(c.db) + } + case <-donech: + } + }() + + // the caller is expected to defer this function + return func() { + // set the done flag so that a context cancellation right after the caller + // returns doesn't trigger a call to interrupt for some other statement. + atomic.AddInt32(done, 1) + close(donech) + } +} + type conn struct { db uintptr // *sqlite3.Xsqlite3 tls *libc.TLS @@ -1091,14 +1090,18 @@ func (c *conn) bindText(pstmt uintptr, idx1 int, value string) (uintptr, error) // int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*)); func (c *conn) bindBlob(pstmt uintptr, idx1 int, value []byte) (uintptr, error) { + if len(value) == 0 { + if rc := sqlite3.Xsqlite3_bind_zeroblob(c.tls, pstmt, int32(idx1), 0); rc != sqlite3.SQLITE_OK { + return 0, c.errstr(rc) + } + return 0, nil + } + p, err := c.malloc(len(value)) if err != nil { return 0, err } - - if len(value) != 0 { - copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value) - } + copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value) if rc := sqlite3.Xsqlite3_bind_blob(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK { c.free(p) return 0, c.errstr(rc) @@ -1308,6 +1311,7 @@ func (c *conn) Close() error { c.db = 0 } + if c.tls != nil { c.tls.Close() c.tls = nil @@ -1324,6 +1328,32 @@ func (c *conn) closeV2(db uintptr) error { return nil } +type userDefinedFunction struct { + zFuncName uintptr + nArg int32 + eTextRep int32 + xFunc func(*libc.TLS, uintptr, int32, uintptr) + + freeOnce sync.Once +} + +func (c *conn) createFunctionInternal(fun *userDefinedFunction) error { + if rc := sqlite3.Xsqlite3_create_function( + c.tls, + c.db, + fun.zFuncName, + fun.nArg, + fun.eTextRep, + 0, + *(*uintptr)(unsafe.Pointer(&fun.xFunc)), + 0, + 0, + ); rc != sqlite3.SQLITE_OK { + return c.errstr(rc) + } + return nil +} + // Execer is an optional interface that may be implemented by a Conn. // // If a Conn does not implement Execer, the sql package's DB.Exec will first @@ -1389,9 +1419,14 @@ func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue } // Driver implements database/sql/driver.Driver. -type Driver struct{} +type Driver struct { + // user defined functions that are added to every new connection on Open + udfs map[string]*userDefinedFunction +} + +var d = &Driver{udfs: make(map[string]*userDefinedFunction)} -func newDriver() *Driver { return &Driver{} } +func newDriver() *Driver { return d } // Open returns a new connection to the database. The name is a string in a // driver-specific format. @@ -1423,5 +1458,180 @@ func newDriver() *Driver { return &Driver{} } // available at // https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions func (d *Driver) Open(name string) (driver.Conn, error) { - return newConn(name) + c, err := newConn(name) + if err != nil { + return nil, err + } + + for _, udf := range d.udfs { + if err = c.createFunctionInternal(udf); err != nil { + c.Close() + return nil, err + } + } + return c, nil +} + +// FunctionContext represents the context user defined functions execute in. +// Fields and/or methods of this type may get addedd in the future. +type FunctionContext struct{} + +const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{}) + +// RegisterScalarFunction registers a scalar function named zFuncName with nArg +// arguments. Passing -1 for nArg indicates the function is variadic. +// +// The new function will be available to all new connections opened after +// executing RegisterScalarFunction. +func RegisterScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8, xFunc) +} + +// MustRegisterScalarFunction is like RegisterScalarFunction but panics on +// error. +func MustRegisterScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) { + if err := RegisterScalarFunction(zFuncName, nArg, xFunc); err != nil { + panic(err) + } +} + +// MustRegisterDeterministicScalarFunction is like +// RegisterDeterministicScalarFunction but panics on error. +func MustRegisterDeterministicScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) { + if err := RegisterDeterministicScalarFunction(zFuncName, nArg, xFunc); err != nil { + panic(err) + } +} + +// RegisterDeterministicScalarFunction registers a deterministic scalar +// function named zFuncName with nArg arguments. Passing -1 for nArg indicates +// the function is variadic. A deterministic function means that the function +// always gives the same output when the input parameters are the same. +// +// The new function will be available to all new connections opened after +// executing RegisterDeterministicScalarFunction. +func RegisterDeterministicScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8|sqlite3.SQLITE_DETERMINISTIC, xFunc) +} + +func registerScalarFunction( + zFuncName string, + nArg int32, + eTextRep int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + + if _, ok := d.udfs[zFuncName]; ok { + return fmt.Errorf("a function named %q is already registered", zFuncName) + } + + // dont free, functions registered on the driver live as long as the program + name, err := libc.CString(zFuncName) + if err != nil { + return err + } + + udf := &userDefinedFunction{ + zFuncName: name, + nArg: nArg, + eTextRep: eTextRep, + xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { + setErrorResult := func(res error) { + errmsg, cerr := libc.CString(res.Error()) + if cerr != nil { + panic(cerr) + } + defer libc.Xfree(tls, errmsg) + sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1) + sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR) + } + + args := make([]driver.Value, argc) + for i := int32(0); i < argc; i++ { + valPtr := *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize)) + + switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType { + case sqlite3.SQLITE_TEXT: + args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr)) + case sqlite3.SQLITE_INTEGER: + args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr) + case sqlite3.SQLITE_FLOAT: + args[i] = sqlite3.Xsqlite3_value_double(tls, valPtr) + case sqlite3.SQLITE_NULL: + args[i] = nil + case sqlite3.SQLITE_BLOB: + size := sqlite3.Xsqlite3_value_bytes(tls, valPtr) + blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr) + v := make([]byte, size) + copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size]) + args[i] = v + default: + panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType)) + } + } + + res, err := xFunc(&FunctionContext{}, args) + if err != nil { + setErrorResult(err) + return + } + + switch resTyped := res.(type) { + case nil: + sqlite3.Xsqlite3_result_null(tls, ctx) + case int64: + sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped) + case float64: + sqlite3.Xsqlite3_result_double(tls, ctx, resTyped) + case bool: + sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(resTyped)) + case time.Time: + sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped.Unix()) + case string: + size := int32(len(resTyped)) + cstr, err := libc.CString(resTyped) + if err != nil { + panic(err) + } + defer libc.Xfree(tls, cstr) + sqlite3.Xsqlite3_result_text(tls, ctx, cstr, size, sqlite3.SQLITE_TRANSIENT) + case []byte: + size := int32(len(resTyped)) + if size == 0 { + sqlite3.Xsqlite3_result_zeroblob(tls, ctx, 0) + return + } + p := libc.Xmalloc(tls, types.Size_t(size)) + if p == 0 { + panic(fmt.Sprintf("unable to allocate space for blob: %d", size)) + } + defer libc.Xfree(tls, p) + copy((*libc.RawMem)(unsafe.Pointer(p))[:size:size], resTyped) + + sqlite3.Xsqlite3_result_blob(tls, ctx, p, size, sqlite3.SQLITE_TRANSIENT) + default: + setErrorResult(fmt.Errorf("function did not return a valid driver.Value: %T", resTyped)) + return + } + }, + } + d.udfs[zFuncName] = udf + + return nil } |