Browse Source

clientv3: add load balancer unix socket test

Joe Betz 7 years ago
parent
commit
ed6bc2b554

+ 9 - 6
clientv3/balancer/balancer_test.go

@@ -42,22 +42,25 @@ func TestRoundRobinBalancedResolvableNoFailover(t *testing.T) {
 		name        string
 		serverCount int
 		reqN        int
+		network     string
 	}{
-		{name: "rrBalanced_1", serverCount: 1, reqN: 5},
-		{name: "rrBalanced_3", serverCount: 3, reqN: 7},
-		{name: "rrBalanced_5", serverCount: 5, reqN: 10},
+		{name: "rrBalanced_1", serverCount: 1, reqN: 5, network: "tcp"},
+		{name: "rrBalanced_1_unix_sockets", serverCount: 1, reqN: 5, network: "unix"},
+		{name: "rrBalanced_3", serverCount: 3, reqN: 7, network: "tcp"},
+		{name: "rrBalanced_5", serverCount: 5, reqN: 10, network: "tcp"},
 	}
 
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
-			ms, err := mockserver.StartMockServers(tc.serverCount)
+			ms, err := mockserver.StartMockServersOnNetwork(tc.serverCount, tc.network)
 			if err != nil {
 				t.Fatalf("failed to start mock servers: %v", err)
 			}
 			defer ms.Stop()
+
 			var resolvedAddrs []resolver.Address
 			for _, svr := range ms.Servers {
-				resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address})
+				resolvedAddrs = append(resolvedAddrs, svr.ResolverAddress())
 			}
 
 			rsv := endpoint.EndpointResolver("nofailover")
@@ -68,7 +71,7 @@ func TestRoundRobinBalancedResolvableNoFailover(t *testing.T) {
 				Policy:    picker.RoundrobinBalanced,
 				Name:      genName(),
 				Logger:    zap.NewExample(),
-				Endpoints: []string{fmt.Sprintf("etcd://nofailover/mock.server")},
+				Endpoints: []string{fmt.Sprintf("etcd://nofailover/*")},
 			}
 			rrb, err := New(cfg)
 			if err != nil {

+ 13 - 0
clientv3/balancer/resolver/endpoint/endpoint.go

@@ -103,6 +103,19 @@ func (r *Resolver) InitialAddrs(addrs []resolver.Address) {
 	r.bootstrapAddrs = addrs
 }
 
+func (r *Resolver) InitialEndpoints(eps []string) {
+	r.InitialAddrs(epsToAddrs(eps...))
+}
+
+// TODO: use balancer.epsToAddrs
+func epsToAddrs(eps ...string) (addrs []resolver.Address) {
+	addrs = make([]resolver.Address, 0, len(eps))
+	for _, ep := range eps {
+		addrs = append(addrs, resolver.Address{Addr: ep})
+	}
+	return addrs
+}
+
 // NewAddress updates the addresses of the resolver.
 func (r *Resolver) NewAddress(addrs []resolver.Address) error {
 	if r.cc == nil {

+ 14 - 0
clientv3/balancer/utils.go

@@ -1,3 +1,17 @@
+// Copyright 2018 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
 package balancer
 
 import (

+ 14 - 0
clientv3/balancer/utils_test.go

@@ -1,3 +1,17 @@
+// Copyright 2018 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
 package balancer
 
 import (

+ 62 - 5
pkg/mock/mockserver/mockserver.go

@@ -17,21 +17,36 @@ package mockserver
 import (
 	"context"
 	"fmt"
+	"io/ioutil"
 	"net"
+	"os"
 	"sync"
 
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/resolver"
 )
 
 // MockServer provides a mocked out grpc server of the etcdserver interface.
 type MockServer struct {
 	ln         net.Listener
+	Network    string
 	Address    string
 	GrpcServer *grpc.Server
 }
 
+func (ms *MockServer) ResolverAddress() resolver.Address {
+	switch ms.Network {
+	case "unix":
+		return resolver.Address{Addr: fmt.Sprintf("unix://%s", ms.Address)}
+	case "tcp":
+		return resolver.Address{Addr: ms.Address}
+	default:
+		panic("illegal network type: " + ms.Network)
+	}
+}
+
 // MockServers provides a cluster of mocket out gprc servers of the etcdserver interface.
 type MockServers struct {
 	mu      sync.RWMutex
@@ -42,8 +57,50 @@ type MockServers struct {
 // StartMockServers creates the desired count of mock servers
 // and starts them.
 func StartMockServers(count int) (ms *MockServers, err error) {
+	return StartMockServersOnNetwork(count, "tcp")
+}
+
+// StartMockServersOnNetwork creates mock servers on either 'tcp' or 'unix' sockets.
+func StartMockServersOnNetwork(count int, network string) (ms *MockServers, err error) {
+	switch network {
+	case "tcp":
+		return startMockServersTcp(count)
+	case "unix":
+		return startMockServersUnix(count)
+	default:
+		return nil, fmt.Errorf("unsupported network type: %s", network)
+	}
+}
+
+func startMockServersTcp(count int) (ms *MockServers, err error) {
+	addrs := make([]string, 0, count)
+	for i := 0; i < count; i++ {
+		addrs = append(addrs, "localhost:0")
+	}
+	return startMockServers("tcp", addrs)
+}
+
+func startMockServersUnix(count int) (ms *MockServers, err error) {
+	dir := os.TempDir()
+	addrs := make([]string, 0, count)
+	for i := 0; i < count; i++ {
+		f, err := ioutil.TempFile(dir, "etcd-unix-so-")
+		if err != nil {
+			return nil, fmt.Errorf("failed to allocate temp file for unix socket: %v", err)
+		}
+		fn := f.Name()
+		err = os.Remove(fn)
+		if err != nil {
+			return nil, fmt.Errorf("failed to remove temp file before creating unix socket: %v", err)
+		}
+		addrs = append(addrs, fn)
+	}
+	return startMockServers("unix", addrs)
+}
+
+func startMockServers(network string, addrs []string) (ms *MockServers, err error) {
 	ms = &MockServers{
-		Servers: make([]*MockServer, count),
+		Servers: make([]*MockServer, len(addrs)),
 		wg:      sync.WaitGroup{},
 	}
 	defer func() {
@@ -51,12 +108,12 @@ func StartMockServers(count int) (ms *MockServers, err error) {
 			ms.Stop()
 		}
 	}()
-	for idx := 0; idx < count; idx++ {
-		ln, err := net.Listen("tcp", "localhost:0")
+	for idx, addr := range addrs {
+		ln, err := net.Listen(network, addr)
 		if err != nil {
 			return nil, fmt.Errorf("failed to listen %v", err)
 		}
-		ms.Servers[idx] = &MockServer{ln: ln, Address: ln.Addr().String()}
+		ms.Servers[idx] = &MockServer{ln: ln, Network: network, Address: ln.Addr().String()}
 		ms.StartAt(idx)
 	}
 	return ms, nil
@@ -68,7 +125,7 @@ func (ms *MockServers) StartAt(idx int) (err error) {
 	defer ms.mu.Unlock()
 
 	if ms.Servers[idx].ln == nil {
-		ms.Servers[idx].ln, err = net.Listen("tcp", ms.Servers[idx].Address)
+		ms.Servers[idx].ln, err = net.Listen(ms.Servers[idx].Network, ms.Servers[idx].Address)
 		if err != nil {
 			return fmt.Errorf("failed to listen %v", err)
 		}

+ 12 - 1
vendor/google.golang.org/grpc/clientconn.go

@@ -24,6 +24,7 @@ import (
 	"math"
 	"net"
 	"reflect"
+	"regexp"
 	"strings"
 	"sync"
 	"time"
@@ -443,7 +444,17 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
 	if cc.dopts.copts.Dialer == nil {
 		cc.dopts.copts.Dialer = newProxyDialer(
 			func(ctx context.Context, addr string) (net.Conn, error) {
-				return dialContext(ctx, "tcp", addr)
+				network := "tcp"
+				p := regexp.MustCompile(`[a-z]+://`)
+				if p.MatchString(addr) {
+					parts := strings.Split(addr, "://")
+					scheme := parts[0]
+					if scheme == "unix" && len(parts) > 1 && len(parts[1]) > 0 {
+						network = "unix"
+						addr = parts[1]
+					}
+				}
+				return dialContext(ctx, network, addr)
 			},
 		)
 	}