浏览代码

http2: swallow io.EOF while reading body and flow fix

This commit fixes two bugs.

The first bug returned io.EOF when a zero bytes were read from the
request body.

The second bug was a hang where the Transport waited for more flow
tokens than initialWindowSize BEFORE sending the first data frame which
never gave the server a chance to send flow tokens, so the client never
got enough to unblock awaitFlowControl. This commit changes
awaitFlowControl to wait for for [1,max] tokens, where max is the length
of the scratch buffer.

Change-Id: Ibbac0a38cd672535917a38330998d3b48d46f5f1
Reviewed-on: https://go-review.googlesource.com/16411
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Blake Mizerany 10 年之前
父节点
当前提交
1b27761f1c
共有 2 个文件被更改,包括 60 次插入32 次删除
  1. 23 13
      http2/transport.go
  2. 37 19
      http2/transport_test.go

+ 23 - 13
http2/transport.go

@@ -545,15 +545,18 @@ func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan st
 	done := false
 	for !done {
 		buf := cc.frameScratchBuffer()
-		n, err := io.ReadFull(body, buf)
-		if err == io.ErrUnexpectedEOF {
-			done = true
-		} else if err != nil {
+
+		taken, err := cs.awaitFlowControl(int32(len(buf)))
+		if err != nil {
 			return err
 		}
 
-		// Await for n flow control tokens.
-		if err := cs.awaitFlowControl(int32(n)); err != nil {
+		n, err := io.ReadFull(body, buf[:taken])
+		if err == io.ErrUnexpectedEOF {
+			done = true
+		} else if err == io.EOF {
+			break
+		} else if err != nil {
 			return err
 		}
 
@@ -567,8 +570,8 @@ func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan st
 			err = cc.fr.WriteData(cs.ID, done, buf[:n])
 		}
 		cc.wmu.Unlock()
-
 		cc.putFrameScratchBuffer(buf)
+
 		if err != nil {
 			return err
 		}
@@ -588,20 +591,27 @@ func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan st
 	return err
 }
 
-func (cs *clientStream) awaitFlowControl(n int32) error {
+// awaitFlowControl waits for [1,max] flow control tokens from the server. It
+// returns either the non-zero number of tokens taken or an error if the stream
+// is dead.
+func (cs *clientStream) awaitFlowControl(max int32) (taken int32, err error) {
 	cc := cs.cc
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	for {
 		if cc.closed {
-			return errClientConnClosed
+			return 0, errClientConnClosed
 		}
 		if err := cs.checkReset(); err != nil {
-			return err
+			return 0, err
 		}
-		if cs.flow.available() >= n {
-			cs.flow.take(n)
-			return nil
+		if a := cs.flow.available(); a > 0 {
+			take := a
+			if take > max {
+				take = max
+			}
+			cs.flow.take(take)
+			return take, nil
 		}
 		cc.cond.Wait()
 	}

+ 37 - 19
http2/transport_test.go

@@ -23,7 +23,7 @@ import (
 var (
 	extNet        = flag.Bool("extnet", false, "do external network tests")
 	transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
-	insecure      = flag.Bool("insecure", false, "insecure TLS dials")
+	insecure      = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove?
 )
 
 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
@@ -206,6 +206,18 @@ func TestTransportPath(t *testing.T) {
 	}
 }
 
+var bodyTests = []struct {
+	body         string
+	noContentLen bool
+}{
+	{body: "some message"},
+	{body: "some message", noContentLen: true},
+	{body: ""},
+	{body: "", noContentLen: true},
+	{body: strings.Repeat("a", 1<<20), noContentLen: true},
+	{body: strings.Repeat("a", 1<<20)},
+}
+
 func TestTransportBody(t *testing.T) {
 	gotc := make(chan interface{}, 1)
 	st := newServerTester(t,
@@ -222,24 +234,30 @@ func TestTransportBody(t *testing.T) {
 	)
 	defer st.Close()
 
-	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
-	defer tr.CloseIdleConnections()
-	const body = "Some message"
-	req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(body))
-	if err != nil {
-		t.Fatal(err)
-	}
-	c := &http.Client{Transport: tr}
-	res, err := c.Do(req)
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer res.Body.Close()
-	got := <-gotc
-	if err, ok := got.(error); ok {
-		t.Fatal(err)
-	} else if got.(string) != body {
-		t.Errorf("Read body = %q; want %q", got, body)
+	for i, tt := range bodyTests {
+		tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+		defer tr.CloseIdleConnections()
+
+		var body io.Reader = strings.NewReader(tt.body)
+		if tt.noContentLen {
+			body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods
+		}
+		req, err := http.NewRequest("POST", st.ts.URL, body)
+		if err != nil {
+			t.Fatalf("#%d: %v", i, err)
+		}
+		c := &http.Client{Transport: tr}
+		res, err := c.Do(req)
+		if err != nil {
+			t.Fatalf("#%d: %v", i, err)
+		}
+		defer res.Body.Close()
+		got := <-gotc
+		if err, ok := got.(error); ok {
+			t.Fatalf("#%d: %v", i, err)
+		} else if got.(string) != tt.body {
+			t.Errorf("#%d: Read body = %q; want %q", i, got, tt.body)
+		}
 	}
 }