summaryrefslogtreecommitdiff
path: root/vendor/github.com/ncruces/go-sqlite3/func.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ncruces/go-sqlite3/func.go')
-rw-r--r--vendor/github.com/ncruces/go-sqlite3/func.go153
1 files changed, 114 insertions, 39 deletions
diff --git a/vendor/github.com/ncruces/go-sqlite3/func.go b/vendor/github.com/ncruces/go-sqlite3/func.go
index f907fa940..16b43056d 100644
--- a/vendor/github.com/ncruces/go-sqlite3/func.go
+++ b/vendor/github.com/ncruces/go-sqlite3/func.go
@@ -3,7 +3,9 @@ package sqlite3
import (
"context"
"io"
+ "iter"
"sync"
+ "sync/atomic"
"github.com/tetratelabs/wazero/api"
@@ -45,7 +47,7 @@ func (c Conn) AnyCollationNeeded() error {
// CreateCollation defines a new collating sequence.
//
// https://sqlite.org/c3ref/create_collation.html
-func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
+func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
@@ -57,6 +59,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
return c.error(rc)
}
+// Collating function is the type of a collation callback.
+// Implementations must not retain a or b.
+type CollatingFunction func(a, b []byte) int
+
// CreateFunction defines a new scalar SQL function.
//
// https://sqlite.org/c3ref/create_function.html
@@ -77,34 +83,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala
// Implementations must not retain arg.
type ScalarFunction func(ctx Context, arg ...Value)
+// CreateAggregateFunction defines a new aggregate SQL function.
+//
+// https://sqlite.org/c3ref/create_function.html
+func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
+ var funcPtr ptr_t
+ defer c.arena.mark()()
+ namePtr := c.arena.string(name)
+ if fn != nil {
+ funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
+ var a aggregateFunc
+ coro := func(yieldCoro func(struct{}) bool) {
+ seq := func(yieldSeq func([]Value) bool) {
+ for yieldSeq(a.arg) {
+ if !yieldCoro(struct{}{}) {
+ break
+ }
+ }
+ }
+ fn(&a.ctx, seq)
+ }
+ a.next, a.stop = iter.Pull(coro)
+ return &a
+ }))
+ }
+ rc := res_t(c.call("sqlite3_create_aggregate_function_go",
+ stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
+ stk_t(flag), stk_t(funcPtr)))
+ return c.error(rc)
+}
+
+// AggregateSeqFunction is the type of an aggregate SQL function.
+// Implementations must not retain the slices yielded by seq.
+type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
+
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
-// If fn returns a [WindowFunction], then an aggregate window function is created.
+// If fn returns a [WindowFunction], an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://sqlite.org/c3ref/create_function.html
-func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
+func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
- call := "sqlite3_create_aggregate_function_go"
if fn != nil {
- agg := fn()
- if c, ok := agg.(io.Closer); ok {
- if err := c.Close(); err != nil {
- return err
+ funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
+ agg := fn()
+ if win, ok := agg.(WindowFunction); ok {
+ return win
}
- }
- if _, ok := agg.(WindowFunction); ok {
- call = "sqlite3_create_window_function_go"
- }
- funcPtr = util.AddHandle(c.ctx, fn)
+ return windowFunc{agg, name}
+ }))
}
- rc := res_t(c.call(call,
+ rc := res_t(c.call("sqlite3_create_window_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
+// AggregateConstructor is a an [AggregateFunction] constructor.
+type AggregateConstructor func() AggregateFunction
+
// AggregateFunction is the interface an aggregate function should implement.
//
// https://sqlite.org/appfunc.html
@@ -153,26 +192,24 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe
}
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
- fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
+ fn := util.GetHandle(ctx, pApp).(CollatingFunction)
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
}
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) {
- args := getFuncArgs()
- defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
+ args := callbackArgs(db, nArg, pArg)
+ defer returnArgs(args)
fn := util.GetHandle(db.ctx, pApp).(ScalarFunction)
- callbackArgs(db, args[:nArg], pArg)
- fn(Context{db, pCtx}, args[:nArg]...)
+ fn(Context{db, pCtx}, *args...)
}
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) {
- args := getFuncArgs()
- defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
- callbackArgs(db, args[:nArg], pArg)
+ args := callbackArgs(db, nArg, pArg)
+ defer returnArgs(args)
fn, _ := callbackAggregate(db, pAgg, pApp)
- fn.Step(Context{db, pCtx}, args[:nArg]...)
+ fn.Step(Context{db, pCtx}, *args...)
}
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) {
@@ -196,12 +233,11 @@ func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t,
}
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) {
- args := getFuncArgs()
- defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
- callbackArgs(db, args[:nArg], pArg)
+ args := callbackArgs(db, nArg, pArg)
+ defer returnArgs(args)
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
- fn.Inverse(Context{db, pCtx}, args[:nArg]...)
+ fn.Inverse(Context{db, pCtx}, *args...)
}
func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
@@ -211,7 +247,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
}
// We need to create the aggregate.
- fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
+ fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
if pAgg != 0 {
handle := util.AddHandle(db.ctx, fn)
util.Write32(db.mod, pAgg, handle)
@@ -220,25 +256,64 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
return fn, 0
}
-func callbackArgs(db *Conn, arg []Value, pArg ptr_t) {
- for i := range arg {
- arg[i] = Value{
+var (
+ valueArgsPool sync.Pool
+ valueArgsLen atomic.Int32
+)
+
+func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value {
+ arg, ok := valueArgsPool.Get().(*[]Value)
+ if !ok || cap(*arg) < int(nArg) {
+ max := valueArgsLen.Or(nArg) | nArg
+ lst := make([]Value, max)
+ arg = &lst
+ }
+ lst := (*arg)[:nArg]
+ for i := range lst {
+ lst[i] = Value{
c: db,
handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen),
}
}
+ *arg = lst
+ return arg
}
-var funcArgsPool sync.Pool
+func returnArgs(p *[]Value) {
+ valueArgsPool.Put(p)
+}
-func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
- funcArgsPool.Put(p)
+type aggregateFunc struct {
+ next func() (struct{}, bool)
+ stop func()
+ ctx Context
+ arg []Value
}
-func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
- if p := funcArgsPool.Get(); p == nil {
- return new([_MAX_FUNCTION_ARG]Value)
- } else {
- return p.(*[_MAX_FUNCTION_ARG]Value)
+func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
+ a.ctx = ctx
+ a.arg = append(a.arg[:0], arg...)
+ if _, more := a.next(); !more {
+ a.stop()
}
}
+
+func (a *aggregateFunc) Value(ctx Context) {
+ a.ctx = ctx
+ a.stop()
+}
+
+func (a *aggregateFunc) Close() error {
+ a.stop()
+ return nil
+}
+
+type windowFunc struct {
+ AggregateFunction
+ name string
+}
+
+func (w windowFunc) Inverse(ctx Context, arg ...Value) {
+ // Implementing inverse allows certain queries that don't really need it to succeed.
+ ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function"))
+}