瀏覽代碼

pkg/transport: set the maxIdleConnsPerHost to -1

for transport that are using timeout connections, we set the
maxIdleConnsPerHost to -1. The default transport does not clear
the timeout for the connections it sets to be idle. So the connections
with timeout cannot be reused.
Xiang Li 11 年之前
父節點
當前提交
e50d43fd32
共有 2 個文件被更改,包括 40 次插入1 次删除
  1. 3 0
      pkg/transport/timeout_transport.go
  2. 37 1
      pkg/transport/timeout_transport_test.go

+ 3 - 0
pkg/transport/timeout_transport.go

@@ -28,6 +28,9 @@ func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.
 	if err != nil {
 		return nil, err
 	}
+	// the timeouted connection will tiemout soon after it is idle.
+	// it should not be put back to http transport as an idle connection for future usage.
+	tr.MaxIdleConnsPerHost = -1
 	tr.Dial = (&rwTimeoutDialer{
 		Dialer: net.Dialer{
 			Timeout:   dialtimeoutd,

+ 37 - 1
pkg/transport/timeout_transport_test.go

@@ -15,6 +15,8 @@
 package transport
 
 import (
+	"bytes"
+	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -28,7 +30,12 @@ func TestNewTimeoutTransport(t *testing.T) {
 	if err != nil {
 		t.Fatalf("unexpected NewTimeoutTransport error: %v", err)
 	}
-	srv := httptest.NewServer(http.NotFoundHandler())
+
+	remoteAddr := func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte(r.RemoteAddr))
+	}
+	srv := httptest.NewServer(http.HandlerFunc(remoteAddr))
+
 	defer srv.Close()
 	conn, err := tr.Dial("tcp", srv.Listener.Addr().String())
 	if err != nil {
@@ -46,4 +53,33 @@ func TestNewTimeoutTransport(t *testing.T) {
 	if tconn.wtimeoutd != time.Hour {
 		t.Errorf("write timeout = %s, want %s", tconn.wtimeoutd, time.Hour)
 	}
+
+	// ensure not reuse timeout connection
+	req, err := http.NewRequest("GET", srv.URL, nil)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	resp, err := tr.RoundTrip(req)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	addr0, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+
+	resp, err = tr.RoundTrip(req)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	addr1, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+
+	if bytes.Equal(addr0, addr1) {
+		t.Errorf("addr0 = %s addr1= %s, want not equal", string(addr0), string(addr1))
+	}
 }