summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/dialect/pgdialect/range.go')
-rw-r--r--vendor/github.com/uptrace/bun/dialect/pgdialect/range.go240
1 files changed, 240 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go
new file mode 100644
index 000000000..b942a068e
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go
@@ -0,0 +1,240 @@
+package pgdialect
+
+import (
+ "bytes"
+ "database/sql"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "time"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/internal/parser"
+ "github.com/uptrace/bun/schema"
+)
+
+type MultiRange[T any] []Range[T]
+
+type Range[T any] struct {
+ Lower, Upper T
+ LowerBound, UpperBound RangeBound
+}
+
+type RangeBound byte
+
+const (
+ RangeBoundInclusiveLeft RangeBound = '['
+ RangeBoundInclusiveRight RangeBound = ']'
+ RangeBoundExclusiveLeft RangeBound = '('
+ RangeBoundExclusiveRight RangeBound = ')'
+)
+
+func NewRange[T any](lower, upper T) Range[T] {
+ return Range[T]{
+ Lower: lower,
+ Upper: upper,
+ LowerBound: RangeBoundInclusiveLeft,
+ UpperBound: RangeBoundExclusiveRight,
+ }
+}
+
+var _ sql.Scanner = (*Range[any])(nil)
+
+func (r *Range[T]) Scan(anySrc any) (err error) {
+ src := anySrc.([]byte)
+
+ if len(src) == 0 {
+ return io.ErrUnexpectedEOF
+ }
+ r.LowerBound = RangeBound(src[0])
+ src = src[1:]
+
+ src, err = scanElem(&r.Lower, src)
+ if err != nil {
+ return err
+ }
+
+ if len(src) == 0 {
+ return io.ErrUnexpectedEOF
+ }
+ if ch := src[0]; ch != ',' {
+ return fmt.Errorf("got %q, wanted %q", ch, ',')
+ }
+ src = src[1:]
+
+ src, err = scanElem(&r.Upper, src)
+ if err != nil {
+ return err
+ }
+
+ if len(src) == 0 {
+ return io.ErrUnexpectedEOF
+ }
+ r.UpperBound = RangeBound(src[0])
+ src = src[1:]
+
+ if len(src) > 0 {
+ return fmt.Errorf("unread data: %q", src)
+ }
+ return nil
+}
+
+var _ schema.QueryAppender = (*Range[any])(nil)
+
+func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) {
+ buf = append(buf, byte(r.LowerBound))
+ buf = appendElem(buf, r.Lower)
+ buf = append(buf, ',')
+ buf = appendElem(buf, r.Upper)
+ buf = append(buf, byte(r.UpperBound))
+ return buf, nil
+}
+
+func appendElem(buf []byte, val any) []byte {
+ switch val := val.(type) {
+ case time.Time:
+ buf = append(buf, '"')
+ buf = appendTime(buf, val)
+ buf = append(buf, '"')
+ return buf
+ default:
+ panic(fmt.Errorf("unsupported range type: %T", val))
+ }
+}
+
+func scanElem(ptr any, src []byte) ([]byte, error) {
+ switch ptr := ptr.(type) {
+ case *time.Time:
+ src, str, err := readStringLiteral(src)
+ if err != nil {
+ return nil, err
+ }
+
+ tm, err := internal.ParseTime(internal.String(str))
+ if err != nil {
+ return nil, err
+ }
+ *ptr = tm
+
+ return src, nil
+ default:
+ panic(fmt.Errorf("unsupported range type: %T", ptr))
+ }
+}
+
+func readStringLiteral(src []byte) ([]byte, []byte, error) {
+ p := newParser(src)
+
+ if err := p.Skip('"'); err != nil {
+ return nil, nil, err
+ }
+
+ str, err := p.ReadSubstring('"')
+ if err != nil {
+ return nil, nil, err
+ }
+
+ src = p.Remaining()
+ return src, str, nil
+}
+
+//------------------------------------------------------------------------------
+
+type pgparser struct {
+ parser.Parser
+ buf []byte
+}
+
+func newParser(b []byte) *pgparser {
+ p := new(pgparser)
+ p.Reset(b)
+ return p
+}
+
+func (p *pgparser) ReadLiteral(ch byte) []byte {
+ p.Unread()
+ lit, _ := p.ReadSep(',')
+ return lit
+}
+
+func (p *pgparser) ReadUnescapedSubstring(ch byte) ([]byte, error) {
+ return p.readSubstring(ch, false)
+}
+
+func (p *pgparser) ReadSubstring(ch byte) ([]byte, error) {
+ return p.readSubstring(ch, true)
+}
+
+func (p *pgparser) readSubstring(ch byte, escaped bool) ([]byte, error) {
+ ch, err := p.ReadByte()
+ if err != nil {
+ return nil, err
+ }
+
+ p.buf = p.buf[:0]
+ for {
+ if ch == '"' {
+ break
+ }
+
+ next, err := p.ReadByte()
+ if err != nil {
+ return nil, err
+ }
+
+ if ch == '\\' {
+ switch next {
+ case '\\', '"':
+ p.buf = append(p.buf, next)
+
+ ch, err = p.ReadByte()
+ if err != nil {
+ return nil, err
+ }
+ default:
+ p.buf = append(p.buf, '\\')
+ ch = next
+ }
+ continue
+ }
+
+ if escaped && ch == '\'' && next == '\'' {
+ p.buf = append(p.buf, next)
+ ch, err = p.ReadByte()
+ if err != nil {
+ return nil, err
+ }
+ continue
+ }
+
+ p.buf = append(p.buf, ch)
+ ch = next
+ }
+
+ if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 {
+ data := p.buf[2:]
+ buf := make([]byte, hex.DecodedLen(len(data)))
+ n, err := hex.Decode(buf, data)
+ if err != nil {
+ return nil, err
+ }
+ return buf[:n], nil
+ }
+
+ return p.buf, nil
+}
+
+func (p *pgparser) ReadRange(ch byte) ([]byte, error) {
+ p.buf = p.buf[:0]
+ p.buf = append(p.buf, ch)
+
+ for p.Valid() {
+ ch = p.Read()
+ p.buf = append(p.buf, ch)
+ if ch == ']' || ch == ')' {
+ break
+ }
+ }
+
+ return p.buf, nil
+}