Просмотр исходного кода

transport: cleanups, CONTINUATION strictness, track active requests per conn

Brad Fitzpatrick 11 лет назад
Родитель
Сommit
7bb8b7cf75
2 измененных файлов с 60 добавлено и 11 удалено
  1. 4 0
      frame.go
  2. 56 11
      transport.go

+ 4 - 0
frame.go

@@ -1103,3 +1103,7 @@ func readUint32(p []byte) (remain []byte, v uint32, err error) {
 	}
 	}
 	return p[4:], binary.BigEndian.Uint32(p[:4]), nil
 	return p[4:], binary.BigEndian.Uint32(p[:4]), nil
 }
 }
+
+type streamEnder interface {
+	StreamEnded() bool
+}

+ 56 - 11
transport.go

@@ -300,22 +300,61 @@ func (cc *clientConn) streamByID(id uint32) *clientStream {
 func (cc *clientConn) readLoop() {
 func (cc *clientConn) readLoop() {
 	defer close(cc.readerDone)
 	defer close(cc.readerDone)
 
 
+	activeRes := map[uint32]*clientStream{} // keyed by streamID
+	// Close any response bodies if the server closes prematurely.
+	// TODO: also do this if we've written the headers but not
+	// gotten a response yet.
+	defer func() {
+		err := cc.readerErr
+		if err == io.EOF {
+			err = io.ErrUnexpectedEOF
+		}
+		for _, cs := range activeRes {
+			cs.pw.CloseWithError(err)
+		}
+	}()
+
+	// continueStreamID is the stream ID we're waiting for
+	// continuation frames for.
+	var continueStreamID uint32
+
 	for {
 	for {
 		f, err := cc.fr.ReadFrame()
 		f, err := cc.fr.ReadFrame()
 		if err != nil {
 		if err != nil {
 			cc.readerErr = err
 			cc.readerErr = err
-			// TODO: don't log it.
-			log.Printf("ReadFrame: %v", err)
 			return
 			return
 		}
 		}
-		cs := cc.streamByID(f.Header().StreamID)
-
 		log.Printf("Transport received %v: %#v", f.Header(), f)
 		log.Printf("Transport received %v: %#v", f.Header(), f)
+
+		streamID := f.Header().StreamID
+
+		_, isContinue := f.(*ContinuationFrame)
+		if isContinue {
+			if streamID != continueStreamID {
+				cc.readerErr = ConnectionError(ErrCodeProtocol)
+				return
+			}
+		} else if continueStreamID != 0 {
+			// Continue frames need to be adjacent in the stream
+			// and we were in the middle of headers.
+			cc.readerErr = ConnectionError(ErrCodeProtocol)
+			return
+		}
+
+		if streamID%2 == 0 {
+			// Ignore streams pushed from the server for now.
+			// These always have an even stream id.
+			continue
+		}
+		cs := cc.streamByID(streamID)
+		if cs == nil {
+			log.Printf("Received frame for untracked stream ID %d", streamID)
+			continue
+		}
+
 		headersEnded := false
 		headersEnded := false
 		streamEnded := false
 		streamEnded := false
-		if ff, ok := f.(interface {
-			StreamEnded() bool
-		}); ok {
+		if ff, ok := f.(streamEnder); ok {
 			streamEnded = ff.StreamEnded()
 			streamEnded = ff.StreamEnded()
 		}
 		}
 		switch f := f.(type) {
 		switch f := f.(type) {
@@ -324,14 +363,11 @@ func (cc *clientConn) readLoop() {
 				Proto:      "HTTP/2.0",
 				Proto:      "HTTP/2.0",
 				ProtoMajor: 2,
 				ProtoMajor: 2,
 				Header:     make(http.Header),
 				Header:     make(http.Header),
-				Request:    nil, // TODO: set this
-				TLS:        nil, // TODO: set this
 			}
 			}
 			cs.pr, cs.pw = io.Pipe()
 			cs.pr, cs.pw = io.Pipe()
 			cc.hdec.Write(f.HeaderBlockFragment())
 			cc.hdec.Write(f.HeaderBlockFragment())
 			headersEnded = f.HeadersEnded()
 			headersEnded = f.HeadersEnded()
 		case *ContinuationFrame:
 		case *ContinuationFrame:
-			// TODO: verify stream id is the same
 			cc.hdec.Write(f.HeaderBlockFragment())
 			cc.hdec.Write(f.HeaderBlockFragment())
 			headersEnded = f.HeadersEnded()
 			headersEnded = f.HeadersEnded()
 		case *DataFrame:
 		case *DataFrame:
@@ -339,15 +375,24 @@ func (cc *clientConn) readLoop() {
 			cs.pw.Write(f.Data())
 			cs.pw.Write(f.Data())
 		default:
 		default:
 		}
 		}
+		if headersEnded {
+			continueStreamID = 0
+		} else {
+			continueStreamID = streamID
+		}
+
 		if streamEnded {
 		if streamEnded {
 			cs.pw.Close()
 			cs.pw.Close()
+			delete(activeRes, streamID)
 		}
 		}
 		if headersEnded {
 		if headersEnded {
 			if cs == nil {
 			if cs == nil {
 				panic("couldn't find stream") // TODO be graceful
 				panic("couldn't find stream") // TODO be graceful
 			}
 			}
 			cc.nextRes.Body = cs.pr
 			cc.nextRes.Body = cs.pr
-			cs.resc <- resAndError{res: cc.nextRes}
+			res := cc.nextRes
+			activeRes[streamID] = cs
+			cs.resc <- resAndError{res: res}
 		}
 		}
 	}
 	}
 }
 }