Browse Source

http2: implement Ping method on ClientConn

This new method sends a PING frame with random payload to the peer and
wait for a PING ack with the same payload.

In order to support cancellation and deadling, the Ping method takes a
context as argument.

Fixes golang/go#15475

Change-Id: I340133a67717af89556837cc531a885d116eba59
Reviewed-on: https://go-review.googlesource.com/29965
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Olivier Poitrey 9 years ago
parent
commit
ed0556cc88
4 changed files with 81 additions and 2 deletions
  1. 5 0
      http2/go17.go
  2. 4 0
      http2/not_go17.go
  3. 51 2
      http2/transport.go
  4. 21 0
      http2/transport_test.go

+ 5 - 0
http2/go17.go

@@ -92,3 +92,8 @@ func requestTrace(req *http.Request) *clientTrace {
 	trace := httptrace.ContextClientTrace(req.Context())
 	trace := httptrace.ContextClientTrace(req.Context())
 	return (*clientTrace)(trace)
 	return (*clientTrace)(trace)
 }
 }
+
+// Ping sends a PING frame to the server and waits for the ack.
+func (cc *ClientConn) Ping(ctx context.Context) error {
+	return cc.ping(ctx)
+}

+ 4 - 0
http2/not_go17.go

@@ -75,3 +75,7 @@ func cloneTLSConfig(c *tls.Config) *tls.Config {
 		CurvePreferences:         c.CurvePreferences,
 		CurvePreferences:         c.CurvePreferences,
 	}
 	}
 }
 }
+
+func (cc *ClientConn) Ping(ctx contextContext) error {
+	return cc.ping(ctx)
+}

+ 51 - 2
http2/transport.go

@@ -10,6 +10,7 @@ import (
 	"bufio"
 	"bufio"
 	"bytes"
 	"bytes"
 	"compress/gzip"
 	"compress/gzip"
+	"crypto/rand"
 	"crypto/tls"
 	"crypto/tls"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -160,6 +161,7 @@ type ClientConn struct {
 	goAwayDebug     string                   // goAway frame's debug data, retained as a string
 	goAwayDebug     string                   // goAway frame's debug data, retained as a string
 	streams         map[uint32]*clientStream // client-initiated
 	streams         map[uint32]*clientStream // client-initiated
 	nextStreamID    uint32
 	nextStreamID    uint32
+	pings           map[[8]byte]chan struct{} // in flight ping data to notification channel
 	bw              *bufio.Writer
 	bw              *bufio.Writer
 	br              *bufio.Reader
 	br              *bufio.Reader
 	fr              *Framer
 	fr              *Framer
@@ -431,6 +433,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
 		streams:              make(map[uint32]*clientStream),
 		streams:              make(map[uint32]*clientStream),
 		singleUse:            singleUse,
 		singleUse:            singleUse,
 		wantSettingsAck:      true,
 		wantSettingsAck:      true,
+		pings:                make(map[[8]byte]chan struct{}),
 	}
 	}
 	if VerboseLogs {
 	if VerboseLogs {
 		t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
 		t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -1815,10 +1818,56 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
 	return nil
 	return nil
 }
 }
 
 
+// Ping sends a PING frame to the server and waits for the ack.
+// Public implementation is in go17.go and not_go17.go
+func (cc *ClientConn) ping(ctx contextContext) error {
+	c := make(chan struct{})
+	// Generate a random payload
+	var p [8]byte
+	for {
+		if _, err := rand.Read(p[:]); err != nil {
+			return err
+		}
+		cc.mu.Lock()
+		// check for dup before insert
+		if _, found := cc.pings[p]; !found {
+			cc.pings[p] = c
+			cc.mu.Unlock()
+			break
+		}
+		cc.mu.Unlock()
+	}
+	cc.wmu.Lock()
+	if err := cc.fr.WritePing(false, p); err != nil {
+		cc.wmu.Unlock()
+		return err
+	}
+	if err := cc.bw.Flush(); err != nil {
+		cc.wmu.Unlock()
+		return err
+	}
+	cc.wmu.Unlock()
+	select {
+	case <-c:
+		return nil
+	case <-ctx.Done():
+		return ctx.Err()
+	case <-cc.readerDone:
+		// connection closed
+		return cc.readerErr
+	}
+}
+
 func (rl *clientConnReadLoop) processPing(f *PingFrame) error {
 func (rl *clientConnReadLoop) processPing(f *PingFrame) error {
 	if f.IsAck() {
 	if f.IsAck() {
-		// 6.7 PING: " An endpoint MUST NOT respond to PING frames
-		// containing this flag."
+		cc := rl.cc
+		cc.mu.Lock()
+		defer cc.mu.Unlock()
+		// If ack, notify listener if any
+		if c, ok := cc.pings[f.Data]; ok {
+			close(c)
+			delete(cc.pings, f.Data)
+		}
 		return nil
 		return nil
 	}
 	}
 	cc := rl.cc
 	cc := rl.cc

+ 21 - 0
http2/transport_test.go

@@ -39,6 +39,13 @@ var (
 
 
 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
 
 
+type testContext struct{}
+
+func (testContext) Done() <-chan struct{}                   { return make(chan struct{}) }
+func (testContext) Err() error                              { panic("should not be called") }
+func (testContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false }
+func (testContext) Value(key interface{}) interface{}       { return nil }
+
 func TestTransportExternal(t *testing.T) {
 func TestTransportExternal(t *testing.T) {
 	if !*extNet {
 	if !*extNet {
 		t.Skip("skipping external network test")
 		t.Skip("skipping external network test")
@@ -2628,3 +2635,17 @@ func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
 		t.Errorf("Body = %q; want %q", slurp, body)
 		t.Errorf("Body = %q; want %q", slurp, body)
 	}
 	}
 }
 }
+
+func TestClientConnPing(t *testing.T) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
+	defer st.Close()
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+	cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if err = cc.Ping(testContext{}); err != nil {
+		t.Fatal(err)
+	}
+}