Blob Blame History Raw
diff --git a/src/twisted/conch/ssh/transport.py b/src/twisted/conch/ssh/transport.py
index bd76b0a..a477d27 100644
--- a/src/twisted/conch/ssh/transport.py
+++ b/src/twisted/conch/ssh/transport.py
@@ -677,6 +677,14 @@ class SSHTransportBase(protocol.Protocol):
         """
         self.buf = self.buf + data
         if not self.gotVersion:
+            if len(self.buf) > 4096:
+                self.sendDisconnect(
+                    DISCONNECT_CONNECTION_LOST,
+                    b"Peer version string longer than 4KB. "
+                    b"Preventing a denial of service attack.",
+                )
+                return
+
             if self.buf.find(b'\n', self.buf.find(b'SSH-')) == -1:
                 return
 
diff --git a/src/twisted/conch/test/test_transport.py b/src/twisted/conch/test/test_transport.py
index 98a3515..449dd3f 100644
--- a/src/twisted/conch/test/test_transport.py
+++ b/src/twisted/conch/test/test_transport.py
@@ -522,6 +522,27 @@ class BaseSSHTransportTests(BaseSSHTransportBaseCase, TransportTestCase):
             r')*$')
         self.assertRegex(softwareVersion, softwareVersionRegex)
 
+    def test_dataReceiveVersionNotSentMemoryDOS(self):
+        """
+        When the peer is not sending its SSH version but keeps sending data,
+        the connection is disconnected after 4KB to prevent buffering too
+        much and running our of memory.
+        """
+        sut = MockTransportBase()
+        sut.makeConnection(self.transport)
+
+        # Data can be received over multiple chunks.
+        sut.dataReceived(b"SSH-2-Server-Identifier")
+        sut.dataReceived(b"1234567890" * 406)
+        sut.dataReceived(b"1235678")
+        self.assertFalse(self.transport.disconnecting)
+
+        # Here we are going over the limit.
+        sut.dataReceived(b"1234567")
+        # Once a lot of data is received without an SSH version string,
+        # the transport is disconnected.
+        self.assertTrue(self.transport.disconnecting)
+        self.assertIn(b"Preventing a denial of service attack", self.transport.value())
 
     def test_sendPacketPlain(self):
         """