peer_hub.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /*
  2. Copyright 2014 CoreOS Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package etcd
  14. import (
  15. "encoding/json"
  16. "errors"
  17. "fmt"
  18. "io/ioutil"
  19. "net/http"
  20. "net/url"
  21. "path"
  22. "sync"
  23. "github.com/coreos/etcd/raft"
  24. )
  25. var (
  26. errUnknownPeer = errors.New("unknown peer")
  27. )
  28. type peerGetter interface {
  29. peer(id int64) (*peer, error)
  30. }
  31. type peerHub struct {
  32. mu sync.RWMutex
  33. stopped bool
  34. seeds map[string]bool
  35. peers map[int64]*peer
  36. c *http.Client
  37. }
  38. func newPeerHub(seeds []string, c *http.Client) *peerHub {
  39. h := &peerHub{
  40. peers: make(map[int64]*peer),
  41. seeds: make(map[string]bool),
  42. c: c,
  43. }
  44. for _, seed := range seeds {
  45. h.seeds[seed] = true
  46. }
  47. return h
  48. }
  49. func (h *peerHub) getSeeds() map[string]bool {
  50. h.mu.RLock()
  51. defer h.mu.RUnlock()
  52. s := make(map[string]bool)
  53. for k, v := range h.seeds {
  54. s[k] = v
  55. }
  56. return s
  57. }
  58. func (h *peerHub) stop() {
  59. h.mu.Lock()
  60. defer h.mu.Unlock()
  61. h.stopped = true
  62. for _, p := range h.peers {
  63. p.stop()
  64. }
  65. tr := h.c.Transport.(*http.Transport)
  66. tr.CloseIdleConnections()
  67. }
  68. func (h *peerHub) peer(id int64) (*peer, error) {
  69. h.mu.Lock()
  70. defer h.mu.Unlock()
  71. if h.stopped {
  72. return nil, fmt.Errorf("peerHub stopped")
  73. }
  74. if p, ok := h.peers[id]; ok {
  75. return p, nil
  76. }
  77. return nil, fmt.Errorf("peer %d not found", id)
  78. }
  79. func (h *peerHub) add(id int64, rawurl string) (*peer, error) {
  80. u, err := url.Parse(rawurl)
  81. if err != nil {
  82. return nil, err
  83. }
  84. u.Path = raftPrefix
  85. h.mu.Lock()
  86. defer h.mu.Unlock()
  87. if h.stopped {
  88. return nil, fmt.Errorf("peerHub stopped")
  89. }
  90. h.peers[id] = newPeer(u.String(), h.c)
  91. return h.peers[id], nil
  92. }
  93. func (h *peerHub) send(msg raft.Message) error {
  94. if p, err := h.fetch(msg.To); err == nil {
  95. data, err := json.Marshal(msg)
  96. if err != nil {
  97. return err
  98. }
  99. return p.send(data)
  100. }
  101. return errUnknownPeer
  102. }
  103. func (h *peerHub) fetch(nodeId int64) (*peer, error) {
  104. if p, err := h.peer(nodeId); err == nil {
  105. return p, nil
  106. }
  107. for seed := range h.seeds {
  108. if p, err := h.seedFetch(seed, nodeId); err == nil {
  109. return p, nil
  110. }
  111. }
  112. return nil, fmt.Errorf("cannot fetch the address of node %d", nodeId)
  113. }
  114. func (h *peerHub) seedFetch(seedurl string, id int64) (*peer, error) {
  115. u, err := url.Parse(seedurl)
  116. if err != nil {
  117. return nil, fmt.Errorf("cannot parse the url of the given seed")
  118. }
  119. u.Path = path.Join("/raft/cfg", fmt.Sprint(id))
  120. resp, err := h.c.Get(u.String())
  121. if err != nil {
  122. return nil, fmt.Errorf("cannot reach %v", u)
  123. }
  124. defer resp.Body.Close()
  125. if resp.StatusCode != http.StatusOK {
  126. return nil, fmt.Errorf("cannot find node %d via %s", id, seedurl)
  127. }
  128. b, err := ioutil.ReadAll(resp.Body)
  129. if err != nil {
  130. return nil, fmt.Errorf("cannot reach %v", u)
  131. }
  132. return h.add(id, string(b))
  133. }