Browse Source

Merge pull request #1701 from Shopify/diego_clone-tls-config

Set server name only for the current broker
Diego Alvarez 4 years ago
parent
commit
4d2231eabe
2 changed files with 37 additions and 19 deletions
  1. 18 19
      broker.go
  2. 19 0
      client_tls_test.go

+ 18 - 19
broker.go

@@ -162,29 +162,11 @@ func (b *Broker) Open(conf *Config) error {
 			atomic.StoreInt32(&b.opened, 0)
 			return
 		}
-
 		if conf.Net.TLS.Enable {
-			Logger.Printf("Using tls")
-			cfg := conf.Net.TLS.Config
-			if cfg == nil {
-				cfg = &tls.Config{}
-			}
-			// If no ServerName is set, infer the ServerName
-			// from the hostname we're connecting to.
-			// Gets the hostname as tls.DialWithDialer does it.
-			if cfg.ServerName == "" {
-				colonPos := strings.LastIndex(b.addr, ":")
-				if colonPos == -1 {
-					colonPos = len(b.addr)
-				}
-				hostname := b.addr[:colonPos]
-				cfg.ServerName = hostname
-			}
-			b.conn = tls.Client(b.conn, cfg)
+			b.conn = tls.Client(b.conn, validServerNameTLS(b.addr, conf.Net.TLS.Config))
 		}
 
 		b.conn = newBufConn(b.conn)
-
 		b.conf = conf
 
 		// Create or reuse the global metrics shared between brokers
@@ -1440,3 +1422,20 @@ func (b *Broker) registerCounter(name string) metrics.Counter {
 	b.registeredMetrics = append(b.registeredMetrics, nameForBroker)
 	return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry)
 }
+
+func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config {
+	if cfg == nil {
+		cfg = &tls.Config{}
+	}
+	if cfg.ServerName != "" {
+		return cfg
+	}
+
+	c := cfg.Clone()
+	sn, _, err := net.SplitHostPort(addr)
+	if err != nil {
+		Logger.Println(fmt.Errorf("failed to get ServerName from addr %w", err))
+	}
+	c.ServerName = sn
+	return c
+}

+ 19 - 0
client_tls_test.go

@@ -210,3 +210,22 @@ func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientCon
 		}
 	}
 }
+
+func TestSetServerName(t *testing.T) {
+	if validServerNameTLS("kafka-server.domain.com:9093", nil).ServerName != "kafka-server.domain.com" {
+		t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config is nil")
+	}
+
+	if validServerNameTLS("kafka-server.domain.com:9093", &tls.Config{}).ServerName != "kafka-server.domain.com" {
+		t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config ServerName is not provided")
+	}
+
+	c := &tls.Config{ServerName: "kafka-server-other.domain.com"}
+	if validServerNameTLS("", c).ServerName != "kafka-server-other.domain.com" {
+		t.Fatal("Expected kafka-server-other.domain.com as tls.ServerName when tls config ServerName is provided")
+	}
+
+	if validServerNameTLS("host-no-port", nil).ServerName != "" {
+		t.Fatal("Expected empty ServerName as the broker addr is missing the port")
+	}
+}