浏览代码

http2: fix Transport's flow control control when writing request bodies

Adapation of Blake's proposed fix in https://golang.org/cl/16463
with a few changes:

-- bug fix (advance the buffer after writing)
-- don't reacquire/release the buffer in the loop. it was done like
   that in case the max frame size changed while writing. Instead, push
   that down into awaitFlowControl since it has to acquire that lock
   anyway. Now it returns between 1 and the lower of how much we read
   and how much we're allowed to write.

This does mean that if we start a request with a max frame size of
32KB, we'll never write larger than 32KB frames until the the next
request (because our scratch buffer we read into is only 32KB), but we
will start writing smaller DATA frames immediately once we see the
peer's SETTINGS frame.

Change-Id: I47fc503062f9602fe448cf7a36fc500e5d6b8ef9
Reviewed-on: https://go-review.googlesource.com/16443
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
Brad Fitzpatrick 10 年之前
父节点
当前提交
24ab552e98
共有 1 个文件被更改,包括 34 次插入25 次删除
  1. 34 25
      http2/transport.go

+ 34 - 25
http2/transport.go

@@ -543,35 +543,40 @@ var errServerResponseBeforeRequestBody = errors.New("http2: server sent response
 func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan struct{}) error {
 func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan struct{}) error {
 	cc := cs.cc
 	cc := cs.cc
 	done := false
 	done := false
-	for !done {
-		buf := cc.frameScratchBuffer()
-
-		taken, err := cs.awaitFlowControl(int32(len(buf)))
-		if err != nil {
-			return err
-		}
+	buf := cc.frameScratchBuffer()
+	defer cc.putFrameScratchBuffer(buf)
 
 
-		n, err := io.ReadFull(body, buf[:taken])
+	for !done {
+		n, err := io.ReadFull(body, buf)
 		if err == io.ErrUnexpectedEOF {
 		if err == io.ErrUnexpectedEOF {
 			done = true
 			done = true
+			err = nil
 		} else if err == io.EOF {
 		} else if err == io.EOF {
 			break
 			break
 		} else if err != nil {
 		} else if err != nil {
 			return err
 			return err
 		}
 		}
 
 
-		cc.wmu.Lock()
-		select {
-		case <-gotResHeaders:
-			err = errServerResponseBeforeRequestBody
-		case <-cs.peerReset:
-			err = cs.resetErr
-		default:
-			err = cc.fr.WriteData(cs.ID, done, buf[:n])
-		}
-		cc.wmu.Unlock()
-		cc.putFrameScratchBuffer(buf)
+		toWrite := buf[:n]
+		for len(toWrite) > 0 && err == nil {
+			var allowed int32
+			allowed, err = cs.awaitFlowControl(int32(len(toWrite)))
+			if err != nil {
+				return err
+			}
 
 
+			cc.wmu.Lock()
+			select {
+			case <-gotResHeaders:
+				err = errServerResponseBeforeRequestBody
+			case <-cs.peerReset:
+				err = cs.resetErr
+			default:
+				err = cc.fr.WriteData(cs.ID, done, toWrite[:allowed])
+				toWrite = toWrite[allowed:]
+			}
+			cc.wmu.Unlock()
+		}
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -591,10 +596,11 @@ func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan st
 	return err
 	return err
 }
 }
 
 
-// awaitFlowControl waits for [1,max] flow control tokens from the server. It
-// returns either the non-zero number of tokens taken or an error if the stream
-// is dead.
-func (cs *clientStream) awaitFlowControl(max int32) (taken int32, err error) {
+// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow
+// control tokens from the server.
+// It returns either the non-zero number of tokens taken or an error
+// if the stream is dead.
+func (cs *clientStream) awaitFlowControl(maxBytes int32) (taken int32, err error) {
 	cc := cs.cc
 	cc := cs.cc
 	cc.mu.Lock()
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	defer cc.mu.Unlock()
@@ -607,8 +613,11 @@ func (cs *clientStream) awaitFlowControl(max int32) (taken int32, err error) {
 		}
 		}
 		if a := cs.flow.available(); a > 0 {
 		if a := cs.flow.available(); a > 0 {
 			take := a
 			take := a
-			if take > max {
-				take = max
+			if take > maxBytes {
+				take = maxBytes
+			}
+			if take > int32(cc.maxFrameSize) {
+				take = int32(cc.maxFrameSize)
 			}
 			}
 			cs.flow.take(take)
 			cs.flow.take(take)
 			return take, nil
 			return take, nil