瀏覽代碼

Merge pull request #192 from relops/verify_query_arg_length

Prevent query argument length mismatches causing a panic
Phillip Couto 11 年之前
父節點
當前提交
20ba163d1e
共有 2 個文件被更改,包括 80 次插入7 次删除
  1. 67 7
      cassandra_test.go
  2. 13 0
      conn.go

+ 67 - 7
cassandra_test.go

@@ -308,6 +308,73 @@ func TestBatchLimit(t *testing.T) {
 
 }
 
+// TestTooManyQueryArgs tests to make sure the library correctly handles the application level bug
+// whereby too many query arguments are passed to a query
+func TestTooManyQueryArgs(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query(`CREATE TABLE too_many_query_args (id int primary key, value int)`).Exec(); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	_, err := session.Query(`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2).Iter().SliceMap()
+
+	if err == nil {
+		t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an ErrQueryArgLength")
+	}
+
+	if err != ErrQueryArgLength {
+		t.Fatalf("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an ErrQueryArgLength, but returned: %s", err)
+	}
+
+	batch := session.NewBatch(UnloggedBatch)
+	batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3)
+	err = session.ExecuteBatch(batch)
+
+	if err == nil {
+		t.Fatal("'`INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an ErrQueryArgLength")
+	}
+
+	if err != ErrQueryArgLength {
+		t.Fatalf("'INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an ErrQueryArgLength, but returned: %s", err)
+	}
+
+}
+
+// TestNotEnoughQueryArgs tests to make sure the library correctly handles the application level bug
+// whereby not enough query arguments are passed to a query
+func TestNotEnoughQueryArgs(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query(`CREATE TABLE not_enough_query_args (id int, cluster int, value int, primary key (id, cluster))`).Exec(); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	_, err := session.Query(`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1).Iter().SliceMap()
+
+	if err == nil {
+		t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an ErrQueryArgLength")
+	}
+
+	if err != ErrQueryArgLength {
+		t.Fatalf("'`SELECT * FROM too_few_query_args WHERE id = ? and cluster = ?`, 1' should return an ErrQueryArgLength, but returned: %s", err)
+	}
+
+	batch := session.NewBatch(UnloggedBatch)
+	batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2)
+	err = session.ExecuteBatch(batch)
+
+	if err == nil {
+		t.Fatal("'`INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an ErrQueryArgLength")
+	}
+
+	if err != ErrQueryArgLength {
+		t.Fatalf("'INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an ErrQueryArgLength, but returned: %s", err)
+	}
+}
+
 // TestCreateSessionTimeout tests to make sure the CreateSession function timeouts out correctly
 // and prevents an infinite loop of connection retries.
 func TestCreateSessionTimeout(t *testing.T) {
@@ -597,13 +664,6 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 			TypeInfo: &TypeInfo{
 				Type: TypeVarchar,
 			},
-		}, ColumnInfo{
-			Keyspace: "gocql_test",
-			Table:    table,
-			Name:     "bar",
-			TypeInfo: &TypeInfo{
-				Type: TypeInt,
-			},
 		}},
 	}
 	return stmt, conn

+ 13 - 0
conn.go

@@ -6,6 +6,7 @@ package gocql
 
 import (
 	"bufio"
+	"errors"
 	"fmt"
 	"net"
 	"sync"
@@ -362,6 +363,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		if err != nil {
 			return &Iter{err: err}
 		}
+		if len(qry.values) != len(info.args) {
+			return &Iter{err: ErrQueryArgLength}
+		}
 		op.Prepared = info.id
 		op.Values = make([][]byte, len(qry.values))
 		for i := 0; i < len(qry.values); i++ {
@@ -472,6 +476,11 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		if len(entry.Args) > 0 {
 			var err error
 			info, err = c.prepareStatement(entry.Stmt, nil)
+
+			if len(entry.Args) != len(info.args) {
+				return ErrQueryArgLength
+			}
+
 			stmts[string(info.id)] = entry.Stmt
 			if err != nil {
 				return err
@@ -618,3 +627,7 @@ type inflightPrepare struct {
 	err  error
 	wg   sync.WaitGroup
 }
+
+var (
+	ErrQueryArgLength = errors.New("query argument length mismatch")
+)