Blob Blame History Raw
--- a/paramiko/__init__.py
+++ b/paramiko/__init__.py
@@ -19,12 +19,15 @@
 # flake8: noqa
 import sys
 from paramiko._version import __version__, __version_info__
-from paramiko.transport import SecurityOptions, Transport
+from paramiko.transport import (
+    SecurityOptions,
+    Transport,
+)
 from paramiko.client import (
-    SSHClient,
-    MissingHostKeyPolicy,
     AutoAddPolicy,
+    MissingHostKeyPolicy,
     RejectPolicy,
+    SSHClient,
     WarningPolicy,
 )
 from paramiko.auth_handler import AuthHandler
@@ -43,6 +46,7 @@ from paramiko.ssh_exception import (
     ConfigParseError,
     CouldNotCanonicalize,
     IncompatiblePeer,
+    MessageOrderError,
     PasswordRequiredException,
     ProxyCommandFailure,
     SSHException,
--- a/paramiko/packet.py
+++ b/paramiko/packet.py
@@ -86,6 +86,7 @@ class Packetizer(object):
         self.__need_rekey = False
         self.__init_count = 0
         self.__remainder = bytes()
+        self._initial_kex_done = False
 
         # used for noticing when to re-key:
         self.__sent_bytes = 0
@@ -130,6 +131,12 @@ class Packetizer(object):
     def closed(self):
         return self.__closed
 
+    def reset_seqno_out(self):
+        self.__sequence_number_out = 0
+
+    def reset_seqno_in(self):
+        self.__sequence_number_in = 0
+
     def set_log(self, log):
         """
         Set the Python log object to use for logging.
@@ -425,9 +432,12 @@ class Packetizer(object):
                 out += compute_hmac(
                     self.__mac_key_out, payload, self.__mac_engine_out
                 )[: self.__mac_size_out]
-            self.__sequence_number_out = (
-                self.__sequence_number_out + 1
-            ) & xffffffff
+            next_seq = (self.__sequence_number_out + 1) & xffffffff
+            if next_seq == 0 and not self._initial_kex_done:
+                raise SSHException(
+                    "Sequence number rolled over during initial kex!"
+                )
+            self.__sequence_number_out = next_seq
             self.write_all(out)
 
             self.__sent_bytes += len(out)
@@ -531,7 +541,12 @@ class Packetizer(object):
 
         msg = Message(payload[1:])
         msg.seqno = self.__sequence_number_in
-        self.__sequence_number_in = (self.__sequence_number_in + 1) & xffffffff
+        next_seq = (self.__sequence_number_in + 1) & xffffffff
+        if next_seq == 0 and not self._initial_kex_done:
+            raise SSHException(
+                "Sequence number rolled over during initial kex!"
+            )
+        self.__sequence_number_in = next_seq
 
         # check for rekey
         raw_packet_size = packet_size + self.__mac_size_in + 4
--- a/paramiko/ssh_exception.py
+++ b/paramiko/ssh_exception.py
@@ -235,3 +235,13 @@ class ConfigParseError(SSHException):
     """
 
     pass
+
+
+class MessageOrderError(SSHException):
+    """
+    Out-of-order protocol messages were received, violating "strict kex" mode.
+
+    .. versionadded:: 3.4
+    """
+
+    pass
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -106,11 +106,12 @@ from paramiko.ecdsakey import ECDSAKey
 from paramiko.server import ServerInterface
 from paramiko.sftp_client import SFTPClient
 from paramiko.ssh_exception import (
-    SSHException,
     BadAuthenticationType,
     ChannelException,
     IncompatiblePeer,
+    MessageOrderError,
     ProxyCommandFailure,
+    SSHException,
 )
 from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value
 
@@ -329,6 +330,8 @@ class Transport(threading.Thread, Closin
         gss_deleg_creds=True,
         disabled_algorithms=None,
         server_sig_algs=True,
+        strict_kex=True,
+        packetizer_class=None,
     ):
         """
         Create a new SSH session over an existing socket, or socket-like
@@ -395,6 +398,13 @@ class Transport(threading.Thread, Closin
             Whether to send an extra message to compatible clients, in server
             mode, with a list of supported pubkey algorithms. Default:
             ``True``.
+        :param bool strict_kex:
+            Whether to advertise (and implement, if client also advertises
+            support for) a "strict kex" mode for safer handshaking. Default:
+            ``True``.
+        :param packetizer_class:
+            Which class to use for instantiating the internal packet handler.
+            Default: ``None`` (i.e.: use `Packetizer` as normal).
 
         .. versionchanged:: 1.15
             Added the ``default_window_size`` and ``default_max_packet_size``
@@ -405,10 +415,16 @@ class Transport(threading.Thread, Closin
             Added the ``disabled_algorithms`` kwarg.
         .. versionchanged:: 2.9
             Added the ``server_sig_algs`` kwarg.
+        .. versionchanged:: 3.4
+            Added the ``strict_kex`` kwarg.
+        .. versionchanged:: 3.4
+            Added the ``packetizer_class`` kwarg.
         """
         self.active = False
         self.hostname = None
         self.server_extensions = {}
+        self.advertise_strict_kex = strict_kex
+        self.agreed_on_strict_kex = False
 
         if isinstance(sock, string_types):
             # convert "host:port" into (host, port)
@@ -450,7 +466,7 @@ class Transport(threading.Thread, Closin
         self.sock.settimeout(self._active_check_timeout)
 
         # negotiated crypto parameters
-        self.packetizer = Packetizer(sock)
+        self.packetizer = (packetizer_class or Packetizer)(sock)
         self.local_version = "SSH-" + self._PROTO_ID + "-" + self._CLIENT_ID
         self.remote_version = ""
         self.local_cipher = self.remote_cipher = ""
@@ -524,6 +540,20 @@ class Transport(threading.Thread, Closin
         self.server_accept_cv = threading.Condition(self.lock)
         self.subsystem_table = {}
 
+        # Handler table, now set at init time for easier per-instance
+        # manipulation and subclass twiddling.
+        self._handler_table = {
+            MSG_EXT_INFO: self._parse_ext_info,
+            MSG_NEWKEYS: self._parse_newkeys,
+            MSG_GLOBAL_REQUEST: self._parse_global_request,
+            MSG_REQUEST_SUCCESS: self._parse_request_success,
+            MSG_REQUEST_FAILURE: self._parse_request_failure,
+            MSG_CHANNEL_OPEN_SUCCESS: self._parse_channel_open_success,
+            MSG_CHANNEL_OPEN_FAILURE: self._parse_channel_open_failure,
+            MSG_CHANNEL_OPEN: self._parse_channel_open,
+            MSG_KEXINIT: self._negotiate_keys,
+        }
+
     def _filter_algorithm(self, type_):
         default = getattr(self, "_preferred_{}".format(type_))
         return tuple(
@@ -2067,6 +2097,20 @@ class Transport(threading.Thread, Closin
         # be empty.)
         return reply
 
+    def _enforce_strict_kex(self, ptype):
+        """
+        Conditionally raise `MessageOrderError` during strict initial kex.
+
+        This method should only be called inside code that handles non-KEXINIT
+        messages; it does not interrogate ``ptype`` besides using it to log
+        more accurately.
+        """
+        if self.agreed_on_strict_kex and not self.initial_kex_done:
+            name = MSG_NAMES.get(ptype, f"msg {ptype}")
+            raise MessageOrderError(
+                f"In strict-kex mode, but was sent {name!r}!"
+            )
+
     def run(self):
         # (use the exposed "run" method, because if we specify a thread target
         # of a private method, threading.Thread will keep a reference to it
@@ -2111,16 +2155,21 @@ class Transport(threading.Thread, Closin
                     except NeedRekeyException:
                         continue
                     if ptype == MSG_IGNORE:
+                        self._enforce_strict_kex(ptype)
                         continue
                     elif ptype == MSG_DISCONNECT:
                         self._parse_disconnect(m)
                         break
                     elif ptype == MSG_DEBUG:
+                        self._enforce_strict_kex(ptype)
                         self._parse_debug(m)
                         continue
                     if len(self._expected_packet) > 0:
                         if ptype not in self._expected_packet:
-                            raise SSHException(
+                            exc_class = SSHException
+                            if self.agreed_on_strict_kex:
+                                exc_class = MessageOrderError
+                            raise exc_class(
                                 "Expecting packet from {!r}, got {:d}".format(
                                     self._expected_packet, ptype
                                 )
@@ -2135,7 +2184,7 @@ class Transport(threading.Thread, Closin
                         if error_msg:
                             self._send_message(error_msg)
                         else:
-                            self._handler_table[ptype](self, m)
+                            self._handler_table[ptype](m)
                     elif ptype in self._channel_handler_table:
                         chanid = m.get_int()
                         chan = self._channels.get(chanid)
@@ -2342,12 +2391,18 @@ class Transport(threading.Thread, Closin
             )
         else:
             available_server_keys = self.preferred_keys
-            # Signal support for MSG_EXT_INFO.
+            # Signal support for MSG_EXT_INFO so server will send it to us.
             # NOTE: doing this here handily means we don't even consider this
             # value when agreeing on real kex algo to use (which is a common
             # pitfall when adding this apparently).
             kex_algos.append("ext-info-c")
 
+        # Similar to ext-info, but used in both server modes, so done outside
+        # of above if/else.
+        if self.advertise_strict_kex:
+            which = "s" if self.server_mode else "c"
+            kex_algos.append(f"kex-strict-{which}-v00@openssh.com")
+
         m = Message()
         m.add_byte(cMSG_KEXINIT)
         m.add_bytes(os.urandom(16))
@@ -2388,7 +2443,8 @@ class Transport(threading.Thread, Closin
 
     def _get_latest_kex_init(self):
         return self._really_parse_kex_init(
-            Message(self._latest_kex_init), ignore_first_byte=True
+            Message(self._latest_kex_init),
+            ignore_first_byte=True,
         )
 
     def _parse_kex_init(self, m):
@@ -2427,10 +2483,39 @@ class Transport(threading.Thread, Closin
         self._log(DEBUG, "kex follows: {}".format(kex_follows))
         self._log(DEBUG, "=== Key exchange agreements ===")
 
-        # Strip out ext-info "kex algo"
+        # Record, and strip out, ext-info and/or strict-kex non-algorithms
         self._remote_ext_info = None
-        if kex_algo_list[-1].startswith("ext-info-"):
-            self._remote_ext_info = kex_algo_list.pop()
+        self._remote_strict_kex = None
+        to_pop = []
+        for i, algo in enumerate(kex_algo_list):
+            if algo.startswith("ext-info-"):
+                self._remote_ext_info = algo
+                to_pop.insert(0, i)
+            elif algo.startswith("kex-strict-"):
+                # NOTE: this is what we are expecting from the /remote/ end.
+                which = "c" if self.server_mode else "s"
+                expected = f"kex-strict-{which}-v00@openssh.com"
+                # Set strict mode if agreed.
+                self.agreed_on_strict_kex = (
+                    algo == expected and self.advertise_strict_kex
+                )
+                self._log(
+                    DEBUG, f"Strict kex mode: {self.agreed_on_strict_kex}"
+                )
+                to_pop.insert(0, i)
+        for i in to_pop:
+            kex_algo_list.pop(i)
+
+        # CVE mitigation: expect zeroed-out seqno anytime we are performing kex
+        # init phase, if strict mode was negotiated.
+        if (
+            self.agreed_on_strict_kex
+            and not self.initial_kex_done
+            and m.seqno != 0
+        ):
+            raise MessageOrderError(
+                "In strict-kex mode, but KEXINIT was not the first packet!"
+            )
 
         # as a server, we pick the first item in the client's list that we
         # support.
@@ -2631,6 +2716,13 @@ class Transport(threading.Thread, Closin
         ):
             self._log(DEBUG, "Switching on inbound compression ...")
             self.packetizer.set_inbound_compressor(compress_in())
+        # Reset inbound sequence number if strict mode.
+        if self.agreed_on_strict_kex:
+            self._log(
+                DEBUG,
+                "Resetting inbound seqno after NEWKEYS due to strict mode",
+            )
+            self.packetizer.reset_seqno_in()
 
     def _activate_outbound(self):
         """switch on newly negotiated encryption parameters for
@@ -2638,6 +2730,13 @@ class Transport(threading.Thread, Closin
         m = Message()
         m.add_byte(cMSG_NEWKEYS)
         self._send_message(m)
+        # Reset outbound sequence number if strict mode.
+        if self.agreed_on_strict_kex:
+            self._log(
+                DEBUG,
+                "Resetting outbound seqno after NEWKEYS due to strict mode",
+            )
+            self.packetizer.reset_seqno_out()
         block_size = self._cipher_info[self.local_cipher]["block-size"]
         if self.server_mode:
             IV_out = self._compute_key("B", block_size)
@@ -2728,7 +2827,9 @@ class Transport(threading.Thread, Closin
             self.auth_handler = AuthHandler(self)
         if not self.initial_kex_done:
             # this was the first key exchange
-            self.initial_kex_done = True
+            # (also signal to packetizer as it sometimes wants to know this
+            # status as well, eg when seqnos rollover)
+            self.initial_kex_done = self.packetizer._initial_kex_done = True
         # send an event?
         if self.completion_event is not None:
             self.completion_event.set()
@@ -2982,18 +3083,6 @@ class Transport(threading.Thread, Closin
         finally:
             self.lock.release()
 
-    _handler_table = {
-        MSG_EXT_INFO: _parse_ext_info,
-        MSG_NEWKEYS: _parse_newkeys,
-        MSG_GLOBAL_REQUEST: _parse_global_request,
-        MSG_REQUEST_SUCCESS: _parse_request_success,
-        MSG_REQUEST_FAILURE: _parse_request_failure,
-        MSG_CHANNEL_OPEN_SUCCESS: _parse_channel_open_success,
-        MSG_CHANNEL_OPEN_FAILURE: _parse_channel_open_failure,
-        MSG_CHANNEL_OPEN: _parse_channel_open,
-        MSG_KEXINIT: _negotiate_keys,
-    }
-
     _channel_handler_table = {
         MSG_CHANNEL_SUCCESS: Channel._request_success,
         MSG_CHANNEL_FAILURE: Channel._request_failed,
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -23,12 +23,14 @@ Some unit tests for the ssh2 protocol in
 from __future__ import with_statement
 
 from binascii import hexlify
+import itertools
 from contextlib import contextmanager
 import select
 import socket
 import time
 import threading
 import random
+import sys
 import unittest
 
 try:
@@ -37,14 +39,15 @@ except ImportError:
     from mock import Mock
 
 from paramiko import (
+    AuthenticationException,
     AuthHandler,
     ChannelException,
     DSSKey,
+    IncompatiblePeer,
+    MessageOrderError,
     Packetizer,
     RSAKey,
     SSHException,
-    AuthenticationException,
-    IncompatiblePeer,
     SecurityOptions,
     ServerInterface,
     Transport,
@@ -57,7 +60,11 @@ from paramiko.common import (
     MAX_WINDOW_SIZE,
     MIN_PACKET_SIZE,
     MIN_WINDOW_SIZE,
+    MSG_CHANNEL_OPEN,
+    MSG_DEBUG,
+    MSG_IGNORE,
     MSG_KEXINIT,
+    MSG_UNIMPLEMENTED,
     MSG_USERAUTH_SUCCESS,
     cMSG_CHANNEL_WINDOW_ADJUST,
     cMSG_UNIMPLEMENTED,
@@ -67,6 +74,7 @@ from paramiko.message import Message
 
 from .util import needs_builtin, _support, requires_sha1_signing, slow
 from .loop import LoopSocket
+from pytest import mark, raises
 
 
 LONG_BANNER = """\
@@ -154,6 +162,10 @@ class NullServer(ServerInterface):
         self._tcpip_dest = destination
         return OPEN_SUCCEEDED
 
+# Faux 'packet type' we do not implement and are unlikely ever to (but which is
+# technically "within spec" re RFC 4251
+MSG_FUGGEDABOUTIT = 253
+
 
 class TransportTest(unittest.TestCase):
     def setUp(self):
@@ -1119,6 +1131,16 @@ class TransportTest(unittest.TestCase):
         # Real fix's behavior
         self._expect_unimplemented()
 
+    def test_can_override_packetizer_used(self):
+        class MyPacketizer(Packetizer):
+            pass
+
+        # control case
+        assert Transport(sock=LoopSocket()).packetizer.__class__ is Packetizer
+        # overridden case
+        tweaked = Transport(sock=LoopSocket(), packetizer_class=MyPacketizer)
+        assert tweaked.packetizer.__class__ is MyPacketizer
+
 
 class AlgorithmDisablingTests(unittest.TestCase):
     def test_preferred_lists_default_to_private_attribute_contents(self):
@@ -1202,10 +1224,17 @@ def server(
     connect=None,
     pubkeys=None,
     catch_error=False,
+    transport_factory=None,
+    server_transport_factory=None,
+    defer=False,
+    skip_verify=False,
 ):
     """
     SSH server contextmanager for testing.
 
+    Yields a tuple of ``(tc, ts)`` (client- and server-side `Transport`
+    objects), or ``(tc, ts, err)`` when ``catch_error==True``.
+
     :param hostkey:
         Host key to use for the server; if None, loads
         ``test_rsa.key``.
@@ -1222,6 +1251,17 @@ def server(
     :param catch_error:
         Whether to capture connection errors & yield from contextmanager.
         Necessary for connection_time exception testing.
+    :param transport_factory:
+        Like the same-named param in SSHClient: which Transport class to use.
+    :param server_transport_factory:
+        Like ``transport_factory``, but only impacts the server transport.
+    :param bool defer:
+        Whether to defer authentication during connecting.
+
+        This is really just shorthand for ``connect={}`` which would do roughly
+        the same thing. Also: this implies skip_verify=True automatically!
+    :param bool skip_verify:
+        Whether NOT to do the default "make sure auth passed" check.
     """
     if init is None:
         init = {}
@@ -1230,12 +1270,21 @@ def server(
     if client_init is None:
         client_init = {}
     if connect is None:
-        connect = dict(username="slowdive", password="pygmalion")
+        # No auth at all please
+        if defer:
+            connect = dict()
+        # Default username based auth
+        else:
+            connect = dict(username="slowdive", password="pygmalion")
     socks = LoopSocket()
     sockc = LoopSocket()
     sockc.link(socks)
-    tc = Transport(sockc, **dict(init, **client_init))
-    ts = Transport(socks, **dict(init, **server_init))
+    if transport_factory is None:
+        transport_factory = Transport
+    if server_transport_factory is None:
+        server_transport_factory = transport_factory
+    tc = transport_factory(sockc, **dict(init, **client_init))
+    ts = server_transport_factory(socks, **dict(init, **server_init))
 
     if hostkey is None:
         hostkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
@@ -1354,10 +1403,14 @@ class TestSHA2SignatureKeyExchange(unitt
 
 
 class TestExtInfo(unittest.TestCase):
-    def test_ext_info_handshake(self):
+    def test_ext_info_handshake_exposed_in_client_kexinit(self):
         with server() as (tc, _):
+            # NOTE: this is latest KEXINIT /sent by us/ (Transport retains it)
             kex = tc._get_latest_kex_init()
-            assert kex["kex_algo_list"][-1] == "ext-info-c"
+            # flag in KexAlgorithms list
+            assert "ext-info-c" in kex["kex_algo_list"]
+            # data stored on Transport after hearing back from a compatible
+            # server (such as ourselves in server mode)
             assert tc.server_extensions == {
                 "server-sig-algs": b"ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-512,rsa-sha2-256,ssh-rsa,ssh-dss"  # noqa
             }
@@ -1463,3 +1516,187 @@ class TestSHA2SignaturePubkeys(unittest.
         ) as (tc, ts):
             assert tc.is_authenticated()
             assert tc._agreed_pubkey_algorithm == "rsa-sha2-256"
+
+
+class BadSeqPacketizer(Packetizer):
+    def read_message(self):
+        cmd, msg = super().read_message()
+        # Only mess w/ seqno if kexinit.
+        if cmd is MSG_KEXINIT:
+            # NOTE: this is /only/ the copy of the seqno which gets
+            # transmitted up from Packetizer; it's not modifying
+            # Packetizer's own internal seqno. For these tests,
+            # modifying the latter isn't required, and is also harder
+            # to do w/o triggering MAC mismatches.
+            msg.seqno = 17  # arbitrary nonzero int
+        return cmd, msg
+
+
+class TestStrictKex:
+    def test_kex_algos_includes_kex_strict_c(self):
+        with server() as (tc, _):
+            kex = tc._get_latest_kex_init()
+            assert "kex-strict-c-v00@openssh.com" in kex["kex_algo_list"]
+
+    @mark.parametrize(
+        "server_active,client_active",
+        itertools.product([True, False], repeat=2),
+    )
+    def test_mode_agreement(self, server_active, client_active):
+        with server(
+            server_init=dict(strict_kex=server_active),
+            client_init=dict(strict_kex=client_active),
+        ) as (tc, ts):
+            if server_active and client_active:
+                assert tc.agreed_on_strict_kex is True
+                assert ts.agreed_on_strict_kex is True
+            else:
+                assert tc.agreed_on_strict_kex is False
+                assert ts.agreed_on_strict_kex is False
+
+    def test_mode_advertised_by_default(self):
+        # NOTE: no explicit strict_kex overrides...
+        with server() as (tc, ts):
+            assert all(
+                (
+                    tc.advertise_strict_kex,
+                    tc.agreed_on_strict_kex,
+                    ts.advertise_strict_kex,
+                    ts.agreed_on_strict_kex,
+                )
+            )
+
+    @mark.parametrize(
+        "ptype",
+        (
+            # "normal" but definitely out-of-order message
+            MSG_CHANNEL_OPEN,
+            # Normally ignored, but not in this case
+            MSG_IGNORE,
+            # Normally triggers debug parsing, but not in this case
+            MSG_DEBUG,
+            # Normally ignored, but...you get the idea
+            MSG_UNIMPLEMENTED,
+            # Not real, so would normally trigger us /sending/
+            # MSG_UNIMPLEMENTED, but...
+            MSG_FUGGEDABOUTIT,
+        ),
+    )
+    def test_MessageOrderError_non_kex_messages_in_initial_kex(self, ptype):
+        class AttackTransport(Transport):
+            # Easiest apparent spot on server side which is:
+            # - late enough for both ends to have handshook on strict mode
+            # - early enough to be in the window of opportunity for Terrapin
+            # attack; essentially during actual kex, when the engine is
+            # waiting for things like MSG_KEXECDH_REPLY (for eg curve25519).
+            def _negotiate_keys(self, m):
+                self.clear_to_send_lock.acquire()
+                try:
+                    self.clear_to_send.clear()
+                finally:
+                    self.clear_to_send_lock.release()
+                if self.local_kex_init is None:
+                    # remote side wants to renegotiate
+                    self._send_kex_init()
+                self._parse_kex_init(m)
+                # Here, we would normally kick over to kex_engine, but instead
+                # we want the server to send the OOO message.
+                m = Message()
+                m.add_byte(byte_chr(ptype))
+                # rest of packet unnecessary...
+                self._send_message(m)
+
+        with raises(MessageOrderError):
+            with server(server_transport_factory=AttackTransport) as (tc, _):
+                pass  # above should run and except during connect()
+
+    def test_SSHException_raised_on_out_of_order_messages_when_not_strict(
+        self,
+    ):
+        # This is kind of dumb (either situation is still fatal!) but whatever,
+        # may as well be strict with our new strict flag...
+        with raises(SSHException) as info:  # would be true either way, but
+            with server(
+                client_init=dict(strict_kex=False),
+            ) as (tc, _):
+                tc._expect_packet(MSG_KEXINIT)
+                tc.open_session()
+        assert info.type is SSHException  # NOT MessageOrderError!
+
+    def test_error_not_raised_when_kexinit_not_seq_0_but_unstrict(self):
+        with server(
+            client_init=dict(
+                # Disable strict kex
+                strict_kex=False,
+                # Give our clientside a packetizer that sets all kexinit
+                # Message objects to have .seqno==17, which would trigger the
+                # new logic if we'd forgotten to wrap it in strict-kex check
+                packetizer_class=BadSeqPacketizer,
+            ),
+        ):
+            pass  # kexinit happens at connect...
+
+    def test_MessageOrderError_raised_when_kexinit_not_seq_0_and_strict(self):
+        with raises(MessageOrderError):
+            with server(
+                # Give our clientside a packetizer that sets all kexinit
+                # Message objects to have .seqno==17, which should trigger the
+                # new logic (given we are NOT disabling strict-mode)
+                client_init=dict(packetizer_class=BadSeqPacketizer),
+            ):
+                pass  # kexinit happens at connect...
+
+    def test_sequence_numbers_reset_on_newkeys_when_strict(self):
+        with server(defer=True) as (tc, ts):
+            # When in strict mode, these should all be zero or close to it
+            # (post-kexinit, pre-auth).
+            # Server->client will be 1 (EXT_INFO got sent after NEWKEYS)
+            assert tc.packetizer._Packetizer__sequence_number_in == 1
+            assert ts.packetizer._Packetizer__sequence_number_out == 1
+            # Client->server will be 0
+            assert tc.packetizer._Packetizer__sequence_number_out == 0
+            assert ts.packetizer._Packetizer__sequence_number_in == 0
+
+    def test_sequence_numbers_not_reset_on_newkeys_when_not_strict(self):
+        with server(defer=True, client_init=dict(strict_kex=False)) as (
+            tc,
+            ts,
+        ):
+            # When not in strict mode, these will all be ~3-4 or so
+            # (post-kexinit, pre-auth). Not encoding exact values as it will
+            # change anytime we mess with the test harness...
+            assert tc.packetizer._Packetizer__sequence_number_in != 0
+            assert tc.packetizer._Packetizer__sequence_number_out != 0
+            assert ts.packetizer._Packetizer__sequence_number_in != 0
+            assert ts.packetizer._Packetizer__sequence_number_out != 0
+
+    def test_sequence_number_rollover_detected(self):
+        class RolloverTransport(Transport):
+            def __init__(self, *args, **kwargs):
+                super().__init__(*args, **kwargs)
+                # Induce an about-to-rollover seqno, such that it rolls over
+                # during initial kex.
+                setattr(
+                    self.packetizer,
+                    "_Packetizer__sequence_number_in",
+                    sys.maxsize,
+                )
+                setattr(
+                    self.packetizer,
+                    "_Packetizer__sequence_number_out",
+                    sys.maxsize,
+                )
+
+        with raises(
+            SSHException,
+            match=r"Sequence number rolled over during initial kex!",
+        ):
+            with server(
+                client_init=dict(
+                    # Disable strict kex - this should happen always
+                    strict_kex=False,
+                ),
+                # Transport which tickles its packetizer seqno's
+                transport_factory=RolloverTransport,
+            ):
+                pass  # kexinit happens at connect...