|
|
@@ -300,22 +300,61 @@ func (cc *clientConn) streamByID(id uint32) *clientStream {
|
|
|
func (cc *clientConn) readLoop() {
|
|
|
defer close(cc.readerDone)
|
|
|
|
|
|
+ activeRes := map[uint32]*clientStream{} // keyed by streamID
|
|
|
+ // Close any response bodies if the server closes prematurely.
|
|
|
+ // TODO: also do this if we've written the headers but not
|
|
|
+ // gotten a response yet.
|
|
|
+ defer func() {
|
|
|
+ err := cc.readerErr
|
|
|
+ if err == io.EOF {
|
|
|
+ err = io.ErrUnexpectedEOF
|
|
|
+ }
|
|
|
+ for _, cs := range activeRes {
|
|
|
+ cs.pw.CloseWithError(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // continueStreamID is the stream ID we're waiting for
|
|
|
+ // continuation frames for.
|
|
|
+ var continueStreamID uint32
|
|
|
+
|
|
|
for {
|
|
|
f, err := cc.fr.ReadFrame()
|
|
|
if err != nil {
|
|
|
cc.readerErr = err
|
|
|
- // TODO: don't log it.
|
|
|
- log.Printf("ReadFrame: %v", err)
|
|
|
return
|
|
|
}
|
|
|
- cs := cc.streamByID(f.Header().StreamID)
|
|
|
-
|
|
|
log.Printf("Transport received %v: %#v", f.Header(), f)
|
|
|
+
|
|
|
+ streamID := f.Header().StreamID
|
|
|
+
|
|
|
+ _, isContinue := f.(*ContinuationFrame)
|
|
|
+ if isContinue {
|
|
|
+ if streamID != continueStreamID {
|
|
|
+ cc.readerErr = ConnectionError(ErrCodeProtocol)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ } else if continueStreamID != 0 {
|
|
|
+ // Continue frames need to be adjacent in the stream
|
|
|
+ // and we were in the middle of headers.
|
|
|
+ cc.readerErr = ConnectionError(ErrCodeProtocol)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if streamID%2 == 0 {
|
|
|
+ // Ignore streams pushed from the server for now.
|
|
|
+ // These always have an even stream id.
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ cs := cc.streamByID(streamID)
|
|
|
+ if cs == nil {
|
|
|
+ log.Printf("Received frame for untracked stream ID %d", streamID)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
headersEnded := false
|
|
|
streamEnded := false
|
|
|
- if ff, ok := f.(interface {
|
|
|
- StreamEnded() bool
|
|
|
- }); ok {
|
|
|
+ if ff, ok := f.(streamEnder); ok {
|
|
|
streamEnded = ff.StreamEnded()
|
|
|
}
|
|
|
switch f := f.(type) {
|
|
|
@@ -324,14 +363,11 @@ func (cc *clientConn) readLoop() {
|
|
|
Proto: "HTTP/2.0",
|
|
|
ProtoMajor: 2,
|
|
|
Header: make(http.Header),
|
|
|
- Request: nil, // TODO: set this
|
|
|
- TLS: nil, // TODO: set this
|
|
|
}
|
|
|
cs.pr, cs.pw = io.Pipe()
|
|
|
cc.hdec.Write(f.HeaderBlockFragment())
|
|
|
headersEnded = f.HeadersEnded()
|
|
|
case *ContinuationFrame:
|
|
|
- // TODO: verify stream id is the same
|
|
|
cc.hdec.Write(f.HeaderBlockFragment())
|
|
|
headersEnded = f.HeadersEnded()
|
|
|
case *DataFrame:
|
|
|
@@ -339,15 +375,24 @@ func (cc *clientConn) readLoop() {
|
|
|
cs.pw.Write(f.Data())
|
|
|
default:
|
|
|
}
|
|
|
+ if headersEnded {
|
|
|
+ continueStreamID = 0
|
|
|
+ } else {
|
|
|
+ continueStreamID = streamID
|
|
|
+ }
|
|
|
+
|
|
|
if streamEnded {
|
|
|
cs.pw.Close()
|
|
|
+ delete(activeRes, streamID)
|
|
|
}
|
|
|
if headersEnded {
|
|
|
if cs == nil {
|
|
|
panic("couldn't find stream") // TODO be graceful
|
|
|
}
|
|
|
cc.nextRes.Body = cs.pr
|
|
|
- cs.resc <- resAndError{res: cc.nextRes}
|
|
|
+ res := cc.nextRes
|
|
|
+ activeRes[streamID] = cs
|
|
|
+ cs.resc <- resAndError{res: res}
|
|
|
}
|
|
|
}
|
|
|
}
|