Kaynağa Gözat

Merge pull request #448 from Zariel/exec-panic

conn: ensure that call.framer is not nil
Chris Bannister 10 yıl önce
ebeveyn
işleme
6927712315
2 değiştirilmiş dosya ile 56 ekleme ve 5 silme
  1. 8 5
      conn.go
  2. 48 0
      conn_test.go

+ 8 - 5
conn.go

@@ -444,11 +444,6 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 
 
 	select {
 	select {
 	case err := <-call.resp:
 	case err := <-call.resp:
-		// dont release the stream if detect a timeout as another request can reuse
-		// that stream and get a response for the old request, which we have no
-		// easy way of detecting.
-		defer c.releaseStream(stream)
-
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -460,6 +455,14 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		return nil, ErrConnectionClosed
 		return nil, ErrConnectionClosed
 	}
 	}
 
 
+	// dont release the stream if detect a timeout as another request can reuse
+	// that stream and get a response for the old request, which we have no
+	// easy way of detecting.
+	//
+	// Ensure that the stream is not released if there are potentially outstanding
+	// requests on the stream to prevent nil pointer dereferences in recv().
+	defer c.releaseStream(stream)
+
 	if v := framer.header.version.version(); v != c.version {
 	if v := framer.header.version.version(); v != c.version {
 		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
 		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
 	}
 	}

+ 48 - 0
conn_test.go

@@ -569,6 +569,54 @@ func TestQueryTimeoutClose(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestExecPanic(t *testing.T) {
+	t.Skip("test can cause unrelated failures, skipping until it can be fixed.")
+	srv := NewTestServer(t, defaultProto)
+	defer srv.Stop()
+
+	cluster := NewCluster(srv.Address)
+	// Set the timeout arbitrarily low so that the query hits the timeout in a
+	// timely manner.
+	cluster.Timeout = 5 * time.Millisecond
+	cluster.NumConns = 1
+	// cluster.NumStreams = 1
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	streams := db.cfg.NumStreams
+
+	wg := &sync.WaitGroup{}
+	wg.Add(streams)
+	for i := 0; i < streams; i++ {
+		go func() {
+			defer wg.Done()
+			q := db.Query("void")
+			for {
+				if err := q.Exec(); err != nil {
+					return
+				}
+			}
+		}()
+	}
+
+	wg.Add(1)
+
+	go func() {
+		defer wg.Done()
+		for i := 0; i < int(TimeoutLimit); i++ {
+			db.Query("timeout").Exec()
+		}
+	}()
+
+	wg.Wait()
+
+	time.Sleep(500 * time.Millisecond)
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {