configure_transport.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. // Copyright 2015 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build go1.6
  5. package http2
  6. import (
  7. "crypto/tls"
  8. "fmt"
  9. "net/http"
  10. )
  11. func configureTransport(t1 *http.Transport) error {
  12. connPool := new(clientConnPool)
  13. t2 := &Transport{ConnPool: noDialClientConnPool{connPool}}
  14. if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
  15. return err
  16. }
  17. if t1.TLSClientConfig == nil {
  18. t1.TLSClientConfig = new(tls.Config)
  19. }
  20. if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
  21. t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
  22. }
  23. if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
  24. t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
  25. }
  26. upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
  27. cc, err := t2.NewClientConn(c)
  28. if err != nil {
  29. c.Close()
  30. return erringRoundTripper{err}
  31. }
  32. connPool.addConn(authorityAddr(authority), cc)
  33. return t2
  34. }
  35. if m := t1.TLSNextProto; len(m) == 0 {
  36. t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
  37. "h2": upgradeFn,
  38. }
  39. } else {
  40. m["h2"] = upgradeFn
  41. }
  42. return nil
  43. }
  44. // registerHTTPSProtocol calls Transport.RegisterProtocol but
  45. // convering panics into errors.
  46. func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) {
  47. defer func() {
  48. if e := recover(); e != nil {
  49. err = fmt.Errorf("%v", e)
  50. }
  51. }()
  52. t.RegisterProtocol("https", rt)
  53. return nil
  54. }
  55. // noDialClientConnPool is an implementation of http2.ClientConnPool
  56. // which never dials. We let the HTTP/1.1 client dial and use its TLS
  57. // connection instead.
  58. type noDialClientConnPool struct{ *clientConnPool }
  59. func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
  60. const doDial = false
  61. return p.getClientConn(req, addr, doDial)
  62. }
  63. // noDialH2RoundTripper is a RoundTripper which only tries to complete the request
  64. // if there's already has a cached connection to the host.
  65. type noDialH2RoundTripper struct{ t *Transport }
  66. func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  67. res, err := rt.t.RoundTrip(req)
  68. if err == ErrNoCachedConn {
  69. return nil, http.ErrSkipAltProtocol
  70. }
  71. return res, err
  72. }