diff options
Diffstat (limited to 'vendor')
23 files changed, 1041 insertions, 1232 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md index ec90631cd..cc090da4c 100644 --- a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md +++ b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md @@ -1,3 +1,31 @@ +# 5.4.1 (June 18, 2023) + +* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov) +* Add TxOptions.BeginQuery to allow overriding the default BEGIN query + +# 5.4.0 (June 14, 2023) + +* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues. +* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov) +* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic +* CancelRequest: don't try to read the reply (Nicola Murino) +* Fix: correctly handle bool type aliases (Wichert Akkerman) +* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr() +* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones) +* Add BeforeClose to pgxpool.Pool (Evan Cordell) +* Fix: various hstore fixes and optimizations (Evan Jones) +* Fix: RowToStructByPos with embedded unexported struct +* Support different bool string representations (Lev Zakharov) +* Fix: error when using BatchResults.Exec on a select that returns an error after some rows. +* Fix: pipelineBatchResults.Exec() not returning error from ResultReader +* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using +    a callback. +* Fix: scanning a table type into a struct +* Fix: scan array of record to pointer to slice of struct +* Fix: handle null for json (Cemre Mengu) +* Batch Query callback is called even when there is an error +* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P) +  # 5.3.1 (February 27, 2023)  * Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka) diff --git a/vendor/github.com/jackc/pgx/v5/README.md b/vendor/github.com/jackc/pgx/v5/README.md index 29d9521c6..14327f2c6 100644 --- a/vendor/github.com/jackc/pgx/v5/README.md +++ b/vendor/github.com/jackc/pgx/v5/README.md @@ -132,13 +132,24 @@ These adapters can be used with the tracelog package.  * [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)  * [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)  * [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) +* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog)  ## 3rd Party Libraries with PGX Support +### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock) + +pgxmock is a mock library implementing pgx interfaces.  +pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection.  +  ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)  Library for scanning data from a database into Go structs and more. +### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql) + +A carefully designed SQL client for making using SQL easier, +more productive, and less error-prone on Golang. +  ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)  Adds GSSAPI / Kerberos authentication support. diff --git a/vendor/github.com/jackc/pgx/v5/batch.go b/vendor/github.com/jackc/pgx/v5/batch.go index af62039f8..8f6ea4f0d 100644 --- a/vendor/github.com/jackc/pgx/v5/batch.go +++ b/vendor/github.com/jackc/pgx/v5/batch.go @@ -21,13 +21,10 @@ type batchItemFunc func(br BatchResults) error  // Query sets fn to be called when the response to qq is received.  func (qq *QueuedQuery) Query(fn func(rows Rows) error) {  	qq.fn = func(br BatchResults) error { -		rows, err := br.Query() -		if err != nil { -			return err -		} +		rows, _ := br.Query()  		defer rows.Close() -		err = fn(rows) +		err := fn(rows)  		if err != nil {  			return err  		} @@ -142,7 +139,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {  	}  	commandTag, err := br.mrr.ResultReader().Close() -	br.err = err +	if err != nil { +		br.err = err +		br.mrr.Close() +	}  	if br.conn.batchTracer != nil {  		br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ @@ -228,7 +228,7 @@ func (br *batchResults) Close() error {  	for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {  		if br.b.queuedQueries[br.qqIdx].fn != nil {  			err := br.b.queuedQueries[br.qqIdx].fn(br) -			if err != nil && br.err == nil { +			if err != nil {  				br.err = err  			}  		} else { @@ -290,7 +290,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {  	results, err := br.pipeline.GetResults()  	if err != nil {  		br.err = err -		return pgconn.CommandTag{}, err +		return pgconn.CommandTag{}, br.err  	}  	var commandTag pgconn.CommandTag  	switch results := results.(type) { @@ -309,7 +309,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {  		})  	} -	return commandTag, err +	return commandTag, br.err  }  // Query reads the results from the next query in the batch as if the query has been sent with Query. @@ -384,24 +384,20 @@ func (br *pipelineBatchResults) Close() error {  		}  	}() -	if br.err != nil { -		return br.err -	} - -	if br.lastRows != nil && br.lastRows.err != nil { +	if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {  		br.err = br.lastRows.err  		return br.err  	}  	if br.closed { -		return nil +		return br.err  	}  	// Read and run fn for all remaining items  	for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {  		if br.b.queuedQueries[br.qqIdx].fn != nil {  			err := br.b.queuedQueries[br.qqIdx].fn(br) -			if err != nil && br.err == nil { +			if err != nil {  				br.err = err  			}  		} else { diff --git a/vendor/github.com/jackc/pgx/v5/conn.go b/vendor/github.com/jackc/pgx/v5/conn.go index 92b6f3e4a..a609d1002 100644 --- a/vendor/github.com/jackc/pgx/v5/conn.go +++ b/vendor/github.com/jackc/pgx/v5/conn.go @@ -178,7 +178,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con  		case "simple_protocol":  			defaultQueryExecMode = QueryExecModeSimpleProtocol  		default: -			return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) +			return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s)  		}  	} @@ -382,11 +382,9 @@ func quoteIdentifier(s string) string {  	return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`  } -// Ping executes an empty sql statement against the *Conn -// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. +// Ping delegates to the underlying *pgconn.PgConn.Ping.  func (c *Conn) Ping(ctx context.Context) error { -	_, err := c.Exec(ctx, ";") -	return err +	return c.pgConn.Ping(ctx)  }  // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the @@ -585,8 +583,10 @@ const (  	QueryExecModeCacheDescribe  	// Get the statement description on every execution. This uses the extended protocol. Queries require two round trips -	// to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even -	// when the the database schema is modified concurrently. +	// to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the +	// statement description on the first round trip and then uses it to execute the query on the second round trip. This +	// may cause problems with connection poolers that switch the underlying connection between round trips. It is safe +	// even when the the database schema is modified concurrently.  	QueryExecModeDescribeExec  	// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol @@ -648,6 +648,9 @@ type QueryRewriter interface {  // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It  // is allowed to ignore the error returned from Query and handle it in Rows.  // +// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not +// return an error. +//  // It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be  // collected before processing rather than processed while receiving each row. This avoids the possibility of the  // application processing rows from a query that the server rejected. The CollectRows function is useful here. @@ -975,7 +978,7 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR  func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {  	if c.statementCache == nil { -		return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} +		return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}  	}  	distinctNewQueries := []*pgconn.StatementDescription{} @@ -1007,7 +1010,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc  func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {  	if c.descriptionCache == nil { -		return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} +		return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}  	}  	distinctNewQueries := []*pgconn.StatementDescription{} @@ -1074,18 +1077,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d  		err := pipeline.Sync()  		if err != nil { -			return &pipelineBatchResults{ctx: ctx, conn: c, err: err} +			return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}  		}  		for _, sd := range distinctNewQueries {  			results, err := pipeline.GetResults()  			if err != nil { -				return &pipelineBatchResults{ctx: ctx, conn: c, err: err} +				return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}  			}  			resultSD, ok := results.(*pgconn.StatementDescription)  			if !ok { -				return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} +				return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}  			}  			// Fill in the previously empty / pending statement descriptions. @@ -1095,12 +1098,12 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d  		results, err := pipeline.GetResults()  		if err != nil { -			return &pipelineBatchResults{ctx: ctx, conn: c, err: err} +			return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}  		}  		_, ok := results.(*pgconn.PipelineSync)  		if !ok { -			return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} +			return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}  		}  	} @@ -1117,7 +1120,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d  		if err != nil {  			// we wrap the error so we the user can understand which query failed inside the batch  			err = fmt.Errorf("error building query %s: %w", bi.query, err) -			return &pipelineBatchResults{ctx: ctx, conn: c, err: err} +			return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}  		}  		if bi.sd.Name == "" { @@ -1129,7 +1132,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d  	err := pipeline.Sync()  	if err != nil { -		return &pipelineBatchResults{ctx: ctx, conn: c, err: err} +		return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}  	}  	return &pipelineBatchResults{ @@ -1282,7 +1285,9 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com  	var fieldOID uint32  	rows, _ := c.Query(ctx, `select attname, atttypid  from pg_attribute -where attrelid=$1 and not attisdropped +where attrelid=$1 +	and not attisdropped +	and attnum > 0  order by attnum`,  		typrelid,  	) @@ -1324,6 +1329,7 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error  	for _, sd := range invalidatedStatements {  		pipeline.SendDeallocate(sd.Name) +		delete(c.preparedStatements, sd.Name)  	}  	err := pipeline.Sync() diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go deleted file mode 100644 index 4bf25481c..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go +++ /dev/null @@ -1,70 +0,0 @@ -package nbconn - -import ( -	"sync" -) - -const minBufferQueueLen = 8 - -type bufferQueue struct { -	lock  sync.Mutex -	queue []*[]byte -	r, w  int -} - -func (bq *bufferQueue) pushBack(buf *[]byte) { -	bq.lock.Lock() -	defer bq.lock.Unlock() - -	if bq.w >= len(bq.queue) { -		bq.growQueue() -	} -	bq.queue[bq.w] = buf -	bq.w++ -} - -func (bq *bufferQueue) pushFront(buf *[]byte) { -	bq.lock.Lock() -	defer bq.lock.Unlock() - -	if bq.w >= len(bq.queue) { -		bq.growQueue() -	} -	copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) -	bq.queue[bq.r] = buf -	bq.w++ -} - -func (bq *bufferQueue) popFront() *[]byte { -	bq.lock.Lock() -	defer bq.lock.Unlock() - -	if bq.r == bq.w { -		return nil -	} - -	buf := bq.queue[bq.r] -	bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. -	bq.r++ - -	if bq.r == bq.w { -		bq.r = 0 -		bq.w = 0 -		if len(bq.queue) > minBufferQueueLen { -			bq.queue = make([]*[]byte, minBufferQueueLen) -		} -	} - -	return buf -} - -func (bq *bufferQueue) growQueue() { -	desiredLen := (len(bq.queue) + 1) * 3 / 2 -	if desiredLen < minBufferQueueLen { -		desiredLen = minBufferQueueLen -	} - -	newQueue := make([]*[]byte, desiredLen) -	copy(newQueue, bq.queue) -	bq.queue = newQueue -} diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go deleted file mode 100644 index 7a38383f0..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go +++ /dev/null @@ -1,520 +0,0 @@ -// Package nbconn implements a non-blocking net.Conn wrapper. -// -// It is designed to solve three problems. -// -// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all -// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. -// -// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. -// -// The third is to efficiently check if a connection has been closed via a non-blocking read. -package nbconn - -import ( -	"crypto/tls" -	"errors" -	"net" -	"os" -	"sync" -	"sync/atomic" -	"syscall" -	"time" - -	"github.com/jackc/pgx/v5/internal/iobufpool" -) - -var errClosed = errors.New("closed") -var ErrWouldBlock = new(wouldBlockError) - -const fakeNonblockingWriteWaitDuration = 100 * time.Millisecond -const minNonblockingReadWaitDuration = time.Microsecond -const maxNonblockingReadWaitDuration = 100 * time.Millisecond - -// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read -// mode. -var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) - -// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to -// ignore all future calls. -var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) - -// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. -type wouldBlockError struct{} - -func (*wouldBlockError) Error() string { -	return "would block" -} - -func (*wouldBlockError) Timeout() bool   { return true } -func (*wouldBlockError) Temporary() bool { return true } - -// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to -// the underlying connection. -type Conn interface { -	net.Conn - -	// Flush flushes any buffered writes. -	Flush() error - -	// BufferReadUntilBlock reads and buffers any successfully read bytes until the read would block. -	BufferReadUntilBlock() error -} - -// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. -type NetConn struct { -	// 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit -	// architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and -	// https://github.com/jackc/pgx/issues/1307. Only access with atomics -	closed int64 // 0 = not closed, 1 = closed - -	conn    net.Conn -	rawConn syscall.RawConn - -	readQueue  bufferQueue -	writeQueue bufferQueue - -	readFlushLock sync.Mutex -	// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the -	// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. -	nonblockWriteFunc func(fd uintptr) (done bool) -	nonblockWriteBuf  []byte -	nonblockWriteErr  error -	nonblockWriteN    int - -	// non-blocking reads with syscall.RawConn are done with a callback function. By using these fields instead of the -	// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. -	nonblockReadFunc func(fd uintptr) (done bool) -	nonblockReadBuf  []byte -	nonblockReadErr  error -	nonblockReadN    int - -	readDeadlineLock                sync.Mutex -	readDeadline                    time.Time -	readNonblocking                 bool -	fakeNonBlockingShortReadCount   int -	fakeNonblockingReadWaitDuration time.Duration - -	writeDeadlineLock sync.Mutex -	writeDeadline     time.Time -} - -func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { -	nc := &NetConn{ -		conn:                            conn, -		fakeNonblockingReadWaitDuration: maxNonblockingReadWaitDuration, -	} - -	if !fakeNonBlockingIO { -		if sc, ok := conn.(syscall.Conn); ok { -			if rawConn, err := sc.SyscallConn(); err == nil { -				nc.rawConn = rawConn -			} -		} -	} - -	return nc -} - -// Read implements io.Reader. -func (c *NetConn) Read(b []byte) (n int, err error) { -	if c.isClosed() { -		return 0, errClosed -	} - -	c.readFlushLock.Lock() -	defer c.readFlushLock.Unlock() - -	err = c.flush() -	if err != nil { -		return 0, err -	} - -	for n < len(b) { -		buf := c.readQueue.popFront() -		if buf == nil { -			break -		} -		copiedN := copy(b[n:], *buf) -		if copiedN < len(*buf) { -			*buf = (*buf)[copiedN:] -			c.readQueue.pushFront(buf) -		} else { -			iobufpool.Put(buf) -		} -		n += copiedN -	} - -	// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to -	// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. -	if n > 0 { -		return n, nil -	} - -	var readNonblocking bool -	c.readDeadlineLock.Lock() -	readNonblocking = c.readNonblocking -	c.readDeadlineLock.Unlock() - -	var readN int -	if readNonblocking { -		readN, err = c.nonblockingRead(b[n:]) -	} else { -		readN, err = c.conn.Read(b[n:]) -	} -	n += readN -	return n, err -} - -// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is -// closed. Call Flush to actually write to the underlying connection. -func (c *NetConn) Write(b []byte) (n int, err error) { -	if c.isClosed() { -		return 0, errClosed -	} - -	buf := iobufpool.Get(len(b)) -	copy(*buf, b) -	c.writeQueue.pushBack(buf) -	return len(b), nil -} - -func (c *NetConn) Close() (err error) { -	swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) -	if !swapped { -		return errClosed -	} - -	defer func() { -		closeErr := c.conn.Close() -		if err == nil { -			err = closeErr -		} -	}() - -	c.readFlushLock.Lock() -	defer c.readFlushLock.Unlock() -	err = c.flush() -	if err != nil { -		return err -	} - -	return nil -} - -func (c *NetConn) LocalAddr() net.Addr { -	return c.conn.LocalAddr() -} - -func (c *NetConn) RemoteAddr() net.Addr { -	return c.conn.RemoteAddr() -} - -// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). -func (c *NetConn) SetDeadline(t time.Time) error { -	err := c.SetReadDeadline(t) -	if err != nil { -		return err -	} -	return c.SetWriteDeadline(t) -} - -// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. -func (c *NetConn) SetReadDeadline(t time.Time) error { -	if c.isClosed() { -		return errClosed -	} - -	c.readDeadlineLock.Lock() -	defer c.readDeadlineLock.Unlock() -	if c.readDeadline == disableSetDeadlineDeadline { -		return nil -	} -	if t == disableSetDeadlineDeadline { -		c.readDeadline = t -		return nil -	} - -	if t == NonBlockingDeadline { -		c.readNonblocking = true -		t = time.Time{} -	} else { -		c.readNonblocking = false -	} - -	c.readDeadline = t - -	return c.conn.SetReadDeadline(t) -} - -func (c *NetConn) SetWriteDeadline(t time.Time) error { -	if c.isClosed() { -		return errClosed -	} - -	c.writeDeadlineLock.Lock() -	defer c.writeDeadlineLock.Unlock() -	if c.writeDeadline == disableSetDeadlineDeadline { -		return nil -	} -	if t == disableSetDeadlineDeadline { -		c.writeDeadline = t -		return nil -	} - -	c.writeDeadline = t - -	return c.conn.SetWriteDeadline(t) -} - -func (c *NetConn) Flush() error { -	if c.isClosed() { -		return errClosed -	} - -	c.readFlushLock.Lock() -	defer c.readFlushLock.Unlock() -	return c.flush() -} - -// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. -func (c *NetConn) flush() error { -	var stopChan chan struct{} -	var errChan chan error - -	defer func() { -		if stopChan != nil { -			select { -			case stopChan <- struct{}{}: -			case <-errChan: -			} -		} -	}() - -	for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { -		remainingBuf := *buf -		for len(remainingBuf) > 0 { -			n, err := c.nonblockingWrite(remainingBuf) -			remainingBuf = remainingBuf[n:] -			if err != nil { -				if !errors.Is(err, ErrWouldBlock) { -					*buf = (*buf)[:len(remainingBuf)] -					copy(*buf, remainingBuf) -					c.writeQueue.pushFront(buf) -					return err -				} - -				// Writing was blocked. Reading might unblock it. -				if stopChan == nil { -					stopChan, errChan = c.bufferNonblockingRead() -				} - -				select { -				case err := <-errChan: -					stopChan = nil -					return err -				default: -				} - -			} -		} -		iobufpool.Put(buf) -	} - -	return nil -} - -func (c *NetConn) BufferReadUntilBlock() error { -	for { -		buf := iobufpool.Get(8 * 1024) -		n, err := c.nonblockingRead(*buf) -		if n > 0 { -			*buf = (*buf)[:n] -			c.readQueue.pushBack(buf) -		} else if n == 0 { -			iobufpool.Put(buf) -		} - -		if err != nil { -			if errors.Is(err, ErrWouldBlock) { -				return nil -			} else { -				return err -			} -		} -	} -} - -func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { -	stopChan = make(chan struct{}) -	errChan = make(chan error, 1) - -	go func() { -		for { -			err := c.BufferReadUntilBlock() -			if err != nil { -				errChan <- err -				return -			} - -			select { -			case <-stopChan: -				return -			default: -			} -		} -	}() - -	return stopChan, errChan -} - -func (c *NetConn) isClosed() bool { -	closed := atomic.LoadInt64(&c.closed) -	return closed == 1 -} - -func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { -	if c.rawConn == nil { -		return c.fakeNonblockingWrite(b) -	} else { -		return c.realNonblockingWrite(b) -	} -} - -func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { -	c.writeDeadlineLock.Lock() -	defer c.writeDeadlineLock.Unlock() - -	deadline := time.Now().Add(fakeNonblockingWriteWaitDuration) -	if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { -		err = c.conn.SetWriteDeadline(deadline) -		if err != nil { -			return 0, err -		} -		defer func() { -			// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. -			c.conn.SetWriteDeadline(c.writeDeadline) - -			if err != nil { -				if errors.Is(err, os.ErrDeadlineExceeded) { -					err = ErrWouldBlock -				} -			} -		}() -	} - -	return c.conn.Write(b) -} - -func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { -	if c.rawConn == nil { -		return c.fakeNonblockingRead(b) -	} else { -		return c.realNonblockingRead(b) -	} -} - -func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { -	c.readDeadlineLock.Lock() -	defer c.readDeadlineLock.Unlock() - -	// The first 5 reads only read 1 byte at a time. This should give us 4 chances to read when we are sure the bytes are -	// already in Go or the OS's receive buffer. -	if c.fakeNonBlockingShortReadCount < 5 && len(b) > 0 && c.fakeNonblockingReadWaitDuration < minNonblockingReadWaitDuration { -		b = b[:1] -	} - -	startTime := time.Now() -	deadline := startTime.Add(c.fakeNonblockingReadWaitDuration) -	if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { -		err = c.conn.SetReadDeadline(deadline) -		if err != nil { -			return 0, err -		} -		defer func() { -			// If the read was successful and the wait duration is not already the minimum -			if err == nil && c.fakeNonblockingReadWaitDuration > minNonblockingReadWaitDuration { -				endTime := time.Now() - -				if n > 0 && c.fakeNonBlockingShortReadCount < 5 { -					c.fakeNonBlockingShortReadCount++ -				} - -				// The wait duration should be 2x the fastest read that has occurred. This should give reasonable assurance that -				// a Read deadline will not block a read before it has a chance to read data already in Go or the OS's receive -				// buffer. -				proposedWait := endTime.Sub(startTime) * 2 -				if proposedWait < minNonblockingReadWaitDuration { -					proposedWait = minNonblockingReadWaitDuration -				} -				if proposedWait < c.fakeNonblockingReadWaitDuration { -					c.fakeNonblockingReadWaitDuration = proposedWait -				} -			} - -			// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. -			c.conn.SetReadDeadline(c.readDeadline) - -			if err != nil { -				if errors.Is(err, os.ErrDeadlineExceeded) { -					err = ErrWouldBlock -				} -			} -		}() -	} - -	return c.conn.Read(b) -} - -// syscall.Conn is interface - -// TLSClient establishes a TLS connection as a client over conn using config. -// -// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby -// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the -// *TLSConn is returned. -func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { -	tc := tls.Client(conn, config) -	err := tc.Handshake() -	if err != nil { -		return nil, err -	} - -	// Ensure last written part of Handshake is actually sent. -	err = conn.Flush() -	if err != nil { -		return nil, err -	} - -	return &TLSConn{ -		tlsConn: tc, -		nbConn:  conn, -	}, nil -} - -// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a -// tls.Conn. -type TLSConn struct { -	tlsConn *tls.Conn -	nbConn  *NetConn -} - -func (tc *TLSConn) Read(b []byte) (n int, err error)  { return tc.tlsConn.Read(b) } -func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } -func (tc *TLSConn) BufferReadUntilBlock() error       { return tc.nbConn.BufferReadUntilBlock() } -func (tc *TLSConn) Flush() error                      { return tc.nbConn.Flush() } -func (tc *TLSConn) LocalAddr() net.Addr               { return tc.tlsConn.LocalAddr() } -func (tc *TLSConn) RemoteAddr() net.Addr              { return tc.tlsConn.RemoteAddr() } - -func (tc *TLSConn) Close() error { -	// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then -	// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our -	// own 5 second deadline then make all set deadlines no-op. -	tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) -	tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) - -	return tc.tlsConn.Close() -} - -func (tc *TLSConn) SetDeadline(t time.Time) error      { return tc.tlsConn.SetDeadline(t) } -func (tc *TLSConn) SetReadDeadline(t time.Time) error  { return tc.tlsConn.SetReadDeadline(t) } -func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go deleted file mode 100644 index 4915c6219..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !unix - -package nbconn - -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { -	return c.fakeNonblockingWrite(b) -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { -	return c.fakeNonblockingRead(b) -} diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go deleted file mode 100644 index e93372f25..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go +++ /dev/null @@ -1,81 +0,0 @@ -//go:build unix - -package nbconn - -import ( -	"errors" -	"io" -	"syscall" -) - -// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { -	if c.nonblockWriteFunc == nil { -		c.nonblockWriteFunc = func(fd uintptr) (done bool) { -			c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) -			return true -		} -	} -	c.nonblockWriteBuf = b -	c.nonblockWriteN = 0 -	c.nonblockWriteErr = nil - -	err = c.rawConn.Write(c.nonblockWriteFunc) -	n = c.nonblockWriteN -	c.nonblockWriteBuf = nil // ensure that no reference to b is kept. -	if err == nil && c.nonblockWriteErr != nil { -		if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { -			err = ErrWouldBlock -		} else { -			err = c.nonblockWriteErr -		} -	} -	if err != nil { -		// n may be -1 when an error occurs. -		if n < 0 { -			n = 0 -		} - -		return n, err -	} - -	return n, nil -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { -	if c.nonblockReadFunc == nil { -		c.nonblockReadFunc = func(fd uintptr) (done bool) { -			c.nonblockReadN, c.nonblockReadErr = syscall.Read(int(fd), c.nonblockReadBuf) -			return true -		} -	} -	c.nonblockReadBuf = b -	c.nonblockReadN = 0 -	c.nonblockReadErr = nil - -	err = c.rawConn.Read(c.nonblockReadFunc) -	n = c.nonblockReadN -	c.nonblockReadBuf = nil // ensure that no reference to b is kept. -	if err == nil && c.nonblockReadErr != nil { -		if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) { -			err = ErrWouldBlock -		} else { -			err = c.nonblockReadErr -		} -	} -	if err != nil { -		// n may be -1 when an error occurs. -		if n < 0 { -			n = 0 -		} - -		return n, err -	} - -	// syscall read did not return an error and 0 bytes were read means EOF. -	if n == 0 { -		return 0, io.EOF -	} - -	return n, nil -} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go index 6ca9e3379..8c4b2de3c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go @@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {  		Data:          sc.clientFirstMessage(),  	}  	c.frontend.Send(saslInitialResponse) -	err = c.frontend.Flush() +	err = c.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		return err  	} @@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {  		Data: []byte(sc.clientFinalMessage()),  	}  	c.frontend.Send(saslResponse) -	err = c.frontend.Flush() +	err = c.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		return err  	} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go new file mode 100644 index 000000000..aa1a3d39c --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go @@ -0,0 +1,132 @@ +// Package bgreader provides a io.Reader that can optionally buffer reads in the background. +package bgreader + +import ( +	"io" +	"sync" + +	"github.com/jackc/pgx/v5/internal/iobufpool" +) + +const ( +	bgReaderStatusStopped = iota +	bgReaderStatusRunning +	bgReaderStatusStopping +) + +// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use. +type BGReader struct { +	r io.Reader + +	cond           *sync.Cond +	bgReaderStatus int32 +	readResults    []readResult +} + +type readResult struct { +	buf *[]byte +	err error +} + +// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background +// reader will stop automatically when the underlying reader returns an error. +func (r *BGReader) Start() { +	r.cond.L.Lock() +	defer r.cond.L.Unlock() + +	switch r.bgReaderStatus { +	case bgReaderStatusStopped: +		r.bgReaderStatus = bgReaderStatusRunning +		go r.bgRead() +	case bgReaderStatusRunning: +		// no-op +	case bgReaderStatusStopping: +		r.bgReaderStatus = bgReaderStatusRunning +	} +} + +// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the +// background reader is not running. +func (r *BGReader) Stop() { +	r.cond.L.Lock() +	defer r.cond.L.Unlock() + +	switch r.bgReaderStatus { +	case bgReaderStatusStopped: +		// no-op +	case bgReaderStatusRunning: +		r.bgReaderStatus = bgReaderStatusStopping +	case bgReaderStatusStopping: +		// no-op +	} +} + +func (r *BGReader) bgRead() { +	keepReading := true +	for keepReading { +		buf := iobufpool.Get(8192) +		n, err := r.r.Read(*buf) +		*buf = (*buf)[:n] + +		r.cond.L.Lock() +		r.readResults = append(r.readResults, readResult{buf: buf, err: err}) +		if r.bgReaderStatus == bgReaderStatusStopping || err != nil { +			r.bgReaderStatus = bgReaderStatusStopped +			keepReading = false +		} +		r.cond.L.Unlock() +		r.cond.Broadcast() +	} +} + +// Read implements the io.Reader interface. +func (r *BGReader) Read(p []byte) (int, error) { +	r.cond.L.Lock() +	defer r.cond.L.Unlock() + +	if len(r.readResults) > 0 { +		return r.readFromReadResults(p) +	} + +	// There are no unread background read results and the background reader is stopped. +	if r.bgReaderStatus == bgReaderStatusStopped { +		return r.r.Read(p) +	} + +	// Wait for results from the background reader +	for len(r.readResults) == 0 { +		r.cond.Wait() +	} +	return r.readFromReadResults(p) +} + +// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held. +func (r *BGReader) readFromReadResults(p []byte) (int, error) { +	buf := r.readResults[0].buf +	var err error + +	n := copy(p, *buf) +	if n == len(*buf) { +		err = r.readResults[0].err +		iobufpool.Put(buf) +		if len(r.readResults) == 1 { +			r.readResults = nil +		} else { +			r.readResults = r.readResults[1:] +		} +	} else { +		*buf = (*buf)[n:] +		r.readResults[0].buf = buf +	} + +	return n, err +} + +func New(r io.Reader) *BGReader { +	return &BGReader{ +		r: r, +		cond: &sync.Cond{ +			L: &sync.Mutex{}, +		}, +	} +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go index 969675fd2..3c1af3477 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go @@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {  			Data: nextData,  		}  		c.frontend.Send(gssResponse) -		err = c.frontend.Flush() +		err = c.flushWithPotentialWriteReadDeadlock()  		if err != nil {  			return err  		} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 8656ea518..9f84605fe 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -13,11 +13,12 @@ import (  	"net"  	"strconv"  	"strings" +	"sync"  	"time"  	"github.com/jackc/pgx/v5/internal/iobufpool" -	"github.com/jackc/pgx/v5/internal/nbconn"  	"github.com/jackc/pgx/v5/internal/pgio" +	"github.com/jackc/pgx/v5/pgconn/internal/bgreader"  	"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"  	"github.com/jackc/pgx/v5/pgproto3"  ) @@ -65,17 +66,24 @@ type NotificationHandler func(*PgConn, *Notification)  // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.  type PgConn struct { -	conn              nbconn.Conn       // the non-blocking wrapper for the underlying TCP or unix domain socket connection +	conn              net.Conn  	pid               uint32            // backend pid  	secretKey         uint32            // key to use to send a cancel query message to the server  	parameterStatuses map[string]string // parameters that have been reported by the server  	txStatus          byte  	frontend          *pgproto3.Frontend +	bgReader          *bgreader.BGReader +	slowWriteTimer    *time.Timer  	config *Config  	status byte // One of connStatus* constants +	bufferingReceive    bool +	bufferingReceiveMux sync.Mutex +	bufferingReceiveMsg pgproto3.BackendMessage +	bufferingReceiveErr error +  	peekedMsg pgproto3.BackendMessage  	// Reusable / preallocated resources @@ -266,14 +274,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig  	if err != nil {  		return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}  	} -	nbNetConn := nbconn.NewNetConn(netConn, false) -	pgConn.conn = nbNetConn -	pgConn.contextWatcher = newContextWatcher(nbNetConn) +	pgConn.conn = netConn +	pgConn.contextWatcher = newContextWatcher(netConn)  	pgConn.contextWatcher.Watch(ctx)  	if fallbackConfig.TLSConfig != nil { -		nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) +		nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)  		pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.  		if err != nil {  			netConn.Close() @@ -289,7 +296,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig  	pgConn.parameterStatuses = make(map[string]string)  	pgConn.status = connStatusConnecting -	pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) +	pgConn.bgReader = bgreader.New(pgConn.conn) +	pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) +	pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)  	startupMsg := pgproto3.StartupMessage{  		ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -307,9 +316,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig  	}  	pgConn.frontend.Send(&startupMsg) -	if err := pgConn.frontend.Flush(); err != nil { +	if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {  		pgConn.conn.Close() -		return nil, &connectError{config: config, msg: "failed to write startup message", err: err} +		return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}  	}  	for { @@ -392,7 +401,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {  	)  } -func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {  	err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})  	if err != nil {  		return nil, err @@ -407,17 +416,12 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err  		return nil, errors.New("server refused TLS connection")  	} -	tlsConn, err := nbconn.TLSClient(conn, tlsConfig) -	if err != nil { -		return nil, err -	} - -	return tlsConn, nil +	return tls.Client(conn, tlsConfig), nil  }  func (pgConn *PgConn) txPasswordMessage(password string) (err error) {  	pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) -	return pgConn.frontend.Flush() +	return pgConn.flushWithPotentialWriteReadDeadlock()  }  func hexMD5(s string) string { @@ -426,6 +430,24 @@ func hexMD5(s string) string {  	return hex.EncodeToString(hash.Sum(nil))  } +func (pgConn *PgConn) signalMessage() chan struct{} { +	if pgConn.bufferingReceive { +		panic("BUG: signalMessage when already in progress") +	} + +	pgConn.bufferingReceive = true +	pgConn.bufferingReceiveMux.Lock() + +	ch := make(chan struct{}) +	go func() { +		pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() +		pgConn.bufferingReceiveMux.Unlock() +		close(ch) +	}() + +	return ch +} +  // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the  // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages  // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -465,13 +487,25 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {  		return pgConn.peekedMsg, nil  	} -	msg, err := pgConn.frontend.Receive() - -	if err != nil { -		if errors.Is(err, nbconn.ErrWouldBlock) { -			return nil, err +	var msg pgproto3.BackendMessage +	var err error +	if pgConn.bufferingReceive { +		pgConn.bufferingReceiveMux.Lock() +		msg = pgConn.bufferingReceiveMsg +		err = pgConn.bufferingReceiveErr +		pgConn.bufferingReceiveMux.Unlock() +		pgConn.bufferingReceive = false + +		// If a timeout error happened in the background try the read again. +		var netErr net.Error +		if errors.As(err, &netErr) && netErr.Timeout() { +			msg, err = pgConn.frontend.Receive()  		} +	} else { +		msg, err = pgConn.frontend.Receive() +	} +	if err != nil {  		// Close on anything other than timeout error - everything else is fatal  		var netErr net.Error  		isNetErr := errors.As(err, &netErr) @@ -582,7 +616,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error {  	//  	// See https://github.com/jackc/pgx/issues/637  	pgConn.frontend.Send(&pgproto3.Terminate{}) -	pgConn.frontend.Flush() +	pgConn.flushWithPotentialWriteReadDeadlock()  	return pgConn.conn.Close()  } @@ -609,7 +643,7 @@ func (pgConn *PgConn) asyncClose() {  		pgConn.conn.SetDeadline(deadline)  		pgConn.frontend.Send(&pgproto3.Terminate{}) -		pgConn.frontend.Flush() +		pgConn.flushWithPotentialWriteReadDeadlock()  	}()  } @@ -784,7 +818,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [  	pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})  	pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})  	pgConn.frontend.SendSync(&pgproto3.Sync{}) -	err := pgConn.frontend.Flush() +	err := pgConn.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		pgConn.asyncClose()  		return nil, err @@ -857,9 +891,28 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {  	// the connection config. This is important in high availability configurations where fallback connections may be  	// specified or DNS may be used to load balance.  	serverAddr := pgConn.conn.RemoteAddr() -	cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) +	var serverNetwork string +	var serverAddress string +	if serverAddr.Network() == "unix" { +		// for unix sockets, RemoteAddr() calls getpeername() which returns the name the +		// server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432" +		// so connecting to it will fail. Fall back to the config's value +		serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port) +	} else { +		serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String() +	} +	cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress)  	if err != nil { -		return err +		// In case of unix sockets, RemoteAddr() returns only the file part of the path. If the +		// first connect failed, try the config. +		if serverAddr.Network() != "unix" { +			return err +		} +		serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port) +		cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr) +		if err != nil { +			return err +		}  	}  	defer cancelConn.Close() @@ -877,17 +930,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {  	binary.BigEndian.PutUint32(buf[4:8], 80877102)  	binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid))  	binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) +	// Postgres will process the request and close the connection +	// so when don't need to read the reply +	// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10  	_, err = cancelConn.Write(buf) -	if err != nil { -		return err -	} - -	_, err = cancelConn.Read(buf) -	if err != io.EOF { -		return err -	} - -	return nil +	return err  }  // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not @@ -953,7 +1000,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {  	}  	pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) -	err := pgConn.frontend.Flush() +	err := pgConn.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		pgConn.asyncClose()  		pgConn.contextWatcher.Unwatch() @@ -1064,7 +1111,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {  	pgConn.frontend.SendExecute(&pgproto3.Execute{})  	pgConn.frontend.SendSync(&pgproto3.Sync{}) -	err := pgConn.frontend.Flush() +	err := pgConn.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		pgConn.asyncClose()  		result.concludeCommand(CommandTag{}, err) @@ -1097,7 +1144,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm  	// Send copy to command  	pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) -	err := pgConn.frontend.Flush() +	err := pgConn.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		pgConn.asyncClose()  		pgConn.unlock() @@ -1153,85 +1200,91 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co  		defer pgConn.contextWatcher.Unwatch()  	} -	// Send copy to command +	// Send copy from query  	pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) -	err := pgConn.frontend.Flush() -	if err != nil { -		pgConn.asyncClose() -		return CommandTag{}, err -	} - -	err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) +	err := pgConn.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		pgConn.asyncClose()  		return CommandTag{}, err  	} -	nonblocking := true -	defer func() { -		if nonblocking { -			pgConn.conn.SetReadDeadline(time.Time{}) -		} -	}() -	buf := iobufpool.Get(65536) -	defer iobufpool.Put(buf) -	(*buf)[0] = 'd' +	// Send copy data +	abortCopyChan := make(chan struct{}) +	copyErrChan := make(chan error, 1) +	signalMessageChan := pgConn.signalMessage() +	var wg sync.WaitGroup +	wg.Add(1) -	var readErr, pgErr error -	for pgErr == nil { -		// Read chunk from r. -		var n int -		n, readErr = r.Read((*buf)[5:cap(*buf)]) +	go func() { +		defer wg.Done() +		buf := iobufpool.Get(65536) +		defer iobufpool.Put(buf) +		(*buf)[0] = 'd' -		// Send chunk to PostgreSQL. -		if n > 0 { -			*buf = (*buf)[0 : n+5] -			pgio.SetInt32((*buf)[1:], int32(n+4)) +		for { +			n, readErr := r.Read((*buf)[5:cap(*buf)]) +			if n > 0 { +				*buf = (*buf)[0 : n+5] +				pgio.SetInt32((*buf)[1:], int32(n+4)) + +				writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) +				if writeErr != nil { +					// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not +					// setting pgConn.status or closing pgConn.cleanupDone for the same reason. +					pgConn.conn.Close() -			writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) -			if writeErr != nil { -				pgConn.asyncClose() -				return CommandTag{}, err +					copyErrChan <- writeErr +					return +				} +			} +			if readErr != nil { +				copyErrChan <- readErr +				return  			} -		} -		// Abort loop if there was a read error. -		if readErr != nil { -			break +			select { +			case <-abortCopyChan: +				return +			default: +			}  		} +	}() -		// Read messages until error or none available. -		for pgErr == nil { -			msg, err := pgConn.receiveMessage() -			if err != nil { -				if errors.Is(err, nbconn.ErrWouldBlock) { -					break -				} -				pgConn.asyncClose() +	var pgErr error +	var copyErr error +	for copyErr == nil && pgErr == nil { +		select { +		case copyErr = <-copyErrChan: +		case <-signalMessageChan: +			// If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with +			// the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an +			// error is found then forcibly close the connection without sending the Terminate message. +			if err := pgConn.bufferingReceiveErr; err != nil { +				pgConn.status = connStatusClosed +				pgConn.conn.Close() +				close(pgConn.cleanupDone)  				return CommandTag{}, normalizeTimeoutError(ctx, err)  			} +			msg, _ := pgConn.receiveMessage()  			switch msg := msg.(type) {  			case *pgproto3.ErrorResponse:  				pgErr = ErrorResponseToPgError(msg) -				break +			default: +				signalMessageChan = pgConn.signalMessage()  			}  		}  	} +	close(abortCopyChan) +	// Make sure io goroutine finishes before writing. +	wg.Wait() -	err = pgConn.conn.SetReadDeadline(time.Time{}) -	if err != nil { -		pgConn.asyncClose() -		return CommandTag{}, err -	} -	nonblocking = false - -	if readErr == io.EOF || pgErr != nil { +	if copyErr == io.EOF || pgErr != nil {  		pgConn.frontend.Send(&pgproto3.CopyDone{})  	} else { -		pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) +		pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})  	} -	err = pgConn.frontend.Flush() +	err = pgConn.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		pgConn.asyncClose()  		return CommandTag{}, err @@ -1426,7 +1479,8 @@ func (rr *ResultReader) NextRow() bool {  }  // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the ResultReader is closed. +// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was +// encountered.)  func (rr *ResultReader) FieldDescriptions() []FieldDescription {  	return rr.fieldDescriptions  } @@ -1592,7 +1646,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR  	batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) +	pgConn.enterPotentialWriteReadDeadlock()  	_, err := pgConn.conn.Write(batch.buf) +	pgConn.exitPotentialWriteReadDeadlock()  	if err != nil {  		multiResult.closed = true  		multiResult.err = err @@ -1620,29 +1676,72 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {  	return strings.Replace(s, "'", "''", -1), nil  } -// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and -// buffering until the read would block or an error occurs. This can be used to check if the server has closed the -// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read +// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear +// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this +// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails  // without the client knowing whether the server received it or not. +// +// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where +// the write would still appear to succeed. Prefer Ping unless on a high latency connection.  func (pgConn *PgConn) CheckConn() error { -	err := pgConn.conn.BufferReadUntilBlock() -	if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { -		return err +	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) +	defer cancel() + +	_, err := pgConn.ReceiveMessage(ctx) +	if err != nil { +		if !Timeout(err) { +			return err +		}  	} +  	return nil  } +// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to +// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the +// chances a query will be sent that fails without the client knowing whether the server received it or not. +func (pgConn *PgConn) Ping(ctx context.Context) error { +	return pgConn.Exec(ctx, "-- ping").Close() +} +  // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.  func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {  	return CommandTag{s: string(buf)}  } +// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously +// blocked writing to us. +func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { +	// The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS +	// outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for +	// the normal case, but short enough not to kill performance if a block occurs. +	// +	// In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is +	// ineffective. +	pgConn.slowWriteTimer.Reset(15 * time.Millisecond) +} + +// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. +func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { +	if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) { +		pgConn.slowWriteTimer.Stop() +	} +} + +func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { +	pgConn.enterPotentialWriteReadDeadlock() +	err := pgConn.frontend.Flush() +	pgConn.exitPotentialWriteReadDeadlock() +	return err +} +  // HijackedConn is the result of hijacking a connection.  //  // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning  // compatibility.  type HijackedConn struct { -	Conn              nbconn.Conn       // the non-blocking wrapper of the underlying TCP or unix domain socket connection +	Conn              net.Conn  	PID               uint32            // backend pid  	SecretKey         uint32            // key to use to send a cancel query message to the server  	ParameterStatuses map[string]string // parameters that have been reported by the server @@ -1695,6 +1794,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) {  	}  	pgConn.contextWatcher = newContextWatcher(pgConn.conn) +	pgConn.bgReader = bgreader.New(pgConn.conn) +	pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)  	return pgConn, nil  } @@ -1817,7 +1918,7 @@ func (p *Pipeline) Flush() error {  		return errors.New("pipeline closed")  	} -	err := p.conn.frontend.Flush() +	err := p.conn.flushWithPotentialWriteReadDeadlock()  	if err != nil {  		err = normalizeTimeoutError(p.ctx, err) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array.go b/vendor/github.com/jackc/pgx/v5/pgtype/array.go index 0fa4c129b..7dfee389e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array.go @@ -363,12 +363,13 @@ func quoteArrayElement(src string) string {  }  func isSpace(ch byte) bool { -	// see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 -	return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' +	// see array_isspace: +	// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c +	return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f'  }  func quoteArrayElementIfNeeded(src string) string { -	if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { +	if src == "" || (len(src) == 4 && strings.EqualFold(src, "null")) || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) {  		return quoteArrayElement(src)  	}  	return src diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go index e7be27e2d..71caffa74 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go @@ -1,10 +1,12 @@  package pgtype  import ( +	"bytes"  	"database/sql/driver"  	"encoding/json"  	"fmt"  	"strconv" +	"strings"  )  type BoolScanner interface { @@ -264,8 +266,8 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error {  		return fmt.Errorf("cannot scan NULL into %T", dst)  	} -	if len(src) != 1 { -		return fmt.Errorf("invalid length for bool: %v", len(src)) +	if len(src) == 0 { +		return fmt.Errorf("cannot scan empty string into %T", dst)  	}  	p, ok := (dst).(*bool) @@ -273,7 +275,12 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error {  		return ErrScanTargetTypeChanged  	} -	*p = src[0] == 't' +	v, err := planTextToBool(src) +	if err != nil { +		return err +	} + +	*p = v  	return nil  } @@ -309,9 +316,28 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error {  		return s.ScanBool(Bool{})  	} -	if len(src) != 1 { -		return fmt.Errorf("invalid length for bool: %v", len(src)) +	if len(src) == 0 { +		return fmt.Errorf("cannot scan empty string into %T", dst)  	} -	return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) +	v, err := planTextToBool(src) +	if err != nil { +		return err +	} + +	return s.ScanBool(Bool{Bool: v, Valid: true}) +} + +// https://www.postgresql.org/docs/11/datatype-boolean.html +func planTextToBool(src []byte) (bool, error) { +	s := string(bytes.ToLower(bytes.TrimSpace(src))) + +	switch { +	case strings.HasPrefix("true", s), strings.HasPrefix("yes", s), s == "on", s == "1": +		return true, nil +	case strings.HasPrefix("false", s), strings.HasPrefix("no", s), strings.HasPrefix("off", s), s == "0": +		return false, nil +	default: +		return false, fmt.Errorf("unknown boolean string representation %q", src) +	}  } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go index 8a2afbe1e..7fddeaa8a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go @@ -64,6 +64,9 @@ func underlyingNumberType(val any) (any, bool) {  	case reflect.String:  		convVal := refVal.String()  		return convVal, reflect.TypeOf(convVal) != refVal.Type() +	case reflect.Bool: +		convVal := refVal.Bool() +		return convVal, reflect.TypeOf(convVal) != refVal.Type()  	}  	return nil, false @@ -262,7 +265,7 @@ func int64AssignTo(srcVal int64, srcValid bool, dst any) error {  			*v = uint8(srcVal)  		case *uint16:  			if srcVal < 0 { -				return fmt.Errorf("%d is less than zero for uint32", srcVal) +				return fmt.Errorf("%d is less than zero for uint16", srcVal)  			} else if srcVal > math.MaxUint16 {  				return fmt.Errorf("%d is greater than maximum value for uint16", srcVal)  			} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go index 4743643e5..9befabd05 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go @@ -1,14 +1,11 @@  package pgtype  import ( -	"bytes"  	"database/sql/driver"  	"encoding/binary"  	"errors"  	"fmt"  	"strings" -	"unicode" -	"unicode/utf8"  	"github.com/jackc/pgx/v5/internal/pgio"  ) @@ -43,7 +40,7 @@ func (h *Hstore) Scan(src any) error {  	switch src := src.(type) {  	case string: -		return scanPlanTextAnyToHstoreScanner{}.Scan([]byte(src), h) +		return scanPlanTextAnyToHstoreScanner{}.scanString(src, h)  	}  	return fmt.Errorf("cannot scan %T", src) @@ -137,13 +134,20 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e  			buf = append(buf, ',')  		} -		buf = append(buf, quoteHstoreElementIfNeeded(k)...) +		// unconditionally quote hstore keys/values like Postgres does +		// this avoids a Mac OS X Postgres hstore parsing bug: +		// https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com +		buf = append(buf, '"') +		buf = append(buf, quoteArrayReplacer.Replace(k)...) +		buf = append(buf, '"')  		buf = append(buf, "=>"...)  		if v == nil {  			buf = append(buf, "NULL"...)  		} else { -			buf = append(buf, quoteHstoreElementIfNeeded(*v)...) +			buf = append(buf, '"') +			buf = append(buf, quoteArrayReplacer.Replace(*v)...) +			buf = append(buf, '"')  		}  	} @@ -174,25 +178,28 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {  	scanner := (dst).(HstoreScanner)  	if src == nil { -		return scanner.ScanHstore(Hstore{}) +		return scanner.ScanHstore(Hstore(nil))  	}  	rp := 0 -	if len(src[rp:]) < 4 { +	const uint32Len = 4 +	if len(src[rp:]) < uint32Len {  		return fmt.Errorf("hstore incomplete %v", src)  	}  	pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) -	rp += 4 +	rp += uint32Len  	hstore := make(Hstore, pairCount) +	// one allocation for all *string, rather than one per string, just like text parsing +	valueStrings := make([]string, pairCount)  	for i := 0; i < pairCount; i++ { -		if len(src[rp:]) < 4 { +		if len(src[rp:]) < uint32Len {  			return fmt.Errorf("hstore incomplete %v", src)  		}  		keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) -		rp += 4 +		rp += uint32Len  		if len(src[rp:]) < keyLen {  			return fmt.Errorf("hstore incomplete %v", src) @@ -200,26 +207,17 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {  		key := string(src[rp : rp+keyLen])  		rp += keyLen -		if len(src[rp:]) < 4 { +		if len(src[rp:]) < uint32Len {  			return fmt.Errorf("hstore incomplete %v", src)  		}  		valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))  		rp += 4 -		var valueBuf []byte  		if valueLen >= 0 { -			valueBuf = src[rp : rp+valueLen] +			valueStrings[i] = string(src[rp : rp+valueLen])  			rp += valueLen -		} - -		var value Text -		err := scanPlanTextAnyToTextScanner{}.Scan(valueBuf, &value) -		if err != nil { -			return err -		} -		if value.Valid { -			hstore[key] = &value.String +			hstore[key] = &valueStrings[i]  		} else {  			hstore[key] = nil  		} @@ -230,28 +228,22 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {  type scanPlanTextAnyToHstoreScanner struct{} -func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { +func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error {  	scanner := (dst).(HstoreScanner)  	if src == nil { -		return scanner.ScanHstore(Hstore{}) +		return scanner.ScanHstore(Hstore(nil))  	} +	return s.scanString(string(src), scanner) +} -	keys, values, err := parseHstore(string(src)) +// scanString does not return nil hstore values because string cannot be nil. +func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error { +	hstore, err := parseHstore(src)  	if err != nil {  		return err  	} - -	m := make(Hstore, len(keys)) -	for i := range keys { -		if values[i].Valid { -			m[keys[i]] = &values[i].String -		} else { -			m[keys[i]] = nil -		} -	} - -	return scanner.ScanHstore(m) +	return scanner.ScanHstore(hstore)  }  func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { @@ -271,191 +263,217 @@ func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (  	return hstore, nil  } -var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) +type hstoreParser struct { +	str           string +	pos           int +	nextBackslash int +} + +func newHSP(in string) *hstoreParser { +	return &hstoreParser{ +		pos:           0, +		str:           in, +		nextBackslash: strings.IndexByte(in, '\\'), +	} +} -func quoteHstoreElement(src string) string { -	return `"` + quoteArrayReplacer.Replace(src) + `"` +func (p *hstoreParser) atEnd() bool { +	return p.pos >= len(p.str)  } -func quoteHstoreElementIfNeeded(src string) string { -	if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { -		return quoteArrayElement(src) +// consume returns the next byte of the string, or end if the string is done. +func (p *hstoreParser) consume() (b byte, end bool) { +	if p.pos >= len(p.str) { +		return 0, true  	} -	return src +	b = p.str[p.pos] +	p.pos++ +	return b, false  } -const ( -	hsPre = iota -	hsKey -	hsSep -	hsVal -	hsNul -	hsNext -) +func unexpectedByteErr(actualB byte, expectedB byte) error { +	return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) +} -type hstoreParser struct { -	str string -	pos int +// consumeExpectedByte consumes expectedB from the string, or returns an error. +func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { +	nextB, end := p.consume() +	if end { +		return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB) +	} +	if nextB != expectedB { +		return unexpectedByteErr(nextB, expectedB) +	} +	return nil  } -func newHSP(in string) *hstoreParser { -	return &hstoreParser{ -		pos: 0, -		str: in, +// consumeExpected2 consumes two expected bytes or returns an error. +// This was a bit faster than using a string argument (better inlining? Not sure). +func (p *hstoreParser) consumeExpected2(one byte, two byte) error { +	if p.pos+2 > len(p.str) { +		return errors.New("unexpected end of string") +	} +	if p.str[p.pos] != one { +		return unexpectedByteErr(p.str[p.pos], one)  	} +	if p.str[p.pos+1] != two { +		return unexpectedByteErr(p.str[p.pos+1], two) +	} +	p.pos += 2 +	return nil  } -func (p *hstoreParser) Consume() (r rune, end bool) { -	if p.pos >= len(p.str) { -		end = true -		return +var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`) + +// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been +// parsed already. This copies the string from the backing string so it can be garbage collected. +func (p *hstoreParser) consumeDoubleQuoted() (string, error) { +	// fast path: assume most keys/values do not contain escapes +	nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"') +	if nextDoubleQuote == -1 { +		return "", errEOSInQuoted +	} +	nextDoubleQuote += p.pos +	if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote { +		// clone the string from the source string to ensure it can be garbage collected separately +		// TODO: use strings.Clone on Go 1.20; this could get optimized away +		s := strings.Clone(p.str[p.pos:nextDoubleQuote]) +		p.pos = nextDoubleQuote + 1 +		return s, nil  	} -	r, w := utf8.DecodeRuneInString(p.str[p.pos:]) -	p.pos += w -	return + +	// slow path: string contains escapes +	s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash) +	p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\') +	if p.nextBackslash != -1 { +		p.nextBackslash += p.pos +	} +	return s, err  } -func (p *hstoreParser) Peek() (r rune, end bool) { -	if p.pos >= len(p.str) { -		end = true -		return +// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting +// at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be +// garbage collected separately. +func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) { +	// copy the prefix that does not contain backslashes +	var builder strings.Builder +	builder.WriteString(p.str[p.pos:firstBackslash]) + +	// skip to the backslash +	p.pos = firstBackslash + +	// copy bytes until the end, unescaping backslashes +	for { +		nextB, end := p.consume() +		if end { +			return "", errEOSInQuoted +		} else if nextB == '"' { +			break +		} else if nextB == '\\' { +			// escape: skip the backslash and copy the char +			nextB, end = p.consume() +			if end { +				return "", errEOSInQuoted +			} +			if !(nextB == '\\' || nextB == '"') { +				return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB) +			} +			builder.WriteByte(nextB) +		} else { +			// normal byte: copy it +			builder.WriteByte(nextB) +		}  	} -	r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) -	return +	return builder.String(), nil +} + +// consumePairSeparator consumes the Hstore pair separator ", " or returns an error. +func (p *hstoreParser) consumePairSeparator() error { +	return p.consumeExpected2(',', ' ')  } -// parseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores. -func parseHstore(s string) (k []string, v []Text, err error) { -	if s == "" { -		return +// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeKVSeparator() error { +	return p.consumeExpected2('=', '>') +} + +// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) { +	// peek at the next byte +	if p.atEnd() { +		return Text{}, errors.New("found end instead of value") +	} +	next := p.str[p.pos] +	if next == 'N' { +		// must be the exact string NULL: use consumeExpected2 twice +		err := p.consumeExpected2('N', 'U') +		if err != nil { +			return Text{}, err +		} +		err = p.consumeExpected2('L', 'L') +		if err != nil { +			return Text{}, err +		} +		return Text{String: "", Valid: false}, nil +	} else if next != '"' { +		return Text{}, unexpectedByteErr(next, '"')  	} -	buf := bytes.Buffer{} -	keys := []string{} -	values := []Text{} -	p := newHSP(s) +	// skip the double quote +	p.pos += 1 +	s, err := p.consumeDoubleQuoted() +	if err != nil { +		return Text{}, err +	} +	return Text{String: s, Valid: true}, nil +} -	r, end := p.Consume() -	state := hsPre +func parseHstore(s string) (Hstore, error) { +	p := newHSP(s) -	for !end { -		switch state { -		case hsPre: -			if r == '"' { -				state = hsKey -			} else { -				err = errors.New("String does not begin with \"") -			} -		case hsKey: -			switch r { -			case '"': //End of the key -				keys = append(keys, buf.String()) -				buf = bytes.Buffer{} -				state = hsSep -			case '\\': //Potential escaped character -				n, end := p.Consume() -				switch { -				case end: -					err = errors.New("Found EOS in key, expecting character or \"") -				case n == '"', n == '\\': -					buf.WriteRune(n) -				default: -					buf.WriteRune(r) -					buf.WriteRune(n) -				} -			default: //Any other character -				buf.WriteRune(r) -			} -		case hsSep: -			if r == '=' { -				r, end = p.Consume() -				switch { -				case end: -					err = errors.New("Found EOS after '=', expecting '>'") -				case r == '>': -					r, end = p.Consume() -					switch { -					case end: -						err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") -					case r == '"': -						state = hsVal -					case r == 'N': -						state = hsNul -					default: -						err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) -					} -				default: -					err = fmt.Errorf("Invalid character after '=', expecting '>'") -				} -			} else { -				err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) -			} -		case hsVal: -			switch r { -			case '"': //End of the value -				values = append(values, Text{String: buf.String(), Valid: true}) -				buf = bytes.Buffer{} -				state = hsNext -			case '\\': //Potential escaped character -				n, end := p.Consume() -				switch { -				case end: -					err = errors.New("Found EOS in key, expecting character or \"") -				case n == '"', n == '\\': -					buf.WriteRune(n) -				default: -					buf.WriteRune(r) -					buf.WriteRune(n) -				} -			default: //Any other character -				buf.WriteRune(r) -			} -		case hsNul: -			nulBuf := make([]rune, 3) -			nulBuf[0] = r -			for i := 1; i < 3; i++ { -				r, end = p.Consume() -				if end { -					err = errors.New("Found EOS in NULL value") -					return -				} -				nulBuf[i] = r -			} -			if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { -				values = append(values, Text{}) -				state = hsNext -			} else { -				err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) -			} -		case hsNext: -			if r == ',' { -				r, end = p.Consume() -				switch { -				case end: -					err = errors.New("Found EOS after ',', expcting space") -				case (unicode.IsSpace(r)): -					r, end = p.Consume() -					state = hsKey -				default: -					err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) -				} -			} else { -				err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) +	// This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it +	// is less likely to occur in keys/values than '=' or ','. +	numPairsEstimate := strings.Count(s, ">") +	// makes one allocation of strings for the entire Hstore, rather than one allocation per value. +	valueStrings := make([]string, 0, numPairsEstimate) +	result := make(Hstore, numPairsEstimate) +	first := true +	for !p.atEnd() { +		if !first { +			err := p.consumePairSeparator() +			if err != nil { +				return nil, err  			} +		} else { +			first = false  		} +		err := p.consumeExpectedByte('"')  		if err != nil { -			return +			return nil, err +		} + +		key, err := p.consumeDoubleQuoted() +		if err != nil { +			return nil, err +		} + +		err = p.consumeKVSeparator() +		if err != nil { +			return nil, err +		} + +		value, err := p.consumeDoubleQuotedOrNull() +		if err != nil { +			return nil, err +		} +		if value.Valid { +			valueStrings = append(valueStrings, value.String) +			result[key] = &valueStrings[len(valueStrings)-1] +		} else { +			result[key] = nil  		} -		r, end = p.Consume() -	} -	if state != hsNext { -		err = errors.New("Improperly formatted hstore") -		return  	} -	k = keys -	v = values -	return + +	return result, nil  } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/json.go b/vendor/github.com/jackc/pgx/v5/pgtype/json.go index 69861bf88..753f24103 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/json.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/json.go @@ -150,7 +150,7 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {  		if dstValue.Kind() == reflect.Ptr {  			el := dstValue.Elem()  			switch el.Kind() { -			case reflect.Ptr, reflect.Slice, reflect.Map: +			case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface:  				el.Set(reflect.Zero(el.Type()))  				return nil  			} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go index 83b349cee..b9cd7b410 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go @@ -147,7 +147,7 @@ const (  	BinaryFormatCode = 1  ) -// A Codec converts between Go and PostgreSQL values. +// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a Map.  type Codec interface {  	// FormatSupported returns true if the format is supported.  	FormatSupported(int16) bool @@ -178,6 +178,7 @@ func (e *nullAssignmentError) Error() string {  	return fmt.Sprintf("cannot assign NULL to %T", e.dst)  } +// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a Map.  type Type struct {  	Codec Codec  	Name  string @@ -211,7 +212,9 @@ type Map struct {  }  func NewMap() *Map { -	m := &Map{ +	defaultMapInitOnce.Do(initDefaultMap) + +	return &Map{  		oidToType:         make(map[uint32]*Type),  		nameToType:        make(map[string]*Type),  		reflectTypeToName: make(map[reflect.Type]string), @@ -240,184 +243,9 @@ func NewMap() *Map {  			TryWrapPtrArrayScanPlan,  		},  	} - -	// Base types -	m.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) -	m.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) -	m.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) -	m.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) -	m.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) -	m.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) -	m.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) -	m.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) -	m.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) -	m.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) -	m.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) -	m.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) -	m.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) -	m.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) -	m.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) -	m.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) -	m.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) -	m.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) -	m.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) -	m.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) -	m.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) -	m.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) -	m.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) -	m.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) -	m.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) -	m.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) -	m.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) -	m.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) -	m.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) -	m.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) -	m.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) -	m.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) -	m.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) -	m.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) -	m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) -	m.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) -	m.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) -	m.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) -	m.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) -	m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) -	m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - -	// Range types -	m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) -	m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) -	m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) -	m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) -	m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) -	m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) - -	// Multirange types -	m.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[DaterangeOID]}}) -	m.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int4rangeOID]}}) -	m.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int8rangeOID]}}) -	m.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[NumrangeOID]}}) -	m.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TsrangeOID]}}) -	m.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TstzrangeOID]}}) - -	// Array types -	m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) -	m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) -	m.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoolOID]}}) -	m.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoxOID]}}) -	m.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BPCharOID]}}) -	m.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ByteaOID]}}) -	m.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[QCharOID]}}) -	m.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDOID]}}) -	m.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDROID]}}) -	m.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CircleOID]}}) -	m.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DateOID]}}) -	m.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DaterangeOID]}}) -	m.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float4OID]}}) -	m.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float8OID]}}) -	m.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[InetOID]}}) -	m.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int2OID]}}) -	m.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4OID]}}) -	m.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4rangeOID]}}) -	m.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8OID]}}) -	m.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8rangeOID]}}) -	m.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[IntervalOID]}}) -	m.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONOID]}}) -	m.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONBOID]}}) -	m.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONPathOID]}}) -	m.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LineOID]}}) -	m.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LsegOID]}}) -	m.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[MacaddrOID]}}) -	m.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NameOID]}}) -	m.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumericOID]}}) -	m.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumrangeOID]}}) -	m.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[OIDOID]}}) -	m.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PathOID]}}) -	m.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PointOID]}}) -	m.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PolygonOID]}}) -	m.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[RecordOID]}}) -	m.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TextOID]}}) -	m.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TIDOID]}}) -	m.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimeOID]}}) -	m.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestampOID]}}) -	m.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestamptzOID]}}) -	m.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TsrangeOID]}}) -	m.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TstzrangeOID]}}) -	m.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[UUIDOID]}}) -	m.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarbitOID]}}) -	m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) -	m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) - -	// Integer types that directly map to a PostgreSQL type -	registerDefaultPgTypeVariants[int16](m, "int2") -	registerDefaultPgTypeVariants[int32](m, "int4") -	registerDefaultPgTypeVariants[int64](m, "int8") - -	// Integer types that do not have a direct match to a PostgreSQL type -	registerDefaultPgTypeVariants[int8](m, "int8") -	registerDefaultPgTypeVariants[int](m, "int8") -	registerDefaultPgTypeVariants[uint8](m, "int8") -	registerDefaultPgTypeVariants[uint16](m, "int8") -	registerDefaultPgTypeVariants[uint32](m, "int8") -	registerDefaultPgTypeVariants[uint64](m, "numeric") -	registerDefaultPgTypeVariants[uint](m, "numeric") - -	registerDefaultPgTypeVariants[float32](m, "float4") -	registerDefaultPgTypeVariants[float64](m, "float8") - -	registerDefaultPgTypeVariants[bool](m, "bool") -	registerDefaultPgTypeVariants[time.Time](m, "timestamptz") -	registerDefaultPgTypeVariants[time.Duration](m, "interval") -	registerDefaultPgTypeVariants[string](m, "text") -	registerDefaultPgTypeVariants[[]byte](m, "bytea") - -	registerDefaultPgTypeVariants[net.IP](m, "inet") -	registerDefaultPgTypeVariants[net.IPNet](m, "cidr") -	registerDefaultPgTypeVariants[netip.Addr](m, "inet") -	registerDefaultPgTypeVariants[netip.Prefix](m, "cidr") - -	// pgtype provided structs -	registerDefaultPgTypeVariants[Bits](m, "varbit") -	registerDefaultPgTypeVariants[Bool](m, "bool") -	registerDefaultPgTypeVariants[Box](m, "box") -	registerDefaultPgTypeVariants[Circle](m, "circle") -	registerDefaultPgTypeVariants[Date](m, "date") -	registerDefaultPgTypeVariants[Range[Date]](m, "daterange") -	registerDefaultPgTypeVariants[Multirange[Range[Date]]](m, "datemultirange") -	registerDefaultPgTypeVariants[Float4](m, "float4") -	registerDefaultPgTypeVariants[Float8](m, "float8") -	registerDefaultPgTypeVariants[Range[Float8]](m, "numrange")                  // There is no PostgreSQL builtin float8range so map it to numrange. -	registerDefaultPgTypeVariants[Multirange[Range[Float8]]](m, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. -	registerDefaultPgTypeVariants[Int2](m, "int2") -	registerDefaultPgTypeVariants[Int4](m, "int4") -	registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") -	registerDefaultPgTypeVariants[Multirange[Range[Int4]]](m, "int4multirange") -	registerDefaultPgTypeVariants[Int8](m, "int8") -	registerDefaultPgTypeVariants[Range[Int8]](m, "int8range") -	registerDefaultPgTypeVariants[Multirange[Range[Int8]]](m, "int8multirange") -	registerDefaultPgTypeVariants[Interval](m, "interval") -	registerDefaultPgTypeVariants[Line](m, "line") -	registerDefaultPgTypeVariants[Lseg](m, "lseg") -	registerDefaultPgTypeVariants[Numeric](m, "numeric") -	registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange") -	registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](m, "nummultirange") -	registerDefaultPgTypeVariants[Path](m, "path") -	registerDefaultPgTypeVariants[Point](m, "point") -	registerDefaultPgTypeVariants[Polygon](m, "polygon") -	registerDefaultPgTypeVariants[TID](m, "tid") -	registerDefaultPgTypeVariants[Text](m, "text") -	registerDefaultPgTypeVariants[Time](m, "time") -	registerDefaultPgTypeVariants[Timestamp](m, "timestamp") -	registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz") -	registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange") -	registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](m, "tsmultirange") -	registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange") -	registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](m, "tstzmultirange") -	registerDefaultPgTypeVariants[UUID](m, "uuid") - -	return m  } +// RegisterType registers a data type with the Map. t must not be mutated after it is registered.  func (m *Map) RegisterType(t *Type) {  	m.oidToType[t.OID] = t  	m.nameToType[t.Name] = t @@ -449,13 +277,22 @@ func (m *Map) RegisterDefaultPgType(value any, name string) {  	}  } +// TypeForOID returns the Type registered for the given OID. The returned Type must not be mutated.  func (m *Map) TypeForOID(oid uint32) (*Type, bool) { -	dt, ok := m.oidToType[oid] +	if dt, ok := m.oidToType[oid]; ok { +		return dt, true +	} + +	dt, ok := defaultMap.oidToType[oid]  	return dt, ok  } +// TypeForName returns the Type registered for the given name. The returned Type must not be mutated.  func (m *Map) TypeForName(name string) (*Type, bool) { -	dt, ok := m.nameToType[name] +	if dt, ok := m.nameToType[name]; ok { +		return dt, true +	} +	dt, ok := defaultMap.nameToType[name]  	return dt, ok  } @@ -463,30 +300,39 @@ func (m *Map) buildReflectTypeToType() {  	m.reflectTypeToType = make(map[reflect.Type]*Type)  	for reflectType, name := range m.reflectTypeToName { -		if dt, ok := m.nameToType[name]; ok { +		if dt, ok := m.TypeForName(name); ok {  			m.reflectTypeToType[reflectType] = dt  		}  	}  }  // TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode -// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type.  The returned Type +// must not be mutated.  func (m *Map) TypeForValue(v any) (*Type, bool) {  	if m.reflectTypeToType == nil {  		m.buildReflectTypeToType()  	} -	dt, ok := m.reflectTypeToType[reflect.TypeOf(v)] +	if dt, ok := m.reflectTypeToType[reflect.TypeOf(v)]; ok { +		return dt, true +	} + +	dt, ok := defaultMap.reflectTypeToType[reflect.TypeOf(v)]  	return dt, ok  }  // FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text  // format code.  func (m *Map) FormatCodeForOID(oid uint32) int16 { -	fc, ok := m.oidToFormatCode[oid] -	if ok { +	if fc, ok := m.oidToFormatCode[oid]; ok { +		return fc +	} + +	if fc, ok := defaultMap.oidToFormatCode[oid]; ok {  		return fc  	} +  	return TextFormatCode  } @@ -587,6 +433,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error {  				return plan.Scan(src, dst)  			}  		} +		for oid := range defaultMap.oidToType { +			if _, ok := plan.m.oidToType[oid]; !ok { +				plan := plan.m.planScan(oid, plan.formatCode, dst) +				if _, ok := plan.(*scanPlanFail); !ok { +					return plan.Scan(src, dst) +				} +			} +		}  	}  	var format string @@ -600,7 +454,7 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error {  	}  	var dataTypeName string -	if t, ok := plan.m.oidToType[plan.oid]; ok { +	if t, ok := plan.m.TypeForOID(plan.oid); ok {  		dataTypeName = t.Name  	} else {  		dataTypeName = "unknown type" @@ -666,6 +520,7 @@ var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]refl  	reflect.Float32: reflect.TypeOf(new(float32)),  	reflect.Float64: reflect.TypeOf(new(float64)),  	reflect.String:  reflect.TypeOf(new(string)), +	reflect.Bool:    reflect.TypeOf(new(bool)),  }  type underlyingTypeScanPlan struct { @@ -1089,15 +944,16 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa  		return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true  	} -	targetValue := reflect.ValueOf(target) -	if targetValue.Kind() != reflect.Ptr { +	targetType := reflect.TypeOf(target) +	if targetType.Kind() != reflect.Ptr {  		return nil, nil, false  	} -	targetElemValue := targetValue.Elem() +	targetElemType := targetType.Elem() -	if targetElemValue.Kind() == reflect.Slice { -		return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true +	if targetElemType.Kind() == reflect.Slice { +		slice := reflect.New(targetElemType).Elem() +		return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: slice}, true  	}  	return nil, nil, false  } @@ -1198,6 +1054,10 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan {  }  func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { +	if target == nil { +		return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} +	} +  	if _, ok := target.(*UndecodedBytes); ok {  		return scanPlanAnyToUndecodedBytes{}  	} @@ -1514,6 +1374,7 @@ var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{  	reflect.Float32: reflect.TypeOf(float32(0)),  	reflect.Float64: reflect.TypeOf(float64(0)),  	reflect.String:  reflect.TypeOf(""), +	reflect.Bool:    reflect.TypeOf(false),  }  type underlyingTypeEncodePlan struct { @@ -2039,13 +1900,13 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error)  	}  	var dataTypeName string -	if t, ok := m.oidToType[oid]; ok { +	if t, ok := m.TypeForOID(oid); ok {  		dataTypeName = t.Name  	} else {  		dataTypeName = "unknown type"  	} -	return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %s", value, format, dataTypeName, oid, err) +	return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %w", value, format, dataTypeName, oid, err)  }  // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go new file mode 100644 index 000000000..58f4b92c7 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go @@ -0,0 +1,223 @@ +package pgtype + +import ( +	"net" +	"net/netip" +	"reflect" +	"sync" +	"time" +) + +var ( +	// defaultMap contains default mappings between PostgreSQL server types and Go type handling logic. +	defaultMap         *Map +	defaultMapInitOnce = sync.Once{} +) + +func initDefaultMap() { +	defaultMap = &Map{ +		oidToType:         make(map[uint32]*Type), +		nameToType:        make(map[string]*Type), +		reflectTypeToName: make(map[reflect.Type]string), +		oidToFormatCode:   make(map[uint32]int16), + +		memoizedScanPlans:   make(map[uint32]map[reflect.Type][2]ScanPlan), +		memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), + +		TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ +			TryWrapDerefPointerEncodePlan, +			TryWrapBuiltinTypeEncodePlan, +			TryWrapFindUnderlyingTypeEncodePlan, +			TryWrapStructEncodePlan, +			TryWrapSliceEncodePlan, +			TryWrapMultiDimSliceEncodePlan, +			TryWrapArrayEncodePlan, +		}, + +		TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ +			TryPointerPointerScanPlan, +			TryWrapBuiltinTypeScanPlan, +			TryFindUnderlyingTypeScanPlan, +			TryWrapStructScanPlan, +			TryWrapPtrSliceScanPlan, +			TryWrapPtrMultiDimSliceScanPlan, +			TryWrapPtrArrayScanPlan, +		}, +	} + +	// Base types +	defaultMap.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) +	defaultMap.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) +	defaultMap.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) +	defaultMap.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) +	defaultMap.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) +	defaultMap.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) +	defaultMap.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) +	defaultMap.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) +	defaultMap.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) +	defaultMap.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) +	defaultMap.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) +	defaultMap.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) +	defaultMap.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) +	defaultMap.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) +	defaultMap.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) +	defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) +	defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) +	defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) +	defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) +	defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) +	defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) +	defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) +	defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) +	defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) +	defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) +	defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) +	defaultMap.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) +	defaultMap.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) +	defaultMap.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) +	defaultMap.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) +	defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) +	defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) +	defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) +	defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) +	defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) +	defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) +	defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) +	defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) +	defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) +	defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) +	defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + +	// Range types +	defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) +	defaultMap.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int4OID]}}) +	defaultMap.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int8OID]}}) +	defaultMap.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[NumericOID]}}) +	defaultMap.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) +	defaultMap.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + +	// Multirange types +	defaultMap.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + +	// Array types +	defaultMap.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ACLItemOID]}}) +	defaultMap.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BitOID]}}) +	defaultMap.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoolOID]}}) +	defaultMap.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoxOID]}}) +	defaultMap.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BPCharOID]}}) +	defaultMap.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ByteaOID]}}) +	defaultMap.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[QCharOID]}}) +	defaultMap.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDOID]}}) +	defaultMap.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDROID]}}) +	defaultMap.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CircleOID]}}) +	defaultMap.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DateOID]}}) +	defaultMap.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float4OID]}}) +	defaultMap.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float8OID]}}) +	defaultMap.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[InetOID]}}) +	defaultMap.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int2OID]}}) +	defaultMap.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4OID]}}) +	defaultMap.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8OID]}}) +	defaultMap.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[IntervalOID]}}) +	defaultMap.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONOID]}}) +	defaultMap.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONBOID]}}) +	defaultMap.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONPathOID]}}) +	defaultMap.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LineOID]}}) +	defaultMap.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LsegOID]}}) +	defaultMap.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[MacaddrOID]}}) +	defaultMap.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NameOID]}}) +	defaultMap.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumericOID]}}) +	defaultMap.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[OIDOID]}}) +	defaultMap.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PathOID]}}) +	defaultMap.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PointOID]}}) +	defaultMap.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PolygonOID]}}) +	defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}}) +	defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}}) +	defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}}) +	defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}}) +	defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) +	defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) +	defaultMap.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) +	defaultMap.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[UUIDOID]}}) +	defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) +	defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) +	defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + +	// Integer types that directly map to a PostgreSQL type +	registerDefaultPgTypeVariants[int16](defaultMap, "int2") +	registerDefaultPgTypeVariants[int32](defaultMap, "int4") +	registerDefaultPgTypeVariants[int64](defaultMap, "int8") + +	// Integer types that do not have a direct match to a PostgreSQL type +	registerDefaultPgTypeVariants[int8](defaultMap, "int8") +	registerDefaultPgTypeVariants[int](defaultMap, "int8") +	registerDefaultPgTypeVariants[uint8](defaultMap, "int8") +	registerDefaultPgTypeVariants[uint16](defaultMap, "int8") +	registerDefaultPgTypeVariants[uint32](defaultMap, "int8") +	registerDefaultPgTypeVariants[uint64](defaultMap, "numeric") +	registerDefaultPgTypeVariants[uint](defaultMap, "numeric") + +	registerDefaultPgTypeVariants[float32](defaultMap, "float4") +	registerDefaultPgTypeVariants[float64](defaultMap, "float8") + +	registerDefaultPgTypeVariants[bool](defaultMap, "bool") +	registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") +	registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") +	registerDefaultPgTypeVariants[string](defaultMap, "text") +	registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") + +	registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") +	registerDefaultPgTypeVariants[net.IPNet](defaultMap, "cidr") +	registerDefaultPgTypeVariants[netip.Addr](defaultMap, "inet") +	registerDefaultPgTypeVariants[netip.Prefix](defaultMap, "cidr") + +	// pgtype provided structs +	registerDefaultPgTypeVariants[Bits](defaultMap, "varbit") +	registerDefaultPgTypeVariants[Bool](defaultMap, "bool") +	registerDefaultPgTypeVariants[Box](defaultMap, "box") +	registerDefaultPgTypeVariants[Circle](defaultMap, "circle") +	registerDefaultPgTypeVariants[Date](defaultMap, "date") +	registerDefaultPgTypeVariants[Range[Date]](defaultMap, "daterange") +	registerDefaultPgTypeVariants[Multirange[Range[Date]]](defaultMap, "datemultirange") +	registerDefaultPgTypeVariants[Float4](defaultMap, "float4") +	registerDefaultPgTypeVariants[Float8](defaultMap, "float8") +	registerDefaultPgTypeVariants[Range[Float8]](defaultMap, "numrange")                  // There is no PostgreSQL builtin float8range so map it to numrange. +	registerDefaultPgTypeVariants[Multirange[Range[Float8]]](defaultMap, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. +	registerDefaultPgTypeVariants[Int2](defaultMap, "int2") +	registerDefaultPgTypeVariants[Int4](defaultMap, "int4") +	registerDefaultPgTypeVariants[Range[Int4]](defaultMap, "int4range") +	registerDefaultPgTypeVariants[Multirange[Range[Int4]]](defaultMap, "int4multirange") +	registerDefaultPgTypeVariants[Int8](defaultMap, "int8") +	registerDefaultPgTypeVariants[Range[Int8]](defaultMap, "int8range") +	registerDefaultPgTypeVariants[Multirange[Range[Int8]]](defaultMap, "int8multirange") +	registerDefaultPgTypeVariants[Interval](defaultMap, "interval") +	registerDefaultPgTypeVariants[Line](defaultMap, "line") +	registerDefaultPgTypeVariants[Lseg](defaultMap, "lseg") +	registerDefaultPgTypeVariants[Numeric](defaultMap, "numeric") +	registerDefaultPgTypeVariants[Range[Numeric]](defaultMap, "numrange") +	registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](defaultMap, "nummultirange") +	registerDefaultPgTypeVariants[Path](defaultMap, "path") +	registerDefaultPgTypeVariants[Point](defaultMap, "point") +	registerDefaultPgTypeVariants[Polygon](defaultMap, "polygon") +	registerDefaultPgTypeVariants[TID](defaultMap, "tid") +	registerDefaultPgTypeVariants[Text](defaultMap, "text") +	registerDefaultPgTypeVariants[Time](defaultMap, "time") +	registerDefaultPgTypeVariants[Timestamp](defaultMap, "timestamp") +	registerDefaultPgTypeVariants[Timestamptz](defaultMap, "timestamptz") +	registerDefaultPgTypeVariants[Range[Timestamp]](defaultMap, "tsrange") +	registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange") +	registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange") +	registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange") +	registerDefaultPgTypeVariants[UUID](defaultMap, "uuid") + +	defaultMap.buildReflectTypeToType() +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go index 9f3de2c59..35d739566 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go @@ -3,6 +3,7 @@ package pgtype  import (  	"database/sql/driver"  	"encoding/binary" +	"encoding/json"  	"fmt"  	"strings"  	"time" @@ -66,6 +67,55 @@ func (ts Timestamp) Value() (driver.Value, error) {  	return ts.Time, nil  } +func (ts Timestamp) MarshalJSON() ([]byte, error) { +	if !ts.Valid { +		return []byte("null"), nil +	} + +	var s string + +	switch ts.InfinityModifier { +	case Finite: +		s = ts.Time.Format(time.RFC3339Nano) +	case Infinity: +		s = "infinity" +	case NegativeInfinity: +		s = "-infinity" +	} + +	return json.Marshal(s) +} + +func (ts *Timestamp) UnmarshalJSON(b []byte) error { +	var s *string +	err := json.Unmarshal(b, &s) +	if err != nil { +		return err +	} + +	if s == nil { +		*ts = Timestamp{} +		return nil +	} + +	switch *s { +	case "infinity": +		*ts = Timestamp{Valid: true, InfinityModifier: Infinity} +	case "-infinity": +		*ts = Timestamp{Valid: true, InfinityModifier: -Infinity} +	default: +		// PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz +		tim, err := time.Parse(time.RFC3339Nano, *s) +		if err != nil { +			return err +		} + +		*ts = Timestamp{Time: tim, Valid: true} +	} + +	return nil +} +  type TimestampCodec struct{}  func (TimestampCodec) FormatSupported(format int16) bool { diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go index ffe739b02..cdd72a25f 100644 --- a/vendor/github.com/jackc/pgx/v5/rows.go +++ b/vendor/github.com/jackc/pgx/v5/rows.go @@ -28,12 +28,16 @@ type Rows interface {  	// to call Close after rows is already closed.  	Close() -	// Err returns any error that occurred while reading. +	// Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by +	// calling Close or by Next returning false). If it is called early it may return nil even if there was an error +	// executing the query.  	Err() error  	// CommandTag returns the command tag from this query. It is only available after Rows is closed.  	CommandTag() pgconn.CommandTag +	// FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur +	// when there was an error executing the query.  	FieldDescriptions() []pgconn.FieldDescription  	// Next prepares the next row for reading. It returns true if there is another @@ -533,13 +537,11 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val  	for i := 0; i < dstElemType.NumField(); i++ {  		sf := dstElemType.Field(i) -		if sf.PkgPath == "" { -			// Handle anonymous struct embedding, but do not try to handle embedded pointers. -			if sf.Anonymous && sf.Type.Kind() == reflect.Struct { -				scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) -			} else { -				scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) -			} +		// Handle anonymous struct embedding, but do not try to handle embedded pointers. +		if sf.Anonymous && sf.Type.Kind() == reflect.Struct { +			scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) +		} else if sf.PkgPath == "" { +			scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())  		}  	} @@ -565,8 +567,28 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {  	return &value, err  } +// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public +// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database +// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. +func RowToStructByNameLax[T any](row CollectableRow) (T, error) { +	var value T +	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) +	return value, err +} + +// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or +// equal number of named public fields as row has fields. The row and T fields will by matched by name. The match is +// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" +// then the field will be ignored. +func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { +	var value T +	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) +	return &value, err +} +  type namedStructRowScanner struct {  	ptrToStruct any +	lax         bool  }  func (rs *namedStructRowScanner) ScanRow(rows Rows) error { @@ -578,7 +600,6 @@ func (rs *namedStructRowScanner) ScanRow(rows Rows) error {  	dstElemValue := dstValue.Elem()  	scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) -  	if err != nil {  		return err  	} @@ -638,7 +659,13 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s  				colName = sf.Name  			}  			fpos := fieldPosByName(fldDescs, colName) -			if fpos == -1 || fpos >= len(scanTargets) { +			if fpos == -1 { +				if rs.lax { +					continue +				} +				return nil, fmt.Errorf("cannot find field %s in returned row", colName) +			} +			if fpos >= len(scanTargets) && !rs.lax {  				return nil, fmt.Errorf("cannot find field %s in returned row", colName)  			}  			scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() diff --git a/vendor/github.com/jackc/pgx/v5/tx.go b/vendor/github.com/jackc/pgx/v5/tx.go index e57142a61..575c17a71 100644 --- a/vendor/github.com/jackc/pgx/v5/tx.go +++ b/vendor/github.com/jackc/pgx/v5/tx.go @@ -44,6 +44,10 @@ type TxOptions struct {  	IsoLevel       TxIsoLevel  	AccessMode     TxAccessMode  	DeferrableMode TxDeferrableMode + +	// BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax +	// such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings. +	BeginQuery string  }  var emptyTxOptions TxOptions @@ -53,6 +57,10 @@ func (txOptions TxOptions) beginSQL() string {  		return "begin"  	} +	if txOptions.BeginQuery != "" { +		return txOptions.BeginQuery +	} +  	var buf strings.Builder  	buf.Grow(64) // 64 - maximum length of string with available options  	buf.WriteString("begin") diff --git a/vendor/modules.txt b/vendor/modules.txt index 64dfaaf59..411a11143 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -334,16 +334,16 @@ github.com/jackc/pgproto3/v2  # github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a  ## explicit; go 1.14  github.com/jackc/pgservicefile -# github.com/jackc/pgx/v5 v5.3.1 +# github.com/jackc/pgx/v5 v5.4.1  ## explicit; go 1.19  github.com/jackc/pgx/v5  github.com/jackc/pgx/v5/internal/anynil  github.com/jackc/pgx/v5/internal/iobufpool -github.com/jackc/pgx/v5/internal/nbconn  github.com/jackc/pgx/v5/internal/pgio  github.com/jackc/pgx/v5/internal/sanitize  github.com/jackc/pgx/v5/internal/stmtcache  github.com/jackc/pgx/v5/pgconn +github.com/jackc/pgx/v5/pgconn/internal/bgreader  github.com/jackc/pgx/v5/pgconn/internal/ctxwatch  github.com/jackc/pgx/v5/pgproto3  github.com/jackc/pgx/v5/pgtype  | 
