Просмотр исходного кода

pkg/transport: set the maxIdleConnsPerHost to -1

for transport that are using timeout connections, we set the
maxIdleConnsPerHost to -1. The default transport does not clear
the timeout for the connections it sets to be idle. So the connections
with timeout cannot be reused.
Xiang Li 11 лет назад
Родитель
Сommit
e50d43fd32
2 измененных файлов с 40 добавлено и 1 удалено
  1. 3 0
      pkg/transport/timeout_transport.go
  2. 37 1
      pkg/transport/timeout_transport_test.go

+ 3 - 0
pkg/transport/timeout_transport.go

@@ -28,6 +28,9 @@ func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.
 	if err != nil {
 		return nil, err
 	}
+	// the timeouted connection will tiemout soon after it is idle.
+	// it should not be put back to http transport as an idle connection for future usage.
+	tr.MaxIdleConnsPerHost = -1
 	tr.Dial = (&rwTimeoutDialer{
 		Dialer: net.Dialer{
 			Timeout:   dialtimeoutd,

+ 37 - 1
pkg/transport/timeout_transport_test.go

@@ -15,6 +15,8 @@
 package transport
 
 import (
+	"bytes"
+	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -28,7 +30,12 @@ func TestNewTimeoutTransport(t *testing.T) {
 	if err != nil {
 		t.Fatalf("unexpected NewTimeoutTransport error: %v", err)
 	}
-	srv := httptest.NewServer(http.NotFoundHandler())
+
+	remoteAddr := func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte(r.RemoteAddr))
+	}
+	srv := httptest.NewServer(http.HandlerFunc(remoteAddr))
+
 	defer srv.Close()
 	conn, err := tr.Dial("tcp", srv.Listener.Addr().String())
 	if err != nil {
@@ -46,4 +53,33 @@ func TestNewTimeoutTransport(t *testing.T) {
 	if tconn.wtimeoutd != time.Hour {
 		t.Errorf("write timeout = %s, want %s", tconn.wtimeoutd, time.Hour)
 	}
+
+	// ensure not reuse timeout connection
+	req, err := http.NewRequest("GET", srv.URL, nil)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	resp, err := tr.RoundTrip(req)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	addr0, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+
+	resp, err = tr.RoundTrip(req)
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+	addr1, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatalf("unexpected err %v", err)
+	}
+
+	if bytes.Equal(addr0, addr1) {
+		t.Errorf("addr0 = %s addr1= %s, want not equal", string(addr0), string(addr1))
+	}
 }