|
|
@@ -15,6 +15,8 @@
|
|
|
package concurrency
|
|
|
|
|
|
import (
|
|
|
+ "math"
|
|
|
+
|
|
|
v3 "github.com/coreos/etcd/clientv3"
|
|
|
"golang.org/x/net/context"
|
|
|
)
|
|
|
@@ -82,7 +84,7 @@ func WithPrefetch(keys ...string) stmOption {
|
|
|
return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
|
|
|
}
|
|
|
|
|
|
-// NewSTM initiates a new STM instance.
|
|
|
+// NewSTM initiates a new STM instance, using snapshot isolation by default.
|
|
|
func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
|
|
|
opts := &stmOptions{ctx: c.Ctx()}
|
|
|
for _, f := range so {
|
|
|
@@ -95,22 +97,38 @@ func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnRespon
|
|
|
return f(s)
|
|
|
}
|
|
|
}
|
|
|
- var s STM
|
|
|
+ return runSTM(mkSTM(c, opts), apply)
|
|
|
+}
|
|
|
+
|
|
|
+func mkSTM(c *v3.Client, opts *stmOptions) STM {
|
|
|
switch opts.iso {
|
|
|
+ case Snapshot:
|
|
|
+ s := &stmSerializable{
|
|
|
+ stm: stm{client: c, ctx: opts.ctx},
|
|
|
+ prefetch: make(map[string]*v3.GetResponse),
|
|
|
+ }
|
|
|
+ s.conflicts = func() []v3.Cmp {
|
|
|
+ return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
|
|
|
+ }
|
|
|
+ return s
|
|
|
case Serializable:
|
|
|
- s = &stmSerializable{
|
|
|
+ s := &stmSerializable{
|
|
|
stm: stm{client: c, ctx: opts.ctx},
|
|
|
prefetch: make(map[string]*v3.GetResponse),
|
|
|
}
|
|
|
+ s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
|
|
|
+ return s
|
|
|
case RepeatableReads:
|
|
|
- s = &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
|
|
|
+ s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
|
|
|
+ s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
|
|
|
+ return s
|
|
|
case ReadCommitted:
|
|
|
- ss := stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
|
|
|
- s = &stmReadCommitted{ss}
|
|
|
+ s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
|
|
|
+ s.conflicts = func() []v3.Cmp { return nil }
|
|
|
+ return s
|
|
|
default:
|
|
|
- panic("unsupported")
|
|
|
+ panic("unsupported stm")
|
|
|
}
|
|
|
- return runSTM(s, apply)
|
|
|
}
|
|
|
|
|
|
type stmResponse struct {
|
|
|
@@ -152,11 +170,13 @@ type stm struct {
|
|
|
client *v3.Client
|
|
|
ctx context.Context
|
|
|
// rset holds read key values and revisions
|
|
|
- rset map[string]*v3.GetResponse
|
|
|
+ rset readSet
|
|
|
// wset holds overwritten keys and their values
|
|
|
wset writeSet
|
|
|
// getOpts are the opts used for gets
|
|
|
getOpts []v3.OpOption
|
|
|
+ // conflicts computes the current conflicts on the txn
|
|
|
+ conflicts func() []v3.Cmp
|
|
|
}
|
|
|
|
|
|
type stmPut struct {
|
|
|
@@ -164,6 +184,33 @@ type stmPut struct {
|
|
|
op v3.Op
|
|
|
}
|
|
|
|
|
|
+type readSet map[string]*v3.GetResponse
|
|
|
+
|
|
|
+func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
|
|
|
+ for i, resp := range txnresp.Responses {
|
|
|
+ rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (rs readSet) first() int64 {
|
|
|
+ ret := int64(math.MaxInt64 - 1)
|
|
|
+ for _, resp := range rs {
|
|
|
+ if len(resp.Kvs) > 0 && resp.Kvs[0].ModRevision < ret {
|
|
|
+ ret = resp.Kvs[0].ModRevision
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return ret
|
|
|
+}
|
|
|
+
|
|
|
+// cmps guards the txn from updates to read set
|
|
|
+func (rs readSet) cmps() []v3.Cmp {
|
|
|
+ cmps := make([]v3.Cmp, 0, len(rs))
|
|
|
+ for k, rk := range rs {
|
|
|
+ cmps = append(cmps, isKeyCurrent(k, rk))
|
|
|
+ }
|
|
|
+ return cmps
|
|
|
+}
|
|
|
+
|
|
|
type writeSet map[string]stmPut
|
|
|
|
|
|
func (ws writeSet) get(keys ...string) *stmPut {
|
|
|
@@ -175,6 +222,15 @@ func (ws writeSet) get(keys ...string) *stmPut {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+// cmps returns a cmp list testing no writes have happened past rev
|
|
|
+func (ws writeSet) cmps(rev int64) []v3.Cmp {
|
|
|
+ cmps := make([]v3.Cmp, 0, len(ws))
|
|
|
+ for key := range ws {
|
|
|
+ cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
|
|
|
+ }
|
|
|
+ return cmps
|
|
|
+}
|
|
|
+
|
|
|
// puts is the list of ops for all pending writes
|
|
|
func (ws writeSet) puts() []v3.Op {
|
|
|
puts := make([]v3.Op, 0, len(ws))
|
|
|
@@ -205,7 +261,7 @@ func (s *stm) Rev(key string) int64 {
|
|
|
}
|
|
|
|
|
|
func (s *stm) commit() *v3.TxnResponse {
|
|
|
- txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...).Commit()
|
|
|
+ txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
|
|
|
if err != nil {
|
|
|
panic(stmError{err})
|
|
|
}
|
|
|
@@ -215,15 +271,6 @@ func (s *stm) commit() *v3.TxnResponse {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-// cmps guards the txn from updates to read set
|
|
|
-func (s *stm) cmps() []v3.Cmp {
|
|
|
- cmps := make([]v3.Cmp, 0, len(s.rset))
|
|
|
- for k, rk := range s.rset {
|
|
|
- cmps = append(cmps, isKeyCurrent(k, rk))
|
|
|
- }
|
|
|
- return cmps
|
|
|
-}
|
|
|
-
|
|
|
func (s *stm) fetch(keys ...string) *v3.GetResponse {
|
|
|
if len(keys) == 0 {
|
|
|
return nil
|
|
|
@@ -239,7 +286,7 @@ func (s *stm) fetch(keys ...string) *v3.GetResponse {
|
|
|
if err != nil {
|
|
|
panic(stmError{err})
|
|
|
}
|
|
|
- addTxnResp(s.rset, keys, txnresp)
|
|
|
+ s.rset.add(keys, txnresp)
|
|
|
return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
|
|
|
}
|
|
|
|
|
|
@@ -292,7 +339,7 @@ func (s *stmSerializable) gets() ([]string, []v3.Op) {
|
|
|
|
|
|
func (s *stmSerializable) commit() *v3.TxnResponse {
|
|
|
keys, getops := s.gets()
|
|
|
- txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...)
|
|
|
+ txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
|
|
|
// use Else to prefetch keys in case of conflict to save a round trip
|
|
|
txnresp, err := txn.Else(getops...).Commit()
|
|
|
if err != nil {
|
|
|
@@ -302,26 +349,12 @@ func (s *stmSerializable) commit() *v3.TxnResponse {
|
|
|
return txnresp
|
|
|
}
|
|
|
// load prefetch with Else data
|
|
|
- addTxnResp(s.rset, keys, txnresp)
|
|
|
+ s.rset.add(keys, txnresp)
|
|
|
s.prefetch = s.rset
|
|
|
s.getOpts = nil
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func addTxnResp(rset map[string]*v3.GetResponse, keys []string, txnresp *v3.TxnResponse) {
|
|
|
- for i, resp := range txnresp.Responses {
|
|
|
- rset[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-type stmReadCommitted struct{ stm }
|
|
|
-
|
|
|
-// commit always goes through when read committed
|
|
|
-func (s *stmReadCommitted) commit() *v3.TxnResponse {
|
|
|
- s.rset = nil
|
|
|
- return s.stm.commit()
|
|
|
-}
|
|
|
-
|
|
|
func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
|
|
|
if len(r.Kvs) != 0 {
|
|
|
return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
|