summaryrefslogtreecommitdiff
path: root/vendor/github.com/ncruces/go-sqlite3/txn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/txn.go')
-rw-r--r--vendor/github.com/ncruces/go-sqlite3/txn.go294
1 files changed, 294 insertions, 0 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/txn.go b/vendor/github.com/ncruces/go-sqlite3/txn.go
new file mode 100644
index 000000000..0efbc2d80
--- /dev/null
+++ b/vendor/github.com/ncruces/go-sqlite3/txn.go
@@ -0,0 +1,294 @@
+package sqlite3
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math/rand"
+ "runtime"
+ "strconv"
+ "strings"
+
+ "github.com/ncruces/go-sqlite3/internal/util"
+ "github.com/tetratelabs/wazero/api"
+)
+
+// Txn is an in-progress database transaction.
+//
+// https://sqlite.org/lang_transaction.html
+type Txn struct {
+ c *Conn
+}
+
+// Begin starts a deferred transaction.
+//
+// https://sqlite.org/lang_transaction.html
+func (c *Conn) Begin() Txn {
+ // BEGIN even if interrupted.
+ err := c.txnExecInterrupted(`BEGIN DEFERRED`)
+ if err != nil {
+ panic(err)
+ }
+ return Txn{c}
+}
+
+// BeginImmediate starts an immediate transaction.
+//
+// https://sqlite.org/lang_transaction.html
+func (c *Conn) BeginImmediate() (Txn, error) {
+ err := c.Exec(`BEGIN IMMEDIATE`)
+ if err != nil {
+ return Txn{}, err
+ }
+ return Txn{c}, nil
+}
+
+// BeginExclusive starts an exclusive transaction.
+//
+// https://sqlite.org/lang_transaction.html
+func (c *Conn) BeginExclusive() (Txn, error) {
+ err := c.Exec(`BEGIN EXCLUSIVE`)
+ if err != nil {
+ return Txn{}, err
+ }
+ return Txn{c}, nil
+}
+
+// End calls either [Txn.Commit] or [Txn.Rollback]
+// depending on whether *error points to a nil or non-nil error.
+//
+// This is meant to be deferred:
+//
+// func doWork(db *sqlite3.Conn) (err error) {
+// tx := db.Begin()
+// defer tx.End(&err)
+//
+// // ... do work in the transaction
+// }
+//
+// https://sqlite.org/lang_transaction.html
+func (tx Txn) End(errp *error) {
+ recovered := recover()
+ if recovered != nil {
+ defer panic(recovered)
+ }
+
+ if *errp == nil && recovered == nil {
+ // Success path.
+ if tx.c.GetAutocommit() { // There is nothing to commit.
+ return
+ }
+ *errp = tx.Commit()
+ if *errp == nil {
+ return
+ }
+ // Fall through to the error path.
+ }
+
+ // Error path.
+ if tx.c.GetAutocommit() { // There is nothing to rollback.
+ return
+ }
+ err := tx.Rollback()
+ if err != nil {
+ panic(err)
+ }
+}
+
+// Commit commits the transaction.
+//
+// https://sqlite.org/lang_transaction.html
+func (tx Txn) Commit() error {
+ return tx.c.Exec(`COMMIT`)
+}
+
+// Rollback rolls back the transaction,
+// even if the connection has been interrupted.
+//
+// https://sqlite.org/lang_transaction.html
+func (tx Txn) Rollback() error {
+ return tx.c.txnExecInterrupted(`ROLLBACK`)
+}
+
+// Savepoint is a marker within a transaction
+// that allows for partial rollback.
+//
+// https://sqlite.org/lang_savepoint.html
+type Savepoint struct {
+ c *Conn
+ name string
+}
+
+// Savepoint establishes a new transaction savepoint.
+//
+// https://sqlite.org/lang_savepoint.html
+func (c *Conn) Savepoint() Savepoint {
+ // Names can be reused; this makes catching bugs more likely.
+ name := saveptName() + "_" + strconv.Itoa(int(rand.Int31()))
+
+ err := c.txnExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
+ if err != nil {
+ panic(err)
+ }
+ return Savepoint{c: c, name: name}
+}
+
+func saveptName() (name string) {
+ defer func() {
+ if name == "" {
+ name = "sqlite3.Savepoint"
+ }
+ }()
+
+ var pc [8]uintptr
+ n := runtime.Callers(3, pc[:])
+ if n <= 0 {
+ return ""
+ }
+ frames := runtime.CallersFrames(pc[:n])
+ frame, more := frames.Next()
+ for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
+ strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
+ frame, more = frames.Next()
+ }
+ return frame.Function
+}
+
+// Release releases the savepoint rolling back any changes
+// if *error points to a non-nil error.
+//
+// This is meant to be deferred:
+//
+// func doWork(db *sqlite3.Conn) (err error) {
+// savept := db.Savepoint()
+// defer savept.Release(&err)
+//
+// // ... do work in the transaction
+// }
+func (s Savepoint) Release(errp *error) {
+ recovered := recover()
+ if recovered != nil {
+ defer panic(recovered)
+ }
+
+ if *errp == nil && recovered == nil {
+ // Success path.
+ if s.c.GetAutocommit() { // There is nothing to commit.
+ return
+ }
+ *errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
+ if *errp == nil {
+ return
+ }
+ // Fall through to the error path.
+ }
+
+ // Error path.
+ if s.c.GetAutocommit() { // There is nothing to rollback.
+ return
+ }
+ // ROLLBACK and RELEASE even if interrupted.
+ err := s.c.txnExecInterrupted(fmt.Sprintf(`
+ ROLLBACK TO %[1]q;
+ RELEASE %[1]q;
+ `, s.name))
+ if err != nil {
+ panic(err)
+ }
+}
+
+// Rollback rolls the transaction back to the savepoint,
+// even if the connection has been interrupted.
+// Rollback does not release the savepoint.
+//
+// https://sqlite.org/lang_transaction.html
+func (s Savepoint) Rollback() error {
+ // ROLLBACK even if interrupted.
+ return s.c.txnExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
+}
+
+func (c *Conn) txnExecInterrupted(sql string) error {
+ err := c.Exec(sql)
+ if errors.Is(err, INTERRUPT) {
+ old := c.SetInterrupt(context.Background())
+ defer c.SetInterrupt(old)
+ err = c.Exec(sql)
+ }
+ return err
+}
+
+// TxnState starts a deferred transaction.
+//
+// https://sqlite.org/c3ref/txn_state.html
+func (c *Conn) TxnState(schema string) TxnState {
+ var ptr uint32
+ if schema != "" {
+ defer c.arena.mark()()
+ ptr = c.arena.string(schema)
+ }
+ r := c.call("sqlite3_txn_state", uint64(c.handle), uint64(ptr))
+ return TxnState(r)
+}
+
+// CommitHook registers a callback function to be invoked
+// whenever a transaction is committed.
+// Return true to allow the commit operation to continue normally.
+//
+// https://sqlite.org/c3ref/commit_hook.html
+func (c *Conn) CommitHook(cb func() (ok bool)) {
+ var enable uint64
+ if cb != nil {
+ enable = 1
+ }
+ c.call("sqlite3_commit_hook_go", uint64(c.handle), enable)
+ c.commit = cb
+}
+
+// RollbackHook registers a callback function to be invoked
+// whenever a transaction is rolled back.
+//
+// https://sqlite.org/c3ref/commit_hook.html
+func (c *Conn) RollbackHook(cb func()) {
+ var enable uint64
+ if cb != nil {
+ enable = 1
+ }
+ c.call("sqlite3_rollback_hook_go", uint64(c.handle), enable)
+ c.rollback = cb
+}
+
+// UpdateHook registers a callback function to be invoked
+// whenever a row is updated, inserted or deleted in a rowid table.
+//
+// https://sqlite.org/c3ref/update_hook.html
+func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) {
+ var enable uint64
+ if cb != nil {
+ enable = 1
+ }
+ c.call("sqlite3_update_hook_go", uint64(c.handle), enable)
+ c.update = cb
+}
+
+func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback uint32) {
+ if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
+ if !c.commit() {
+ rollback = 1
+ }
+ }
+ return rollback
+}
+
+func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) {
+ if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil {
+ c.rollback()
+ }
+}
+
+func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zSchema, zTabName uint32, rowid uint64) {
+ if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil {
+ schema := util.ReadString(mod, zSchema, _MAX_NAME)
+ table := util.ReadString(mod, zTabName, _MAX_NAME)
+ c.update(action, schema, table, int64(rowid))
+ }
+}