Browse Source

Merge pull request #431 from Zariel/dont-swallow-close-with-error

dont swallow the error from closeWithError
Chris Bannister 10 years ago
parent
commit
8c4377d735
2 changed files with 44 additions and 2 deletions
  1. 8 2
      conn.go
  2. 36 0
      conn_test.go

+ 8 - 2
conn.go

@@ -622,15 +622,21 @@ func (c *Conn) closeWithError(err error) {
 		return
 		return
 	}
 	}
 
 
-	close(c.quit)
-
 	for id := 0; id < len(c.calls); id++ {
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
 		req := &c.calls[id]
 		// we need to send the error to all waiting queries, put the state
 		// we need to send the error to all waiting queries, put the state
 		// of this conn into not active so that it can not execute any queries.
 		// of this conn into not active so that it can not execute any queries.
 		atomic.StoreInt32(&req.waiting, -1)
 		atomic.StoreInt32(&req.waiting, -1)
+
+		if err != nil {
+			select {
+			case req.resp <- err:
+			default:
+			}
+		}
 	}
 	}
 
 
+	close(c.quit)
 	c.conn.Close()
 	c.conn.Close()
 
 
 	if c.started && err != nil {
 	if c.started && err != nil {

+ 36 - 0
conn_test.go

@@ -533,6 +533,42 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestQueryTimeoutClose(t *testing.T) {
+	srv := NewTestServer(t, defaultProto)
+	defer srv.Stop()
+
+	cluster := NewCluster(srv.Address)
+	// Set the timeout arbitrarily low so that the query hits the timeout in a
+	// timely manner.
+	cluster.Timeout = 1000 * time.Millisecond
+	cluster.NumConns = 1
+	cluster.NumStreams = 1
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatalf("NewCluster: %v", err)
+	}
+
+	ch := make(chan error)
+	go func() {
+		err := db.Query("timeout").Exec()
+		ch <- err
+	}()
+	// ensure that the above goroutine gets sheduled
+	time.Sleep(50 * time.Millisecond)
+
+	db.Close()
+	select {
+	case err = <-ch:
+	case <-time.After(1 * time.Second):
+		t.Fatal("timedout waiting to get a response once cluster is closed")
+	}
+
+	if err != ErrConnectionClosed {
+		t.Fatalf("expected to get %v got %v", ErrConnectionClosed, err)
+	}
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {