소스 검색

ensure streams are returned in the correct order

When timeouts occur we need to ensure that the stream is not returned
until the server returns the stream.
Chris Bannister 10 년 전
부모
커밋
bca09cc483
1개의 변경된 파일30개의 추가작업 그리고 15개의 파일을 삭제
  1. 30 15
      conn.go

+ 30 - 15
conn.go

@@ -178,7 +178,7 @@ func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn,
 	}
 
 	for i := 0; i < cfg.NumStreams; i++ {
-		c.calls[i].resp = make(chan error, 1)
+		c.calls[i].resp = make(chan error)
 		c.uniq <- i
 	}
 
@@ -334,11 +334,17 @@ func (c *Conn) recv() error {
 		}
 	}
 
+	if !atomic.CompareAndSwapInt32(&call.waiting, 1, 0) {
+		// the waiting thread timed out and is no longer waiting, the stream has
+		// not yet been readded to the chan so it cant be used again,
+		c.releaseStream(head.stream)
+		return nil
+	}
+
 	select {
 	case call.resp <- err:
 	default:
 		c.releaseStream(head.stream)
-		// in case the caller timedout
 	}
 
 	return nil
@@ -346,11 +352,16 @@ func (c *Conn) recv() error {
 
 type callReq struct {
 	// could use a waitgroup but this allows us to do timeouts on the read/send
-	resp   chan error
-	framer *framer
+	resp    chan error
+	framer  *framer
+	waiting int32
 }
 
 func (c *Conn) releaseStream(stream int) {
+	call := &c.calls[stream]
+	framerPool.Put(call.framer)
+	call.framer = nil
+
 	select {
 	case c.uniq <- stream:
 	default:
@@ -376,20 +387,27 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		framer.trace()
 	}
 
+	atomic.StoreInt32(&call.waiting, 1)
+	defer atomic.StoreInt32(&call.waiting, 0)
+
 	err := req.writeFrame(framer, stream)
 	if err != nil {
 		return nil, err
 	}
 
-	err = <-call.resp
-
-	// dont release the stream if detect a timeout as another request can reuse
-	// that stream and get a response for the old request, which we have no
-	// easy way of detecting.
-	defer c.releaseStream(stream)
+	select {
+	case err := <-call.resp:
+		// dont release the stream if detect a timeout as another request can reuse
+		// that stream and get a response for the old request, which we have no
+		// easy way of detecting.
+		defer c.releaseStream(stream)
 
-	if err != nil {
-		return nil, err
+		if err != nil {
+			return nil, err
+		}
+	case <-time.After(c.timeout):
+		c.handleTimeout()
+		return nil, ErrTimeoutNoResponse
 	}
 
 	if v := framer.header.version.version(); v != c.version {
@@ -405,9 +423,6 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		tracer.Trace(framer.traceID)
 	}
 
-	framerPool.Put(framer)
-	call.framer = nil
-
 	return frame, nil
 }