浏览代码

Created lru cache using the groupcache/lru package.

Phillip Couto 11 年之前
父节点
当前提交
7d8a91b15d
共有 2 个文件被更改,包括 121 次插入35 次删除
  1. 83 3
      cassandra_test.go
  2. 38 32
      conn.go

+ 83 - 3
cassandra_test.go

@@ -10,6 +10,7 @@ import (
 	"reflect"
 	"reflect"
 	"sort"
 	"sort"
 	"speter.net/go/exp/math/dec/inf"
 	"speter.net/go/exp/math/dec/inf"
+	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"testing"
 	"testing"
@@ -583,9 +584,10 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	}
 	}
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	conn := session.Pool.Pick(nil)
 	conn := session.Pool.Pick(nil)
-	conn.prepMu.Lock()
 	flight := new(inflightPrepare)
 	flight := new(inflightPrepare)
-	conn.prep[stmt] = flight
+	stmtsLRU.mu.Lock()
+	stmtsLRU.lru.Add(conn.addr+stmt, flight)
+	stmtsLRU.mu.Unlock()
 	flight.info = &queryInfo{
 	flight.info = &queryInfo{
 		id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		args: []ColumnInfo{ColumnInfo{
 		args: []ColumnInfo{ColumnInfo{
@@ -604,7 +606,6 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 			},
 			},
 		}},
 		}},
 	}
 	}
-	conn.prepMu.Unlock()
 	return stmt, conn
 	return stmt, conn
 }
 }
 
 
@@ -629,3 +630,82 @@ func TestReprepareBatch(t *testing.T) {
 	}
 	}
 
 
 }
 }
+
+//TestPreparedCacheEviction will make sure that the cache size is maintained
+func TestPreparedCacheEviction(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	stmtsLRU.mu.Lock()
+	for stmtsLRU.lru.Len() > 10 {
+		stmtsLRU.lru.RemoveOldest()
+	}
+	stmtsLRU.lru.MaxEntries = 10
+	stmtsLRU.mu.Unlock()
+
+	if err := session.Query(`CREATE TABLE prepCacheEvict (
+id int,
+mod int,
+PRIMARY KEY (id)
+)`).Exec(); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	for i := 0; i < 100; i++ {
+		if err := session.Query(`INSERT INTO prepCacheEvict (id,mod) VALUES (?, ?)`,
+			i, 10000%(i+1)).Exec(); err != nil {
+			t.Fatal("insert:", err)
+		}
+	}
+
+	var id, mod int
+	for i := 0; i < 100; i++ {
+		err := session.Query("SELECT id,mod FROM prepcacheevict WHERE id = "+strconv.FormatInt(int64(i), 10)).Scan(&id, &mod)
+		if err != nil {
+			t.Error("select prepcacheevit:", err)
+			continue
+		}
+	}
+	if stmtsLRU.lru.Len() != stmtsLRU.lru.MaxEntries {
+		t.Errorf("expected cache size of %v, got %v", stmtsLRU.lru.MaxEntries, stmtsLRU.lru.Len())
+	}
+}
+
+//TestPreparedCacheAccuracy will test to make sure cached queries are moving to expected positions within the cache
+func TestPreparedCacheAccuracy(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	//purge cache
+	stmtsLRU.mu.Lock()
+	for stmtsLRU.lru.Len() > 0 {
+		stmtsLRU.lru.RemoveOldest()
+	}
+	stmtsLRU.lru.MaxEntries = 10
+	stmtsLRU.mu.Unlock()
+
+	if err := session.Query(`CREATE TABLE prepCacheacc (
+id int,
+mod int,
+PRIMARY KEY (id)
+)`).Exec(); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	for i := 0; i < 100; i++ {
+		if err := session.Query(`INSERT INTO prepCacheacc (id,mod) VALUES (?, ?)`,
+			i, 10000%(i+1)).Exec(); err != nil {
+			t.Fatal("insert:", err)
+		}
+	}
+
+	var id, mod int
+	for i := 0; i < 100; i++ {
+		err := session.Query("SELECT id,mod FROM prepCacheacc WHERE id = ?", i).Scan(&id, &mod)
+		if err != nil {
+			t.Error("select prepCacheacc:", err)
+			continue
+		}
+	}
+	if stmtsLRU.lru.Len() != 2 {
+		t.Errorf("expected cache size of %v, got %v", 2, stmtsLRU.lru.Len())
+	}
+}

+ 38 - 32
conn.go

@@ -6,8 +6,8 @@ package gocql
 
 
 import (
 import (
 	"bufio"
 	"bufio"
-	"bytes"
 	"fmt"
 	"fmt"
+	"github.com/golang/groupcache/lru"
 	"net"
 	"net"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
@@ -18,6 +18,14 @@ const defaultFrameSize = 4096
 const flagResponse = 0x80
 const flagResponse = 0x80
 const maskVersion = 0x7F
 const maskVersion = 0x7F
 
 
+//Package global reference to Prepared Statements LRU
+var stmtsLRU *preparedLRU
+
+//init houses could to initialize components related to connections like LRU for prepared statements
+func init() {
+	stmtsLRU = &preparedLRU{lru: lru.New(10)}
+}
+
 type Authenticator interface {
 type Authenticator interface {
 	Challenge(req []byte) (resp []byte, auth Authenticator, err error)
 	Challenge(req []byte) (resp []byte, auth Authenticator, err error)
 	Success(data []byte) error
 	Success(data []byte) error
@@ -66,9 +74,6 @@ type Conn struct {
 	calls []callReq
 	calls []callReq
 	nwait int32
 	nwait int32
 
 
-	prepMu sync.Mutex
-	prep   map[string]*inflightPrepare
-
 	pool       ConnectionPool
 	pool       ConnectionPool
 	compressor Compressor
 	compressor Compressor
 	auth       Authenticator
 	auth       Authenticator
@@ -98,7 +103,6 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		r:          bufio.NewReader(conn),
 		r:          bufio.NewReader(conn),
 		uniq:       make(chan uint8, cfg.NumStreams),
 		uniq:       make(chan uint8, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
-		prep:       make(map[string]*inflightPrepare),
 		timeout:    cfg.Timeout,
 		timeout:    cfg.Timeout,
 		version:    uint8(cfg.ProtoVersion),
 		version:    uint8(cfg.ProtoVersion),
 		addr:       conn.RemoteAddr().String(),
 		addr:       conn.RemoteAddr().String(),
@@ -313,18 +317,18 @@ func (c *Conn) ping() error {
 }
 }
 
 
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
-	c.prepMu.Lock()
-	flight := c.prep[stmt]
-	if flight != nil {
-		c.prepMu.Unlock()
+	stmtsLRU.mu.Lock()
+	if val, ok := stmtsLRU.lru.Get(c.addr + stmt); ok {
+		flight := val.(*inflightPrepare)
+		stmtsLRU.mu.Unlock()
 		flight.wg.Wait()
 		flight.wg.Wait()
 		return flight.info, flight.err
 		return flight.info, flight.err
 	}
 	}
 
 
-	flight = new(inflightPrepare)
+	flight := new(inflightPrepare)
 	flight.wg.Add(1)
 	flight.wg.Add(1)
-	c.prep[stmt] = flight
-	c.prepMu.Unlock()
+	stmtsLRU.lru.Add(c.addr+stmt, flight)
+	stmtsLRU.mu.Unlock()
 
 
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	if err != nil {
 	if err != nil {
@@ -346,9 +350,9 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 	flight.wg.Done()
 	flight.wg.Done()
 
 
 	if err != nil {
 	if err != nil {
-		c.prepMu.Lock()
-		delete(c.prep, stmt)
-		c.prepMu.Unlock()
+		stmtsLRU.mu.Lock()
+		stmtsLRU.lru.Remove(c.addr + stmt)
+		stmtsLRU.mu.Unlock()
 	}
 	}
 
 
 	return flight.info, flight.err
 	return flight.info, flight.err
@@ -400,13 +404,13 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case resultKeyspaceFrame:
 	case resultKeyspaceFrame:
 		return &Iter{}
 		return &Iter{}
 	case RequestErrUnprepared:
 	case RequestErrUnprepared:
-		c.prepMu.Lock()
-		if val, ok := c.prep[qry.stmt]; ok && val != nil {
-			delete(c.prep, qry.stmt)
-			c.prepMu.Unlock()
+		stmtsLRU.mu.Lock()
+		if _, ok := stmtsLRU.lru.Get(c.addr + qry.stmt); ok {
+			stmtsLRU.lru.Remove(c.addr + qry.stmt)
+			stmtsLRU.mu.Unlock()
 			return c.executeQuery(qry)
 			return c.executeQuery(qry)
 		}
 		}
-		c.prepMu.Unlock()
+		stmtsLRU.mu.Unlock()
 		return &Iter{err: x}
 		return &Iter{err: x}
 	case error:
 	case error:
 		return &Iter{err: x}
 		return &Iter{err: x}
@@ -468,12 +472,16 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.writeByte(byte(batch.Type))
 	f.writeByte(byte(batch.Type))
 	f.writeShort(uint16(len(batch.Entries)))
 	f.writeShort(uint16(len(batch.Entries)))
+
+	stmts := make(map[string]string)
+
 	for i := 0; i < len(batch.Entries); i++ {
 	for i := 0; i < len(batch.Entries); i++ {
 		entry := &batch.Entries[i]
 		entry := &batch.Entries[i]
 		var info *queryInfo
 		var info *queryInfo
 		if len(entry.Args) > 0 {
 		if len(entry.Args) > 0 {
 			var err error
 			var err error
 			info, err = c.prepareStatement(entry.Stmt, nil)
 			info, err = c.prepareStatement(entry.Stmt, nil)
+			stmts[string(info.id)] = entry.Stmt
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -502,19 +510,12 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	case resultVoidFrame:
 	case resultVoidFrame:
 		return nil
 		return nil
 	case RequestErrUnprepared:
 	case RequestErrUnprepared:
-		c.prepMu.Lock()
-		found := false
-		for stmt, flight := range c.prep {
-			if flight == nil || flight.info == nil {
-				continue
-			}
-			if bytes.Equal(flight.info.id, x.StatementId) {
-				found = true
-				delete(c.prep, stmt)
-				break
-			}
+		stmtsLRU.mu.Lock()
+		stmt, found := stmts[string(x.StatementId)]
+		if found {
+			stmtsLRU.lru.Remove(c.addr + stmt)
 		}
 		}
-		c.prepMu.Unlock()
+		stmtsLRU.mu.Unlock()
 		if found {
 		if found {
 			return c.executeBatch(batch)
 			return c.executeBatch(batch)
 		} else {
 		} else {
@@ -626,3 +627,8 @@ type inflightPrepare struct {
 	err  error
 	err  error
 	wg   sync.WaitGroup
 	wg   sync.WaitGroup
 }
 }
+
+type preparedLRU struct {
+	lru *lru.Cache
+	mu  sync.Mutex
+}