Quellcode durchsuchen

Add mod/lock connection monitoring.

Ben Johnson vor 12 Jahren
Ursprung
Commit
f3d438a93f
2 geänderte Dateien mit 76 neuen und 25 gelöschten Zeilen
  1. 70 21
      mod/lock/acquire_handler.go
  2. 6 4
      test.sh

+ 70 - 21
mod/lock/acquire_handler.go

@@ -6,13 +6,22 @@ import (
 	"strconv"
 	"time"
 
+	"github.com/coreos/go-etcd/etcd"
 	"github.com/gorilla/mux"
 )
 
 // acquireHandler attempts to acquire a lock on the given key.
+// The "key" parameter specifies the resource to lock.
+// The "ttl" parameter specifies how long the lock will persist for.
+// The "timeout" parameter specifies how long the request should wait for the lock.
 func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
 	h.client.SyncCluster()
 
+	// Setup connection watcher.
+	closeNotifier, _ := w.(http.CloseNotifier)
+	closeChan := closeNotifier.CloseNotify()
+
+	// Parse "key" and "ttl" query parameters.
 	vars := mux.Vars(req)
 	keypath := path.Join(prefix, vars["key"])
 	ttl, err := strconv.Atoi(req.FormValue("ttl"))
@@ -20,6 +29,16 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
 		http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
 		return
 	}
+	
+	// Parse "timeout" parameter.
+	var timeout int
+	if len(req.FormValue("timeout")) == 0 {
+		timeout = -1
+	} else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
+		http.Error(w, "invalid timeout: " + err.Error(), http.StatusInternalServerError)
+		return
+	}
+	timeout = timeout + 1
 
 	// Create an incrementing id for the lock.
 	resp, err := h.client.AddChild(keypath, "-", uint64(ttl))
@@ -30,32 +49,31 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
 	indexpath := resp.Key
 
 	// Keep updating TTL to make sure lock request is not expired before acquisition.
-	stopChan := make(chan bool)
-	defer close(stopChan)
-	go func(k string) {
-		stopped := false
-		for {
-			select {
-			case <-time.After(time.Duration(ttl / 2) * time.Second):
-			case <-stopChan:
-				stopped = true
-			}
-			h.client.Update(k, "-", uint64(ttl))
-			if stopped {
-				break
-			}
+	stop := make(chan bool)
+	go h.ttlKeepAlive(indexpath, ttl, stop)
+
+	// Monitor for broken connection.
+	stopWatchChan := make(chan bool)
+	go func() {
+		select {
+		case <-closeChan:
+			stopWatchChan <- true
+		case <-stop:
+			// Stop watching for connection disconnect.
 		}
-	}(indexpath)
+	}()
 
 	// Extract the lock index.
 	index, _ := strconv.Atoi(path.Base(resp.Key))
 
+	// Wait until we successfully get a lock or we get a failure.
+	var success bool
 	for {
 		// Read all indices.
 		resp, err = h.client.GetAll(keypath, true)
 		if err != nil {
 			http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError)
-			return
+			break
 		}
 		indices := extractResponseIndices(resp)
 		waitIndex := resp.ModifiedIndex
@@ -63,17 +81,48 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
 
 		// If there is no previous index then we have the lock.
 		if prevIndex == 0 {
+			success = true
 			break
 		}
 
 		// Otherwise watch previous index until it's gone.
-		_, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, nil, nil)
-		if err != nil {
+		_, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, nil, stopWatchChan)
+		if err == etcd.ErrWatchStoppedByUser {
+			break
+		} else if err != nil {
 			http.Error(w, "lock watch error: " + err.Error(), http.StatusInternalServerError)
-			return
+			break
 		}
 	}
 
-	// Write lock index to response body.
-	w.Write([]byte(strconv.Itoa(index)))
+	// Check for connection disconnect before we write the lock index.
+	select {
+	case <-stopWatchChan:
+		success = false
+	default:
+	}
+
+	// Stop the ttl keep-alive.
+	close(stop)
+
+	if success {
+		// Write lock index to response body if we acquire the lock.
+		h.client.Update(indexpath, "-", uint64(ttl))
+		w.Write([]byte(strconv.Itoa(index)))
+	} else {
+		// Make sure key is deleted if we couldn't acquire.
+		h.client.Delete(indexpath)
+	}
+}
+
+// ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
+func (h *handler) ttlKeepAlive(k string, ttl int, stop chan bool) {
+	for {
+		select {
+		case <-time.After(time.Duration(ttl / 2) * time.Second):
+			h.client.Update(k, "-", uint64(ttl))
+		case <-stop:
+			return
+		}
+	}
 }

+ 6 - 4
test.sh

@@ -1,7 +1,9 @@
 #!/bin/sh
 set -e
 
-PKGS="./store ./server ./server/v2/tests ./mod/lock/tests"
+if [ -z "$PKG" ]; then
+    PKG="./store ./server ./server/v2/tests ./mod/lock/tests"
+fi
 
 # Get GOPATH, etc from build
 . ./build
@@ -10,10 +12,10 @@ PKGS="./store ./server ./server/v2/tests ./mod/lock/tests"
 export GOPATH="${PWD}"
 
 # Unit tests
-for PKG in $PKGS
+for i in $PKG
 do
-    go test -i $PKG
-    go test -v $PKG
+    go test -i $i
+    go test -v $i
 done
 
 # Functional tests