diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/dialect/pgdialect/range.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/dialect/pgdialect/range.go | 132 |
1 files changed, 15 insertions, 117 deletions
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go index b942a068e..936ad5521 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go @@ -1,15 +1,12 @@ package pgdialect import ( - "bytes" "database/sql" - "encoding/hex" "fmt" "io" "time" "github.com/uptrace/bun/internal" - "github.com/uptrace/bun/internal/parser" "github.com/uptrace/bun/schema" ) @@ -41,7 +38,10 @@ func NewRange[T any](lower, upper T) Range[T] { var _ sql.Scanner = (*Range[any])(nil) func (r *Range[T]) Scan(anySrc any) (err error) { - src := anySrc.([]byte) + src, ok := anySrc.([]byte) + if !ok { + return fmt.Errorf("pgdialect: Range can't scan %T", anySrc) + } if len(src) == 0 { return io.ErrUnexpectedEOF @@ -90,18 +90,6 @@ func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) return buf, nil } -func appendElem(buf []byte, val any) []byte { - switch val := val.(type) { - case time.Time: - buf = append(buf, '"') - buf = appendTime(buf, val) - buf = append(buf, '"') - return buf - default: - panic(fmt.Errorf("unsupported range type: %T", val)) - } -} - func scanElem(ptr any, src []byte) ([]byte, error) { switch ptr := ptr.(type) { case *time.Time: @@ -117,6 +105,17 @@ func scanElem(ptr any, src []byte) ([]byte, error) { *ptr = tm return src, nil + + case sql.Scanner: + src, str, err := readStringLiteral(src) + if err != nil { + return nil, err + } + if err := ptr.Scan(str); err != nil { + return nil, err + } + return src, nil + default: panic(fmt.Errorf("unsupported range type: %T", ptr)) } @@ -137,104 +136,3 @@ func readStringLiteral(src []byte) ([]byte, []byte, error) { src = p.Remaining() return src, str, nil } - -//------------------------------------------------------------------------------ - -type pgparser struct { - parser.Parser - buf []byte -} - -func newParser(b []byte) *pgparser { - p := new(pgparser) - p.Reset(b) - return p -} - -func (p *pgparser) ReadLiteral(ch byte) []byte { - p.Unread() - lit, _ := p.ReadSep(',') - return lit -} - -func (p *pgparser) ReadUnescapedSubstring(ch byte) ([]byte, error) { - return p.readSubstring(ch, false) -} - -func (p *pgparser) ReadSubstring(ch byte) ([]byte, error) { - return p.readSubstring(ch, true) -} - -func (p *pgparser) readSubstring(ch byte, escaped bool) ([]byte, error) { - ch, err := p.ReadByte() - if err != nil { - return nil, err - } - - p.buf = p.buf[:0] - for { - if ch == '"' { - break - } - - next, err := p.ReadByte() - if err != nil { - return nil, err - } - - if ch == '\\' { - switch next { - case '\\', '"': - p.buf = append(p.buf, next) - - ch, err = p.ReadByte() - if err != nil { - return nil, err - } - default: - p.buf = append(p.buf, '\\') - ch = next - } - continue - } - - if escaped && ch == '\'' && next == '\'' { - p.buf = append(p.buf, next) - ch, err = p.ReadByte() - if err != nil { - return nil, err - } - continue - } - - p.buf = append(p.buf, ch) - ch = next - } - - if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 { - data := p.buf[2:] - buf := make([]byte, hex.DecodedLen(len(data))) - n, err := hex.Decode(buf, data) - if err != nil { - return nil, err - } - return buf[:n], nil - } - - return p.buf, nil -} - -func (p *pgparser) ReadRange(ch byte) ([]byte, error) { - p.buf = p.buf[:0] - p.buf = append(p.buf, ch) - - for p.Valid() { - ch = p.Read() - p.buf = append(p.buf, ch) - if ch == ']' || ch == ')' { - break - } - } - - return p.buf, nil -} |