|
@@ -19,6 +19,7 @@ import (
|
|
|
"net"
|
|
"net"
|
|
|
"net/http"
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
"net/http/httptest"
|
|
|
|
|
+ "net/http/httptrace"
|
|
|
"net/textproto"
|
|
"net/textproto"
|
|
|
"net/url"
|
|
"net/url"
|
|
|
"os"
|
|
"os"
|
|
@@ -98,6 +99,15 @@ func TestTransportH2c(t *testing.T) {
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
t.Fatal(err)
|
|
|
}
|
|
}
|
|
|
|
|
+ var gotConnCnt int32
|
|
|
|
|
+ trace := &httptrace.ClientTrace{
|
|
|
|
|
+ GotConn: func(connInfo httptrace.GotConnInfo) {
|
|
|
|
|
+ if !connInfo.Reused {
|
|
|
|
|
+ atomic.AddInt32(&gotConnCnt, 1)
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
|
|
tr := &Transport{
|
|
tr := &Transport{
|
|
|
AllowHTTP: true,
|
|
AllowHTTP: true,
|
|
|
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
|
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
|
@@ -118,6 +128,9 @@ func TestTransportH2c(t *testing.T) {
|
|
|
if got, want := string(body), "Hello, /foobar, http: true"; got != want {
|
|
if got, want := string(body), "Hello, /foobar, http: true"; got != want {
|
|
|
t.Fatalf("response got %v, want %v", got, want)
|
|
t.Fatalf("response got %v, want %v", got, want)
|
|
|
}
|
|
}
|
|
|
|
|
+ if got, want := gotConnCnt, int32(1); got != want {
|
|
|
|
|
+ t.Errorf("Too many got connections: %d", gotConnCnt)
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func TestTransport(t *testing.T) {
|
|
func TestTransport(t *testing.T) {
|
|
@@ -244,6 +257,14 @@ func TestTransportGroupsPendingDials(t *testing.T) {
|
|
|
mu sync.Mutex
|
|
mu sync.Mutex
|
|
|
dials = map[string]int{}
|
|
dials = map[string]int{}
|
|
|
)
|
|
)
|
|
|
|
|
+ var gotConnCnt int32
|
|
|
|
|
+ trace := &httptrace.ClientTrace{
|
|
|
|
|
+ GotConn: func(connInfo httptrace.GotConnInfo) {
|
|
|
|
|
+ if !connInfo.Reused {
|
|
|
|
|
+ atomic.AddInt32(&gotConnCnt, 1)
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
var wg sync.WaitGroup
|
|
var wg sync.WaitGroup
|
|
|
for i := 0; i < 10; i++ {
|
|
for i := 0; i < 10; i++ {
|
|
|
wg.Add(1)
|
|
wg.Add(1)
|
|
@@ -254,6 +275,7 @@ func TestTransportGroupsPendingDials(t *testing.T) {
|
|
|
t.Error(err)
|
|
t.Error(err)
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
|
|
res, err := tr.RoundTrip(req)
|
|
res, err := tr.RoundTrip(req)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
t.Error(err)
|
|
t.Error(err)
|
|
@@ -298,6 +320,9 @@ func TestTransportGroupsPendingDials(t *testing.T) {
|
|
|
}); err != nil {
|
|
}); err != nil {
|
|
|
t.Errorf("State of pool after CloseIdleConnections: %v", err)
|
|
t.Errorf("State of pool after CloseIdleConnections: %v", err)
|
|
|
}
|
|
}
|
|
|
|
|
+ if got, want := gotConnCnt, int32(1); got != want {
|
|
|
|
|
+ t.Errorf("Too many got connections: %d", gotConnCnt)
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func retry(tries int, delay time.Duration, fn func() error) error {
|
|
func retry(tries int, delay time.Duration, fn func() error) error {
|