Blob Blame History Raw
From 755dfdc2e16f2f7f5fa6669cb81d0ec3118ec203 Mon Sep 17 00:00:00 2001
From: Damien Ciabrini <damien.ciabrini@gmail.com>
Date: Wed, 16 Nov 2016 13:11:03 +0100
Subject: [PATCH] Add bind_address option (#529)

Allow connecting to the DB from a specific network interface
---
 pymysql/connections.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/pymysql/connections.py b/pymysql/connections.py
index d5e39a1..2884cdc 100644
--- a/pymysql/connections.py
+++ b/pymysql/connections.py
@@ -534,7 +534,8 @@ class Connection(object):
                  compress=None, named_pipe=None, no_delay=None,
                  autocommit=False, db=None, passwd=None, local_infile=False,
                  max_allowed_packet=16*1024*1024, defer_connect=False,
-                 auth_plugin_map={}, read_timeout=None, write_timeout=None):
+                 auth_plugin_map={}, read_timeout=None, write_timeout=None,
+                 bind_address=None):
         """
         Establish a connection to the MySQL database. Accepts several
         arguments:
@@ -544,6 +545,9 @@ class Connection(object):
         password: Password to use.
         database: Database to use, None to not use a particular one.
         port: MySQL port to use, default is usually OK. (default: 3306)
+        bind_address: When the client has multiple network interfaces, specify
+            the interface from which to connect to the host. Argument can be
+            a hostname or an IP address.
         unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
         charset: Charset you want to use.
         sql_mode: Default SQL_MODE to use.
@@ -632,6 +636,7 @@ class Connection(object):
             database = _config("database", database)
             unix_socket = _config("socket", unix_socket)
             port = int(_config("port", port))
+            bind_address = _config("bind-address", bind_address)
             charset = _config("default-character-set", charset)
 
         self.host = host or "localhost"
@@ -640,6 +645,7 @@ class Connection(object):
         self.password = password or ""
         self.db = database
         self.unix_socket = unix_socket
+        self.bind_address = bind_address
         if read_timeout is not None and read_timeout <= 0:
             raise ValueError("read_timeout should be >= 0")
         self._read_timeout = read_timeout
@@ -884,10 +890,14 @@ class Connection(object):
                     self.host_info = "Localhost via UNIX socket"
                     if DEBUG: print('connected using unix_socket')
                 else:
+                    kwargs = {}
+                    if self.bind_address is not None:
+                        kwargs['source_address'] = (self.bind_address, 0)
                     while True:
                         try:
                             sock = socket.create_connection(
-                                (self.host, self.port), self.connect_timeout)
+                                (self.host, self.port), self.connect_timeout,
+                                **kwargs)
                             break
                         except (OSError, IOError) as e:
                             if e.errno == errno.EINTR:
-- 
2.5.5