Forráskód Böngészése

Merge upstream master

Ben Hood 11 éve
szülő
commit
553692227e
3 módosított fájl, 169 hozzáadás és 60 törlés
  1. 89 3
      cassandra_test.go
  2. 56 25
      cluster.go
  3. 24 32
      conn.go

+ 89 - 3
cassandra_test.go

@@ -10,6 +10,7 @@ import (
 	"reflect"
 	"sort"
 	"speter.net/go/exp/math/dec/inf"
+	"strconv"
 	"strings"
 	"sync"
 	"testing"
@@ -625,9 +626,10 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	}
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	conn := session.Pool.Pick(nil)
-	conn.prepMu.Lock()
 	flight := new(inflightPrepare)
-	conn.prep[stmt] = flight
+	stmtsLRU.mu.Lock()
+	stmtsLRU.lru.Add(conn.addr+stmt, flight)
+	stmtsLRU.mu.Unlock()
 	flight.info = &queryInfo{
 		id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		args: []ColumnInfo{ColumnInfo{
@@ -639,7 +641,6 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 			},
 		}},
 	}
-	conn.prepMu.Unlock()
 	return stmt, conn
 }
 
@@ -664,3 +665,88 @@ func TestReprepareBatch(t *testing.T) {
 	}
 
 }
+
+//TestPreparedCacheEviction will make sure that the cache size is maintained
+func TestPreparedCacheEviction(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	stmtsLRU.mu.Lock()
+	stmtsLRU.Max(4)
+	stmtsLRU.mu.Unlock()
+
+	if err := session.Query("CREATE TABLE prepcachetest (id int,mod int,PRIMARY KEY (id))").Exec(); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+	//Fill the table
+	for i := 0; i < 2; i++ {
+		if err := session.Query("INSERT INTO prepcachetest (id,mod) VALUES (?, ?)", i, 10000%(i+1)).Exec(); err != nil {
+			t.Fatalf("insert into prepcachetest failed, err '%v'", err)
+		}
+	}
+	//Populate the prepared statement cache with select statements
+	var id, mod int
+	for i := 0; i < 2; i++ {
+		err := session.Query("SELECT id,mod FROM prepcachetest WHERE id = "+strconv.FormatInt(int64(i), 10)).Scan(&id, &mod)
+		if err != nil {
+			t.Fatalf("select from prepcachetest failed, error '%v'", err)
+		}
+	}
+
+	//generate an update statement to test they are prepared
+	err := session.Query("UPDATE prepcachetest SET mod = ? WHERE id = ?", 1, 11).Exec()
+	if err != nil {
+		t.Fatalf("update prepcachetest failed, error '%v'", err)
+	}
+
+	//generate a delete statement to test they are prepared
+	err = session.Query("DELETE FROM prepcachetest WHERE id = ?", 1).Exec()
+	if err != nil {
+		t.Fatalf("delete from prepcachetest failed, error '%v'", err)
+	}
+
+	//generate an insert statement to test they are prepared
+	err = session.Query("INSERT INTO prepcachetest (id,mod) VALUES (?, ?)", 3, 11).Exec()
+	if err != nil {
+		t.Fatalf("insert into prepcachetest failed, error '%v'", err)
+	}
+
+	//Make sure the cache size is maintained
+	if stmtsLRU.lru.Len() != stmtsLRU.lru.MaxEntries {
+		t.Fatalf("expected cache size of %v, got %v", stmtsLRU.lru.MaxEntries, stmtsLRU.lru.Len())
+	}
+
+	//Walk through all the configured hosts and test cache retention and eviction
+	var selFound, insFound, updFound, delFound, selEvict bool
+	for i := range session.cfg.Hosts {
+		_, ok := stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042SELECT id,mod FROM prepcachetest WHERE id = 1")
+		selFound = selFound || ok
+
+		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042INSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
+		insFound = insFound || ok
+
+		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042UPDATE prepcachetest SET mod = ? WHERE id = ?")
+		updFound = updFound || ok
+
+		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042DELETE FROM prepcachetest WHERE id = ?")
+		delFound = delFound || ok
+
+		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042SELECT id,mod FROM prepcachetest WHERE id = 0")
+		selEvict = selEvict || !ok
+
+	}
+	if !selEvict {
+		t.Fatalf("expected first select statement to be purged, but statement was found in the cache.")
+	}
+	if !selFound {
+		t.Fatalf("expected second select statement to be cached, but statement was purged or not prepared.")
+	}
+	if !insFound {
+		t.Fatalf("expected insert statement to be cached, but statement was purged or not prepared.")
+	}
+	if !updFound {
+		t.Fatalf("expected update statement to be cached, but statement was purged or not prepared.")
+	}
+	if !delFound {
+		t.Error("expected delete statement to be cached, but statement was purged or not prepared.")
+	}
+}

+ 56 - 25
cluster.go

@@ -6,44 +6,66 @@ package gocql
 
 import (
 	"errors"
+	"github.com/golang/groupcache/lru"
+	"sync"
 	"time"
 )
 
+//Package global reference to Prepared Statements LRU
+var stmtsLRU preparedLRU
+
+//preparedLRU is the prepared statement cache
+type preparedLRU struct {
+	lru *lru.Cache
+	mu  sync.Mutex
+}
+
+//Max adjusts the maximum size of the cache and cleans up the oldest records if
+//the new max is lower than the previous value. Not concurrency safe.
+func (p *preparedLRU) Max(max int) {
+	for p.lru.Len() > max {
+		p.lru.RemoveOldest()
+	}
+	p.lru.MaxEntries = max
+}
+
 // ClusterConfig is a struct to configure the default cluster implementation
 // of gocoql. It has a varity of attributes that can be used to modify the
 // behavior to fit the most common use cases. Applications that requre a
 // different setup must implement their own cluster.
 type ClusterConfig struct {
-	Hosts           []string      // addresses for the initial connections
-	CQLVersion      string        // CQL version (default: 3.0.0)
-	ProtoVersion    int           // version of the native protocol (default: 2)
-	Timeout         time.Duration // connection timeout (default: 600ms)
-	DefaultPort     int           // default port (default: 9042)
-	Keyspace        string        // initial keyspace (optional)
-	NumConns        int           // number of connections per host (default: 2)
-	NumStreams      int           // number of streams per connection (default: 128)
-	Consistency     Consistency   // default consistency level (default: Quorum)
-	Compressor      Compressor    // compression algorithm (default: nil)
-	Authenticator   Authenticator // authenticator (default: nil)
-	RetryPolicy     RetryPolicy   // Default retry policy to use for queries (default: 0)
-	SocketKeepalive time.Duration // The keepalive period to use, enabled if > 0 (default: 0)
-	ConnPoolType    NewPoolFunc   // The function used to create the connection pool for the session (default: NewSimplePool)
-	DiscoverHosts   bool          // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
+	Hosts            []string      // addresses for the initial connections
+	CQLVersion       string        // CQL version (default: 3.0.0)
+	ProtoVersion     int           // version of the native protocol (default: 2)
+	Timeout          time.Duration // connection timeout (default: 600ms)
+	DefaultPort      int           // default port (default: 9042)
+	Keyspace         string        // initial keyspace (optional)
+	NumConns         int           // number of connections per host (default: 2)
+	NumStreams       int           // number of streams per connection (default: 128)
+	Consistency      Consistency   // default consistency level (default: Quorum)
+	Compressor       Compressor    // compression algorithm (default: nil)
+	Authenticator    Authenticator // authenticator (default: nil)
+	RetryPolicy      RetryPolicy   // Default retry policy to use for queries (default: 0)
+	SocketKeepalive  time.Duration // The keepalive period to use, enabled if > 0 (default: 0)
+	ConnPoolType     NewPoolFunc   // The function used to create the connection pool for the session (default: NewSimplePool)
+	DiscoverHosts    bool          // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
+	MaxPreparedStmts int           // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
 }
 
 // NewCluster generates a new config for the default cluster implementation.
 func NewCluster(hosts ...string) *ClusterConfig {
 	cfg := &ClusterConfig{
-		Hosts:         hosts,
-		CQLVersion:    "3.0.0",
-		ProtoVersion:  2,
-		Timeout:       600 * time.Millisecond,
-		DefaultPort:   9042,
-		NumConns:      2,
-		NumStreams:    128,
-		Consistency:   Quorum,
-		ConnPoolType:  NewSimplePool,
-		DiscoverHosts: false,
+		Hosts:            hosts,
+		CQLVersion:       "3.0.0",
+		ProtoVersion:     2,
+		Timeout:          600 * time.Millisecond,
+		DefaultPort:      9042,
+		NumConns:         2,
+		NumStreams:       128,
+		Consistency:      Quorum,
+		ConnPoolType:     NewSimplePool,
+		DiscoverHosts:    false,
+		MaxPreparedStmts: 1000,
 	}
 	return cfg
 }
@@ -58,6 +80,15 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	}
 	pool := cfg.ConnPoolType(cfg)
 
+	//Adjust the size of the prepared statements cache to match the latest configuration
+	stmtsLRU.mu.Lock()
+	if stmtsLRU.lru != nil {
+		stmtsLRU.Max(cfg.MaxPreparedStmts)
+	} else {
+		stmtsLRU.lru = lru.New(cfg.MaxPreparedStmts)
+	}
+	stmtsLRU.mu.Unlock()
+
 	//See if there are any connections in the pool
 	if pool.Size() > 0 {
 		s := NewSession(pool, *cfg)

+ 24 - 32
conn.go

@@ -6,7 +6,6 @@ package gocql
 
 import (
 	"bufio"
-	"bytes"
 	"errors"
 	"fmt"
 	"net"
@@ -67,9 +66,6 @@ type Conn struct {
 	calls []callReq
 	nwait int32
 
-	prepMu sync.Mutex
-	prep   map[string]*inflightPrepare
-
 	pool       ConnectionPool
 	compressor Compressor
 	auth       Authenticator
@@ -99,7 +95,6 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		r:          bufio.NewReader(conn),
 		uniq:       make(chan uint8, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
-		prep:       make(map[string]*inflightPrepare),
 		timeout:    cfg.Timeout,
 		version:    uint8(cfg.ProtoVersion),
 		addr:       conn.RemoteAddr().String(),
@@ -314,18 +309,18 @@ func (c *Conn) ping() error {
 }
 
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
-	c.prepMu.Lock()
-	flight := c.prep[stmt]
-	if flight != nil {
-		c.prepMu.Unlock()
+	stmtsLRU.mu.Lock()
+	if val, ok := stmtsLRU.lru.Get(c.addr + stmt); ok {
+		flight := val.(*inflightPrepare)
+		stmtsLRU.mu.Unlock()
 		flight.wg.Wait()
 		return flight.info, flight.err
 	}
 
-	flight = new(inflightPrepare)
+	flight := new(inflightPrepare)
 	flight.wg.Add(1)
-	c.prep[stmt] = flight
-	c.prepMu.Unlock()
+	stmtsLRU.lru.Add(c.addr+stmt, flight)
+	stmtsLRU.mu.Unlock()
 
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	if err != nil {
@@ -347,9 +342,9 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 	flight.wg.Done()
 
 	if err != nil {
-		c.prepMu.Lock()
-		delete(c.prep, stmt)
-		c.prepMu.Unlock()
+		stmtsLRU.mu.Lock()
+		stmtsLRU.lru.Remove(c.addr + stmt)
+		stmtsLRU.mu.Unlock()
 	}
 
 	return flight.info, flight.err
@@ -404,13 +399,13 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case resultKeyspaceFrame:
 		return &Iter{}
 	case RequestErrUnprepared:
-		c.prepMu.Lock()
-		if val, ok := c.prep[qry.stmt]; ok && val != nil {
-			delete(c.prep, qry.stmt)
-			c.prepMu.Unlock()
+		stmtsLRU.mu.Lock()
+		if _, ok := stmtsLRU.lru.Get(c.addr + qry.stmt); ok {
+			stmtsLRU.lru.Remove(c.addr + qry.stmt)
+			stmtsLRU.mu.Unlock()
 			return c.executeQuery(qry)
 		}
-		c.prepMu.Unlock()
+		stmtsLRU.mu.Unlock()
 		return &Iter{err: x}
 	case error:
 		return &Iter{err: x}
@@ -472,12 +467,16 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.writeByte(byte(batch.Type))
 	f.writeShort(uint16(len(batch.Entries)))
+
+	stmts := make(map[string]string)
+
 	for i := 0; i < len(batch.Entries); i++ {
 		entry := &batch.Entries[i]
 		var info *queryInfo
 		if len(entry.Args) > 0 {
 			var err error
 			info, err = c.prepareStatement(entry.Stmt, nil)
+			stmts[string(info.id)] = entry.Stmt
 			if err != nil {
 				return err
 			}
@@ -506,19 +505,12 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	case resultVoidFrame:
 		return nil
 	case RequestErrUnprepared:
-		c.prepMu.Lock()
-		found := false
-		for stmt, flight := range c.prep {
-			if flight == nil || flight.info == nil {
-				continue
-			}
-			if bytes.Equal(flight.info.id, x.StatementId) {
-				found = true
-				delete(c.prep, stmt)
-				break
-			}
+		stmt, found := stmts[string(x.StatementId)]
+		if found {
+			stmtsLRU.mu.Lock()
+			stmtsLRU.lru.Remove(c.addr + stmt)
+			stmtsLRU.mu.Unlock()
 		}
-		c.prepMu.Unlock()
 		if found {
 			return c.executeBatch(batch)
 		} else {