diff -ur m2crypto-0.18/M2Crypto/SSL/Connection.py m2crypto/M2Crypto/SSL/Connection.py --- m2crypto-0.18/M2Crypto/SSL/Connection.py 2007-06-15 23:34:05.000000000 +0200 +++ m2crypto/M2Crypto/SSL/Connection.py 2007-07-31 23:30:51.000000000 +0200 @@ -37,9 +37,11 @@ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self._fileno = self.socket.fileno() - - self.blocking = self.socket.gettimeout() - + + self._timeout = self.socket.gettimeout() + if self._timeout is None: + self._timeout = -1.0 + self.ssl_close_flag = m2.bio_noclose @@ -137,7 +139,7 @@ m2.ssl_set_accept_state(self.ssl) def accept_ssl(self): - return m2.ssl_accept(self.ssl) + return m2.ssl_accept(self.ssl, self._timeout) def accept(self): """Accept an SSL connection. The return value is a pair (ssl, addr) where @@ -159,7 +161,7 @@ m2.ssl_set_connect_state(self.ssl) def connect_ssl(self): - return m2.ssl_connect(self.ssl) + return m2.ssl_connect(self.ssl, self._timeout) def connect(self, addr): self.socket.connect(addr) @@ -186,7 +188,7 @@ return m2.ssl_pending(self.ssl) def _write_bio(self, data): - return m2.ssl_write(self.ssl, data) + return m2.ssl_write(self.ssl, data, self._timeout) def _write_nbio(self, data): return m2.ssl_write_nbio(self.ssl, data) @@ -194,7 +196,7 @@ def _read_bio(self, size=1024): if size <= 0: raise ValueError, 'size <= 0' - return m2.ssl_read(self.ssl, size) + return m2.ssl_read(self.ssl, size, self._timeout) def _read_nbio(self, size=1024): if size <= 0: @@ -202,13 +204,13 @@ return m2.ssl_read_nbio(self.ssl, size) def write(self, data): - if self.blocking: + if self._timeout != 0.0: return self._write_bio(data) return self._write_nbio(data) sendall = send = write def read(self, size=1024): - if self.blocking: + if self._timeout != 0.0: return self._read_bio(size) return self._read_nbio(size) recv = read @@ -216,7 +218,17 @@ def setblocking(self, mode): """Set this connection's underlying socket to _mode_.""" self.socket.setblocking(mode) - self.blocking = mode + if mode: + self._timeout = -1.0 + else: + self._timeout = 0.0 + + def settimeout(self, timeout): + """Set this connection's underlying socket's timeout to _timeout_.""" + self.socket.settimeout(timeout) + self._timeout = timeout + if self._timeout is None: + self._timeout = -1.0 def fileno(self): return self.socket.fileno() @@ -293,15 +305,8 @@ """Set the cipher suites for this connection.""" return m2.ssl_set_cipher_list(self.ssl, cipher_list) - def makefile(self, mode='rb', bufsize='ignored'): - r = 'r' in mode or '+' in mode - w = 'w' in mode or 'a' in mode or '+' in mode - b = 'b' in mode - m2mode = ['', 'r'][r] + ['', 'w'][w] + ['', 'b'][b] - # XXX Need to dup(). - bio = BIO.BIO(self.sslbio, _close_cb=self.close) - m2.bio_do_handshake(bio._ptr()) - return BIO.IOBuffer(bio, m2mode, _pyfree=0) + def makefile(self, mode='rb', bufsize=-1): + return socket._fileobject(self, mode, bufsize) def getsockname(self): return self.socket.getsockname() diff -ur m2crypto-0.18/M2Crypto/SSL/__init__.py m2crypto/M2Crypto/SSL/__init__.py --- m2crypto-0.18/M2Crypto/SSL/__init__.py 2006-03-20 20:26:28.000000000 +0100 +++ m2crypto/M2Crypto/SSL/__init__.py 2007-07-31 23:29:21.000000000 +0200 @@ -2,11 +2,14 @@ Copyright (c) 1999-2004 Ng Pheng Siong. All rights reserved.""" +import socket + # M2Crypto from M2Crypto import m2 class SSLError(Exception): pass -m2.ssl_init(SSLError) +class SSLTimeoutError(SSLError, socket.timeout): pass +m2.ssl_init(SSLError, SSLTimeoutError) # M2Crypto.SSL from Cipher import Cipher, Cipher_Stack diff -ur m2crypto-0.18/SWIG/_ssl.i m2crypto/SWIG/_ssl.i --- m2crypto-0.18/SWIG/_ssl.i 2007-06-05 02:30:11.000000000 +0200 +++ m2crypto/SWIG/_ssl.i 2007-08-01 00:06:34.000000000 +0200 @@ -8,10 +8,13 @@ %{ #include +#include #include #include #include #include +#include +#include %} %apply Pointer NONNULL { SSL_CTX * }; @@ -142,6 +145,11 @@ %rename(ssl_session_get_timeout) SSL_SESSION_get_timeout; extern long SSL_SESSION_get_timeout(CONST SSL_SESSION *); +extern PyObject *ssl_accept(SSL *ssl, double timeout = -1); +extern PyObject *ssl_connect(SSL *ssl, double timeout = -1); +extern PyObject *ssl_read(SSL *ssl, int num, double timeout = -1); +extern int ssl_write(SSL *ssl, PyObject *blob, double timeout = -1); + %constant int ssl_error_none = SSL_ERROR_NONE; %constant int ssl_error_ssl = SSL_ERROR_SSL; %constant int ssl_error_want_read = SSL_ERROR_WANT_READ; @@ -197,14 +205,19 @@ %constant int SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER = SSL_MODE_ENABLE_PARTIAL_WRITE; %constant int SSL_MODE_AUTO_RETRY = SSL_MODE_AUTO_RETRY; +%ignore ssl_handle_error; +%ignore ssl_sleep_with_timeout; %inline %{ static PyObject *_ssl_err; +static PyObject *_ssl_timeout_err; -void ssl_init(PyObject *ssl_err) { +void ssl_init(PyObject *ssl_err, PyObject *ssl_timeout_err) { SSL_library_init(); SSL_load_error_strings(); Py_INCREF(ssl_err); + Py_INCREF(ssl_timeout_err); _ssl_err = ssl_err; + _ssl_timeout_err = ssl_timeout_err; } void ssl_ctx_passphrase_callback(SSL_CTX *ctx, PyObject *pyfunc) { @@ -358,36 +371,130 @@ return ret; } -PyObject *ssl_accept(SSL *ssl) { +static void ssl_handle_error(int ssl_err, int ret) { + int err; + + switch (ssl_err) { + case SSL_ERROR_SSL: + PyErr_SetString(_ssl_err, + ERR_reason_error_string(ERR_get_error())); + break; + case SSL_ERROR_SYSCALL: + err = ERR_get_error(); + if (err) + PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); + else if (ret == 0) + PyErr_SetString(_ssl_err, "unexpected eof"); + else if (ret == -1) + PyErr_SetFromErrno(_ssl_err); + else + assert(0); + break; + default: + PyErr_SetString(_ssl_err, "unexpected SSL error"); + } +} + +static int ssl_sleep_with_timeout(SSL *ssl, const struct timeval *start, + double timeout, int ssl_err) { + struct pollfd fd; + struct timeval tv; + int ms, tmp; + + assert(timeout > 0); + again: + gettimeofday(&tv, NULL); + /* tv >= start */ + if ((timeout + start->tv_sec - tv.tv_sec) > INT_MAX / 1000) + ms = -1; + else { + int fract; + + ms = ((start->tv_sec + (int)timeout) - tv.tv_sec) * 1000; + fract = (start->tv_usec + (timeout - (int)timeout) * 1000000 + - tv.tv_usec + 999) / 1000; + if (ms > 0 && fract > INT_MAX - ms) + ms = -1; + else { + ms += fract; + if (ms <= 0) + goto timeout; + } + } + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + fd.fd = SSL_get_rfd(ssl); + fd.events = POLLIN; + break; + + case SSL_ERROR_WANT_WRITE: + fd.fd = SSL_get_wfd(ssl); + fd.events = POLLOUT; + break; + + case SSL_ERROR_WANT_X509_LOOKUP: + return 0; /* FIXME: is this correct? */ + + default: + assert(0); + } + if (fd.fd == -1) { + PyErr_SetString(_ssl_err, "timeout on a non-FD SSL"); + return -1; + } + Py_BEGIN_ALLOW_THREADS + tmp = poll(&fd, 1, ms); + Py_END_ALLOW_THREADS + switch (tmp) { + case 1: + return 0; + case 0: + goto timeout; + case -1: + if (errno == EINTR) + goto again; + PyErr_SetFromErrno(_ssl_err); + return -1; + } + return 0; + + timeout: + PyErr_SetNone(_ssl_timeout_err); + return -1; +} + +PyObject *ssl_accept(SSL *ssl, double timeout) { PyObject *obj = NULL; - int r, err; + int r, ssl_err; + struct timeval tv; + if (timeout > 0) + gettimeofday(&tv, NULL); + again: Py_BEGIN_ALLOW_THREADS r = SSL_accept(ssl); + ssl_err = SSL_get_error(ssl, r); Py_END_ALLOW_THREADS - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: obj = PyInt_FromLong((long)1); break; case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: - obj = PyInt_FromLong((long)0); - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); + if (timeout <= 0) { + obj = PyInt_FromLong((long)0); + break; + } + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; obj = NULL; break; + case SSL_ERROR_SSL: case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); obj = NULL; break; } @@ -396,36 +503,38 @@ return obj; } -PyObject *ssl_connect(SSL *ssl) { +PyObject *ssl_connect(SSL *ssl, double timeout) { PyObject *obj = NULL; - int r, err; + int r, ssl_err; + struct timeval tv; + if (timeout > 0) + gettimeofday(&tv, NULL); + again: Py_BEGIN_ALLOW_THREADS r = SSL_connect(ssl); + ssl_err = SSL_get_error(ssl, r); Py_END_ALLOW_THREADS - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: obj = PyInt_FromLong((long)1); break; case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: - obj = PyInt_FromLong((long)0); - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); + if (timeout <= 0) { + obj = PyInt_FromLong((long)0); + break; + } + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; obj = NULL; break; + case SSL_ERROR_SSL: case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); obj = NULL; break; } @@ -438,10 +547,11 @@ SSL_set_shutdown(ssl, mode); } -PyObject *ssl_read(SSL *ssl, int num) { +PyObject *ssl_read(SSL *ssl, int num, double timeout) { PyObject *obj = NULL; void *buf; - int r, err; + int r; + struct timeval tv; if (!(buf = PyMem_Malloc(num))) { PyErr_SetString(PyExc_MemoryError, "ssl_read"); @@ -449,37 +559,44 @@ } + if (timeout > 0) + gettimeofday(&tv, NULL); + again: Py_BEGIN_ALLOW_THREADS r = SSL_read(ssl, buf, num); Py_END_ALLOW_THREADS - switch (SSL_get_error(ssl, r)) { - case SSL_ERROR_NONE: - case SSL_ERROR_ZERO_RETURN: - buf = PyMem_Realloc(buf, r); - obj = PyString_FromStringAndSize(buf, r); - break; - case SSL_ERROR_WANT_WRITE: - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_X509_LOOKUP: - Py_INCREF(Py_None); - obj = Py_None; - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - obj = NULL; - break; - case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); - obj = NULL; - break; + if (r >= 0) { + buf = PyMem_Realloc(buf, r); + obj = PyString_FromStringAndSize(buf, r); + } else { + int ssl_err; + + ssl_err = SSL_get_error(ssl, r); + switch (ssl_err) { + case SSL_ERROR_NONE: + case SSL_ERROR_ZERO_RETURN: + assert(0); + + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_X509_LOOKUP: + if (timeout <= 0) { + Py_INCREF(Py_None); + obj = Py_None; + break; + } + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; + obj = NULL; + break; + case SSL_ERROR_SSL: + case SSL_ERROR_SYSCALL: + ssl_handle_error(ssl_err, r); + obj = NULL; + break; + } } PyMem_Free(buf); @@ -537,22 +654,26 @@ return obj; } -int ssl_write(SSL *ssl, PyObject *blob) { +int ssl_write(SSL *ssl, PyObject *blob, double timeout) { const void *buf; - int len, r, err, ret; + int len, r, ssl_err, ret; + struct timeval tv; if (m2_PyObject_AsReadBufferInt(blob, &buf, &len) == -1) { return -1; } - + if (timeout > 0) + gettimeofday(&tv, NULL); + again: Py_BEGIN_ALLOW_THREADS r = SSL_write(ssl, buf, len); + ssl_err = SSL_get_error(ssl, r); Py_END_ALLOW_THREADS - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: ret = r; @@ -560,20 +681,17 @@ case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_X509_LOOKUP: + if (timeout <= 0) { + ret = -1; + break; + } + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; ret = -1; break; case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - ret = -1; - break; case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); default: ret = -1; } diff -ur m2crypto-0.18/tests/test_ssl.py m2crypto/tests/test_ssl.py --- m2crypto-0.18/tests/test_ssl.py 2007-07-02 22:25:45.000000000 +0200 +++ m2crypto/tests/test_ssl.py 2007-07-31 23:29:21.000000000 +0200 @@ -887,6 +887,53 @@ class TwistedSSLClientTestCase(BaseSSLClientTestCase): + def test_timeout(self): + pid = self.start_server(self.args) + try: + ctx = SSL.Context() + s = SSL.Connection(ctx) + # Just a really small number so we can timeout + s.settimeout(0.000000000000000000000000000001) + self.assertRaises(SSL.SSLTimeoutError, s.connect, self.srv_addr) + s.close() + finally: + self.stop_server(pid) + + def test_makefile_timeout(self): + # httpslib uses makefile to read the response + pid = self.start_server(self.args) + try: + from M2Crypto import httpslib + c = httpslib.HTTPS(srv_host, srv_port) + c.putrequest('GET', '/') + c.putheader('Accept', 'text/html') + c.putheader('Accept', 'text/plain') + c.endheaders() + c._conn.sock.settimeout(100) + err, msg, headers = c.getreply() + assert err == 200, err + f = c.getfile() + data = f.read() + c.close() + finally: + self.stop_server(pid) + self.failIf(string.find(data, 's_server -quiet -www') == -1) + + def test_makefile_timeout_fires(self): + pid = self.start_server(self.args) + try: + from M2Crypto import httpslib + c = httpslib.HTTPS(srv_host, srv_port) + c.putrequest('GET', '/') + c.putheader('Accept', 'text/html') + c.putheader('Accept', 'text/plain') + c.endheaders() + c._conn.sock.settimeout(0.0000000001) + self.assertRaises(socket.timeout, c.getreply) + c.close() + finally: + self.stop_server(pid) + def test_twisted_wrapper(self): # Test only when twisted and ZopeInterfaces are present try: