summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v4/internal/sanitize
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v4/internal/sanitize')
-rw-r--r--vendor/github.com/jackc/pgx/v4/internal/sanitize/sanitize.go304
1 files changed, 304 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v4/internal/sanitize/sanitize.go b/vendor/github.com/jackc/pgx/v4/internal/sanitize/sanitize.go
new file mode 100644
index 000000000..2dba3b810
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v4/internal/sanitize/sanitize.go
@@ -0,0 +1,304 @@
+package sanitize
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+ "unicode/utf8"
+)
+
+// Part is either a string or an int. A string is raw SQL. An int is a
+// argument placeholder.
+type Part interface{}
+
+type Query struct {
+ Parts []Part
+}
+
+func (q *Query) Sanitize(args ...interface{}) (string, error) {
+ argUse := make([]bool, len(args))
+ buf := &bytes.Buffer{}
+
+ for _, part := range q.Parts {
+ var str string
+ switch part := part.(type) {
+ case string:
+ str = part
+ case int:
+ argIdx := part - 1
+ if argIdx >= len(args) {
+ return "", fmt.Errorf("insufficient arguments")
+ }
+ arg := args[argIdx]
+ switch arg := arg.(type) {
+ case nil:
+ str = "null"
+ case int64:
+ str = strconv.FormatInt(arg, 10)
+ case float64:
+ str = strconv.FormatFloat(arg, 'f', -1, 64)
+ case bool:
+ str = strconv.FormatBool(arg)
+ case []byte:
+ str = QuoteBytes(arg)
+ case string:
+ str = QuoteString(arg)
+ case time.Time:
+ str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
+ default:
+ return "", fmt.Errorf("invalid arg type: %T", arg)
+ }
+ argUse[argIdx] = true
+ default:
+ return "", fmt.Errorf("invalid Part type: %T", part)
+ }
+ buf.WriteString(str)
+ }
+
+ for i, used := range argUse {
+ if !used {
+ return "", fmt.Errorf("unused argument: %d", i)
+ }
+ }
+ return buf.String(), nil
+}
+
+func NewQuery(sql string) (*Query, error) {
+ l := &sqlLexer{
+ src: sql,
+ stateFn: rawState,
+ }
+
+ for l.stateFn != nil {
+ l.stateFn = l.stateFn(l)
+ }
+
+ query := &Query{Parts: l.parts}
+
+ return query, nil
+}
+
+func QuoteString(str string) string {
+ return "'" + strings.ReplaceAll(str, "'", "''") + "'"
+}
+
+func QuoteBytes(buf []byte) string {
+ return `'\x` + hex.EncodeToString(buf) + "'"
+}
+
+type sqlLexer struct {
+ src string
+ start int
+ pos int
+ nested int // multiline comment nesting level.
+ stateFn stateFn
+ parts []Part
+}
+
+type stateFn func(*sqlLexer) stateFn
+
+func rawState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case 'e', 'E':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune == '\'' {
+ l.pos += width
+ return escapeStringState
+ }
+ case '\'':
+ return singleQuoteState
+ case '"':
+ return doubleQuoteState
+ case '$':
+ nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
+ if '0' <= nextRune && nextRune <= '9' {
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos-width])
+ }
+ l.start = l.pos
+ return placeholderState
+ }
+ case '-':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune == '-' {
+ l.pos += width
+ return oneLineCommentState
+ }
+ case '/':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune == '*' {
+ l.pos += width
+ return multilineCommentState
+ }
+ case utf8.RuneError:
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+}
+
+func singleQuoteState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '\'':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune != '\'' {
+ return rawState
+ }
+ l.pos += width
+ case utf8.RuneError:
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+}
+
+func doubleQuoteState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '"':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune != '"' {
+ return rawState
+ }
+ l.pos += width
+ case utf8.RuneError:
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+}
+
+// placeholderState consumes a placeholder value. The $ must have already has
+// already been consumed. The first rune must be a digit.
+func placeholderState(l *sqlLexer) stateFn {
+ num := 0
+
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ if '0' <= r && r <= '9' {
+ num *= 10
+ num += int(r - '0')
+ } else {
+ l.parts = append(l.parts, num)
+ l.pos -= width
+ l.start = l.pos
+ return rawState
+ }
+ }
+}
+
+func escapeStringState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '\\':
+ _, width = utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+ case '\'':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune != '\'' {
+ return rawState
+ }
+ l.pos += width
+ case utf8.RuneError:
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+}
+
+func oneLineCommentState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '\\':
+ _, width = utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+ case '\n':
+ return rawState
+ case utf8.RuneError:
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+}
+
+func multilineCommentState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '/':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune == '*' {
+ l.pos += width
+ l.nested++
+ }
+ case '*':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune != '/' {
+ continue
+ }
+
+ l.pos += width
+ if l.nested == 0 {
+ return rawState
+ }
+ l.nested--
+
+ case utf8.RuneError:
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+}
+
+// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
+// as necessary. This function is only safe when standard_conforming_strings is
+// on.
+func SanitizeSQL(sql string, args ...interface{}) (string, error) {
+ query, err := NewQuery(sql)
+ if err != nil {
+ return "", err
+ }
+ return query.Sanitize(args...)
+}