userspace.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. // Copyright 2016 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package tcpproxy
  15. import (
  16. "fmt"
  17. "io"
  18. "math/rand"
  19. "net"
  20. "sync"
  21. "time"
  22. "github.com/coreos/pkg/capnslog"
  23. )
  24. var (
  25. plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "proxy/tcpproxy")
  26. )
  27. type remote struct {
  28. mu sync.Mutex
  29. srv *net.SRV
  30. addr string
  31. inactive bool
  32. }
  33. func (r *remote) inactivate() {
  34. r.mu.Lock()
  35. defer r.mu.Unlock()
  36. r.inactive = true
  37. }
  38. func (r *remote) tryReactivate() error {
  39. conn, err := net.Dial("tcp", r.addr)
  40. if err != nil {
  41. return err
  42. }
  43. conn.Close()
  44. r.mu.Lock()
  45. defer r.mu.Unlock()
  46. r.inactive = false
  47. return nil
  48. }
  49. func (r *remote) isActive() bool {
  50. r.mu.Lock()
  51. defer r.mu.Unlock()
  52. return !r.inactive
  53. }
  54. type TCPProxy struct {
  55. Listener net.Listener
  56. Endpoints []*net.SRV
  57. MonitorInterval time.Duration
  58. donec chan struct{}
  59. mu sync.Mutex // guards the following fields
  60. remotes []*remote
  61. pickCount int // for round robin
  62. }
  63. func (tp *TCPProxy) Run() error {
  64. tp.donec = make(chan struct{})
  65. if tp.MonitorInterval == 0 {
  66. tp.MonitorInterval = 5 * time.Minute
  67. }
  68. for _, srv := range tp.Endpoints {
  69. addr := fmt.Sprintf("%s:%d", srv.Target, srv.Port)
  70. tp.remotes = append(tp.remotes, &remote{srv: srv, addr: addr})
  71. }
  72. eps := []string{}
  73. for _, ep := range tp.Endpoints {
  74. eps = append(eps, fmt.Sprintf("%s:%d", ep.Target, ep.Port))
  75. }
  76. plog.Printf("ready to proxy client requests to %+v", eps)
  77. go tp.runMonitor()
  78. for {
  79. in, err := tp.Listener.Accept()
  80. if err != nil {
  81. return err
  82. }
  83. go tp.serve(in)
  84. }
  85. }
  86. func (tp *TCPProxy) pick() *remote {
  87. var weighted []*remote
  88. var unweighted []*remote
  89. bestPr := uint16(65535)
  90. w := 0
  91. // find best priority class
  92. for _, r := range tp.remotes {
  93. switch {
  94. case !r.isActive():
  95. case r.srv.Priority < bestPr:
  96. bestPr = r.srv.Priority
  97. w = 0
  98. weighted = nil
  99. unweighted = []*remote{r}
  100. fallthrough
  101. case r.srv.Priority == bestPr:
  102. if r.srv.Weight > 0 {
  103. weighted = append(weighted, r)
  104. w += int(r.srv.Weight)
  105. } else {
  106. unweighted = append(unweighted, r)
  107. }
  108. }
  109. }
  110. if weighted != nil {
  111. if len(unweighted) > 0 && rand.Intn(100) == 1 {
  112. // In the presence of records containing weights greater
  113. // than 0, records with weight 0 should have a very small
  114. // chance of being selected.
  115. r := unweighted[tp.pickCount%len(unweighted)]
  116. tp.pickCount++
  117. return r
  118. }
  119. // choose a uniform random number between 0 and the sum computed
  120. // (inclusive), and select the RR whose running sum value is the
  121. // first in the selected order
  122. choose := rand.Intn(w)
  123. for i := 0; i < len(weighted); i++ {
  124. choose -= int(weighted[i].srv.Weight)
  125. if choose <= 0 {
  126. return weighted[i]
  127. }
  128. }
  129. }
  130. if unweighted != nil {
  131. for i := 0; i < len(tp.remotes); i++ {
  132. picked := tp.remotes[tp.pickCount%len(tp.remotes)]
  133. tp.pickCount++
  134. if picked.isActive() {
  135. return picked
  136. }
  137. }
  138. }
  139. return nil
  140. }
  141. func (tp *TCPProxy) serve(in net.Conn) {
  142. var (
  143. err error
  144. out net.Conn
  145. )
  146. for {
  147. tp.mu.Lock()
  148. remote := tp.pick()
  149. tp.mu.Unlock()
  150. if remote == nil {
  151. break
  152. }
  153. // TODO: add timeout
  154. out, err = net.Dial("tcp", remote.addr)
  155. if err == nil {
  156. break
  157. }
  158. remote.inactivate()
  159. plog.Warningf("deactivated endpoint [%s] due to %v for %v", remote.addr, err, tp.MonitorInterval)
  160. }
  161. if out == nil {
  162. in.Close()
  163. return
  164. }
  165. go func() {
  166. io.Copy(in, out)
  167. in.Close()
  168. out.Close()
  169. }()
  170. io.Copy(out, in)
  171. out.Close()
  172. in.Close()
  173. }
  174. func (tp *TCPProxy) runMonitor() {
  175. for {
  176. select {
  177. case <-time.After(tp.MonitorInterval):
  178. tp.mu.Lock()
  179. for _, rem := range tp.remotes {
  180. if rem.isActive() {
  181. continue
  182. }
  183. go func(r *remote) {
  184. if err := r.tryReactivate(); err != nil {
  185. plog.Warningf("failed to activate endpoint [%s] due to %v (stay inactive for another %v)", r.addr, err, tp.MonitorInterval)
  186. } else {
  187. plog.Printf("activated %s", r.addr)
  188. }
  189. }(rem)
  190. }
  191. tp.mu.Unlock()
  192. case <-tp.donec:
  193. return
  194. }
  195. }
  196. }
  197. func (tp *TCPProxy) Stop() {
  198. // graceful shutdown?
  199. // shutdown current connections?
  200. tp.Listener.Close()
  201. close(tp.donec)
  202. }