Browse Source

Merge pull request #2414 from xiang90/fix_max_host

pkg/transport: set the maxIdleConnsPerHost to -1
Kelsey Hightower 10 years ago
parent
commit
a3b3fc5e87
2 changed files with 40 additions and 1 deletions
  1. 3 0
      pkg/transport/timeout_transport.go
  2. 37 1
      pkg/transport/timeout_transport_test.go

+ 3 - 0
pkg/transport/timeout_transport.go

@@ -28,6 +28,9 @@ func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.
 	if err != nil {
 		return nil, err
 	}
+	// the timeouted connection will tiemout soon after it is idle.
+	// it should not be put back to http transport as an idle connection for future usage.
+	tr.MaxIdleConnsPerHost = -1
 	tr.Dial = (&rwTimeoutDialer{
 		Dialer: net.Dialer{
 			Timeout:   dialtimeoutd,

+ 37 - 1
pkg/transport/timeout_transport_test.go

@@ -15,6 +15,8 @@
 package transport
 
 import (
+	"bytes"
+	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -28,7 +30,12 @@ func TestNewTimeoutTransport(t *testing.T) {
 	if err != nil {
 		t.Fatalf("unexpected NewTimeoutTransport error: %v", err)
 	}
-	srv := httptest.NewServer(http.NotFoundHandler())
+
+	remoteAddr := func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte(r.RemoteAddr))
+	}
+	srv := httptest.NewServer(http.HandlerFunc(remoteAddr))
+
 	defer srv.Close()
 	conn, err := tr.Dial("tcp", srv.Listener.Addr().String())
 	if err != nil {
@@ -46,4 +53,33 @@ func TestNewTimeoutTransport(t *testing.T) {
 	if tconn.wtimeoutd != time.Hour {
 		t.Errorf("write timeout = %s, want %s", tconn.wtimeoutd, time.Hour)
 	}
+
+	// ensure not reuse timeout connection
+	req, err := http.NewRequest("GET", srv.URL, nil)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	resp, err := tr.RoundTrip(req)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	addr0, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+
+	resp, err = tr.RoundTrip(req)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	addr1, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+
+	if bytes.Equal(addr0, addr1) {
+		t.Errorf("addr0 = %s addr1= %s, want not equal", string(addr0), string(addr1))
+	}
 }