Browse Source

concurrency: extend STM interface to Get from any of a list of keys

Now possible to fetch multiple keys in a single txn.
Anthony Romano 9 years ago
parent
commit
a81234a25b
1 changed files with 57 additions and 34 deletions
  1. 57 34
      clientv3/concurrency/stm.go

+ 57 - 34
clientv3/concurrency/stm.go

@@ -23,7 +23,7 @@ import (
 type STM interface {
 	// Get returns the value for a key and inserts the key in the txn's read set.
 	// If Get fails, it aborts the transaction with an error, never returning.
-	Get(key string) string
+	Get(key ...string) string
 	// Put adds a value for a key to the write set.
 	Put(key, val string, opts ...v3.OpOption)
 	// Rev returns the revision of a key in the read set.
@@ -58,7 +58,7 @@ type stmError struct{ err error }
 
 type stmOptions struct {
 	iso Isolation
-	ctx       context.Context
+	ctx context.Context
 }
 
 type stmOption func(*stmOptions)
@@ -138,7 +138,7 @@ type stm struct {
 	// rset holds read key values and revisions
 	rset map[string]*v3.GetResponse
 	// wset holds overwritten keys and their values
-	wset map[string]stmPut
+	wset writeSet
 	// getOpts are the opts used for gets
 	getOpts []v3.OpOption
 }
@@ -148,11 +148,31 @@ type stmPut struct {
 	op  v3.Op
 }
 
-func (s *stm) Get(key string) string {
-	if wv, ok := s.wset[key]; ok {
+type writeSet map[string]stmPut
+
+func (ws writeSet) get(keys ...string) *stmPut {
+	for _, key := range keys {
+		if wv, ok := ws[key]; ok {
+			return &wv
+		}
+	}
+	return nil
+}
+
+// puts is the list of ops for all pending writes
+func (ws writeSet) puts() []v3.Op {
+	puts := make([]v3.Op, 0, len(ws))
+	for _, v := range ws {
+		puts = append(puts, v.op)
+	}
+	return puts
+}
+
+func (s *stm) Get(keys ...string) string {
+	if wv := s.wset.get(keys...); wv != nil {
 		return wv.val
 	}
-	return respToValue(s.fetch(key))
+	return respToValue(s.fetch(keys...))
 }
 
 func (s *stm) Put(key, val string, opts ...v3.OpOption) {
@@ -169,7 +189,7 @@ func (s *stm) Rev(key string) int64 {
 }
 
 func (s *stm) commit() *v3.TxnResponse {
-	txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...).Commit()
+	txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...).Commit()
 	if err != nil {
 		panic(stmError{err})
 	}
@@ -188,25 +208,23 @@ func (s *stm) cmps() []v3.Cmp {
 	return cmps
 }
 
-func (s *stm) fetch(key string) *v3.GetResponse {
-	if resp, ok := s.rset[key]; ok {
-		return resp
+func (s *stm) fetch(keys ...string) *v3.GetResponse {
+	if len(keys) == 0 {
+		return nil
 	}
-	resp, err := s.client.Get(s.ctx, key, s.getOpts...)
+	ops := make([]v3.Op, len(keys))
+	for i, key := range keys {
+		if resp, ok := s.rset[key]; ok {
+			return resp
+		}
+		ops[i] = v3.OpGet(key, s.getOpts...)
+	}
+	txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
 	if err != nil {
 		panic(stmError{err})
 	}
-	s.rset[key] = resp
-	return resp
-}
-
-// puts is the list of ops for all pending writes
-func (s *stm) puts() []v3.Op {
-	puts := make([]v3.Op, 0, len(s.wset))
-	for _, v := range s.wset {
-		puts = append(puts, v.op)
-	}
-	return puts
+	addTxnResp(s.rset, keys, txnresp)
+	return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
 }
 
 func (s *stm) reset() {
@@ -219,16 +237,18 @@ type stmSerializable struct {
 	prefetch map[string]*v3.GetResponse
 }
 
-func (s *stmSerializable) Get(key string) string {
-	if wv, ok := s.wset[key]; ok {
+func (s *stmSerializable) Get(keys ...string) string {
+	if wv := s.wset.get(keys...); wv != nil {
 		return wv.val
 	}
 	firstRead := len(s.rset) == 0
-	if resp, ok := s.prefetch[key]; ok {
-		delete(s.prefetch, key)
-		s.rset[key] = resp
+	for _, key := range keys {
+		if resp, ok := s.prefetch[key]; ok {
+			delete(s.prefetch, key)
+			s.rset[key] = resp
+		}
 	}
-	resp := s.stm.fetch(key)
+	resp := s.stm.fetch(keys...)
 	if firstRead {
 		// txn's base revision is defined by the first read
 		s.getOpts = []v3.OpOption{
@@ -256,7 +276,7 @@ func (s *stmSerializable) gets() ([]string, []v3.Op) {
 
 func (s *stmSerializable) commit() *v3.TxnResponse {
 	keys, getops := s.gets()
-	txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...)
+	txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.wset.puts()...)
 	// use Else to prefetch keys in case of conflict to save a round trip
 	txnresp, err := txn.Else(getops...).Commit()
 	if err != nil {
@@ -266,15 +286,18 @@ func (s *stmSerializable) commit() *v3.TxnResponse {
 		return txnresp
 	}
 	// load prefetch with Else data
-	for i := range keys {
-		resp := txnresp.Responses[i].GetResponseRange()
-		s.rset[keys[i]] = (*v3.GetResponse)(resp)
-	}
+	addTxnResp(s.rset, keys, txnresp)
 	s.prefetch = s.rset
 	s.getOpts = nil
 	return nil
 }
 
+func addTxnResp(rset map[string]*v3.GetResponse, keys []string, txnresp *v3.TxnResponse) {
+	for i, resp := range txnresp.Responses {
+		rset[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
+	}
+}
+
 type stmReadCommitted struct{ stm }
 
 // commit always goes through when read committed
@@ -291,7 +314,7 @@ func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
 }
 
 func respToValue(resp *v3.GetResponse) string {
-	if len(resp.Kvs) == 0 {
+	if resp == nil || len(resp.Kvs) == 0 {
 		return ""
 	}
 	return string(resp.Kvs[0].Value)