Ver Fonte

http2: merge duplicate Transport dials

Fixes golang/go#13397
Updates golang/go#6891

Change-Id: I1e4c7bfe60c6abf9a03f2888aa6abc3891c309e7
Reviewed-on: https://go-review.googlesource.com/17134
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Brad Fitzpatrick há 10 anos atrás
pai
commit
195180cfeb
2 ficheiros alterados com 135 adições e 14 exclusões
  1. 54 14
      http2/client_conn_pool.go
  2. 81 0
      http2/transport_test.go

+ 54 - 14
http2/client_conn_pool.go

@@ -19,11 +19,12 @@ type ClientConnPool interface {
 
 type clientConnPool struct {
 	t  *Transport
-	mu sync.Mutex // TODO: switch to RWMutex
+	mu sync.Mutex // TODO: maybe switch to RWMutex
 	// TODO: add support for sharing conns based on cert names
 	// (e.g. share conn for googleapis.com and appspot.com)
-	conns map[string][]*ClientConn // key is host:port
-	keys  map[*ClientConn][]string
+	conns   map[string][]*ClientConn // key is host:port
+	dialing map[string]*dialCall     // currently in-flight dials
+	keys    map[*ClientConn][]string
 }
 
 func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
@@ -38,26 +39,65 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
 			return cc, nil
 		}
 	}
-	p.mu.Unlock()
 	if !dialOnMiss {
+		p.mu.Unlock()
 		return nil, ErrNoCachedConn
 	}
+	call := p.getStartDialLocked(addr)
+	p.mu.Unlock()
+	<-call.done
+	return call.res, call.err
+}
+
+// dialCall is an in-flight Transport dial call to a host.
+type dialCall struct {
+	p    *clientConnPool
+	done chan struct{} // closed when done
+	res  *ClientConn   // valid after done is closed
+	err  error         // valid after done is closed
+}
+
+// requires p.mu is held.
+func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
+	if call, ok := p.dialing[addr]; ok {
+		// A dial is already in-flight. Don't start another.
+		return call
+	}
+	call := &dialCall{p: p, done: make(chan struct{})}
+	if p.dialing == nil {
+		p.dialing = make(map[string]*dialCall)
+	}
+	p.dialing[addr] = call
+	go call.dial(addr)
+	return call
+}
+
+// run in its own goroutine.
+func (c *dialCall) dial(addr string) {
+	c.res, c.err = c.p.t.dialClientConn(addr)
+	close(c.done)
 
-	// TODO(bradfitz): use a singleflight.Group to only lock once per 'key'.
-	// Probably need to vendor it in as github.com/golang/sync/singleflight
-	// though, since the net package already uses it? Also lines up with
-	// sameer, bcmills, et al wanting to open source some sync stuff.
-	cc, err := p.t.dialClientConn(addr)
-	if err != nil {
-		return nil, err
+	c.p.mu.Lock()
+	delete(c.p.dialing, addr)
+	if c.err == nil {
+		c.p.addConnLocked(addr, c.res)
 	}
-	p.addConn(addr, cc)
-	return cc, nil
+	c.p.mu.Unlock()
 }
 
 func (p *clientConnPool) addConn(key string, cc *ClientConn) {
 	p.mu.Lock()
-	defer p.mu.Unlock()
+	p.addConnLocked(key, cc)
+	p.mu.Unlock()
+}
+
+// p.mu must be held
+func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
+	for _, v := range p.conns[key] {
+		if v == cc {
+			return
+		}
+	}
 	if p.conns == nil {
 		p.conns = make(map[string][]*ClientConn)
 	}

+ 81 - 0
http2/transport_test.go

@@ -128,6 +128,87 @@ func TestTransportReusesConns(t *testing.T) {
 	}
 }
 
+// Tests that the Transport only keeps one pending dial open per destination address.
+// https://golang.org/issue/13397
+func TestTransportGroupsPendingDials(t *testing.T) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		io.WriteString(w, r.RemoteAddr)
+	}, optOnlyServer)
+	defer st.Close()
+	tr := &Transport{
+		TLSClientConfig: tlsConfigInsecure,
+	}
+	defer tr.CloseIdleConnections()
+	var (
+		mu    sync.Mutex
+		dials = map[string]int{}
+	)
+	var wg sync.WaitGroup
+	for i := 0; i < 10; i++ {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			req, err := http.NewRequest("GET", st.ts.URL, nil)
+			if err != nil {
+				t.Error(err)
+				return
+			}
+			res, err := tr.RoundTrip(req)
+			if err != nil {
+				t.Error(err)
+				return
+			}
+			defer res.Body.Close()
+			slurp, err := ioutil.ReadAll(res.Body)
+			if err != nil {
+				t.Errorf("Body read: %v", err)
+			}
+			addr := strings.TrimSpace(string(slurp))
+			if addr == "" {
+				t.Errorf("didn't get an addr in response")
+			}
+			mu.Lock()
+			dials[addr]++
+			mu.Unlock()
+		}()
+	}
+	wg.Wait()
+	if len(dials) != 1 {
+		t.Errorf("saw %d dials; want 1: %v", len(dials), dials)
+	}
+	tr.CloseIdleConnections()
+	if err := retry(50, 10*time.Millisecond, func() error {
+		cp, ok := tr.connPool().(*clientConnPool)
+		if !ok {
+			return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
+		}
+		if len(cp.dialing) != 0 {
+			return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
+		}
+		if len(cp.conns) != 0 {
+			return fmt.Errorf("conns = %v; want empty", cp.conns)
+		}
+		if len(cp.keys) != 0 {
+			return fmt.Errorf("keys = %v; want empty", cp.keys)
+		}
+		return nil
+	}); err != nil {
+		t.Error("State of pool after CloseIdleConnections: %v", err)
+	}
+}
+
+func retry(tries int, delay time.Duration, fn func() error) error {
+	var err error
+	for i := 0; i < tries; i++ {
+		err = fn()
+		if err == nil {
+			return nil
+		}
+		time.Sleep(delay)
+	}
+	return err
+}
+
 func TestTransportAbortClosesPipes(t *testing.T) {
 	shutdown := make(chan struct{})
 	st := newServerTester(t,