Browse Source

Merge pull request #7079 from heyitsanthony/stm-prefetch

STM: prefetch and more
Anthony Romano 9 years ago
parent
commit
e0f4dd4cca
3 changed files with 197 additions and 78 deletions
  1. 173 66
      clientv3/concurrency/stm.go
  2. 14 5
      integration/v3_stm_test.go
  3. 10 7
      tools/benchmark/cmd/stm.go

+ 173 - 66
clientv3/concurrency/stm.go

@@ -15,6 +15,8 @@
 package concurrency
 
 import (
+	"math"
+
 	v3 "github.com/coreos/etcd/clientv3"
 	"golang.org/x/net/context"
 )
@@ -23,7 +25,7 @@ import (
 type STM interface {
 	// Get returns the value for a key and inserts the key in the txn's read set.
 	// If Get fails, it aborts the transaction with an error, never returning.
-	Get(key string) string
+	Get(key ...string) string
 	// Put adds a value for a key to the write set.
 	Put(key, val string, opts ...v3.OpOption)
 	// Rev returns the revision of a key in the read set.
@@ -36,30 +38,97 @@ type STM interface {
 	reset()
 }
 
+// Isolation is an enumeration of transactional isolation levels which
+// describes how transactions should interfere and conflict.
+type Isolation int
+
+const (
+	// Snapshot is serializable but also checks writes for conflicts.
+	Snapshot Isolation = iota
+	// Serializable reads within the same transactiona attempt return data
+	// from the at the revision of the first read.
+	Serializable
+	// RepeatableReads reads within the same transaction attempt always
+	// return the same data.
+	RepeatableReads
+	// ReadCommitted reads keys from any committed revision.
+	ReadCommitted
+)
+
 // stmError safely passes STM errors through panic to the STM error channel.
 type stmError struct{ err error }
 
-// NewSTMRepeatable initiates new repeatable read transaction; reads within
-// the same transaction attempt always return the same data.
-func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
-	s := &stm{client: c, ctx: ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
-	return runSTM(s, apply)
+type stmOptions struct {
+	iso      Isolation
+	ctx      context.Context
+	prefetch []string
+}
+
+type stmOption func(*stmOptions)
+
+// WithIsolation specifies the transaction isolation level.
+func WithIsolation(lvl Isolation) stmOption {
+	return func(so *stmOptions) { so.iso = lvl }
+}
+
+// WithAbortContext specifies the context for permanently aborting the transaction.
+func WithAbortContext(ctx context.Context) stmOption {
+	return func(so *stmOptions) { so.ctx = ctx }
+}
+
+// WithPrefetch is a hint to prefetch a list of keys before trying to apply.
+// If an STM transaction will unconditionally fetch a set of keys, prefetching
+// those keys will save the round-trip cost from requesting each key one by one
+// with Get().
+func WithPrefetch(keys ...string) stmOption {
+	return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
 }
 
-// NewSTMSerializable initiates a new serialized transaction; reads within the
-// same transactiona attempt return data from the revision of the first read.
-func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
-	s := &stmSerializable{
-		stm:      stm{client: c, ctx: ctx},
-		prefetch: make(map[string]*v3.GetResponse),
+// NewSTM initiates a new STM instance, using snapshot isolation by default.
+func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
+	opts := &stmOptions{ctx: c.Ctx()}
+	for _, f := range so {
+		f(opts)
+	}
+	if len(opts.prefetch) != 0 {
+		f := apply
+		apply = func(s STM) error {
+			s.Get(opts.prefetch...)
+			return f(s)
+		}
 	}
-	return runSTM(s, apply)
+	return runSTM(mkSTM(c, opts), apply)
 }
 
-// NewSTMReadCommitted initiates a new read committed transaction.
-func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
-	s := &stmReadCommitted{stm{client: c, ctx: ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}}
-	return runSTM(s, apply)
+func mkSTM(c *v3.Client, opts *stmOptions) STM {
+	switch opts.iso {
+	case Snapshot:
+		s := &stmSerializable{
+			stm:      stm{client: c, ctx: opts.ctx},
+			prefetch: make(map[string]*v3.GetResponse),
+		}
+		s.conflicts = func() []v3.Cmp {
+			return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
+		}
+		return s
+	case Serializable:
+		s := &stmSerializable{
+			stm:      stm{client: c, ctx: opts.ctx},
+			prefetch: make(map[string]*v3.GetResponse),
+		}
+		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
+		return s
+	case RepeatableReads:
+		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
+		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
+		return s
+	case ReadCommitted:
+		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
+		s.conflicts = func() []v3.Cmp { return nil }
+		return s
+	default:
+		panic("unsupported stm")
+	}
 }
 
 type stmResponse struct {
@@ -101,11 +170,13 @@ type stm struct {
 	client *v3.Client
 	ctx    context.Context
 	// rset holds read key values and revisions
-	rset map[string]*v3.GetResponse
+	rset readSet
 	// wset holds overwritten keys and their values
-	wset map[string]stmPut
+	wset writeSet
 	// getOpts are the opts used for gets
 	getOpts []v3.OpOption
+	// conflicts computes the current conflicts on the txn
+	conflicts func() []v3.Cmp
 }
 
 type stmPut struct {
@@ -113,11 +184,67 @@ type stmPut struct {
 	op  v3.Op
 }
 
-func (s *stm) Get(key string) string {
-	if wv, ok := s.wset[key]; ok {
+type readSet map[string]*v3.GetResponse
+
+func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
+	for i, resp := range txnresp.Responses {
+		rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
+	}
+}
+
+func (rs readSet) first() int64 {
+	ret := int64(math.MaxInt64 - 1)
+	for _, resp := range rs {
+		if len(resp.Kvs) > 0 && resp.Kvs[0].ModRevision < ret {
+			ret = resp.Kvs[0].ModRevision
+		}
+	}
+	return ret
+}
+
+// cmps guards the txn from updates to read set
+func (rs readSet) cmps() []v3.Cmp {
+	cmps := make([]v3.Cmp, 0, len(rs))
+	for k, rk := range rs {
+		cmps = append(cmps, isKeyCurrent(k, rk))
+	}
+	return cmps
+}
+
+type writeSet map[string]stmPut
+
+func (ws writeSet) get(keys ...string) *stmPut {
+	for _, key := range keys {
+		if wv, ok := ws[key]; ok {
+			return &wv
+		}
+	}
+	return nil
+}
+
+// cmps returns a cmp list testing no writes have happened past rev
+func (ws writeSet) cmps(rev int64) []v3.Cmp {
+	cmps := make([]v3.Cmp, 0, len(ws))
+	for key := range ws {
+		cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
+	}
+	return cmps
+}
+
+// puts is the list of ops for all pending writes
+func (ws writeSet) puts() []v3.Op {
+	puts := make([]v3.Op, 0, len(ws))
+	for _, v := range ws {
+		puts = append(puts, v.op)
+	}
+	return puts
+}
+
+func (s *stm) Get(keys ...string) string {
+	if wv := s.wset.get(keys...); wv != nil {
 		return wv.val
 	}
-	return respToValue(s.fetch(key))
+	return respToValue(s.fetch(keys...))
 }
 
 func (s *stm) Put(key, val string, opts ...v3.OpOption) {
@@ -134,7 +261,7 @@ func (s *stm) Rev(key string) int64 {
 }
 
 func (s *stm) commit() *v3.TxnResponse {
-	txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...).Commit()
+	txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
 	if err != nil {
 		panic(stmError{err})
 	}
@@ -144,34 +271,23 @@ func (s *stm) commit() *v3.TxnResponse {
 	return nil
 }
 
-// cmps guards the txn from updates to read set
-func (s *stm) cmps() []v3.Cmp {
-	cmps := make([]v3.Cmp, 0, len(s.rset))
-	for k, rk := range s.rset {
-		cmps = append(cmps, isKeyCurrent(k, rk))
+func (s *stm) fetch(keys ...string) *v3.GetResponse {
+	if len(keys) == 0 {
+		return nil
 	}
-	return cmps
-}
-
-func (s *stm) fetch(key string) *v3.GetResponse {
-	if resp, ok := s.rset[key]; ok {
-		return resp
+	ops := make([]v3.Op, len(keys))
+	for i, key := range keys {
+		if resp, ok := s.rset[key]; ok {
+			return resp
+		}
+		ops[i] = v3.OpGet(key, s.getOpts...)
 	}
-	resp, err := s.client.Get(s.ctx, key, s.getOpts...)
+	txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
 	if err != nil {
 		panic(stmError{err})
 	}
-	s.rset[key] = resp
-	return resp
-}
-
-// puts is the list of ops for all pending writes
-func (s *stm) puts() []v3.Op {
-	puts := make([]v3.Op, 0, len(s.wset))
-	for _, v := range s.wset {
-		puts = append(puts, v.op)
-	}
-	return puts
+	s.rset.add(keys, txnresp)
+	return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
 }
 
 func (s *stm) reset() {
@@ -184,16 +300,18 @@ type stmSerializable struct {
 	prefetch map[string]*v3.GetResponse
 }
 
-func (s *stmSerializable) Get(key string) string {
-	if wv, ok := s.wset[key]; ok {
+func (s *stmSerializable) Get(keys ...string) string {
+	if wv := s.wset.get(keys...); wv != nil {
 		return wv.val
 	}
 	firstRead := len(s.rset) == 0
-	if resp, ok := s.prefetch[key]; ok {
-		delete(s.prefetch, key)
-		s.rset[key] = resp
+	for _, key := range keys {
+		if resp, ok := s.prefetch[key]; ok {
+			delete(s.prefetch, key)
+			s.rset[key] = resp
+		}
 	}
-	resp := s.stm.fetch(key)
+	resp := s.stm.fetch(keys...)
 	if firstRead {
 		// txn's base revision is defined by the first read
 		s.getOpts = []v3.OpOption{
@@ -221,7 +339,7 @@ func (s *stmSerializable) gets() ([]string, []v3.Op) {
 
 func (s *stmSerializable) commit() *v3.TxnResponse {
 	keys, getops := s.gets()
-	txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...)
+	txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
 	// use Else to prefetch keys in case of conflict to save a round trip
 	txnresp, err := txn.Else(getops...).Commit()
 	if err != nil {
@@ -231,23 +349,12 @@ func (s *stmSerializable) commit() *v3.TxnResponse {
 		return txnresp
 	}
 	// load prefetch with Else data
-	for i := range keys {
-		resp := txnresp.Responses[i].GetResponseRange()
-		s.rset[keys[i]] = (*v3.GetResponse)(resp)
-	}
+	s.rset.add(keys, txnresp)
 	s.prefetch = s.rset
 	s.getOpts = nil
 	return nil
 }
 
-type stmReadCommitted struct{ stm }
-
-// commit always goes through when read committed
-func (s *stmReadCommitted) commit() *v3.TxnResponse {
-	s.rset = nil
-	return s.stm.commit()
-}
-
 func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
 	if len(r.Kvs) != 0 {
 		return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
@@ -256,7 +363,7 @@ func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
 }
 
 func respToValue(resp *v3.GetResponse) string {
-	if len(resp.Kvs) == 0 {
+	if resp == nil || len(resp.Kvs) == 0 {
 		return ""
 	}
 	return string(resp.Kvs[0].Value)

+ 14 - 5
integration/v3_stm_test.go

@@ -63,7 +63,8 @@ func TestSTMConflict(t *testing.T) {
 			return nil
 		}
 		go func() {
-			_, err := concurrency.NewSTMRepeatable(context.TODO(), curEtcdc, applyf)
+			iso := concurrency.WithIsolation(concurrency.RepeatableReads)
+			_, err := concurrency.NewSTM(curEtcdc, applyf, iso)
 			errc <- err
 		}()
 	}
@@ -100,7 +101,9 @@ func TestSTMPutNewKey(t *testing.T) {
 		stm.Put("foo", "bar")
 		return nil
 	}
-	if _, err := concurrency.NewSTMRepeatable(context.TODO(), etcdc, applyf); err != nil {
+
+	iso := concurrency.WithIsolation(concurrency.RepeatableReads)
+	if _, err := concurrency.NewSTM(etcdc, applyf, iso); err != nil {
 		t.Fatalf("error on stm txn (%v)", err)
 	}
 
@@ -126,7 +129,10 @@ func TestSTMAbort(t *testing.T) {
 		stm.Put("foo", "bap")
 		return nil
 	}
-	if _, err := concurrency.NewSTMRepeatable(ctx, etcdc, applyf); err == nil {
+
+	iso := concurrency.WithIsolation(concurrency.RepeatableReads)
+	sctx := concurrency.WithAbortContext(ctx)
+	if _, err := concurrency.NewSTM(etcdc, applyf, iso, sctx); err == nil {
 		t.Fatalf("no error on stm txn")
 	}
 
@@ -186,7 +192,8 @@ func TestSTMSerialize(t *testing.T) {
 			return nil
 		}
 		go func() {
-			_, err := concurrency.NewSTMSerializable(context.TODO(), curEtcdc, applyf)
+			iso := concurrency.WithIsolation(concurrency.Serializable)
+			_, err := concurrency.NewSTM(curEtcdc, applyf, iso)
 			errc <- err
 		}()
 	}
@@ -229,7 +236,9 @@ func TestSTMApplyOnConcurrentDeletion(t *testing.T) {
 		stm.Put("foo2", "bar2")
 		return nil
 	}
-	if _, err := concurrency.NewSTMRepeatable(context.TODO(), etcdc, applyf); err != nil {
+
+	iso := concurrency.WithIsolation(concurrency.RepeatableReads)
+	if _, err := concurrency.NewSTM(etcdc, applyf, iso); err != nil {
 		t.Fatalf("error on stm txn (%v)", err)
 	}
 	if try != 2 {

+ 10 - 7
tools/benchmark/cmd/stm.go

@@ -41,20 +41,21 @@ var stmCmd = &cobra.Command{
 type stmApply func(v3sync.STM) error
 
 var (
-	stmIsolation    string
+	stmIsolation string
+	stmIso       v3sync.Isolation
+
 	stmTotal        int
 	stmKeysPerTxn   int
 	stmKeyCount     int
 	stmValSize      int
 	stmWritePercent int
 	stmMutex        bool
-	mkSTM           func(context.Context, *v3.Client, func(v3sync.STM) error) (*v3.TxnResponse, error)
 )
 
 func init() {
 	RootCmd.AddCommand(stmCmd)
 
-	stmCmd.Flags().StringVar(&stmIsolation, "isolation", "r", "Read Committed (c), Repeatable Reads (r), or Serializable (s)")
+	stmCmd.Flags().StringVar(&stmIsolation, "isolation", "r", "Read Committed (c), Repeatable Reads (r), Serializable (s), or Snapshot (ss)")
 	stmCmd.Flags().IntVar(&stmKeyCount, "keys", 1, "Total unique keys accessible by the benchmark")
 	stmCmd.Flags().IntVar(&stmTotal, "total", 10000, "Total number of completed STM transactions")
 	stmCmd.Flags().IntVar(&stmKeysPerTxn, "keys-per-txn", 1, "Number of keys to access per transaction")
@@ -81,11 +82,13 @@ func stmFunc(cmd *cobra.Command, args []string) {
 
 	switch stmIsolation {
 	case "c":
-		mkSTM = v3sync.NewSTMReadCommitted
+		stmIso = v3sync.ReadCommitted
 	case "r":
-		mkSTM = v3sync.NewSTMRepeatable
+		stmIso = v3sync.RepeatableReads
 	case "s":
-		mkSTM = v3sync.NewSTMSerializable
+		stmIso = v3sync.Serializable
+	case "ss":
+		stmIso = v3sync.Snapshot
 	default:
 		fmt.Fprintln(os.Stderr, cmd.Usage())
 		os.Exit(1)
@@ -155,7 +158,7 @@ func doSTM(client *v3.Client, requests <-chan stmApply, results chan<- report.Re
 		if m != nil {
 			m.Lock(context.TODO())
 		}
-		_, err := mkSTM(context.TODO(), client, applyf)
+		_, err := v3sync.NewSTM(client, applyf, v3sync.WithIsolation(stmIso))
 		if m != nil {
 			m.Unlock(context.TODO())
 		}