summaryrefslogtreecommitdiff
path: root/vendor/modernc.org/sqlite/sqlite.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/modernc.org/sqlite/sqlite.go')
-rw-r--r--vendor/modernc.org/sqlite/sqlite.go37
1 files changed, 35 insertions, 2 deletions
diff --git a/vendor/modernc.org/sqlite/sqlite.go b/vendor/modernc.org/sqlite/sqlite.go
index ec5b2c0f9..9763713fa 100644
--- a/vendor/modernc.org/sqlite/sqlite.go
+++ b/vendor/modernc.org/sqlite/sqlite.go
@@ -1844,17 +1844,32 @@ func (b *Backup) Finish() error {
}
}
+type ExecQuerierContext interface {
+ driver.ExecerContext
+ driver.QueryerContext
+}
+
+// ConnectionHookFn function type for a connection hook on the Driver. Connection
+// hooks are called after the connection has been set up.
+type ConnectionHookFn func(
+ conn ExecQuerierContext,
+ dsn string,
+) error
+
// Driver implements database/sql/driver.Driver.
type Driver struct {
// user defined functions that are added to every new connection on Open
udfs map[string]*userDefinedFunction
// collations that are added to every new connection on Open
collations map[string]*collation
+ // connection hooks are called after a connection is opened
+ connectionHooks []ConnectionHookFn
}
var d = &Driver{
- udfs: make(map[string]*userDefinedFunction, 0),
- collations: make(map[string]*collation, 0),
+ udfs: make(map[string]*userDefinedFunction, 0),
+ collations: make(map[string]*collation, 0),
+ connectionHooks: make([]ConnectionHookFn, 0),
}
func newDriver() *Driver { return d }
@@ -1909,6 +1924,12 @@ func (d *Driver) Open(name string) (conn driver.Conn, err error) {
return nil, err
}
}
+ for _, connHookFn := range d.connectionHooks {
+ if err = connHookFn(c, name); err != nil {
+ c.Close()
+ return nil, fmt.Errorf("connection hook: %w", err)
+ }
+ }
return c, nil
}
@@ -2063,6 +2084,18 @@ func registerFunction(
return nil
}
+// RegisterConnectionHook registers a function to be called after each connection
+// is opened. This is called after all the connection has been set up.
+func (d *Driver) RegisterConnectionHook(fn ConnectionHookFn) {
+ d.connectionHooks = append(d.connectionHooks, fn)
+}
+
+// RegisterConnectionHook registers a function to be called after each connection
+// is opened. This is called after all the connection has been set up.
+func RegisterConnectionHook(fn ConnectionHookFn) {
+ d.RegisterConnectionHook(fn)
+}
+
func origin(skip int) string {
pc, fn, fl, _ := runtime.Caller(skip)
f := runtime.FuncForPC(pc)