Browse Source

add Framer.WriteContinuation

Brad Fitzpatrick 11 years ago
parent
commit
2888626574
2 changed files with 73 additions and 0 deletions
  1. 17 0
      frame.go
  2. 56 0
      frame_test.go

+ 17 - 0
frame.go

@@ -778,6 +778,23 @@ func (f *ContinuationFrame) HeadersEnded() bool {
 	return f.FrameHeader.Flags.Has(FlagContinuationEndHeaders)
 }
 
+// WriteContinuation writes a CONTINUATION frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error {
+	if !validStreamID(streamID) && !f.AllowIllegalWrites {
+		return errStreamID
+	}
+	var flags Flags
+	if endHeaders {
+		flags |= FlagContinuationEndHeaders
+	}
+	f.startWrite(FrameContinuation, flags, streamID)
+	f.wbuf = append(f.wbuf, headerBlockFragment...)
+	return f.endWrite()
+}
+
 func readByte(p []byte) (remain []byte, b byte, err error) {
 	if len(p) == 0 {
 		return nil, 0, io.ErrUnexpectedEOF

+ 56 - 0
frame_test.go

@@ -190,3 +190,59 @@ func TestWriteHeaders(t *testing.T) {
 		}
 	}
 }
+
+func TestWriteContinuation(t *testing.T) {
+	const streamID = 42
+	tests := []struct {
+		name string
+		end  bool
+		frag []byte
+
+		wantFrame *ContinuationFrame
+	}{
+		{
+			"not end",
+			false,
+			[]byte("abc"),
+			&ContinuationFrame{
+				FrameHeader: FrameHeader{
+					valid:    true,
+					StreamID: streamID,
+					Type:     FrameContinuation,
+					Length:   uint32(len("abc")),
+				},
+				headerFragBuf: []byte("abc"),
+			},
+		},
+		{
+			"end",
+			true,
+			[]byte("def"),
+			&ContinuationFrame{
+				FrameHeader: FrameHeader{
+					valid:    true,
+					StreamID: streamID,
+					Type:     FrameContinuation,
+					Flags:    FlagContinuationEndHeaders,
+					Length:   uint32(len("def")),
+				},
+				headerFragBuf: []byte("def"),
+			},
+		},
+	}
+	for _, tt := range tests {
+		fr, _ := testFramer()
+		if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil {
+			t.Errorf("test %q: %v", tt.name, err)
+			continue
+		}
+		f, err := fr.ReadFrame()
+		if err != nil {
+			t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
+			continue
+		}
+		if !reflect.DeepEqual(f, tt.wantFrame) {
+			t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
+		}
+	}
+}