Browse Source

Respond to PING frames.

Brad Fitzpatrick 11 years ago
parent
commit
b59345dc02
2 changed files with 84 additions and 4 deletions
  1. 20 1
      http2.go
  2. 64 3
      http2_test.go

+ 20 - 1
http2.go

@@ -386,12 +386,31 @@ func (sc *serverConn) processFrame(f Frame) error {
 		return sc.processHeaders(f)
 		return sc.processHeaders(f)
 	case *ContinuationFrame:
 	case *ContinuationFrame:
 		return sc.processContinuation(f)
 		return sc.processContinuation(f)
+	case *PingFrame:
+		return sc.processPing(f)
 	default:
 	default:
-		log.Printf("Ignoring unknown %v", f.Header)
+		log.Printf("Ignoring unknown frame %#v", f)
 		return nil
 		return nil
 	}
 	}
 }
 }
 
 
+func (sc *serverConn) processPing(f *PingFrame) error {
+	sc.serveG.check()
+	if f.Flags.Has(FlagSettingsAck) {
+		// 6.7 PING: " An endpoint MUST NOT respond to PING frames containing this flag."
+		return nil
+	}
+	if f.StreamID != 0 {
+		// "PING frames are not associated with any individual
+		// stream. If a PING frame is received with a stream
+		// identifier field value other than 0x0, the
+		// recipient MUST respond with a connection error
+		// (Section 5.4.1) of type PROTOCOL_ERROR."
+		return ConnectionError(ErrCodeProtocol)
+	}
+	return sc.framer.WritePing(true, f.Data)
+}
+
 func (sc *serverConn) processSettings(f *SettingsFrame) error {
 func (sc *serverConn) processSettings(f *SettingsFrame) error {
 	sc.serveG.check()
 	sc.serveG.check()
 	f.ForeachSetting(func(s Setting) {
 	f.ForeachSetting(func(s Setting) {

+ 64 - 3
http2_test.go

@@ -131,8 +131,31 @@ func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte)
 	}
 	}
 }
 }
 
 
+func (st *serverTester) readFrame() (Frame, error) {
+	frc := make(chan Frame, 1)
+	errc := make(chan error, 1)
+	go func() {
+		fr, err := st.fr.ReadFrame()
+		if err != nil {
+			errc <- err
+		} else {
+			frc <- fr
+		}
+	}()
+	t := time.NewTimer(2 * time.Second)
+	defer t.Stop()
+	select {
+	case f := <-frc:
+		return f, nil
+	case err := <-errc:
+		return nil, err
+	case <-t.C:
+		return nil, errors.New("timeout waiting for frame")
+	}
+}
+
 func (st *serverTester) wantSettings() *SettingsFrame {
 func (st *serverTester) wantSettings() *SettingsFrame {
-	f, err := st.fr.ReadFrame()
+	f, err := st.readFrame()
 	if err != nil {
 	if err != nil {
 		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
 		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
 	}
 	}
@@ -143,8 +166,20 @@ func (st *serverTester) wantSettings() *SettingsFrame {
 	return sf
 	return sf
 }
 }
 
 
+func (st *serverTester) wantPing() *PingFrame {
+	f, err := st.readFrame()
+	if err != nil {
+		st.t.Fatalf("Error while expecting a PING frame: %v", err)
+	}
+	pf, ok := f.(*PingFrame)
+	if !ok {
+		st.t.Fatalf("got a %T; want *PingFrame", f)
+	}
+	return pf
+}
+
 func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
 func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
-	f, err := st.fr.ReadFrame()
+	f, err := st.readFrame()
 	if err != nil {
 	if err != nil {
 		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
 		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
 	}
 	}
@@ -161,7 +196,7 @@ func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
 }
 }
 
 
 func (st *serverTester) wantSettingsAck() {
 func (st *serverTester) wantSettingsAck() {
-	f, err := st.fr.ReadFrame()
+	f, err := st.readFrame()
 	if err != nil {
 	if err != nil {
 		st.t.Fatal(err)
 		st.t.Fatal(err)
 	}
 	}
@@ -458,6 +493,32 @@ func testRejectRequest(t *testing.T, send func(*serverTester)) {
 	st.wantRSTStream(1, ErrCodeProtocol)
 	st.wantRSTStream(1, ErrCodeProtocol)
 }
 }
 
 
+func TestServer_Ping(t *testing.T) {
+	st := newServerTester(t, nil)
+	defer st.Close()
+	st.greet()
+
+	// Server should ignore this one, since it has ACK set.
+	ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128}
+	if err := st.fr.WritePing(true, ackPingData); err != nil {
+		t.Fatal(err)
+	}
+
+	// But the server should reply to this one, since ACK is false.
+	pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+	if err := st.fr.WritePing(false, pingData); err != nil {
+		t.Fatal(err)
+	}
+
+	pf := st.wantPing()
+	if !pf.Flags.Has(FlagPingAck) {
+		t.Error("response ping doesn't have ACK set")
+	}
+	if pf.Data != pingData {
+		t.Errorf("response ping has data %q; want %q", pf.Data, pingData)
+	}
+}
+
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)