Browse Source

Merge pull request #606 from Zariel/session-local-statment-cache

Session local statment cache
Chris Bannister 10 năm trước cách đây
mục cha
commit
8616d8e8a3
4 tập tin đã thay đổi với 36 bổ sung54 xóa
  1. 13 15
      cassandra_test.go
  2. 5 13
      cluster.go
  3. 16 21
      conn.go
  4. 2 5
      session.go

+ 13 - 15
cassandra_test.go

@@ -1008,9 +1008,9 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	_, conn := session.pool.Pick(nil)
 	flight := new(inflightPrepare)
-	stmtsLRU.Lock()
-	stmtsLRU.lru.Add(conn.addr+stmt, flight)
-	stmtsLRU.Unlock()
+	session.stmtsLRU.Lock()
+	session.stmtsLRU.lru.Add(conn.addr+stmt, flight)
+	session.stmtsLRU.Unlock()
 	flight.info = QueryInfo{
 		Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		Args: []ColumnInfo{
@@ -1100,9 +1100,7 @@ func TestQueryInfo(t *testing.T) {
 func TestPreparedCacheEviction(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
-	stmtsLRU.Lock()
-	stmtsLRU.Max(4)
-	stmtsLRU.Unlock()
+	session.stmtsLRU.max(4)
 
 	if err := createTable(session, "CREATE TABLE gocql_test.prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil {
 		t.Fatalf("failed to create table with error '%v'", err)
@@ -1140,33 +1138,33 @@ func TestPreparedCacheEviction(t *testing.T) {
 		t.Fatalf("insert into prepcachetest failed, error '%v'", err)
 	}
 
-	stmtsLRU.Lock()
+	session.stmtsLRU.Lock()
 
 	//Make sure the cache size is maintained
-	if stmtsLRU.lru.Len() != stmtsLRU.lru.MaxEntries {
-		t.Fatalf("expected cache size of %v, got %v", stmtsLRU.lru.MaxEntries, stmtsLRU.lru.Len())
+	if session.stmtsLRU.lru.Len() != session.stmtsLRU.lru.MaxEntries {
+		t.Fatalf("expected cache size of %v, got %v", session.stmtsLRU.lru.MaxEntries, session.stmtsLRU.lru.Len())
 	}
 
 	//Walk through all the configured hosts and test cache retention and eviction
 	var selFound, insFound, updFound, delFound, selEvict bool
 	for i := range session.cfg.Hosts {
-		_, ok := stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 1")
+		_, ok := session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 1")
 		selFound = selFound || ok
 
-		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testINSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
+		_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testINSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
 		insFound = insFound || ok
 
-		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testUPDATE prepcachetest SET mod = ? WHERE id = ?")
+		_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testUPDATE prepcachetest SET mod = ? WHERE id = ?")
 		updFound = updFound || ok
 
-		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testDELETE FROM prepcachetest WHERE id = ?")
+		_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testDELETE FROM prepcachetest WHERE id = ?")
 		delFound = delFound || ok
 
-		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 0")
+		_, ok = session.stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 0")
 		selEvict = selEvict || !ok
 	}
 
-	stmtsLRU.Unlock()
+	session.stmtsLRU.Unlock()
 
 	if !selEvict {
 		t.Fatalf("expected first select statement to be purged, but statement was found in the cache.")

+ 5 - 13
cluster.go

@@ -14,32 +14,24 @@ import (
 
 const defaultMaxPreparedStmts = 1000
 
-//Package global reference to Prepared Statements LRU
-var stmtsLRU preparedLRU
-
 //preparedLRU is the prepared statement cache
 type preparedLRU struct {
-	sync.Mutex
+	sync.RWMutex
 	lru *lru.Cache
 }
 
 //Max adjusts the maximum size of the cache and cleans up the oldest records if
 //the new max is lower than the previous value. Not concurrency safe.
-func (p *preparedLRU) Max(max int) {
+func (p *preparedLRU) max(max int) {
+	p.Lock()
+	defer p.Unlock()
+
 	for p.lru.Len() > max {
 		p.lru.RemoveOldest()
 	}
 	p.lru.MaxEntries = max
 }
 
-func initStmtsLRU(max int) {
-	if stmtsLRU.lru != nil {
-		stmtsLRU.Max(max)
-	} else {
-		stmtsLRU.lru = lru.New(max)
-	}
-}
-
 // PoolConfig configures the connection pool used by the driver, it defaults to
 // using a round robbin host selection policy and a round robbin connection selection
 // policy for each host.

+ 16 - 21
conn.go

@@ -580,15 +580,10 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 }
 
 func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error) {
-	stmtsLRU.Lock()
-	if stmtsLRU.lru == nil {
-		initStmtsLRU(defaultMaxPreparedStmts)
-	}
-
+	c.session.stmtsLRU.Lock()
 	stmtCacheKey := c.addr + c.currentKeyspace + stmt
-
-	if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
-		stmtsLRU.Unlock()
+	if val, ok := c.session.stmtsLRU.lru.Get(stmtCacheKey); ok {
+		c.session.stmtsLRU.Unlock()
 		flight := val.(*inflightPrepare)
 		flight.wg.Wait()
 		return &flight.info, flight.err
@@ -596,8 +591,8 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error)
 
 	flight := new(inflightPrepare)
 	flight.wg.Add(1)
-	stmtsLRU.lru.Add(stmtCacheKey, flight)
-	stmtsLRU.Unlock()
+	c.session.stmtsLRU.lru.Add(stmtCacheKey, flight)
+	c.session.stmtsLRU.Unlock()
 
 	prep := &writePrepareFrame{
 		statement: stmt,
@@ -641,9 +636,9 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error)
 	flight.wg.Done()
 
 	if flight.err != nil {
-		stmtsLRU.Lock()
-		stmtsLRU.lru.Remove(stmtCacheKey)
-		stmtsLRU.Unlock()
+		c.session.stmtsLRU.Lock()
+		c.session.stmtsLRU.lru.Remove(stmtCacheKey)
+		c.session.stmtsLRU.Unlock()
 	}
 
 	framerPool.Put(framer)
@@ -763,14 +758,14 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		// is not consistent with regards to its schema.
 		return iter
 	case *RequestErrUnprepared:
-		stmtsLRU.Lock()
+		c.session.stmtsLRU.Lock()
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
-		if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
-			stmtsLRU.lru.Remove(stmtCacheKey)
-			stmtsLRU.Unlock()
+		if _, ok := c.session.stmtsLRU.lru.Get(stmtCacheKey); ok {
+			c.session.stmtsLRU.lru.Remove(stmtCacheKey)
+			c.session.stmtsLRU.Unlock()
 			return c.executeQuery(qry)
 		}
-		stmtsLRU.Unlock()
+		c.session.stmtsLRU.Unlock()
 		return &Iter{err: x, framer: framer}
 	case error:
 		return &Iter{err: x, framer: framer}
@@ -904,9 +899,9 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 	case *RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
-			stmtsLRU.Lock()
-			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
-			stmtsLRU.Unlock()
+			c.session.stmtsLRU.Lock()
+			c.session.stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
+			c.session.stmtsLRU.Unlock()
 		}
 
 		framerPool.Put(framer)

+ 2 - 5
session.go

@@ -39,6 +39,7 @@ type Session struct {
 	trace               Tracer
 	hostSource          *ringDescriber
 	ring                ring
+	stmtsLRU            *preparedLRU
 
 	connCfg *ConnConfig
 
@@ -85,16 +86,12 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		return nil, ErrNoHosts
 	}
 
-	//Adjust the size of the prepared statements cache to match the latest configuration
-	stmtsLRU.Lock()
-	initStmtsLRU(cfg.MaxPreparedStmts)
-	stmtsLRU.Unlock()
-
 	s := &Session{
 		cons:     cfg.Consistency,
 		prefetch: 0.25,
 		cfg:      cfg,
 		pageSize: cfg.PageSize,
+		stmtsLRU: &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)},
 	}
 
 	connCfg, err := connConfig(s)