Browse Source

concurrency: STM snapshot isolation level

Anthony Romano 9 years ago
parent
commit
8695511153
1 changed files with 69 additions and 36 deletions
  1. 69 36
      clientv3/concurrency/stm.go

+ 69 - 36
clientv3/concurrency/stm.go

@@ -15,6 +15,8 @@
 package concurrency
 package concurrency
 
 
 import (
 import (
+	"math"
+
 	v3 "github.com/coreos/etcd/clientv3"
 	v3 "github.com/coreos/etcd/clientv3"
 	"golang.org/x/net/context"
 	"golang.org/x/net/context"
 )
 )
@@ -82,7 +84,7 @@ func WithPrefetch(keys ...string) stmOption {
 	return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
 	return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
 }
 }
 
 
-// NewSTM initiates a new STM instance.
+// NewSTM initiates a new STM instance, using snapshot isolation by default.
 func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
 func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
 	opts := &stmOptions{ctx: c.Ctx()}
 	opts := &stmOptions{ctx: c.Ctx()}
 	for _, f := range so {
 	for _, f := range so {
@@ -95,22 +97,38 @@ func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnRespon
 			return f(s)
 			return f(s)
 		}
 		}
 	}
 	}
-	var s STM
+	return runSTM(mkSTM(c, opts), apply)
+}
+
+func mkSTM(c *v3.Client, opts *stmOptions) STM {
 	switch opts.iso {
 	switch opts.iso {
+	case Snapshot:
+		s := &stmSerializable{
+			stm:      stm{client: c, ctx: opts.ctx},
+			prefetch: make(map[string]*v3.GetResponse),
+		}
+		s.conflicts = func() []v3.Cmp {
+			return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
+		}
+		return s
 	case Serializable:
 	case Serializable:
-		s = &stmSerializable{
+		s := &stmSerializable{
 			stm:      stm{client: c, ctx: opts.ctx},
 			stm:      stm{client: c, ctx: opts.ctx},
 			prefetch: make(map[string]*v3.GetResponse),
 			prefetch: make(map[string]*v3.GetResponse),
 		}
 		}
+		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
+		return s
 	case RepeatableReads:
 	case RepeatableReads:
-		s = &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
+		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
+		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
+		return s
 	case ReadCommitted:
 	case ReadCommitted:
-		ss := stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
-		s = &stmReadCommitted{ss}
+		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
+		s.conflicts = func() []v3.Cmp { return nil }
+		return s
 	default:
 	default:
-		panic("unsupported")
+		panic("unsupported stm")
 	}
 	}
-	return runSTM(s, apply)
 }
 }
 
 
 type stmResponse struct {
 type stmResponse struct {
@@ -152,11 +170,13 @@ type stm struct {
 	client *v3.Client
 	client *v3.Client
 	ctx    context.Context
 	ctx    context.Context
 	// rset holds read key values and revisions
 	// rset holds read key values and revisions
-	rset map[string]*v3.GetResponse
+	rset readSet
 	// wset holds overwritten keys and their values
 	// wset holds overwritten keys and their values
 	wset writeSet
 	wset writeSet
 	// getOpts are the opts used for gets
 	// getOpts are the opts used for gets
 	getOpts []v3.OpOption
 	getOpts []v3.OpOption
+	// conflicts computes the current conflicts on the txn
+	conflicts func() []v3.Cmp
 }
 }
 
 
 type stmPut struct {
 type stmPut struct {
@@ -164,6 +184,33 @@ type stmPut struct {
 	op  v3.Op
 	op  v3.Op
 }
 }
 
 
+type readSet map[string]*v3.GetResponse
+
+func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
+	for i, resp := range txnresp.Responses {
+		rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
+	}
+}
+
+func (rs readSet) first() int64 {
+	ret := int64(math.MaxInt64 - 1)
+	for _, resp := range rs {
+		if len(resp.Kvs) > 0 && resp.Kvs[0].ModRevision < ret {
+			ret = resp.Kvs[0].ModRevision
+		}
+	}
+	return ret
+}
+
+// cmps guards the txn from updates to read set
+func (rs readSet) cmps() []v3.Cmp {
+	cmps := make([]v3.Cmp, 0, len(rs))
+	for k, rk := range rs {
+		cmps = append(cmps, isKeyCurrent(k, rk))
+	}
+	return cmps
+}
+
 type writeSet map[string]stmPut
 type writeSet map[string]stmPut
 
 
 func (ws writeSet) get(keys ...string) *stmPut {
 func (ws writeSet) get(keys ...string) *stmPut {
@@ -175,6 +222,15 @@ func (ws writeSet) get(keys ...string) *stmPut {
 	return nil
 	return nil
 }
 }
 
 
+// cmps returns a cmp list testing no writes have happened past rev
+func (ws writeSet) cmps(rev int64) []v3.Cmp {
+	cmps := make([]v3.Cmp, 0, len(ws))
+	for key := range ws {
+		cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
+	}
+	return cmps
+}
+
 // puts is the list of ops for all pending writes
 // puts is the list of ops for all pending writes
 func (ws writeSet) puts() []v3.Op {
 func (ws writeSet) puts() []v3.Op {
 	puts := make([]v3.Op, 0, len(ws))
 	puts := make([]v3.Op, 0, len(ws))
@@ -205,7 +261,7 @@ func (s *stm) Rev(key string) int64 {
 }
 }
 
 
 func (s *stm) commit() *v3.TxnResponse {
 func (s *stm) commit() *v3.TxnResponse {
-	txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...).Commit()
+	txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
 	if err != nil {
 	if err != nil {
 		panic(stmError{err})
 		panic(stmError{err})
 	}
 	}
@@ -215,15 +271,6 @@ func (s *stm) commit() *v3.TxnResponse {
 	return nil
 	return nil
 }
 }
 
 
-// cmps guards the txn from updates to read set
-func (s *stm) cmps() []v3.Cmp {
-	cmps := make([]v3.Cmp, 0, len(s.rset))
-	for k, rk := range s.rset {
-		cmps = append(cmps, isKeyCurrent(k, rk))
-	}
-	return cmps
-}
-
 func (s *stm) fetch(keys ...string) *v3.GetResponse {
 func (s *stm) fetch(keys ...string) *v3.GetResponse {
 	if len(keys) == 0 {
 	if len(keys) == 0 {
 		return nil
 		return nil
@@ -239,7 +286,7 @@ func (s *stm) fetch(keys ...string) *v3.GetResponse {
 	if err != nil {
 	if err != nil {
 		panic(stmError{err})
 		panic(stmError{err})
 	}
 	}
-	addTxnResp(s.rset, keys, txnresp)
+	s.rset.add(keys, txnresp)
 	return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
 	return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
 }
 }
 
 
@@ -292,7 +339,7 @@ func (s *stmSerializable) gets() ([]string, []v3.Op) {
 
 
 func (s *stmSerializable) commit() *v3.TxnResponse {
 func (s *stmSerializable) commit() *v3.TxnResponse {
 	keys, getops := s.gets()
 	keys, getops := s.gets()
-	txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...)
+	txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
 	// use Else to prefetch keys in case of conflict to save a round trip
 	// use Else to prefetch keys in case of conflict to save a round trip
 	txnresp, err := txn.Else(getops...).Commit()
 	txnresp, err := txn.Else(getops...).Commit()
 	if err != nil {
 	if err != nil {
@@ -302,26 +349,12 @@ func (s *stmSerializable) commit() *v3.TxnResponse {
 		return txnresp
 		return txnresp
 	}
 	}
 	// load prefetch with Else data
 	// load prefetch with Else data
-	addTxnResp(s.rset, keys, txnresp)
+	s.rset.add(keys, txnresp)
 	s.prefetch = s.rset
 	s.prefetch = s.rset
 	s.getOpts = nil
 	s.getOpts = nil
 	return nil
 	return nil
 }
 }
 
 
-func addTxnResp(rset map[string]*v3.GetResponse, keys []string, txnresp *v3.TxnResponse) {
-	for i, resp := range txnresp.Responses {
-		rset[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
-	}
-}
-
-type stmReadCommitted struct{ stm }
-
-// commit always goes through when read committed
-func (s *stmReadCommitted) commit() *v3.TxnResponse {
-	s.rset = nil
-	return s.stm.commit()
-}
-
 func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
 func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
 	if len(r.Kvs) != 0 {
 	if len(r.Kvs) != 0 {
 		return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
 		return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)