Browse Source

Debug-only runtime tracking of funcs running on correct goroutines.

Brad Fitzpatrick 11 years ago
parent
commit
6fe7631778
4 changed files with 217 additions and 2 deletions
  1. 160 0
      gotrack.go
  2. 33 0
      gotrack_test.go
  3. 20 1
      http2.go
  4. 4 1
      http2_test.go

+ 160 - 0
gotrack.go

@@ -0,0 +1,160 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
+// Licensed under the same terms as Go itself:
+// https://code.google.com/p/go/source/browse/LICENSE
+
+// Defensive debug-only utility to track that functions run on the
+// goroutine that they're supposed to.
+
+package http2
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"runtime"
+	"strconv"
+	"sync"
+)
+
+var DebugGoroutines = false
+
+type goroutineLock uint64
+
+func newGoroutineLock() goroutineLock {
+	return goroutineLock(curGoroutineID())
+}
+
+func (g goroutineLock) check() {
+	if !DebugGoroutines {
+		return
+	}
+	if curGoroutineID() != uint64(g) {
+		panic("running on the wrong goroutine")
+	}
+}
+
+var goroutineSpace = []byte("goroutine ")
+
+func curGoroutineID() uint64 {
+	bp := littleBuf.Get().(*[]byte)
+	defer littleBuf.Put(bp)
+	b := *bp
+	b = b[:runtime.Stack(b, false)]
+	// Parse the 4707 otu of "goroutine 4707 ["
+	b = bytes.TrimPrefix(b, goroutineSpace)
+	i := bytes.IndexByte(b, ' ')
+	if i < 0 {
+		panic(fmt.Sprintf("No space found in %q", b))
+	}
+	b = b[:i]
+	n, err := parseUintBytes(b, 10, 64)
+	if err != nil {
+		panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
+	}
+	return n
+}
+
+var littleBuf = sync.Pool{
+	New: func() interface{} {
+		buf := make([]byte, 64)
+		return &buf
+	},
+}
+
+// parseUintBytes is like strconv.ParseUint, but using a []byte.
+func parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
+	var cutoff, maxVal uint64
+
+	if bitSize == 0 {
+		bitSize = int(strconv.IntSize)
+	}
+
+	s0 := s
+	switch {
+	case len(s) < 1:
+		err = strconv.ErrSyntax
+		goto Error
+
+	case 2 <= base && base <= 36:
+		// valid base; nothing to do
+
+	case base == 0:
+		// Look for octal, hex prefix.
+		switch {
+		case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
+			base = 16
+			s = s[2:]
+			if len(s) < 1 {
+				err = strconv.ErrSyntax
+				goto Error
+			}
+		case s[0] == '0':
+			base = 8
+		default:
+			base = 10
+		}
+
+	default:
+		err = errors.New("invalid base " + strconv.Itoa(base))
+		goto Error
+	}
+
+	n = 0
+	cutoff = cutoff64(base)
+	maxVal = 1<<uint(bitSize) - 1
+
+	for i := 0; i < len(s); i++ {
+		var v byte
+		d := s[i]
+		switch {
+		case '0' <= d && d <= '9':
+			v = d - '0'
+		case 'a' <= d && d <= 'z':
+			v = d - 'a' + 10
+		case 'A' <= d && d <= 'Z':
+			v = d - 'A' + 10
+		default:
+			n = 0
+			err = strconv.ErrSyntax
+			goto Error
+		}
+		if int(v) >= base {
+			n = 0
+			err = strconv.ErrSyntax
+			goto Error
+		}
+
+		if n >= cutoff {
+			// n*base overflows
+			n = 1<<64 - 1
+			err = strconv.ErrRange
+			goto Error
+		}
+		n *= uint64(base)
+
+		n1 := n + uint64(v)
+		if n1 < n || n1 > maxVal {
+			// n+v overflows
+			n = 1<<64 - 1
+			err = strconv.ErrRange
+			goto Error
+		}
+		n = n1
+	}
+
+	return n, nil
+
+Error:
+	return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
+}
+
+// Return the first number n such that n*base >= 1<<64.
+func cutoff64(base int) uint64 {
+	if base < 2 {
+		return 0
+	}
+	return (1<<64-1)/uint64(base) + 1
+}

+ 33 - 0
gotrack_test.go

@@ -0,0 +1,33 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
+// Licensed under the same terms as Go itself:
+// https://code.google.com/p/go/source/browse/LICENSE
+
+package http2
+
+import (
+	"fmt"
+	"strings"
+	"testing"
+)
+
+func TestGoroutineLock(t *testing.T) {
+	DebugGoroutines = true
+	g := newGoroutineLock()
+	g.check()
+
+	sawPanic := make(chan interface{})
+	go func() {
+		defer func() { sawPanic <- recover() }()
+		g.check() // should panic
+	}()
+	e := <-sawPanic
+	if e == nil {
+		t.Fatal("did not see panic from check in other goroutine")
+	}
+	if !strings.Contains(fmt.Sprint(e), "wrong goroutine") {
+		t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e)
+	}
+}

+ 20 - 1
http2.go

@@ -73,6 +73,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 		writeHeaderCh:     make(chan headerWriteReq), // must not be buffered
 		doneServing:       make(chan struct{}),
 		maxWriteFrameSize: initialMaxFrameSize,
+		serveG:            newGoroutineLock(),
 	}
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
@@ -89,6 +90,7 @@ type frameAndProcessed struct {
 }
 
 type serverConn struct {
+	// Immutable:
 	hs             *http.Server
 	conn           net.Conn
 	handler        http.Handler
@@ -97,6 +99,9 @@ type serverConn struct {
 	readFrameCh    chan frameAndProcessed // written by serverConn.readFrames
 	readFrameErrCh chan error
 	writeHeaderCh  chan headerWriteReq // must not be buffered
+	serveG         goroutineLock       // used to verify funcs are on serve()
+
+	// Everything following is owned by the serve loop; use serveG.check()
 
 	maxStreamID uint32 // max ever seen
 	streams     map[uint32]*stream
@@ -139,6 +144,7 @@ type stream struct {
 }
 
 func (sc *serverConn) state(streamID uint32) streamState {
+	sc.serveG.check()
 	// http://http2.github.io/http2-spec/#rfc.section.5.1
 	if st, ok := sc.streams[streamID]; ok {
 		return st.state
@@ -170,6 +176,7 @@ func (sc *serverConn) logf(format string, args ...interface{}) {
 }
 
 func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
+	sc.serveG.check()
 	switch {
 	case !validHeader(f.Name):
 		sc.invalidHeader = true
@@ -199,6 +206,7 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 }
 
 func (sc *serverConn) canonicalHeader(v string) string {
+	sc.serveG.check()
 	// TODO: use a sync.Pool instead of putting the cache on *serverConn?
 	cv, ok := sc.canonHeader[v]
 	if !ok {
@@ -208,6 +216,8 @@ func (sc *serverConn) canonicalHeader(v string) string {
 	return cv
 }
 
+// readFrames is the loop that reads incoming frames.
+// It's run on its own goroutine.
 func (sc *serverConn) readFrames() {
 	processed := make(chan struct{}, 1)
 	for {
@@ -223,6 +233,7 @@ func (sc *serverConn) readFrames() {
 }
 
 func (sc *serverConn) serve() {
+	sc.serveG.check()
 	defer sc.conn.Close()
 	defer close(sc.doneServing)
 
@@ -316,6 +327,7 @@ func (sc *serverConn) serve() {
 }
 
 func (sc *serverConn) resetStreamInLoop(se StreamError) error {
+	sc.serveG.check()
 	if err := sc.framer.WriteRSTStream(se.streamID, uint32(se.code)); err != nil {
 		return err
 	}
@@ -324,6 +336,8 @@ func (sc *serverConn) resetStreamInLoop(se StreamError) error {
 }
 
 func (sc *serverConn) processFrame(f Frame) error {
+	sc.serveG.check()
+
 	if s := sc.curHeaderStreamID; s != 0 {
 		if cf, ok := f.(*ContinuationFrame); !ok {
 			return ConnectionError(ErrCodeProtocol)
@@ -346,6 +360,7 @@ func (sc *serverConn) processFrame(f Frame) error {
 }
 
 func (sc *serverConn) processSettings(f *SettingsFrame) error {
+	sc.serveG.check()
 	f.ForeachSetting(func(s Setting) {
 		log.Printf("  setting %s = %v", s.ID, s.Val)
 	})
@@ -353,6 +368,7 @@ func (sc *serverConn) processSettings(f *SettingsFrame) error {
 }
 
 func (sc *serverConn) processHeaders(f *HeadersFrame) error {
+	sc.serveG.check()
 	id := f.Header().StreamID
 
 	// http://http2.github.io/http2-spec/#rfc.section.5.1.1
@@ -386,6 +402,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 }
 
 func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
+	sc.serveG.check()
 	id := f.Header().StreamID
 	if sc.curHeaderStreamID != id {
 		return ConnectionError(ErrCodeProtocol)
@@ -394,6 +411,7 @@ func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
 }
 
 func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, end bool) error {
+	sc.serveG.check()
 	if _, err := sc.hpackDecoder.Write(frag); err != nil {
 		// TODO: convert to stream error I assume?
 		return err
@@ -423,6 +441,7 @@ func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, e
 	return nil
 }
 
+// Run on its own goroutine.
 func (sc *serverConn) startHandler(streamID uint32, bodyOpen bool, method, path, scheme, authority string, reqHeader http.Header) {
 	var tlsState *tls.ConnectionState // make this non-nil if https
 	if scheme == "https" {
@@ -486,8 +505,8 @@ func (sc *serverConn) writeHeader(req headerWriteReq) {
 	sc.writeHeaderCh <- req
 }
 
-// called from serverConn.serve loop.
 func (sc *serverConn) writeHeaderInLoop(req headerWriteReq) error {
+	sc.serveG.check()
 	sc.headerWriteBuf.Reset()
 	// TODO: remove this strconv
 	sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(req.httpResCode)})

+ 4 - 1
http2_test.go

@@ -29,7 +29,10 @@ import (
 	"github.com/bradfitz/http2/hpack"
 )
 
-func init() { VerboseLogs = true }
+func init() {
+	VerboseLogs = true
+	DebugGoroutines = true
+}
 
 type serverTester struct {
 	cc     net.Conn // client conn