Blob Blame History Raw
From 2b83e7ccc12af9fec136e9f4897e1585b3b931aa Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Thu, 24 May 2018 22:29:12 +0200
Subject: [PATCH 1/3] Add support for Python 3.7.

Hopefully for real this time.

This is annoyingly complicated.

Fix #405.

(cherry picked from commit 6f8f1c877744623f0a5df5917a85b97807bfb7e5)
---
 websockets/client.py                   | 24 +++++++----------
 websockets/py35/_test_client_server.py | 37 ++++++++++++++++++++++++++
 websockets/py35/client.py              | 33 +++++++++++++++++++++++
 websockets/py35/server.py              | 22 +++++++++++++++
 websockets/server.py                   | 25 +++++++----------
 websockets/test_client_server.py       |  1 +
 6 files changed, 111 insertions(+), 31 deletions(-)
 create mode 100644 websockets/py35/client.py
 create mode 100644 websockets/py35/server.py

diff --git a/websockets/client.py b/websockets/client.py
index 92f29e9..a86b90f 100644
--- a/websockets/client.py
+++ b/websockets/client.py
@@ -385,15 +385,7 @@ class Connect:
         self._creating_connection = loop.create_connection(
             factory, host, port, **kwds)
 
-    @asyncio.coroutine
-    def __aenter__(self):
-        return (yield from self)
-
-    @asyncio.coroutine
-    def __aexit__(self, exc_type, exc_value, traceback):
-        yield from self.ws_client.close()
-
-    def __await__(self):
+    def __iter__(self):                                     # pragma: no cover
         transport, protocol = yield from self._creating_connection
 
         try:
@@ -410,17 +402,19 @@ class Connect:
         self.ws_client = protocol
         return protocol
 
-    __iter__ = __await__
-
 
-# Disable asynchronous context manager functionality only on Python < 3.5.1
-# because it doesn't exist on Python < 3.5 and asyncio.ensure_future didn't
-# accept arbitrary awaitables in Python 3.5; that was fixed in Python 3.5.1.
+# We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future
+# didn't accept arbitrary awaitables until Python 3.5.1. We don't define
+# __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple.
 if sys.version_info[:3] <= (3, 5, 0):                       # pragma: no cover
     @asyncio.coroutine
     def connect(*args, **kwds):
-        return Connect(*args, **kwds).__await__()
+        return Connect(*args, **kwds).__iter__()
     connect.__doc__ = Connect.__doc__
 
 else:
+    from .py35.client import __aenter__, __aexit__, __await__
+    Connect.__aenter__ = __aenter__
+    Connect.__aexit__ = __aexit__
+    Connect.__await__ = __await__
     connect = Connect
diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py
index 4375248..5360d8d 100644
--- a/websockets/py35/_test_client_server.py
+++ b/websockets/py35/_test_client_server.py
@@ -13,6 +13,43 @@ from ..server import *
 from ..test_client_server import get_server_uri, handler
 
 
+class AsyncAwaitTests(unittest.TestCase):
+
+    def setUp(self):
+        self.loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(self.loop)
+
+    def tearDown(self):
+        self.loop.close()
+
+    def test_client(self):
+        start_server = serve(handler, 'localhost', 0)
+        server = self.loop.run_until_complete(start_server)
+
+        async def run_client():
+            # Await connect.
+            client = await connect(get_server_uri(server))
+            self.assertEqual(client.state, State.OPEN)
+            await client.close()
+            self.assertEqual(client.state, State.CLOSED)
+
+        self.loop.run_until_complete(run_client())
+
+        server.close()
+        self.loop.run_until_complete(server.wait_closed())
+
+    def test_server(self):
+        async def run_server():
+            # Await serve.
+            server = await serve(handler, 'localhost', 0)
+            self.assertTrue(server.sockets)
+            server.close()
+            await server.wait_closed()
+            self.assertFalse(server.sockets)
+
+        self.loop.run_until_complete(run_server())
+
+
 class ContextManagerTests(unittest.TestCase):
 
     def setUp(self):
diff --git a/websockets/py35/client.py b/websockets/py35/client.py
new file mode 100644
index 0000000..7673ea3
--- /dev/null
+++ b/websockets/py35/client.py
@@ -0,0 +1,33 @@
+async def __aenter__(self):
+    return await self
+
+
+async def __aexit__(self, exc_type, exc_value, traceback):
+    await self.ws_client.close()
+
+
+async def __await_impl__(self):
+    # Duplicated with __iter__ because Python 3.7 requires an async function
+    # (as explained in __await__ below) which Python 3.4 doesn't support.
+    transport, protocol = await self._creating_connection
+
+    try:
+        await protocol.handshake(
+            self._wsuri, origin=self._origin,
+            available_extensions=protocol.available_extensions,
+            available_subprotocols=protocol.available_subprotocols,
+            extra_headers=protocol.extra_headers,
+        )
+    except Exception:
+        await protocol.fail_connection()
+        raise
+
+    self.ws_client = protocol
+    return protocol
+
+
+def __await__(self):
+    # __await__() must return a type that I don't know how to obtain except
+    # by calling __await__() on the return value of an async function.
+    # I'm not finding a better way to take advantage of PEP 492.
+    return __await_impl__(self).__await__()
diff --git a/websockets/py35/server.py b/websockets/py35/server.py
new file mode 100644
index 0000000..41a3675
--- /dev/null
+++ b/websockets/py35/server.py
@@ -0,0 +1,22 @@
+async def __aenter__(self):
+    return await self
+
+
+async def __aexit__(self, exc_type, exc_value, traceback):
+    self.ws_server.close()
+    await self.ws_server.wait_closed()
+
+
+async def __await_impl__(self):
+    # Duplicated with __iter__ because Python 3.7 requires an async function
+    # (as explained in __await__ below) which Python 3.4 doesn't support.
+    server = await self._creating_server
+    self.ws_server.wrap(server)
+    return self.ws_server
+
+
+def __await__(self):
+    # __await__() must return a type that I don't know how to obtain except
+    # by calling __await__() on the return value of an async function.
+    # I'm not finding a better way to take advantage of PEP 492.
+    return __await_impl__(self).__await__()
diff --git a/websockets/server.py b/websockets/server.py
index 8db0482..46c80dc 100644
--- a/websockets/server.py
+++ b/websockets/server.py
@@ -729,22 +729,11 @@ class Serve:
         self._creating_server = creating_server
         self.ws_server = ws_server
 
-    @asyncio.coroutine
-    def __aenter__(self):
-        return (yield from self)
-
-    @asyncio.coroutine
-    def __aexit__(self, exc_type, exc_value, traceback):
-        self.ws_server.close()
-        yield from self.ws_server.wait_closed()
-
-    def __await__(self):
+    def __iter__(self):                                     # pragma: no cover
         server = yield from self._creating_server
         self.ws_server.wrap(server)
         return self.ws_server
 
-    __iter__ = __await__
-
 
 def unix_serve(ws_handler, path, **kwargs):
     """
@@ -761,14 +750,18 @@ def unix_serve(ws_handler, path, **kwargs):
     return serve(ws_handler, path=path, **kwargs)
 
 
-# Disable asynchronous context manager functionality only on Python < 3.5.1
-# because it doesn't exist on Python < 3.5 and asyncio.ensure_future didn't
-# accept arbitrary awaitables in Python 3.5; that was fixed in Python 3.5.1.
+# We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future
+# didn't accept arbitrary awaitables until Python 3.5.1. We don't define
+# __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple.
 if sys.version_info[:3] <= (3, 5, 0):                       # pragma: no cover
     @asyncio.coroutine
     def serve(*args, **kwds):
-        return Serve(*args, **kwds).__await__()
+        return Serve(*args, **kwds).__iter__()
     serve.__doc__ = Serve.__doc__
 
 else:
+    from .py35.server import __aenter__, __aexit__, __await__
+    Serve.__aenter__ = __aenter__
+    Serve.__aexit__ = __aexit__
+    Serve.__await__ = __await__
     serve = Serve
diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py
index 8476913..27a2a71 100644
--- a/websockets/test_client_server.py
+++ b/websockets/test_client_server.py
@@ -1057,6 +1057,7 @@ class ClientServerOriginTests(unittest.TestCase):
 
 
 try:
+    from .py35._test_client_server import AsyncAwaitTests               # noqa
     from .py35._test_client_server import ContextManagerTests           # noqa
 except (SyntaxError, ImportError):                          # pragma: no cover
     pass
-- 
2.18.0