瀏覽代碼

Merge pull request #6560 from gyuho/scheme

clientv3: handle 'https' scheme in endpoint
Gyu-Ho Lee 9 年之前
父節點
當前提交
dd607b5eff
共有 2 個文件被更改,包括 78 次插入6 次删除
  1. 6 6
      clientv3/client.go
  2. 72 0
      clientv3/integration/dial_test.go

+ 6 - 6
clientv3/client.go

@@ -151,14 +151,14 @@ func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...str
 	}, nil
 }
 
-func parseEndpoint(endpoint string) (proto string, host string, scheme bool) {
+func parseEndpoint(endpoint string) (proto string, host string, scheme string) {
 	proto = "tcp"
 	host = endpoint
 	url, uerr := url.Parse(endpoint)
 	if uerr != nil || !strings.Contains(endpoint, "://") {
 		return
 	}
-	scheme = true
+	scheme = url.Scheme
 
 	// strip scheme:// prefix since grpc dials by host
 	host = url.Host
@@ -172,9 +172,9 @@ func parseEndpoint(endpoint string) (proto string, host string, scheme bool) {
 	return
 }
 
-func (c *Client) processCreds(protocol string) (creds *credentials.TransportCredentials) {
+func (c *Client) processCreds(scheme string) (creds *credentials.TransportCredentials) {
 	creds = c.creds
-	switch protocol {
+	switch scheme {
 	case "unix":
 	case "http":
 		creds = nil
@@ -213,8 +213,8 @@ func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts
 	opts = append(opts, grpc.WithDialer(f))
 
 	creds := c.creds
-	if proto, _, scheme := parseEndpoint(endpoint); scheme {
-		creds = c.processCreds(proto)
+	if _, _, scheme := parseEndpoint(endpoint); len(scheme) != 0 {
+		creds = c.processCreds(scheme)
 	}
 	if creds != nil {
 		opts = append(opts, grpc.WithTransportCredentials(*creds))

+ 72 - 0
clientv3/integration/dial_test.go

@@ -15,11 +15,17 @@
 package integration
 
 import (
+	"fmt"
+	"io/ioutil"
 	"math/rand"
+	"net/url"
+	"os"
+	"sync"
 	"testing"
 	"time"
 
 	"github.com/coreos/etcd/clientv3"
+	"github.com/coreos/etcd/embed"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/pkg/testutil"
 	"golang.org/x/net/context"
@@ -58,3 +64,69 @@ func TestDialSetEndpoints(t *testing.T) {
 	}
 	cancel()
 }
+
+var (
+	testMu   sync.Mutex
+	testPort = 31000
+)
+
+// TestDialWithHTTPS ensures that client can handle 'https' scheme in endpoints.
+func TestDialWithHTTPS(t *testing.T) {
+	defer testutil.AfterTest(t)
+
+	testMu.Lock()
+	port := testPort
+	testPort += 10 // to avoid port conflicts
+	testMu.Unlock()
+
+	dir, err := ioutil.TempDir(os.TempDir(), "dial-test")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(dir)
+
+	// set up single-node cluster with client auto TLS
+	cfg := embed.NewConfig()
+	cfg.Dir = dir
+
+	cfg.ClientAutoTLS = true
+	clientURL := url.URL{Scheme: "https", Host: fmt.Sprintf("localhost:%d", port)}
+	cfg.LCUrls, cfg.ACUrls = []url.URL{clientURL}, []url.URL{clientURL}
+
+	peerURL := url.URL{Scheme: "http", Host: fmt.Sprintf("localhost:%d", port+1)}
+	cfg.LPUrls, cfg.APUrls = []url.URL{peerURL}, []url.URL{peerURL}
+	cfg.InitialCluster = cfg.Name + "=" + peerURL.String()
+
+	srv, err := embed.StartEtcd(cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	nc := srv.Config() // overwrite config after processing ClientTLSInfo
+	cfg = &nc
+
+	<-srv.Server.ReadyNotify()
+	defer func() {
+		srv.Close()
+		<-srv.Err()
+	}()
+
+	// wait for leader election to finish
+	time.Sleep(500 * time.Millisecond)
+
+	ccfg := clientv3.Config{Endpoints: []string{clientURL.String()}}
+	tcfg, err := cfg.ClientTLSInfo.ClientConfig()
+	if err != nil {
+		t.Fatal(err)
+	}
+	ccfg.TLS = tcfg
+
+	cli, err := clientv3.New(ccfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cli.Close()
+
+	if _, err = cli.Get(context.Background(), "foo"); err != nil {
+		t.Fatal(err)
+	}
+}