diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/internal/sanitize')
| -rw-r--r-- | vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh | 60 | ||||
| -rw-r--r-- | vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go | 183 |
2 files changed, 216 insertions, 27 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh new file mode 100644 index 000000000..ec0f7b03a --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +current_branch=$(git rev-parse --abbrev-ref HEAD) +if [ "$current_branch" == "HEAD" ]; then + current_branch=$(git rev-parse HEAD) +fi + +restore_branch() { + echo "Restoring original branch/commit: $current_branch" + git checkout "$current_branch" +} +trap restore_branch EXIT + +# Check if there are uncommitted changes +if ! git diff --quiet || ! git diff --cached --quiet; then + echo "There are uncommitted changes. Please commit or stash them before running this script." + exit 1 +fi + +# Ensure that at least one commit argument is passed +if [ "$#" -lt 1 ]; then + echo "Usage: $0 <commit1> <commit2> ... <commitN>" + exit 1 +fi + +commits=("$@") +benchmarks_dir=benchmarks + +if ! mkdir -p "${benchmarks_dir}"; then + echo "Unable to create dir for benchmarks data" + exit 1 +fi + +# Benchmark results +bench_files=() + +# Run benchmark for each listed commit +for i in "${!commits[@]}"; do + commit="${commits[i]}" + git checkout "$commit" || { + echo "Failed to checkout $commit" + exit 1 + } + + # Sanitized commmit message + commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') + + # Benchmark data will go there + bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" + + if ! go test -bench=. -count=10 >"$bench_file"; then + echo "Benchmarking failed for commit $commit" + exit 1 + fi + + bench_files+=("$bench_file") +done + +# go install golang.org/x/perf/cmd/benchstat[@latest] +benchstat "${bench_files[@]}" diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go b/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go index df58c4484..b516817cb 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go +++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/hex" "fmt" + "slices" "strconv" "strings" + "sync" "time" "unicode/utf8" ) @@ -24,18 +26,33 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 +const maxBufSize = 16384 // 16 Ki + +var bufPool = &pool[*bytes.Buffer]{ + new: func() *bytes.Buffer { + return &bytes.Buffer{} + }, + reset: func(b *bytes.Buffer) bool { + n := b.Len() + b.Reset() + return n < maxBufSize + }, +} + +var null = []byte("null") + func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) - buf := &bytes.Buffer{} + buf := bufPool.get() + defer bufPool.put(buf) for _, part := range q.Parts { - var str string switch part := part.(type) { case string: - str = part + buf.WriteString(part) case int: argIdx := part - 1 - + var p []byte if argIdx < 0 { return "", fmt.Errorf("first sql argument must be > 0") } @@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) { if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + buf.WriteByte(' ') + arg := args[argIdx] switch arg := arg.(type) { case nil: - str = "null" + p = null case int64: - str = strconv.FormatInt(arg, 10) + p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) case float64: - str = strconv.FormatFloat(arg, 'f', -1, 64) + p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) case bool: - str = strconv.FormatBool(arg) + p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - str = QuoteBytes(arg) + p = QuoteBytes(buf.AvailableBuffer(), arg) case string: - str = QuoteString(arg) + p = QuoteString(buf.AvailableBuffer(), arg) case time.Time: - str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond). + AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true + buf.Write(p) + // Prevent SQL injection via Line Comment Creation // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - str = " " + str + " " + buf.WriteByte(' ') default: return "", fmt.Errorf("invalid Part type: %T", part) } - buf.WriteString(str) } for i, used := range argUse { @@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) { } func NewQuery(sql string) (*Query, error) { - l := &sqlLexer{ - src: sql, - stateFn: rawState, + query := &Query{} + query.init(sql) + + return query, nil +} + +var sqlLexerPool = &pool[*sqlLexer]{ + new: func() *sqlLexer { + return &sqlLexer{} + }, + reset: func(sl *sqlLexer) bool { + *sl = sqlLexer{} + return true + }, +} + +func (q *Query) init(sql string) { + parts := q.Parts[:0] + if parts == nil { + // dirty, but fast heuristic to preallocate for ~90% usecases + n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 + parts = make([]Part, 0, n) } + l := sqlLexerPool.get() + defer sqlLexerPool.put(l) + + l.src = sql + l.stateFn = rawState + l.parts = parts + for l.stateFn != nil { l.stateFn = l.stateFn(l) } - query := &Query{Parts: l.parts} - - return query, nil + q.Parts = l.parts } -func QuoteString(str string) string { - return "'" + strings.ReplaceAll(str, "'", "''") + "'" +func QuoteString(dst []byte, str string) []byte { + const quote = '\'' + + // Preallocate space for the worst case scenario + dst = slices.Grow(dst, len(str)*2+2) + + // Add opening quote + dst = append(dst, quote) + + // Iterate through the string without allocating + for i := 0; i < len(str); i++ { + if str[i] == quote { + dst = append(dst, quote, quote) + } else { + dst = append(dst, str[i]) + } + } + + // Add closing quote + dst = append(dst, quote) + + return dst } -func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" +func QuoteBytes(dst, buf []byte) []byte { + if len(buf) == 0 { + return append(dst, `'\x'`...) + } + + // Calculate required length + requiredLen := 3 + hex.EncodedLen(len(buf)) + 1 + + // Ensure dst has enough capacity + if cap(dst)-len(dst) < requiredLen { + newDst := make([]byte, len(dst), len(dst)+requiredLen) + copy(newDst, dst) + dst = newDst + } + + // Record original length and extend slice + origLen := len(dst) + dst = dst[:origLen+requiredLen] + + // Add prefix + dst[origLen] = '\'' + dst[origLen+1] = '\\' + dst[origLen+2] = 'x' + + // Encode bytes directly into dst + hex.Encode(dst[origLen+3:len(dst)-1], buf) + + // Add suffix + dst[len(dst)-1] = '\'' + + return dst } type sqlLexer struct { @@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn { } } +var queryPool = &pool[*Query]{ + new: func() *Query { + return &Query{} + }, + reset: func(q *Query) bool { + n := len(q.Parts) + q.Parts = q.Parts[:0] + return n < 64 // drop too large queries + }, +} + // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. func SanitizeSQL(sql string, args ...any) (string, error) { - query, err := NewQuery(sql) - if err != nil { - return "", err - } + query := queryPool.get() + query.init(sql) + defer queryPool.put(query) + return query.Sanitize(args...) } + +type pool[E any] struct { + p sync.Pool + new func() E + reset func(E) bool +} + +func (pool *pool[E]) get() E { + v, ok := pool.p.Get().(E) + if !ok { + v = pool.new() + } + + return v +} + +func (p *pool[E]) put(v E) { + if p.reset(v) { + p.p.Put(v) + } +} |
