Explorar o código

Introduce handling for query argument length mismatches

Ben Hood %!s(int64=11) %!d(string=hai) anos
pai
achega
7fd92f1209
Modificáronse 2 ficheiros con 12 adicións e 2 borrados
  1. 4 2
      cassandra_test.go
  2. 8 0
      conn.go

+ 4 - 2
cassandra_test.go

@@ -317,8 +317,10 @@ func TestWrongQueryArgsLength(t *testing.T) {
 		t.Fatal("create table:", err)
 	}
 
-	if _, err := session.Query(`SELECT * FROM query_args_length WHERE id = ?`, 1, 2).Iter().SliceMap(); err != nil {
-		t.Fatal("select query_args_length:", err)
+	_, err := session.Query(`SELECT * FROM query_args_length WHERE id = ?`, 1, 2).Iter().SliceMap()
+
+	if err == nil || err != ErrQueryArgLength {
+		t.Fatal("'`SELECT * FROM query_args_length WHERE id = ?`, 1, 2' should return an ErrQueryArgLength")
 	}
 }
 

+ 8 - 0
conn.go

@@ -7,6 +7,7 @@ package gocql
 import (
 	"bufio"
 	"bytes"
+	"errors"
 	"fmt"
 	"net"
 	"sync"
@@ -367,6 +368,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		if err != nil {
 			return &Iter{err: err}
 		}
+		if len(qry.values) != len(info.args) {
+			return &Iter{err: ErrQueryArgLength}
+		}
 		op.Prepared = info.id
 		op.Values = make([][]byte, len(qry.values))
 		for i := 0; i < len(qry.values); i++ {
@@ -626,3 +630,7 @@ type inflightPrepare struct {
 	err  error
 	wg   sync.WaitGroup
 }
+
+var (
+	ErrQueryArgLength = errors.New("query argument length mismatch")
+)