Browse Source

clientv3/concurrency: software transactional memory

Repeatable read and serialized read STM implementations.
Anthony Romano 9 years ago
parent
commit
bc37a32062
3 changed files with 342 additions and 143 deletions
  1. 246 0
      clientv3/concurrency/stm.go
  2. 0 103
      contrib/recipes/stm.go
  3. 96 40
      integration/v3_stm_test.go

+ 246 - 0
clientv3/concurrency/stm.go

@@ -0,0 +1,246 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package concurrency
+
+import (
+	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
+	v3 "github.com/coreos/etcd/clientv3"
+)
+
+// STM is an interface for software transactional memory.
+type STM interface {
+	// Get returns the value for a key and inserts the key in the txn's read set.
+	// If Get fails, it aborts the transaction with an error, never returning.
+	Get(key string) string
+	// Put adds a value for a key to the write set.
+	Put(key, val string, opts ...v3.OpOption)
+	// Rev returns the revision of a key in the read set.
+	Rev(key string) int64
+	// Del deletes a key.
+	Del(key string)
+
+	// commit attempts to apply the txn's changes to the server.
+	commit() *v3.TxnResponse
+	reset()
+}
+
+// stmError safely passes STM errors through panic to the STM error channel.
+type stmError struct{ err error }
+
+// NewSTMRepeatable initiates new repeatable read transaction; reads within
+// the same transaction attempt always return the same data.
+func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
+	s := &stm{client: c, ctx: ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
+	return runSTM(s, apply)
+}
+
+// NewSTMSerializable initiates a new serialized transaction; reads within the
+// same transactiona attempt return data from the revision of the first read.
+func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
+	s := &stmSerializable{
+		stm:      stm{client: c, ctx: ctx},
+		prefetch: make(map[string]*v3.GetResponse),
+	}
+	return runSTM(s, apply)
+}
+
+type stmResponse struct {
+	resp *v3.TxnResponse
+	err  error
+}
+
+func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
+	outc := make(chan stmResponse, 1)
+	go func() {
+		defer func() {
+			if r := recover(); r != nil {
+				e, ok := r.(stmError)
+				if !ok {
+					// client apply panicked
+					panic(r)
+				}
+				outc <- stmResponse{nil, e.err}
+			}
+		}()
+		var out stmResponse
+		for {
+			s.reset()
+			if out.err = apply(s); out.err != nil {
+				break
+			}
+			if out.resp = s.commit(); out.resp != nil {
+				break
+			}
+		}
+		outc <- out
+	}()
+	r := <-outc
+	return r.resp, r.err
+}
+
+// stm implements repeatable-read software transactional memory over etcd
+type stm struct {
+	client *v3.Client
+	ctx    context.Context
+	// rset holds read key values and revisions
+	rset map[string]*v3.GetResponse
+	// wset holds overwritten keys and their values
+	wset map[string]stmPut
+	// getOpts are the opts used for gets
+	getOpts []v3.OpOption
+}
+
+type stmPut struct {
+	val string
+	op  v3.Op
+}
+
+func (s *stm) Get(key string) string {
+	if wv, ok := s.wset[key]; ok {
+		return wv.val
+	}
+	return respToValue(s.fetch(key))
+}
+
+func (s *stm) Put(key, val string, opts ...v3.OpOption) {
+	s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
+}
+
+func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
+
+func (s *stm) Rev(key string) int64 {
+	if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
+		return resp.Kvs[0].ModRevision
+	}
+	return 0
+}
+
+func (s *stm) commit() *v3.TxnResponse {
+	txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...).Commit()
+	if err != nil {
+		panic(stmError{err})
+	}
+	if txnresp.Succeeded {
+		return txnresp
+	}
+	return nil
+}
+
+// cmps guards the txn from updates to read set
+func (s *stm) cmps() (cmps []v3.Cmp) {
+	for k, rk := range s.rset {
+		cmps = append(cmps, isKeyCurrent(k, rk))
+	}
+	return
+}
+
+func (s *stm) fetch(key string) *v3.GetResponse {
+	if resp, ok := s.rset[key]; ok {
+		return resp
+	}
+	resp, err := s.client.Get(s.ctx, key, s.getOpts...)
+	if err != nil {
+		panic(stmError{err})
+	}
+	s.rset[key] = resp
+	return resp
+}
+
+// puts is the list of ops for all pending writes
+func (s *stm) puts() (puts []v3.Op) {
+	for _, v := range s.wset {
+		puts = append(puts, v.op)
+	}
+	return
+}
+
+func (s *stm) reset() {
+	s.rset = make(map[string]*v3.GetResponse)
+	s.wset = make(map[string]stmPut)
+}
+
+type stmSerializable struct {
+	stm
+	prefetch map[string]*v3.GetResponse
+}
+
+func (s *stmSerializable) Get(key string) string {
+	if wv, ok := s.wset[key]; ok {
+		return wv.val
+	}
+	firstRead := len(s.rset) == 0
+	if resp, ok := s.prefetch[key]; ok {
+		delete(s.prefetch, key)
+		s.rset[key] = resp
+	}
+	resp := s.stm.fetch(key)
+	if firstRead {
+		// txn's base revision is defined by the first read
+		s.getOpts = []v3.OpOption{
+			v3.WithRev(resp.Header.Revision),
+			v3.WithSerializable(),
+		}
+	}
+	return respToValue(resp)
+}
+
+func (s *stmSerializable) Rev(key string) int64 {
+	s.Get(key)
+	return s.stm.Rev(key)
+}
+
+func (s *stmSerializable) gets() (keys []string, ops []v3.Op) {
+	for k := range s.rset {
+		keys = append(keys, k)
+		ops = append(ops, v3.OpGet(k))
+	}
+	return
+}
+
+func (s *stmSerializable) commit() *v3.TxnResponse {
+	keys, getops := s.gets()
+	txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...)
+	// use Else to prefetch keys in case of conflict to save a round trip
+	txnresp, err := txn.Else(getops...).Commit()
+	if err != nil {
+		panic(stmError{err})
+	}
+	if txnresp.Succeeded {
+		return txnresp
+	}
+	// load prefetch with Else data
+	for i := range keys {
+		resp := txnresp.Responses[i].GetResponseRange()
+		s.rset[keys[i]] = (*v3.GetResponse)(resp)
+	}
+	s.prefetch = s.rset
+	s.getOpts = nil
+	return nil
+}
+
+func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
+	rev := r.Header.Revision + 1
+	if len(r.Kvs) != 0 {
+		rev = r.Kvs[0].ModRevision + 1
+	}
+	return v3.Compare(v3.ModifiedRevision(k), "<", rev)
+}
+
+func respToValue(resp *v3.GetResponse) string {
+	if len(resp.Kvs) == 0 {
+		return ""
+	}
+	return string(resp.Kvs[0].Value)
+}

+ 0 - 103
contrib/recipes/stm.go

@@ -1,103 +0,0 @@
-// Copyright 2016 CoreOS, Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package recipe
-
-import (
-	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
-	v3 "github.com/coreos/etcd/clientv3"
-)
-
-// STM implements software transactional memory over etcd
-type STM struct {
-	client *v3.Client
-	// rset holds the read key's value and revision of read
-	rset map[string]*RemoteKV
-	// wset holds the write key and its value
-	wset map[string]string
-	// aborted is whether user aborted the txn
-	aborted bool
-	apply   func(*STM) error
-}
-
-// NewSTM creates new transaction loop for a given apply function.
-func NewSTM(client *v3.Client, apply func(*STM) error) <-chan error {
-	s := &STM{client: client, apply: apply}
-	errc := make(chan error, 1)
-	go func() {
-		var err error
-		for {
-			s.clear()
-			if err = apply(s); err != nil || s.aborted {
-				break
-			}
-			if ok, cerr := s.commit(); ok || cerr != nil {
-				err = cerr
-				break
-			}
-		}
-		errc <- err
-	}()
-	return errc
-}
-
-// Abort abandons the apply loop, letting the transaction close without a commit.
-func (s *STM) Abort() { s.aborted = true }
-
-// Get returns the value for a given key, inserting the key into the txn's rset.
-func (s *STM) Get(key string) (string, error) {
-	if wv, ok := s.wset[key]; ok {
-		return wv, nil
-	}
-	if rk, ok := s.rset[key]; ok {
-		return rk.Value(), nil
-	}
-	rk, err := GetRemoteKV(s.client, key)
-	if err != nil {
-		return "", err
-	}
-	// TODO: setup watchers to abort txn early
-	s.rset[key] = rk
-	return rk.Value(), nil
-}
-
-// Put adds a value for a key to the write set.
-func (s *STM) Put(key string, val string) { s.wset[key] = val }
-
-// commit attempts to apply the txn's changes to the server.
-func (s *STM) commit() (ok bool, rr error) {
-	// read set must not change
-	cmps := make([]v3.Cmp, 0, len(s.rset))
-	for k, rk := range s.rset {
-		// use < to support updating keys that don't exist yet
-		cmp := v3.Compare(v3.ModifiedRevision(k), "<", rk.Revision()+1)
-		cmps = append(cmps, cmp)
-	}
-
-	// apply all writes
-	puts := make([]v3.Op, 0, len(s.wset))
-	for k, v := range s.wset {
-		puts = append(puts, v3.OpPut(k, v))
-	}
-	txnresp, err := s.client.Txn(context.TODO()).If(cmps...).Then(puts...).Commit()
-	if err != nil {
-		return false, err
-	}
-	return txnresp.Succeeded, err
-}
-
-func (s *STM) clear() {
-	s.rset = make(map[string]*RemoteKV)
-	s.wset = make(map[string]string)
-}

+ 96 - 40
integration/v3_stm_test.go

@@ -19,7 +19,9 @@ import (
 	"strconv"
 	"testing"
 
-	"github.com/coreos/etcd/contrib/recipes"
+	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
+	v3 "github.com/coreos/etcd/clientv3"
+	"github.com/coreos/etcd/clientv3/concurrency"
 )
 
 // TestSTMConflict tests that conflicts are retried.
@@ -28,33 +30,26 @@ func TestSTMConflict(t *testing.T) {
 	defer clus.Terminate(t)
 
 	etcdc := clus.RandClient()
-	keys := make([]*recipe.RemoteKV, 5)
+	keys := make([]string, 5)
 	for i := 0; i < len(keys); i++ {
-		rk, err := recipe.NewKV(etcdc, fmt.Sprintf("foo-%d", i), "100", 0)
-		if err != nil {
+		keys[i] = fmt.Sprintf("foo-%d", i)
+		if _, err := etcdc.Put(context.TODO(), keys[i], "100"); err != nil {
 			t.Fatalf("could not make key (%v)", err)
 		}
-		keys[i] = rk
 	}
 
-	errc := make([]<-chan error, len(keys))
-	for i, rk := range keys {
+	errc := make(chan error)
+	for i := range keys {
 		curEtcdc := clus.RandClient()
-		srcKey := rk.Key()
-		applyf := func(stm *recipe.STM) error {
-			src, err := stm.Get(srcKey)
-			if err != nil {
-				return err
-			}
+		srcKey := keys[i]
+		applyf := func(stm concurrency.STM) error {
+			src := stm.Get(srcKey)
 			// must be different key to avoid double-adding
 			dstKey := srcKey
 			for dstKey == srcKey {
-				dstKey = keys[rand.Intn(len(keys))].Key()
-			}
-			dst, err := stm.Get(dstKey)
-			if err != nil {
-				return err
+				dstKey = keys[rand.Intn(len(keys))]
 			}
+			dst := stm.Get(dstKey)
 			srcV, _ := strconv.ParseInt(src, 10, 64)
 			dstV, _ := strconv.ParseInt(dst, 10, 64)
 			xfer := int64(rand.Intn(int(srcV)) / 2)
@@ -62,24 +57,27 @@ func TestSTMConflict(t *testing.T) {
 			stm.Put(dstKey, fmt.Sprintf("%d", dstV+xfer))
 			return nil
 		}
-		errc[i] = recipe.NewSTM(curEtcdc, applyf)
+		go func() {
+			_, err := concurrency.NewSTMRepeatable(context.TODO(), curEtcdc, applyf)
+			errc <- err
+		}()
 	}
 
 	// wait for txns
-	for _, ch := range errc {
-		if err := <-ch; err != nil {
+	for range keys {
+		if err := <-errc; err != nil {
 			t.Fatalf("apply failed (%v)", err)
 		}
 	}
 
 	// ensure sum matches initial sum
 	sum := 0
-	for _, oldRK := range keys {
-		rk, err := recipe.GetRemoteKV(etcdc, oldRK.Key())
+	for _, oldkey := range keys {
+		rk, err := etcdc.Get(context.TODO(), oldkey)
 		if err != nil {
-			t.Fatalf("couldn't fetch key %s (%v)", oldRK.Key(), err)
+			t.Fatalf("couldn't fetch key %s (%v)", oldkey, err)
 		}
-		v, _ := strconv.ParseInt(rk.Value(), 10, 64)
+		v, _ := strconv.ParseInt(string(rk.Kvs[0].Value), 10, 64)
 		sum += int(v)
 	}
 	if sum != len(keys)*100 {
@@ -93,21 +91,20 @@ func TestSTMPutNewKey(t *testing.T) {
 	defer clus.Terminate(t)
 
 	etcdc := clus.RandClient()
-	applyf := func(stm *recipe.STM) error {
+	applyf := func(stm concurrency.STM) error {
 		stm.Put("foo", "bar")
 		return nil
 	}
-	errc := recipe.NewSTM(etcdc, applyf)
-	if err := <-errc; err != nil {
+	if _, err := concurrency.NewSTMRepeatable(context.TODO(), etcdc, applyf); err != nil {
 		t.Fatalf("error on stm txn (%v)", err)
 	}
 
-	rk, err := recipe.GetRemoteKV(etcdc, "foo")
+	resp, err := etcdc.Get(context.TODO(), "foo")
 	if err != nil {
 		t.Fatalf("error fetching key (%v)", err)
 	}
-	if rk.Value() != "bar" {
-		t.Fatalf("bad value. got %v, expected bar", rk.Value())
+	if string(resp.Kvs[0].Value) != "bar" {
+		t.Fatalf("bad value. got %+v, expected 'bar' value", resp)
 	}
 }
 
@@ -117,22 +114,81 @@ func TestSTMAbort(t *testing.T) {
 	defer clus.Terminate(t)
 
 	etcdc := clus.RandClient()
-	applyf := func(stm *recipe.STM) error {
-		stm.Put("foo", "baz")
-		stm.Abort()
+	ctx, cancel := context.WithCancel(context.TODO())
+	applyf := func(stm concurrency.STM) error {
 		stm.Put("foo", "baz")
+		cancel()
+		stm.Put("foo", "bap")
 		return nil
 	}
-	errc := recipe.NewSTM(etcdc, applyf)
-	if err := <-errc; err != nil {
-		t.Fatalf("error on stm txn (%v)", err)
+	if _, err := concurrency.NewSTMRepeatable(ctx, etcdc, applyf); err == nil {
+		t.Fatalf("no error on stm txn")
 	}
 
-	rk, err := recipe.GetRemoteKV(etcdc, "foo")
+	resp, err := etcdc.Get(context.TODO(), "foo")
 	if err != nil {
 		t.Fatalf("error fetching key (%v)", err)
 	}
-	if rk.Value() != "" {
-		t.Fatalf("bad value. got %v, expected empty string", rk.Value())
+	if len(resp.Kvs) != 0 {
+		t.Fatalf("bad value. got %+v, expected nothing", resp)
+	}
+}
+
+// TestSTMSerialize tests that serialization is honored when serializable.
+func TestSTMSerialize(t *testing.T) {
+	clus := NewClusterV3(t, &ClusterConfig{Size: 3})
+	defer clus.Terminate(t)
+
+	etcdc := clus.RandClient()
+
+	// set up initial keys
+	keys := make([]string, 5)
+	for i := 0; i < len(keys); i++ {
+		keys[i] = fmt.Sprintf("foo-%d", i)
+	}
+
+	// update keys in full batches
+	updatec := make(chan struct{})
+	go func() {
+		defer close(updatec)
+		for i := 0; i < 5; i++ {
+			s := fmt.Sprintf("%d", i)
+			ops := []v3.Op{}
+			for _, k := range keys {
+				ops = append(ops, v3.OpPut(k, s))
+			}
+			if _, err := etcdc.Txn(context.TODO()).Then(ops...).Commit(); err != nil {
+				t.Fatalf("couldn't put keys (%v)", err)
+			}
+			updatec <- struct{}{}
+		}
+	}()
+
+	// read all keys in txn, make sure all values match
+	errc := make(chan error)
+	for range updatec {
+		curEtcdc := clus.RandClient()
+		applyf := func(stm concurrency.STM) error {
+			vs := []string{}
+			for i := range keys {
+				vs = append(vs, stm.Get(keys[i]))
+			}
+			for i := range vs {
+				if vs[0] != vs[i] {
+					return fmt.Errorf("got vs[%d] = %v, want %v", i, vs[i], vs[0])
+				}
+			}
+			return nil
+		}
+		go func() {
+			_, err := concurrency.NewSTMSerializable(context.TODO(), curEtcdc, applyf)
+			errc <- err
+		}()
+	}
+
+	for i := 0; i < 5; i++ {
+		if err := <-errc; err != nil {
+			t.Error(err)
+		}
 	}
 }