transport_test.go 44 KB


  1. // Copyright 2015 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package http2
  5. import (
  6. "bufio"
  7. "bytes"
  8. "crypto/tls"
  9. "errors"
  10. "flag"
  11. "fmt"
  12. "io"
  13. "io/ioutil"
  14. "log"
  15. "math/rand"
  16. "net"
  17. "net/http"
  18. "net/url"
  19. "os"
  20. "reflect"
  21. "sort"
  22. "strconv"
  23. "strings"
  24. "sync"
  25. "sync/atomic"
  26. "testing"
  27. "time"
  28. "golang.org/x/net/http2/hpack"
  29. )
  30. var (
  31. extNet = flag.Bool("extnet", false, "do external network tests")
  32. transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
  33. insecure = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove?
  34. )
  35. var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
  36. func TestTransportExternal(t *testing.T) {
  37. if !*extNet {
  38. t.Skip("skipping external network test")
  39. }
  40. req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
  41. rt := &Transport{TLSClientConfig: tlsConfigInsecure}
  42. res, err := rt.RoundTrip(req)
  43. if err != nil {
  44. t.Fatalf("%v", err)
  45. }
  46. res.Write(os.Stdout)
  47. }
  48. func TestTransport(t *testing.T) {
  49. const body = "sup"
  50. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  51. io.WriteString(w, body)
  52. }, optOnlyServer)
  53. defer st.Close()
  54. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  55. defer tr.CloseIdleConnections()
  56. req, err := http.NewRequest("GET", st.ts.URL, nil)
  57. if err != nil {
  58. t.Fatal(err)
  59. }
  60. res, err := tr.RoundTrip(req)
  61. if err != nil {
  62. t.Fatal(err)
  63. }
  64. defer res.Body.Close()
  65. t.Logf("Got res: %+v", res)
  66. if g, w := res.StatusCode, 200; g != w {
  67. t.Errorf("StatusCode = %v; want %v", g, w)
  68. }
  69. if g, w := res.Status, "200 OK"; g != w {
  70. t.Errorf("Status = %q; want %q", g, w)
  71. }
  72. wantHeader := http.Header{
  73. "Content-Length": []string{"3"},
  74. "Content-Type": []string{"text/plain; charset=utf-8"},
  75. "Date": []string{"XXX"}, // see cleanDate
  76. }
  77. cleanDate(res)
  78. if !reflect.DeepEqual(res.Header, wantHeader) {
  79. t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
  80. }
  81. if res.Request != req {
  82. t.Errorf("Response.Request = %p; want %p", res.Request, req)
  83. }
  84. if res.TLS == nil {
  85. t.Error("Response.TLS = nil; want non-nil")
  86. }
  87. slurp, err := ioutil.ReadAll(res.Body)
  88. if err != nil {
  89. t.Errorf("Body read: %v", err)
  90. } else if string(slurp) != body {
  91. t.Errorf("Body = %q; want %q", slurp, body)
  92. }
  93. }
  94. func onSameConn(t *testing.T, modReq func(*http.Request)) bool {
  95. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  96. io.WriteString(w, r.RemoteAddr)
  97. }, optOnlyServer)
  98. defer st.Close()
  99. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  100. defer tr.CloseIdleConnections()
  101. get := func() string {
  102. req, err := http.NewRequest("GET", st.ts.URL, nil)
  103. if err != nil {
  104. t.Fatal(err)
  105. }
  106. modReq(req)
  107. res, err := tr.RoundTrip(req)
  108. if err != nil {
  109. t.Fatal(err)
  110. }
  111. defer res.Body.Close()
  112. slurp, err := ioutil.ReadAll(res.Body)
  113. if err != nil {
  114. t.Fatalf("Body read: %v", err)
  115. }
  116. addr := strings.TrimSpace(string(slurp))
  117. if addr == "" {
  118. t.Fatalf("didn't get an addr in response")
  119. }
  120. return addr
  121. }
  122. first := get()
  123. second := get()
  124. return first == second
  125. }
  126. func TestTransportReusesConns(t *testing.T) {
  127. if !onSameConn(t, func(*http.Request) {}) {
  128. t.Errorf("first and second responses were on different connections")
  129. }
  130. }
  131. func TestTransportReusesConn_RequestClose(t *testing.T) {
  132. if onSameConn(t, func(r *http.Request) { r.Close = true }) {
  133. t.Errorf("first and second responses were not on different connections")
  134. }
  135. }
  136. func TestTransportReusesConn_ConnClose(t *testing.T) {
  137. if onSameConn(t, func(r *http.Request) { r.Header.Set("Connection", "close") }) {
  138. t.Errorf("first and second responses were not on different connections")
  139. }
  140. }
  141. // Tests that the Transport only keeps one pending dial open per destination address.
  142. // https://golang.org/issue/13397
  143. func TestTransportGroupsPendingDials(t *testing.T) {
  144. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  145. io.WriteString(w, r.RemoteAddr)
  146. }, optOnlyServer)
  147. defer st.Close()
  148. tr := &Transport{
  149. TLSClientConfig: tlsConfigInsecure,
  150. }
  151. defer tr.CloseIdleConnections()
  152. var (
  153. mu sync.Mutex
  154. dials = map[string]int{}
  155. )
  156. var wg sync.WaitGroup
  157. for i := 0; i < 10; i++ {
  158. wg.Add(1)
  159. go func() {
  160. defer wg.Done()
  161. req, err := http.NewRequest("GET", st.ts.URL, nil)
  162. if err != nil {
  163. t.Error(err)
  164. return
  165. }
  166. res, err := tr.RoundTrip(req)
  167. if err != nil {
  168. t.Error(err)
  169. return
  170. }
  171. defer res.Body.Close()
  172. slurp, err := ioutil.ReadAll(res.Body)
  173. if err != nil {
  174. t.Errorf("Body read: %v", err)
  175. }
  176. addr := strings.TrimSpace(string(slurp))
  177. if addr == "" {
  178. t.Errorf("didn't get an addr in response")
  179. }
  180. mu.Lock()
  181. dials[addr]++
  182. mu.Unlock()
  183. }()
  184. }
  185. wg.Wait()
  186. if len(dials) != 1 {
  187. t.Errorf("saw %d dials; want 1: %v", len(dials), dials)
  188. }
  189. tr.CloseIdleConnections()
  190. if err := retry(50, 10*time.Millisecond, func() error {
  191. cp, ok := tr.connPool().(*clientConnPool)
  192. if !ok {
  193. return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
  194. }
  195. cp.mu.Lock()
  196. defer cp.mu.Unlock()
  197. if len(cp.dialing) != 0 {
  198. return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
  199. }
  200. if len(cp.conns) != 0 {
  201. return fmt.Errorf("conns = %v; want empty", cp.conns)
  202. }
  203. if len(cp.keys) != 0 {
  204. return fmt.Errorf("keys = %v; want empty", cp.keys)
  205. }
  206. return nil
  207. }); err != nil {
  208. t.Errorf("State of pool after CloseIdleConnections: %v", err)
  209. }
  210. }
  211. func retry(tries int, delay time.Duration, fn func() error) error {
  212. var err error
  213. for i := 0; i < tries; i++ {
  214. err = fn()
  215. if err == nil {
  216. return nil
  217. }
  218. time.Sleep(delay)
  219. }
  220. return err
  221. }
  222. func TestTransportAbortClosesPipes(t *testing.T) {
  223. shutdown := make(chan struct{})
  224. st := newServerTester(t,
  225. func(w http.ResponseWriter, r *http.Request) {
  226. w.(http.Flusher).Flush()
  227. <-shutdown
  228. },
  229. optOnlyServer,
  230. )
  231. defer st.Close()
  232. defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
  233. done := make(chan struct{})
  234. requestMade := make(chan struct{})
  235. go func() {
  236. defer close(done)
  237. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  238. req, err := http.NewRequest("GET", st.ts.URL, nil)
  239. if err != nil {
  240. t.Fatal(err)
  241. }
  242. res, err := tr.RoundTrip(req)
  243. if err != nil {
  244. t.Fatal(err)
  245. }
  246. defer res.Body.Close()
  247. close(requestMade)
  248. _, err = ioutil.ReadAll(res.Body)
  249. if err == nil {
  250. t.Error("expected error from res.Body.Read")
  251. }
  252. }()
  253. <-requestMade
  254. // Now force the serve loop to end, via closing the connection.
  255. st.closeConn()
  256. // deadlock? that's a bug.
  257. select {
  258. case <-done:
  259. case <-time.After(3 * time.Second):
  260. t.Fatal("timeout")
  261. }
  262. }
  263. // TODO: merge this with TestTransportBody to make TestTransportRequest? This
  264. // could be a table-driven test with extra goodies.
  265. func TestTransportPath(t *testing.T) {
  266. gotc := make(chan *url.URL, 1)
  267. st := newServerTester(t,
  268. func(w http.ResponseWriter, r *http.Request) {
  269. gotc <- r.URL
  270. },
  271. optOnlyServer,
  272. )
  273. defer st.Close()
  274. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  275. defer tr.CloseIdleConnections()
  276. const (
  277. path = "/testpath"
  278. query = "q=1"
  279. )
  280. surl := st.ts.URL + path + "?" + query
  281. req, err := http.NewRequest("POST", surl, nil)
  282. if err != nil {
  283. t.Fatal(err)
  284. }
  285. c := &http.Client{Transport: tr}
  286. res, err := c.Do(req)
  287. if err != nil {
  288. t.Fatal(err)
  289. }
  290. defer res.Body.Close()
  291. got := <-gotc
  292. if got.Path != path {
  293. t.Errorf("Read Path = %q; want %q", got.Path, path)
  294. }
  295. if got.RawQuery != query {
  296. t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
  297. }
  298. }
  299. func randString(n int) string {
  300. rnd := rand.New(rand.NewSource(int64(n)))
  301. b := make([]byte, n)
  302. for i := range b {
  303. b[i] = byte(rnd.Intn(256))
  304. }
  305. return string(b)
  306. }
  307. var bodyTests = []struct {
  308. body string
  309. noContentLen bool
  310. }{
  311. {body: "some message"},
  312. {body: "some message", noContentLen: true},
  313. {body: ""},
  314. {body: "", noContentLen: true},
  315. {body: strings.Repeat("a", 1<<20), noContentLen: true},
  316. {body: strings.Repeat("a", 1<<20)},
  317. {body: randString(16<<10 - 1)},
  318. {body: randString(16 << 10)},
  319. {body: randString(16<<10 + 1)},
  320. {body: randString(512<<10 - 1)},
  321. {body: randString(512 << 10)},
  322. {body: randString(512<<10 + 1)},
  323. {body: randString(1<<20 - 1)},
  324. {body: randString(1 << 20)},
  325. {body: randString(1<<20 + 2)},
  326. }
  327. func TestTransportBody(t *testing.T) {
  328. type reqInfo struct {
  329. req *http.Request
  330. slurp []byte
  331. err error
  332. }
  333. gotc := make(chan reqInfo, 1)
  334. st := newServerTester(t,
  335. func(w http.ResponseWriter, r *http.Request) {
  336. slurp, err := ioutil.ReadAll(r.Body)
  337. if err != nil {
  338. gotc <- reqInfo{err: err}
  339. } else {
  340. gotc <- reqInfo{req: r, slurp: slurp}
  341. }
  342. },
  343. optOnlyServer,
  344. )
  345. defer st.Close()
  346. for i, tt := range bodyTests {
  347. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  348. defer tr.CloseIdleConnections()
  349. var body io.Reader = strings.NewReader(tt.body)
  350. if tt.noContentLen {
  351. body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods
  352. }
  353. req, err := http.NewRequest("POST", st.ts.URL, body)
  354. if err != nil {
  355. t.Fatalf("#%d: %v", i, err)
  356. }
  357. c := &http.Client{Transport: tr}
  358. res, err := c.Do(req)
  359. if err != nil {
  360. t.Fatalf("#%d: %v", i, err)
  361. }
  362. defer res.Body.Close()
  363. ri := <-gotc
  364. if ri.err != nil {
  365. t.Errorf("#%d: read error: %v", i, ri.err)
  366. continue
  367. }
  368. if got := string(ri.slurp); got != tt.body {
  369. t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
  370. }
  371. wantLen := int64(len(tt.body))
  372. if tt.noContentLen && tt.body != "" {
  373. wantLen = -1
  374. }
  375. if ri.req.ContentLength != wantLen {
  376. t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
  377. }
  378. }
  379. }
  380. func shortString(v string) string {
  381. const maxLen = 100
  382. if len(v) <= maxLen {
  383. return v
  384. }
  385. return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
  386. }
  387. func TestTransportDialTLS(t *testing.T) {
  388. var mu sync.Mutex // guards following
  389. var gotReq, didDial bool
  390. ts := newServerTester(t,
  391. func(w http.ResponseWriter, r *http.Request) {
  392. mu.Lock()
  393. gotReq = true
  394. mu.Unlock()
  395. },
  396. optOnlyServer,
  397. )
  398. defer ts.Close()
  399. tr := &Transport{
  400. DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
  401. mu.Lock()
  402. didDial = true
  403. mu.Unlock()
  404. cfg.InsecureSkipVerify = true
  405. c, err := tls.Dial(netw, addr, cfg)
  406. if err != nil {
  407. return nil, err
  408. }
  409. return c, c.Handshake()
  410. },
  411. }
  412. defer tr.CloseIdleConnections()
  413. client := &http.Client{Transport: tr}
  414. res, err := client.Get(ts.ts.URL)
  415. if err != nil {
  416. t.Fatal(err)
  417. }
  418. res.Body.Close()
  419. mu.Lock()
  420. if !gotReq {
  421. t.Error("didn't get request")
  422. }
  423. if !didDial {
  424. t.Error("didn't use dial hook")
  425. }
  426. }
  427. func TestConfigureTransport(t *testing.T) {
  428. t1 := &http.Transport{}
  429. err := ConfigureTransport(t1)
  430. if err == errTransportVersion {
  431. t.Skip(err)
  432. }
  433. if err != nil {
  434. t.Fatal(err)
  435. }
  436. if got := fmt.Sprintf("%#v", *t1); !strings.Contains(got, `"h2"`) {
  437. // Laziness, to avoid buildtags.
  438. t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got)
  439. }
  440. wantNextProtos := []string{"h2", "http/1.1"}
  441. if t1.TLSClientConfig == nil {
  442. t.Errorf("nil t1.TLSClientConfig")
  443. } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) {
  444. t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos)
  445. }
  446. if err := ConfigureTransport(t1); err == nil {
  447. t.Error("unexpected success on second call to ConfigureTransport")
  448. }
  449. // And does it work?
  450. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  451. io.WriteString(w, r.Proto)
  452. }, optOnlyServer)
  453. defer st.Close()
  454. t1.TLSClientConfig.InsecureSkipVerify = true
  455. c := &http.Client{Transport: t1}
  456. res, err := c.Get(st.ts.URL)
  457. if err != nil {
  458. t.Fatal(err)
  459. }
  460. slurp, err := ioutil.ReadAll(res.Body)
  461. if err != nil {
  462. t.Fatal(err)
  463. }
  464. if got, want := string(slurp), "HTTP/2.0"; got != want {
  465. t.Errorf("body = %q; want %q", got, want)
  466. }
  467. }
  468. type capitalizeReader struct {
  469. r io.Reader
  470. }
  471. func (cr capitalizeReader) Read(p []byte) (n int, err error) {
  472. n, err = cr.r.Read(p)
  473. for i, b := range p[:n] {
  474. if b >= 'a' && b <= 'z' {
  475. p[i] = b - ('a' - 'A')
  476. }
  477. }
  478. return
  479. }
  480. type flushWriter struct {
  481. w io.Writer
  482. }
  483. func (fw flushWriter) Write(p []byte) (n int, err error) {
  484. n, err = fw.w.Write(p)
  485. if f, ok := fw.w.(http.Flusher); ok {
  486. f.Flush()
  487. }
  488. return
  489. }
  490. type clientTester struct {
  491. t *testing.T
  492. tr *Transport
  493. sc, cc net.Conn // server and client conn
  494. fr *Framer // server's framer
  495. client func() error
  496. server func() error
  497. }
  498. func newClientTester(t *testing.T) *clientTester {
  499. var dialOnce struct {
  500. sync.Mutex
  501. dialed bool
  502. }
  503. ct := &clientTester{
  504. t: t,
  505. }
  506. ct.tr = &Transport{
  507. TLSClientConfig: tlsConfigInsecure,
  508. DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  509. dialOnce.Lock()
  510. defer dialOnce.Unlock()
  511. if dialOnce.dialed {
  512. return nil, errors.New("only one dial allowed in test mode")
  513. }
  514. dialOnce.dialed = true
  515. return ct.cc, nil
  516. },
  517. }
  518. ln := newLocalListener(t)
  519. cc, err := net.Dial("tcp", ln.Addr().String())
  520. if err != nil {
  521. t.Fatal(err)
  522. }
  523. sc, err := ln.Accept()
  524. if err != nil {
  525. t.Fatal(err)
  526. }
  527. ln.Close()
  528. ct.cc = cc
  529. ct.sc = sc
  530. ct.fr = NewFramer(sc, sc)
  531. return ct
  532. }
  533. func newLocalListener(t *testing.T) net.Listener {
  534. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  535. if err == nil {
  536. return ln
  537. }
  538. ln, err = net.Listen("tcp6", "[::1]:0")
  539. if err != nil {
  540. t.Fatal(err)
  541. }
  542. return ln
  543. }
  544. func (ct *clientTester) greet() {
  545. buf := make([]byte, len(ClientPreface))
  546. _, err := io.ReadFull(ct.sc, buf)
  547. if err != nil {
  548. ct.t.Fatalf("reading client preface: %v", err)
  549. }
  550. f, err := ct.fr.ReadFrame()
  551. if err != nil {
  552. ct.t.Fatalf("Reading client settings frame: %v", err)
  553. }
  554. if sf, ok := f.(*SettingsFrame); !ok {
  555. ct.t.Fatalf("Wanted client settings frame; got %v", f)
  556. _ = sf // stash it away?
  557. }
  558. if err := ct.fr.WriteSettings(); err != nil {
  559. ct.t.Fatal(err)
  560. }
  561. if err := ct.fr.WriteSettingsAck(); err != nil {
  562. ct.t.Fatal(err)
  563. }
  564. }
  565. func (ct *clientTester) cleanup() {
  566. ct.tr.CloseIdleConnections()
  567. }
  568. func (ct *clientTester) run() {
  569. errc := make(chan error, 2)
  570. ct.start("client", errc, ct.client)
  571. ct.start("server", errc, ct.server)
  572. defer ct.cleanup()
  573. for i := 0; i < 2; i++ {
  574. if err := <-errc; err != nil {
  575. ct.t.Error(err)
  576. return
  577. }
  578. }
  579. }
  580. func (ct *clientTester) start(which string, errc chan<- error, fn func() error) {
  581. go func() {
  582. finished := false
  583. var err error
  584. defer func() {
  585. if !finished {
  586. err = fmt.Errorf("%s goroutine didn't finish.", which)
  587. } else if err != nil {
  588. err = fmt.Errorf("%s: %v", which, err)
  589. }
  590. errc <- err
  591. }()
  592. err = fn()
  593. finished = true
  594. }()
  595. }
  596. type countingReader struct {
  597. n *int64
  598. }
  599. func (r countingReader) Read(p []byte) (n int, err error) {
  600. for i := range p {
  601. p[i] = byte(i)
  602. }
  603. atomic.AddInt64(r.n, int64(len(p)))
  604. return len(p), err
  605. }
  606. func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
  607. func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
  608. func testTransportReqBodyAfterResponse(t *testing.T, status int) {
  609. const bodySize = 10 << 20
  610. ct := newClientTester(t)
  611. ct.client = func() error {
  612. var n int64 // atomic
  613. req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize))
  614. if err != nil {
  615. return err
  616. }
  617. res, err := ct.tr.RoundTrip(req)
  618. if err != nil {
  619. return fmt.Errorf("RoundTrip: %v", err)
  620. }
  621. defer res.Body.Close()
  622. if res.StatusCode != status {
  623. return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
  624. }
  625. slurp, err := ioutil.ReadAll(res.Body)
  626. if err != nil {
  627. return fmt.Errorf("Slurp: %v", err)
  628. }
  629. if len(slurp) > 0 {
  630. return fmt.Errorf("unexpected body: %q", slurp)
  631. }
  632. if status == 200 {
  633. if got := atomic.LoadInt64(&n); got != bodySize {
  634. return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
  635. }
  636. } else {
  637. if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize {
  638. return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
  639. }
  640. }
  641. return nil
  642. }
  643. ct.server = func() error {
  644. ct.greet()
  645. var buf bytes.Buffer
  646. enc := hpack.NewEncoder(&buf)
  647. var dataRecv int64
  648. var closed bool
  649. for {
  650. f, err := ct.fr.ReadFrame()
  651. if err != nil {
  652. return err
  653. }
  654. //println(fmt.Sprintf("server got frame: %v", f))
  655. switch f := f.(type) {
  656. case *WindowUpdateFrame, *SettingsFrame:
  657. case *HeadersFrame:
  658. if !f.HeadersEnded() {
  659. return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  660. }
  661. if f.StreamEnded() {
  662. return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
  663. }
  664. time.Sleep(50 * time.Millisecond) // let client send body
  665. enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
  666. ct.fr.WriteHeaders(HeadersFrameParam{
  667. StreamID: f.StreamID,
  668. EndHeaders: true,
  669. EndStream: false,
  670. BlockFragment: buf.Bytes(),
  671. })
  672. case *DataFrame:
  673. dataLen := len(f.Data())
  674. dataRecv += int64(dataLen)
  675. if dataLen > 0 {
  676. if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
  677. return err
  678. }
  679. if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
  680. return err
  681. }
  682. }
  683. if !closed && ((status != 200 && dataRecv > 0) ||
  684. (status == 200 && dataRecv == bodySize)) {
  685. closed = true
  686. if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
  687. return err
  688. }
  689. return nil
  690. }
  691. default:
  692. return fmt.Errorf("Unexpected client frame %v", f)
  693. }
  694. }
  695. return nil
  696. }
  697. ct.run()
  698. }
  699. // See golang.org/issue/13444
  700. func TestTransportFullDuplex(t *testing.T) {
  701. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  702. w.WriteHeader(200) // redundant but for clarity
  703. w.(http.Flusher).Flush()
  704. io.Copy(flushWriter{w}, capitalizeReader{r.Body})
  705. fmt.Fprintf(w, "bye.\n")
  706. }, optOnlyServer)
  707. defer st.Close()
  708. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  709. defer tr.CloseIdleConnections()
  710. c := &http.Client{Transport: tr}
  711. pr, pw := io.Pipe()
  712. req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
  713. if err != nil {
  714. log.Fatal(err)
  715. }
  716. req.ContentLength = -1
  717. res, err := c.Do(req)
  718. if err != nil {
  719. log.Fatal(err)
  720. }
  721. defer res.Body.Close()
  722. if res.StatusCode != 200 {
  723. t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
  724. }
  725. bs := bufio.NewScanner(res.Body)
  726. want := func(v string) {
  727. if !bs.Scan() {
  728. t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
  729. }
  730. }
  731. write := func(v string) {
  732. _, err := io.WriteString(pw, v)
  733. if err != nil {
  734. t.Fatalf("pipe write: %v", err)
  735. }
  736. }
  737. write("foo\n")
  738. want("FOO")
  739. write("bar\n")
  740. want("BAR")
  741. pw.Close()
  742. want("bye.")
  743. if err := bs.Err(); err != nil {
  744. t.Fatal(err)
  745. }
  746. }
  747. func TestTransportConnectRequest(t *testing.T) {
  748. gotc := make(chan *http.Request, 1)
  749. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  750. gotc <- r
  751. }, optOnlyServer)
  752. defer st.Close()
  753. u, err := url.Parse(st.ts.URL)
  754. if err != nil {
  755. t.Fatal(err)
  756. }
  757. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  758. defer tr.CloseIdleConnections()
  759. c := &http.Client{Transport: tr}
  760. tests := []struct {
  761. req *http.Request
  762. want string
  763. }{
  764. {
  765. req: &http.Request{
  766. Method: "CONNECT",
  767. Header: http.Header{},
  768. URL: u,
  769. },
  770. want: u.Host,
  771. },
  772. {
  773. req: &http.Request{
  774. Method: "CONNECT",
  775. Header: http.Header{},
  776. URL: u,
  777. Host: "example.com:123",
  778. },
  779. want: "example.com:123",
  780. },
  781. }
  782. for i, tt := range tests {
  783. res, err := c.Do(tt.req)
  784. if err != nil {
  785. t.Errorf("%d. RoundTrip = %v", i, err)
  786. continue
  787. }
  788. res.Body.Close()
  789. req := <-gotc
  790. if req.Method != "CONNECT" {
  791. t.Errorf("method = %q; want CONNECT", req.Method)
  792. }
  793. if req.Host != tt.want {
  794. t.Errorf("Host = %q; want %q", req.Host, tt.want)
  795. }
  796. if req.URL.Host != tt.want {
  797. t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
  798. }
  799. }
  800. }
  801. type headerType int
  802. const (
  803. noHeader headerType = iota // omitted
  804. oneHeader
  805. splitHeader // broken into continuation on purpose
  806. )
  807. const (
  808. f0 = noHeader
  809. f1 = oneHeader
  810. f2 = splitHeader
  811. d0 = false
  812. d1 = true
  813. )
  814. // Test all 36 combinations of response frame orders:
  815. // (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) }
  816. // Generated by http://play.golang.org/p/SScqYKJYXd
  817. func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
  818. func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
  819. func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
  820. func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
  821. func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
  822. func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
  823. func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
  824. func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
  825. func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
  826. func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
  827. func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
  828. func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
  829. func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
  830. func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
  831. func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
  832. func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
  833. func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
  834. func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
  835. func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
  836. func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
  837. func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
  838. func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
  839. func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
  840. func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
  841. func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
  842. func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
  843. func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
  844. func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
  845. func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
  846. func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
  847. func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
  848. func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
  849. func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
  850. func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
  851. func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
  852. func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
  853. func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
  854. const reqBody = "some request body"
  855. const resBody = "some response body"
  856. if resHeader == noHeader {
  857. // TODO: test 100-continue followed by immediate
  858. // server stream reset, without headers in the middle?
  859. panic("invalid combination")
  860. }
  861. ct := newClientTester(t)
  862. ct.client = func() error {
  863. req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
  864. if expect100Continue != noHeader {
  865. req.Header.Set("Expect", "100-continue")
  866. }
  867. res, err := ct.tr.RoundTrip(req)
  868. if err != nil {
  869. return fmt.Errorf("RoundTrip: %v", err)
  870. }
  871. defer res.Body.Close()
  872. if res.StatusCode != 200 {
  873. return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  874. }
  875. slurp, err := ioutil.ReadAll(res.Body)
  876. if err != nil {
  877. return fmt.Errorf("Slurp: %v", err)
  878. }
  879. wantBody := resBody
  880. if !withData {
  881. wantBody = ""
  882. }
  883. if string(slurp) != wantBody {
  884. return fmt.Errorf("body = %q; want %q", slurp, wantBody)
  885. }
  886. if trailers == noHeader {
  887. if len(res.Trailer) > 0 {
  888. t.Errorf("Trailer = %v; want none", res.Trailer)
  889. }
  890. } else {
  891. want := http.Header{"Some-Trailer": {"some-value"}}
  892. if !reflect.DeepEqual(res.Trailer, want) {
  893. t.Errorf("Trailer = %v; want %v", res.Trailer, want)
  894. }
  895. }
  896. return nil
  897. }
  898. ct.server = func() error {
  899. ct.greet()
  900. var buf bytes.Buffer
  901. enc := hpack.NewEncoder(&buf)
  902. for {
  903. f, err := ct.fr.ReadFrame()
  904. if err != nil {
  905. return err
  906. }
  907. switch f := f.(type) {
  908. case *WindowUpdateFrame, *SettingsFrame:
  909. case *DataFrame:
  910. // ignore for now.
  911. case *HeadersFrame:
  912. endStream := false
  913. send := func(mode headerType) {
  914. hbf := buf.Bytes()
  915. switch mode {
  916. case oneHeader:
  917. ct.fr.WriteHeaders(HeadersFrameParam{
  918. StreamID: f.StreamID,
  919. EndHeaders: true,
  920. EndStream: endStream,
  921. BlockFragment: hbf,
  922. })
  923. case splitHeader:
  924. if len(hbf) < 2 {
  925. panic("too small")
  926. }
  927. ct.fr.WriteHeaders(HeadersFrameParam{
  928. StreamID: f.StreamID,
  929. EndHeaders: false,
  930. EndStream: endStream,
  931. BlockFragment: hbf[:1],
  932. })
  933. ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
  934. default:
  935. panic("bogus mode")
  936. }
  937. }
  938. if expect100Continue != noHeader {
  939. buf.Reset()
  940. enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
  941. send(expect100Continue)
  942. }
  943. // Response headers (1+ frames; 1 or 2 in this test, but never 0)
  944. {
  945. buf.Reset()
  946. enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  947. enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
  948. enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
  949. if trailers != noHeader {
  950. enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
  951. }
  952. endStream = withData == false && trailers == noHeader
  953. send(resHeader)
  954. }
  955. if withData {
  956. endStream = trailers == noHeader
  957. ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
  958. }
  959. if trailers != noHeader {
  960. endStream = true
  961. buf.Reset()
  962. enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
  963. send(trailers)
  964. }
  965. return nil
  966. }
  967. }
  968. }
  969. ct.run()
  970. }
  971. func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
  972. ct := newClientTester(t)
  973. ct.client = func() error {
  974. req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  975. res, err := ct.tr.RoundTrip(req)
  976. if err != nil {
  977. return fmt.Errorf("RoundTrip: %v", err)
  978. }
  979. defer res.Body.Close()
  980. if res.StatusCode != 200 {
  981. return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  982. }
  983. slurp, err := ioutil.ReadAll(res.Body)
  984. if err != nil {
  985. return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil)
  986. }
  987. if len(slurp) > 0 {
  988. return fmt.Errorf("body = %q; want nothing", slurp)
  989. }
  990. if _, ok := res.Trailer["Some-Trailer"]; !ok {
  991. return fmt.Errorf("expected Some-Trailer")
  992. }
  993. return nil
  994. }
  995. ct.server = func() error {
  996. ct.greet()
  997. var n int
  998. var hf *HeadersFrame
  999. for hf == nil && n < 10 {
  1000. f, err := ct.fr.ReadFrame()
  1001. if err != nil {
  1002. return err
  1003. }
  1004. hf, _ = f.(*HeadersFrame)
  1005. n++
  1006. }
  1007. var buf bytes.Buffer
  1008. enc := hpack.NewEncoder(&buf)
  1009. // send headers without Trailer header
  1010. enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1011. ct.fr.WriteHeaders(HeadersFrameParam{
  1012. StreamID: hf.StreamID,
  1013. EndHeaders: true,
  1014. EndStream: false,
  1015. BlockFragment: buf.Bytes(),
  1016. })
  1017. // send trailers
  1018. buf.Reset()
  1019. enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"})
  1020. ct.fr.WriteHeaders(HeadersFrameParam{
  1021. StreamID: hf.StreamID,
  1022. EndHeaders: true,
  1023. EndStream: true,
  1024. BlockFragment: buf.Bytes(),
  1025. })
  1026. return nil
  1027. }
  1028. ct.run()
  1029. }
  1030. func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
  1031. testTransportInvalidTrailer_Pseudo(t, oneHeader)
  1032. }
  1033. func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
  1034. testTransportInvalidTrailer_Pseudo(t, splitHeader)
  1035. }
  1036. func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
  1037. testInvalidTrailer(t, trailers, errPseudoTrailers, func(enc *hpack.Encoder) {
  1038. enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
  1039. enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
  1040. })
  1041. }
  1042. func TestTransportInvalidTrailer_Capital1(t *testing.T) {
  1043. testTransportInvalidTrailer_Capital(t, oneHeader)
  1044. }
  1045. func TestTransportInvalidTrailer_Capital2(t *testing.T) {
  1046. testTransportInvalidTrailer_Capital(t, splitHeader)
  1047. }
  1048. func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
  1049. testInvalidTrailer(t, trailers, errInvalidHeaderFieldName, func(enc *hpack.Encoder) {
  1050. enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
  1051. enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
  1052. })
  1053. }
  1054. func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
  1055. testInvalidTrailer(t, oneHeader, errInvalidHeaderFieldName, func(enc *hpack.Encoder) {
  1056. enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
  1057. })
  1058. }
  1059. func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
  1060. testInvalidTrailer(t, oneHeader, errInvalidHeaderFieldValue, func(enc *hpack.Encoder) {
  1061. enc.WriteField(hpack.HeaderField{Name: "", Value: "has\nnewline"})
  1062. })
  1063. }
  1064. func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) {
  1065. ct := newClientTester(t)
  1066. ct.client = func() error {
  1067. req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1068. res, err := ct.tr.RoundTrip(req)
  1069. if err != nil {
  1070. return fmt.Errorf("RoundTrip: %v", err)
  1071. }
  1072. defer res.Body.Close()
  1073. if res.StatusCode != 200 {
  1074. return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  1075. }
  1076. slurp, err := ioutil.ReadAll(res.Body)
  1077. if err != wantErr {
  1078. return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, wantErr)
  1079. }
  1080. if len(slurp) > 0 {
  1081. return fmt.Errorf("body = %q; want nothing", slurp)
  1082. }
  1083. return nil
  1084. }
  1085. ct.server = func() error {
  1086. ct.greet()
  1087. var buf bytes.Buffer
  1088. enc := hpack.NewEncoder(&buf)
  1089. for {
  1090. f, err := ct.fr.ReadFrame()
  1091. if err != nil {
  1092. return err
  1093. }
  1094. switch f := f.(type) {
  1095. case *HeadersFrame:
  1096. var endStream bool
  1097. send := func(mode headerType) {
  1098. hbf := buf.Bytes()
  1099. switch mode {
  1100. case oneHeader:
  1101. ct.fr.WriteHeaders(HeadersFrameParam{
  1102. StreamID: f.StreamID,
  1103. EndHeaders: true,
  1104. EndStream: endStream,
  1105. BlockFragment: hbf,
  1106. })
  1107. case splitHeader:
  1108. if len(hbf) < 2 {
  1109. panic("too small")
  1110. }
  1111. ct.fr.WriteHeaders(HeadersFrameParam{
  1112. StreamID: f.StreamID,
  1113. EndHeaders: false,
  1114. EndStream: endStream,
  1115. BlockFragment: hbf[:1],
  1116. })
  1117. ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
  1118. default:
  1119. panic("bogus mode")
  1120. }
  1121. }
  1122. // Response headers (1+ frames; 1 or 2 in this test, but never 0)
  1123. {
  1124. buf.Reset()
  1125. enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1126. enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"})
  1127. endStream = false
  1128. send(oneHeader)
  1129. }
  1130. // Trailers:
  1131. {
  1132. endStream = true
  1133. buf.Reset()
  1134. writeTrailer(enc)
  1135. send(trailers)
  1136. }
  1137. return nil
  1138. }
  1139. }
  1140. }
  1141. ct.run()
  1142. }
  1143. func TestTransportChecksResponseHeaderListSize(t *testing.T) {
  1144. ct := newClientTester(t)
  1145. ct.client = func() error {
  1146. req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1147. res, err := ct.tr.RoundTrip(req)
  1148. if err != errResponseHeaderListSize {
  1149. if res != nil {
  1150. res.Body.Close()
  1151. }
  1152. size := int64(0)
  1153. for k, vv := range res.Header {
  1154. for _, v := range vv {
  1155. size += int64(len(k)) + int64(len(v)) + 32
  1156. }
  1157. }
  1158. return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
  1159. }
  1160. return nil
  1161. }
  1162. ct.server = func() error {
  1163. ct.greet()
  1164. var buf bytes.Buffer
  1165. enc := hpack.NewEncoder(&buf)
  1166. for {
  1167. f, err := ct.fr.ReadFrame()
  1168. if err != nil {
  1169. return err
  1170. }
  1171. switch f := f.(type) {
  1172. case *HeadersFrame:
  1173. enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1174. large := strings.Repeat("a", 1<<10)
  1175. for i := 0; i < 5042; i++ {
  1176. enc.WriteField(hpack.HeaderField{Name: large, Value: large})
  1177. }
  1178. if size, want := buf.Len(), 6329; size != want {
  1179. // Note: this number might change if
  1180. // our hpack implementation
  1181. // changes. That's fine. This is
  1182. // just a sanity check that our
  1183. // response can fit in a single
  1184. // header block fragment frame.
  1185. return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
  1186. }
  1187. ct.fr.WriteHeaders(HeadersFrameParam{
  1188. StreamID: f.StreamID,
  1189. EndHeaders: true,
  1190. EndStream: true,
  1191. BlockFragment: buf.Bytes(),
  1192. })
  1193. return nil
  1194. }
  1195. }
  1196. }
  1197. ct.run()
  1198. }
  1199. // Test that the the Transport returns a typed error from Response.Body.Read calls
  1200. // when the server sends an error. (here we use a panic, since that should generate
  1201. // a stream error, but others like cancel should be similar)
  1202. func TestTransportBodyReadErrorType(t *testing.T) {
  1203. doPanic := make(chan bool, 1)
  1204. st := newServerTester(t,
  1205. func(w http.ResponseWriter, r *http.Request) {
  1206. w.(http.Flusher).Flush() // force headers out
  1207. <-doPanic
  1208. panic("boom")
  1209. },
  1210. optOnlyServer,
  1211. optQuiet,
  1212. )
  1213. defer st.Close()
  1214. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  1215. defer tr.CloseIdleConnections()
  1216. c := &http.Client{Transport: tr}
  1217. res, err := c.Get(st.ts.URL)
  1218. if err != nil {
  1219. t.Fatal(err)
  1220. }
  1221. defer res.Body.Close()
  1222. doPanic <- true
  1223. buf := make([]byte, 100)
  1224. n, err := res.Body.Read(buf)
  1225. want := StreamError{StreamID: 0x1, Code: 0x2}
  1226. if !reflect.DeepEqual(want, err) {
  1227. t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
  1228. }
  1229. }
  1230. // golang.org/issue/13924
  1231. // This used to fail after many iterations, especially with -race:
  1232. // go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race
  1233. func TestTransportDoubleCloseOnWriteError(t *testing.T) {
  1234. var (
  1235. mu sync.Mutex
  1236. conn net.Conn // to close if set
  1237. )
  1238. st := newServerTester(t,
  1239. func(w http.ResponseWriter, r *http.Request) {
  1240. mu.Lock()
  1241. defer mu.Unlock()
  1242. if conn != nil {
  1243. conn.Close()
  1244. }
  1245. },
  1246. optOnlyServer,
  1247. )
  1248. defer st.Close()
  1249. tr := &Transport{
  1250. TLSClientConfig: tlsConfigInsecure,
  1251. DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  1252. tc, err := tls.Dial(network, addr, cfg)
  1253. if err != nil {
  1254. return nil, err
  1255. }
  1256. mu.Lock()
  1257. defer mu.Unlock()
  1258. conn = tc
  1259. return tc, nil
  1260. },
  1261. }
  1262. defer tr.CloseIdleConnections()
  1263. c := &http.Client{Transport: tr}
  1264. c.Get(st.ts.URL)
  1265. }
  1266. // Test that the http1 Transport.DisableKeepAlives option is respected
  1267. // and connections are closed as soon as idle.
  1268. // See golang.org/issue/14008
  1269. func TestTransportDisableKeepAlives(t *testing.T) {
  1270. st := newServerTester(t,
  1271. func(w http.ResponseWriter, r *http.Request) {
  1272. io.WriteString(w, "hi")
  1273. },
  1274. optOnlyServer,
  1275. )
  1276. defer st.Close()
  1277. connClosed := make(chan struct{}) // closed on tls.Conn.Close
  1278. tr := &Transport{
  1279. t1: &http.Transport{
  1280. DisableKeepAlives: true,
  1281. },
  1282. TLSClientConfig: tlsConfigInsecure,
  1283. DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  1284. tc, err := tls.Dial(network, addr, cfg)
  1285. if err != nil {
  1286. return nil, err
  1287. }
  1288. return &noteCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
  1289. },
  1290. }
  1291. c := &http.Client{Transport: tr}
  1292. res, err := c.Get(st.ts.URL)
  1293. if err != nil {
  1294. t.Fatal(err)
  1295. }
  1296. if _, err := ioutil.ReadAll(res.Body); err != nil {
  1297. t.Fatal(err)
  1298. }
  1299. defer res.Body.Close()
  1300. select {
  1301. case <-connClosed:
  1302. case <-time.After(1 * time.Second):
  1303. t.Errorf("timeout")
  1304. }
  1305. }
  1306. // Test concurrent requests with Transport.DisableKeepAlives. We can share connections,
  1307. // but when things are totally idle, it still needs to close.
  1308. func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
  1309. const D = 25 * time.Millisecond
  1310. st := newServerTester(t,
  1311. func(w http.ResponseWriter, r *http.Request) {
  1312. time.Sleep(D)
  1313. io.WriteString(w, "hi")
  1314. },
  1315. optOnlyServer,
  1316. )
  1317. defer st.Close()
  1318. var dials int32
  1319. var conns sync.WaitGroup
  1320. tr := &Transport{
  1321. t1: &http.Transport{
  1322. DisableKeepAlives: true,
  1323. },
  1324. TLSClientConfig: tlsConfigInsecure,
  1325. DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  1326. tc, err := tls.Dial(network, addr, cfg)
  1327. if err != nil {
  1328. return nil, err
  1329. }
  1330. atomic.AddInt32(&dials, 1)
  1331. conns.Add(1)
  1332. return &noteCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
  1333. },
  1334. }
  1335. c := &http.Client{Transport: tr}
  1336. var reqs sync.WaitGroup
  1337. const N = 20
  1338. for i := 0; i < N; i++ {
  1339. reqs.Add(1)
  1340. if i == N-1 {
  1341. // For the final request, try to make all the
  1342. // others close. This isn't verified in the
  1343. // count, other than the Log statement, since
  1344. // it's so timing dependent. This test is
  1345. // really to make sure we don't interrupt a
  1346. // valid request.
  1347. time.Sleep(D * 2)
  1348. }
  1349. go func() {
  1350. defer reqs.Done()
  1351. res, err := c.Get(st.ts.URL)
  1352. if err != nil {
  1353. t.Error(err)
  1354. return
  1355. }
  1356. if _, err := ioutil.ReadAll(res.Body); err != nil {
  1357. t.Error(err)
  1358. return
  1359. }
  1360. res.Body.Close()
  1361. }()
  1362. }
  1363. reqs.Wait()
  1364. conns.Wait()
  1365. t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
  1366. }
  1367. type noteCloseConn struct {
  1368. net.Conn
  1369. onceClose sync.Once
  1370. closefn func()
  1371. }
  1372. func (c *noteCloseConn) Close() error {
  1373. c.onceClose.Do(c.closefn)
  1374. return c.Conn.Close()
  1375. }
  1376. func isTimeout(err error) bool {
  1377. switch err := err.(type) {
  1378. case nil:
  1379. return false
  1380. case *url.Error:
  1381. return isTimeout(err.Err)
  1382. case net.Error:
  1383. return err.Timeout()
  1384. }
  1385. return false
  1386. }
  1387. // Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent.
  1388. func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
  1389. testTransportResponseHeaderTimeout(t, false)
  1390. }
  1391. func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
  1392. testTransportResponseHeaderTimeout(t, true)
  1393. }
  1394. func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
  1395. ct := newClientTester(t)
  1396. ct.tr.t1 = &http.Transport{
  1397. ResponseHeaderTimeout: 5 * time.Millisecond,
  1398. }
  1399. ct.client = func() error {
  1400. c := &http.Client{Transport: ct.tr}
  1401. var err error
  1402. var n int64
  1403. const bodySize = 4 << 20
  1404. if body {
  1405. _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
  1406. } else {
  1407. _, err = c.Get("https://dummy.tld/")
  1408. }
  1409. if !isTimeout(err) {
  1410. t.Errorf("client expected timeout error; got %#v", err)
  1411. }
  1412. if body && n != bodySize {
  1413. t.Errorf("only read %d bytes of body; want %d", n, bodySize)
  1414. }
  1415. return nil
  1416. }
  1417. ct.server = func() error {
  1418. ct.greet()
  1419. for {
  1420. f, err := ct.fr.ReadFrame()
  1421. if err != nil {
  1422. t.Logf("ReadFrame: %v", err)
  1423. return nil
  1424. }
  1425. switch f := f.(type) {
  1426. case *DataFrame:
  1427. dataLen := len(f.Data())
  1428. if dataLen > 0 {
  1429. if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
  1430. return err
  1431. }
  1432. if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
  1433. return err
  1434. }
  1435. }
  1436. case *RSTStreamFrame:
  1437. if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
  1438. return nil
  1439. }
  1440. }
  1441. }
  1442. return nil
  1443. }
  1444. ct.run()
  1445. }
  1446. func TestTransportDisableCompression(t *testing.T) {
  1447. const body = "sup"
  1448. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  1449. want := http.Header{
  1450. "User-Agent": []string{"Go-http-client/2.0"},
  1451. }
  1452. if !reflect.DeepEqual(r.Header, want) {
  1453. t.Errorf("request headers = %v; want %v", r.Header, want)
  1454. }
  1455. }, optOnlyServer)
  1456. defer st.Close()
  1457. tr := &Transport{
  1458. TLSClientConfig: tlsConfigInsecure,
  1459. t1: &http.Transport{
  1460. DisableCompression: true,
  1461. },
  1462. }
  1463. defer tr.CloseIdleConnections()
  1464. req, err := http.NewRequest("GET", st.ts.URL, nil)
  1465. if err != nil {
  1466. t.Fatal(err)
  1467. }
  1468. res, err := tr.RoundTrip(req)
  1469. if err != nil {
  1470. t.Fatal(err)
  1471. }
  1472. defer res.Body.Close()
  1473. }
  1474. // RFC 7540 section 8.1.2.2
  1475. func TestTransportRejectsConnHeaders(t *testing.T) {
  1476. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  1477. var got []string
  1478. for k := range r.Header {
  1479. got = append(got, k)
  1480. }
  1481. sort.Strings(got)
  1482. w.Header().Set("Got-Header", strings.Join(got, ","))
  1483. }, optOnlyServer)
  1484. defer st.Close()
  1485. tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  1486. defer tr.CloseIdleConnections()
  1487. tests := []struct {
  1488. key string
  1489. value []string
  1490. want string
  1491. }{
  1492. {
  1493. key: "Upgrade",
  1494. value: []string{"anything"},
  1495. want: "ERROR: http2: invalid Upgrade request header",
  1496. },
  1497. {
  1498. key: "Connection",
  1499. value: []string{"foo"},
  1500. want: "ERROR: http2: invalid Connection request header",
  1501. },
  1502. {
  1503. key: "Connection",
  1504. value: []string{"close"},
  1505. want: "Accept-Encoding,User-Agent",
  1506. },
  1507. {
  1508. key: "Connection",
  1509. value: []string{"close", "something-else"},
  1510. want: "ERROR: http2: invalid Connection request header",
  1511. },
  1512. {
  1513. key: "Connection",
  1514. value: []string{"keep-alive"},
  1515. want: "Accept-Encoding,User-Agent",
  1516. },
  1517. {
  1518. key: "Proxy-Connection", // just deleted and ignored
  1519. value: []string{"keep-alive"},
  1520. want: "Accept-Encoding,User-Agent",
  1521. },
  1522. {
  1523. key: "Transfer-Encoding",
  1524. value: []string{""},
  1525. want: "Accept-Encoding,User-Agent",
  1526. },
  1527. {
  1528. key: "Transfer-Encoding",
  1529. value: []string{"foo"},
  1530. want: "ERROR: http2: invalid Transfer-Encoding request header",
  1531. },
  1532. {
  1533. key: "Transfer-Encoding",
  1534. value: []string{"chunked"},
  1535. want: "Accept-Encoding,User-Agent",
  1536. },
  1537. {
  1538. key: "Transfer-Encoding",
  1539. value: []string{"chunked", "other"},
  1540. want: "ERROR: http2: invalid Transfer-Encoding request header",
  1541. },
  1542. {
  1543. key: "Content-Length",
  1544. value: []string{"123"},
  1545. want: "Accept-Encoding,User-Agent",
  1546. },
  1547. }
  1548. for _, tt := range tests {
  1549. req, _ := http.NewRequest("GET", st.ts.URL, nil)
  1550. req.Header[tt.key] = tt.value
  1551. res, err := tr.RoundTrip(req)
  1552. var got string
  1553. if err != nil {
  1554. got = fmt.Sprintf("ERROR: %v", err)
  1555. } else {
  1556. got = res.Header.Get("Got-Header")
  1557. res.Body.Close()
  1558. }
  1559. if got != tt.want {
  1560. t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
  1561. }
  1562. }
  1563. }
  1564. // Tests that gzipReader doesn't crash on a second Read call following
  1565. // the first Read call's gzip.NewReader returning an error.
  1566. func TestGzipReader_DoubleReadCrash(t *testing.T) {
  1567. gz := &gzipReader{
  1568. body: ioutil.NopCloser(strings.NewReader("0123456789")),
  1569. }
  1570. var buf [1]byte
  1571. n, err1 := gz.Read(buf[:])
  1572. if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") {
  1573. t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1)
  1574. }
  1575. n, err2 := gz.Read(buf[:])
  1576. if n != 0 || err2 != err1 {
  1577. t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1)
  1578. }
  1579. }