Browse Source

refactoring cleanup

Brad Fitzpatrick 11 years ago
parent
commit
bc7d81f131
1 changed files with 74 additions and 66 deletions
  1. 74 66
      http2.go

+ 74 - 66
http2.go

@@ -110,29 +110,28 @@ type serverConn struct {
 	flow           *flow               // the connection-wide one
 
 	// Everything following is owned by the serve loop; use serveG.check()
-
-	maxStreamID uint32 // max ever seen
-	streams     map[uint32]*stream
-
+	maxStreamID       uint32 // max ever seen
+	streams           map[uint32]*stream
 	maxWriteFrameSize uint32 // TODO: update this when settings come in
 	initialWindowSize int32
+	canonHeader       map[string]string // http2-lower-case -> Go-Canonical-Case
 	sentGoAway        bool
+	req               requestParam // non-zero while reading request headers
+	headerWriteBuf    bytes.Buffer // used to write response headers
+}
 
-	// State related to parsing current headers:
+// requestParam is the state of the next request, initialized over
+// potentially several frames HEADERS + zero or more CONTINUATION
+// frames.
+type requestParam struct {
+	// stream is non-nil if we're reading (HEADER or CONTINUATION)
+	// frames for a request (but not DATA).
+	stream            *stream
 	header            http.Header
-	canonHeader       map[string]string // http2-lower-case -> Go-Canonical-Case
 	method, path      string
 	scheme, authority string
 	sawRegularHeader  bool // saw a non-pseudo header already
-	invalidHeader     bool
-
-	// curHeaderStreamID and curStream are non-zero if we're in
-	// the middle of parsing headers that span multiple frames.
-	curHeaderStreamID uint32
-	curStream         *stream
-
-	// State related to writing current headers:
-	headerWriteBuf bytes.Buffer // written/accessed from serve goroutine
+	invalidHeader     bool // an invalid header was seen
 }
 
 type streamState int
@@ -202,23 +201,23 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 	sc.serveG.check()
 	switch {
 	case !validHeader(f.Name):
-		sc.invalidHeader = true
+		sc.req.invalidHeader = true
 	case strings.HasPrefix(f.Name, ":"):
-		if sc.sawRegularHeader {
+		if sc.req.sawRegularHeader {
 			sc.logf("pseudo-header after regular header")
-			sc.invalidHeader = true
+			sc.req.invalidHeader = true
 			return
 		}
 		var dst *string
 		switch f.Name {
 		case ":method":
-			dst = &sc.method
+			dst = &sc.req.method
 		case ":path":
-			dst = &sc.path
+			dst = &sc.req.path
 		case ":scheme":
-			dst = &sc.scheme
+			dst = &sc.req.scheme
 		case ":authority":
-			dst = &sc.authority
+			dst = &sc.req.authority
 		default:
 			// 8.1.2.1 Pseudo-Header Fields
 			// "Endpoints MUST treat a request or response
@@ -226,25 +225,25 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 			// pseudo-header fields as malformed (Section
 			// 8.1.2.6)."
 			sc.logf("invalid pseudo-header %q", f.Name)
-			sc.invalidHeader = true
+			sc.req.invalidHeader = true
 			return
 		}
 		if *dst != "" {
 			sc.logf("duplicate pseudo-header %q sent", f.Name)
-			sc.invalidHeader = true
+			sc.req.invalidHeader = true
 			return
 		}
 		*dst = f.Value
 	case f.Name == "cookie":
-		sc.sawRegularHeader = true
-		if s, ok := sc.header["Cookie"]; ok && len(s) == 1 {
+		sc.req.sawRegularHeader = true
+		if s, ok := sc.req.header["Cookie"]; ok && len(s) == 1 {
 			s[0] = s[0] + "; " + f.Value
 		} else {
-			sc.header.Add("Cookie", f.Value)
+			sc.req.header.Add("Cookie", f.Value)
 		}
 	default:
-		sc.sawRegularHeader = true
-		sc.header.Add(sc.canonicalHeader(f.Name), f.Value)
+		sc.req.sawRegularHeader = true
+		sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value)
 	}
 }
 
@@ -389,10 +388,19 @@ func (sc *serverConn) resetStreamInLoop(se StreamError) error {
 	return nil
 }
 
+func (sc *serverConn) curHeaderStreamID() uint32 {
+	sc.serveG.check()
+	st := sc.req.stream
+	if st == nil {
+		return 0
+	}
+	return st.id
+}
+
 func (sc *serverConn) processFrame(f Frame) error {
 	sc.serveG.check()
 
-	if s := sc.curHeaderStreamID; s != 0 {
+	if s := sc.curHeaderStreamID(); s != 0 {
 		if cf, ok := f.(*ContinuationFrame); !ok {
 			return ConnectionError(ErrCodeProtocol)
 		} else if cf.Header().StreamID != s {
@@ -516,7 +524,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		return nil
 	}
 	// http://http2.github.io/http2-spec/#rfc.section.5.1.1
-	if id%2 != 1 || id <= sc.maxStreamID {
+	if id%2 != 1 || id <= sc.maxStreamID || sc.req.stream != nil {
 		// Streams initiated by a client MUST use odd-numbered
 		// stream identifiers. [...] The identifier of a newly
 		// established stream MUST be numerically greater than all
@@ -529,7 +537,6 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	if id > sc.maxStreamID {
 		sc.maxStreamID = id
 	}
-
 	st := &stream{
 		id:    id,
 		state: stateOpen,
@@ -539,23 +546,17 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		st.state = stateHalfClosedRemote
 	}
 	sc.streams[id] = st
-
-	sc.header = make(http.Header)
-	sc.method = ""
-	sc.path = ""
-	sc.scheme = ""
-	sc.authority = ""
-	sc.invalidHeader = false
-	sc.sawRegularHeader = false
-	sc.curHeaderStreamID = id
-	sc.curStream = st
+	sc.req = requestParam{
+		stream: st,
+		header: make(http.Header),
+	}
 	return sc.processHeaderBlockFragment(id, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 
 func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
 	sc.serveG.check()
 	id := f.Header().StreamID
-	if sc.curHeaderStreamID != id {
+	if sc.curHeaderStreamID() != id {
 		return ConnectionError(ErrCodeProtocol)
 	}
 	return sc.processHeaderBlockFragment(id, f.HeaderBlockFragment(), f.HeadersEnded())
@@ -574,8 +575,20 @@ func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, e
 		// TODO: convert to stream error I assume?
 		return err
 	}
-	if sc.invalidHeader || sc.method == "" || sc.path == "" ||
-		(sc.scheme != "https" && sc.scheme != "http") {
+	rw, req, err := sc.newWriterAndRequest()
+	sc.req = requestParam{}
+	if err != nil {
+		return err
+	}
+	go sc.runHandler(rw, req)
+	return nil
+}
+
+func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) {
+	sc.serveG.check()
+	rp := &sc.req
+	if rp.invalidHeader || rp.method == "" || rp.path == "" ||
+		(rp.scheme != "https" && rp.scheme != "http") {
 		// See 8.1.2.6 Malformed Requests and Responses:
 		//
 		// Malformed requests or responses that are detected
@@ -586,34 +599,24 @@ func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, e
 		// "All HTTP/2 requests MUST include exactly one valid
 		// value for the :method, :scheme, and :path
 		// pseudo-header fields"
-		return StreamError{streamID, ErrCodeProtocol}
+		return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
 	}
-	curStream := sc.curStream
-	sc.curHeaderStreamID = 0
-	sc.curStream = nil
-
-	// TODO: transition streamID state
-	go sc.startHandler(curStream.id, curStream.state == stateOpen, sc.method, sc.path, sc.scheme, sc.authority, sc.header)
-
-	return nil
-}
-
-// Run on its own goroutine.
-func (sc *serverConn) startHandler(streamID uint32, bodyOpen bool, method, path, scheme, authority string, reqHeader http.Header) {
 	var tlsState *tls.ConnectionState // make this non-nil if https
-	if scheme == "https" {
+	if rp.scheme == "https" {
 		// TODO: get from sc's ConnectionState
 		tlsState = &tls.ConnectionState{}
 	}
+	authority := rp.authority
 	if authority == "" {
-		authority = reqHeader.Get("Host")
+		authority = rp.header.Get("Host")
 	}
+	bodyOpen := rp.stream.state == stateOpen
 	req := &http.Request{
-		Method:     method,
+		Method:     rp.method,
 		URL:        &url.URL{},
 		RemoteAddr: sc.conn.RemoteAddr().String(),
-		Header:     reqHeader,
-		RequestURI: path,
+		Header:     rp.header,
+		RequestURI: rp.path,
 		Proto:      "HTTP/2.0",
 		ProtoMajor: 2,
 		ProtoMinor: 0,
@@ -621,12 +624,12 @@ func (sc *serverConn) startHandler(streamID uint32, bodyOpen bool, method, path,
 		Host:       authority,
 		Body: &requestBody{
 			sc:       sc,
-			streamID: streamID,
+			streamID: rp.stream.id,
 			hasBody:  bodyOpen,
 		},
 	}
 	if bodyOpen {
-		if vv, ok := reqHeader["Content-Length"]; ok {
+		if vv, ok := rp.header["Content-Length"]; ok {
 			req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
 		} else {
 			req.ContentLength = -1
@@ -634,8 +637,13 @@ func (sc *serverConn) startHandler(streamID uint32, bodyOpen bool, method, path,
 	}
 	rw := &responseWriter{
 		sc:       sc,
-		streamID: streamID,
+		streamID: rp.stream.id,
 	}
+	return rw, req, nil
+}
+
+// Run on its own goroutine.
+func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request) {
 	defer rw.handlerDone()
 	// TODO: catch panics like net/http.Server
 	sc.handler.ServeHTTP(rw, req)