Browse Source

Merge pull request #5199 from heyitsanthony/safe-lock-retry

clientv3/concurrency: use session lease id for mutex keys
Anthony Romano 9 years ago
parent
commit
08f6c0775a
2 changed files with 79 additions and 6 deletions
  1. 18 3
      clientv3/concurrency/mutex.go
  2. 61 3
      integration/v3_lock_test.go

+ 18 - 3
clientv3/concurrency/mutex.go

@@ -15,6 +15,7 @@
 package concurrency
 
 import (
+	"fmt"
 	"sync"
 
 	v3 "github.com/coreos/etcd/clientv3"
@@ -37,12 +38,26 @@ func NewMutex(client *v3.Client, pfx string) *Mutex {
 // Lock locks the mutex with a cancellable context. If the context is cancelled
 // while trying to acquire the lock, the mutex tries to clean its stale lock entry.
 func (m *Mutex) Lock(ctx context.Context) error {
-	s, err := NewSession(m.client)
+	s, serr := NewSession(m.client)
+	if serr != nil {
+		return serr
+	}
+
+	m.myKey = fmt.Sprintf("%s/%x", m.pfx, s.Lease())
+	cmp := v3.Compare(v3.CreateRevision(m.myKey), "=", 0)
+	// put self in lock waiters via myKey; oldest waiter holds lock
+	put := v3.OpPut(m.myKey, "", v3.WithLease(s.Lease()))
+	// reuse key in case this session already holds the lock
+	get := v3.OpGet(m.myKey)
+	resp, err := m.client.Txn(ctx).If(cmp).Then(put).Else(get).Commit()
 	if err != nil {
 		return err
 	}
-	// put self in lock waiters via myKey; oldest waiter holds lock
-	m.myKey, m.myRev, err = NewUniqueKey(ctx, m.client, m.pfx, v3.WithLease(s.Lease()))
+	m.myRev = resp.Header.Revision
+	if !resp.Succeeded {
+		m.myRev = resp.Responses[0].GetResponseRange().Kvs[0].CreateRevision
+	}
+
 	// wait for deletion revisions prior to myKey
 	err = waitDeletes(ctx, m.client, m.pfx, v3.WithPrefix(), v3.WithRev(m.myRev-1))
 	// release lock key if cancelled

+ 61 - 3
integration/v3_lock_test.go

@@ -16,6 +16,7 @@ package integration
 
 import (
 	"math/rand"
+	"sync"
 	"testing"
 	"time"
 
@@ -28,18 +29,24 @@ import (
 func TestMutexSingleNode(t *testing.T) {
 	clus := NewClusterV3(t, &ClusterConfig{Size: 3})
 	defer clus.Terminate(t)
-	testMutex(t, 5, func() *clientv3.Client { return clus.clients[0] })
+
+	var clients []*clientv3.Client
+	testMutex(t, 5, makeSingleNodeClients(t, clus.cluster, &clients))
+	closeClients(t, clients)
 }
 
 func TestMutexMultiNode(t *testing.T) {
 	clus := NewClusterV3(t, &ClusterConfig{Size: 3})
 	defer clus.Terminate(t)
-	testMutex(t, 5, func() *clientv3.Client { return clus.RandClient() })
+
+	var clients []*clientv3.Client
+	testMutex(t, 5, makeMultiNodeClients(t, clus.cluster, &clients))
+	closeClients(t, clients)
 }
 
 func testMutex(t *testing.T, waiters int, chooseClient func() *clientv3.Client) {
 	// stream lock acquisitions
-	lockedC := make(chan *concurrency.Mutex, 1)
+	lockedC := make(chan *concurrency.Mutex)
 	for i := 0; i < waiters; i++ {
 		go func() {
 			m := concurrency.NewMutex(chooseClient(), "test-mutex")
@@ -69,6 +76,22 @@ func testMutex(t *testing.T, waiters int, chooseClient func() *clientv3.Client)
 	}
 }
 
+// TestMutexSessionRelock ensures that acquiring the same lock with the same
+// session will not result in deadlock.
+func TestMutexSessionRelock(t *testing.T) {
+	clus := NewClusterV3(t, &ClusterConfig{Size: 3})
+	defer clus.Terminate(t)
+	cli := clus.RandClient()
+	m := concurrency.NewMutex(cli, "test-mutex")
+	if err := m.Lock(context.TODO()); err != nil {
+		t.Fatal(err)
+	}
+	m2 := concurrency.NewMutex(cli, "test-mutex")
+	if err := m2.Lock(context.TODO()); err != nil {
+		t.Fatal(err)
+	}
+}
+
 func BenchmarkMutex4Waiters(b *testing.B) {
 	// XXX switch tests to use TB interface
 	clus := NewClusterV3(nil, &ClusterConfig{Size: 3})
@@ -137,3 +160,38 @@ func testRWMutex(t *testing.T, waiters int, chooseClient func() *clientv3.Client
 		}
 	}
 }
+
+func makeClients(t *testing.T, clients *[]*clientv3.Client, choose func() *member) func() *clientv3.Client {
+	var mu sync.Mutex
+	*clients = nil
+	return func() *clientv3.Client {
+		cli, err := NewClientV3(choose())
+		if err != nil {
+			t.Fatal(err)
+		}
+		mu.Lock()
+		*clients = append(*clients, cli)
+		mu.Unlock()
+		return cli
+	}
+}
+
+func makeSingleNodeClients(t *testing.T, clus *cluster, clients *[]*clientv3.Client) func() *clientv3.Client {
+	return makeClients(t, clients, func() *member {
+		return clus.Members[0]
+	})
+}
+
+func makeMultiNodeClients(t *testing.T, clus *cluster, clients *[]*clientv3.Client) func() *clientv3.Client {
+	return makeClients(t, clients, func() *member {
+		return clus.Members[rand.Intn(len(clus.Members))]
+	})
+}
+
+func closeClients(t *testing.T, clients []*clientv3.Client) {
+	for _, cli := range clients {
+		if err := cli.Close(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}