|
|
@@ -0,0 +1,78 @@
|
|
|
+// Copyright 2015 The Go Authors. All rights reserved.
|
|
|
+// Use of this source code is governed by a BSD-style
|
|
|
+// license that can be found in the LICENSE file.
|
|
|
+
|
|
|
+// +build go1.6
|
|
|
+
|
|
|
+package http2
|
|
|
+
|
|
|
+import (
|
|
|
+ "crypto/tls"
|
|
|
+ "fmt"
|
|
|
+ "net/http"
|
|
|
+)
|
|
|
+
|
|
|
+func configureTransport(t1 *http.Transport) error {
|
|
|
+ connPool := new(clientConnPool)
|
|
|
+ t2 := &Transport{ConnPool: noDialClientConnPool{connPool}}
|
|
|
+ if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if t1.TLSClientConfig == nil {
|
|
|
+ t1.TLSClientConfig = new(tls.Config)
|
|
|
+ }
|
|
|
+ if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
|
|
|
+ t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
|
|
|
+ }
|
|
|
+ upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
|
|
|
+ cc, err := t2.NewClientConn(c)
|
|
|
+ if err != nil {
|
|
|
+ c.Close()
|
|
|
+ return erringRoundTripper{err}
|
|
|
+ }
|
|
|
+ connPool.addConn(authorityAddr(authority), cc)
|
|
|
+ return t2
|
|
|
+ }
|
|
|
+ if m := t1.TLSNextProto; len(m) == 0 {
|
|
|
+ t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
|
|
|
+ "h2": upgradeFn,
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ m["h2"] = upgradeFn
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// registerHTTPSProtocol calls Transport.RegisterProtocol but
|
|
|
+// convering panics into errors.
|
|
|
+func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) {
|
|
|
+ defer func() {
|
|
|
+ if e := recover(); e != nil {
|
|
|
+ err = fmt.Errorf("%v", e)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ t.RegisterProtocol("https", rt)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// noDialClientConnPool is an implementation of http2.ClientConnPool
|
|
|
+// which never dials. We let the HTTP/1.1 client dial and use its TLS
|
|
|
+// connection instead.
|
|
|
+type noDialClientConnPool struct{ *clientConnPool }
|
|
|
+
|
|
|
+func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
|
|
|
+ const doDial = false
|
|
|
+ return p.getClientConn(req, addr, doDial)
|
|
|
+}
|
|
|
+
|
|
|
+// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
|
|
|
+// if there's already has a cached connection to the host.
|
|
|
+type noDialH2RoundTripper struct{ t *Transport }
|
|
|
+
|
|
|
+func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
+ res, err := rt.t.RoundTrip(req)
|
|
|
+ if err == ErrNoCachedConn {
|
|
|
+ return nil, http.ErrSkipAltProtocol
|
|
|
+ }
|
|
|
+ return res, err
|
|
|
+}
|