|
|
@@ -3309,6 +3309,166 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
|
|
|
req.Header = http.Header{}
|
|
|
}
|
|
|
|
|
|
+func TestTransportCloseAfterLostPing(t *testing.T) {
|
|
|
+ clientDone := make(chan struct{})
|
|
|
+ ct := newClientTester(t)
|
|
|
+ ct.tr.PingTimeout = 1 * time.Second
|
|
|
+ ct.tr.ReadIdleTimeout = 1 * time.Second
|
|
|
+ ct.client = func() error {
|
|
|
+ defer ct.cc.(*net.TCPConn).CloseWrite()
|
|
|
+ defer close(clientDone)
|
|
|
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
|
|
+ _, err := ct.tr.RoundTrip(req)
|
|
|
+ if err == nil || !strings.Contains(err.Error(), "client connection lost") {
|
|
|
+ return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ ct.server = func() error {
|
|
|
+ ct.greet()
|
|
|
+ <-clientDone
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ ct.run()
|
|
|
+}
|
|
|
+
|
|
|
+func TestTransportPingWhenReading(t *testing.T) {
|
|
|
+ testCases := []struct {
|
|
|
+ name string
|
|
|
+ readIdleTimeout time.Duration
|
|
|
+ serverResponseInterval time.Duration
|
|
|
+ expectedPingCount int
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "two pings in each serverResponseInterval",
|
|
|
+ readIdleTimeout: 400 * time.Millisecond,
|
|
|
+ serverResponseInterval: 1000 * time.Millisecond,
|
|
|
+ expectedPingCount: 4,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "one ping in each serverResponseInterval",
|
|
|
+ readIdleTimeout: 700 * time.Millisecond,
|
|
|
+ serverResponseInterval: 1000 * time.Millisecond,
|
|
|
+ expectedPingCount: 2,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "zero ping in each serverResponseInterval",
|
|
|
+ readIdleTimeout: 1000 * time.Millisecond,
|
|
|
+ serverResponseInterval: 500 * time.Millisecond,
|
|
|
+ expectedPingCount: 0,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "0 readIdleTimeout means no ping",
|
|
|
+ readIdleTimeout: 0 * time.Millisecond,
|
|
|
+ serverResponseInterval: 500 * time.Millisecond,
|
|
|
+ expectedPingCount: 0,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tc := range testCases {
|
|
|
+ tc := tc // capture range variable
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
|
|
|
+ var pingCount int
|
|
|
+ clientDone := make(chan struct{})
|
|
|
+ ct := newClientTester(t)
|
|
|
+ ct.tr.PingTimeout = 10 * time.Millisecond
|
|
|
+ ct.tr.ReadIdleTimeout = readIdleTimeout
|
|
|
+ // guards the ct.fr.Write
|
|
|
+ var wmu sync.Mutex
|
|
|
+
|
|
|
+ ct.client = func() error {
|
|
|
+ defer ct.cc.(*net.TCPConn).CloseWrite()
|
|
|
+ defer close(clientDone)
|
|
|
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
|
|
+ res, err := ct.tr.RoundTrip(req)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("RoundTrip: %v", err)
|
|
|
+ }
|
|
|
+ defer res.Body.Close()
|
|
|
+ if res.StatusCode != 200 {
|
|
|
+ return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
|
|
|
+ }
|
|
|
+ _, err = ioutil.ReadAll(res.Body)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ ct.server = func() error {
|
|
|
+ ct.greet()
|
|
|
+ var buf bytes.Buffer
|
|
|
+ enc := hpack.NewEncoder(&buf)
|
|
|
+ for {
|
|
|
+ f, err := ct.fr.ReadFrame()
|
|
|
+ if err != nil {
|
|
|
+ select {
|
|
|
+ case <-clientDone:
|
|
|
+ // If the client's done, it
|
|
|
+ // will have reported any
|
|
|
+ // errors on its side.
|
|
|
+ return nil
|
|
|
+ default:
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ switch f := f.(type) {
|
|
|
+ case *WindowUpdateFrame, *SettingsFrame:
|
|
|
+ case *HeadersFrame:
|
|
|
+ if !f.HeadersEnded() {
|
|
|
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
|
|
|
+ }
|
|
|
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
|
|
|
+ ct.fr.WriteHeaders(HeadersFrameParam{
|
|
|
+ StreamID: f.StreamID,
|
|
|
+ EndHeaders: true,
|
|
|
+ EndStream: false,
|
|
|
+ BlockFragment: buf.Bytes(),
|
|
|
+ })
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ for i := 0; i < 2; i++ {
|
|
|
+ wmu.Lock()
|
|
|
+ if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
|
|
|
+ wmu.Unlock()
|
|
|
+ t.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ wmu.Unlock()
|
|
|
+ time.Sleep(serverResponseInterval)
|
|
|
+ }
|
|
|
+ wmu.Lock()
|
|
|
+ if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
|
|
|
+ wmu.Unlock()
|
|
|
+ t.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ wmu.Unlock()
|
|
|
+ }()
|
|
|
+ case *PingFrame:
|
|
|
+ pingCount++
|
|
|
+ wmu.Lock()
|
|
|
+ if err := ct.fr.WritePing(true, f.Data); err != nil {
|
|
|
+ wmu.Unlock()
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ wmu.Unlock()
|
|
|
+ default:
|
|
|
+ return fmt.Errorf("Unexpected client frame %v", f)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ct.run()
|
|
|
+ if e, a := expectedPingCount, pingCount; e != a {
|
|
|
+ t.Errorf("expected receiving %d pings, got %d pings", e, a)
|
|
|
+
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TestTransportRetryAfterGOAWAY(t *testing.T) {
|
|
|
var dialer struct {
|
|
|
sync.Mutex
|