Procházet zdrojové kódy

Merge pull request #311 from obeattie/fix-panic

Fix a panic establishing sessions using NewSession
Ben Hood před 10 roky
rodič
revize
5604309060
4 změnil soubory, kde provedl 33 přidání a 23 odebrání
  1. 1 0
      AUTHORS
  2. 4 4
      cassandra_test.go
  3. 15 9
      cluster.go
  4. 13 10
      conn.go

+ 1 - 0
AUTHORS

@@ -42,3 +42,4 @@ Zach Marcantel <zmarcantel@gmail.com>
 James Maloney <jamessagan@gmail.com>
 James Maloney <jamessagan@gmail.com>
 Ashwin Purohit <purohit@gmail.com>
 Ashwin Purohit <purohit@gmail.com>
 Dan Kinder <dkinder.is.me@gmail.com>
 Dan Kinder <dkinder.is.me@gmail.com>
+Oliver Beattie <oliver@obeattie.com>

+ 4 - 4
cassandra_test.go

@@ -967,9 +967,9 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	conn := session.Pool.Pick(nil)
 	conn := session.Pool.Pick(nil)
 	flight := new(inflightPrepare)
 	flight := new(inflightPrepare)
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 	flight.info = &QueryInfo{
 	flight.info = &QueryInfo{
 		Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		Args: []ColumnInfo{ColumnInfo{
 		Args: []ColumnInfo{ColumnInfo{
@@ -1057,9 +1057,9 @@ func TestQueryInfo(t *testing.T) {
 func TestPreparedCacheEviction(t *testing.T) {
 func TestPreparedCacheEviction(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
 	defer session.Close()
 	defer session.Close()
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
 	stmtsLRU.Max(4)
 	stmtsLRU.Max(4)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 
 
 	if err := createTable(session, "CREATE TABLE prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
 	if err := createTable(session, "CREATE TABLE prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
 		t.Fatalf("failed to create table with error '%v'", err)
 		t.Fatalf("failed to create table with error '%v'", err)

+ 15 - 9
cluster.go

@@ -12,13 +12,15 @@ import (
 	"github.com/golang/groupcache/lru"
 	"github.com/golang/groupcache/lru"
 )
 )
 
 
+const defaultMaxPreparedStmts = 1000
+
 //Package global reference to Prepared Statements LRU
 //Package global reference to Prepared Statements LRU
 var stmtsLRU preparedLRU
 var stmtsLRU preparedLRU
 
 
 //preparedLRU is the prepared statement cache
 //preparedLRU is the prepared statement cache
 type preparedLRU struct {
 type preparedLRU struct {
+	sync.Mutex
 	lru *lru.Cache
 	lru *lru.Cache
-	mu  sync.Mutex
 }
 }
 
 
 //Max adjusts the maximum size of the cache and cleans up the oldest records if
 //Max adjusts the maximum size of the cache and cleans up the oldest records if
@@ -30,6 +32,14 @@ func (p *preparedLRU) Max(max int) {
 	p.lru.MaxEntries = max
 	p.lru.MaxEntries = max
 }
 }
 
 
+func initStmtsLRU(max int) {
+	if stmtsLRU.lru != nil {
+		stmtsLRU.Max(max)
+	} else {
+		stmtsLRU.lru = lru.New(max)
+	}
+}
+
 // To enable periodic node discovery enable DiscoverHosts in ClusterConfig
 // To enable periodic node discovery enable DiscoverHosts in ClusterConfig
 type DiscoveryConfig struct {
 type DiscoveryConfig struct {
 	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
 	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
@@ -79,7 +89,7 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Consistency:      Quorum,
 		Consistency:      Quorum,
 		ConnPoolType:     NewSimplePool,
 		ConnPoolType:     NewSimplePool,
 		DiscoverHosts:    false,
 		DiscoverHosts:    false,
-		MaxPreparedStmts: 1000,
+		MaxPreparedStmts: defaultMaxPreparedStmts,
 	}
 	}
 	return cfg
 	return cfg
 }
 }
@@ -95,13 +105,9 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	pool := cfg.ConnPoolType(cfg)
 	pool := cfg.ConnPoolType(cfg)
 
 
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	//Adjust the size of the prepared statements cache to match the latest configuration
-	stmtsLRU.mu.Lock()
-	if stmtsLRU.lru != nil {
-		stmtsLRU.Max(cfg.MaxPreparedStmts)
-	} else {
-		stmtsLRU.lru = lru.New(cfg.MaxPreparedStmts)
-	}
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Lock()
+	initStmtsLRU(cfg.MaxPreparedStmts)
+	stmtsLRU.Unlock()
 
 
 	//See if there are any connections in the pool
 	//See if there are any connections in the pool
 	if pool.Size() > 0 {
 	if pool.Size() > 0 {

+ 13 - 10
conn.go

@@ -364,13 +364,16 @@ func (c *Conn) ping() error {
 }
 }
 
 
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
+	if stmtsLRU.lru == nil {
+		initStmtsLRU(defaultMaxPreparedStmts)
+	}
 
 
 	stmtCacheKey := c.addr + c.currentKeyspace + stmt
 	stmtCacheKey := c.addr + c.currentKeyspace + stmt
 
 
 	if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 	if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 		flight := val.(*inflightPrepare)
 		flight := val.(*inflightPrepare)
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 		flight.wg.Wait()
 		flight.wg.Wait()
 		return flight.info, flight.err
 		return flight.info, flight.err
 	}
 	}
@@ -378,7 +381,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	flight := new(inflightPrepare)
 	flight := new(inflightPrepare)
 	flight.wg.Add(1)
 	flight.wg.Add(1)
 	stmtsLRU.lru.Add(stmtCacheKey, flight)
 	stmtsLRU.lru.Add(stmtCacheKey, flight)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 
 
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	if err != nil {
 	if err != nil {
@@ -402,9 +405,9 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	flight.wg.Done()
 	flight.wg.Done()
 
 
 	if err != nil {
 	if err != nil {
-		stmtsLRU.mu.Lock()
+		stmtsLRU.Lock()
 		stmtsLRU.lru.Remove(stmtCacheKey)
 		stmtsLRU.lru.Remove(stmtCacheKey)
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 	}
 	}
 
 
 	return flight.info, flight.err
 	return flight.info, flight.err
@@ -471,14 +474,14 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case resultKeyspaceFrame:
 	case resultKeyspaceFrame:
 		return &Iter{}
 		return &Iter{}
 	case RequestErrUnprepared:
 	case RequestErrUnprepared:
-		stmtsLRU.mu.Lock()
+		stmtsLRU.Lock()
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
 		if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 		if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 			stmtsLRU.lru.Remove(stmtCacheKey)
 			stmtsLRU.lru.Remove(stmtCacheKey)
-			stmtsLRU.mu.Unlock()
+			stmtsLRU.Unlock()
 			return c.executeQuery(qry)
 			return c.executeQuery(qry)
 		}
 		}
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 		return &Iter{err: x}
 		return &Iter{err: x}
 	case error:
 	case error:
 		return &Iter{err: x}
 		return &Iter{err: x}
@@ -602,9 +605,9 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	case RequestErrUnprepared:
 	case RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
 		if found {
-			stmtsLRU.mu.Lock()
+			stmtsLRU.Lock()
 			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
 			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
-			stmtsLRU.mu.Unlock()
+			stmtsLRU.Unlock()
 		}
 		}
 		if found {
 		if found {
 			return c.executeBatch(batch)
 			return c.executeBatch(batch)