|
|
@@ -2745,3 +2745,127 @@ func TestTransportCancelDataResponseRace(t *testing.T) {
|
|
|
t.Errorf("Got = %q; want %q", slurp, msg)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestTransportRetryAfterGOAWAY(t *testing.T) {
|
|
|
+ t.Skip("to be unskipped by https://go-review.googlesource.com/c/33971/")
|
|
|
+ var dialer struct {
|
|
|
+ sync.Mutex
|
|
|
+ count int
|
|
|
+ }
|
|
|
+ ct1 := make(chan *clientTester)
|
|
|
+ ct2 := make(chan *clientTester)
|
|
|
+
|
|
|
+ ln := newLocalListener(t)
|
|
|
+ defer ln.Close()
|
|
|
+
|
|
|
+ tr := &Transport{
|
|
|
+ TLSClientConfig: tlsConfigInsecure,
|
|
|
+ }
|
|
|
+ tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
|
|
+ dialer.Lock()
|
|
|
+ defer dialer.Unlock()
|
|
|
+ dialer.count++
|
|
|
+ cc, err := net.Dial("tcp", ln.Addr().String())
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("dial error: %v", err)
|
|
|
+ }
|
|
|
+ sc, err := ln.Accept()
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("accept error: %v", err)
|
|
|
+ }
|
|
|
+ ct := &clientTester{
|
|
|
+ t: t,
|
|
|
+ tr: tr,
|
|
|
+ cc: cc,
|
|
|
+ sc: sc,
|
|
|
+ fr: NewFramer(sc, sc),
|
|
|
+ }
|
|
|
+ switch dialer.count {
|
|
|
+ case 1:
|
|
|
+ ct1 <- ct
|
|
|
+ case 2:
|
|
|
+ ct2 <- ct
|
|
|
+ }
|
|
|
+ return cc, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ errs := make(chan error, 3)
|
|
|
+ done := make(chan struct{})
|
|
|
+ defer close(done)
|
|
|
+
|
|
|
+ // Client.
|
|
|
+ go func() {
|
|
|
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
|
|
+ res, err := tr.RoundTrip(req)
|
|
|
+ t.Logf("client got %T, %v", res, err)
|
|
|
+ errs <- err
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Server for the first request.
|
|
|
+ go func() {
|
|
|
+ var ct *clientTester
|
|
|
+ select {
|
|
|
+ case ct = <-ct1:
|
|
|
+ case <-done:
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ ct.greet()
|
|
|
+ hf, err := ct.firstHeaders()
|
|
|
+ if err != nil {
|
|
|
+ errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ t.Logf("server1 got %v", hf)
|
|
|
+ if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
|
|
|
+ errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ ct.cc.(*net.TCPConn).Close()
|
|
|
+ errs <- nil
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Server for the second request.
|
|
|
+ go func() {
|
|
|
+ var ct *clientTester
|
|
|
+ select {
|
|
|
+ case ct = <-ct2:
|
|
|
+ case <-done:
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ ct.greet()
|
|
|
+ hf, err := ct.firstHeaders()
|
|
|
+ if err != nil {
|
|
|
+ errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ t.Logf("server2 Got %v", hf)
|
|
|
+
|
|
|
+ var buf bytes.Buffer
|
|
|
+ enc := hpack.NewEncoder(&buf)
|
|
|
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
|
|
+ err = ct.fr.WriteHeaders(HeadersFrameParam{
|
|
|
+ StreamID: hf.StreamID,
|
|
|
+ EndHeaders: true,
|
|
|
+ EndStream: false,
|
|
|
+ BlockFragment: buf.Bytes(),
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ errs <- fmt.Errorf("server2 failed writin responseg HEADERS: %v", err)
|
|
|
+ } else {
|
|
|
+ errs <- nil
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ for k := 0; k < 3; k++ {
|
|
|
+ select {
|
|
|
+ case err := <-errs:
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ case <-time.After(1 * time.Second):
|
|
|
+ t.Errorf("timed out")
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|