Blob Blame History Raw
From 9b8f36d08a5bdffa83019f679a9c9d2ef5ca4302 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sun, 15 Jul 2018 11:07:47 +0200
Subject: [PATCH 3/3] Support yield from connect/serve on Python 3.7.

Fix #435.

(cherry picked from commit 91a376685b1ab7103d3d861ff8b02a1c00f142b1)
---
 websockets/client.py                   |  1 +
 websockets/py35/_test_client_server.py |  3 ++
 websockets/server.py                   |  1 +
 websockets/test_client_server.py       | 41 ++++++++++++++++++++++++++
 4 files changed, 46 insertions(+)

diff --git a/websockets/client.py b/websockets/client.py
index a86b90f..bb3009b 100644
--- a/websockets/client.py
+++ b/websockets/client.py
@@ -385,6 +385,7 @@ class Connect:
         self._creating_connection = loop.create_connection(
             factory, host, port, **kwds)
 
+    @asyncio.coroutine
     def __iter__(self):                                     # pragma: no cover
         transport, protocol = yield from self._creating_connection
 
diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py
index 5360d8d..c656dd3 100644
--- a/websockets/py35/_test_client_server.py
+++ b/websockets/py35/_test_client_server.py
@@ -39,6 +39,7 @@ class AsyncAwaitTests(unittest.TestCase):
         self.loop.run_until_complete(server.wait_closed())
 
     def test_server(self):
+
         async def run_server():
             # Await serve.
             server = await serve(handler, 'localhost', 0)
@@ -83,6 +84,7 @@ class ContextManagerTests(unittest.TestCase):
     @unittest.skipIf(
         sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+')
     def test_server(self):
+
         async def run_server():
             # Use serve as an asynchronous context manager.
             async with serve(handler, 'localhost', 0) as server:
@@ -99,6 +101,7 @@ class ContextManagerTests(unittest.TestCase):
     @unittest.skipUnless(
         hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets')
     def test_unix_server(self):
+
         async def run_server(path):
             async with unix_serve(handler, path) as server:
                 self.assertTrue(server.sockets)
diff --git a/websockets/server.py b/websockets/server.py
index 46c80dc..86fa700 100644
--- a/websockets/server.py
+++ b/websockets/server.py
@@ -729,6 +729,7 @@ class Serve:
         self._creating_server = creating_server
         self.ws_server = ws_server
 
+    @asyncio.coroutine
     def __iter__(self):                                     # pragma: no cover
         server = yield from self._creating_server
         self.ws_server.wrap(server)
diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py
index a3e1e92..6c25784 100644
--- a/websockets/test_client_server.py
+++ b/websockets/test_client_server.py
@@ -24,6 +24,7 @@ from .extensions.permessage_deflate import (
 )
 from .handshake import build_response
 from .http import USER_AGENT, read_response
+from .protocol import State
 from .server import *
 from .test_protocol import MS
 
@@ -1056,6 +1057,46 @@ class ClientServerOriginTests(unittest.TestCase):
         self.loop.run_until_complete(server.wait_closed())
 
 
+class YieldFromTests(unittest.TestCase):
+
+    def setUp(self):
+        self.loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(self.loop)
+
+    def tearDown(self):
+        self.loop.close()
+
+    def test_client(self):
+        start_server = serve(handler, 'localhost', 0)
+        server = self.loop.run_until_complete(start_server)
+
+        @asyncio.coroutine
+        def run_client():
+            # Yield from connect.
+            client = yield from connect(get_server_uri(server))
+            self.assertEqual(client.state, State.OPEN)
+            yield from client.close()
+            self.assertEqual(client.state, State.CLOSED)
+
+        self.loop.run_until_complete(run_client())
+
+        server.close()
+        self.loop.run_until_complete(server.wait_closed())
+
+    def test_server(self):
+
+        @asyncio.coroutine
+        def run_server():
+            # Yield from serve.
+            server = yield from serve(handler, 'localhost', 0)
+            self.assertTrue(server.sockets)
+            server.close()
+            yield from server.wait_closed()
+            self.assertFalse(server.sockets)
+
+        self.loop.run_until_complete(run_server())
+
+
 if sys.version_info[:2] >= (3, 5):                          # pragma: no cover
     from .py35._test_client_server import AsyncAwaitTests               # noqa
     from .py35._test_client_server import ContextManagerTests           # noqa
-- 
2.18.0