Blob Blame History Raw
From 7000f339c18b58f00dae129bd4a8ba097329db67 Mon Sep 17 00:00:00 2001
From: Jan Chaloupka <jchaloup@redhat.com>
Date: Thu, 29 Oct 2015 15:30:04 +0100
Subject: [PATCH] backport refactoring of TLS connection

---
 pkg/registry/generic/rest/proxy.go |  74 +-------------------------
 pkg/util/http.go                   |  58 ++++++++++++++++++++
 pkg/util/proxy/dial.go             | 106 +++++++++++++++++++++++++++++++++++++
 3 files changed, 165 insertions(+), 73 deletions(-)
 create mode 100644 pkg/util/proxy/dial.go

diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go
index bd28ed7..3f91db4 100644
--- a/pkg/registry/generic/rest/proxy.go
+++ b/pkg/registry/generic/rest/proxy.go
@@ -17,10 +17,7 @@ limitations under the License.
 package rest
 
 import (
-	"crypto/tls"
-	"fmt"
 	"io"
-	"net"
 	"net/http"
 	"net/http/httputil"
 	"net/url"
@@ -33,7 +30,6 @@ import (
 	"k8s.io/kubernetes/pkg/util/proxy"
 
 	"github.com/golang/glog"
-	"k8s.io/kubernetes/third_party/golang/netutil"
 )
 
 // UpgradeAwareProxyHandler is a handler for proxy requests that may require an upgrade
@@ -122,7 +118,7 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R
 		return false
 	}
 
-	backendConn, err := h.dialURL()
+	backendConn, err := proxy.DialURL(h.Location, h.Transport)
 	if err != nil {
 		h.err = err
 		return true
@@ -171,74 +167,6 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R
 	return true
 }
 
-func (h *UpgradeAwareProxyHandler) dialURL() (net.Conn, error) {
-	dialAddr := netutil.CanonicalAddr(h.Location)
-
-	var dialer func(network, addr string) (net.Conn, error)
-	if httpTransport, ok := h.Transport.(*http.Transport); ok && httpTransport.Dial != nil {
-		dialer = httpTransport.Dial
-	}
-
-	switch h.Location.Scheme {
-	case "http":
-		if dialer != nil {
-			return dialer("tcp", dialAddr)
-		}
-		return net.Dial("tcp", dialAddr)
-	case "https":
-		// TODO: this TLS logic can probably be cleaned up; it's messy in an attempt
-		// to preserve behavior that we don't know for sure is exercised.
-
-		// Get the tls config from the transport if we recognize it
-		var tlsConfig *tls.Config
-		var tlsConn *tls.Conn
-		var err error
-		if h.Transport != nil {
-			httpTransport, ok := h.Transport.(*http.Transport)
-			if ok {
-				tlsConfig = httpTransport.TLSClientConfig
-			}
-		}
-		if dialer != nil {
-			// We have a dialer; use it to open the connection, then
-			// create a tls client using the connection.
-			netConn, err := dialer("tcp", dialAddr)
-			if err != nil {
-				return nil, err
-			}
-			// tls.Client requires non-nil config
-			if tlsConfig == nil {
-				glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify")
-				tlsConfig = &tls.Config{
-					InsecureSkipVerify: true,
-				}
-			}
-			tlsConn = tls.Client(netConn, tlsConfig)
-			if err := tlsConn.Handshake(); err != nil {
-				return nil, err
-			}
-
-		} else {
-			// Dial
-			tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
-			if err != nil {
-				return nil, err
-			}
-		}
-
-		// Verify
-		host, _, _ := net.SplitHostPort(dialAddr)
-		if err := tlsConn.VerifyHostname(host); err != nil {
-			tlsConn.Close()
-			return nil, err
-		}
-
-		return tlsConn, nil
-	default:
-		return nil, fmt.Errorf("unknown scheme: %s", h.Location.Scheme)
-	}
-}
-
 func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.RoundTripper {
 	scheme := url.Scheme
 	host := url.Host
diff --git a/pkg/util/http.go b/pkg/util/http.go
index eca9aff..8f35ce4 100644
--- a/pkg/util/http.go
+++ b/pkg/util/http.go
@@ -17,7 +17,11 @@ limitations under the License.
 package util
 
 import (
+	"crypto/tls"
+	"fmt"
 	"io"
+	"net"
+	"net/http"
 	"net/url"
 	"strings"
 )
@@ -44,3 +48,57 @@ func IsProbableEOF(err error) bool {
 	}
 	return false
 }
+
+var defaultTransport = http.DefaultTransport.(*http.Transport)
+
+// SetTransportDefaults applies the defaults from http.DefaultTransport
+// for the Proxy, Dial, and TLSHandshakeTimeout fields if unset
+func SetTransportDefaults(t *http.Transport) *http.Transport {
+	if t.Proxy == nil {
+		t.Proxy = defaultTransport.Proxy
+	}
+	if t.Dial == nil {
+		t.Dial = defaultTransport.Dial
+	}
+	if t.TLSHandshakeTimeout == 0 {
+		t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
+	}
+	return t
+}
+
+type RoundTripperWrapper interface {
+	http.RoundTripper
+	WrappedRoundTripper() http.RoundTripper
+}
+
+type DialFunc func(net, addr string) (net.Conn, error)
+
+func Dialer(transport http.RoundTripper) (DialFunc, error) {
+	if transport == nil {
+		return nil, nil
+	}
+
+	switch transport := transport.(type) {
+	case *http.Transport:
+		return transport.Dial, nil
+	case RoundTripperWrapper:
+		return Dialer(transport.WrappedRoundTripper())
+	default:
+		return nil, fmt.Errorf("unknown transport type: %v", transport)
+	}
+}
+
+func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) {
+	if transport == nil {
+		return nil, nil
+	}
+
+	switch transport := transport.(type) {
+	case *http.Transport:
+		return transport.TLSClientConfig, nil
+	case RoundTripperWrapper:
+		return TLSClientConfig(transport.WrappedRoundTripper())
+	default:
+		return nil, fmt.Errorf("unknown transport type: %v", transport)
+	}
+}
diff --git a/pkg/util/proxy/dial.go b/pkg/util/proxy/dial.go
new file mode 100644
index 0000000..07982b7
--- /dev/null
+++ b/pkg/util/proxy/dial.go
@@ -0,0 +1,106 @@
+/*
+Copyright 2015 The Kubernetes Authors All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package proxy
+
+import (
+	"crypto/tls"
+	"fmt"
+	"net"
+	"net/http"
+	"net/url"
+
+	"github.com/golang/glog"
+
+	"k8s.io/kubernetes/pkg/util"
+	"k8s.io/kubernetes/third_party/golang/netutil"
+)
+
+func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
+	dialAddr := netutil.CanonicalAddr(url)
+
+	dialer, _ := util.Dialer(transport)
+
+	switch url.Scheme {
+	case "http":
+		if dialer != nil {
+			return dialer("tcp", dialAddr)
+		}
+		return net.Dial("tcp", dialAddr)
+	case "https":
+		// Get the tls config from the transport if we recognize it
+		var tlsConfig *tls.Config
+		var tlsConn *tls.Conn
+		var err error
+		tlsConfig, _ = util.TLSClientConfig(transport)
+
+		if dialer != nil {
+			// We have a dialer; use it to open the connection, then
+			// create a tls client using the connection.
+			netConn, err := dialer("tcp", dialAddr)
+			if err != nil {
+				return nil, err
+			}
+			if tlsConfig == nil {
+				// tls.Client requires non-nil config
+				glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify")
+				// tls.Handshake() requires ServerName or InsecureSkipVerify
+				tlsConfig = &tls.Config{
+					InsecureSkipVerify: true,
+				}
+			} else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
+				// tls.Handshake() requires ServerName or InsecureSkipVerify
+				// infer the ServerName from the hostname we're connecting to.
+				inferredHost := dialAddr
+				if host, _, err := net.SplitHostPort(dialAddr); err == nil {
+					inferredHost = host
+				}
+				// Make a copy to avoid polluting the provided config
+				tlsConfigCopy := *tlsConfig
+				tlsConfigCopy.ServerName = inferredHost
+				tlsConfig = &tlsConfigCopy
+			}
+			tlsConn = tls.Client(netConn, tlsConfig)
+			if err := tlsConn.Handshake(); err != nil {
+				netConn.Close()
+				return nil, err
+			}
+
+		} else {
+			// Dial
+			tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
+			if err != nil {
+				return nil, err
+			}
+		}
+
+		// Return if we were configured to skip validation
+		if tlsConfig != nil && tlsConfig.InsecureSkipVerify {
+			return tlsConn, nil
+		}
+
+		// Verify
+		host, _, _ := net.SplitHostPort(dialAddr)
+		if err := tlsConn.VerifyHostname(host); err != nil {
+			tlsConn.Close()
+			return nil, err
+		}
+
+		return tlsConn, nil
+	default:
+		return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
+	}
+}
-- 
1.9.3