瀏覽代碼

dont swallow the error from closeWithError

Try to send the error to the waiting requests so that their
callers will see the correct error before closing the quit
channel.

Ensure that this error is not nil so that exec() will not see
a nil frame.
Chris Bannister 10 年之前
父節點
當前提交
d4cdc5dd42
共有 2 個文件被更改,包括 44 次插入2 次删除
  1. 8 2
      conn.go
  2. 36 0
      conn_test.go

+ 8 - 2
conn.go

@@ -622,15 +622,21 @@ func (c *Conn) closeWithError(err error) {
 		return
 	}
 
-	close(c.quit)
-
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
 		// we need to send the error to all waiting queries, put the state
 		// of this conn into not active so that it can not execute any queries.
 		atomic.StoreInt32(&req.waiting, -1)
+
+		if err != nil {
+			select {
+			case req.resp <- err:
+			default:
+			}
+		}
 	}
 
+	close(c.quit)
 	c.conn.Close()
 
 	if c.started && err != nil {

+ 36 - 0
conn_test.go

@@ -533,6 +533,42 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	}
 }
 
+func TestQueryTimeoutClose(t *testing.T) {
+	srv := NewTestServer(t, defaultProto)
+	defer srv.Stop()
+
+	cluster := NewCluster(srv.Address)
+	// Set the timeout arbitrarily low so that the query hits the timeout in a
+	// timely manner.
+	cluster.Timeout = 1000 * time.Millisecond
+	cluster.NumConns = 1
+	cluster.NumStreams = 1
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatalf("NewCluster: %v", err)
+	}
+
+	ch := make(chan error)
+	go func() {
+		err := db.Query("timeout").Exec()
+		ch <- err
+	}()
+	// ensure that the above goroutine gets sheduled
+	time.Sleep(50 * time.Millisecond)
+
+	db.Close()
+	select {
+	case err = <-ch:
+	case <-time.After(1 * time.Second):
+		t.Fatal("timedout waiting to get a response once cluster is closed")
+	}
+
+	if err != ErrConnectionClosed {
+		t.Fatalf("expected to get %v got %v", ErrConnectionClosed, err)
+	}
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {