瀏覽代碼

http2: test for retry after GOAWAY

Updates #golang/go#18083

Change-Id: If0ebbbb9b5644a4cec94c89e28939fcaeb4de321
Reviewed-on: https://go-review.googlesource.com/33972
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Tom Bergan 9 年之前
父節點
當前提交
4a1c855a05
共有 1 個文件被更改,包括 124 次插入0 次删除
  1. 124 0
      http2/transport_test.go

+ 124 - 0
http2/transport_test.go

@@ -2745,3 +2745,127 @@ func TestTransportCancelDataResponseRace(t *testing.T) {
 		t.Errorf("Got = %q; want %q", slurp, msg)
 	}
 }
+
+func TestTransportRetryAfterGOAWAY(t *testing.T) {
+	t.Skip("to be unskipped by https://go-review.googlesource.com/c/33971/")
+	var dialer struct {
+		sync.Mutex
+		count int
+	}
+	ct1 := make(chan *clientTester)
+	ct2 := make(chan *clientTester)
+
+	ln := newLocalListener(t)
+	defer ln.Close()
+
+	tr := &Transport{
+		TLSClientConfig: tlsConfigInsecure,
+	}
+	tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+		dialer.Lock()
+		defer dialer.Unlock()
+		dialer.count++
+		cc, err := net.Dial("tcp", ln.Addr().String())
+		if err != nil {
+			return nil, fmt.Errorf("dial error: %v", err)
+		}
+		sc, err := ln.Accept()
+		if err != nil {
+			return nil, fmt.Errorf("accept error: %v", err)
+		}
+		ct := &clientTester{
+			t:  t,
+			tr: tr,
+			cc: cc,
+			sc: sc,
+			fr: NewFramer(sc, sc),
+		}
+		switch dialer.count {
+		case 1:
+			ct1 <- ct
+		case 2:
+			ct2 <- ct
+		}
+		return cc, nil
+	}
+
+	errs := make(chan error, 3)
+	done := make(chan struct{})
+	defer close(done)
+
+	// Client.
+	go func() {
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		res, err := tr.RoundTrip(req)
+		t.Logf("client got %T, %v", res, err)
+		errs <- err
+	}()
+
+	// Server for the first request.
+	go func() {
+		var ct *clientTester
+		select {
+		case ct = <-ct1:
+		case <-done:
+			return
+		}
+
+		ct.greet()
+		hf, err := ct.firstHeaders()
+		if err != nil {
+			errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
+			return
+		}
+		t.Logf("server1 got %v", hf)
+		if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
+			errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
+			return
+		}
+		ct.cc.(*net.TCPConn).Close()
+		errs <- nil
+	}()
+
+	// Server for the second request.
+	go func() {
+		var ct *clientTester
+		select {
+		case ct = <-ct2:
+		case <-done:
+			return
+		}
+
+		ct.greet()
+		hf, err := ct.firstHeaders()
+		if err != nil {
+			errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
+			return
+		}
+		t.Logf("server2 Got %v", hf)
+
+		var buf bytes.Buffer
+		enc := hpack.NewEncoder(&buf)
+		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+		err = ct.fr.WriteHeaders(HeadersFrameParam{
+			StreamID:      hf.StreamID,
+			EndHeaders:    true,
+			EndStream:     false,
+			BlockFragment: buf.Bytes(),
+		})
+		if err != nil {
+			errs <- fmt.Errorf("server2 failed writin responseg HEADERS: %v", err)
+		} else {
+			errs <- nil
+		}
+	}()
+
+	for k := 0; k < 3; k++ {
+		select {
+		case err := <-errs:
+			if err != nil {
+				t.Error(err)
+			}
+		case <-time.After(1 * time.Second):
+			t.Errorf("timed out")
+		}
+	}
+}