summaryrefslogtreecommitdiff
path: root/vendor/golang.org/x/crypto/ssh/transport.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/crypto/ssh/transport.go')
-rw-r--r--vendor/golang.org/x/crypto/ssh/transport.go32
1 files changed, 27 insertions, 5 deletions
diff --git a/vendor/golang.org/x/crypto/ssh/transport.go b/vendor/golang.org/x/crypto/ssh/transport.go
index da015801e..0424d2d37 100644
--- a/vendor/golang.org/x/crypto/ssh/transport.go
+++ b/vendor/golang.org/x/crypto/ssh/transport.go
@@ -49,6 +49,9 @@ type transport struct {
rand io.Reader
isClient bool
io.Closer
+
+ strictMode bool
+ initialKEXDone bool
}
// packetCipher represents a combination of SSH encryption/MAC
@@ -74,6 +77,18 @@ type connectionState struct {
pendingKeyChange chan packetCipher
}
+func (t *transport) setStrictMode() error {
+ if t.reader.seqNum != 1 {
+ return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
+ }
+ t.strictMode = true
+ return nil
+}
+
+func (t *transport) setInitialKEXDone() {
+ t.initialKEXDone = true
+}
+
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
@@ -112,11 +127,12 @@ func (t *transport) printPacket(p []byte, write bool) {
// Read and decrypt next packet.
func (t *transport) readPacket() (p []byte, err error) {
for {
- p, err = t.reader.readPacket(t.bufReader)
+ p, err = t.reader.readPacket(t.bufReader, t.strictMode)
if err != nil {
break
}
- if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) {
+ // in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
+ if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
break
}
}
@@ -127,7 +143,7 @@ func (t *transport) readPacket() (p []byte, err error) {
return p, err
}
-func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
+func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
s.seqNum++
if err == nil && len(packet) == 0 {
@@ -140,6 +156,9 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
+ if strictMode {
+ s.seqNum = 0
+ }
default:
return nil, errors.New("ssh: got bogus newkeys message")
}
@@ -170,10 +189,10 @@ func (t *transport) writePacket(packet []byte) error {
if debugTransport {
t.printPacket(packet, true)
}
- return t.writer.writePacket(t.bufWriter, t.rand, packet)
+ return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
}
-func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
+func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
@@ -188,6 +207,9 @@ func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
+ if strictMode {
+ s.seqNum = 0
+ }
default:
panic("ssh: no key material for msgNewKeys")
}