Browse Source

http2: export Server.ServeConn

Fixes golang/go#12737
Updates golang/go#14141

Change-Id: I552b603b63a7c87d7fcdb4eb09f96ab9fd0ec0aa
Reviewed-on: https://go-review.googlesource.com/19176
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Brad Fitzpatrick 10 năm trước cách đây
mục cha
commit
6ccd6698c6
4 tập tin đã thay đổi với 169 bổ sung16 xóa
  1. 5 0
      http2/http2.go
  2. 57 9
      http2/server.go
  3. 107 4
      http2/server_test.go
  4. 0 3
      http2/transport.go

+ 5 - 0
http2/http2.go

@@ -17,6 +17,7 @@ package http2
 
 import (
 	"bufio"
+	"crypto/tls"
 	"errors"
 	"fmt"
 	"io"
@@ -422,3 +423,7 @@ var isTokenTable = [127]bool{
 	'|':  true,
 	'~':  true,
 }
+
+type connectionStater interface {
+	ConnectionState() tls.ConnectionState
+}

+ 57 - 9
http2/server.go

@@ -195,28 +195,76 @@ func ConfigureServer(s *http.Server, conf *Server) error {
 		if testHookOnConn != nil {
 			testHookOnConn()
 		}
-		conf.handleConn(hs, c, h)
+		conf.ServeConn(c, &ServeConnOpts{
+			Handler:    h,
+			BaseConfig: hs,
+		})
 	}
 	s.TLSNextProto[NextProtoTLS] = protoHandler
 	s.TLSNextProto["h2-14"] = protoHandler // temporary; see above.
 	return nil
 }
 
-func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
+// ServeConnOpts are options for the Server.ServeConn method.
+type ServeConnOpts struct {
+	// BaseConfig optionally sets the base configuration
+	// for values. If nil, defaults are used.
+	BaseConfig *http.Server
+
+	// Handler specifies which handler to use for processing
+	// requests. If nil, BaseConfig.Handler is used. If BaseConfig
+	// or BaseConfig.Handler is nil, http.DefaultServeMux is used.
+	Handler http.Handler
+}
+
+func (o *ServeConnOpts) baseConfig() *http.Server {
+	if o != nil && o.BaseConfig != nil {
+		return o.BaseConfig
+	}
+	return new(http.Server)
+}
+
+func (o *ServeConnOpts) handler() http.Handler {
+	if o != nil {
+		if o.Handler != nil {
+			return o.Handler
+		}
+		if o.BaseConfig != nil && o.BaseConfig.Handler != nil {
+			return o.BaseConfig.Handler
+		}
+	}
+	return http.DefaultServeMux
+}
+
+// ServeConn serves HTTP/2 requests on the provided connection and
+// blocks until the connection is no longer readable.
+//
+// ServeConn starts speaking HTTP/2 assuming that c has not had any
+// reads or writes. It writes its initial settings frame and expects
+// to be able to read the preface and settings frame from the
+// client. If c has a ConnectionState method like a *tls.Conn, the
+// ConnectionState is used to verify the TLS ciphersuite and to set
+// the Request.TLS field in Handlers.
+//
+// ServeConn does not support h2c by itself. Any h2c support must be
+// implemented in terms of providing a suitably-behaving net.Conn.
+//
+// The opts parameter is optional. If nil, default values are used.
+func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
 	sc := &serverConn{
-		srv:              srv,
-		hs:               hs,
+		srv:              s,
+		hs:               opts.baseConfig(),
 		conn:             c,
 		remoteAddrStr:    c.RemoteAddr().String(),
 		bw:               newBufferedWriter(c),
-		handler:          h,
+		handler:          opts.handler(),
 		streams:          make(map[uint32]*stream),
 		readFrameCh:      make(chan readFrameResult),
 		wantWriteFrameCh: make(chan frameWriteMsg, 8),
 		wroteFrameCh:     make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync
 		bodyReadCh:       make(chan bodyReadMsg),         // buffering doesn't matter either way
 		doneServing:      make(chan struct{}),
-		advMaxStreams:    srv.maxConcurrentStreams(),
+		advMaxStreams:    s.maxConcurrentStreams(),
 		writeSched: writeScheduler{
 			maxFrameSize: initialMaxFrameSize,
 		},
@@ -232,10 +280,10 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 	sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
 
 	fr := NewFramer(sc.bw, c)
-	fr.SetMaxReadFrameSize(srv.maxReadFrameSize())
+	fr.SetMaxReadFrameSize(s.maxReadFrameSize())
 	sc.framer = fr
 
-	if tc, ok := c.(*tls.Conn); ok {
+	if tc, ok := c.(connectionStater); ok {
 		sc.tlsState = new(tls.ConnectionState)
 		*sc.tlsState = tc.ConnectionState()
 		// 9.2 Use of TLS Features
@@ -265,7 +313,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 			// So for now, do nothing here again.
 		}
 
-		if !srv.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
+		if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
 			// "Endpoints MAY choose to generate a connection error
 			// (Section 5.4.1) of type INADEQUATE_SECURITY if one of
 			// the prohibited cipher suites are negotiated."

+ 107 - 4
http2/server_test.go

@@ -2808,12 +2808,16 @@ func TestIssue53(t *testing.T) {
 		"\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
 	s := &http.Server{
 		ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags),
+		Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+			w.Write([]byte("hello"))
+		}),
+	}
+	s2 := &Server{
+		MaxReadFrameSize:             1 << 16,
+		PermitProhibitedCipherSuites: true,
 	}
-	s2 := &Server{MaxReadFrameSize: 1 << 16, PermitProhibitedCipherSuites: true}
 	c := &issue53Conn{[]byte(data), false, false}
-	s2.handleConn(s, c, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-		w.Write([]byte("hello"))
-	}))
+	s2.ServeConn(c, &ServeConnOpts{BaseConfig: s})
 	if !c.closed {
 		t.Fatal("connection is not closed")
 	}
@@ -2977,3 +2981,102 @@ func TestServerNoDuplicateContentType(t *testing.T) {
 		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
 	}
 }
+
+type connStateConn struct {
+	net.Conn
+	cs tls.ConnectionState
+}
+
+func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs }
+
+// golang.org/issue/12737 -- handle any net.Conn, not just
+// *tls.Conn.
+func TestServerHandleCustomConn(t *testing.T) {
+	var s Server
+	c1, c2 := net.Pipe()
+	clientDone := make(chan struct{})
+	handlerDone := make(chan struct{})
+	var req *http.Request
+	go func() {
+		defer close(clientDone)
+		defer c2.Close()
+		fr := NewFramer(c2, c2)
+		io.WriteString(c2, ClientPreface)
+		fr.WriteSettings()
+		fr.WriteSettingsAck()
+		f, err := fr.ReadFrame()
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() {
+			t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f))
+			return
+		}
+		f, err = fr.ReadFrame()
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() {
+			t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f))
+			return
+		}
+		var henc hpackEncoder
+		fr.WriteHeaders(HeadersFrameParam{
+			StreamID:      1,
+			BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+		go io.Copy(ioutil.Discard, c2)
+		<-handlerDone
+	}()
+	const testString = "my custom ConnectionState"
+	fakeConnState := tls.ConnectionState{
+		ServerName: testString,
+		Version:    tls.VersionTLS12,
+	}
+	go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{
+		BaseConfig: &http.Server{
+			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				defer close(handlerDone)
+				req = r
+			}),
+		}})
+	select {
+	case <-clientDone:
+	case <-time.After(5 * time.Second):
+		t.Fatal("timeout waiting for handler")
+	}
+	if req.TLS == nil {
+		t.Fatalf("Request.TLS is nil. Got: %#v", req)
+	}
+	if req.TLS.ServerName != testString {
+		t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString)
+	}
+}
+
+type hpackEncoder struct {
+	enc *hpack.Encoder
+	buf bytes.Buffer
+}
+
+func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte {
+	if len(headers)%2 == 1 {
+		panic("odd number of kv args")
+	}
+	he.buf.Reset()
+	if he.enc == nil {
+		he.enc = hpack.NewEncoder(&he.buf)
+	}
+	for len(headers) > 0 {
+		k, v := headers[0], headers[1]
+		err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
+		if err != nil {
+			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+		}
+		headers = headers[2:]
+	}
+	return he.buf.Bytes()
+}

+ 0 - 3
http2/transport.go

@@ -406,9 +406,6 @@ func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
 	// henc in response to SETTINGS frames?
 	cc.henc = hpack.NewEncoder(&cc.hbuf)
 
-	type connectionStater interface {
-		ConnectionState() tls.ConnectionState
-	}
 	if cs, ok := c.(connectionStater); ok {
 		state := cs.ConnectionState()
 		cc.tlsState = &state