Browse Source

simplify the prepared cache access

Chris Bannister 9 years ago
parent
commit
45c7cec18e
5 changed files with 125 additions and 96 deletions
  1. 39 40
      cassandra_test.go
  2. 0 32
      cluster.go
  3. 16 22
      conn.go
  4. 6 2
      internal/lru/lru.go
  5. 64 0
      prepared_cache.go

+ 39 - 40
cassandra_test.go

@@ -991,12 +991,14 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	)`); err != nil {
 		t.Fatal("create:", err)
 	}
+
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	_, conn := session.pool.Pick(nil)
+
 	flight := new(inflightPrepare)
-	session.stmtsLRU.Lock()
-	session.stmtsLRU.lru.Add(conn.addr+stmt, flight)
-	session.stmtsLRU.Unlock()
+	key := session.stmtsLRU.keyFor(conn.addr, "", stmt)
+	session.stmtsLRU.add(key, flight)
+
 	flight.preparedStatment = &preparedStatment{
 		id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		request: preparedMetadata{
@@ -1016,10 +1018,11 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 			},
 		},
 	}
+
 	return stmt, conn
 }
 
-func TestMissingSchemaPrepare(t *testing.T) {
+func TestPrepare_MissingSchemaPrepare(t *testing.T) {
 	s := createSession(t)
 	_, conn := s.pool.Pick(nil)
 	defer s.Close()
@@ -1041,7 +1044,7 @@ func TestMissingSchemaPrepare(t *testing.T) {
 	}
 }
 
-func TestReprepareStatement(t *testing.T) {
+func TestPrepare_ReprepareStatement(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement")
@@ -1051,7 +1054,7 @@ func TestReprepareStatement(t *testing.T) {
 	}
 }
 
-func TestReprepareBatch(t *testing.T) {
+func TestPrepare_ReprepareBatch(t *testing.T) {
 	if *flagProto == 1 {
 		t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
 	}
@@ -1089,8 +1092,9 @@ func TestQueryInfo(t *testing.T) {
 }
 
 //TestPreparedCacheEviction will make sure that the cache size is maintained
-func TestPreparedCacheEviction(t *testing.T) {
+func TestPrepare_PreparedCacheEviction(t *testing.T) {
 	const maxPrepared = 4
+
 	cluster := createCluster()
 	cluster.MaxPreparedStmts = maxPrepared
 	cluster.Events.DisableSchemaEvents = true
@@ -1101,6 +1105,9 @@ func TestPreparedCacheEviction(t *testing.T) {
 	if err := createTable(session, "CREATE TABLE gocql_test.prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
 		t.Fatalf("failed to create table with error '%v'", err)
 	}
+	// clear the cache
+	session.stmtsLRU.clear()
+
 	//Fill the table
 	for i := 0; i < 2; i++ {
 		if err := session.Query("INSERT INTO prepcachetest (id,mod) VALUES (?, ?)", i, 10000%(i+1)).Exec(); err != nil {
@@ -1134,52 +1141,44 @@ func TestPreparedCacheEviction(t *testing.T) {
 		t.Fatalf("insert into prepcachetest failed, error '%v'", err)
 	}
 
-	session.stmtsLRU.Lock()
+	session.stmtsLRU.mu.Lock()
+	defer session.stmtsLRU.mu.Unlock()
 
 	//Make sure the cache size is maintained
 	if session.stmtsLRU.lru.Len() != session.stmtsLRU.lru.MaxEntries {
 		t.Fatalf("expected cache size of %v, got %v", session.stmtsLRU.lru.MaxEntries, session.stmtsLRU.lru.Len())
 	}
 
-	//Walk through all the configured hosts and test cache retention and eviction
-	var selFound, insFound, updFound, delFound, selEvict bool
-	for i := range session.cfg.Hosts {
-		_, ok := session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 1")
-		selFound = selFound || ok
-
-		_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testINSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
-		insFound = insFound || ok
-
-		_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testUPDATE prepcachetest SET mod = ? WHERE id = ?")
-		updFound = updFound || ok
+	// Walk through all the configured hosts and test cache retention and eviction
+	for _, host := range session.cfg.Hosts {
+		_, ok := session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "SELECT id,mod FROM prepcachetest WHERE id = 0"))
+		if ok {
+			t.Errorf("expected first select to be purged but was in cache for host=%q", host)
+		}
 
-		_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testDELETE FROM prepcachetest WHERE id = ?")
-		delFound = delFound || ok
+		_, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "SELECT id,mod FROM prepcachetest WHERE id = 1"))
+		if !ok {
+			t.Errorf("exepected second select to be in cache for host=%q", host)
+		}
 
-		_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 0")
-		selEvict = selEvict || !ok
-	}
+		_, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "INSERT INTO prepcachetest (id,mod) VALUES (?, ?)"))
+		if !ok {
+			t.Errorf("expected insert to be in cache for host=%q", host)
+		}
 
-	session.stmtsLRU.Unlock()
+		_, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "UPDATE prepcachetest SET mod = ? WHERE id = ?"))
+		if !ok {
+			t.Errorf("expected update to be in cached for host=%q", host)
+		}
 
-	if !selEvict {
-		t.Fatalf("expected first select statement to be purged, but statement was found in the cache.")
-	}
-	if !selFound {
-		t.Fatalf("expected second select statement to be cached, but statement was purged or not prepared.")
-	}
-	if !insFound {
-		t.Fatalf("expected insert statement to be cached, but statement was purged or not prepared.")
-	}
-	if !updFound {
-		t.Fatalf("expected update statement to be cached, but statement was purged or not prepared.")
-	}
-	if !delFound {
-		t.Error("expected delete statement to be cached, but statement was purged or not prepared.")
+		_, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host+":9042", session.cfg.Keyspace, "DELETE FROM prepcachetest WHERE id = ?"))
+		if !ok {
+			t.Errorf("expected delete to be cached for host=%q", host)
+		}
 	}
 }
 
-func TestPreparedCacheKey(t *testing.T) {
+func TestPrepare_PreparedCacheKey(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 

+ 0 - 32
cluster.go

@@ -6,41 +6,9 @@ package gocql
 
 import (
 	"errors"
-	"sync"
 	"time"
-
-	"github.com/gocql/gocql/internal/lru"
 )
 
-const defaultMaxPreparedStmts = 1000
-
-//preparedLRU is the prepared statement cache
-type preparedLRU struct {
-	sync.Mutex
-	lru *lru.Cache
-}
-
-//Max adjusts the maximum size of the cache and cleans up the oldest records if
-//the new max is lower than the previous value. Not concurrency safe.
-func (p *preparedLRU) max(max int) {
-	p.Lock()
-	defer p.Unlock()
-
-	for p.lru.Len() > max {
-		p.lru.RemoveOldest()
-	}
-	p.lru.MaxEntries = max
-}
-
-func (p *preparedLRU) clear() {
-	p.Lock()
-	defer p.Unlock()
-
-	for p.lru.Len() > 0 {
-		p.lru.RemoveOldest()
-	}
-}
-
 // PoolConfig configures the connection pool used by the driver, it defaults to
 // using a round robbin host selection policy and a round robbin connection selection
 // policy for each host.

+ 16 - 22
conn.go

@@ -9,6 +9,7 @@ import (
 	"crypto/tls"
 	"errors"
 	"fmt"
+	"github.com/gocql/gocql/internal/lru"
 	"io"
 	"io/ioutil"
 	"log"
@@ -593,20 +594,19 @@ type inflightPrepare struct {
 }
 
 func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment, error) {
-	c.session.stmtsLRU.Lock()
-	stmtCacheKey := c.addr + c.currentKeyspace + stmt
-	if val, ok := c.session.stmtsLRU.lru.Get(stmtCacheKey); ok {
-		c.session.stmtsLRU.Unlock()
-		flight := val.(*inflightPrepare)
+	stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
+	flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
+		flight := new(inflightPrepare)
+		flight.wg.Add(1)
+		lru.Add(stmtCacheKey, flight)
+		return flight
+	})
+
+	if ok {
 		flight.wg.Wait()
 		return flight.preparedStatment, flight.err
 	}
 
-	flight := new(inflightPrepare)
-	flight.wg.Add(1)
-	c.session.stmtsLRU.lru.Add(stmtCacheKey, flight)
-	c.session.stmtsLRU.Unlock()
-
 	prep := &writePrepareFrame{
 		statement: stmt,
 	}
@@ -650,9 +650,7 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment,
 	flight.wg.Done()
 
 	if flight.err != nil {
-		c.session.stmtsLRU.Lock()
-		c.session.stmtsLRU.lru.Remove(stmtCacheKey)
-		c.session.stmtsLRU.Unlock()
+		c.session.stmtsLRU.remove(stmtCacheKey)
 	}
 
 	framerPool.Put(framer)
@@ -799,14 +797,11 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		// is not consistent with regards to its schema.
 		return iter
 	case *RequestErrUnprepared:
-		c.session.stmtsLRU.Lock()
-		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
-		if _, ok := c.session.stmtsLRU.lru.Get(stmtCacheKey); ok {
-			c.session.stmtsLRU.lru.Remove(stmtCacheKey)
-			c.session.stmtsLRU.Unlock()
+		stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
+		if c.session.stmtsLRU.remove(stmtCacheKey) {
 			return c.executeQuery(qry)
 		}
-		c.session.stmtsLRU.Unlock()
+
 		return &Iter{err: x, framer: framer}
 	case error:
 		return &Iter{err: x, framer: framer}
@@ -945,9 +940,8 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 	case *RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
-			c.session.stmtsLRU.Lock()
-			c.session.stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
-			c.session.stmtsLRU.Unlock()
+			key := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
+			c.session.stmtsLRU.remove(key)
 		}
 
 		framerPool.Put(framer)

+ 6 - 2
internal/lru/lru.go

@@ -85,13 +85,17 @@ func (c *Cache) Get(key string) (value interface{}, ok bool) {
 }
 
 // Remove removes the provided key from the cache.
-func (c *Cache) Remove(key string) {
+func (c *Cache) Remove(key string) bool {
 	if c.cache == nil {
-		return
+		return false
 	}
+
 	if ele, hit := c.cache[key]; hit {
 		c.removeElement(ele)
+		return true
 	}
+
+	return false
 }
 
 // RemoveOldest removes the oldest item from the cache.

+ 64 - 0
prepared_cache.go

@@ -0,0 +1,64 @@
+package gocql
+
+import (
+	"github.com/gocql/gocql/internal/lru"
+	"sync"
+)
+
+const defaultMaxPreparedStmts = 1000
+
+// preparedLRU is the prepared statement cache
+type preparedLRU struct {
+	mu  sync.Mutex
+	lru *lru.Cache
+}
+
+// Max adjusts the maximum size of the cache and cleans up the oldest records if
+// the new max is lower than the previous value. Not concurrency safe.
+func (p *preparedLRU) max(max int) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+
+	for p.lru.Len() > max {
+		p.lru.RemoveOldest()
+	}
+	p.lru.MaxEntries = max
+}
+
+func (p *preparedLRU) clear() {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+
+	for p.lru.Len() > 0 {
+		p.lru.RemoveOldest()
+	}
+}
+
+func (p *preparedLRU) add(key string, val *inflightPrepare) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	p.lru.Add(key, val)
+}
+
+func (p *preparedLRU) remove(key string) bool {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	return p.lru.Remove(key)
+}
+
+func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+
+	val, ok := p.lru.Get(key)
+	if ok {
+		return val.(*inflightPrepare), true
+	}
+
+	return fn(p.lru), false
+}
+
+func (p *preparedLRU) keyFor(addr, keyspace, statement string) string {
+	// TODO: maybe use []byte for keys?
+	return addr + keyspace + statement
+}