Browse Source

Refactor mod/lock.

Ben Johnson 12 years ago
parent
commit
4bec461db1

+ 112 - 66
mod/lock/v2/acquire_handler.go

@@ -1,6 +1,8 @@
 package v2
 
 import (
+	"errors"
+	"fmt"
 	"net/http"
 	"path"
 	"strconv"
@@ -12,6 +14,7 @@ import (
 
 // acquireHandler attempts to acquire a lock on the given key.
 // The "key" parameter specifies the resource to lock.
+// The "value" parameter specifies a value to associate with the lock.
 // The "ttl" parameter specifies how long the lock will persist for.
 // The "timeout" parameter specifies how long the request should wait for the lock.
 func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
@@ -20,109 +23,152 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
 	// Setup connection watcher.
 	closeNotifier, _ := w.(http.CloseNotifier)
 	closeChan := closeNotifier.CloseNotify()
+	stopChan := make(chan bool)
 
-	// Parse "key" and "ttl" query parameters.
+	// Parse the lock "key".
 	vars := mux.Vars(req)
 	keypath := path.Join(prefix, vars["key"])
-	ttl, err := strconv.Atoi(req.FormValue("ttl"))
-	if err != nil {
-		http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
-		return
-	}
-	
+	value := req.FormValue("value")
+
 	// Parse "timeout" parameter.
 	var timeout int
-	if len(req.FormValue("timeout")) == 0 {
+	var err error
+	if req.FormValue("timeout") == "" {
 		timeout = -1
 	} else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
-		http.Error(w, "invalid timeout: " + err.Error(), http.StatusInternalServerError)
+		http.Error(w, "invalid timeout: " + req.FormValue("timeout"), http.StatusInternalServerError)
 		return
 	}
 	timeout = timeout + 1
 
-	// Create an incrementing id for the lock.
-	resp, err := h.client.AddChild(keypath, "-", uint64(ttl))
+	// Parse TTL.
+	ttl, err := strconv.Atoi(req.FormValue("ttl"))
 	if err != nil {
-		http.Error(w, "add lock index error: " + err.Error(), http.StatusInternalServerError)
+		http.Error(w, "invalid ttl: " + req.FormValue("ttl"), http.StatusInternalServerError)
 		return
 	}
+
+	// If node exists then just watch it. Otherwise create the node and watch it.
+	index := h.findExistingNode(keypath, value)
+	if index > 0 {
+		err = h.watch(keypath, index, nil)
+	} else {
+		index, err = h.createNode(keypath, value, ttl, closeChan, stopChan)
+	}
+
+	// Stop all goroutines.
+	close(stopChan)
+
+	// Write response.
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+	} else {
+		w.Write([]byte(strconv.Itoa(index)))
+	}
+}
+
+// createNode creates a new lock node and watches it until it is acquired or acquisition fails.
+func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- chan bool, stopChan chan bool) (int, error) {
+	// Default the value to "-" if it is blank.
+	if len(value) == 0 {
+		value = "-"
+	}
+
+	// Create an incrementing id for the lock.
+	resp, err := h.client.AddChild(keypath, value, uint64(ttl))
+	if err != nil {
+		return 0, errors.New("acquire lock index error: " + err.Error())
+	}
 	indexpath := resp.Node.Key
+	index, _ := strconv.Atoi(path.Base(indexpath))
 
 	// Keep updating TTL to make sure lock request is not expired before acquisition.
-	stop := make(chan bool)
-	go h.ttlKeepAlive(indexpath, ttl, stop)
+	go h.ttlKeepAlive(indexpath, value, ttl, stopChan)
+
+	// Watch until we acquire or fail.
+	err = h.watch(keypath, index, closeChan)
+
+	// Check for connection disconnect before we write the lock index.
+	if err != nil {
+		select {
+		case <-closeChan:
+			err = errors.New("acquire lock error: user interrupted")
+		default:
+		}
+	}
+
+	// Update TTL one last time if acquired. Otherwise delete.
+	if err == nil {
+		h.client.Update(indexpath, value, uint64(ttl))
+	} else {
+		h.client.Delete(indexpath, false)
+	}
 
-	// Monitor for broken connection.
+	return index, err
+}
+
+// findExistingNode search for a node on the lock with the given value.
+func (h *handler) findExistingNode(keypath string, value string) int {
+	if len(value) > 0 {
+		resp, err := h.client.Get(keypath, true, true)
+		if err == nil {
+			nodes := lockNodes{resp.Node.Nodes}
+			if node := nodes.FindByValue(value); node != nil {
+				index, _ := strconv.Atoi(path.Base(node.Key))
+				return index
+			}
+		}
+	}
+	return 0
+}
+
+// ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
+func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bool) {
+	for {
+		select {
+		case <-time.After(time.Duration(ttl / 2) * time.Second):
+			h.client.Update(k, value, uint64(ttl))
+		case <-stopChan:
+			return
+		}
+	}
+}
+
+// watch continuously waits for a given lock index to be acquired or until lock fails.
+// Returns a boolean indicating success.
+func (h *handler) watch(keypath string, index int, closeChan <- chan bool) error {
+	// Wrap close chan so we can pass it to Client.Watch().
 	stopWatchChan := make(chan bool)
 	go func() {
 		select {
-		case <-closeChan:
+		case <- closeChan:
 			stopWatchChan <- true
-		case <-stop:
-			// Stop watching for connection disconnect.
+		case <- stopWatchChan:
 		}
 	}()
+	defer close(stopWatchChan)
 
-	// Extract the lock index.
-	index, _ := strconv.Atoi(path.Base(resp.Node.Key))
-
-	// Wait until we successfully get a lock or we get a failure.
-	var success bool
 	for {
-		// Read all indices.
-		resp, err = h.client.Get(keypath, true, true)
+		// Read all nodes for the lock.
+		resp, err := h.client.Get(keypath, true, true)
 		if err != nil {
-			http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError)
-			break
+			return fmt.Errorf("lock watch lookup error: %s", err.Error())
 		}
-		indices := extractResponseIndices(resp)
 		waitIndex := resp.Node.ModifiedIndex
-		prevIndex := findPrevIndex(indices, index)
+		nodes := lockNodes{resp.Node.Nodes}
+		prevIndex := nodes.PrevIndex(index)
 
 		// If there is no previous index then we have the lock.
 		if prevIndex == 0 {
-			success = true
-			break
+			return nil
 		}
 
-		// Otherwise watch previous index until it's gone.
+		// Watch previous index until it's gone.
 		_, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan)
 		if err == etcd.ErrWatchStoppedByUser {
-			break
+			return fmt.Errorf("lock watch closed")
 		} else if err != nil {
-			http.Error(w, "lock watch error: " + err.Error(), http.StatusInternalServerError)
-			break
-		}
-	}
-
-	// Check for connection disconnect before we write the lock index.
-	select {
-	case <-stopWatchChan:
-		success = false
-	default:
-	}
-
-	// Stop the ttl keep-alive.
-	close(stop)
-
-	if success {
-		// Write lock index to response body if we acquire the lock.
-		h.client.Update(indexpath, "-", uint64(ttl))
-		w.Write([]byte(strconv.Itoa(index)))
-	} else {
-		// Make sure key is deleted if we couldn't acquire.
-		h.client.Delete(indexpath, false)
-	}
-}
-
-// ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
-func (h *handler) ttlKeepAlive(k string, ttl int, stop chan bool) {
-	for {
-		select {
-		case <-time.After(time.Duration(ttl / 2) * time.Second):
-			h.client.Update(k, "-", uint64(ttl))
-		case <-stop:
-			return
+			return fmt.Errorf("lock watch error:%s", err.Error())
 		}
 	}
 }

+ 19 - 6
mod/lock/v2/get_index_handler.go

@@ -3,28 +3,41 @@ package v2
 import (
 	"net/http"
 	"path"
-	"strconv"
 
 	"github.com/gorilla/mux"
 )
 
 // getIndexHandler retrieves the current lock index.
+// The "field" parameter specifies to read either the lock "index" or lock "value".
 func (h *handler) getIndexHandler(w http.ResponseWriter, req *http.Request) {
 	h.client.SyncCluster()
 
 	vars := mux.Vars(req)
 	keypath := path.Join(prefix, vars["key"])
+	field := req.FormValue("field")
+	if len(field) == 0 {
+		field = "value"
+	}
 
 	// Read all indices.
 	resp, err := h.client.Get(keypath, true, true)
 	if err != nil {
-		http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError)
+		http.Error(w, "read lock error: " + err.Error(), http.StatusInternalServerError)
 		return
 	}
+	nodes := lockNodes{resp.Node.Nodes}
+
+	// Write out the requested field.
+	if node := nodes.First(); node != nil {
+		switch field {
+		case "index":
+			w.Write([]byte(path.Base(node.Key)))
+
+		case "value":
+			w.Write([]byte(node.Value))
 
-	// Write out the index of the last one to the response body.
-	indices := extractResponseIndices(resp)
-	if len(indices) > 0 {
-		w.Write([]byte(strconv.Itoa(indices[0])))
+		default:
+			http.Error(w, "read lock error: invalid field: " + field, http.StatusInternalServerError)
+		}
 	}
 }

+ 2 - 30
mod/lock/v2/handler.go

@@ -2,9 +2,6 @@ package v2
 
 import (
 	"net/http"
-	"path"
-	"strconv"
-	"sort"
 
 	"github.com/gorilla/mux"
 	"github.com/coreos/go-etcd/etcd"
@@ -27,32 +24,7 @@ func NewHandler(addr string) (http.Handler) {
 	h.StrictSlash(false)
 	h.HandleFunc("/{key:.*}", h.getIndexHandler).Methods("GET")
 	h.HandleFunc("/{key:.*}", h.acquireHandler).Methods("POST")
-	h.HandleFunc("/{key_with_index:.*}", h.renewLockHandler).Methods("PUT")
-	h.HandleFunc("/{key_with_index:.*}", h.releaseLockHandler).Methods("DELETE")
+	h.HandleFunc("/{key:.*}", h.renewLockHandler).Methods("PUT")
+	h.HandleFunc("/{key:.*}", h.releaseLockHandler).Methods("DELETE")
 	return h
 }
-
-
-// extractResponseIndices extracts a sorted list of indicies from a response.
-func extractResponseIndices(resp *etcd.Response) []int {
-	var indices []int
-	for _, node := range resp.Node.Nodes {
-		if index, _ := strconv.Atoi(path.Base(node.Key)); index > 0 {
-			indices = append(indices, index)
-		}
-	}
-	sort.Ints(indices)
-	return indices
-}
-
-// findPrevIndex retrieves the previous index before the given index.
-func findPrevIndex(indices []int, idx int) int {
-	var prevIndex int
-	for _, index := range indices {
-		if index == idx {
-			break
-		}
-		prevIndex = index
-	}
-	return prevIndex
-}

+ 57 - 0
mod/lock/v2/lock_nodes.go

@@ -0,0 +1,57 @@
+package v2
+
+import (
+	"path"
+	"sort"
+	"strconv"
+
+	"github.com/coreos/go-etcd/etcd"
+)
+
+// lockNodes is a wrapper for go-etcd's Nodes to allow for sorting by numeric key.
+type lockNodes struct {
+	etcd.Nodes
+}
+
+// Less sorts the nodes by key (numerically).
+func (s lockNodes) Less(i, j int) bool {
+	a, _ := strconv.Atoi(path.Base(s.Nodes[i].Key))
+	b, _ := strconv.Atoi(path.Base(s.Nodes[j].Key))
+	return a < b
+}
+
+// Retrieves the first node in the set of lock nodes.
+func (s lockNodes) First() *etcd.Node {
+	sort.Sort(s)
+	if len(s.Nodes) > 0 {
+		return &s.Nodes[0]
+	}
+	return nil
+}
+
+// Retrieves the first node with a given value.
+func (s lockNodes) FindByValue(value string) *etcd.Node {
+	sort.Sort(s)
+
+	for _, node := range s.Nodes {
+		if node.Value == value {
+			return &node
+		}
+	}
+	return nil
+}
+
+// Retrieves the index that occurs before a given index.
+func (s lockNodes) PrevIndex(index int) int {
+	sort.Sort(s)
+
+	var prevIndex int
+	for _, node := range s.Nodes {
+		idx, _ := strconv.Atoi(path.Base(node.Key))
+		if index == idx {
+			return prevIndex
+		}
+		prevIndex = idx
+	}
+	return 0
+}

+ 30 - 3
mod/lock/v2/release_handler.go

@@ -12,12 +12,39 @@ func (h *handler) releaseLockHandler(w http.ResponseWriter, req *http.Request) {
 	h.client.SyncCluster()
 
 	vars := mux.Vars(req)
-	keypath := path.Join(prefix, vars["key_with_index"])
+	keypath := path.Join(prefix, vars["key"])
+
+	// Read index and value parameters.
+	index := req.FormValue("index")
+	value := req.FormValue("value")
+	if len(index) == 0 && len(value) == 0 {
+		http.Error(w, "release lock error: index or value required", http.StatusInternalServerError)
+		return
+	} else if len(index) != 0 && len(value) != 0 {
+		http.Error(w, "release lock error: index and value cannot both be specified", http.StatusInternalServerError)
+		return
+	}
+
+	// Look up index by value if index is missing.
+	if len(index) == 0 {
+		resp, err := h.client.Get(keypath, true, true)
+		if err != nil {
+			http.Error(w, "release lock index error: " + err.Error(), http.StatusInternalServerError)
+			return
+		}
+		nodes := lockNodes{resp.Node.Nodes}
+		node := nodes.FindByValue(value)
+		if node == nil {
+			http.Error(w, "release lock error: cannot find: " + value, http.StatusInternalServerError)
+			return
+		}
+		index = path.Base(node.Key)
+	}
 
 	// Delete the lock.
-	_, err := h.client.Delete(keypath, false)
+	_, err := h.client.Delete(path.Join(keypath, index), false)
 	if err != nil {
-		http.Error(w, "delete lock index error: " + err.Error(), http.StatusInternalServerError)
+		http.Error(w, "release lock error: " + err.Error(), http.StatusInternalServerError)
 		return
 	}
 }

+ 40 - 3
mod/lock/v2/renew_handler.go

@@ -13,18 +13,55 @@ import (
 func (h *handler) renewLockHandler(w http.ResponseWriter, req *http.Request) {
 	h.client.SyncCluster()
 
+	// Read the lock path.
 	vars := mux.Vars(req)
-	keypath := path.Join(prefix, vars["key_with_index"])
+	keypath := path.Join(prefix, vars["key"])
+
+	// Parse new TTL parameter.
 	ttl, err := strconv.Atoi(req.FormValue("ttl"))
 	if err != nil {
 		http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
 		return
 	}
 
+	// Read and set defaults for index and value.
+	index := req.FormValue("index")
+	value := req.FormValue("value")
+	if len(index) == 0 && len(value) == 0 {
+		// The index or value is required.
+		http.Error(w, "renew lock error: index or value required", http.StatusInternalServerError)
+		return
+	}
+
+	if len(index) == 0 {
+		// If index is not specified then look it up by value.
+		resp, err := h.client.Get(keypath, true, true)
+		if err != nil {
+			http.Error(w, "renew lock index error: " + err.Error(), http.StatusInternalServerError)
+			return
+		}
+		nodes := lockNodes{resp.Node.Nodes}
+		node := nodes.FindByValue(value)
+		if node == nil {
+			http.Error(w, "renew lock error: cannot find: " + value, http.StatusInternalServerError)
+			return
+		}
+		index = path.Base(node.Key)
+
+	} else if len(value) == 0 {
+		// If value is not specified then default it to the previous value.
+		resp, err := h.client.Get(path.Join(keypath, index), true, false)
+		if err != nil {
+			http.Error(w, "renew lock value error: " + err.Error(), http.StatusInternalServerError)
+			return
+		}
+		value = resp.Node.Value
+	}
+
 	// Renew the lock, if it exists.
-	_, err = h.client.Update(keypath, "-", uint64(ttl))
+	_, err = h.client.Update(path.Join(keypath, index), value, uint64(ttl))
 	if err != nil {
-		http.Error(w, "renew lock index error: " + err.Error(), http.StatusInternalServerError)
+		http.Error(w, "renew lock error: " + err.Error(), http.StatusInternalServerError)
 		return
 	}
 }

+ 52 - 16
mod/lock/v2/tests/handler_test.go → mod/lock/v2/tests/mod_lock_test.go

@@ -14,7 +14,7 @@ import (
 func TestModLockAcquireAndRelease(t *testing.T) {
 	tests.RunServer(func(s *server.Server) {
 		// Acquire lock.
-		body, err := testAcquireLock(s, "foo", 10)
+		body, err := testAcquireLock(s, "foo", "", 10)
 		assert.NoError(t, err)
 		assert.Equal(t, body, "2")
 
@@ -24,7 +24,7 @@ func TestModLockAcquireAndRelease(t *testing.T) {
 		assert.Equal(t, body, "2")
 
 		// Release lock.
-		body, err = testReleaseLock(s, "foo", 2)
+		body, err = testReleaseLock(s, "foo", "2", "")
 		assert.NoError(t, err)
 		assert.Equal(t, body, "")
 
@@ -42,7 +42,7 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
 
 		// Acquire lock #1.
 		go func() {
-			body, err := testAcquireLock(s, "foo", 10)
+			body, err := testAcquireLock(s, "foo", "", 10)
 			assert.NoError(t, err)
 			assert.Equal(t, body, "2")
 			c <- true
@@ -50,11 +50,13 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
 		<- c
 
 		// Acquire lock #2.
+		waiting := true
 		go func() {
 			c <- true
-			body, err := testAcquireLock(s, "foo", 10)
+			body, err := testAcquireLock(s, "foo", "", 10)
 			assert.NoError(t, err)
 			assert.Equal(t, body, "4")
+			waiting = false
 		}()
 		<- c
 
@@ -65,8 +67,11 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
 		assert.NoError(t, err)
 		assert.Equal(t, body, "2")
 
+		// Check that we are still waiting for lock #2.
+		assert.Equal(t, waiting, true)
+
 		// Release lock #1.
-		body, err = testReleaseLock(s, "foo", 2)
+		body, err = testReleaseLock(s, "foo", "2", "")
 		assert.NoError(t, err)
 
 		// Check that we have lock #2.
@@ -75,7 +80,7 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
 		assert.Equal(t, body, "4")
 
 		// Release lock #2.
-		body, err = testReleaseLock(s, "foo", 4)
+		body, err = testReleaseLock(s, "foo", "4", "")
 		assert.NoError(t, err)
 
 		// Check that we have no lock.
@@ -92,7 +97,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
 
 		// Acquire lock #1.
 		go func() {
-			body, err := testAcquireLock(s, "foo", 2)
+			body, err := testAcquireLock(s, "foo", "", 2)
 			assert.NoError(t, err)
 			assert.Equal(t, body, "2")
 			c <- true
@@ -102,7 +107,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
 		// Acquire lock #2.
 		go func() {
 			c <- true
-			body, err := testAcquireLock(s, "foo", 10)
+			body, err := testAcquireLock(s, "foo", "", 10)
 			assert.NoError(t, err)
 			assert.Equal(t, body, "4")
 		}()
@@ -129,7 +134,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
 func TestModLockRenew(t *testing.T) {
 	tests.RunServer(func(s *server.Server) {
 		// Acquire lock.
-		body, err := testAcquireLock(s, "foo", 3)
+		body, err := testAcquireLock(s, "foo", "", 3)
 		assert.NoError(t, err)
 		assert.Equal(t, body, "2")
 
@@ -141,7 +146,7 @@ func TestModLockRenew(t *testing.T) {
 		assert.Equal(t, body, "2")
 
 		// Renew lock.
-		body, err = testRenewLock(s, "foo", 2, 3)
+		body, err = testRenewLock(s, "foo", "2", "", 3)
 		assert.NoError(t, err)
 		assert.Equal(t, body, "")
 
@@ -161,28 +166,59 @@ func TestModLockRenew(t *testing.T) {
 	})
 }
 
+// Ensure that a lock can be acquired with a value and released by value.
+func TestModLockAcquireAndReleaseByValue(t *testing.T) {
+	tests.RunServer(func(s *server.Server) {
+		// Acquire lock.
+		body, err := testAcquireLock(s, "foo", "XXX", 10)
+		assert.NoError(t, err)
+		assert.Equal(t, body, "2")
+
+		// Check that we have the lock.
+		body, err = testGetLockValue(s, "foo")
+		assert.NoError(t, err)
+		assert.Equal(t, body, "XXX")
+
+		// Release lock.
+		body, err = testReleaseLock(s, "foo", "", "XXX")
+		assert.NoError(t, err)
+		assert.Equal(t, body, "")
+
+		// Check that we released the lock.
+		body, err = testGetLockValue(s, "foo")
+		assert.NoError(t, err)
+		assert.Equal(t, body, "")
+	})
+}
+
 
 
-func testAcquireLock(s *server.Server, key string, ttl int) (string, error) {
-	resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?ttl=%d", s.URL(), key, ttl), nil)
+func testAcquireLock(s *server.Server, key string, value string, ttl int) (string, error) {
+	resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?value=%s&ttl=%d", s.URL(), key, value, ttl), nil)
 	ret := tests.ReadBody(resp)
 	return string(ret), err
 }
 
 func testGetLockIndex(s *server.Server, key string) (string, error) {
+	resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s?field=index", s.URL(), key))
+	ret := tests.ReadBody(resp)
+	return string(ret), err
+}
+
+func testGetLockValue(s *server.Server, key string) (string, error) {
 	resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s", s.URL(), key))
 	ret := tests.ReadBody(resp)
 	return string(ret), err
 }
 
-func testReleaseLock(s *server.Server, key string, index int) (string, error) {
-	resp, err := tests.DeleteForm(fmt.Sprintf("%s/mod/v2/lock/%s/%d", s.URL(), key, index), nil)
+func testReleaseLock(s *server.Server, key string, index string, value string) (string, error) {
+	resp, err := tests.DeleteForm(fmt.Sprintf("%s/mod/v2/lock/%s?index=%s&value=%s", s.URL(), key, index, value), nil)
 	ret := tests.ReadBody(resp)
 	return string(ret), err
 }
 
-func testRenewLock(s *server.Server, key string, index int, ttl int) (string, error) {
-	resp, err := tests.PutForm(fmt.Sprintf("%s/mod/v2/lock/%s/%d?ttl=%d", s.URL(), key, index, ttl), nil)
+func testRenewLock(s *server.Server, key string, index string, value string, ttl int) (string, error) {
+	resp, err := tests.PutForm(fmt.Sprintf("%s/mod/v2/lock/%s?index=%s&value=%s&ttl=%d", s.URL(), key, index, value, ttl), nil)
 	ret := tests.ReadBody(resp)
 	return string(ret), err
 }