Browse Source

http2: don't override user's Transport.TLSConfig.ServerName

Fixes golang/go#14501

Change-Id: Ibaa7fb1fff404c62c35bb7c63f4a442e4fc0610d
Reviewed-on: https://go-review.googlesource.com/19918
Reviewed-by: Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick 9 years ago
parent
commit
3f5b0e6e67
2 changed files with 67 additions and 2 deletions
  1. 6 2
      http2/transport.go
  2. 61 0
      http2/transport_test.go

+ 6 - 2
http2/transport.go

@@ -333,8 +333,12 @@ func (t *Transport) newTLSConfig(host string) *tls.Config {
 	if t.TLSClientConfig != nil {
 		*cfg = *t.TLSClientConfig
 	}
-	cfg.NextProtos = []string{NextProtoTLS} // TODO: don't override if already in list
-	cfg.ServerName = host
+	if !strSliceContains(cfg.NextProtos, NextProtoTLS) {
+		cfg.NextProtos = append([]string{NextProtoTLS}, cfg.NextProtos...)
+	}
+	if cfg.ServerName == "" {
+		cfg.ServerName = host
+	}
 	return cfg
 }
 

+ 61 - 0
http2/transport_test.go

@@ -1677,3 +1677,64 @@ func TestGzipReader_DoubleReadCrash(t *testing.T) {
 		t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1)
 	}
 }
+
+func TestTransportNewTLSConfig(t *testing.T) {
+	tests := [...]struct {
+		conf *tls.Config
+		host string
+		want *tls.Config
+	}{
+		// Normal case.
+		0: {
+			conf: nil,
+			host: "foo.com",
+			want: &tls.Config{
+				ServerName: "foo.com",
+				NextProtos: []string{NextProtoTLS},
+			},
+		},
+
+		// User-provided name (bar.com) takes precedence:
+		1: {
+			conf: &tls.Config{
+				ServerName: "bar.com",
+			},
+			host: "foo.com",
+			want: &tls.Config{
+				ServerName: "bar.com",
+				NextProtos: []string{NextProtoTLS},
+			},
+		},
+
+		// NextProto is prepended:
+		2: {
+			conf: &tls.Config{
+				NextProtos: []string{"foo", "bar"},
+			},
+			host: "example.com",
+			want: &tls.Config{
+				ServerName: "example.com",
+				NextProtos: []string{NextProtoTLS, "foo", "bar"},
+			},
+		},
+
+		// NextProto is not duplicated:
+		3: {
+			conf: &tls.Config{
+				NextProtos: []string{"foo", "bar", NextProtoTLS},
+			},
+			host: "example.com",
+			want: &tls.Config{
+				ServerName: "example.com",
+				NextProtos: []string{"foo", "bar", NextProtoTLS},
+			},
+		},
+	}
+	for i, tt := range tests {
+		tr := &Transport{TLSClientConfig: tt.conf}
+		got := tr.newTLSConfig(tt.host)
+		if !reflect.DeepEqual(got, tt.want) {
+			t.Errorf("%d. got %#v; want %#v", i, got, tt.want)
+		}
+	}
+}