Browse Source

Speculative query execution (#1178)

* Rework metrics locking

* Metrics are now split into:
     hostMetrics - for a list of metrics
     queryMetrics - for a map and a locker
* Added functions to perform locked metrics updates/reads
* Locking is private for the metrics only, so should have no
performance effects.

Signed-off-by: Alex Lourie <alex@instaclustr.com>

* Introduce Speculative Policy

* Define the speculative policy
* Add NonSpeculative policy
* Add SimpleSpeculative policy

Signed-off-by: Alex Lourie <alex@instaclustr.com>

* Add IsIdempotent to ExecutableQuery interface

Signed-off-by: Alex Lourie <alex@instaclustr.com>

* Implement speculative execution

* Refactor executeQuery to execute main code in a separate goroutine
* Handle speculative/non-speculative cases separately
* Add TestSpeculativeExecution test

Signed-off-by: Alex Lourie <alex@instaclustr.com>

* Review comments

* Make one code path for all executions
* Simplify the results handling
* Update the tests

Signed-off-by: Alex Lourie <alex@instaclustr.com>

* More review comments

* Metric lock improvements
* Style cleanups

Signed-off-by: Alex Lourie <alex@instaclustr.com>

* Fix Latency calc lock

Signed-off-by: Alex Lourie <alex@instaclustr.com>

* Fix session.go for new metrics

Signed-off-by: Alex Lourie <alex@instaclustr.com>
Alex Lourie 7 years ago
parent
commit
aa46e85d0a
7 changed files with 374 additions and 100 deletions
  1. 102 7
      conn_test.go
  2. 1 2
      control.go
  3. 25 0
      policies.go
  4. 4 4
      policies_test.go
  5. 110 45
      query_executor.go
  6. 126 37
      session.go
  7. 6 5
      session_test.go

+ 102 - 7
conn_test.go

@@ -14,7 +14,9 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"math/rand"
 	"net"
+	"os"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -44,8 +46,8 @@ func TestApprove(t *testing.T) {
 
 func TestJoinHostPort(t *testing.T) {
 	tests := map[string]string{
-		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
-		"127.0.0.1:1":                                 JoinHostPort("127.0.0.1:1", 9142),
+		"127.0.0.1:0": JoinHostPort("127.0.0.1", 0),
+		"127.0.0.1:1": JoinHostPort("127.0.0.1:1", 9142),
 		"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:0": JoinHostPort("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 0),
 		"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1": JoinHostPort("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1", 9142),
 	}
@@ -316,7 +318,7 @@ func TestCancel(t *testing.T) {
 }
 
 type testQueryObserver struct {
-	metrics map[string]*queryMetrics
+	metrics map[string]*hostMetrics
 	verbose bool
 }
 
@@ -329,7 +331,7 @@ func (o *testQueryObserver) ObserveQuery(ctx context.Context, q ObservedQuery) {
 	}
 }
 
-func (o *testQueryObserver) GetMetrics(host *HostInfo) *queryMetrics {
+func (o *testQueryObserver) GetMetrics(host *HostInfo) *hostMetrics {
 	return o.metrics[host.ConnectAddress().String()]
 }
 
@@ -377,6 +379,12 @@ func TestQueryRetry(t *testing.T) {
 }
 
 func TestQueryMultinodeWithMetrics(t *testing.T) {
+	log := &testLogger{}
+	Logger = log
+	defer func() {
+		Logger = &defaultLogger{}
+		os.Stdout.WriteString(log.String())
+	}()
 
 	// Build a 3 node cluster to test host metric mapping
 	var nodes []*TestServer
@@ -401,7 +409,7 @@ func TestQueryMultinodeWithMetrics(t *testing.T) {
 
 	// 1 retry per host
 	rt := &SimpleRetryPolicy{NumRetries: 3}
-	observer := &testQueryObserver{metrics: make(map[string]*queryMetrics), verbose: false}
+	observer := &testQueryObserver{metrics: make(map[string]*hostMetrics), verbose: false}
 	qry := db.Query("kill").RetryPolicy(rt).Observer(observer)
 	if err := qry.Exec(); err == nil {
 		t.Fatalf("expected error")
@@ -409,10 +417,11 @@ func TestQueryMultinodeWithMetrics(t *testing.T) {
 
 	for i, ip := range addresses {
 		host := &HostInfo{connectAddress: net.ParseIP(ip)}
+		queryMetric := qry.getHostMetrics(host)
 		observedMetrics := observer.GetMetrics(host)
 
 		requests := int(atomic.LoadInt64(&nodes[i].nKillReq))
-		hostAttempts := qry.metrics[ip].Attempts
+		hostAttempts := queryMetric.Attempts
 		if requests != hostAttempts {
 			t.Fatalf("expected requests %v to match query attempts %v", requests, hostAttempts)
 		}
@@ -421,7 +430,7 @@ func TestQueryMultinodeWithMetrics(t *testing.T) {
 			t.Fatalf("expected observed attempts %v to match query attempts %v on host %v", observedMetrics.Attempts, hostAttempts, ip)
 		}
 
-		hostLatency := qry.metrics[ip].TotalLatency
+		hostLatency := queryMetric.TotalLatency
 		observedLatency := observedMetrics.TotalLatency
 		if hostLatency != observedLatency {
 			t.Fatalf("expected observed latency %v to match query latency %v on host %v", observedLatency, hostLatency, ip)
@@ -435,6 +444,79 @@ func TestQueryMultinodeWithMetrics(t *testing.T) {
 
 }
 
+type testRetryPolicy struct {
+	NumRetries int
+}
+
+func (t *testRetryPolicy) Attempt(qry RetryableQuery) bool {
+	return qry.Attempts() <= t.NumRetries
+}
+func (t *testRetryPolicy) GetRetryType(err error) RetryType {
+	return Retry
+}
+
+func TestSpeculativeExecution(t *testing.T) {
+	log := &testLogger{}
+	Logger = log
+	defer func() {
+		Logger = &defaultLogger{}
+		os.Stdout.WriteString(log.String())
+	}()
+
+	// Build a 3 node cluster
+	var nodes []*TestServer
+	var addresses = []string{
+		"127.0.0.1",
+		"127.0.0.2",
+		"127.0.0.3",
+	}
+	// Can do with 1 context for all servers
+	ctx := context.Background()
+	for _, ip := range addresses {
+		srv := NewTestServerWithAddress(ip+":0", t, defaultProto, ctx)
+		defer srv.Stop()
+		nodes = append(nodes, srv)
+	}
+
+	db, err := newTestSession(defaultProto, nodes[0].Address, nodes[1].Address, nodes[2].Address)
+	if err != nil {
+		t.Fatalf("NewCluster: %v", err)
+	}
+	defer db.Close()
+
+	// Create a test retry policy, 6 retries will cover 2 executions
+	rt := &testRetryPolicy{NumRetries: 8}
+	// test Speculative policy with 1 additional execution
+	sp := &SimpleSpeculativeExecution{NumAttempts: 1, TimeoutDelay: 200 * time.Millisecond}
+
+	// Build the query
+	qry := db.Query("speculative").RetryPolicy(rt).SetSpeculativeExecutionPolicy(sp).Idempotent(true)
+
+	// Execute the query and close, check that it doesn't error out
+	if err := qry.Exec(); err != nil {
+		t.Errorf("The query failed with '%v'!\n", err)
+	}
+	requests1 := atomic.LoadInt64(&nodes[0].nKillReq)
+	requests2 := atomic.LoadInt64(&nodes[1].nKillReq)
+	requests3 := atomic.LoadInt64(&nodes[2].nKillReq)
+
+	// Spec Attempts == 1, so expecting to see only 1 regular + 1 speculative = 2 nodes attempted
+	if requests1 != 0 && requests2 != 0 && requests3 != 0 {
+		t.Error("error: all 3 nodes were attempted, should have been only 2")
+	}
+
+	// Only the 4th request will generate results, so
+	if requests1 != 4 && requests2 != 4 && requests3 != 4 {
+		t.Error("error: none of 3 nodes was attempted 4 times!")
+	}
+
+	// "speculative" query will succeed on one arbitrary node after 4 attempts, so
+	// expecting to see 4 (on successful node) + not more than 2 (as cancelled on another node) == 6
+	if requests1+requests2+requests3 > 6 {
+		t.Errorf("error: expected to see 6 attempts, got %v\n", requests1+requests2+requests3)
+	}
+}
+
 func TestStreams_Protocol1(t *testing.T) {
 	srv := NewTestServer(t, protoVersion1, context.Background())
 	defer srv.Stop()
@@ -1107,6 +1189,19 @@ func (srv *TestServer) process(f *framer) {
 				}
 			}()
 			return
+		case "speculative":
+			atomic.AddInt64(&srv.nKillReq, 1)
+			if atomic.LoadInt64(&srv.nKillReq) > 3 {
+				f.writeHeader(0, opResult, head.stream)
+				f.writeInt(resultKindVoid)
+				f.writeString("speculative query success on the node " + srv.Address)
+			} else {
+				f.writeHeader(0, opError, head.stream)
+				f.writeInt(0x1001)
+				f.writeString("speculative error")
+				rand.Seed(time.Now().UnixNano())
+				<-time.After(time.Millisecond * 120)
+			}
 		default:
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)

+ 1 - 2
control.go

@@ -453,8 +453,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 			Logger.Printf("control: error executing %q: %v\n", statement, iter.err)
 		}
 
-		metric := q.getHostMetrics(c.getConn().host)
-		metric.Attempts++
+		q.AddAttempts(1, c.getConn().host)
 		if iter.err == nil || !c.retry.Attempt(q) {
 			break
 		}

+ 25 - 0
policies.go

@@ -5,6 +5,8 @@
 package gocql
 
 import (
+	"context"
+	"errors"
 	"fmt"
 	"math"
 	"math/rand"
@@ -130,6 +132,7 @@ type RetryableQuery interface {
 	Attempts() int
 	SetConsistency(c Consistency)
 	GetConsistency() Consistency
+	GetContext() context.Context
 }
 
 type RetryType uint16
@@ -141,6 +144,10 @@ const (
 	Rethrow       RetryType = 0x03 // raise error and stop retrying
 )
 
+// ErrUnknownRetryType is returned if the retry policy returns a retry type
+// unknown to the query executor.
+var ErrUnknownRetryType = errors.New("unknown retry type returned by retry policy")
+
 // RetryPolicy interface is used by gocql to determine if a query can be attempted
 // again after a retryable error has been received. The interface allows gocql
 // users to implement their own logic to determine if a query can be attempted
@@ -852,3 +859,21 @@ func (e *ExponentialReconnectionPolicy) GetInterval(currentRetry int) time.Durat
 func (e *ExponentialReconnectionPolicy) GetMaxRetries() int {
 	return e.MaxRetries
 }
+
+type SpeculativeExecutionPolicy interface {
+	Attempts() int
+	Delay() time.Duration
+}
+
+type NonSpeculativeExecution struct{}
+
+func (sp NonSpeculativeExecution) Attempts() int        { return 0 } // No additional attempts
+func (sp NonSpeculativeExecution) Delay() time.Duration { return 1 } // The delay. Must be positive to be used in a ticker.
+
+type SimpleSpeculativeExecution struct {
+	NumAttempts  int
+	TimeoutDelay time.Duration
+}
+
+func (sp *SimpleSpeculativeExecution) Attempts() int        { return sp.NumAttempts }
+func (sp *SimpleSpeculativeExecution) Delay() time.Duration { return sp.TimeoutDelay }

+ 4 - 4
policies_test.go

@@ -263,9 +263,9 @@ func TestSimpleRetryPolicy(t *testing.T) {
 		{5, false},
 	}
 
-	q.metrics = make(map[string]*queryMetrics)
+	q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
 	for _, c := range cases {
-		q.metrics["127.0.0.1"] = &queryMetrics{Attempts: c.attempts}
+		q.metrics.m["127.0.0.1"] = &hostMetrics{Attempts: c.attempts}
 		if c.allow && !rt.Attempt(q) {
 			t.Fatalf("should allow retry after %d attempts", c.attempts)
 		}
@@ -348,9 +348,9 @@ func TestDowngradingConsistencyRetryPolicy(t *testing.T) {
 		{16, false, reu1, Retry},
 	}
 
-	q.metrics = make(map[string]*queryMetrics)
+	q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
 	for _, c := range cases {
-		q.metrics["127.0.0.1"] = &queryMetrics{Attempts: c.attempts}
+		q.metrics.m["127.0.0.1"] = &hostMetrics{Attempts: c.attempts}
 		if c.retryType != rt.GetRetryType(c.err) {
 			t.Fatalf("retry type should be %v", c.retryType)
 		}

+ 110 - 45
query_executor.go

@@ -1,6 +1,7 @@
 package gocql
 
 import (
+	"sync"
 	"time"
 )
 
@@ -8,9 +9,11 @@ type ExecutableQuery interface {
 	execute(conn *Conn) *Iter
 	attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
 	retryPolicy() RetryPolicy
+	speculativeExecutionPolicy() SpeculativeExecutionPolicy
 	GetRoutingKey() ([]byte, error)
 	Keyspace() string
 	Cancel()
+	IsIdempotent() bool
 	RetryableQuery
 }
 
@@ -19,6 +22,11 @@ type queryExecutor struct {
 	policy HostSelectionPolicy
 }
 
+type queryResponse struct {
+	iter *Iter
+	err  error
+}
+
 func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
 	start := time.Now()
 	iter := qry.execute(conn)
@@ -30,12 +38,74 @@ func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
 }
 
 func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
-	rt := qry.retryPolicy()
+
+	// check if the query is not marked as idempotent, if
+	// it is, we force the policy to NonSpeculative
+	sp := qry.speculativeExecutionPolicy()
+	if !qry.IsIdempotent() {
+		sp = NonSpeculativeExecution{}
+	}
+
+	results := make(chan queryResponse, 1)
+	stop := make(chan struct{})
+	defer close(stop)
+	var specWG sync.WaitGroup
+
+	// Launch the main execution
+	specWG.Add(1)
+	go q.run(qry, &specWG, results, stop)
+
+	// The speculative executions are launched _in addition_ to the main
+	// execution, on a timer. So Speculation{2} would make 3 executions running
+	// in total.
+	go func() {
+		// Handle the closing of the resources. We do it here because it's
+		// right after we finish launching executions. Otherwise clearing the
+		// wait group is complicated.
+		defer func() {
+			specWG.Wait()
+			close(results)
+		}()
+
+		// setup a ticker
+		ticker := time.NewTicker(sp.Delay())
+		defer ticker.Stop()
+
+		for i := 0; i < sp.Attempts(); i++ {
+			select {
+			case <-ticker.C:
+				// Launch the additional execution
+				specWG.Add(1)
+				go q.run(qry, &specWG, results, stop)
+			case <-qry.GetContext().Done():
+				// not starting additional executions
+				return
+			case <-stop:
+				// not starting additional executions
+				return
+			}
+		}
+	}()
+
+	res := <-results
+	if res.iter == nil && res.err == nil {
+		// if we're here, the results channel was closed, so no more hosts
+		return nil, ErrNoConnections
+	}
+	return res.iter, res.err
+}
+
+func (q *queryExecutor) run(qry ExecutableQuery, specWG *sync.WaitGroup, results chan queryResponse, stop chan struct{}) {
+	// Handle the wait group
+	defer specWG.Done()
+
 	hostIter := q.policy.Pick(qry)
+	selectedHost := hostIter()
+	rt := qry.retryPolicy()
 
 	var iter *Iter
-	for hostResponse := hostIter(); hostResponse != nil; hostResponse = hostIter() {
-		host := hostResponse.Info()
+	for selectedHost != nil {
+		host := selectedHost.Info()
 		if host == nil || !host.IsUp() {
 			continue
 		}
@@ -50,51 +120,46 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 			continue
 		}
 
-		iter = q.attemptQuery(qry, conn)
-		// Update host
-		hostResponse.Mark(iter.err)
-
-		if rt == nil {
-			iter.host = host
-			break
-		}
-
-		switch rt.GetRetryType(iter.err) {
-		case Retry:
-			for rt.Attempt(qry) {
-				iter = q.attemptQuery(qry, conn)
-				hostResponse.Mark(iter.err)
-				if iter.err == nil {
-					iter.host = host
-					return iter, nil
-				}
-				if rt.GetRetryType(iter.err) != Retry {
-					break
-				}
-			}
-		case Rethrow:
-			return nil, iter.err
-		case Ignore:
-			return iter, nil
-		case RetryNextHost:
+		select {
+		case <-stop:
+			// stop this execution and return
+			return
 		default:
-		}
-
-		// Exit for loop if the query was successful
-		if iter.err == nil {
-			iter.host = host
-			return iter, nil
-		}
+			// Run the query
+			iter = q.attemptQuery(qry, conn)
+			iter.host = selectedHost.Info()
+			// Update host
+			selectedHost.Mark(iter.err)
+
+			// Exit if the query was successful
+			// or no retry policy defined or retry attempts were reached
+			if iter.err == nil || rt == nil || !rt.Attempt(qry) {
+				results <- queryResponse{iter: iter}
+				return
+			}
 
-		if !rt.Attempt(qry) {
-			// What do here? Should we just return an error here?
-			break
+			// If query is unsuccessful, check the error with RetryPolicy to retry
+			switch rt.GetRetryType(iter.err) {
+			case Retry:
+				// retry on the same host
+				continue
+			case Rethrow:
+				results <- queryResponse{err: iter.err}
+				return
+			case Ignore:
+				results <- queryResponse{iter: iter}
+				return
+			case RetryNextHost:
+				// retry on the next host
+				selectedHost = hostIter()
+				continue
+			default:
+				// Undefined? Return nil and error, this will panic in the requester
+				results <- queryResponse{iter: nil, err: ErrUnknownRetryType}
+				return
+			}
 		}
-	}
 
-	if iter == nil {
-		return nil, ErrNoConnections
 	}
-
-	return iter, nil
+	// All hosts are exhausted, return nothing
 }

+ 126 - 37
session.go

@@ -658,11 +658,16 @@ func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn,
 	return s.dial(host, s.connCfg, errorHandler)
 }
 
-type queryMetrics struct {
+type hostMetrics struct {
 	Attempts     int
 	TotalLatency int64
 }
 
+type queryMetrics struct {
+	l sync.RWMutex
+	m map[string]*hostMetrics
+}
+
 // Query represents a CQL statement that can be executed.
 type Query struct {
 	stmt                  string
@@ -677,6 +682,7 @@ type Query struct {
 	observer              QueryObserver
 	session               *Session
 	rt                    RetryPolicy
+	spec                  SpeculativeExecutionPolicy
 	binding               func(q *QueryInfo) ([]interface{}, error)
 	serialCons            SerialConsistency
 	defaultTimestamp      bool
@@ -685,8 +691,8 @@ type Query struct {
 	context               context.Context
 	cancelQuery           func()
 	idempotent            bool
-	metrics               map[string]*queryMetrics
 	customPayload         map[string][]byte
+	metrics               *queryMetrics
 
 	disableAutoPage bool
 }
@@ -704,23 +710,26 @@ func (q *Query) defaultsFromSession() {
 	q.serialCons = s.cfg.SerialConsistency
 	q.defaultTimestamp = s.cfg.DefaultTimestamp
 	q.idempotent = s.cfg.DefaultIdempotence
-	q.metrics = make(map[string]*queryMetrics)
+	q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
 
 	// Initiate an empty context with a cancel call
 	q.WithContext(context.Background())
 
+	q.spec = &NonSpeculativeExecution{}
 	s.mu.RUnlock()
 }
 
-func (q *Query) getHostMetrics(host *HostInfo) *queryMetrics {
-	hostMetrics, exists := q.metrics[host.ConnectAddress().String()]
+func (q *Query) getHostMetrics(host *HostInfo) *hostMetrics {
+	q.metrics.l.Lock()
+	metrics, exists := q.metrics.m[host.ConnectAddress().String()]
 	if !exists {
 		// if the host is not in the map, it means it's been accessed for the first time
-		hostMetrics = &queryMetrics{Attempts: 0, TotalLatency: 0}
-		q.metrics[host.ConnectAddress().String()] = hostMetrics
+		metrics = &hostMetrics{}
+		q.metrics.m[host.ConnectAddress().String()] = metrics
 	}
+	q.metrics.l.Unlock()
 
-	return hostMetrics
+	return metrics
 }
 
 // Statement returns the statement that was used to generate this query.
@@ -735,27 +744,45 @@ func (q Query) String() string {
 
 //Attempts returns the number of times the query was executed.
 func (q *Query) Attempts() int {
-	attempts := 0
-	for _, metric := range q.metrics {
+	q.metrics.l.Lock()
+	var attempts int
+	for _, metric := range q.metrics.m {
 		attempts += metric.Attempts
 	}
+	q.metrics.l.Unlock()
 	return attempts
 }
 
+func (q *Query) AddAttempts(i int, host *HostInfo) {
+	hostMetric := q.getHostMetrics(host)
+	q.metrics.l.Lock()
+	hostMetric.Attempts += i
+	q.metrics.l.Unlock()
+}
+
 //Latency returns the average amount of nanoseconds per attempt of the query.
 func (q *Query) Latency() int64 {
+	q.metrics.l.Lock()
 	var attempts int
 	var latency int64
-	for _, metric := range q.metrics {
+	for _, metric := range q.metrics.m {
 		attempts += metric.Attempts
 		latency += metric.TotalLatency
 	}
+	q.metrics.l.Unlock()
 	if attempts > 0 {
 		return latency / int64(attempts)
 	}
 	return 0
 }
 
+func (q *Query) AddLatency(l int64, host *HostInfo) {
+	hostMetric := q.getHostMetrics(host)
+	q.metrics.l.Lock()
+	hostMetric.TotalLatency += l
+	q.metrics.l.Unlock()
+}
+
 // Consistency sets the consistency level for this query. If no consistency
 // level have been set, the default consistency level of the cluster
 // is used.
@@ -781,6 +808,10 @@ func (q *Query) CustomPayload(customPayload map[string][]byte) *Query {
 	return q
 }
 
+func (q *Query) GetContext() context.Context {
+	return q.context
+}
+
 // Trace enables tracing of this query. Look at the documentation of the
 // Tracer interface to learn more about tracing.
 func (q *Query) Trace(trace Tracer) *Query {
@@ -851,9 +882,8 @@ func (q *Query) execute(conn *Conn) *Iter {
 }
 
 func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
-	hostMetrics := q.getHostMetrics(host)
-	hostMetrics.Attempts++
-	hostMetrics.TotalLatency += end.Sub(start).Nanoseconds()
+	q.AddAttempts(1, host)
+	q.AddLatency(end.Sub(start).Nanoseconds(), host)
 
 	if q.observer != nil {
 		q.observer.ObserveQuery(q.context, ObservedQuery{
@@ -863,7 +893,7 @@ func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host
 			End:       end,
 			Rows:      iter.numRows,
 			Host:      host,
-			Metrics:   hostMetrics,
+			Metrics:   q.getHostMetrics(host),
 			Err:       iter.err,
 		})
 	}
@@ -983,6 +1013,17 @@ func (q *Query) RetryPolicy(r RetryPolicy) *Query {
 	return q
 }
 
+// SetSpeculativeExecutionPolicy sets the execution policy
+func (q *Query) SetSpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Query {
+	q.spec = sp
+	return q
+}
+
+// speculativeExecutionPolicy fetches the policy
+func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy {
+	return q.spec
+}
+
 func (q *Query) IsIdempotent() bool {
 	return q.idempotent
 }
@@ -1431,6 +1472,7 @@ type Batch struct {
 	Cons                  Consistency
 	CustomPayload         map[string][]byte
 	rt                    RetryPolicy
+	spec                  SpeculativeExecutionPolicy
 	observer              BatchObserver
 	serialCons            SerialConsistency
 	defaultTimestamp      bool
@@ -1438,14 +1480,14 @@ type Batch struct {
 	context               context.Context
 	cancelBatch           func()
 	keyspace              string
-	metrics               map[string]*queryMetrics
+	metrics               *queryMetrics
 }
 
 // NewBatch creates a new batch operation without defaults from the cluster
 //
 // Deprecated: use session.NewBatch instead
 func NewBatch(typ BatchType) *Batch {
-	return &Batch{Type: typ, metrics: make(map[string]*queryMetrics)}
+	return &Batch{Type: typ, metrics: &queryMetrics{m: make(map[string]*hostMetrics)}}
 }
 
 // NewBatch creates a new batch operation using defaults defined in the cluster
@@ -1459,7 +1501,8 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
 		Cons:             s.cons,
 		defaultTimestamp: s.cfg.DefaultTimestamp,
 		keyspace:         s.cfg.Keyspace,
-		metrics:          make(map[string]*queryMetrics),
+		metrics:          &queryMetrics{m: make(map[string]*hostMetrics)},
+		spec:             &NonSpeculativeExecution{},
 	}
 
 	// Initiate an empty context with a cancel call
@@ -1469,15 +1512,17 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
 	return batch
 }
 
-func (b *Batch) getHostMetrics(host *HostInfo) *queryMetrics {
-	hostMetrics, exists := b.metrics[host.ConnectAddress().String()]
+func (b *Batch) getHostMetrics(host *HostInfo) *hostMetrics {
+	b.metrics.l.Lock()
+	metrics, exists := b.metrics.m[host.ConnectAddress().String()]
 	if !exists {
 		// if the host is not in the map, it means it's been accessed for the first time
-		hostMetrics = &queryMetrics{Attempts: 0, TotalLatency: 0}
-		b.metrics[host.ConnectAddress().String()] = hostMetrics
+		metrics = &hostMetrics{}
+		b.metrics.m[host.ConnectAddress().String()] = metrics
 	}
+	b.metrics.l.Unlock()
 
-	return hostMetrics
+	return metrics
 }
 
 // Observer enables batch-level observer on this batch.
@@ -1493,18 +1538,33 @@ func (b *Batch) Keyspace() string {
 
 // Attempts returns the number of attempts made to execute the batch.
 func (b *Batch) Attempts() int {
-	attempts := 0
-	for _, metric := range b.metrics {
+	b.metrics.l.Lock()
+	defer b.metrics.l.Unlock()
+
+	var attempts int
+	for _, metric := range b.metrics.m {
 		attempts += metric.Attempts
 	}
 	return attempts
 }
 
+func (b *Batch) AddAttempts(i int, host *HostInfo) {
+	hostMetric := b.getHostMetrics(host)
+	b.metrics.l.Lock()
+	hostMetric.Attempts += i
+	b.metrics.l.Unlock()
+}
+
 //Latency returns the average number of nanoseconds to execute a single attempt of the batch.
 func (b *Batch) Latency() int64 {
-	attempts := 0
-	var latency int64 = 0
-	for _, metric := range b.metrics {
+	b.metrics.l.Lock()
+	defer b.metrics.l.Unlock()
+
+	var (
+		attempts int
+		latency  int64
+	)
+	for _, metric := range b.metrics.m {
 		attempts += metric.Attempts
 		latency += metric.TotalLatency
 	}
@@ -1514,6 +1574,13 @@ func (b *Batch) Latency() int64 {
 	return 0
 }
 
+func (b *Batch) AddLatency(l int64, host *HostInfo) {
+	hostMetric := b.getHostMetrics(host)
+	b.metrics.l.Lock()
+	hostMetric.TotalLatency += l
+	b.metrics.l.Unlock()
+}
+
 // GetConsistency returns the currently configured consistency level for the batch
 // operation.
 func (b *Batch) GetConsistency() Consistency {
@@ -1526,6 +1593,28 @@ func (b *Batch) SetConsistency(c Consistency) {
 	b.Cons = c
 }
 
+func (b *Batch) GetContext() context.Context {
+	return b.context
+}
+
+func (b *Batch) IsIdempotent() bool {
+	for _, entry := range b.Entries {
+		if !entry.Idempotent {
+			return false
+		}
+	}
+	return true
+}
+
+func (b *Batch) speculativeExecutionPolicy() SpeculativeExecutionPolicy {
+	return b.spec
+}
+
+func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch {
+	b.spec = sp
+	return b
+}
+
 // Query adds the query to the batch operation
 func (b *Batch) Query(stmt string, args ...interface{}) {
 	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
@@ -1601,9 +1690,8 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch {
 }
 
 func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
-	hostMetrics := b.getHostMetrics(host)
-	hostMetrics.Attempts++
-	hostMetrics.TotalLatency += end.Sub(start).Nanoseconds()
+	b.AddAttempts(1, host)
+	b.AddLatency(end.Sub(start).Nanoseconds(), host)
 
 	if b.observer == nil {
 		return
@@ -1621,7 +1709,7 @@ func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host
 		End:        end,
 		// Rows not used in batch observations // TODO - might be able to support it when using BatchCAS
 		Host:    host,
-		Metrics: hostMetrics,
+		Metrics: b.getHostMetrics(host),
 		Err:     iter.err,
 	})
 }
@@ -1640,9 +1728,10 @@ const (
 )
 
 type BatchEntry struct {
-	Stmt    string
-	Args    []interface{}
-	binding func(q *QueryInfo) ([]interface{}, error)
+	Stmt       string
+	Args       []interface{}
+	Idempotent bool
+	binding    func(q *QueryInfo) ([]interface{}, error)
 }
 
 type ColumnInfo struct {
@@ -1775,7 +1864,7 @@ type ObservedQuery struct {
 	Host *HostInfo
 
 	// The metrics per this host
-	Metrics *queryMetrics
+	Metrics *hostMetrics
 
 	// Err is the error in the query.
 	// It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error
@@ -1807,7 +1896,7 @@ type ObservedBatch struct {
 	Err error
 
 	// The metrics per this host
-	Metrics *queryMetrics
+	Metrics *hostMetrics
 }
 
 // BatchObserver is the interface implemented by batch observers / stat collectors.

+ 6 - 5
session_test.go

@@ -100,17 +100,17 @@ func TestQueryBasicAPI(t *testing.T) {
 	qry := &Query{}
 
 	// Initialise metrics map
-	qry.metrics = make(map[string]*queryMetrics)
+	qry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
 
 	// Initiate host
 	ip := "127.0.0.1"
 
-	qry.metrics[ip] = &queryMetrics{Attempts: 0, TotalLatency: 0}
+	qry.metrics.m[ip] = &hostMetrics{Attempts: 0, TotalLatency: 0}
 	if qry.Latency() != 0 {
 		t.Fatalf("expected Query.Latency() to return 0, got %v", qry.Latency())
 	}
 
-	qry.metrics[ip] = &queryMetrics{Attempts: 2, TotalLatency: 4}
+	qry.metrics.m[ip] = &hostMetrics{Attempts: 2, TotalLatency: 4}
 	if qry.Attempts() != 2 {
 		t.Fatalf("expected Query.Attempts() to return 2, got %v", qry.Attempts())
 	}
@@ -202,10 +202,11 @@ func TestBatchBasicAPI(t *testing.T) {
 		t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type)
 	}
 
+	b.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
 	ip := "127.0.0.1"
 
 	// Test attempts
-	b.metrics[ip] = &queryMetrics{Attempts: 1}
+	b.metrics.m[ip] = &hostMetrics{Attempts: 1}
 	if b.Attempts() != 1 {
 		t.Fatalf("expected batch.Attempts() to return %v, got %v", 1, b.Attempts())
 	}
@@ -215,7 +216,7 @@ func TestBatchBasicAPI(t *testing.T) {
 		t.Fatalf("expected batch.Latency() to be 0, got %v", b.Latency())
 	}
 
-	b.metrics[ip] = &queryMetrics{Attempts: 1, TotalLatency: 4}
+	b.metrics.m[ip] = &hostMetrics{Attempts: 1, TotalLatency: 4}
 	if b.Latency() != 4 {
 		t.Fatalf("expected batch.Latency() to return %v, got %v", 4, b.Latency())
 	}