acquire_handler.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package v2
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "path"
  7. "strconv"
  8. "time"
  9. "github.com/gorilla/mux"
  10. "github.com/coreos/go-etcd/etcd"
  11. etcdErr "github.com/coreos/etcd/error"
  12. )
  13. // acquireHandler attempts to acquire a lock on the given key.
  14. // The "key" parameter specifies the resource to lock.
  15. // The "value" parameter specifies a value to associate with the lock.
  16. // The "ttl" parameter specifies how long the lock will persist for.
  17. // The "timeout" parameter specifies how long the request should wait for the lock.
  18. func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) error {
  19. h.client.SyncCluster()
  20. // Setup connection watcher.
  21. closeNotifier, _ := w.(http.CloseNotifier)
  22. closeChan := closeNotifier.CloseNotify()
  23. stopChan := make(chan bool)
  24. // Parse the lock "key".
  25. vars := mux.Vars(req)
  26. keypath := path.Join(prefix, vars["key"])
  27. value := req.FormValue("value")
  28. // Parse "timeout" parameter.
  29. var timeout int
  30. var err error
  31. if req.FormValue("timeout") == "" {
  32. timeout = -1
  33. } else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
  34. return etcdErr.NewError(etcdErr.EcodeTimeoutNaN, "Acquire", 0)
  35. }
  36. timeout = timeout + 1
  37. // Parse TTL.
  38. ttl, err := strconv.Atoi(req.FormValue("ttl"))
  39. if err != nil {
  40. return etcdErr.NewError(etcdErr.EcodeTTLNaN, "Acquire", 0)
  41. }
  42. // If node exists then just watch it. Otherwise create the node and watch it.
  43. node, index, pos := h.findExistingNode(keypath, value)
  44. if index > 0 {
  45. if pos == 0 {
  46. // If lock is already acquired then update the TTL.
  47. h.client.Update(node.Key, node.Value, uint64(ttl))
  48. } else {
  49. // Otherwise watch until it becomes acquired (or errors).
  50. err = h.watch(keypath, index, nil)
  51. }
  52. } else {
  53. index, err = h.createNode(keypath, value, ttl, closeChan, stopChan)
  54. }
  55. // Stop all goroutines.
  56. close(stopChan)
  57. // Check for an error.
  58. if err != nil {
  59. return err
  60. }
  61. // Write response.
  62. w.Write([]byte(strconv.Itoa(index)))
  63. return nil
  64. }
  65. // createNode creates a new lock node and watches it until it is acquired or acquisition fails.
  66. func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- chan bool, stopChan chan bool) (int, error) {
  67. // Default the value to "-" if it is blank.
  68. if len(value) == 0 {
  69. value = "-"
  70. }
  71. // Create an incrementing id for the lock.
  72. resp, err := h.client.AddChild(keypath, value, uint64(ttl))
  73. if err != nil {
  74. return 0, err
  75. }
  76. indexpath := resp.Node.Key
  77. index, _ := strconv.Atoi(path.Base(indexpath))
  78. // Keep updating TTL to make sure lock request is not expired before acquisition.
  79. go h.ttlKeepAlive(indexpath, value, ttl, stopChan)
  80. // Watch until we acquire or fail.
  81. err = h.watch(keypath, index, closeChan)
  82. // Check for connection disconnect before we write the lock index.
  83. if err != nil {
  84. select {
  85. case <-closeChan:
  86. err = errors.New("user interrupted")
  87. default:
  88. }
  89. }
  90. // Update TTL one last time if acquired. Otherwise delete.
  91. if err == nil {
  92. h.client.Update(indexpath, value, uint64(ttl))
  93. } else {
  94. h.client.Delete(indexpath, false)
  95. }
  96. return index, err
  97. }
  98. // findExistingNode search for a node on the lock with the given value.
  99. func (h *handler) findExistingNode(keypath string, value string) (*etcd.Node, int, int) {
  100. if len(value) > 0 {
  101. resp, err := h.client.Get(keypath, true, true)
  102. if err == nil {
  103. nodes := lockNodes{resp.Node.Nodes}
  104. if node, pos := nodes.FindByValue(value); node != nil {
  105. index, _ := strconv.Atoi(path.Base(node.Key))
  106. return node, index, pos
  107. }
  108. }
  109. }
  110. return nil, 0, 0
  111. }
  112. // ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
  113. func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bool) {
  114. for {
  115. select {
  116. case <-time.After(time.Duration(ttl / 2) * time.Second):
  117. h.client.Update(k, value, uint64(ttl))
  118. case <-stopChan:
  119. return
  120. }
  121. }
  122. }
  123. // watch continuously waits for a given lock index to be acquired or until lock fails.
  124. // Returns a boolean indicating success.
  125. func (h *handler) watch(keypath string, index int, closeChan <- chan bool) error {
  126. // Wrap close chan so we can pass it to Client.Watch().
  127. stopWatchChan := make(chan bool)
  128. go func() {
  129. select {
  130. case <- closeChan:
  131. stopWatchChan <- true
  132. case <- stopWatchChan:
  133. }
  134. }()
  135. defer close(stopWatchChan)
  136. for {
  137. // Read all nodes for the lock.
  138. resp, err := h.client.Get(keypath, true, true)
  139. if err != nil {
  140. return fmt.Errorf("lock watch lookup error: %s", err.Error())
  141. }
  142. waitIndex := resp.Node.ModifiedIndex
  143. nodes := lockNodes{resp.Node.Nodes}
  144. prevIndex := nodes.PrevIndex(index)
  145. // If there is no previous index then we have the lock.
  146. if prevIndex == 0 {
  147. return nil
  148. }
  149. // Watch previous index until it's gone.
  150. _, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan)
  151. if err == etcd.ErrWatchStoppedByUser {
  152. return fmt.Errorf("lock watch closed")
  153. } else if err != nil {
  154. return fmt.Errorf("lock watch error: %s", err.Error())
  155. }
  156. }
  157. }