Browse Source

Merge pull request #1639 from bcwaldon/etcdctl-tls

Wire up TLS flags for etcdctl
Brian Waldon 11 years ago
parent
commit
5f6e536be8

+ 12 - 5
client/http.go

@@ -40,6 +40,7 @@ var (
 type SyncableHTTPClient interface {
 type SyncableHTTPClient interface {
 	HTTPClient
 	HTTPClient
 	Sync(context.Context) error
 	Sync(context.Context) error
+	Endpoints() []string
 }
 }
 
 
 type HTTPClient interface {
 type HTTPClient interface {
@@ -65,7 +66,8 @@ func NewHTTPClient(tr CancelableTransport, eps []string) (SyncableHTTPClient, er
 func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterClient, error) {
 func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterClient, error) {
 	c := httpClusterClient{
 	c := httpClusterClient{
 		transport: tr,
 		transport: tr,
-		endpoints: make([]HTTPClient, len(eps)),
+		endpoints: eps,
+		clients:   make([]HTTPClient, len(eps)),
 	}
 	}
 
 
 	for i, ep := range eps {
 	for i, ep := range eps {
@@ -74,7 +76,7 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		c.endpoints[i] = &redirectFollowingHTTPClient{
+		c.clients[i] = &redirectFollowingHTTPClient{
 			max: DefaultMaxRedirects,
 			max: DefaultMaxRedirects,
 			client: &httpClient{
 			client: &httpClient{
 				transport: tr,
 				transport: tr,
@@ -88,14 +90,15 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli
 
 
 type httpClusterClient struct {
 type httpClusterClient struct {
 	transport CancelableTransport
 	transport CancelableTransport
-	endpoints []HTTPClient
+	endpoints []string
+	clients   []HTTPClient
 }
 }
 
 
 func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.Response, body []byte, err error) {
 func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.Response, body []byte, err error) {
-	if len(c.endpoints) == 0 {
+	if len(c.clients) == 0 {
 		return nil, nil, ErrNoEndpoints
 		return nil, nil, ErrNoEndpoints
 	}
 	}
-	for _, hc := range c.endpoints {
+	for _, hc := range c.clients {
 		resp, body, err = hc.Do(ctx, act)
 		resp, body, err = hc.Do(ctx, act)
 		if err != nil {
 		if err != nil {
 			if err == ErrTimeout || err == ErrCanceled {
 			if err == ErrTimeout || err == ErrCanceled {
@@ -111,6 +114,10 @@ func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.
 	return
 	return
 }
 }
 
 
+func (c *httpClusterClient) Endpoints() []string {
+	return c.endpoints
+}
+
 func (c *httpClusterClient) Sync(ctx context.Context) error {
 func (c *httpClusterClient) Sync(ctx context.Context) error {
 	mAPI := NewMembersAPI(c)
 	mAPI := NewMembersAPI(c)
 	ms, err := mAPI.List(ctx)
 	ms, err := mAPI.List(ctx)

+ 7 - 7
client/http_test.go

@@ -193,7 +193,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// first good response short-circuits Do
 		// first good response short-circuits Do
 		{
 		{
 			client: &httpClusterClient{
 			client: &httpClusterClient{
-				endpoints: []HTTPClient{
+				clients: []HTTPClient{
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 					&staticHTTPClient{err: fakeErr},
 					&staticHTTPClient{err: fakeErr},
 				},
 				},
@@ -204,7 +204,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// fall through to good endpoint if err is arbitrary
 		// fall through to good endpoint if err is arbitrary
 		{
 		{
 			client: &httpClusterClient{
 			client: &httpClusterClient{
-				endpoints: []HTTPClient{
+				clients: []HTTPClient{
 					&staticHTTPClient{err: fakeErr},
 					&staticHTTPClient{err: fakeErr},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 				},
 				},
@@ -215,7 +215,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// ErrTimeout short-circuits Do
 		// ErrTimeout short-circuits Do
 		{
 		{
 			client: &httpClusterClient{
 			client: &httpClusterClient{
-				endpoints: []HTTPClient{
+				clients: []HTTPClient{
 					&staticHTTPClient{err: ErrTimeout},
 					&staticHTTPClient{err: ErrTimeout},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 				},
 				},
@@ -226,7 +226,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// ErrCanceled short-circuits Do
 		// ErrCanceled short-circuits Do
 		{
 		{
 			client: &httpClusterClient{
 			client: &httpClusterClient{
-				endpoints: []HTTPClient{
+				clients: []HTTPClient{
 					&staticHTTPClient{err: ErrCanceled},
 					&staticHTTPClient{err: ErrCanceled},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 				},
 				},
@@ -237,7 +237,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// return err if there are no endpoints
 		// return err if there are no endpoints
 		{
 		{
 			client: &httpClusterClient{
 			client: &httpClusterClient{
-				endpoints: []HTTPClient{},
+				clients: []HTTPClient{},
 			},
 			},
 			wantErr: ErrNoEndpoints,
 			wantErr: ErrNoEndpoints,
 		},
 		},
@@ -245,7 +245,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// return err if all endpoints return arbitrary errors
 		// return err if all endpoints return arbitrary errors
 		{
 		{
 			client: &httpClusterClient{
 			client: &httpClusterClient{
-				endpoints: []HTTPClient{
+				clients: []HTTPClient{
 					&staticHTTPClient{err: fakeErr},
 					&staticHTTPClient{err: fakeErr},
 					&staticHTTPClient{err: fakeErr},
 					&staticHTTPClient{err: fakeErr},
 				},
 				},
@@ -256,7 +256,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// 500-level errors cause Do to fallthrough to next endpoint
 		// 500-level errors cause Do to fallthrough to next endpoint
 		{
 		{
 			client: &httpClusterClient{
 			client: &httpClusterClient{
-				endpoints: []HTTPClient{
+				clients: []HTTPClient{
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
 				},
 				},

+ 13 - 51
etcdctl/command/handle.go

@@ -20,7 +20,6 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"net/url"
 	"os"
 	"os"
 	"strings"
 	"strings"
 
 
@@ -40,72 +39,35 @@ func dumpCURL(client *etcd.Client) {
 	}
 	}
 }
 }
 
 
-// createHttpPath attaches http scheme to the given address if needed
-func createHttpPath(addr string) (string, error) {
-	u, err := url.Parse(addr)
-	if err != nil {
-		return "", err
-	}
-
-	if u.Scheme == "" {
-		u.Scheme = "http"
-	}
-	return u.String(), nil
-}
-
-func getPeersFlagValue(c *cli.Context) []string {
-	peerstr := c.GlobalString("peers")
-
-	// Use an environment variable if nothing was supplied on the
-	// command line
-	if peerstr == "" {
-		peerstr = os.Getenv("ETCDCTL_PEERS")
-	}
-
-	// If we still don't have peers, use a default
-	if peerstr == "" {
-		peerstr = "127.0.0.1:4001"
-	}
-
-	return strings.Split(peerstr, ",")
-}
-
 // rawhandle wraps the command function handlers and sets up the
 // rawhandle wraps the command function handlers and sets up the
 // environment but performs no output formatting.
 // environment but performs no output formatting.
 func rawhandle(c *cli.Context, fn handlerFunc) (*etcd.Response, error) {
 func rawhandle(c *cli.Context, fn handlerFunc) (*etcd.Response, error) {
-	sync := !c.GlobalBool("no-sync")
-
-	peers := getPeersFlagValue(c)
-
-	// If no sync, create http path for each peer address
-	if !sync {
-		revisedPeers := make([]string, 0)
-		for _, peer := range peers {
-			if revisedPeer, err := createHttpPath(peer); err != nil {
-				fmt.Fprintf(os.Stderr, "Unsupported url %v: %v\n", peer, err)
-			} else {
-				revisedPeers = append(revisedPeers, revisedPeer)
-			}
-		}
-		peers = revisedPeers
+	endpoints, err := getEndpoints(c)
+	if err != nil {
+		return nil, err
+	}
+
+	tr, err := getTransport(c)
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
-	client := etcd.NewClient(peers)
+	client := etcd.NewClient(endpoints)
+	client.SetTransport(tr)
 
 
 	if c.GlobalBool("debug") {
 	if c.GlobalBool("debug") {
 		go dumpCURL(client)
 		go dumpCURL(client)
 	}
 	}
 
 
 	// Sync cluster.
 	// Sync cluster.
-	if sync {
+	if !c.GlobalBool("no-sync") {
 		if ok := client.SyncCluster(); !ok {
 		if ok := client.SyncCluster(); !ok {
-			handleError(FailedToConnectToHost, errors.New("Cannot sync with the cluster using peers "+strings.Join(peers, ", ")))
+			handleError(FailedToConnectToHost, errors.New("cannot sync with the cluster using endpoints "+strings.Join(endpoints, ", ")))
 		}
 		}
 	}
 	}
 
 
 	if c.GlobalBool("debug") {
 	if c.GlobalBool("debug") {
-		fmt.Fprintf(os.Stderr, "Cluster-Peers: %s\n",
-			strings.Join(client.GetCluster(), " "))
+		fmt.Fprintf(os.Stderr, "Cluster-Endpoints: %s\n", strings.Join(client.GetCluster(), ", "))
 	}
 	}
 
 
 	// Execute handler function.
 	// Execute handler function.

+ 15 - 7
etcdctl/command/member_commands.go

@@ -18,7 +18,6 @@ package command
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"net/http"
 	"os"
 	"os"
 	"strings"
 	"strings"
 
 
@@ -52,14 +51,19 @@ func NewMemberCommand() cli.Command {
 }
 }
 
 
 func mustNewMembersAPI(c *cli.Context) client.MembersAPI {
 func mustNewMembersAPI(c *cli.Context) client.MembersAPI {
-	peers := getPeersFlagValue(c)
-	for i, p := range peers {
-		if !strings.HasPrefix(p, "http") && !strings.HasPrefix(p, "https") {
-			peers[i] = fmt.Sprintf("http://%s", p)
-		}
+	eps, err := getEndpoints(c)
+	if err != nil {
+		fmt.Fprintln(os.Stderr, err.Error())
+		os.Exit(1)
 	}
 	}
 
 
-	hc, err := client.NewHTTPClient(&http.Transport{}, peers)
+	tr, err := getTransport(c)
+	if err != nil {
+		fmt.Fprintln(os.Stderr, err.Error())
+		os.Exit(1)
+	}
+
+	hc, err := client.NewHTTPClient(tr, eps)
 	if err != nil {
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err.Error())
 		fmt.Fprintln(os.Stderr, err.Error())
 		os.Exit(1)
 		os.Exit(1)
@@ -75,6 +79,10 @@ func mustNewMembersAPI(c *cli.Context) client.MembersAPI {
 		}
 		}
 	}
 	}
 
 
+	if c.GlobalBool("debug") {
+		fmt.Fprintf(os.Stderr, "Cluster-Endpoints: %s\n", strings.Join(hc.Endpoints(), ", "))
+	}
+
 	return client.NewMembersAPI(hc)
 	return client.NewMembersAPI(hc)
 }
 }
 
 

+ 50 - 0
etcdctl/command/util.go

@@ -20,7 +20,13 @@ import (
 	"errors"
 	"errors"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
+	"net/http"
+	"net/url"
+	"os"
 	"strings"
 	"strings"
+
+	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
+	"github.com/coreos/etcd/pkg/transport"
 )
 )
 
 
 var (
 var (
@@ -49,3 +55,47 @@ func argOrStdin(args []string, stdin io.Reader, i int) (string, error) {
 	}
 	}
 	return string(bytes), nil
 	return string(bytes), nil
 }
 }
+
+func getPeersFlagValue(c *cli.Context) []string {
+	peerstr := c.GlobalString("peers")
+
+	// Use an environment variable if nothing was supplied on the
+	// command line
+	if peerstr == "" {
+		peerstr = os.Getenv("ETCDCTL_PEERS")
+	}
+
+	// If we still don't have peers, use a default
+	if peerstr == "" {
+		peerstr = "127.0.0.1:4001"
+	}
+
+	return strings.Split(peerstr, ",")
+}
+
+func getEndpoints(c *cli.Context) ([]string, error) {
+	eps := getPeersFlagValue(c)
+	for i, ep := range eps {
+		u, err := url.Parse(ep)
+		if err != nil {
+			return nil, err
+		}
+
+		if u.Scheme == "" {
+			u.Scheme = "http"
+		}
+
+		eps[i] = u.String()
+	}
+	return eps, nil
+}
+
+func getTransport(c *cli.Context) (*http.Transport, error) {
+	tls := transport.TLSInfo{
+		CAFile:   c.GlobalString("ca-file"),
+		CertFile: c.GlobalString("cert-file"),
+		KeyFile:  c.GlobalString("key-file"),
+	}
+	return transport.NewTransport(tls)
+
+}

+ 3 - 0
etcdctl/main.go

@@ -35,6 +35,9 @@ func main() {
 		cli.BoolFlag{Name: "no-sync", Usage: "don't synchronize cluster information before sending request"},
 		cli.BoolFlag{Name: "no-sync", Usage: "don't synchronize cluster information before sending request"},
 		cli.StringFlag{Name: "output, o", Value: "simple", Usage: "output response in the given format (`simple` or `json`)"},
 		cli.StringFlag{Name: "output, o", Value: "simple", Usage: "output response in the given format (`simple` or `json`)"},
 		cli.StringFlag{Name: "peers, C", Value: "", Usage: "a comma-delimited list of machine addresses in the cluster (default: \"127.0.0.1:4001\")"},
 		cli.StringFlag{Name: "peers, C", Value: "", Usage: "a comma-delimited list of machine addresses in the cluster (default: \"127.0.0.1:4001\")"},
+		cli.StringFlag{Name: "cert-file", Value: "", Usage: "identify HTTPS client using this SSL certificate file"},
+		cli.StringFlag{Name: "key-file", Value: "", Usage: "identify HTTPS client using this SSL key file"},
+		cli.StringFlag{Name: "ca-file", Value: "", Usage: "verify certificates of HTTPS-enabled servers using this CA bundle"},
 	}
 	}
 	app.Commands = []cli.Command{
 	app.Commands = []cli.Command{
 		command.NewMakeCommand(),
 		command.NewMakeCommand(),

+ 17 - 17
pkg/transport/listener.go

@@ -46,6 +46,11 @@ func NewListener(addr string, info TLSInfo) (net.Listener, error) {
 }
 }
 
 
 func NewTransport(info TLSInfo) (*http.Transport, error) {
 func NewTransport(info TLSInfo) (*http.Transport, error) {
+	cfg, err := info.ClientConfig()
+	if err != nil {
+		return nil, err
+	}
+
 	t := &http.Transport{
 	t := &http.Transport{
 		// timeouts taken from http.DefaultTransport
 		// timeouts taken from http.DefaultTransport
 		Dial: (&net.Dialer{
 		Dial: (&net.Dialer{
@@ -53,14 +58,7 @@ func NewTransport(info TLSInfo) (*http.Transport, error) {
 			KeepAlive: 30 * time.Second,
 			KeepAlive: 30 * time.Second,
 		}).Dial,
 		}).Dial,
 		TLSHandshakeTimeout: 10 * time.Second,
 		TLSHandshakeTimeout: 10 * time.Second,
-	}
-
-	if !info.Empty() {
-		tlsCfg, err := info.ClientConfig()
-		if err != nil {
-			return nil, err
-		}
-		t.TLSClientConfig = tlsCfg
+		TLSClientConfig:     cfg,
 	}
 	}
 
 
 	return t, nil
 	return t, nil
@@ -134,22 +132,24 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) {
 }
 }
 
 
 // ClientConfig generates a tls.Config object for use by an HTTP client
 // ClientConfig generates a tls.Config object for use by an HTTP client
-func (info TLSInfo) ClientConfig() (*tls.Config, error) {
-	cfg, err := info.baseConfig()
-	if err != nil {
-		return nil, err
+func (info TLSInfo) ClientConfig() (cfg *tls.Config, err error) {
+	if !info.Empty() {
+		cfg, err = info.baseConfig()
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		cfg = &tls.Config{}
 	}
 	}
 
 
 	if info.CAFile != "" {
 	if info.CAFile != "" {
-		cp, err := newCertPool(info.CAFile)
+		cfg.RootCAs, err = newCertPool(info.CAFile)
 		if err != nil {
 		if err != nil {
-			return nil, err
+			return
 		}
 		}
-
-		cfg.RootCAs = cp
 	}
 	}
 
 
-	return cfg, nil
+	return
 }
 }
 
 
 // newCertPool creates x509 certPool with provided CA file
 // newCertPool creates x509 certPool with provided CA file

+ 15 - 27
pkg/transport/listener_test.go

@@ -51,41 +51,31 @@ func TestNewTransportTLSInfo(t *testing.T) {
 	}
 	}
 	defer os.Remove(tmp)
 	defer os.Remove(tmp)
 
 
-	tests := []struct {
-		info                TLSInfo
-		wantTLSClientConfig bool
-	}{
-		{
-			info:                TLSInfo{},
-			wantTLSClientConfig: false,
+	tests := []TLSInfo{
+		TLSInfo{},
+		TLSInfo{
+			CertFile: tmp,
+			KeyFile:  tmp,
 		},
 		},
-		{
-			info: TLSInfo{
-				CertFile: tmp,
-				KeyFile:  tmp,
-			},
-			wantTLSClientConfig: true,
+		TLSInfo{
+			CertFile: tmp,
+			KeyFile:  tmp,
+			CAFile:   tmp,
 		},
 		},
-		{
-			info: TLSInfo{
-				CertFile: tmp,
-				KeyFile:  tmp,
-				CAFile:   tmp,
-			},
-			wantTLSClientConfig: true,
+		TLSInfo{
+			CAFile: tmp,
 		},
 		},
 	}
 	}
 
 
 	for i, tt := range tests {
 	for i, tt := range tests {
-		tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
-		trans, err := NewTransport(tt.info)
+		tt.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
+		trans, err := NewTransport(tt)
 		if err != nil {
 		if err != nil {
 			t.Fatalf("Received unexpected error from NewTransport: %v", err)
 			t.Fatalf("Received unexpected error from NewTransport: %v", err)
 		}
 		}
 
 
-		gotTLSClientConfig := trans.TLSClientConfig != nil
-		if tt.wantTLSClientConfig != gotTLSClientConfig {
-			t.Fatalf("#%d: wantTLSClientConfig=%t but gotTLSClientConfig=%t", i, tt.wantTLSClientConfig, gotTLSClientConfig)
+		if trans.TLSClientConfig == nil {
+			t.Fatalf("#%d: want non-nil TLSClientConfig", i)
 		}
 		}
 	}
 	}
 }
 }
@@ -121,8 +111,6 @@ func TestTLSInfoMissingFields(t *testing.T) {
 	defer os.Remove(tmp)
 	defer os.Remove(tmp)
 
 
 	tests := []TLSInfo{
 	tests := []TLSInfo{
-		TLSInfo{},
-		TLSInfo{CAFile: tmp},
 		TLSInfo{CertFile: tmp},
 		TLSInfo{CertFile: tmp},
 		TLSInfo{KeyFile: tmp},
 		TLSInfo{KeyFile: tmp},
 		TLSInfo{CertFile: tmp, CAFile: tmp},
 		TLSInfo{CertFile: tmp, CAFile: tmp},