summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go')
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go240
1 files changed, 191 insertions, 49 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
index 7efb522a4..14966aa49 100644
--- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
+++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
@@ -1,6 +1,7 @@
package pgconn
import (
+ "container/list"
"context"
"crypto/md5"
"crypto/tls"
@@ -267,12 +268,15 @@ func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*
var pgErr *PgError
if errors.As(err, &pgErr) {
- const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
- const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
- const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
- const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
+ // pgx will try next host even if libpq does not in certain cases (see #2246)
+ // consider change for the next major version
+
+ const ERRCODE_INVALID_PASSWORD = "28P01"
+ const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
+ const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
+
+ // auth failed due to invalid password, db does not exist or user has no permission
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
- pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil ||
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
return nil, allErrors
@@ -1408,9 +1412,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
- pgConn *PgConn
- ctx context.Context
- pipeline *Pipeline
+ pgConn *PgConn
+ ctx context.Context
rr *ResultReader
@@ -1443,12 +1446,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
mrr.closed = true
- if mrr.pipeline != nil {
- mrr.pipeline.expectedReadyForQueryCount--
- } else {
- mrr.pgConn.contextWatcher.Unwatch()
- mrr.pgConn.unlock()
- }
+ mrr.pgConn.contextWatcher.Unwatch()
+ mrr.pgConn.unlock()
case *pgproto3.ErrorResponse:
mrr.err = ErrorResponseToPgError(msg)
}
@@ -1672,7 +1671,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.EmptyQueryResponse:
rr.concludeCommand(CommandTag{}, nil)
case *pgproto3.ErrorResponse:
- rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
+ pgErr := ErrorResponseToPgError(msg)
+ if rr.pipeline != nil {
+ rr.pipeline.state.HandleError(pgErr)
+ }
+ rr.concludeCommand(CommandTag{}, pgErr)
}
return msg, nil
@@ -1773,9 +1776,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
+ pgConn.contextWatcher.Unwatch()
+ multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
multiResult.closed = true
- multiResult.err = batch.err
- pgConn.unlock()
+ pgConn.asyncClose()
return multiResult
}
@@ -1783,9 +1787,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
defer pgConn.exitPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
+ pgConn.contextWatcher.Unwatch()
+ multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
multiResult.closed = true
- multiResult.err = err
- pgConn.unlock()
+ pgConn.asyncClose()
return multiResult
}
@@ -1999,9 +2004,7 @@ type Pipeline struct {
conn *PgConn
ctx context.Context
- expectedReadyForQueryCount int
- pendingSync bool
-
+ state pipelineState
err error
closed bool
}
@@ -2012,6 +2015,122 @@ type PipelineSync struct{}
// CloseComplete is returned by GetResults when a CloseComplete message is received.
type CloseComplete struct{}
+type pipelineRequestType int
+
+const (
+ pipelineNil pipelineRequestType = iota
+ pipelinePrepare
+ pipelineQueryParams
+ pipelineQueryPrepared
+ pipelineDeallocate
+ pipelineSyncRequest
+ pipelineFlushRequest
+)
+
+type pipelineRequestEvent struct {
+ RequestType pipelineRequestType
+ WasSentToServer bool
+ BeforeFlushOrSync bool
+}
+
+type pipelineState struct {
+ requestEventQueue list.List
+ lastRequestType pipelineRequestType
+ pgErr *PgError
+ expectedReadyForQueryCount int
+}
+
+func (s *pipelineState) Init() {
+ s.requestEventQueue.Init()
+ s.lastRequestType = pipelineNil
+}
+
+func (s *pipelineState) RegisterSendingToServer() {
+ for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
+ val := elem.Value.(pipelineRequestEvent)
+ if val.WasSentToServer {
+ return
+ }
+ val.WasSentToServer = true
+ elem.Value = val
+ }
+}
+
+func (s *pipelineState) registerFlushingBufferOnServer() {
+ for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
+ val := elem.Value.(pipelineRequestEvent)
+ if val.BeforeFlushOrSync {
+ return
+ }
+ val.BeforeFlushOrSync = true
+ elem.Value = val
+ }
+}
+
+func (s *pipelineState) PushBackRequestType(req pipelineRequestType) {
+ if req == pipelineNil {
+ return
+ }
+
+ if req != pipelineFlushRequest {
+ s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req})
+ }
+ if req == pipelineFlushRequest || req == pipelineSyncRequest {
+ s.registerFlushingBufferOnServer()
+ }
+ s.lastRequestType = req
+
+ if req == pipelineSyncRequest {
+ s.expectedReadyForQueryCount++
+ }
+}
+
+func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType {
+ for {
+ elem := s.requestEventQueue.Front()
+ if elem == nil {
+ return pipelineNil
+ }
+ val := elem.Value.(pipelineRequestEvent)
+ if !(val.WasSentToServer && val.BeforeFlushOrSync) {
+ return pipelineNil
+ }
+
+ s.requestEventQueue.Remove(elem)
+ if val.RequestType == pipelineSyncRequest {
+ s.pgErr = nil
+ }
+ if s.pgErr == nil {
+ return val.RequestType
+ }
+ }
+}
+
+func (s *pipelineState) HandleError(err *PgError) {
+ s.pgErr = err
+}
+
+func (s *pipelineState) HandleReadyForQuery() {
+ s.expectedReadyForQueryCount--
+}
+
+func (s *pipelineState) PendingSync() bool {
+ var notPendingSync bool
+
+ if elem := s.requestEventQueue.Back(); elem != nil {
+ val := elem.Value.(pipelineRequestEvent)
+ notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer
+ } else {
+ notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil)
+ }
+
+ return !notPendingSync
+}
+
+func (s *pipelineState) ExpectedReadyForQuery() int {
+ return s.expectedReadyForQueryCount
+}
+
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
@@ -2020,16 +2139,21 @@ type CloseComplete struct{}
// Prefer ExecBatch when only sending one group of queries at once.
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
if err := pgConn.lock(); err != nil {
- return &Pipeline{
+ pipeline := &Pipeline{
closed: true,
err: err,
}
+ pipeline.state.Init()
+
+ return pipeline
}
pgConn.pipeline = Pipeline{
conn: pgConn,
ctx: ctx,
}
+ pgConn.pipeline.state.Init()
+
pipeline := &pgConn.pipeline
if ctx != context.Background() {
@@ -2052,10 +2176,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
if p.closed {
return
}
- p.pendingSync = true
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
+ p.state.PushBackRequestType(pipelinePrepare)
}
// SendDeallocate deallocates a prepared statement.
@@ -2063,9 +2187,9 @@ func (p *Pipeline) SendDeallocate(name string) {
if p.closed {
return
}
- p.pendingSync = true
p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
+ p.state.PushBackRequestType(pipelineDeallocate)
}
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
@@ -2073,12 +2197,12 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [
if p.closed {
return
}
- p.pendingSync = true
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
p.conn.frontend.SendExecute(&pgproto3.Execute{})
+ p.state.PushBackRequestType(pipelineQueryParams)
}
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
@@ -2086,11 +2210,42 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para
if p.closed {
return
}
- p.pendingSync = true
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
p.conn.frontend.SendExecute(&pgproto3.Execute{})
+ p.state.PushBackRequestType(pipelineQueryPrepared)
+}
+
+// SendFlushRequest sends a request for the server to flush its output buffer.
+//
+// The server flushes its output buffer automatically as a result of Sync being called,
+// or on any request when not in pipeline mode; this function is useful to cause the server
+// to flush its output buffer in pipeline mode without establishing a synchronization point.
+// Note that the request is not itself flushed to the server automatically; use Flush if
+// necessary. This copies the behavior of libpq PQsendFlushRequest.
+func (p *Pipeline) SendFlushRequest() {
+ if p.closed {
+ return
+ }
+
+ p.conn.frontend.Send(&pgproto3.Flush{})
+ p.state.PushBackRequestType(pipelineFlushRequest)
+}
+
+// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message
+// without flushing the send buffer. This serves as the delimiter of an implicit
+// transaction and an error recovery point.
+//
+// Note that the request is not itself flushed to the server automatically; use Flush if
+// necessary. This copies the behavior of libpq PQsendPipelineSync.
+func (p *Pipeline) SendPipelineSync() {
+ if p.closed {
+ return
+ }
+
+ p.conn.frontend.SendSync(&pgproto3.Sync{})
+ p.state.PushBackRequestType(pipelineSyncRequest)
}
// Flush flushes the queued requests without establishing a synchronization point.
@@ -2115,28 +2270,14 @@ func (p *Pipeline) Flush() error {
return err
}
+ p.state.RegisterSendingToServer()
return nil
}
// Sync establishes a synchronization point and flushes the queued requests.
func (p *Pipeline) Sync() error {
- if p.closed {
- if p.err != nil {
- return p.err
- }
- return errors.New("pipeline closed")
- }
-
- p.conn.frontend.SendSync(&pgproto3.Sync{})
- err := p.Flush()
- if err != nil {
- return err
- }
-
- p.pendingSync = false
- p.expectedReadyForQueryCount++
-
- return nil
+ p.SendPipelineSync()
+ return p.Flush()
}
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
@@ -2150,7 +2291,7 @@ func (p *Pipeline) GetResults() (results any, err error) {
return nil, errors.New("pipeline closed")
}
- if p.expectedReadyForQueryCount == 0 {
+ if p.state.ExtractFrontRequestType() == pipelineNil {
return nil, nil
}
@@ -2195,13 +2336,13 @@ func (p *Pipeline) getResults() (results any, err error) {
case *pgproto3.CloseComplete:
return &CloseComplete{}, nil
case *pgproto3.ReadyForQuery:
- p.expectedReadyForQueryCount--
+ p.state.HandleReadyForQuery()
return &PipelineSync{}, nil
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
+ p.state.HandleError(pgErr)
return nil, pgErr
}
-
}
}
@@ -2231,6 +2372,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
// These should never happen here. But don't take chances that could lead to a deadlock.
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
+ p.state.HandleError(pgErr)
return nil, pgErr
case *pgproto3.CommandComplete:
p.conn.asyncClose()
@@ -2250,7 +2392,7 @@ func (p *Pipeline) Close() error {
p.closed = true
- if p.pendingSync {
+ if p.state.PendingSync() {
p.conn.asyncClose()
p.err = errors.New("pipeline has unsynced requests")
p.conn.contextWatcher.Unwatch()
@@ -2259,7 +2401,7 @@ func (p *Pipeline) Close() error {
return p.err
}
- for p.expectedReadyForQueryCount > 0 {
+ for p.state.ExpectedReadyForQuery() > 0 {
_, err := p.getResults()
if err != nil {
p.err = err