Browse Source

Fix Request's ContentLength and Headers in handlers + more tests

Brad Fitzpatrick 11 years ago
parent
commit
137b013472
2 changed files with 118 additions and 5 deletions
  1. 7 4
      http2.go
  2. 111 1
      http2_test.go

+ 7 - 4
http2.go

@@ -390,6 +390,7 @@ func (sc *serverConn) startHandler(streamID uint32, bodyOpen bool, method, path,
 		Method:     method,
 		Method:     method,
 		URL:        &url.URL{},
 		URL:        &url.URL{},
 		RemoteAddr: sc.conn.RemoteAddr().String(),
 		RemoteAddr: sc.conn.RemoteAddr().String(),
+		Header:     reqHeader,
 		RequestURI: path,
 		RequestURI: path,
 		Proto:      "HTTP/2.0",
 		Proto:      "HTTP/2.0",
 		ProtoMajor: 2,
 		ProtoMajor: 2,
@@ -402,10 +403,12 @@ func (sc *serverConn) startHandler(streamID uint32, bodyOpen bool, method, path,
 			hasBody:  bodyOpen,
 			hasBody:  bodyOpen,
 		},
 		},
 	}
 	}
-	if vv, ok := reqHeader["Content-Length"]; ok {
-		req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
-	} else {
-		req.ContentLength = -1
+	if bodyOpen {
+		if vv, ok := reqHeader["Content-Length"]; ok {
+			req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
+		} else {
+			req.ContentLength = -1
+		}
 	}
 	}
 	rw := &responseWriter{
 	rw := &responseWriter{
 		sc:       sc,
 		sc:       sc,

+ 111 - 1
http2_test.go

@@ -8,6 +8,7 @@
 package http2
 package http2
 
 
 import (
 import (
+	"bytes"
 	"crypto/tls"
 	"crypto/tls"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -17,11 +18,14 @@ import (
 	"net/http/httptest"
 	"net/http/httptest"
 	"os"
 	"os"
 	"os/exec"
 	"os/exec"
+	"reflect"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync/atomic"
 	"sync/atomic"
 	"testing"
 	"testing"
 	"time"
 	"time"
+
+	"github.com/bradfitz/http2/hpack"
 )
 )
 
 
 type serverTester struct {
 type serverTester struct {
@@ -82,6 +86,12 @@ func (st *serverTester) writeSettingsAck() {
 	}
 	}
 }
 }
 
 
+func (st *serverTester) writeHeaders(p HeadersFrameParam) {
+	if err := st.fr.WriteHeaders(p); err != nil {
+		st.t.Fatalf("Error writing HEADERS: %v", err)
+	}
+}
+
 func (st *serverTester) wantSettings() *SettingsFrame {
 func (st *serverTester) wantSettings() *SettingsFrame {
 	f, err := st.fr.ReadFrame()
 	f, err := st.fr.ReadFrame()
 	if err != nil {
 	if err != nil {
@@ -110,8 +120,11 @@ func (st *serverTester) wantSettingsAck() {
 }
 }
 
 
 func TestServer(t *testing.T) {
 func TestServer(t *testing.T) {
+	gotReq := make(chan bool, 1)
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("Foo", "Bar")
 		w.Header().Set("Foo", "Bar")
+		t.Logf("GOT REQUEST %#v", r)
+		gotReq <- true
 	})
 	})
 	defer st.Close()
 	defer st.Close()
 
 
@@ -129,7 +142,88 @@ func TestServer(t *testing.T) {
 	st.writeSettingsAck()
 	st.writeSettingsAck()
 	st.wantSettingsAck()
 	st.wantSettingsAck()
 
 
-	// TODO: send a request
+	st.writeHeaders(HeadersFrameParam{
+		StreamID: 1, // clients send odd numbers
+		BlockFragment: encodeHeader(t,
+			":method", "GET",
+			":path", "/",
+			":scheme", "https",
+		),
+		EndStream:  true, // no DATA frames
+		EndHeaders: true,
+	})
+
+	select {
+	case <-gotReq:
+	case <-time.After(2 * time.Second):
+		t.Error("timeout waiting for request")
+	}
+}
+
+func TestServer_Request_Get(t *testing.T) {
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID: 1, // clients send odd numbers
+			BlockFragment: encodeHeader(t,
+				":method", "GET",
+				":path", "/",
+				":scheme", "https",
+				"foo-bar", "some-value",
+			),
+			EndStream:  true, // no DATA frames
+			EndHeaders: true,
+		})
+	}, func(r *http.Request) {
+		t.Logf("GOT %#v", r)
+		if r.Method != "GET" {
+			t.Errorf("Method = %q; want GET", r.Method)
+		}
+		if r.ContentLength != 0 {
+			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
+		}
+		if r.Close {
+			t.Error("Close = true; want false")
+		}
+		if !strings.Contains(r.RemoteAddr, ":") {
+			t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr)
+		}
+		if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 {
+			t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor)
+		}
+		wantHeader := http.Header{
+			"Foo-Bar": []string{"some-value"},
+		}
+		if !reflect.DeepEqual(r.Header, wantHeader) {
+			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
+		}
+	})
+}
+
+// testServerRequest sets up an idle HTTP/2 connection and lets you
+// write a single request with writeReq, and then verify that the
+// *http.Request is built correctly in checkReq.
+func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
+	gotReq := make(chan bool, 1)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		checkReq(r)
+		gotReq <- true
+	})
+	defer st.Close()
+
+	st.writePreface()
+	st.writeInitialSettings()
+	st.wantSettings()
+	st.writeSettingsAck()
+	st.wantSettingsAck()
+
+	writeReq(st)
+
+	select {
+	case <-gotReq:
+	case <-time.After(2 * time.Second):
+		t.Error("timeout waiting for request")
+	}
+
 }
 }
 
 
 func TestServerWithCurl(t *testing.T) {
 func TestServerWithCurl(t *testing.T) {
@@ -237,3 +331,19 @@ func (w twriter) Write(p []byte) (n int, err error) {
 	w.t.Logf("%s", p)
 	w.t.Logf("%s", p)
 	return len(p), nil
 	return len(p), nil
 }
 }
+
+func encodeHeader(t *testing.T, kv ...string) []byte {
+	if len(kv)%2 == 1 {
+		panic("odd number of kv args")
+	}
+	var buf bytes.Buffer
+	enc := hpack.NewEncoder(&buf)
+	for len(kv) > 0 {
+		k, v := kv[0], kv[1]
+		kv = kv[2:]
+		if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
+			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+		}
+	}
+	return buf.Bytes()
+}