Переглянути джерело

Refactor recv loop, inline dispatch

Make the recv loop simpler, pass the error back directly from
recv so that we dont close the connection on non fatal errors.

Use a mutex to guard the read/write buffer inside the frame from
the connection, it appears we can start reading a frame into the
framers buffer as soon as we have finished writing the frame out.
Chris Bannister 10 роки тому
батько
коміт
b49a43aff1
1 змінених файлів з 37 додано та 43 видалено
  1. 37 43
      conn.go

+ 37 - 43
conn.go

@@ -256,23 +256,24 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 // open and is therefore usually called in a separate goroutine.
 func (c *Conn) serve() {
 	var (
-		err    error
-		header frameHeader
+		err error
 	)
 
 	for {
-		header, err = c.recv()
+		err = c.recv()
 		if err != nil {
 			break
 		}
-		c.dispatch(header)
 	}
 
 	c.Close()
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
-		if atomic.CompareAndSwapInt32(&req.active, 1, 0) {
-			req.resp <- struct{}{}
+		if atomic.CompareAndSwapInt32(&req.active, 1, -1) {
+			// we need to send the error to all waiting queries, put the state
+			// of this conn into not active so that it can not execute any queries.
+			// Here use -1.
+			req.resp <- err
 			close(req.resp)
 		}
 	}
@@ -282,7 +283,7 @@ func (c *Conn) serve() {
 	}
 }
 
-func (c *Conn) recv() (frameHeader, error) {
+func (c *Conn) recv() error {
 	// not safe for concurrent reads
 
 	// read a full header, ignore timeouts, as this is being ran in a loop
@@ -291,54 +292,34 @@ func (c *Conn) recv() (frameHeader, error) {
 	// were just reading headers over and over and copy bodies
 	head, err := readHeader(c.r, c.headerBuf)
 	if err != nil {
-		return frameHeader{}, err
+		return err
 	}
 
 	call := &c.calls[head.stream]
-
 	call.mu.Lock()
-	log.Printf("readframe stream=%v\n", head.stream)
 	err = call.framer.readFrame(&head)
 	call.mu.Unlock()
 	if err != nil {
-		return frameHeader{}, err
+		return err
 	}
 
-	if head.version.version() != c.version {
-		return frameHeader{}, NewErrProtocol("unexpected protocol version in response: got %d expected %d", head.version.version(), c.version)
+	// the caller went away somehow
+	if atomic.CompareAndSwapInt32(&call.active, 1, 0) {
+		call.resp <- nil
 	}
 
-	return head, nil
-}
-
-func (c *Conn) dispatch(header frameHeader) {
-	id := header.stream
-	if id >= len(c.calls) {
-		// should this panic?
-		return
-	}
-
-	// TODO: replace this with a sparse map
-	call := &c.calls[id]
-
-	call.resp <- struct{}{}
 	atomic.AddInt32(&c.nwait, -1)
+	c.uniq <- head.stream
 
-	c.uniq <- id
+	return nil
 }
 
 type callReq struct {
 	active int32
 	// could use a waitgroup but this allows us to do timeouts on the read/send
-	resp   chan struct{}
+	resp   chan error
 	mu     sync.Mutex
 	framer *framer
-	err    error
-}
-
-type callResp struct {
-	framer *framer
-	err    error
 }
 
 func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
@@ -346,22 +327,31 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 	stream := <-c.uniq
 
 	call := &c.calls[stream]
-	atomic.StoreInt32(&call.active, 1)
-	defer atomic.StoreInt32(&call.active, 0)
+	if !atomic.CompareAndSwapInt32(&call.active, 0, 1) {
+		panic("stream not available")
+	}
+
+	if call.resp == nil {
+		call.resp = make(chan error, 1)
+	}
 
-	// resp is basically a waiting semafore protecting the framer
-	call.resp = make(chan struct{}, 1)
+	// resp is basically a waiting semaphore protecting the framer
 
 	// log.Printf("%v: OUT stream=%d (%T) req=%v\n", c.conn.LocalAddr(), stream, req, req)
 	framer := newFramer(c, c, c.compressor, c.version)
 	defer framerPool.Put(framer)
+
 	call.framer = framer
 
 	if tracer != nil {
 		framer.trace()
 	}
 
-	log.Printf("writing frame stream=%v\n", stream)
+	// there is a race that we can read and write to the same buffer, I dont think
+	// the data will actually corrupt but to be safe and apepase the race detector gods,
+	// guard it.
+	// We could fix this by using seperate read and write buffers, which may end up
+	// being faster and easier to reason about.
 	call.mu.Lock()
 	err := req.writeFrame(framer, stream)
 	call.mu.Unlock()
@@ -369,9 +359,13 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		return nil, err
 	}
 
-	<-call.resp
-	if call.err != nil {
-		return nil, call.err
+	err = <-call.resp
+	if err != nil {
+		return nil, err
+	}
+
+	if v := framer.header.version.version(); v != c.version {
+		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
 	}
 
 	frame, err := framer.parseFrame()