فهرست منبع

Fix a panic establishing sessions using NewSession

Previously, the only way to establish the *first* connection was to use ClusterConfig.CreateSession(). This is due to the global prepared statement cache only being initialised here.
Oliver Beattie 11 سال پیش
والد
کامیت
241b2b965e
3فایلهای تغییر یافته به همراه29 افزوده شده و 22 حذف شده
  1. 4 4
      cassandra_test.go
  2. 12 8
      cluster.go
  3. 13 10
      conn.go

+ 4 - 4
cassandra_test.go

@@ -967,9 +967,9 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	conn := session.Pool.Pick(nil)
 	flight := new(inflightPrepare)
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 	flight.info = &QueryInfo{
 		Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		Args: []ColumnInfo{ColumnInfo{
@@ -1057,9 +1057,9 @@ func TestQueryInfo(t *testing.T) {
 func TestPreparedCacheEviction(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
 	stmtsLRU.Max(4)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 
 	if err := createTable(session, "CREATE TABLE prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
 		t.Fatalf("failed to create table with error '%v'", err)

+ 12 - 8
cluster.go

@@ -17,8 +17,8 @@ var stmtsLRU preparedLRU
 
 //preparedLRU is the prepared statement cache
 type preparedLRU struct {
+	sync.Mutex
 	lru *lru.Cache
-	mu  sync.Mutex
 }
 
 //Max adjusts the maximum size of the cache and cleans up the oldest records if
@@ -30,6 +30,14 @@ func (p *preparedLRU) Max(max int) {
 	p.lru.MaxEntries = max
 }
 
+func initStmtsLRU(max int) {
+	if stmtsLRU.lru != nil {
+		stmtsLRU.Max(max)
+	} else {
+		stmtsLRU.lru = lru.New(max)
+	}
+}
+
 // To enable periodic node discovery enable DiscoverHosts in ClusterConfig
 type DiscoveryConfig struct {
 	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
@@ -94,13 +102,9 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	pool := cfg.ConnPoolType(cfg)
 
 	//Adjust the size of the prepared statements cache to match the latest configuration
-	stmtsLRU.mu.Lock()
-	if stmtsLRU.lru != nil {
-		stmtsLRU.Max(cfg.MaxPreparedStmts)
-	} else {
-		stmtsLRU.lru = lru.New(cfg.MaxPreparedStmts)
-	}
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Lock()
+	initStmtsLRU(cfg.MaxPreparedStmts)
+	stmtsLRU.Unlock()
 
 	//See if there are any connections in the pool
 	if pool.Size() > 0 {

+ 13 - 10
conn.go

@@ -364,13 +364,16 @@ func (c *Conn) ping() error {
 }
 
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
+	if stmtsLRU.lru == nil {
+		initStmtsLRU(1000)
+	}
 
 	stmtCacheKey := c.addr + c.currentKeyspace + stmt
 
 	if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 		flight := val.(*inflightPrepare)
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 		flight.wg.Wait()
 		return flight.info, flight.err
 	}
@@ -378,7 +381,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	flight := new(inflightPrepare)
 	flight.wg.Add(1)
 	stmtsLRU.lru.Add(stmtCacheKey, flight)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	if err != nil {
@@ -402,9 +405,9 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	flight.wg.Done()
 
 	if err != nil {
-		stmtsLRU.mu.Lock()
+		stmtsLRU.Lock()
 		stmtsLRU.lru.Remove(stmtCacheKey)
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 	}
 
 	return flight.info, flight.err
@@ -471,14 +474,14 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case resultKeyspaceFrame:
 		return &Iter{}
 	case RequestErrUnprepared:
-		stmtsLRU.mu.Lock()
+		stmtsLRU.Lock()
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
 		if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 			stmtsLRU.lru.Remove(stmtCacheKey)
-			stmtsLRU.mu.Unlock()
+			stmtsLRU.Unlock()
 			return c.executeQuery(qry)
 		}
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 		return &Iter{err: x}
 	case error:
 		return &Iter{err: x}
@@ -602,9 +605,9 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	case RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
-			stmtsLRU.mu.Lock()
+			stmtsLRU.Lock()
 			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
-			stmtsLRU.mu.Unlock()
+			stmtsLRU.Unlock()
 		}
 		if found {
 			return c.executeBatch(batch)