diff options
Diffstat (limited to 'vendor/github.com/go-pg/pg/v10/messages.go')
-rw-r--r-- | vendor/github.com/go-pg/pg/v10/messages.go | 1390 |
1 files changed, 1390 insertions, 0 deletions
diff --git a/vendor/github.com/go-pg/pg/v10/messages.go b/vendor/github.com/go-pg/pg/v10/messages.go new file mode 100644 index 000000000..7fb84ba0d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/messages.go @@ -0,0 +1,1390 @@ +package pg + +import ( + "bufio" + "context" + "crypto/md5" //nolint + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + + "mellium.im/sasl" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/orm" + "github.com/go-pg/pg/v10/types" +) + +// https://www.postgresql.org/docs/current/protocol-message-formats.html +const ( + commandCompleteMsg = 'C' + errorResponseMsg = 'E' + noticeResponseMsg = 'N' + parameterStatusMsg = 'S' + authenticationOKMsg = 'R' + backendKeyDataMsg = 'K' + noDataMsg = 'n' + passwordMessageMsg = 'p' + terminateMsg = 'X' + + saslInitialResponseMsg = 'p' + authenticationSASLContinueMsg = 'R' + saslResponseMsg = 'p' + authenticationSASLFinalMsg = 'R' + + authenticationOK = 0 + authenticationCleartextPassword = 3 + authenticationMD5Password = 5 + authenticationSASL = 10 + + notificationResponseMsg = 'A' + + describeMsg = 'D' + parameterDescriptionMsg = 't' + + queryMsg = 'Q' + readyForQueryMsg = 'Z' + emptyQueryResponseMsg = 'I' + rowDescriptionMsg = 'T' + dataRowMsg = 'D' + + parseMsg = 'P' + parseCompleteMsg = '1' + + bindMsg = 'B' + bindCompleteMsg = '2' + + executeMsg = 'E' + + syncMsg = 'S' + flushMsg = 'H' + + closeMsg = 'C' + closeCompleteMsg = '3' + + copyInResponseMsg = 'G' + copyOutResponseMsg = 'H' + copyDataMsg = 'd' + copyDoneMsg = 'c' +) + +var errEmptyQuery = internal.Errorf("pg: query is empty") + +func (db *baseDB) startup( + c context.Context, cn *pool.Conn, user, password, database, appName string, +) error { + err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writeStartupMsg(wb, user, database, appName) + return nil + }) + if err != nil { + return err + } + + return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + for { + typ, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch typ { + case backendKeyDataMsg: + processID, err := readInt32(rd) + if err != nil { + return err + } + secretKey, err := readInt32(rd) + if err != nil { + return err + } + cn.ProcessID = processID + cn.SecretKey = secretKey + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return err + } + case authenticationOKMsg: + err := db.auth(c, cn, rd, user, password) + if err != nil { + return err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + return err + case noticeResponseMsg: + // If we encounter a notice message from the server then we want to try to log it as it might be + // important for the client. If something goes wrong with this we want to fail. At the time of writing + // this the client will fail just encountering a notice during startup. So failing if a bad notice is + // sent is probably better than not failing, especially if we can try to log at least some data from the + // notice. + if err := db.logStartupNotice(rd); err != nil { + return err + } + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf("pg: unknown startup message response: %q", typ) + } + } + }) +} + +func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Config) error { + err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writeSSLMsg(wb) + return nil + }) + if err != nil { + return err + } + + err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + c, err := rd.ReadByte() + if err != nil { + return err + } + if c != 'S' { + return errors.New("pg: SSL is not enabled on the server") + } + return nil + }) + if err != nil { + return err + } + + cn.SetNetConn(tls.Client(cn.NetConn(), tlsConf)) + return nil +} + +func (db *baseDB) auth( + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, +) error { + num, err := readInt32(rd) + if err != nil { + return err + } + + switch num { + case authenticationOK: + return nil + case authenticationCleartextPassword: + return db.authCleartext(c, cn, rd, password) + case authenticationMD5Password: + return db.authMD5(c, cn, rd, user, password) + case authenticationSASL: + return db.authSASL(c, cn, rd, user, password) + default: + return fmt.Errorf("pg: unknown authentication message response: %q", num) + } +} + +// logStartupNotice will handle notice messages during the startup process. It will parse them and log them for the +// client. Notices are not common and only happen if there is something the client should be aware of. So logging should +// not be a problem. +// Notice messages can be seen in startup: https://www.postgresql.org/docs/13/protocol-flow.html +// Information on the notice message format: https://www.postgresql.org/docs/13/protocol-message-formats.html +// Note: This is true for earlier versions of PostgreSQL as well, I've just included the latest versions of the docs. +func (db *baseDB) logStartupNotice( + rd *pool.ReaderContext, +) error { + message := make([]string, 0) + // Notice messages are null byte delimited key-value pairs. Where the keys are one byte. + for { + // Read the key byte. + fieldType, err := rd.ReadByte() + if err != nil { + return err + } + + // If they key byte (the type of field this data is) is 0 then that means we have reached the end of the notice. + // We can break our loop here and throw our message data into the logger. + if fieldType == 0 { + break + } + + // Read until the next null byte to get the data for this field. This does include the null byte at the end of + // fieldValue so we will trim it off down below. + fieldValue, err := readString(rd) + if err != nil { + return err + } + + // Just throw the field type as a string and its value into an array. + // Field types can be seen here: https://www.postgresql.org/docs/13/protocol-error-fields.html + // TODO This is a rare occurrence as is, would it be worth adding something to indicate what the field names + // are? Or is PostgreSQL documentation enough for a user at this point? + message = append(message, fmt.Sprintf("%s: %s", string(fieldType), fieldValue)) + } + + // Tell the client what PostgreSQL told us. Warning because its probably something the client should at the very + // least adjust. + internal.Warn.Printf("notice during startup: %s", strings.Join(message, ", ")) + + return nil +} + +func (db *baseDB) authCleartext( + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, password string, +) error { + err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writePasswordMsg(wb, password) + return nil + }) + if err != nil { + return err + } + return readAuthOK(rd) +} + +func (db *baseDB) authMD5( + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, +) error { + b, err := rd.ReadN(4) + if err != nil { + return err + } + + secret := "md5" + md5s(md5s(password+user)+string(b)) + err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writePasswordMsg(wb, secret) + return nil + }) + if err != nil { + return err + } + + return readAuthOK(rd) +} + +func readAuthOK(rd *pool.ReaderContext) error { + c, _, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case authenticationOKMsg: + c0, err := readInt32(rd) + if err != nil { + return err + } + if c0 != 0 { + return fmt.Errorf("pg: unexpected authentication code: %q", c0) + } + return nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf("pg: unknown password message response: %q", c) + } +} + +func (db *baseDB) authSASL( + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, +) error { + s, err := readString(rd) + if err != nil { + return err + } + if s != "SCRAM-SHA-256" { + return fmt.Errorf("pg: SASL: got %q, wanted %q", s, "SCRAM-SHA-256") + } + + c0, err := rd.ReadByte() + if err != nil { + return err + } + if c0 != 0 { + return fmt.Errorf("pg: SASL: got %q, wanted %q", c0, 0) + } + + creds := sasl.Credentials(func() (Username, Password, Identity []byte) { + return []byte(user), []byte(password), nil + }) + client := sasl.NewClient(sasl.ScramSha256, creds) + + _, resp, err := client.Step(nil) + if err != nil { + return err + } + + err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + wb.StartMessage(saslInitialResponseMsg) + wb.WriteString("SCRAM-SHA-256") + wb.WriteInt32(int32(len(resp))) + _, err := wb.Write(resp) + if err != nil { + return err + } + wb.FinishMessage() + return nil + }) + if err != nil { + return err + } + + typ, n, err := readMessageType(rd) + if err != nil { + return err + } + + switch typ { + case authenticationSASLContinueMsg: + c11, err := readInt32(rd) + if err != nil { + return err + } + if c11 != 11 { + return fmt.Errorf("pg: SASL: got %q, wanted %q", typ, 11) + } + + b, err := rd.ReadN(n - 4) + if err != nil { + return err + } + + _, resp, err = client.Step(b) + if err != nil { + return err + } + + err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + wb.StartMessage(saslResponseMsg) + _, err := wb.Write(resp) + if err != nil { + return err + } + wb.FinishMessage() + return nil + }) + if err != nil { + return err + } + + return readAuthSASLFinal(rd, client) + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf( + "pg: SASL: got %q, wanted %q", typ, authenticationSASLContinueMsg) + } +} + +func readAuthSASLFinal(rd *pool.ReaderContext, client *sasl.Negotiator) error { + c, n, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case authenticationSASLFinalMsg: + c12, err := readInt32(rd) + if err != nil { + return err + } + if c12 != 12 { + return fmt.Errorf("pg: SASL: got %q, wanted %q", c, 12) + } + + b, err := rd.ReadN(n - 4) + if err != nil { + return err + } + + _, _, err = client.Step(b) + if err != nil { + return err + } + + if client.State() != sasl.ValidServerResponse { + return fmt.Errorf("pg: SASL: state=%q, wanted %q", + client.State(), sasl.ValidServerResponse) + } + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf( + "pg: SASL: got %q, wanted %q", c, authenticationSASLFinalMsg) + } + + return readAuthOK(rd) +} + +func md5s(s string) string { + //nolint + h := md5.Sum([]byte(s)) + return hex.EncodeToString(h[:]) +} + +func writeStartupMsg(buf *pool.WriteBuffer, user, database, appName string) { + buf.StartMessage(0) + buf.WriteInt32(196608) + buf.WriteString("user") + buf.WriteString(user) + buf.WriteString("database") + buf.WriteString(database) + if appName != "" { + buf.WriteString("application_name") + buf.WriteString(appName) + } + buf.WriteString("") + buf.FinishMessage() +} + +func writeSSLMsg(buf *pool.WriteBuffer) { + buf.StartMessage(0) + buf.WriteInt32(80877103) + buf.FinishMessage() +} + +func writePasswordMsg(buf *pool.WriteBuffer, password string) { + buf.StartMessage(passwordMessageMsg) + buf.WriteString(password) + buf.FinishMessage() +} + +func writeFlushMsg(buf *pool.WriteBuffer) { + buf.StartMessage(flushMsg) + buf.FinishMessage() +} + +func writeCancelRequestMsg(buf *pool.WriteBuffer, processID, secretKey int32) { + buf.StartMessage(0) + buf.WriteInt32(80877102) + buf.WriteInt32(processID) + buf.WriteInt32(secretKey) + buf.FinishMessage() +} + +func writeQueryMsg( + buf *pool.WriteBuffer, + fmter orm.QueryFormatter, + query interface{}, + params ...interface{}, +) error { + buf.StartMessage(queryMsg) + bytes, err := appendQuery(fmter, buf.Bytes, query, params...) + if err != nil { + return err + } + buf.Bytes = bytes + err = buf.WriteByte(0x0) + if err != nil { + return err + } + buf.FinishMessage() + return nil +} + +func appendQuery(fmter orm.QueryFormatter, dst []byte, query interface{}, params ...interface{}) ([]byte, error) { + switch query := query.(type) { + case orm.QueryAppender: + if v, ok := fmter.(*orm.Formatter); ok { + fmter = v.WithModel(query) + } + return query.AppendQuery(fmter, dst) + case string: + if len(params) > 0 { + model, ok := params[len(params)-1].(orm.TableModel) + if ok { + if v, ok := fmter.(*orm.Formatter); ok { + fmter = v.WithTableModel(model) + params = params[:len(params)-1] + } + } + } + return fmter.FormatQuery(dst, query, params...), nil + default: + return nil, fmt.Errorf("pg: can't append %T", query) + } +} + +func writeSyncMsg(buf *pool.WriteBuffer) { + buf.StartMessage(syncMsg) + buf.FinishMessage() +} + +func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) { + buf.StartMessage(parseMsg) + buf.WriteString(name) + buf.WriteString(q) + buf.WriteInt16(0) + buf.FinishMessage() + + buf.StartMessage(describeMsg) + buf.WriteByte('S') //nolint + buf.WriteString(name) + buf.FinishMessage() + + writeSyncMsg(buf) +} + +func readParseDescribeSync(rd *pool.ReaderContext) ([]types.ColumnInfo, error) { + var columns []types.ColumnInfo + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + switch c { + case parseCompleteMsg: + _, err = rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case rowDescriptionMsg: // Response to the DESCRIBE message. + columns, err = readRowDescription(rd, pool.NewColumnAlloc()) + if err != nil { + return nil, err + } + case parameterDescriptionMsg: // Response to the DESCRIBE message. + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case noDataMsg: // Response to the DESCRIBE message. + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return columns, err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readParseDescribeSync: unexpected message %q", c) + } + } +} + +// Writes BIND, EXECUTE and SYNC messages. +func writeBindExecuteMsg(buf *pool.WriteBuffer, name string, params ...interface{}) error { + buf.StartMessage(bindMsg) + buf.WriteString("") + buf.WriteString(name) + buf.WriteInt16(0) + buf.WriteInt16(int16(len(params))) + for _, param := range params { + buf.StartParam() + bytes := types.Append(buf.Bytes, param, 0) + if bytes != nil { + buf.Bytes = bytes + buf.FinishParam() + } else { + buf.FinishNullParam() + } + } + buf.WriteInt16(0) + buf.FinishMessage() + + buf.StartMessage(executeMsg) + buf.WriteString("") + buf.WriteInt32(0) + buf.FinishMessage() + + writeSyncMsg(buf) + + return nil +} + +func writeCloseMsg(buf *pool.WriteBuffer, name string) { + buf.StartMessage(closeMsg) + buf.WriteByte('S') //nolint + buf.WriteString(name) + buf.FinishMessage() +} + +func readCloseCompleteMsg(rd *pool.ReaderContext) error { + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + switch c { + case closeCompleteMsg: + _, err := rd.ReadN(msgLen) + return err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return err + } + default: + return fmt.Errorf("pg: readCloseCompleteMsg: unexpected message %q", c) + } + } +} + +func readSimpleQuery(rd *pool.ReaderContext) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case commandCompleteMsg: + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case rowDescriptionMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case dataRowMsg: + if _, err := rd.Discard(msgLen); err != nil { + return nil, err + } + res.returned++ + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readSimpleQuery: unexpected message %q", c) + } + } +} + +func readExtQuery(rd *pool.ReaderContext) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case bindCompleteMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case dataRowMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + res.returned++ + case commandCompleteMsg: // Response to the EXECUTE message. + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: // Response to the SYNC message. + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readExtQuery: unexpected message %q", c) + } + } +} + +func readRowDescription( + rd *pool.ReaderContext, columnAlloc *pool.ColumnAlloc, +) ([]types.ColumnInfo, error) { + numCol, err := readInt16(rd) + if err != nil { + return nil, err + } + + for i := 0; i < int(numCol); i++ { + b, err := rd.ReadSlice(0) + if err != nil { + return nil, err + } + + col := columnAlloc.New(int16(i), b[:len(b)-1]) + + if _, err := rd.ReadN(6); err != nil { + return nil, err + } + + dataType, err := readInt32(rd) + if err != nil { + return nil, err + } + col.DataType = dataType + + if _, err := rd.ReadN(8); err != nil { + return nil, err + } + } + + return columnAlloc.Columns(), nil +} + +func readDataRow( + ctx context.Context, + rd *pool.ReaderContext, + columns []types.ColumnInfo, + scanner orm.ColumnScanner, +) error { + numCol, err := readInt16(rd) + if err != nil { + return err + } + + if h, ok := scanner.(orm.BeforeScanHook); ok { + if err := h.BeforeScan(ctx); err != nil { + return err + } + } + + var firstErr error + + for colIdx := int16(0); colIdx < numCol; colIdx++ { + n, err := readInt32(rd) + if err != nil { + return err + } + + var colRd types.Reader + if int(n) <= rd.Buffered() { + colRd = rd.BytesReader(int(n)) + } else { + rd.SetAvailable(int(n)) + colRd = rd + } + + column := columns[colIdx] + if err := scanner.ScanColumn(column, colRd, int(n)); err != nil && firstErr == nil { + firstErr = internal.Errorf(err.Error()) + } + + if rd == colRd { + if rd.Available() > 0 { + if _, err := rd.Discard(rd.Available()); err != nil && firstErr == nil { + firstErr = err + } + } + rd.SetAvailable(-1) + } + } + + if h, ok := scanner.(orm.AfterScanHook); ok { + if err := h.AfterScan(ctx); err != nil { + return err + } + } + + return firstErr +} + +func newModel(mod interface{}) (orm.Model, error) { + m, err := orm.NewModel(mod) + if err != nil { + return nil, err + } + return m, m.Init() +} + +func readSimpleQueryData( + ctx context.Context, rd *pool.ReaderContext, mod interface{}, +) (*result, error) { + var columns []types.ColumnInfo + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case rowDescriptionMsg: + columns, err = readRowDescription(rd, rd.ColumnAlloc) + if err != nil { + return nil, err + } + + if res.model == nil { + var err error + res.model, err = newModel(mod) + if err != nil { + if firstErr == nil { + firstErr = err + } + res.model = Discard + } + } + case dataRowMsg: + scanner := res.model.NextColumnScanner() + if err := readDataRow(ctx, rd, columns, scanner); err != nil { + if firstErr == nil { + firstErr = err + } + } else if err := res.model.AddColumnScanner(scanner); err != nil { + if firstErr == nil { + firstErr = err + } + } + + res.returned++ + case commandCompleteMsg: + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readSimpleQueryData: unexpected message %q", c) + } + } +} + +func readExtQueryData( + ctx context.Context, rd *pool.ReaderContext, mod interface{}, columns []types.ColumnInfo, +) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case bindCompleteMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case dataRowMsg: + if res.model == nil { + var err error + res.model, err = newModel(mod) + if err != nil { + if firstErr == nil { + firstErr = err + } + res.model = Discard + } + } + + scanner := res.model.NextColumnScanner() + if err := readDataRow(ctx, rd, columns, scanner); err != nil { + if firstErr == nil { + firstErr = err + } + } else if err := res.model.AddColumnScanner(scanner); err != nil { + if firstErr == nil { + firstErr = err + } + } + + res.returned++ + case commandCompleteMsg: // Response to the EXECUTE message. + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: // Response to the SYNC message. + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readExtQueryData: unexpected message %q", c) + } + } +} + +func readCopyInResponse(rd *pool.ReaderContext) error { + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case copyInResponseMsg: + _, err := rd.ReadN(msgLen) + return err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + if firstErr == nil { + firstErr = e + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return err + } + return firstErr + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return err + } + default: + return fmt.Errorf("pg: readCopyInResponse: unexpected message %q", c) + } + } +} + +func readCopyOutResponse(rd *pool.ReaderContext) error { + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case copyOutResponseMsg: + _, err := rd.ReadN(msgLen) + return err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + if firstErr == nil { + firstErr = e + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return err + } + return firstErr + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return err + } + default: + return fmt.Errorf("pg: readCopyOutResponse: unexpected message %q", c) + } + } +} + +func readCopyData(rd *pool.ReaderContext, w io.Writer) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case copyDataMsg: + for msgLen > 0 { + b, err := rd.ReadN(msgLen) + if err != nil && err != bufio.ErrBufferFull { + return nil, err + } + + _, err = w.Write(b) + if err != nil { + return nil, err + } + + msgLen -= len(b) + } + case copyDoneMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case commandCompleteMsg: + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + return nil, e + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readCopyData: unexpected message %q", c) + } + } +} + +func writeCopyData(buf *pool.WriteBuffer, r io.Reader) error { + buf.StartMessage(copyDataMsg) + _, err := buf.ReadFrom(r) + buf.FinishMessage() + return err +} + +func writeCopyDone(buf *pool.WriteBuffer) { + buf.StartMessage(copyDoneMsg) + buf.FinishMessage() +} + +func readReadyForQuery(rd *pool.ReaderContext) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case commandCompleteMsg: + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readReadyForQueryOrError: unexpected message %q", c) + } + } +} + +func readNotification(rd *pool.ReaderContext) (channel, payload string, err error) { + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return "", "", err + } + + switch c { + case commandCompleteMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return "", "", err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return "", "", err + } + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return "", "", err + } + return "", "", e + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return "", "", err + } + case notificationResponseMsg: + _, err := readInt32(rd) + if err != nil { + return "", "", err + } + channel, err = readString(rd) + if err != nil { + return "", "", err + } + payload, err = readString(rd) + if err != nil { + return "", "", err + } + return channel, payload, nil + default: + return "", "", fmt.Errorf("pg: readNotification: unexpected message %q", c) + } + } +} + +var terminateMessage = []byte{terminateMsg, 0, 0, 0, 4} + +func terminateConn(cn *pool.Conn) error { + // Don't use cn.Buf because it is racy with user code. + _, err := cn.NetConn().Write(terminateMessage) + return err +} + +//------------------------------------------------------------------------------ + +func logNotice(rd *pool.ReaderContext, msgLen int) error { + _, err := rd.ReadN(msgLen) + return err +} + +func logParameterStatus(rd *pool.ReaderContext, msgLen int) error { + _, err := rd.ReadN(msgLen) + return err +} + +func readInt16(rd *pool.ReaderContext) (int16, error) { + b, err := rd.ReadN(2) + if err != nil { + return 0, err + } + return int16(binary.BigEndian.Uint16(b)), nil +} + +func readInt32(rd *pool.ReaderContext) (int32, error) { + b, err := rd.ReadN(4) + if err != nil { + return 0, err + } + return int32(binary.BigEndian.Uint32(b)), nil +} + +func readString(rd *pool.ReaderContext) (string, error) { + b, err := rd.ReadSlice(0) + if err != nil { + return "", err + } + return string(b[:len(b)-1]), nil +} + +func readError(rd *pool.ReaderContext) (error, error) { + m := make(map[byte]string) + for { + c, err := rd.ReadByte() + if err != nil { + return nil, err + } + if c == 0 { + break + } + s, err := readString(rd) + if err != nil { + return nil, err + } + m[c] = s + } + return internal.NewPGError(m), nil +} + +func readMessageType(rd *pool.ReaderContext) (byte, int, error) { + c, err := rd.ReadByte() + if err != nil { + return 0, 0, err + } + l, err := readInt32(rd) + if err != nil { + return 0, 0, err + } + return c, int(l) - 4, nil +} |