diff options
Diffstat (limited to 'vendor/modernc.org/sqlite/sqlite.go')
-rw-r--r-- | vendor/modernc.org/sqlite/sqlite.go | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/vendor/modernc.org/sqlite/sqlite.go b/vendor/modernc.org/sqlite/sqlite.go index ec5b2c0f9..9763713fa 100644 --- a/vendor/modernc.org/sqlite/sqlite.go +++ b/vendor/modernc.org/sqlite/sqlite.go @@ -1844,17 +1844,32 @@ func (b *Backup) Finish() error { } } +type ExecQuerierContext interface { + driver.ExecerContext + driver.QueryerContext +} + +// ConnectionHookFn function type for a connection hook on the Driver. Connection +// hooks are called after the connection has been set up. +type ConnectionHookFn func( + conn ExecQuerierContext, + dsn string, +) error + // Driver implements database/sql/driver.Driver. type Driver struct { // user defined functions that are added to every new connection on Open udfs map[string]*userDefinedFunction // collations that are added to every new connection on Open collations map[string]*collation + // connection hooks are called after a connection is opened + connectionHooks []ConnectionHookFn } var d = &Driver{ - udfs: make(map[string]*userDefinedFunction, 0), - collations: make(map[string]*collation, 0), + udfs: make(map[string]*userDefinedFunction, 0), + collations: make(map[string]*collation, 0), + connectionHooks: make([]ConnectionHookFn, 0), } func newDriver() *Driver { return d } @@ -1909,6 +1924,12 @@ func (d *Driver) Open(name string) (conn driver.Conn, err error) { return nil, err } } + for _, connHookFn := range d.connectionHooks { + if err = connHookFn(c, name); err != nil { + c.Close() + return nil, fmt.Errorf("connection hook: %w", err) + } + } return c, nil } @@ -2063,6 +2084,18 @@ func registerFunction( return nil } +// RegisterConnectionHook registers a function to be called after each connection +// is opened. This is called after all the connection has been set up. +func (d *Driver) RegisterConnectionHook(fn ConnectionHookFn) { + d.connectionHooks = append(d.connectionHooks, fn) +} + +// RegisterConnectionHook registers a function to be called after each connection +// is opened. This is called after all the connection has been set up. +func RegisterConnectionHook(fn ConnectionHookFn) { + d.RegisterConnectionHook(fn) +} + func origin(skip int) string { pc, fn, fl, _ := runtime.Caller(skip) f := runtime.FuncForPC(pc) |