Browse Source

pkg/transport: set the maxIdleConnsPerHost to -1

for transport that are using timeout connections, we set the
maxIdleConnsPerHost to -1. The default transport does not clear
the timeout for the connections it sets to be idle. So the connections
with timeout cannot be reused.
Xiang Li 10 years ago
parent
commit
e50d43fd32
2 changed files with 40 additions and 1 deletions
  1. 3 0
      pkg/transport/timeout_transport.go
  2. 37 1
      pkg/transport/timeout_transport_test.go

+ 3 - 0
pkg/transport/timeout_transport.go

@@ -28,6 +28,9 @@ func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.
 	if err != nil {
 		return nil, err
 	}
+	// the timeouted connection will tiemout soon after it is idle.
+	// it should not be put back to http transport as an idle connection for future usage.
+	tr.MaxIdleConnsPerHost = -1
 	tr.Dial = (&rwTimeoutDialer{
 		Dialer: net.Dialer{
 			Timeout:   dialtimeoutd,

+ 37 - 1
pkg/transport/timeout_transport_test.go

@@ -15,6 +15,8 @@
 package transport
 
 import (
+	"bytes"
+	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -28,7 +30,12 @@ func TestNewTimeoutTransport(t *testing.T) {
 	if err != nil {
 		t.Fatalf("unexpected NewTimeoutTransport error: %v", err)
 	}
-	srv := httptest.NewServer(http.NotFoundHandler())
+
+	remoteAddr := func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte(r.RemoteAddr))
+	}
+	srv := httptest.NewServer(http.HandlerFunc(remoteAddr))
+
 	defer srv.Close()
 	conn, err := tr.Dial("tcp", srv.Listener.Addr().String())
 	if err != nil {
@@ -46,4 +53,33 @@ func TestNewTimeoutTransport(t *testing.T) {
 	if tconn.wtimeoutd != time.Hour {
 		t.Errorf("write timeout = %s, want %s", tconn.wtimeoutd, time.Hour)
 	}
+
+	// ensure not reuse timeout connection
+	req, err := http.NewRequest("GET", srv.URL, nil)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	resp, err := tr.RoundTrip(req)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	addr0, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+
+	resp, err = tr.RoundTrip(req)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	addr1, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+
+	if bytes.Equal(addr0, addr1) {
+		t.Errorf("addr0 = %s addr1= %s, want not equal", string(addr0), string(addr1))
+	}
 }