Prechádzať zdrojové kódy

Fix a panic establishing sessions using NewSession

Previously, the only way to establish the *first* connection was to use ClusterConfig.CreateSession(). This is due to the global prepared statement cache only being initialised here.
Oliver Beattie 10 rokov pred
rodič
commit
241b2b965e
3 zmenil súbory, kde vykonal 29 pridanie a 22 odobranie
  1. 4 4
      cassandra_test.go
  2. 12 8
      cluster.go
  3. 13 10
      conn.go

+ 4 - 4
cassandra_test.go

@@ -967,9 +967,9 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	conn := session.Pool.Pick(nil)
 	conn := session.Pool.Pick(nil)
 	flight := new(inflightPrepare)
 	flight := new(inflightPrepare)
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 	flight.info = &QueryInfo{
 	flight.info = &QueryInfo{
 		Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		Args: []ColumnInfo{ColumnInfo{
 		Args: []ColumnInfo{ColumnInfo{
@@ -1057,9 +1057,9 @@ func TestQueryInfo(t *testing.T) {
 func TestPreparedCacheEviction(t *testing.T) {
 func TestPreparedCacheEviction(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
 	defer session.Close()
 	defer session.Close()
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
 	stmtsLRU.Max(4)
 	stmtsLRU.Max(4)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 
 
 	if err := createTable(session, "CREATE TABLE prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
 	if err := createTable(session, "CREATE TABLE prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
 		t.Fatalf("failed to create table with error '%v'", err)
 		t.Fatalf("failed to create table with error '%v'", err)

+ 12 - 8
cluster.go

@@ -17,8 +17,8 @@ var stmtsLRU preparedLRU
 
 
 //preparedLRU is the prepared statement cache
 //preparedLRU is the prepared statement cache
 type preparedLRU struct {
 type preparedLRU struct {
+	sync.Mutex
 	lru *lru.Cache
 	lru *lru.Cache
-	mu  sync.Mutex
 }
 }
 
 
 //Max adjusts the maximum size of the cache and cleans up the oldest records if
 //Max adjusts the maximum size of the cache and cleans up the oldest records if
@@ -30,6 +30,14 @@ func (p *preparedLRU) Max(max int) {
 	p.lru.MaxEntries = max
 	p.lru.MaxEntries = max
 }
 }
 
 
+func initStmtsLRU(max int) {
+	if stmtsLRU.lru != nil {
+		stmtsLRU.Max(max)
+	} else {
+		stmtsLRU.lru = lru.New(max)
+	}
+}
+
 // To enable periodic node discovery enable DiscoverHosts in ClusterConfig
 // To enable periodic node discovery enable DiscoverHosts in ClusterConfig
 type DiscoveryConfig struct {
 type DiscoveryConfig struct {
 	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
 	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
@@ -94,13 +102,9 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	pool := cfg.ConnPoolType(cfg)
 	pool := cfg.ConnPoolType(cfg)
 
 
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	//Adjust the size of the prepared statements cache to match the latest configuration
-	stmtsLRU.mu.Lock()
-	if stmtsLRU.lru != nil {
-		stmtsLRU.Max(cfg.MaxPreparedStmts)
-	} else {
-		stmtsLRU.lru = lru.New(cfg.MaxPreparedStmts)
-	}
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Lock()
+	initStmtsLRU(cfg.MaxPreparedStmts)
+	stmtsLRU.Unlock()
 
 
 	//See if there are any connections in the pool
 	//See if there are any connections in the pool
 	if pool.Size() > 0 {
 	if pool.Size() > 0 {

+ 13 - 10
conn.go

@@ -364,13 +364,16 @@ func (c *Conn) ping() error {
 }
 }
 
 
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
-	stmtsLRU.mu.Lock()
+	stmtsLRU.Lock()
+	if stmtsLRU.lru == nil {
+		initStmtsLRU(1000)
+	}
 
 
 	stmtCacheKey := c.addr + c.currentKeyspace + stmt
 	stmtCacheKey := c.addr + c.currentKeyspace + stmt
 
 
 	if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 	if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 		flight := val.(*inflightPrepare)
 		flight := val.(*inflightPrepare)
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 		flight.wg.Wait()
 		flight.wg.Wait()
 		return flight.info, flight.err
 		return flight.info, flight.err
 	}
 	}
@@ -378,7 +381,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	flight := new(inflightPrepare)
 	flight := new(inflightPrepare)
 	flight.wg.Add(1)
 	flight.wg.Add(1)
 	stmtsLRU.lru.Add(stmtCacheKey, flight)
 	stmtsLRU.lru.Add(stmtCacheKey, flight)
-	stmtsLRU.mu.Unlock()
+	stmtsLRU.Unlock()
 
 
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	if err != nil {
 	if err != nil {
@@ -402,9 +405,9 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	flight.wg.Done()
 	flight.wg.Done()
 
 
 	if err != nil {
 	if err != nil {
-		stmtsLRU.mu.Lock()
+		stmtsLRU.Lock()
 		stmtsLRU.lru.Remove(stmtCacheKey)
 		stmtsLRU.lru.Remove(stmtCacheKey)
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 	}
 	}
 
 
 	return flight.info, flight.err
 	return flight.info, flight.err
@@ -471,14 +474,14 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case resultKeyspaceFrame:
 	case resultKeyspaceFrame:
 		return &Iter{}
 		return &Iter{}
 	case RequestErrUnprepared:
 	case RequestErrUnprepared:
-		stmtsLRU.mu.Lock()
+		stmtsLRU.Lock()
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
 		if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 		if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 			stmtsLRU.lru.Remove(stmtCacheKey)
 			stmtsLRU.lru.Remove(stmtCacheKey)
-			stmtsLRU.mu.Unlock()
+			stmtsLRU.Unlock()
 			return c.executeQuery(qry)
 			return c.executeQuery(qry)
 		}
 		}
-		stmtsLRU.mu.Unlock()
+		stmtsLRU.Unlock()
 		return &Iter{err: x}
 		return &Iter{err: x}
 	case error:
 	case error:
 		return &Iter{err: x}
 		return &Iter{err: x}
@@ -602,9 +605,9 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	case RequestErrUnprepared:
 	case RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
 		if found {
-			stmtsLRU.mu.Lock()
+			stmtsLRU.Lock()
 			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
 			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
-			stmtsLRU.mu.Unlock()
+			stmtsLRU.Unlock()
 		}
 		}
 		if found {
 		if found {
 			return c.executeBatch(batch)
 			return c.executeBatch(batch)