Browse Source

http2: fix Transport.RoundTrip hang on stream error before headers

If the Transport got a stream error on the response headers, it was
never unblocking the client. Previously, Response.Body reads would be
aborted with the stream error, but RoundTrip itself would never
unblock.

The Transport now also sends a RST_STREAM to the server when we
encounter a stream error.

Also, add a "Cause" field to StreamError with additional detail. The
old code was just returning the detail, without the stream error
header.

Fixes golang/go#16572

Change-Id: Ibecedb5779f17bf98c32787b68eb8a9b850833b3
Reviewed-on: https://go-review.googlesource.com/25402
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick 9 năm trước cách đây
mục cha
commit
e2ba55e4e7

+ 8 - 0
http2/errors.go

@@ -64,9 +64,17 @@ func (e ConnectionError) Error() string { return fmt.Sprintf("connection error:
 type StreamError struct {
 type StreamError struct {
 	StreamID uint32
 	StreamID uint32
 	Code     ErrCode
 	Code     ErrCode
+	Cause    error // optional additional detail
+}
+
+func streamError(id uint32, code ErrCode) StreamError {
+	return StreamError{StreamID: id, Code: code}
 }
 }
 
 
 func (e StreamError) Error() string {
 func (e StreamError) Error() string {
+	if e.Cause != nil {
+		return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause)
+	}
 	return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
 	return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
 }
 }
 
 

+ 4 - 4
http2/frame.go

@@ -863,7 +863,7 @@ func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) {
 		if fh.StreamID == 0 {
 		if fh.StreamID == 0 {
 			return nil, ConnectionError(ErrCodeProtocol)
 			return nil, ConnectionError(ErrCodeProtocol)
 		}
 		}
-		return nil, StreamError{fh.StreamID, ErrCodeProtocol}
+		return nil, streamError(fh.StreamID, ErrCodeProtocol)
 	}
 	}
 	return &WindowUpdateFrame{
 	return &WindowUpdateFrame{
 		FrameHeader: fh,
 		FrameHeader: fh,
@@ -944,7 +944,7 @@ func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) {
 		}
 		}
 	}
 	}
 	if len(p)-int(padLength) <= 0 {
 	if len(p)-int(padLength) <= 0 {
-		return nil, StreamError{fh.StreamID, ErrCodeProtocol}
+		return nil, streamError(fh.StreamID, ErrCodeProtocol)
 	}
 	}
 	hf.headerFragBuf = p[:len(p)-int(padLength)]
 	hf.headerFragBuf = p[:len(p)-int(padLength)]
 	return hf, nil
 	return hf, nil
@@ -1483,14 +1483,14 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
 		if VerboseLogs {
 		if VerboseLogs {
 			log.Printf("http2: invalid header: %v", invalid)
 			log.Printf("http2: invalid header: %v", invalid)
 		}
 		}
-		return nil, StreamError{mh.StreamID, ErrCodeProtocol}
+		return nil, StreamError{mh.StreamID, ErrCodeProtocol, invalid}
 	}
 	}
 	if err := mh.checkPseudos(); err != nil {
 	if err := mh.checkPseudos(); err != nil {
 		fr.errDetail = err
 		fr.errDetail = err
 		if VerboseLogs {
 		if VerboseLogs {
 			log.Printf("http2: invalid pseudo headers: %v", err)
 			log.Printf("http2: invalid pseudo headers: %v", err)
 		}
 		}
-		return nil, StreamError{mh.StreamID, ErrCodeProtocol}
+		return nil, StreamError{mh.StreamID, ErrCodeProtocol, err}
 	}
 	}
 	return mh, nil
 	return mh, nil
 }
 }

+ 13 - 6
http2/frame_test.go

@@ -992,7 +992,7 @@ func TestMetaFrameHeader(t *testing.T) {
 					":path", "/", // bogus
 					":path", "/", // bogus
 				))
 				))
 			},
 			},
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "pseudo header field after regular",
 			wantErrReason: "pseudo header field after regular",
 		},
 		},
 		7: {
 		7: {
@@ -1003,7 +1003,7 @@ func TestMetaFrameHeader(t *testing.T) {
 					"foo", "bar",
 					"foo", "bar",
 				))
 				))
 			},
 			},
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "invalid pseudo-header \":unknown\"",
 			wantErrReason: "invalid pseudo-header \":unknown\"",
 		},
 		},
 		8: {
 		8: {
@@ -1014,7 +1014,7 @@ func TestMetaFrameHeader(t *testing.T) {
 					":status", "100",
 					":status", "100",
 				))
 				))
 			},
 			},
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "mix of request and response pseudo headers",
 			wantErrReason: "mix of request and response pseudo headers",
 		},
 		},
 		9: {
 		9: {
@@ -1025,7 +1025,7 @@ func TestMetaFrameHeader(t *testing.T) {
 					":method", "POST",
 					":method", "POST",
 				))
 				))
 			},
 			},
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "duplicate pseudo-header \":method\"",
 			wantErrReason: "duplicate pseudo-header \":method\"",
 		},
 		},
 		10: {
 		10: {
@@ -1036,13 +1036,13 @@ func TestMetaFrameHeader(t *testing.T) {
 		11: {
 		11: {
 			name:          "invalid_field_name",
 			name:          "invalid_field_name",
 			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) },
 			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) },
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "invalid header field name \"CapitalBad\"",
 			wantErrReason: "invalid header field name \"CapitalBad\"",
 		},
 		},
 		12: {
 		12: {
 			name:          "invalid_field_value",
 			name:          "invalid_field_value",
 			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) },
 			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) },
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "invalid header field value \"bad_null\\x00\"",
 			wantErrReason: "invalid header field value \"bad_null\\x00\"",
 		},
 		},
 	}
 	}
@@ -1063,6 +1063,13 @@ func TestMetaFrameHeader(t *testing.T) {
 		got, err = f.ReadFrame()
 		got, err = f.ReadFrame()
 		if err != nil {
 		if err != nil {
 			got = err
 			got = err
+
+			// Ignore the StreamError.Cause field, if it matches the wantErrReason.
+			// The test table above predates the Cause field.
+			if se, ok := err.(StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason {
+				se.Cause = nil
+				got = se
+			}
 		}
 		}
 		if !reflect.DeepEqual(got, tt.want) {
 		if !reflect.DeepEqual(got, tt.want) {
 			if mhg, ok := got.(*MetaHeadersFrame); ok {
 			if mhg, ok := got.(*MetaHeadersFrame); ok {

+ 17 - 17
http2/server.go

@@ -922,7 +922,7 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
 			// state here anyway, after telling the peer
 			// state here anyway, after telling the peer
 			// we're hanging up on them.
 			// we're hanging up on them.
 			st.state = stateHalfClosedLocal // won't last long, but necessary for closeStream via resetStream
 			st.state = stateHalfClosedLocal // won't last long, but necessary for closeStream via resetStream
-			errCancel := StreamError{st.id, ErrCodeCancel}
+			errCancel := streamError(st.id, ErrCodeCancel)
 			sc.resetStream(errCancel)
 			sc.resetStream(errCancel)
 		case stateHalfClosedRemote:
 		case stateHalfClosedRemote:
 			sc.closeStream(st, errHandlerComplete)
 			sc.closeStream(st, errHandlerComplete)
@@ -1133,7 +1133,7 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
 			return nil
 			return nil
 		}
 		}
 		if !st.flow.add(int32(f.Increment)) {
 		if !st.flow.add(int32(f.Increment)) {
-			return StreamError{f.StreamID, ErrCodeFlowControl}
+			return streamError(f.StreamID, ErrCodeFlowControl)
 		}
 		}
 	default: // connection-level flow control
 	default: // connection-level flow control
 		if !sc.flow.add(int32(f.Increment)) {
 		if !sc.flow.add(int32(f.Increment)) {
@@ -1159,7 +1159,7 @@ func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
 	if st != nil {
 	if st != nil {
 		st.gotReset = true
 		st.gotReset = true
 		st.cancelCtx()
 		st.cancelCtx()
-		sc.closeStream(st, StreamError{f.StreamID, f.ErrCode})
+		sc.closeStream(st, streamError(f.StreamID, f.ErrCode))
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -1299,7 +1299,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
 		// and return any flow control bytes since we're not going
 		// and return any flow control bytes since we're not going
 		// to consume them.
 		// to consume them.
 		if sc.inflow.available() < int32(f.Length) {
 		if sc.inflow.available() < int32(f.Length) {
-			return StreamError{id, ErrCodeFlowControl}
+			return streamError(id, ErrCodeFlowControl)
 		}
 		}
 		// Deduct the flow control from inflow, since we're
 		// Deduct the flow control from inflow, since we're
 		// going to immediately add it back in
 		// going to immediately add it back in
@@ -1308,7 +1308,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
 		sc.inflow.take(int32(f.Length))
 		sc.inflow.take(int32(f.Length))
 		sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
 		sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
 
 
-		return StreamError{id, ErrCodeStreamClosed}
+		return streamError(id, ErrCodeStreamClosed)
 	}
 	}
 	if st.body == nil {
 	if st.body == nil {
 		panic("internal error: should have a body in this state")
 		panic("internal error: should have a body in this state")
@@ -1317,19 +1317,19 @@ func (sc *serverConn) processData(f *DataFrame) error {
 	// Sender sending more than they'd declared?
 	// Sender sending more than they'd declared?
 	if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
 	if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
 		st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
 		st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
-		return StreamError{id, ErrCodeStreamClosed}
+		return streamError(id, ErrCodeStreamClosed)
 	}
 	}
 	if f.Length > 0 {
 	if f.Length > 0 {
 		// Check whether the client has flow control quota.
 		// Check whether the client has flow control quota.
 		if st.inflow.available() < int32(f.Length) {
 		if st.inflow.available() < int32(f.Length) {
-			return StreamError{id, ErrCodeFlowControl}
+			return streamError(id, ErrCodeFlowControl)
 		}
 		}
 		st.inflow.take(int32(f.Length))
 		st.inflow.take(int32(f.Length))
 
 
 		if len(data) > 0 {
 		if len(data) > 0 {
 			wrote, err := st.body.Write(data)
 			wrote, err := st.body.Write(data)
 			if err != nil {
 			if err != nil {
-				return StreamError{id, ErrCodeStreamClosed}
+				return streamError(id, ErrCodeStreamClosed)
 			}
 			}
 			if wrote != len(data) {
 			if wrote != len(data) {
 				panic("internal error: bad Writer")
 				panic("internal error: bad Writer")
@@ -1446,14 +1446,14 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
 		// REFUSED_STREAM."
 		// REFUSED_STREAM."
 		if sc.unackedSettings == 0 {
 		if sc.unackedSettings == 0 {
 			// They should know better.
 			// They should know better.
-			return StreamError{st.id, ErrCodeProtocol}
+			return streamError(st.id, ErrCodeProtocol)
 		}
 		}
 		// Assume it's a network race, where they just haven't
 		// Assume it's a network race, where they just haven't
 		// received our last SETTINGS update. But actually
 		// received our last SETTINGS update. But actually
 		// this can't happen yet, because we don't yet provide
 		// this can't happen yet, because we don't yet provide
 		// a way for users to adjust server parameters at
 		// a way for users to adjust server parameters at
 		// runtime.
 		// runtime.
-		return StreamError{st.id, ErrCodeRefusedStream}
+		return streamError(st.id, ErrCodeRefusedStream)
 	}
 	}
 
 
 	rw, req, err := sc.newWriterAndRequest(st, f)
 	rw, req, err := sc.newWriterAndRequest(st, f)
@@ -1487,11 +1487,11 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
 	}
 	}
 	st.gotTrailerHeader = true
 	st.gotTrailerHeader = true
 	if !f.StreamEnded() {
 	if !f.StreamEnded() {
-		return StreamError{st.id, ErrCodeProtocol}
+		return streamError(st.id, ErrCodeProtocol)
 	}
 	}
 
 
 	if len(f.PseudoFields()) > 0 {
 	if len(f.PseudoFields()) > 0 {
-		return StreamError{st.id, ErrCodeProtocol}
+		return streamError(st.id, ErrCodeProtocol)
 	}
 	}
 	if st.trailer != nil {
 	if st.trailer != nil {
 		for _, hf := range f.RegularFields() {
 		for _, hf := range f.RegularFields() {
@@ -1500,7 +1500,7 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
 				// TODO: send more details to the peer somehow. But http2 has
 				// TODO: send more details to the peer somehow. But http2 has
 				// no way to send debug data at a stream level. Discuss with
 				// no way to send debug data at a stream level. Discuss with
 				// HTTP folk.
 				// HTTP folk.
-				return StreamError{st.id, ErrCodeProtocol}
+				return streamError(st.id, ErrCodeProtocol)
 			}
 			}
 			st.trailer[key] = append(st.trailer[key], hf.Value)
 			st.trailer[key] = append(st.trailer[key], hf.Value)
 		}
 		}
@@ -1561,7 +1561,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
 	isConnect := method == "CONNECT"
 	isConnect := method == "CONNECT"
 	if isConnect {
 	if isConnect {
 		if path != "" || scheme != "" || authority == "" {
 		if path != "" || scheme != "" || authority == "" {
-			return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
+			return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 		}
 		}
 	} else if method == "" || path == "" ||
 	} else if method == "" || path == "" ||
 		(scheme != "https" && scheme != "http") {
 		(scheme != "https" && scheme != "http") {
@@ -1575,13 +1575,13 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
 		// "All HTTP/2 requests MUST include exactly one valid
 		// "All HTTP/2 requests MUST include exactly one valid
 		// value for the :method, :scheme, and :path
 		// value for the :method, :scheme, and :path
 		// pseudo-header fields"
 		// pseudo-header fields"
-		return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
+		return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 	}
 	}
 
 
 	bodyOpen := !f.StreamEnded()
 	bodyOpen := !f.StreamEnded()
 	if method == "HEAD" && bodyOpen {
 	if method == "HEAD" && bodyOpen {
 		// HEAD requests can't have bodies
 		// HEAD requests can't have bodies
-		return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
+		return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 	}
 	}
 	var tlsState *tls.ConnectionState // nil if not scheme https
 	var tlsState *tls.ConnectionState // nil if not scheme https
 
 
@@ -1639,7 +1639,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
 		var err error
 		var err error
 		url_, err = url.ParseRequestURI(path)
 		url_, err = url.ParseRequestURI(path)
 		if err != nil {
 		if err != nil {
-			return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
+			return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 		}
 		}
 		requestURI = path
 		requestURI = path
 	}
 	}

+ 16 - 22
http2/server_test.go

@@ -55,11 +55,6 @@ type serverTester struct {
 	// writing headers:
 	// writing headers:
 	headerBuf bytes.Buffer
 	headerBuf bytes.Buffer
 	hpackEnc  *hpack.Encoder
 	hpackEnc  *hpack.Encoder
-
-	// reading frames:
-	frc       chan Frame
-	frErrc    chan error
-	readTimer *time.Timer
 }
 }
 
 
 func init() {
 func init() {
@@ -117,8 +112,6 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
 		t:      t,
 		t:      t,
 		ts:     ts,
 		ts:     ts,
 		logBuf: logBuf,
 		logBuf: logBuf,
-		frc:    make(chan Frame, 1),
-		frErrc: make(chan error, 1),
 	}
 	}
 	st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
 	st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
 	st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
 	st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
@@ -365,32 +358,33 @@ func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, p
 	}
 	}
 }
 }
 
 
-func (st *serverTester) readFrame() (Frame, error) {
+func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) {
+	ch := make(chan interface{}, 1)
 	go func() {
 	go func() {
-		fr, err := st.fr.ReadFrame()
+		fr, err := fr.ReadFrame()
 		if err != nil {
 		if err != nil {
-			st.frErrc <- err
+			ch <- err
 		} else {
 		} else {
-			st.frc <- fr
+			ch <- fr
 		}
 		}
 	}()
 	}()
-	t := st.readTimer
-	if t == nil {
-		t = time.NewTimer(2 * time.Second)
-		st.readTimer = t
-	}
-	t.Reset(2 * time.Second)
-	defer t.Stop()
+	t := time.NewTimer(wait)
 	select {
 	select {
-	case f := <-st.frc:
-		return f, nil
-	case err := <-st.frErrc:
-		return nil, err
+	case v := <-ch:
+		t.Stop()
+		if fr, ok := v.(Frame); ok {
+			return fr, nil
+		}
+		return nil, v.(error)
 	case <-t.C:
 	case <-t.C:
 		return nil, errors.New("timeout waiting for frame")
 		return nil, errors.New("timeout waiting for frame")
 	}
 	}
 }
 }
 
 
+func (st *serverTester) readFrame() (Frame, error) {
+	return readFrameTimeout(st.fr, 2*time.Second)
+}
+
 func (st *serverTester) wantHeaders() *HeadersFrame {
 func (st *serverTester) wantHeaders() *HeadersFrame {
 	f, err := st.readFrame()
 	f, err := st.readFrame()
 	if err != nil {
 	if err != nil {

+ 11 - 2
http2/transport.go

@@ -1229,7 +1229,11 @@ func (rl *clientConnReadLoop) run() error {
 		}
 		}
 		if se, ok := err.(StreamError); ok {
 		if se, ok := err.(StreamError); ok {
 			if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil {
 			if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil {
-				rl.endStreamError(cs, cc.fr.errDetail)
+				cs.cc.writeStreamReset(cs.ID, se.Code, err)
+				if se.Cause == nil {
+					se.Cause = cc.fr.errDetail
+				}
+				rl.endStreamError(cs, se)
 			}
 			}
 			continue
 			continue
 		} else if err != nil {
 		} else if err != nil {
@@ -1639,6 +1643,11 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
 	if isConnectionCloseRequest(cs.req) {
 	if isConnectionCloseRequest(cs.req) {
 		rl.closeWhenIdle = true
 		rl.closeWhenIdle = true
 	}
 	}
+
+	select {
+	case cs.resc <- resAndError{err: err}:
+	default:
+	}
 }
 }
 
 
 func (cs *clientStream) copyTrailers() {
 func (cs *clientStream) copyTrailers() {
@@ -1740,7 +1749,7 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
 		// which closes this, so there
 		// which closes this, so there
 		// isn't a race.
 		// isn't a race.
 	default:
 	default:
-		err := StreamError{cs.ID, f.ErrCode}
+		err := streamError(cs.ID, f.ErrCode)
 		cs.resetErr = err
 		cs.resetErr = err
 		close(cs.peerReset)
 		close(cs.peerReset)
 		cs.bufPipe.CloseWithError(err)
 		cs.bufPipe.CloseWithError(err)

+ 81 - 2
http2/transport_test.go

@@ -699,6 +699,28 @@ func (ct *clientTester) start(which string, errc chan<- error, fn func() error)
 	}()
 	}()
 }
 }
 
 
+func (ct *clientTester) readFrame() (Frame, error) {
+	return readFrameTimeout(ct.fr, 2*time.Second)
+}
+
+func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
+	for {
+		f, err := ct.readFrame()
+		if err != nil {
+			return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
+		}
+		switch f.(type) {
+		case *WindowUpdateFrame, *SettingsFrame:
+			continue
+		}
+		hf, ok := f.(*HeadersFrame)
+		if !ok {
+			return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
+		}
+		return hf, nil
+	}
+}
+
 type countingReader struct {
 type countingReader struct {
 	n *int64
 	n *int64
 }
 }
@@ -1224,8 +1246,9 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT
 			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
 			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
 		}
 		}
 		slurp, err := ioutil.ReadAll(res.Body)
 		slurp, err := ioutil.ReadAll(res.Body)
-		if err != wantErr {
-			return fmt.Errorf("res.Body ReadAll error = %q, %#v; want %T of %#v", slurp, err, wantErr, wantErr)
+		se, ok := err.(StreamError)
+		if !ok || se.Cause != wantErr {
+			return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
 		}
 		}
 		if len(slurp) > 0 {
 		if len(slurp) > 0 {
 			return fmt.Errorf("body = %q; want nothing", slurp)
 			return fmt.Errorf("body = %q; want nothing", slurp)
@@ -2278,3 +2301,59 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
 	}
 	}
 	ct.run()
 	ct.run()
 }
 }
+
+// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
+// StreamError as a result of the response HEADERS
+func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
+	ct := newClientTester(t)
+
+	ct.client = func() error {
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		res, err := ct.tr.RoundTrip(req)
+		if err == nil {
+			res.Body.Close()
+			return errors.New("unexpected successful GET")
+		}
+		want := StreamError{1, ErrCodeProtocol, headerFieldNameError("  content-type")}
+		if !reflect.DeepEqual(want, err) {
+			t.Errorf("RoundTrip error = %#v; want %#v", err, want)
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+
+		hf, err := ct.firstHeaders()
+		if err != nil {
+			return err
+		}
+
+		var buf bytes.Buffer
+		enc := hpack.NewEncoder(&buf)
+		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+		enc.WriteField(hpack.HeaderField{Name: "  content-type", Value: "bogus"}) // bogus spaces
+		ct.fr.WriteHeaders(HeadersFrameParam{
+			StreamID:      hf.StreamID,
+			EndHeaders:    true,
+			EndStream:     false,
+			BlockFragment: buf.Bytes(),
+		})
+
+		for {
+			fr, err := ct.readFrame()
+			if err != nil {
+				return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
+			}
+			if _, ok := fr.(*SettingsFrame); ok {
+				continue
+			}
+			if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
+				t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
+			}
+			break
+		}
+
+		return nil
+	}
+	ct.run()
+}