Browse Source

http2: make Transport retry on server's GOAWAY graceful shutdown

Debugged & wrote with Tom Bergan.

Updates golang/go#18083

Change-Id: I00a1cb748fe9c0f01c5bd4b8d1ac4438b56f1f8c
Reviewed-on: https://go-review.googlesource.com/33971
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Tom Bergan <tombergan@google.com>
Brad Fitzpatrick 9 years ago
parent
commit
8dab929343
4 changed files with 104 additions and 12 deletions
  1. 9 0
      http2/go18.go
  2. 10 1
      http2/not_go18.go
  3. 57 6
      http2/transport.go
  4. 28 5
      http2/transport_test.go

+ 9 - 0
http2/go18.go

@@ -8,6 +8,7 @@ package http2
 
 import (
 	"crypto/tls"
+	"io"
 	"net/http"
 )
 
@@ -39,3 +40,11 @@ func configureServer18(h1 *http.Server, h2 *Server) error {
 func shouldLogPanic(panicValue interface{}) bool {
 	return panicValue != nil && panicValue != http.ErrAbortHandler
 }
+
+func reqGetBody(req *http.Request) func() (io.ReadCloser, error) {
+	return req.GetBody
+}
+
+func reqBodyIsNoBody(body io.ReadCloser) bool {
+	return body == http.NoBody
+}

+ 10 - 1
http2/not_go18.go

@@ -6,7 +6,10 @@
 
 package http2
 
-import "net/http"
+import (
+	"io"
+	"net/http"
+)
 
 func configureServer18(h1 *http.Server, h2 *Server) error {
 	// No IdleTimeout to sync prior to Go 1.8.
@@ -16,3 +19,9 @@ func configureServer18(h1 *http.Server, h2 *Server) error {
 func shouldLogPanic(panicValue interface{}) bool {
 	return panicValue != nil
 }
+
+func reqGetBody(req *http.Request) func() (io.ReadCloser, error) {
+	return nil
+}
+
+func reqBodyIsNoBody(io.ReadCloser) bool { return false }

+ 57 - 6
http2/transport.go

@@ -191,6 +191,7 @@ type clientStream struct {
 	ID            uint32
 	resc          chan resAndError
 	bufPipe       pipe // buffered pipe with the flow-controlled response payload
+	startedWrite  bool // started request body write; guarded by cc.mu
 	requestedGzip bool
 	on100         func() // optional code to run if get a 100 continue response
 
@@ -332,8 +333,10 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
 		}
 		traceGotConn(req, cc)
 		res, err := cc.RoundTrip(req)
-		if shouldRetryRequest(req, err) {
-			continue
+		if err != nil {
+			if req, err = shouldRetryRequest(req, err); err == nil {
+				continue
+			}
 		}
 		if err != nil {
 			t.vlogf("RoundTrip failure: %v", err)
@@ -355,12 +358,41 @@ func (t *Transport) CloseIdleConnections() {
 var (
 	errClientConnClosed   = errors.New("http2: client conn is closed")
 	errClientConnUnusable = errors.New("http2: client conn not usable")
+
+	errClientConnGotGoAway                 = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
+	errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written")
 )
 
-func shouldRetryRequest(req *http.Request, err error) bool {
-	// TODO: retry GET requests (no bodies) more aggressively, if shutdown
-	// before response.
-	return err == errClientConnUnusable
+// shouldRetryRequest is called by RoundTrip when a request fails to get
+// response headers. It is always called with a non-nil error.
+// It returns either a request to retry (either the same request, or a
+// modified clone), or an error if the request can't be replayed.
+func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) {
+	switch err {
+	default:
+		return nil, err
+	case errClientConnUnusable, errClientConnGotGoAway:
+		return req, nil
+	case errClientConnGotGoAwayAfterSomeReqBody:
+		// If the Body is nil (or http.NoBody), it's safe to reuse
+		// this request and its Body.
+		if req.Body == nil || reqBodyIsNoBody(req.Body) {
+			return req, nil
+		}
+		// Otherwise we depend on the Request having its GetBody
+		// func defined.
+		getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody
+		if getBody == nil {
+			return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error")
+		}
+		body, err := getBody()
+		if err != nil {
+			return nil, err
+		}
+		newReq := *req
+		newReq.Body = body
+		return &newReq, nil
+	}
 }
 
 func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
@@ -513,6 +545,15 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
 	if old != nil && old.ErrCode != ErrCodeNo {
 		cc.goAway.ErrCode = old.ErrCode
 	}
+	last := f.LastStreamID
+	for streamID, cs := range cc.streams {
+		if streamID > last {
+			select {
+			case cs.resc <- resAndError{err: errClientConnGotGoAway}:
+			default:
+			}
+		}
+	}
 }
 
 func (cc *ClientConn) CanTakeNewRequest() bool {
@@ -773,6 +814,13 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			cs.abortRequestBodyWrite(errStopReqBodyWrite)
 		}
 		if re.err != nil {
+			if re.err == errClientConnGotGoAway {
+				cc.mu.Lock()
+				if cs.startedWrite {
+					re.err = errClientConnGotGoAwayAfterSomeReqBody
+				}
+				cc.mu.Unlock()
+			}
 			cc.forgetStreamID(cs.ID)
 			return nil, re.err
 		}
@@ -2013,6 +2061,9 @@ func (t *Transport) getBodyWriterState(cs *clientStream, body io.Reader) (s body
 	resc := make(chan error, 1)
 	s.resc = resc
 	s.fn = func() {
+		cs.cc.mu.Lock()
+		cs.startedWrite = true
+		cs.cc.mu.Unlock()
 		resc <- cs.writeRequestBody(body, cs.req.Body)
 	}
 	s.delay = t.expectContinueTimeout()

+ 28 - 5
http2/transport_test.go

@@ -2747,7 +2747,6 @@ func TestTransportCancelDataResponseRace(t *testing.T) {
 }
 
 func TestTransportRetryAfterGOAWAY(t *testing.T) {
-	t.Skip("to be unskipped by https://go-review.googlesource.com/c/33971/")
 	var dialer struct {
 		sync.Mutex
 		count int
@@ -2765,6 +2764,9 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
 		dialer.Lock()
 		defer dialer.Unlock()
 		dialer.count++
+		if dialer.count == 3 {
+			return nil, errors.New("unexpected number of dials")
+		}
 		cc, err := net.Dial("tcp", ln.Addr().String())
 		if err != nil {
 			return nil, fmt.Errorf("dial error: %v", err)
@@ -2797,10 +2799,20 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
 	go func() {
 		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
 		res, err := tr.RoundTrip(req)
-		t.Logf("client got %T, %v", res, err)
+		if res != nil {
+			res.Body.Close()
+			if got := res.Header.Get("Foo"); got != "bar" {
+				err = fmt.Errorf("foo header = %q; want bar", got)
+			}
+		}
+		if err != nil {
+			err = fmt.Errorf("RoundTrip: %v", err)
+		}
 		errs <- err
 	}()
 
+	connToClose := make(chan io.Closer, 2)
+
 	// Server for the first request.
 	go func() {
 		var ct *clientTester
@@ -2810,6 +2822,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
 			return
 		}
 
+		connToClose <- ct.cc
 		ct.greet()
 		hf, err := ct.firstHeaders()
 		if err != nil {
@@ -2821,7 +2834,6 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
 			errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
 			return
 		}
-		ct.cc.(*net.TCPConn).Close()
 		errs <- nil
 	}()
 
@@ -2834,17 +2846,19 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
 			return
 		}
 
+		connToClose <- ct.cc
 		ct.greet()
 		hf, err := ct.firstHeaders()
 		if err != nil {
 			errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
 			return
 		}
-		t.Logf("server2 Got %v", hf)
+		t.Logf("server2 got %v", hf)
 
 		var buf bytes.Buffer
 		enc := hpack.NewEncoder(&buf)
 		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
 		err = ct.fr.WriteHeaders(HeadersFrameParam{
 			StreamID:      hf.StreamID,
 			EndHeaders:    true,
@@ -2852,7 +2866,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
 			BlockFragment: buf.Bytes(),
 		})
 		if err != nil {
-			errs <- fmt.Errorf("server2 failed writin responseg HEADERS: %v", err)
+			errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
 		} else {
 			errs <- nil
 		}
@@ -2868,4 +2882,13 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
 			t.Errorf("timed out")
 		}
 	}
+
+	for {
+		select {
+		case c := <-connToClose:
+			c.Close()
+		default:
+			return
+		}
+	}
 }