diff options
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/func.go')
| -rw-r--r-- | vendor/github.com/ncruces/go-sqlite3/func.go | 39 |
1 files changed, 25 insertions, 14 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/func.go b/vendor/github.com/ncruces/go-sqlite3/func.go index 6b69368b4..f907fa940 100644 --- a/vendor/github.com/ncruces/go-sqlite3/func.go +++ b/vendor/github.com/ncruces/go-sqlite3/func.go @@ -2,6 +2,7 @@ package sqlite3 import ( "context" + "io" "sync" "github.com/tetratelabs/wazero/api" @@ -85,13 +86,19 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) + call := "sqlite3_create_aggregate_function_go" if fn != nil { + agg := fn() + if c, ok := agg.(io.Closer); ok { + if err := c.Close(); err != nil { + return err + } + } + if _, ok := agg.(WindowFunction); ok { + call = "sqlite3_create_window_function_go" + } funcPtr = util.AddHandle(c.ctx, fn) } - call := "sqlite3_create_aggregate_function_go" - if _, ok := fn().(WindowFunction); ok { - call = "sqlite3_create_window_function_go" - } rc := res_t(c.call(call, stk_t(c.handle), stk_t(namePtr), stk_t(nArg), stk_t(flag), stk_t(funcPtr))) @@ -168,20 +175,24 @@ func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, n fn.Step(Context{db, pCtx}, args[:nArg]...) } -func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t) { +func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) { db := ctx.Value(connKey{}).(*Conn) fn, handle := callbackAggregate(db, pAgg, pApp) fn.Value(Context{db, pCtx}) - if err := util.DelHandle(ctx, handle); err != nil { - Context{db, pCtx}.ResultError(err) - return // notest - } -} -func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t) { - db := ctx.Value(connKey{}).(*Conn) - fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction) - fn.Value(Context{db, pCtx}) + // Cleanup. + if final != 0 { + var err error + if handle != 0 { + err = util.DelHandle(ctx, handle) + } else if c, ok := fn.(io.Closer); ok { + err = c.Close() + } + if err != nil { + Context{db, pCtx}.ResultError(err) + return // notest + } + } } func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) { |
