summaryrefslogtreecommitdiff
path: root/vendor/modernc.org/sqlite/pre_update_hook.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/modernc.org/sqlite/pre_update_hook.go')
-rw-r--r--vendor/modernc.org/sqlite/pre_update_hook.go227
1 files changed, 227 insertions, 0 deletions
diff --git a/vendor/modernc.org/sqlite/pre_update_hook.go b/vendor/modernc.org/sqlite/pre_update_hook.go
new file mode 100644
index 000000000..9a00fe5ac
--- /dev/null
+++ b/vendor/modernc.org/sqlite/pre_update_hook.go
@@ -0,0 +1,227 @@
+package sqlite
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "unsafe"
+
+ "modernc.org/libc"
+ "modernc.org/libc/sys/types"
+ sqlite3 "modernc.org/sqlite/lib"
+)
+
+var (
+ xPreUpdateHandlers = struct {
+ mu sync.RWMutex
+ m map[uintptr]func(SQLitePreUpdateData)
+ }{
+ m: make(map[uintptr]func(SQLitePreUpdateData)),
+ }
+ xCommitHandlers = struct {
+ mu sync.RWMutex
+ m map[uintptr]CommitHookFn
+ }{
+ m: make(map[uintptr]CommitHookFn),
+ }
+ xRollbackHandlers = struct {
+ mu sync.RWMutex
+ m map[uintptr]RollbackHookFn
+ }{
+ m: make(map[uintptr]RollbackHookFn),
+ }
+)
+
+type PreUpdateHookFn func(SQLitePreUpdateData)
+
+func (c *conn) RegisterPreUpdateHook(callback PreUpdateHookFn) {
+
+ if callback == nil {
+ xPreUpdateHandlers.mu.Lock()
+ delete(xPreUpdateHandlers.m, c.db)
+ xPreUpdateHandlers.mu.Unlock()
+ sqlite3.Xsqlite3_preupdate_hook(c.tls, c.db, uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(nil)))
+ return
+ }
+ xPreUpdateHandlers.mu.Lock()
+ xPreUpdateHandlers.m[c.db] = callback
+ xPreUpdateHandlers.mu.Unlock()
+
+ sqlite3.Xsqlite3_preupdate_hook(c.tls, c.db, cFuncPointer(preUpdateHookTrampoline), c.db)
+}
+
+type CommitHookFn func() int32
+
+func (c *conn) RegisterCommitHook(callback CommitHookFn) {
+ if callback == nil {
+ xCommitHandlers.mu.Lock()
+ delete(xCommitHandlers.m, c.db)
+ xCommitHandlers.mu.Unlock()
+ sqlite3.Xsqlite3_commit_hook(c.tls, c.db, uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(nil)))
+ return
+ }
+ xCommitHandlers.mu.Lock()
+ xCommitHandlers.m[c.db] = callback
+ xCommitHandlers.mu.Unlock()
+ sqlite3.Xsqlite3_commit_hook(c.tls, c.db, cFuncPointer(commitHookTrampoline), c.db)
+}
+
+type RollbackHookFn func()
+
+func (c *conn) RegisterRollbackHook(callback RollbackHookFn) {
+ if callback == nil {
+ xRollbackHandlers.mu.Lock()
+ delete(xRollbackHandlers.m, c.db)
+ xRollbackHandlers.mu.Unlock()
+ sqlite3.Xsqlite3_rollback_hook(c.tls, c.db, uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(nil)))
+ return
+ }
+ xRollbackHandlers.mu.Lock()
+ xRollbackHandlers.m[c.db] = callback
+ xRollbackHandlers.mu.Unlock()
+ sqlite3.Xsqlite3_rollback_hook(c.tls, c.db, cFuncPointer(rollbackHookTrampoline), c.db)
+}
+
+type SQLitePreUpdateData struct {
+ tls *libc.TLS
+ pCsr uintptr
+ Op int32
+ DatabaseName string
+ TableName string
+ OldRowID int64
+ NewRowID int64
+}
+
+// Depth returns the source path of the write, see sqlite3_preupdate_depth()
+func (d *SQLitePreUpdateData) Depth() int {
+ return int(sqlite3.Xsqlite3_preupdate_depth(d.tls, d.pCsr))
+}
+
+// Count returns the number of columns in the row
+func (d *SQLitePreUpdateData) Count() int {
+ return int(sqlite3.Xsqlite3_preupdate_count(d.tls, d.pCsr))
+}
+
+func (d *SQLitePreUpdateData) row(dest []any, new bool) error {
+ count := d.Count()
+ ppValue, err := mallocValue(d.tls)
+ if err != nil {
+ return err
+ }
+ defer libc.Xfree(d.tls, ppValue)
+
+ for i := 0; i < count && i < len(dest); i++ {
+ val, err := d.value(ppValue, i, new)
+ if err != nil {
+ return err
+ }
+ err = convertAssign(&dest[i], val)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Old populates dest with the row data to be replaced. This works similar to
+// database/sql's Rows.Scan()
+func (d *SQLitePreUpdateData) Old(dest ...any) error {
+ if d.Op == sqlite3.SQLITE_INSERT {
+ return errors.New("there is no old row for INSERT operations")
+ }
+ return d.row(dest, false)
+}
+
+// New populates dest with the replacement row data. This works similar to
+// database/sql's Rows.Scan()
+func (d *SQLitePreUpdateData) New(dest ...any) error {
+ if d.Op == sqlite3.SQLITE_DELETE {
+ return errors.New("there is no new row for DELETE operations")
+ }
+ return d.row(dest, true)
+}
+
+const ptrValSize = types.Size_t(unsafe.Sizeof(&sqlite3.Sqlite3_value{}))
+
+func mallocValue(tls *libc.TLS) (uintptr, error) {
+ p := libc.Xmalloc(tls, ptrValSize)
+ if p == 0 {
+ return 0, fmt.Errorf("out of memory")
+ }
+ return p, nil
+}
+
+func (d *SQLitePreUpdateData) value(ppValue uintptr, i int, new bool) (any, error) {
+ var src any
+ if new {
+ sqlite3.Xsqlite3_preupdate_new(d.tls, d.pCsr, int32(i), ppValue)
+ } else {
+ sqlite3.Xsqlite3_preupdate_old(d.tls, d.pCsr, int32(i), ppValue)
+ }
+ ptrValue := *(*uintptr)(unsafe.Pointer(ppValue))
+ switch sqlite3.Xsqlite3_value_type(d.tls, ptrValue) {
+ case sqlite3.SQLITE_INTEGER:
+ src = int64(sqlite3.Xsqlite3_value_int64(d.tls, ptrValue))
+ case sqlite3.SQLITE_FLOAT:
+ src = float64(sqlite3.Xsqlite3_value_double(d.tls, ptrValue))
+ case sqlite3.SQLITE_BLOB:
+ size := sqlite3.Xsqlite3_value_bytes(d.tls, ptrValue)
+ blobPtr := sqlite3.Xsqlite3_value_blob(d.tls, ptrValue)
+
+ var v []byte
+ if size != 0 {
+ v = make([]byte, size)
+ copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size])
+ }
+ src = v
+ case sqlite3.SQLITE_TEXT:
+ src = libc.GoString(sqlite3.Xsqlite3_value_text(d.tls, ptrValue))
+ case sqlite3.SQLITE_NULL:
+ src = nil
+ }
+ return src, nil
+}
+
+func preUpdateHookTrampoline(tls *libc.TLS, handle uintptr, pCsr uintptr, op int32, zDb uintptr, pTab uintptr, iKey1 int64, iReg int32, iBlobWrite int32) {
+ xPreUpdateHandlers.mu.RLock()
+ xPreUpdateHandler := xPreUpdateHandlers.m[handle]
+ xPreUpdateHandlers.mu.RUnlock()
+
+ if xPreUpdateHandler == nil {
+ return
+ }
+ data := SQLitePreUpdateData{
+ tls: tls,
+ pCsr: pCsr,
+ Op: op,
+ DatabaseName: libc.GoString(zDb),
+ TableName: libc.GoString(pTab),
+ OldRowID: iKey1,
+ NewRowID: int64(iReg),
+ }
+ xPreUpdateHandler(data)
+}
+
+func commitHookTrampoline(tls *libc.TLS, handle uintptr, pCsr uintptr) int32 {
+ xCommitHandlers.mu.RLock()
+ xCommitHandler := xCommitHandlers.m[handle]
+ xCommitHandlers.mu.RUnlock()
+
+ if xCommitHandler == nil {
+ return 0
+ }
+
+ return xCommitHandler()
+}
+
+func rollbackHookTrampoline(tls *libc.TLS, handle uintptr, pCsr uintptr) {
+ xRollbackHandlers.mu.RLock()
+ xRollbackHandler := xRollbackHandlers.m[handle]
+ xRollbackHandlers.mu.RUnlock()
+
+ if xRollbackHandler == nil {
+ return
+ }
+
+ xRollbackHandler()
+}