Ver código fonte

correctly handle negative streams, differentiate -1

Chris Bannister 10 anos atrás
pai
commit
9be46a1f88
2 arquivos alterados com 35 adições e 4 exclusões
  1. 31 0
      cassandra_test.go
  2. 4 4
      conn.go

+ 31 - 0
cassandra_test.go

@@ -2027,3 +2027,34 @@ func TestStream0(t *testing.T) {
 		t.Fatalf("expected to get nil frame got %+v", frame)
 	}
 }
+
+func TestNegativeStream(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	var conn *Conn
+	for i := 0; i < 5; i++ {
+		if conn != nil {
+			break
+		}
+
+		conn = session.Pool.Pick(nil)
+	}
+
+	if conn == nil {
+		t.Fatal("no connections available in the pool")
+	}
+
+	const stream = -50
+	writer := frameWriterFunc(func(f *framer, streamID int) error {
+		f.writeHeader(0, opOptions, stream)
+		return f.finishWrite()
+	})
+
+	frame, err := conn.exec(writer, nil)
+	if err == nil {
+		t.Fatalf("expected to get an error on stream %d", stream)
+	} else if frame != nil {
+		t.Fatalf("expected to get nil frame got %+v", frame)
+	}
+}

+ 4 - 4
conn.go

@@ -367,14 +367,14 @@ func (c *Conn) recv() error {
 
 	if head.stream > len(c.calls) {
 		return fmt.Errorf("gocql: frame header stream is beyond call exepected bounds: %d", head.stream)
-	} else if head.stream < 0 {
+	} else if head.stream == -1 {
 		// TODO: handle cassandra event frames, we shouldnt get any currently
 		_, err := io.CopyN(ioutil.Discard, c, int64(head.length))
 		if err != nil {
 			return err
 		}
 		return nil
-	} else if head.stream == 0 {
+	} else if head.stream <= 0 {
 		// reserved stream that we dont use, probably due to a protocol error
 		// or a bug in Cassandra, this should be an error, parse it and return.
 		framer := newFramer(c, c, c.compressor, c.version)
@@ -389,9 +389,9 @@ func (c *Conn) recv() error {
 
 		switch v := frame.(type) {
 		case error:
-			return fmt.Errorf("gocql: error on stream 0: %v", v)
+			return fmt.Errorf("gocql: error on stream %d: %v", head.stream, v)
 		default:
-			return fmt.Errorf("gocql: received frame on stream 0: %v", frame)
+			return fmt.Errorf("gocql: received frame on stream %d: %v", head.stream, frame)
 		}
 	}