summaryrefslogtreecommitdiff
path: root/vendor/github.com/ncruces/go-sqlite3/sqlite.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/sqlite.go')
-rw-r--r--vendor/github.com/ncruces/go-sqlite3/sqlite.go341
1 files changed, 341 insertions, 0 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/sqlite.go b/vendor/github.com/ncruces/go-sqlite3/sqlite.go
new file mode 100644
index 000000000..61a03652f
--- /dev/null
+++ b/vendor/github.com/ncruces/go-sqlite3/sqlite.go
@@ -0,0 +1,341 @@
+// Package sqlite3 wraps the C SQLite API.
+package sqlite3
+
+import (
+ "context"
+ "math"
+ "math/bits"
+ "os"
+ "sync"
+ "unsafe"
+
+ "github.com/ncruces/go-sqlite3/internal/util"
+ "github.com/ncruces/go-sqlite3/vfs"
+ "github.com/tetratelabs/wazero"
+ "github.com/tetratelabs/wazero/api"
+)
+
+// Configure SQLite Wasm.
+//
+// Importing package embed initializes [Binary]
+// with an appropriate build of SQLite:
+//
+// import _ "github.com/ncruces/go-sqlite3/embed"
+var (
+ Binary []byte // Wasm binary to load.
+ Path string // Path to load the binary from.
+
+ RuntimeConfig wazero.RuntimeConfig
+)
+
+// Initialize decodes and compiles the SQLite Wasm binary.
+// This is called implicitly when the first connection is openned,
+// but is potentially slow, so you may want to call it at a more convenient time.
+func Initialize() error {
+ instance.once.Do(compileSQLite)
+ return instance.err
+}
+
+var instance struct {
+ runtime wazero.Runtime
+ compiled wazero.CompiledModule
+ err error
+ once sync.Once
+}
+
+func compileSQLite() {
+ if RuntimeConfig == nil {
+ RuntimeConfig = wazero.NewRuntimeConfig()
+ }
+
+ ctx := context.Background()
+ instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig)
+
+ env := instance.runtime.NewHostModuleBuilder("env")
+ env = vfs.ExportHostFunctions(env)
+ env = exportCallbacks(env)
+ _, instance.err = env.Instantiate(ctx)
+ if instance.err != nil {
+ return
+ }
+
+ bin := Binary
+ if bin == nil && Path != "" {
+ bin, instance.err = os.ReadFile(Path)
+ if instance.err != nil {
+ return
+ }
+ }
+ if bin == nil {
+ instance.err = util.NoBinaryErr
+ return
+ }
+
+ instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
+}
+
+type sqlite struct {
+ ctx context.Context
+ mod api.Module
+ funcs struct {
+ fn [32]api.Function
+ id [32]*byte
+ mask uint32
+ }
+ stack [8]uint64
+ freer uint32
+}
+
+func instantiateSQLite() (sqlt *sqlite, err error) {
+ if err := Initialize(); err != nil {
+ return nil, err
+ }
+
+ sqlt = new(sqlite)
+ sqlt.ctx = util.NewContext(context.Background())
+
+ sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
+ instance.compiled, wazero.NewModuleConfig().WithName(""))
+ if err != nil {
+ return nil, err
+ }
+
+ global := sqlt.mod.ExportedGlobal("malloc_destructor")
+ if global == nil {
+ return nil, util.BadBinaryErr
+ }
+
+ sqlt.freer = util.ReadUint32(sqlt.mod, uint32(global.Get()))
+ if sqlt.freer == 0 {
+ return nil, util.BadBinaryErr
+ }
+ return sqlt, nil
+}
+
+func (sqlt *sqlite) close() error {
+ return sqlt.mod.Close(sqlt.ctx)
+}
+
+func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
+ if rc == _OK {
+ return nil
+ }
+
+ err := Error{code: rc}
+
+ if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
+ panic(util.OOMErr)
+ }
+
+ if r := sqlt.call("sqlite3_errstr", rc); r != 0 {
+ err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME)
+ }
+
+ if handle != 0 {
+ if r := sqlt.call("sqlite3_errmsg", uint64(handle)); r != 0 {
+ err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH)
+ }
+
+ if sql != nil {
+ if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 {
+ err.sql = sql[0][r:]
+ }
+ }
+ }
+
+ switch err.msg {
+ case err.str, "not an error":
+ err.msg = ""
+ }
+ return &err
+}
+
+func (sqlt *sqlite) getfn(name string) api.Function {
+ c := &sqlt.funcs
+ p := unsafe.StringData(name)
+ for i := range c.id {
+ if c.id[i] == p {
+ c.id[i] = nil
+ c.mask &^= uint32(1) << i
+ return c.fn[i]
+ }
+ }
+ return sqlt.mod.ExportedFunction(name)
+}
+
+func (sqlt *sqlite) putfn(name string, fn api.Function) {
+ c := &sqlt.funcs
+ p := unsafe.StringData(name)
+ i := bits.TrailingZeros32(^c.mask)
+ if i < 32 {
+ c.id[i] = p
+ c.fn[i] = fn
+ c.mask |= uint32(1) << i
+ } else {
+ c.id[0] = p
+ c.fn[0] = fn
+ c.mask = uint32(1)
+ }
+}
+
+func (sqlt *sqlite) call(name string, params ...uint64) uint64 {
+ copy(sqlt.stack[:], params)
+ fn := sqlt.getfn(name)
+ err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
+ if err != nil {
+ panic(err)
+ }
+ sqlt.putfn(name, fn)
+ return sqlt.stack[0]
+}
+
+func (sqlt *sqlite) free(ptr uint32) {
+ if ptr == 0 {
+ return
+ }
+ sqlt.call("free", uint64(ptr))
+}
+
+func (sqlt *sqlite) new(size uint64) uint32 {
+ if size > _MAX_ALLOCATION_SIZE {
+ panic(util.OOMErr)
+ }
+ ptr := uint32(sqlt.call("malloc", size))
+ if ptr == 0 && size != 0 {
+ panic(util.OOMErr)
+ }
+ return ptr
+}
+
+func (sqlt *sqlite) newBytes(b []byte) uint32 {
+ if (*[0]byte)(b) == nil {
+ return 0
+ }
+ ptr := sqlt.new(uint64(len(b)))
+ util.WriteBytes(sqlt.mod, ptr, b)
+ return ptr
+}
+
+func (sqlt *sqlite) newString(s string) uint32 {
+ ptr := sqlt.new(uint64(len(s) + 1))
+ util.WriteString(sqlt.mod, ptr, s)
+ return ptr
+}
+
+func (sqlt *sqlite) newArena(size uint64) arena {
+ // Ensure the arena's size is a multiple of 8.
+ size = (size + 7) &^ 7
+ return arena{
+ sqlt: sqlt,
+ size: uint32(size),
+ base: sqlt.new(size),
+ }
+}
+
+type arena struct {
+ sqlt *sqlite
+ ptrs []uint32
+ base uint32
+ next uint32
+ size uint32
+}
+
+func (a *arena) free() {
+ if a.sqlt == nil {
+ return
+ }
+ for _, ptr := range a.ptrs {
+ a.sqlt.free(ptr)
+ }
+ a.sqlt.free(a.base)
+ a.sqlt = nil
+}
+
+func (a *arena) mark() (reset func()) {
+ ptrs := len(a.ptrs)
+ next := a.next
+ return func() {
+ for _, ptr := range a.ptrs[ptrs:] {
+ a.sqlt.free(ptr)
+ }
+ a.ptrs = a.ptrs[:ptrs]
+ a.next = next
+ }
+}
+
+func (a *arena) new(size uint64) uint32 {
+ // Align the next address, to 4 or 8 bytes.
+ if size&7 != 0 {
+ a.next = (a.next + 3) &^ 3
+ } else {
+ a.next = (a.next + 7) &^ 7
+ }
+ if size <= uint64(a.size-a.next) {
+ ptr := a.base + a.next
+ a.next += uint32(size)
+ return ptr
+ }
+ ptr := a.sqlt.new(size)
+ a.ptrs = append(a.ptrs, ptr)
+ return ptr
+}
+
+func (a *arena) bytes(b []byte) uint32 {
+ if (*[0]byte)(b) == nil {
+ return 0
+ }
+ ptr := a.new(uint64(len(b)))
+ util.WriteBytes(a.sqlt.mod, ptr, b)
+ return ptr
+}
+
+func (a *arena) string(s string) uint32 {
+ ptr := a.new(uint64(len(s) + 1))
+ util.WriteString(a.sqlt.mod, ptr, s)
+ return ptr
+}
+
+func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
+ util.ExportFuncII(env, "go_progress_handler", progressCallback)
+ util.ExportFuncIIII(env, "go_busy_timeout", timeoutCallback)
+ util.ExportFuncIII(env, "go_busy_handler", busyCallback)
+ util.ExportFuncII(env, "go_commit_hook", commitCallback)
+ util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)
+ util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback)
+ util.ExportFuncIIIII(env, "go_wal_hook", walCallback)
+ util.ExportFuncIIIIII(env, "go_autovacuum_pages", autoVacuumCallback)
+ util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback)
+ util.ExportFuncVIII(env, "go_log", logCallback)
+ util.ExportFuncVI(env, "go_destroy", destroyCallback)
+ util.ExportFuncVIIII(env, "go_func", funcCallback)
+ util.ExportFuncVIIIII(env, "go_step", stepCallback)
+ util.ExportFuncVIII(env, "go_final", finalCallback)
+ util.ExportFuncVII(env, "go_value", valueCallback)
+ util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
+ util.ExportFuncVIIII(env, "go_collation_needed", collationCallback)
+ util.ExportFuncIIIIII(env, "go_compare", compareCallback)
+ util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate))
+ util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect))
+ util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
+ util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback)
+ util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)
+ util.ExportFuncIIIII(env, "go_vtab_update", vtabUpdateCallback)
+ util.ExportFuncIII(env, "go_vtab_rename", vtabRenameCallback)
+ util.ExportFuncIIIII(env, "go_vtab_find_function", vtabFindFuncCallback)
+ util.ExportFuncII(env, "go_vtab_begin", vtabBeginCallback)
+ util.ExportFuncII(env, "go_vtab_sync", vtabSyncCallback)
+ util.ExportFuncII(env, "go_vtab_commit", vtabCommitCallback)
+ util.ExportFuncII(env, "go_vtab_rollback", vtabRollbackCallback)
+ util.ExportFuncIII(env, "go_vtab_savepoint", vtabSavepointCallback)
+ util.ExportFuncIII(env, "go_vtab_release", vtabReleaseCallback)
+ util.ExportFuncIII(env, "go_vtab_rollback_to", vtabRollbackToCallback)
+ util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback)
+ util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback)
+ util.ExportFuncII(env, "go_cur_close", cursorCloseCallback)
+ util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback)
+ util.ExportFuncII(env, "go_cur_next", cursorNextCallback)
+ util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback)
+ util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback)
+ util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback)
+ return env
+}