Przeglądaj źródła

Merge pull request #450 from Zariel/dont-use-stream0

conn: dont use stream 0 for outgoing requests
Chris Bannister 10 lat temu
rodzic
commit
53ea371a15
3 zmienionych plików z 105 dodań i 6 usunięć
  1. 76 0
      cassandra_test.go
  2. 26 3
      conn.go
  3. 3 3
      conn_test.go

+ 76 - 0
cassandra_test.go

@@ -1982,3 +1982,79 @@ func TestTokenAwareConnPool(t *testing.T) {
 
 	// TODO add verification that the query went to the correct host
 }
+
+type frameWriterFunc func(framer *framer, streamID int) error
+
+func (f frameWriterFunc) writeFrame(framer *framer, streamID int) error {
+	return f(framer, streamID)
+}
+
+func TestStream0(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	var conn *Conn
+	for i := 0; i < 5; i++ {
+		if conn != nil {
+			break
+		}
+
+		conn = session.Pool.Pick(nil)
+	}
+
+	if conn == nil {
+		t.Fatal("no connections available in the pool")
+	}
+
+	writer := frameWriterFunc(func(f *framer, streamID int) error {
+		if streamID == 0 {
+			t.Fatal("should not use stream 0 for requests")
+		}
+		f.writeHeader(0, opError, streamID)
+		f.writeString("i am a bad frame")
+		f.wbuf[0] = 0xFF
+		return f.finishWrite()
+	})
+
+	const expErr = "gocql: error on stream 0:"
+	// need to write out an invalid frame, which we need a connection to do
+	frame, err := conn.exec(writer, nil)
+	if err == nil {
+		t.Fatal("expected to get an error on stream 0")
+	} else if !strings.HasPrefix(err.Error(), expErr) {
+		t.Fatalf("expected to get error prefix %q got %q", expErr, err.Error())
+	} else if frame != nil {
+		t.Fatalf("expected to get nil frame got %+v", frame)
+	}
+}
+
+func TestNegativeStream(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	var conn *Conn
+	for i := 0; i < 5; i++ {
+		if conn != nil {
+			break
+		}
+
+		conn = session.Pool.Pick(nil)
+	}
+
+	if conn == nil {
+		t.Fatal("no connections available in the pool")
+	}
+
+	const stream = -50
+	writer := frameWriterFunc(func(f *framer, streamID int) error {
+		f.writeHeader(0, opOptions, stream)
+		return f.finishWrite()
+	})
+
+	frame, err := conn.exec(writer, nil)
+	if err == nil {
+		t.Fatalf("expected to get an error on stream %d", stream)
+	} else if frame != nil {
+		t.Fatalf("expected to get nil frame got %+v", frame)
+	}
+}

+ 26 - 3
conn.go

@@ -157,8 +157,10 @@ func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn,
 		headerSize = 9
 	}
 
-	if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
+	if cfg.NumStreams <= 0 || cfg.NumStreams >= maxStreams {
 		cfg.NumStreams = maxStreams
+	} else {
+		cfg.NumStreams++
 	}
 
 	c := &Conn{
@@ -180,7 +182,9 @@ func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn,
 		c.setKeepalive(cfg.Keepalive)
 	}
 
-	for i := 0; i < cfg.NumStreams; i++ {
+	// reserve stream 0 incase cassandra returns an error on it without us sending
+	// a request.
+	for i := 1; i < cfg.NumStreams; i++ {
 		c.calls[i].resp = make(chan error)
 		c.uniq <- i
 	}
@@ -363,13 +367,32 @@ func (c *Conn) recv() error {
 
 	if head.stream > len(c.calls) {
 		return fmt.Errorf("gocql: frame header stream is beyond call exepected bounds: %d", head.stream)
-	} else if head.stream < 0 {
+	} else if head.stream == -1 {
 		// TODO: handle cassandra event frames, we shouldnt get any currently
 		_, err := io.CopyN(ioutil.Discard, c, int64(head.length))
 		if err != nil {
 			return err
 		}
 		return nil
+	} else if head.stream <= 0 {
+		// reserved stream that we dont use, probably due to a protocol error
+		// or a bug in Cassandra, this should be an error, parse it and return.
+		framer := newFramer(c, c, c.compressor, c.version)
+		if err := framer.readFrame(&head); err != nil {
+			return err
+		}
+
+		frame, err := framer.parseFrame()
+		if err != nil {
+			return err
+		}
+
+		switch v := frame.(type) {
+		case error:
+			return fmt.Errorf("gocql: error on stream %d: %v", head.stream, v)
+		default:
+			return fmt.Errorf("gocql: received frame on stream %d: %v", head.stream, frame)
+		}
 	}
 
 	call := &c.calls[head.stream]

+ 3 - 3
conn_test.go

@@ -284,7 +284,7 @@ func TestStreams_Protocol1(t *testing.T) {
 	defer db.Close()
 
 	var wg sync.WaitGroup
-	for i := 0; i < db.cfg.NumStreams; i++ {
+	for i := 1; i < db.cfg.NumStreams; i++ {
 		// here were just validating that if we send NumStream request we get
 		// a response for every stream and the lengths for the queries are set
 		// correctly.
@@ -315,7 +315,7 @@ func TestStreams_Protocol2(t *testing.T) {
 	}
 	defer db.Close()
 
-	for i := 0; i < db.cfg.NumStreams; i++ {
+	for i := 1; i < db.cfg.NumStreams; i++ {
 		// the test server processes each conn synchronously
 		// here were just validating that if we send NumStream request we get
 		// a response for every stream and the lengths for the queries are set
@@ -342,7 +342,7 @@ func TestStreams_Protocol3(t *testing.T) {
 	}
 	defer db.Close()
 
-	for i := 0; i < db.cfg.NumStreams; i++ {
+	for i := 1; i < db.cfg.NumStreams; i++ {
 		// the test server processes each conn synchronously
 		// here were just validating that if we send NumStream request we get
 		// a response for every stream and the lengths for the queries are set