|
|
@@ -32,20 +32,28 @@ import (
|
|
|
)
|
|
|
|
|
|
type serverTester struct {
|
|
|
- cc net.Conn // client conn
|
|
|
- t *testing.T
|
|
|
- ts *httptest.Server
|
|
|
- fr *Framer
|
|
|
- logBuf *bytes.Buffer
|
|
|
- sc *serverConn
|
|
|
+ cc net.Conn // client conn
|
|
|
+ t *testing.T
|
|
|
+ ts *httptest.Server
|
|
|
+ fr *Framer
|
|
|
+ logBuf *bytes.Buffer
|
|
|
+ sc *serverConn
|
|
|
+ logFilter []string // substrings to filter out
|
|
|
}
|
|
|
|
|
|
func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
|
|
|
logBuf := new(bytes.Buffer)
|
|
|
ts := httptest.NewUnstartedServer(handler)
|
|
|
ConfigureServer(ts.Config, &Server{})
|
|
|
+
|
|
|
+ st := &serverTester{
|
|
|
+ t: t,
|
|
|
+ ts: ts,
|
|
|
+ logBuf: logBuf,
|
|
|
+ }
|
|
|
+
|
|
|
ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
|
|
|
- ts.Config.ErrorLog = log.New(io.MultiWriter(twriter{t: t}, logBuf), "", log.LstdFlags)
|
|
|
+ ts.Config.ErrorLog = log.New(io.MultiWriter(twriter{t: t, st: st}, logBuf), "", log.LstdFlags)
|
|
|
ts.StartTLS()
|
|
|
|
|
|
if VerboseLogs {
|
|
|
@@ -68,16 +76,19 @@ func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- log.SetOutput(twriter{t})
|
|
|
+ log.SetOutput(twriter{t: t, st: st})
|
|
|
+
|
|
|
+ st.cc = cc
|
|
|
+ st.fr = NewFramer(cc, cc)
|
|
|
+
|
|
|
mu.Lock()
|
|
|
- return &serverTester{
|
|
|
- t: t,
|
|
|
- ts: ts,
|
|
|
- cc: cc,
|
|
|
- fr: NewFramer(cc, cc),
|
|
|
- logBuf: logBuf,
|
|
|
- sc: sc,
|
|
|
- }
|
|
|
+ st.sc = sc
|
|
|
+ mu.Unlock() // unnecessary, but looks weird without.
|
|
|
+ return st
|
|
|
+}
|
|
|
+
|
|
|
+func (st *serverTester) addLogFilter(phrase string) {
|
|
|
+ st.logFilter = append(st.logFilter, phrase)
|
|
|
}
|
|
|
|
|
|
func (st *serverTester) stream(id uint32) *stream {
|
|
|
@@ -304,10 +315,7 @@ func TestServer(t *testing.T) {
|
|
|
|
|
|
st.writePreface()
|
|
|
st.writeInitialSettings()
|
|
|
- st.wantSettings().ForeachSetting(func(s Setting) error {
|
|
|
- t.Logf("Server sent setting %v = %v", s.ID, s.Val)
|
|
|
- return nil
|
|
|
- })
|
|
|
+ st.wantSettings()
|
|
|
st.writeSettingsAck()
|
|
|
st.wantSettingsAck()
|
|
|
|
|
|
@@ -647,7 +655,10 @@ func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
|
|
|
func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
|
|
|
// 8.1.2.3 Request Pseudo-Header Fields
|
|
|
// "All HTTP/2 requests MUST include exactly one valid value" ...
|
|
|
- testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "GET", ":method", "POST") })
|
|
|
+ testRejectRequest(t, func(st *serverTester) {
|
|
|
+ st.addLogFilter("duplicate pseudo-header")
|
|
|
+ st.bodylessReq1(":method", "GET", ":method", "POST")
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
|
|
|
@@ -658,6 +669,7 @@ func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
|
|
|
// block after a regular header field MUST be treated as
|
|
|
// malformed (Section 8.1.2.6)."
|
|
|
testRejectRequest(t, func(st *serverTester) {
|
|
|
+ st.addLogFilter("pseudo-header after regular header")
|
|
|
var buf bytes.Buffer
|
|
|
enc := hpack.NewEncoder(&buf)
|
|
|
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
|
|
|
@@ -686,7 +698,10 @@ func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
|
|
|
- testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":unknown_thing", "") })
|
|
|
+ testRejectRequest(t, func(st *serverTester) {
|
|
|
+ st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
|
|
|
+ st.bodylessReq1(":unknown_thing", "")
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
func testRejectRequest(t *testing.T, send func(*serverTester)) {
|
|
|
@@ -977,8 +992,121 @@ func TestServer_StateTransitions(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
|
|
|
-// TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
|
|
|
+// test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
|
|
|
+func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
|
|
|
+ testServerRejects(t, func(st *serverTester) {
|
|
|
+ st.writeHeaders(HeadersFrameParam{
|
|
|
+ StreamID: 1,
|
|
|
+ BlockFragment: encodeHeader(st.t),
|
|
|
+ EndStream: true,
|
|
|
+ EndHeaders: false,
|
|
|
+ })
|
|
|
+ st.writeHeaders(HeadersFrameParam{ // Not a continuation.
|
|
|
+ StreamID: 3, // different stream.
|
|
|
+ BlockFragment: encodeHeader(st.t),
|
|
|
+ EndStream: true,
|
|
|
+ EndHeaders: true,
|
|
|
+ })
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// test HEADERS w/o EndHeaders + PING (should get rejected)
|
|
|
+func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
|
|
|
+ testServerRejects(t, func(st *serverTester) {
|
|
|
+ st.writeHeaders(HeadersFrameParam{
|
|
|
+ StreamID: 1,
|
|
|
+ BlockFragment: encodeHeader(st.t),
|
|
|
+ EndStream: true,
|
|
|
+ EndHeaders: false,
|
|
|
+ })
|
|
|
+ if err := st.fr.WritePing(false, [8]byte{}); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
|
|
|
+func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
|
|
|
+ testServerRejects(t, func(st *serverTester) {
|
|
|
+ st.writeHeaders(HeadersFrameParam{
|
|
|
+ StreamID: 1,
|
|
|
+ BlockFragment: encodeHeader(st.t),
|
|
|
+ EndStream: true,
|
|
|
+ EndHeaders: true,
|
|
|
+ })
|
|
|
+ st.wantHeaders()
|
|
|
+ if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
|
|
|
+func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
|
|
|
+ testServerRejects(t, func(st *serverTester) {
|
|
|
+ st.writeHeaders(HeadersFrameParam{
|
|
|
+ StreamID: 1,
|
|
|
+ BlockFragment: encodeHeader(st.t),
|
|
|
+ EndStream: true,
|
|
|
+ EndHeaders: false,
|
|
|
+ })
|
|
|
+ if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// No HEADERS on stream 0.
|
|
|
+func TestServer_Rejects_Headers0(t *testing.T) {
|
|
|
+ testServerRejects(t, func(st *serverTester) {
|
|
|
+ st.fr.AllowIllegalWrites = true
|
|
|
+ st.writeHeaders(HeadersFrameParam{
|
|
|
+ StreamID: 0,
|
|
|
+ BlockFragment: encodeHeader(st.t),
|
|
|
+ EndStream: true,
|
|
|
+ EndHeaders: true,
|
|
|
+ })
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// No CONTINUATION on stream 0.
|
|
|
+func TestServer_Rejects_Continuation0(t *testing.T) {
|
|
|
+ testServerRejects(t, func(st *serverTester) {
|
|
|
+ st.fr.AllowIllegalWrites = true
|
|
|
+ if err := st.fr.WriteContinuation(0, true, encodeHeader(t)); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// testServerRejects tests that the server hangs up with a GOAWAY
|
|
|
+// frame and a server close after the client does something
|
|
|
+// deserving a CONNECTION_ERROR.
|
|
|
+func testServerRejects(t *testing.T, writeReq func(*serverTester)) {
|
|
|
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
|
|
|
+ st.addLogFilter("connection error: PROTOCOL_ERROR")
|
|
|
+ defer st.Close()
|
|
|
+ st.greet()
|
|
|
+ writeReq(st)
|
|
|
+
|
|
|
+ st.wantGoAway()
|
|
|
+ errc := make(chan error, 1)
|
|
|
+ go func() {
|
|
|
+ fr, err := st.fr.ReadFrame()
|
|
|
+ if err == nil {
|
|
|
+ err = fmt.Errorf("got frame of type %T", fr)
|
|
|
+ }
|
|
|
+ errc <- err
|
|
|
+ }()
|
|
|
+ select {
|
|
|
+ case err := <-errc:
|
|
|
+ if err != io.EOF {
|
|
|
+ t.Errorf("ReadFrame = %v; want io.EOF", err)
|
|
|
+ }
|
|
|
+ case <-time.After(2 * time.Second):
|
|
|
+ t.Error("timeout waiting for disconnect")
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
// testServerRequest sets up an idle HTTP/2 connection and lets you
|
|
|
// write a single request with writeReq, and then verify that the
|