瀏覽代碼

Refactor recv loop, inline dispatch

Make the recv loop simpler, pass the error back directly from
recv so that we dont close the connection on non fatal errors.

Use a mutex to guard the read/write buffer inside the frame from
the connection, it appears we can start reading a frame into the
framers buffer as soon as we have finished writing the frame out.
Chris Bannister 10 年之前
父節點
當前提交
b49a43aff1
共有 1 個文件被更改,包括 37 次插入43 次删除
  1. 37 43
      conn.go

+ 37 - 43
conn.go

@@ -256,23 +256,24 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 // open and is therefore usually called in a separate goroutine.
 func (c *Conn) serve() {
 	var (
-		err    error
-		header frameHeader
+		err error
 	)
 
 	for {
-		header, err = c.recv()
+		err = c.recv()
 		if err != nil {
 			break
 		}
-		c.dispatch(header)
 	}
 
 	c.Close()
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
-		if atomic.CompareAndSwapInt32(&req.active, 1, 0) {
-			req.resp <- struct{}{}
+		if atomic.CompareAndSwapInt32(&req.active, 1, -1) {
+			// we need to send the error to all waiting queries, put the state
+			// of this conn into not active so that it can not execute any queries.
+			// Here use -1.
+			req.resp <- err
 			close(req.resp)
 		}
 	}
@@ -282,7 +283,7 @@ func (c *Conn) serve() {
 	}
 }
 
-func (c *Conn) recv() (frameHeader, error) {
+func (c *Conn) recv() error {
 	// not safe for concurrent reads
 
 	// read a full header, ignore timeouts, as this is being ran in a loop
@@ -291,54 +292,34 @@ func (c *Conn) recv() (frameHeader, error) {
 	// were just reading headers over and over and copy bodies
 	head, err := readHeader(c.r, c.headerBuf)
 	if err != nil {
-		return frameHeader{}, err
+		return err
 	}
 
 	call := &c.calls[head.stream]
-
 	call.mu.Lock()
-	log.Printf("readframe stream=%v\n", head.stream)
 	err = call.framer.readFrame(&head)
 	call.mu.Unlock()
 	if err != nil {
-		return frameHeader{}, err
+		return err
 	}
 
-	if head.version.version() != c.version {
-		return frameHeader{}, NewErrProtocol("unexpected protocol version in response: got %d expected %d", head.version.version(), c.version)
+	// the caller went away somehow
+	if atomic.CompareAndSwapInt32(&call.active, 1, 0) {
+		call.resp <- nil
 	}
 
-	return head, nil
-}
-
-func (c *Conn) dispatch(header frameHeader) {
-	id := header.stream
-	if id >= len(c.calls) {
-		// should this panic?
-		return
-	}
-
-	// TODO: replace this with a sparse map
-	call := &c.calls[id]
-
-	call.resp <- struct{}{}
 	atomic.AddInt32(&c.nwait, -1)
+	c.uniq <- head.stream
 
-	c.uniq <- id
+	return nil
 }
 
 type callReq struct {
 	active int32
 	// could use a waitgroup but this allows us to do timeouts on the read/send
-	resp   chan struct{}
+	resp   chan error
 	mu     sync.Mutex
 	framer *framer
-	err    error
-}
-
-type callResp struct {
-	framer *framer
-	err    error
 }
 
 func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
@@ -346,22 +327,31 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 	stream := <-c.uniq
 
 	call := &c.calls[stream]
-	atomic.StoreInt32(&call.active, 1)
-	defer atomic.StoreInt32(&call.active, 0)
+	if !atomic.CompareAndSwapInt32(&call.active, 0, 1) {
+		panic("stream not available")
+	}
+
+	if call.resp == nil {
+		call.resp = make(chan error, 1)
+	}
 
-	// resp is basically a waiting semafore protecting the framer
-	call.resp = make(chan struct{}, 1)
+	// resp is basically a waiting semaphore protecting the framer
 
 	// log.Printf("%v: OUT stream=%d (%T) req=%v\n", c.conn.LocalAddr(), stream, req, req)
 	framer := newFramer(c, c, c.compressor, c.version)
 	defer framerPool.Put(framer)
+
 	call.framer = framer
 
 	if tracer != nil {
 		framer.trace()
 	}
 
-	log.Printf("writing frame stream=%v\n", stream)
+	// there is a race that we can read and write to the same buffer, I dont think
+	// the data will actually corrupt but to be safe and apepase the race detector gods,
+	// guard it.
+	// We could fix this by using seperate read and write buffers, which may end up
+	// being faster and easier to reason about.
 	call.mu.Lock()
 	err := req.writeFrame(framer, stream)
 	call.mu.Unlock()
@@ -369,9 +359,13 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		return nil, err
 	}
 
-	<-call.resp
-	if call.err != nil {
-		return nil, call.err
+	err = <-call.resp
+	if err != nil {
+		return nil, err
+	}
+
+	if v := framer.header.version.version(); v != c.version {
+		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
 	}
 
 	frame, err := framer.parseFrame()