|
|
@@ -4180,3 +4180,99 @@ func TestNoDialH2RoundTripperType(t *testing.T) {
|
|
|
t.Fatalf("wrong kind %T; want *Transport", v.Interface())
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+type errReader struct {
|
|
|
+ body []byte
|
|
|
+ err error
|
|
|
+}
|
|
|
+
|
|
|
+func (r *errReader) Read(p []byte) (int, error) {
|
|
|
+ if len(r.body) > 0 {
|
|
|
+ n := copy(p, r.body)
|
|
|
+ r.body = r.body[n:]
|
|
|
+ return n, nil
|
|
|
+ }
|
|
|
+ return 0, r.err
|
|
|
+}
|
|
|
+
|
|
|
+func testTransportBodyReadError(t *testing.T, body []byte) {
|
|
|
+ clientDone := make(chan struct{})
|
|
|
+ ct := newClientTester(t)
|
|
|
+ ct.client = func() error {
|
|
|
+ defer ct.cc.(*net.TCPConn).CloseWrite()
|
|
|
+ defer close(clientDone)
|
|
|
+
|
|
|
+ checkNoStreams := func() error {
|
|
|
+ cp, ok := ct.tr.connPool().(*clientConnPool)
|
|
|
+ if !ok {
|
|
|
+ return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
|
|
|
+ }
|
|
|
+ cp.mu.Lock()
|
|
|
+ defer cp.mu.Unlock()
|
|
|
+ conns, ok := cp.conns["dummy.tld:443"]
|
|
|
+ if !ok {
|
|
|
+ return fmt.Errorf("missing connection")
|
|
|
+ }
|
|
|
+ if len(conns) != 1 {
|
|
|
+ return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
|
|
|
+ }
|
|
|
+ if activeStreams(conns[0]) != 0 {
|
|
|
+ return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ bodyReadError := errors.New("body read error")
|
|
|
+ body := &errReader{body, bodyReadError}
|
|
|
+ req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ _, err = ct.tr.RoundTrip(req)
|
|
|
+ if err != bodyReadError {
|
|
|
+ return fmt.Errorf("err = %v; want %v", err, bodyReadError)
|
|
|
+ }
|
|
|
+ if err = checkNoStreams(); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ ct.server = func() error {
|
|
|
+ ct.greet()
|
|
|
+ var receivedBody []byte
|
|
|
+ var resetCount int
|
|
|
+ for {
|
|
|
+ f, err := ct.fr.ReadFrame()
|
|
|
+ if err != nil {
|
|
|
+ select {
|
|
|
+ case <-clientDone:
|
|
|
+ // If the client's done, it
|
|
|
+ // will have reported any
|
|
|
+ // errors on its side.
|
|
|
+ if bytes.Compare(receivedBody, body) != 0 {
|
|
|
+ return fmt.Errorf("body: %v; expected %v", receivedBody, body)
|
|
|
+ }
|
|
|
+ if resetCount != 1 {
|
|
|
+ return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ default:
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ switch f := f.(type) {
|
|
|
+ case *WindowUpdateFrame, *SettingsFrame:
|
|
|
+ case *HeadersFrame:
|
|
|
+ case *DataFrame:
|
|
|
+ receivedBody = append(receivedBody, f.Data()...)
|
|
|
+ case *RSTStreamFrame:
|
|
|
+ resetCount++
|
|
|
+ default:
|
|
|
+ return fmt.Errorf("Unexpected client frame %v", f)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ct.run()
|
|
|
+}
|
|
|
+
|
|
|
+func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
|
|
|
+func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }
|