diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/named_args.go')
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/named_args.go | 58 |
1 files changed, 42 insertions, 16 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/named_args.go b/vendor/github.com/jackc/pgx/v5/named_args.go index 8367fc63a..c88991ee4 100644 --- a/vendor/github.com/jackc/pgx/v5/named_args.go +++ b/vendor/github.com/jackc/pgx/v5/named_args.go @@ -2,6 +2,7 @@ package pgx import ( "context" + "fmt" "strconv" "strings" "unicode/utf8" @@ -21,6 +22,34 @@ type NamedArgs map[string]any // RewriteQuery implements the QueryRewriter interface. func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(na, sql, false) +} + +// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all +// named arguments that the sql query uses, and no extra arguments. +type StrictNamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(sna, sql, true) +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) { l := &sqlLexer{ src: sql, stateFn: rawState, @@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar newArgs = make([]any, len(l.nameToOrdinal)) for name, ordinal := range l.nameToOrdinal { - newArgs[ordinal-1] = na[string(name)] + var found bool + newArgs[ordinal-1], found = na[string(name)] + if isStrict && !found { + return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name) + } } - return sb.String(), newArgs, nil -} - -type namedArg string - -type sqlLexer struct { - src string - start int - pos int - nested int // multiline comment nesting level. - stateFn stateFn - parts []any + if isStrict { + for name := range na { + if _, found := l.nameToOrdinal[namedArg(name)]; !found { + return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name) + } + } + } - nameToOrdinal map[namedArg]int + return sb.String(), newArgs, nil } -type stateFn func(*sqlLexer) stateFn - func rawState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) |