Browse Source

Allow Host or :authority in requests. More tests.

Brad Fitzpatrick 11 years ago
parent
commit
3302cb09fe
2 changed files with 54 additions and 0 deletions
  1. 3 0
      http2.go
  2. 51 0
      http2_test.go

+ 3 - 0
http2.go

@@ -386,6 +386,9 @@ func (sc *serverConn) startHandler(streamID uint32, bodyOpen bool, method, path,
 		// TODO: get from sc's ConnectionState
 		// TODO: get from sc's ConnectionState
 		tlsState = &tls.ConnectionState{}
 		tlsState = &tls.ConnectionState{}
 	}
 	}
+	if authority == "" {
+		authority = reqHeader.Get("Host")
+	}
 	req := &http.Request{
 	req := &http.Request{
 		Method:     method,
 		Method:     method,
 		URL:        &url.URL{},
 		URL:        &url.URL{},

+ 51 - 0
http2_test.go

@@ -12,6 +12,7 @@ import (
 	"crypto/tls"
 	"crypto/tls"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"io"
 	"log"
 	"log"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
@@ -196,6 +197,53 @@ func TestServer_Request_Get(t *testing.T) {
 		if !reflect.DeepEqual(r.Header, wantHeader) {
 		if !reflect.DeepEqual(r.Header, wantHeader) {
 			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
 			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
 		}
 		}
+		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
+			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
+		}
+	})
+}
+
+// Using a Host header, instead of :authority
+func TestServer_Request_Get_Host(t *testing.T) {
+	const host = "example.com"
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID: 1, // clients send odd numbers
+			BlockFragment: encodeHeader(t,
+				":method", "GET",
+				":path", "/",
+				":scheme", "https",
+				"host", host,
+			),
+			EndStream:  true,
+			EndHeaders: true,
+		})
+	}, func(r *http.Request) {
+		if r.Host != host {
+			t.Errorf("Host = %q; want %q", r.Host, host)
+		}
+	})
+}
+
+// Using an :authority pseudo-header, instead of Host
+func TestServer_Request_Get_Authority(t *testing.T) {
+	const host = "example.com"
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID: 1, // clients send odd numbers
+			BlockFragment: encodeHeader(t,
+				":method", "GET",
+				":path", "/",
+				":scheme", "https",
+				":authority", host,
+			),
+			EndStream:  true,
+			EndHeaders: true,
+		})
+	}, func(r *http.Request) {
+		if r.Host != host {
+			t.Errorf("Host = %q; want %q", r.Host, host)
+		}
 	})
 	})
 }
 }
 
 
@@ -205,6 +253,9 @@ func TestServer_Request_Get(t *testing.T) {
 func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
 func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
 	gotReq := make(chan bool, 1)
 	gotReq := make(chan bool, 1)
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		if r.Body == nil {
+			t.Fatal("nil Body")
+		}
 		checkReq(r)
 		checkReq(r)
 		gotReq <- true
 		gotReq <- true
 	})
 	})