summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/named_args.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/named_args.go')
-rw-r--r--vendor/github.com/jackc/pgx/v5/named_args.go58
1 files changed, 42 insertions, 16 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/named_args.go b/vendor/github.com/jackc/pgx/v5/named_args.go
index 8367fc63a..c88991ee4 100644
--- a/vendor/github.com/jackc/pgx/v5/named_args.go
+++ b/vendor/github.com/jackc/pgx/v5/named_args.go
@@ -2,6 +2,7 @@ package pgx
import (
"context"
+ "fmt"
"strconv"
"strings"
"unicode/utf8"
@@ -21,6 +22,34 @@ type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
+ return rewriteQuery(na, sql, false)
+}
+
+// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
+// named arguments that the sql query uses, and no extra arguments.
+type StrictNamedArgs map[string]any
+
+// RewriteQuery implements the QueryRewriter interface.
+func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
+ return rewriteQuery(sna, sql, true)
+}
+
+type namedArg string
+
+type sqlLexer struct {
+ src string
+ start int
+ pos int
+ nested int // multiline comment nesting level.
+ stateFn stateFn
+ parts []any
+
+ nameToOrdinal map[namedArg]int
+}
+
+type stateFn func(*sqlLexer) stateFn
+
+func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
@@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal {
- newArgs[ordinal-1] = na[string(name)]
+ var found bool
+ newArgs[ordinal-1], found = na[string(name)]
+ if isStrict && !found {
+ return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
+ }
}
- return sb.String(), newArgs, nil
-}
-
-type namedArg string
-
-type sqlLexer struct {
- src string
- start int
- pos int
- nested int // multiline comment nesting level.
- stateFn stateFn
- parts []any
+ if isStrict {
+ for name := range na {
+ if _, found := l.nameToOrdinal[namedArg(name)]; !found {
+ return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
+ }
+ }
+ }
- nameToOrdinal map[namedArg]int
+ return sb.String(), newArgs, nil
}
-type stateFn func(*sqlLexer) stateFn
-
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])