浏览代码

Merge pull request #144 from phillipCouto/prevent_use_keyspace

Prevent use keyspace
Phillip Couto 11 年之前
父节点
当前提交
466e11f6c4
共有 5 个文件被更改,包括 69 次插入51 次删除
  1. 31 0
      README.md
  2. 22 7
      cassandra_test.go
  3. 10 41
      cluster.go
  4. 0 2
      conn.go
  5. 6 1
      session.go

+ 31 - 0
README.md

@@ -41,6 +41,37 @@ Features
 
 Please visit the [Roadmap](https://github.com/gocql/gocql/wiki/Roadmap) page to see what is on the horizion.
 
+Important Default Keyspace Changes
+----------------------------------
+gocql no longer supports executing "use <keyspace>" statements to simplfy the library. The user still has the
+ability to define the default keyspace for connections but now the keyspace can only be defined before a
+session is created. Queries can still access keyspaces by indicating the keyspace in the query:
+```sql
+SELECT * FROM example2.table;
+```
+
+Example of correct usage:
+```go
+	cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
+	cluster.Keyspace = "example"
+	...
+	session, err := cluster.CreateSession()
+
+```
+Example of incorrect usage:
+```go
+	cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
+	cluster.Keyspace = "example"
+	...
+	session, err := cluster.CreateSession()
+
+	if err = session.Query("use example2").Exec(); err != nil {
+		log.Fatal(err)
+	}
+```
+This will result in an err being returned from the session.Query line as the user is trying to execute a "use"
+statement. 
+
 Example
 -------
 

+ 22 - 7
cassandra_test.go

@@ -33,12 +33,11 @@ func createSession(t *testing.T) *Session {
 		Password: "cassandra",
 	}
 
-	session, err := cluster.CreateSession()
-	if err != nil {
-		t.Fatal("createSession:", err)
-	}
-
 	initOnce.Do(func() {
+		session, err := cluster.CreateSession()
+		if err != nil {
+			t.Fatal("createSession:", err)
+		}
 		// Drop and re-create the keyspace once. Different tests should use their own
 		// individual tables, but can assume that the table does not exist before.
 		if err := session.Query(`DROP KEYSPACE gocql_test`).Exec(); err != nil {
@@ -51,9 +50,11 @@ func createSession(t *testing.T) *Session {
 			}`).Exec(); err != nil {
 			t.Fatal("create keyspace:", err)
 		}
+		session.Close()
 	})
-
-	if err := session.Query(`USE gocql_test`).Exec(); err != nil {
+	cluster.Keyspace = "gocql_test"
+	session, err := cluster.CreateSession()
+	if err != nil {
 		t.Fatal("createSession:", err)
 	}
 
@@ -68,6 +69,20 @@ func TestEmptyHosts(t *testing.T) {
 	}
 }
 
+//TestUseStatementError checks to make sure the correct error is returned when the user tries to execute a use statement.
+func TestUseStatementError(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query("USE gocql_test").Exec(); err != nil {
+		if err != ErrUseStmt {
+			t.Error("expected ErrUseStmt, got " + err.Error())
+		}
+	} else {
+		t.Error("expected err, got nil.")
+	}
+}
+
 func TestCRUD(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()

+ 10 - 41
cluster.go

@@ -138,35 +138,26 @@ func (c *clusterImpl) connect(addr string) {
 				return
 			}
 		}
-		c.addConn(conn, "")
+		c.addConn(conn)
 		return
 	}
 }
 
-func (c *clusterImpl) changeKeyspace(conn *Conn, keyspace string, connected bool) {
-	if err := conn.UseKeyspace(keyspace); err != nil {
-		conn.Close()
-		if connected {
-			c.removeConn(conn)
-		}
-		go c.connect(conn.Address())
-	}
-	if !connected {
-		c.addConn(conn, keyspace)
-	}
-}
-
-func (c *clusterImpl) addConn(conn *Conn, keyspace string) {
+func (c *clusterImpl) addConn(conn *Conn) {
 	c.mu.Lock()
 	defer c.mu.Unlock()
 	if c.quit {
 		conn.Close()
 		return
 	}
-	if keyspace != c.keyspace && c.keyspace != "" {
-		// change the keyspace before adding the node to the pool
-		go c.changeKeyspace(conn, c.keyspace, false)
-		return
+	//Set the connection's keyspace if any before adding it to the pool
+	if c.keyspace != "" {
+		if err := conn.UseKeyspace(c.keyspace); err != nil {
+			log.Printf("error setting connection keyspace. %v", err)
+			conn.Close()
+			go c.connect(conn.Address())
+			return
+		}
 	}
 	connPool := c.connPool[conn.Address()]
 	if connPool == nil {
@@ -214,28 +205,6 @@ func (c *clusterImpl) HandleError(conn *Conn, err error, closed bool) {
 	}
 }
 
-func (c *clusterImpl) HandleKeyspace(conn *Conn, keyspace string) {
-	c.mu.Lock()
-	if c.keyspace == keyspace {
-		c.mu.Unlock()
-		return
-	}
-	c.keyspace = keyspace
-	conns := make([]*Conn, 0, len(c.conns))
-	for conn := range c.conns {
-		conns = append(conns, conn)
-	}
-	c.mu.Unlock()
-
-	// change the keyspace of all other connections too
-	for i := 0; i < len(conns); i++ {
-		if conns[i] == conn {
-			continue
-		}
-		c.changeKeyspace(conns[i], keyspace, true)
-	}
-}
-
 func (c *clusterImpl) Pick(qry *Query) *Conn {
 	return c.hostPool.Pick(qry)
 }

+ 0 - 2
conn.go

@@ -23,7 +23,6 @@ const maskVersion = 0x7F
 
 type Cluster interface {
 	HandleError(conn *Conn, err error, closed bool)
-	HandleKeyspace(conn *Conn, keyspace string)
 }
 
 type Authenticator interface {
@@ -427,7 +426,6 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		}
 		return iter
 	case resultKeyspaceFrame:
-		c.cluster.HandleKeyspace(c, x.Keyspace)
 		return &Iter{}
 	case errorFrame:
 		if x.Code == errUnprepared && len(qry.values) > 0 {

+ 6 - 1
session.go

@@ -8,6 +8,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"strings"
 	"sync"
 	"time"
 )
@@ -190,13 +191,16 @@ func (q *Query) RetryPolicy(r RetryPolicy) *Query {
 
 // Exec executes the query without returning any rows.
 func (q *Query) Exec() error {
-	iter := q.session.executeQuery(q)
+	iter := q.Iter()
 	return iter.err
 }
 
 // Iter executes the query and returns an iterator capable of iterating
 // over all results.
 func (q *Query) Iter() *Iter {
+	if strings.Index(strings.ToLower(q.stmt), "use") == 0 {
+		return &Iter{err: ErrUseStmt}
+	}
 	return q.session.executeQuery(q)
 }
 
@@ -467,6 +471,7 @@ var (
 	ErrProtocol     = errors.New("protocol error")
 	ErrUnsupported  = errors.New("feature not supported")
 	ErrTooManyStmts = errors.New("too many statements")
+	ErrUseStmt      = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.")
 )
 
 // BatchSizeMaximum is the maximum number of statements a batch operation can have.