summaryrefslogtreecommitdiff
path: root/vendor/github.com/ncruces/go-sqlite3/func.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/func.go')
-rw-r--r--vendor/github.com/ncruces/go-sqlite3/func.go39
1 files changed, 25 insertions, 14 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/func.go b/vendor/github.com/ncruces/go-sqlite3/func.go
index 6b69368b4..f907fa940 100644
--- a/vendor/github.com/ncruces/go-sqlite3/func.go
+++ b/vendor/github.com/ncruces/go-sqlite3/func.go
@@ -2,6 +2,7 @@ package sqlite3
import (
"context"
+ "io"
"sync"
"github.com/tetratelabs/wazero/api"
@@ -85,13 +86,19 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
+ call := "sqlite3_create_aggregate_function_go"
if fn != nil {
+ agg := fn()
+ if c, ok := agg.(io.Closer); ok {
+ if err := c.Close(); err != nil {
+ return err
+ }
+ }
+ if _, ok := agg.(WindowFunction); ok {
+ call = "sqlite3_create_window_function_go"
+ }
funcPtr = util.AddHandle(c.ctx, fn)
}
- call := "sqlite3_create_aggregate_function_go"
- if _, ok := fn().(WindowFunction); ok {
- call = "sqlite3_create_window_function_go"
- }
rc := res_t(c.call(call,
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
@@ -168,20 +175,24 @@ func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, n
fn.Step(Context{db, pCtx}, args[:nArg]...)
}
-func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t) {
+func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) {
db := ctx.Value(connKey{}).(*Conn)
fn, handle := callbackAggregate(db, pAgg, pApp)
fn.Value(Context{db, pCtx})
- if err := util.DelHandle(ctx, handle); err != nil {
- Context{db, pCtx}.ResultError(err)
- return // notest
- }
-}
-func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t) {
- db := ctx.Value(connKey{}).(*Conn)
- fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction)
- fn.Value(Context{db, pCtx})
+ // Cleanup.
+ if final != 0 {
+ var err error
+ if handle != 0 {
+ err = util.DelHandle(ctx, handle)
+ } else if c, ok := fn.(io.Closer); ok {
+ err = c.Close()
+ }
+ if err != nil {
+ Context{db, pCtx}.ResultError(err)
+ return // notest
+ }
+ }
}
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) {