configure_transport.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. // Copyright 2015 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build go1.6
  5. package http2
  6. import (
  7. "crypto/tls"
  8. "fmt"
  9. "net/http"
  10. )
  11. func configureTransport(t1 *http.Transport) error {
  12. connPool := new(clientConnPool)
  13. t2 := &Transport{ConnPool: noDialClientConnPool{connPool}}
  14. if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
  15. return err
  16. }
  17. if t1.TLSClientConfig == nil {
  18. t1.TLSClientConfig = new(tls.Config)
  19. }
  20. if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
  21. t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
  22. }
  23. upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
  24. cc, err := t2.NewClientConn(c)
  25. if err != nil {
  26. c.Close()
  27. return erringRoundTripper{err}
  28. }
  29. connPool.addConn(authorityAddr(authority), cc)
  30. return t2
  31. }
  32. if m := t1.TLSNextProto; len(m) == 0 {
  33. t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
  34. "h2": upgradeFn,
  35. }
  36. } else {
  37. m["h2"] = upgradeFn
  38. }
  39. return nil
  40. }
  41. // registerHTTPSProtocol calls Transport.RegisterProtocol but
  42. // convering panics into errors.
  43. func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) {
  44. defer func() {
  45. if e := recover(); e != nil {
  46. err = fmt.Errorf("%v", e)
  47. }
  48. }()
  49. t.RegisterProtocol("https", rt)
  50. return nil
  51. }
  52. // noDialClientConnPool is an implementation of http2.ClientConnPool
  53. // which never dials. We let the HTTP/1.1 client dial and use its TLS
  54. // connection instead.
  55. type noDialClientConnPool struct{ *clientConnPool }
  56. func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
  57. const doDial = false
  58. return p.getClientConn(req, addr, doDial)
  59. }
  60. // noDialH2RoundTripper is a RoundTripper which only tries to complete the request
  61. // if there's already has a cached connection to the host.
  62. type noDialH2RoundTripper struct{ t *Transport }
  63. func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  64. res, err := rt.t.RoundTrip(req)
  65. if err == ErrNoCachedConn {
  66. return nil, http.ErrSkipAltProtocol
  67. }
  68. return res, err
  69. }