|
|
@@ -2011,3 +2011,77 @@ func TestTransportFlowControl(t *testing.T) {
|
|
|
time.Sleep(1 * time.Millisecond)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+// golang.org/issue/14627 -- if the server sends a GOAWAY frame, make
|
|
|
+// the Transport remember it and return it back to users (via
|
|
|
+// RoundTrip or request body reads) if needed (e.g. if the server
|
|
|
+// proceeds to close the TCP connection before the client gets its
|
|
|
+// response)
|
|
|
+func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
|
|
|
+ testTransportUsesGoAwayDebugError(t, false)
|
|
|
+}
|
|
|
+
|
|
|
+func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
|
|
|
+ testTransportUsesGoAwayDebugError(t, true)
|
|
|
+}
|
|
|
+
|
|
|
+func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
|
|
|
+ ct := newClientTester(t)
|
|
|
+ clientDone := make(chan struct{})
|
|
|
+
|
|
|
+ const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
|
|
|
+ const goAwayDebugData = "some debug data"
|
|
|
+
|
|
|
+ ct.client = func() error {
|
|
|
+ defer close(clientDone)
|
|
|
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
|
|
+ res, err := ct.tr.RoundTrip(req)
|
|
|
+ if failMidBody {
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("unexpected client RoundTrip error: %v", err)
|
|
|
+ }
|
|
|
+ _, err = io.Copy(ioutil.Discard, res.Body)
|
|
|
+ res.Body.Close()
|
|
|
+ }
|
|
|
+ want := GoAwayError{
|
|
|
+ LastStreamID: 0,
|
|
|
+ ErrCode: goAwayErrCode,
|
|
|
+ DebugData: goAwayDebugData,
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(err, want) {
|
|
|
+ t.Errorf("RoundTrip error = %T: %#v, want %T (%#T)", err, err, want, want)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ ct.server = func() error {
|
|
|
+ ct.greet()
|
|
|
+ for {
|
|
|
+ f, err := ct.fr.ReadFrame()
|
|
|
+ if err != nil {
|
|
|
+ t.Logf("ReadFrame: %v", err)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ hf, ok := f.(*HeadersFrame)
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if failMidBody {
|
|
|
+ var buf bytes.Buffer
|
|
|
+ enc := hpack.NewEncoder(&buf)
|
|
|
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
|
|
+ enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
|
|
|
+ ct.fr.WriteHeaders(HeadersFrameParam{
|
|
|
+ StreamID: hf.StreamID,
|
|
|
+ EndHeaders: true,
|
|
|
+ EndStream: false,
|
|
|
+ BlockFragment: buf.Bytes(),
|
|
|
+ })
|
|
|
+ }
|
|
|
+ ct.fr.WriteGoAway(0, goAwayErrCode, []byte(goAwayDebugData))
|
|
|
+ ct.sc.Close()
|
|
|
+ <-clientDone
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ct.run()
|
|
|
+}
|