acquire_handler.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package v2
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "path"
  7. "strconv"
  8. "time"
  9. "github.com/coreos/go-etcd/etcd"
  10. "github.com/gorilla/mux"
  11. )
  12. // acquireHandler attempts to acquire a lock on the given key.
  13. // The "key" parameter specifies the resource to lock.
  14. // The "value" parameter specifies a value to associate with the lock.
  15. // The "ttl" parameter specifies how long the lock will persist for.
  16. // The "timeout" parameter specifies how long the request should wait for the lock.
  17. func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
  18. h.client.SyncCluster()
  19. // Setup connection watcher.
  20. closeNotifier, _ := w.(http.CloseNotifier)
  21. closeChan := closeNotifier.CloseNotify()
  22. stopChan := make(chan bool)
  23. // Parse the lock "key".
  24. vars := mux.Vars(req)
  25. keypath := path.Join(prefix, vars["key"])
  26. value := req.FormValue("value")
  27. // Parse "timeout" parameter.
  28. var timeout int
  29. var err error
  30. if req.FormValue("timeout") == "" {
  31. timeout = -1
  32. } else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
  33. http.Error(w, "invalid timeout: " + req.FormValue("timeout"), http.StatusInternalServerError)
  34. return
  35. }
  36. timeout = timeout + 1
  37. // Parse TTL.
  38. ttl, err := strconv.Atoi(req.FormValue("ttl"))
  39. if err != nil {
  40. http.Error(w, "invalid ttl: " + req.FormValue("ttl"), http.StatusInternalServerError)
  41. return
  42. }
  43. // If node exists then just watch it. Otherwise create the node and watch it.
  44. index := h.findExistingNode(keypath, value)
  45. if index > 0 {
  46. err = h.watch(keypath, index, nil)
  47. } else {
  48. index, err = h.createNode(keypath, value, ttl, closeChan, stopChan)
  49. }
  50. // Stop all goroutines.
  51. close(stopChan)
  52. // Write response.
  53. if err != nil {
  54. http.Error(w, err.Error(), http.StatusInternalServerError)
  55. } else {
  56. w.Write([]byte(strconv.Itoa(index)))
  57. }
  58. }
  59. // createNode creates a new lock node and watches it until it is acquired or acquisition fails.
  60. func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- chan bool, stopChan chan bool) (int, error) {
  61. // Default the value to "-" if it is blank.
  62. if len(value) == 0 {
  63. value = "-"
  64. }
  65. // Create an incrementing id for the lock.
  66. resp, err := h.client.AddChild(keypath, value, uint64(ttl))
  67. if err != nil {
  68. return 0, errors.New("acquire lock index error: " + err.Error())
  69. }
  70. indexpath := resp.Node.Key
  71. index, _ := strconv.Atoi(path.Base(indexpath))
  72. // Keep updating TTL to make sure lock request is not expired before acquisition.
  73. go h.ttlKeepAlive(indexpath, value, ttl, stopChan)
  74. // Watch until we acquire or fail.
  75. err = h.watch(keypath, index, closeChan)
  76. // Check for connection disconnect before we write the lock index.
  77. if err != nil {
  78. select {
  79. case <-closeChan:
  80. err = errors.New("acquire lock error: user interrupted")
  81. default:
  82. }
  83. }
  84. // Update TTL one last time if acquired. Otherwise delete.
  85. if err == nil {
  86. h.client.Update(indexpath, value, uint64(ttl))
  87. } else {
  88. h.client.Delete(indexpath, false)
  89. }
  90. return index, err
  91. }
  92. // findExistingNode search for a node on the lock with the given value.
  93. func (h *handler) findExistingNode(keypath string, value string) int {
  94. if len(value) > 0 {
  95. resp, err := h.client.Get(keypath, true, true)
  96. if err == nil {
  97. nodes := lockNodes{resp.Node.Nodes}
  98. if node := nodes.FindByValue(value); node != nil {
  99. index, _ := strconv.Atoi(path.Base(node.Key))
  100. return index
  101. }
  102. }
  103. }
  104. return 0
  105. }
  106. // ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
  107. func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bool) {
  108. for {
  109. select {
  110. case <-time.After(time.Duration(ttl / 2) * time.Second):
  111. h.client.Update(k, value, uint64(ttl))
  112. case <-stopChan:
  113. return
  114. }
  115. }
  116. }
  117. // watch continuously waits for a given lock index to be acquired or until lock fails.
  118. // Returns a boolean indicating success.
  119. func (h *handler) watch(keypath string, index int, closeChan <- chan bool) error {
  120. // Wrap close chan so we can pass it to Client.Watch().
  121. stopWatchChan := make(chan bool)
  122. go func() {
  123. select {
  124. case <- closeChan:
  125. stopWatchChan <- true
  126. case <- stopWatchChan:
  127. }
  128. }()
  129. defer close(stopWatchChan)
  130. for {
  131. // Read all nodes for the lock.
  132. resp, err := h.client.Get(keypath, true, true)
  133. if err != nil {
  134. return fmt.Errorf("lock watch lookup error: %s", err.Error())
  135. }
  136. waitIndex := resp.Node.ModifiedIndex
  137. nodes := lockNodes{resp.Node.Nodes}
  138. prevIndex := nodes.PrevIndex(index)
  139. // If there is no previous index then we have the lock.
  140. if prevIndex == 0 {
  141. return nil
  142. }
  143. // Watch previous index until it's gone.
  144. _, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan)
  145. if err == etcd.ErrWatchStoppedByUser {
  146. return fmt.Errorf("lock watch closed")
  147. } else if err != nil {
  148. return fmt.Errorf("lock watch error:%s", err.Error())
  149. }
  150. }
  151. }