acquire_handler.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. package v2
  2. import (
  3. "net/http"
  4. "path"
  5. "strconv"
  6. "time"
  7. "github.com/coreos/go-etcd/etcd"
  8. "github.com/gorilla/mux"
  9. )
  10. // acquireHandler attempts to acquire a lock on the given key.
  11. // The "key" parameter specifies the resource to lock.
  12. // The "ttl" parameter specifies how long the lock will persist for.
  13. // The "timeout" parameter specifies how long the request should wait for the lock.
  14. func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
  15. h.client.SyncCluster()
  16. // Setup connection watcher.
  17. closeNotifier, _ := w.(http.CloseNotifier)
  18. closeChan := closeNotifier.CloseNotify()
  19. // Parse "key" and "ttl" query parameters.
  20. vars := mux.Vars(req)
  21. keypath := path.Join(prefix, vars["key"])
  22. ttl, err := strconv.Atoi(req.FormValue("ttl"))
  23. if err != nil {
  24. http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
  25. return
  26. }
  27. // Parse "timeout" parameter.
  28. var timeout int
  29. if len(req.FormValue("timeout")) == 0 {
  30. timeout = -1
  31. } else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
  32. http.Error(w, "invalid timeout: " + err.Error(), http.StatusInternalServerError)
  33. return
  34. }
  35. timeout = timeout + 1
  36. // Create an incrementing id for the lock.
  37. resp, err := h.client.AddChild(keypath, "-", uint64(ttl))
  38. if err != nil {
  39. http.Error(w, "add lock index error: " + err.Error(), http.StatusInternalServerError)
  40. return
  41. }
  42. indexpath := resp.Key
  43. // Keep updating TTL to make sure lock request is not expired before acquisition.
  44. stop := make(chan bool)
  45. go h.ttlKeepAlive(indexpath, ttl, stop)
  46. // Monitor for broken connection.
  47. stopWatchChan := make(chan bool)
  48. go func() {
  49. select {
  50. case <-closeChan:
  51. stopWatchChan <- true
  52. case <-stop:
  53. // Stop watching for connection disconnect.
  54. }
  55. }()
  56. // Extract the lock index.
  57. index, _ := strconv.Atoi(path.Base(resp.Key))
  58. // Wait until we successfully get a lock or we get a failure.
  59. var success bool
  60. for {
  61. // Read all indices.
  62. resp, err = h.client.GetAll(keypath, true)
  63. if err != nil {
  64. http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError)
  65. break
  66. }
  67. indices := extractResponseIndices(resp)
  68. waitIndex := resp.ModifiedIndex
  69. prevIndex := findPrevIndex(indices, index)
  70. // If there is no previous index then we have the lock.
  71. if prevIndex == 0 {
  72. success = true
  73. break
  74. }
  75. // Otherwise watch previous index until it's gone.
  76. _, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, nil, stopWatchChan)
  77. if err == etcd.ErrWatchStoppedByUser {
  78. break
  79. } else if err != nil {
  80. http.Error(w, "lock watch error: " + err.Error(), http.StatusInternalServerError)
  81. break
  82. }
  83. }
  84. // Check for connection disconnect before we write the lock index.
  85. select {
  86. case <-stopWatchChan:
  87. success = false
  88. default:
  89. }
  90. // Stop the ttl keep-alive.
  91. close(stop)
  92. if success {
  93. // Write lock index to response body if we acquire the lock.
  94. h.client.Update(indexpath, "-", uint64(ttl))
  95. w.Write([]byte(strconv.Itoa(index)))
  96. } else {
  97. // Make sure key is deleted if we couldn't acquire.
  98. h.client.Delete(indexpath)
  99. }
  100. }
  101. // ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
  102. func (h *handler) ttlKeepAlive(k string, ttl int, stop chan bool) {
  103. for {
  104. select {
  105. case <-time.After(time.Duration(ttl / 2) * time.Second):
  106. h.client.Update(k, "-", uint64(ttl))
  107. case <-stop:
  108. return
  109. }
  110. }
  111. }