Procházet zdrojové kódy

Added check to prevent user from executing a use statement from an active session.

Phillip Couto před 11 roky
rodič
revize
f9a5343db1
2 změnil soubory, kde provedl 28 přidání a 8 odebrání
  1. 22 7
      cassandra_test.go
  2. 6 1
      session.go

+ 22 - 7
cassandra_test.go

@@ -33,12 +33,11 @@ func createSession(t *testing.T) *Session {
 		Password: "cassandra",
 		Password: "cassandra",
 	}
 	}
 
 
-	session, err := cluster.CreateSession()
-	if err != nil {
-		t.Fatal("createSession:", err)
-	}
-
 	initOnce.Do(func() {
 	initOnce.Do(func() {
+		session, err := cluster.CreateSession()
+		if err != nil {
+			t.Fatal("createSession:", err)
+		}
 		// Drop and re-create the keyspace once. Different tests should use their own
 		// Drop and re-create the keyspace once. Different tests should use their own
 		// individual tables, but can assume that the table does not exist before.
 		// individual tables, but can assume that the table does not exist before.
 		if err := session.Query(`DROP KEYSPACE gocql_test`).Exec(); err != nil {
 		if err := session.Query(`DROP KEYSPACE gocql_test`).Exec(); err != nil {
@@ -51,9 +50,11 @@ func createSession(t *testing.T) *Session {
 			}`).Exec(); err != nil {
 			}`).Exec(); err != nil {
 			t.Fatal("create keyspace:", err)
 			t.Fatal("create keyspace:", err)
 		}
 		}
+		session.Close()
 	})
 	})
-
-	if err := session.Query(`USE gocql_test`).Exec(); err != nil {
+	cluster.Keyspace = "gocql_test"
+	session, err := cluster.CreateSession()
+	if err != nil {
 		t.Fatal("createSession:", err)
 		t.Fatal("createSession:", err)
 	}
 	}
 
 
@@ -68,6 +69,20 @@ func TestEmptyHosts(t *testing.T) {
 	}
 	}
 }
 }
 
 
+//TestUseStatementError checks to make sure the correct error is returned when the user tries to execute a use statement.
+func TestUseStatementError(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query("USE gocql_test").Exec(); err != nil {
+		if err != ErrUseStmt {
+			t.Error("expected ErrUseStmt, got " + err.Error())
+		}
+	} else {
+		t.Error("expected err, got nil.")
+	}
+}
+
 func TestCRUD(t *testing.T) {
 func TestCRUD(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
 	defer session.Close()
 	defer session.Close()

+ 6 - 1
session.go

@@ -8,6 +8,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
 )
 )
@@ -190,13 +191,16 @@ func (q *Query) RetryPolicy(r RetryPolicy) *Query {
 
 
 // Exec executes the query without returning any rows.
 // Exec executes the query without returning any rows.
 func (q *Query) Exec() error {
 func (q *Query) Exec() error {
-	iter := q.session.executeQuery(q)
+	iter := q.Iter()
 	return iter.err
 	return iter.err
 }
 }
 
 
 // Iter executes the query and returns an iterator capable of iterating
 // Iter executes the query and returns an iterator capable of iterating
 // over all results.
 // over all results.
 func (q *Query) Iter() *Iter {
 func (q *Query) Iter() *Iter {
+	if strings.Index(strings.ToLower(q.stmt), "use") == 0 {
+		return &Iter{err: ErrUseStmt}
+	}
 	return q.session.executeQuery(q)
 	return q.session.executeQuery(q)
 }
 }
 
 
@@ -467,6 +471,7 @@ var (
 	ErrProtocol     = errors.New("protocol error")
 	ErrProtocol     = errors.New("protocol error")
 	ErrUnsupported  = errors.New("feature not supported")
 	ErrUnsupported  = errors.New("feature not supported")
 	ErrTooManyStmts = errors.New("too many statements")
 	ErrTooManyStmts = errors.New("too many statements")
+	ErrUseStmt      = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.")
 )
 )
 
 
 // BatchSizeMaximum is the maximum number of statements a batch operation can have.
 // BatchSizeMaximum is the maximum number of statements a batch operation can have.