Quellcode durchsuchen

etcdmain, tcpproxy: srv-priority policy

Adds DNS SRV weighting and priorities to gateway.

Partially addresses #4378
Anthony Romano vor 8 Jahren
Ursprung
Commit
c232814003
5 geänderte Dateien mit 115 neuen und 40 gelöschten Zeilen
  1. 19 8
      etcdmain/gateway.go
  2. 3 2
      etcdmain/grpc_proxy.go
  3. 19 5
      etcdmain/util.go
  4. 71 24
      proxy/tcpproxy/userspace.go
  5. 3 1
      proxy/tcpproxy/userspace_test.go

+ 19 - 8
etcdmain/gateway.go

@@ -91,17 +91,28 @@ func stripSchema(eps []string) []string {
 
 	return endpoints
 }
-func startGateway(cmd *cobra.Command, args []string) {
-	endpoints := gatewayEndpoints
 
-	if eps := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery); len(eps) != 0 {
-		endpoints = eps
+func startGateway(cmd *cobra.Command, args []string) {
+	srvs := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery)
+	if len(srvs.Endpoints) == 0 {
+		// no endpoints discovered, fall back to provided endpoints
+		srvs.Endpoints = gatewayEndpoints
 	}
-
 	// Strip the schema from the endpoints because we start just a TCP proxy
-	endpoints = stripSchema(endpoints)
+	srvs.Endpoints = stripSchema(srvs.Endpoints)
+	if len(srvs.SRVs) == 0 {
+		for _, ep := range srvs.Endpoints {
+			h, p, err := net.SplitHostPort(ep)
+			if err != nil {
+				plog.Fatalf("error parsing endpoint %q", ep)
+			}
+			var port uint16
+			fmt.Sscanf(p, "%d", &port)
+			srvs.SRVs = append(srvs.SRVs, &net.SRV{Target: h, Port: port})
+		}
+	}
 
-	if len(endpoints) == 0 {
+	if len(srvs.Endpoints) == 0 {
 		plog.Fatalf("no endpoints found")
 	}
 
@@ -113,7 +124,7 @@ func startGateway(cmd *cobra.Command, args []string) {
 
 	tp := tcpproxy.TCPProxy{
 		Listener:        l,
-		Endpoints:       endpoints,
+		Endpoints:       srvs.SRVs,
 		MonitorInterval: getewayRetryDelay,
 	}
 

+ 3 - 2
etcdmain/grpc_proxy.go

@@ -106,8 +106,9 @@ func startGRPCProxy(cmd *cobra.Command, args []string) {
 		os.Exit(1)
 	}
 
-	if eps := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery); len(eps) != 0 {
-		grpcProxyEndpoints = eps
+	srvs := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery)
+	if len(srvs.Endpoints) != 0 {
+		grpcProxyEndpoints = srvs.Endpoints
 	}
 
 	l, err := net.Listen("tcp", grpcProxyListenAddr)

+ 19 - 5
etcdmain/util.go

@@ -22,19 +22,19 @@ import (
 	"github.com/coreos/etcd/pkg/transport"
 )
 
-func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string) {
+func discoverEndpoints(dns string, ca string, insecure bool) (s srv.SRVClients) {
 	if dns == "" {
-		return nil
+		return s
 	}
 	srvs, err := srv.GetClient("etcd-client", dns)
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)
 		os.Exit(1)
 	}
-	endpoints = srvs.Endpoints
+	endpoints := srvs.Endpoints
 	plog.Infof("discovered the cluster %s from %s", endpoints, dns)
 	if insecure {
-		return endpoints
+		return *srvs
 	}
 	// confirm TLS connections are good
 	tlsInfo := transport.TLSInfo{
@@ -47,5 +47,19 @@ func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string
 		plog.Warningf("%v", err)
 	}
 	plog.Infof("using discovered endpoints %v", endpoints)
-	return endpoints
+
+	// map endpoints back to SRVClients struct with SRV data
+	eps := make(map[string]struct{})
+	for _, ep := range endpoints {
+		eps[ep] = struct{}{}
+	}
+	for i := range srvs.Endpoints {
+		if _, ok := eps[srvs.Endpoints[i]]; !ok {
+			continue
+		}
+		s.Endpoints = append(s.Endpoints, srvs.Endpoints[i])
+		s.SRVs = append(s.SRVs, srvs.SRVs[i])
+	}
+
+	return s
 }

+ 71 - 24
proxy/tcpproxy/userspace.go

@@ -15,7 +15,9 @@
 package tcpproxy
 
 import (
+	"fmt"
 	"io"
+	"math/rand"
 	"net"
 	"sync"
 	"time"
@@ -29,6 +31,7 @@ var (
 
 type remote struct {
 	mu       sync.Mutex
+	srv      *net.SRV
 	addr     string
 	inactive bool
 }
@@ -59,14 +62,14 @@ func (r *remote) isActive() bool {
 
 type TCPProxy struct {
 	Listener        net.Listener
-	Endpoints       []string
+	Endpoints       []*net.SRV
 	MonitorInterval time.Duration
 
 	donec chan struct{}
 
-	mu         sync.Mutex // guards the following fields
-	remotes    []*remote
-	nextRemote int
+	mu        sync.Mutex // guards the following fields
+	remotes   []*remote
+	pickCount int // for round robin
 }
 
 func (tp *TCPProxy) Run() error {
@@ -74,11 +77,12 @@ func (tp *TCPProxy) Run() error {
 	if tp.MonitorInterval == 0 {
 		tp.MonitorInterval = 5 * time.Minute
 	}
-	for _, ep := range tp.Endpoints {
-		tp.remotes = append(tp.remotes, &remote{addr: ep})
+	for _, srv := range tp.Endpoints {
+		addr := fmt.Sprintf("%s:%d", srv.Target, srv.Port)
+		tp.remotes = append(tp.remotes, &remote{srv: srv, addr: addr})
 	}
 
-	plog.Printf("ready to proxy client requests to %v", tp.Endpoints)
+	plog.Printf("ready to proxy client requests to %+v", tp.Endpoints)
 	go tp.runMonitor()
 	for {
 		in, err := tp.Listener.Accept()
@@ -90,10 +94,61 @@ func (tp *TCPProxy) Run() error {
 	}
 }
 
-func (tp *TCPProxy) numRemotes() int {
-	tp.mu.Lock()
-	defer tp.mu.Unlock()
-	return len(tp.remotes)
+func (tp *TCPProxy) pick() *remote {
+	var weighted []*remote
+	var unweighted []*remote
+
+	bestPr := uint16(65535)
+	w := 0
+	// find best priority class
+	for _, r := range tp.remotes {
+		switch {
+		case !r.isActive():
+		case r.srv.Priority < bestPr:
+			bestPr = r.srv.Priority
+			w = 0
+			weighted, unweighted = nil, nil
+			unweighted = []*remote{r}
+			fallthrough
+		case r.srv.Priority == bestPr:
+			if r.srv.Weight > 0 {
+				weighted = append(weighted, r)
+				w += int(r.srv.Weight)
+			} else {
+				unweighted = append(unweighted, r)
+			}
+		}
+	}
+	if weighted != nil {
+		if len(unweighted) > 0 && rand.Intn(100) == 1 {
+			// In the presence of records containing weights greater
+			// than 0, records with weight 0 should have a very small
+			// chance of being selected.
+			r := unweighted[tp.pickCount%len(unweighted)]
+			tp.pickCount++
+			return r
+		}
+		// choose a uniform random number between 0 and the sum computed
+		// (inclusive), and select the RR whose running sum value is the
+		// first in the selected order
+		choose := rand.Intn(w)
+		for i := 0; i < len(weighted); i++ {
+			choose -= int(weighted[i].srv.Weight)
+			if choose <= 0 {
+				return weighted[i]
+			}
+		}
+	}
+	if unweighted != nil {
+		for i := 0; i < len(tp.remotes); i++ {
+			picked := tp.remotes[tp.pickCount%len(tp.remotes)]
+			tp.pickCount++
+			if picked.isActive() {
+				return picked
+			}
+		}
+	}
+	return nil
 }
 
 func (tp *TCPProxy) serve(in net.Conn) {
@@ -102,10 +157,12 @@ func (tp *TCPProxy) serve(in net.Conn) {
 		out net.Conn
 	)
 
-	for i := 0; i < tp.numRemotes(); i++ {
+	for {
+		tp.mu.Lock()
 		remote := tp.pick()
-		if !remote.isActive() {
-			continue
+		tp.mu.Unlock()
+		if remote == nil {
+			break
 		}
 		// TODO: add timeout
 		out, err = net.Dial("tcp", remote.addr)
@@ -132,16 +189,6 @@ func (tp *TCPProxy) serve(in net.Conn) {
 	in.Close()
 }
 
-// pick picks a remote in round-robin fashion
-func (tp *TCPProxy) pick() *remote {
-	tp.mu.Lock()
-	defer tp.mu.Unlock()
-
-	picked := tp.remotes[tp.nextRemote]
-	tp.nextRemote = (tp.nextRemote + 1) % len(tp.remotes)
-	return picked
-}
-
 func (tp *TCPProxy) runMonitor() {
 	for {
 		select {

+ 3 - 1
proxy/tcpproxy/userspace_test.go

@@ -42,9 +42,11 @@ func TestUserspaceProxy(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	var port uint16
+	fmt.Sscanf(u.Port(), "%d", &port)
 	p := TCPProxy{
 		Listener:  l,
-		Endpoints: []string{u.Host},
+		Endpoints: []*net.SRV{{Target: u.Hostname(), Port: port}},
 	}
 	go p.Run()
 	defer p.Stop()