Browse Source

clientv3: support read conf from file

Xiang Li 9 years ago
parent
commit
802de5f9f8
3 changed files with 246 additions and 21 deletions
  1. 9 21
      clientv3/client.go
  2. 111 0
      clientv3/config.go
  3. 126 0
      clientv3/config_test.go

+ 9 - 21
clientv3/client.go

@@ -15,7 +15,6 @@
 package clientv3
 
 import (
-	"crypto/tls"
 	"errors"
 	"io/ioutil"
 	"log"
@@ -53,26 +52,6 @@ type Client struct {
 	cancel context.CancelFunc
 }
 
-// EndpointDialer is a policy for choosing which endpoint to dial next
-type EndpointDialer func(*Client) (*grpc.ClientConn, error)
-
-type Config struct {
-	// Endpoints is a list of URLs
-	Endpoints []string
-
-	// RetryDialer chooses the next endpoint to use
-	RetryDialer EndpointDialer
-
-	// DialTimeout is the timeout for failing to establish a connection.
-	DialTimeout time.Duration
-
-	// TLS holds the client secure credentials, if any.
-	TLS *tls.Config
-
-	// Logger is the logger used by client library.
-	Logger Logger
-}
-
 // New creates a new etcdv3 client from a given configuration.
 func New(cfg Config) (*Client, error) {
 	if cfg.RetryDialer == nil {
@@ -90,6 +69,15 @@ func NewFromURL(url string) (*Client, error) {
 	return New(Config{Endpoints: []string{url}})
 }
 
+// NewFromConfigFile creates a new etcdv3 client from a configuration file.
+func NewFromConfigFile(path string) (*Client, error) {
+	cfg, err := configFromFile(path)
+	if err != nil {
+		return nil, err
+	}
+	return New(*cfg)
+}
+
 // Close shuts down the client's etcd connections.
 func (c *Client) Close() error {
 	c.mu.Lock()

+ 111 - 0
clientv3/config.go

@@ -0,0 +1,111 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package clientv3
+
+import (
+	"crypto/tls"
+	"crypto/x509"
+	"io/ioutil"
+	"time"
+
+	"github.com/coreos/etcd/pkg/tlsutil"
+	"github.com/ghodss/yaml"
+	"google.golang.org/grpc"
+)
+
+// EndpointDialer is a policy for choosing which endpoint to dial next
+type EndpointDialer func(*Client) (*grpc.ClientConn, error)
+
+type Config struct {
+	// Endpoints is a list of URLs
+	Endpoints []string
+
+	// RetryDialer chooses the next endpoint to use
+	RetryDialer EndpointDialer
+
+	// DialTimeout is the timeout for failing to establish a connection.
+	DialTimeout time.Duration
+
+	// TLS holds the client secure credentials, if any.
+	TLS *tls.Config
+
+	// Logger is the logger used by client library.
+	Logger Logger
+}
+
+type YamlConfig struct {
+	Endpoints             []string      `json:"endpoints"`
+	DialTimeout           time.Duration `json:"dial-timeout"`
+	InsecureTransport     bool          `json:"insecure-transport"`
+	InsecureSkipTLSVerify bool          `json:"insecure-skip-tls-verify"`
+	Certfile              string        `json:"cert-file"`
+	Keyfile               string        `json:"key-file"`
+	CAfile                string        `json:"ca-file"`
+}
+
+func configFromFile(fpath string) (*Config, error) {
+	b, err := ioutil.ReadFile(fpath)
+	if err != nil {
+		return nil, err
+	}
+
+	yc := &YamlConfig{}
+
+	err = yaml.Unmarshal(b, yc)
+	if err != nil {
+		return nil, err
+	}
+
+	cfg := &Config{
+		Endpoints:   yc.Endpoints,
+		DialTimeout: yc.DialTimeout,
+	}
+
+	if yc.InsecureTransport {
+		cfg.TLS = nil
+		return cfg, nil
+	}
+
+	var (
+		cert *tls.Certificate
+		cp   *x509.CertPool
+	)
+
+	if yc.Certfile != "" && yc.Keyfile != "" {
+		cert, err = tlsutil.NewCert(yc.Certfile, yc.Keyfile, nil)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	if yc.CAfile != "" {
+		cp, err = tlsutil.NewCertPool([]string{yc.CAfile})
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	tlscfg := &tls.Config{
+		MinVersion:         tls.VersionTLS10,
+		InsecureSkipVerify: yc.InsecureSkipTLSVerify,
+		RootCAs:            cp,
+	}
+	if cert != nil {
+		tlscfg.Certificates = []tls.Certificate{*cert}
+	}
+	cfg.TLS = tlscfg
+
+	return cfg, nil
+}

+ 126 - 0
clientv3/config_test.go

@@ -0,0 +1,126 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package clientv3
+
+import (
+	"io/ioutil"
+	"log"
+	"os"
+	"reflect"
+	"testing"
+
+	"github.com/ghodss/yaml"
+)
+
+var (
+	certPath       = "../integration/fixtures/server.crt"
+	privateKeyPath = "../integration/fixtures/server.key.insecure"
+	caPath         = "../integration/fixtures/ca.crt"
+)
+
+func TestConfigFromFile(t *testing.T) {
+	tests := []struct {
+		ym *YamlConfig
+
+		werr bool
+	}{
+		{
+			&YamlConfig{},
+			false,
+		},
+		{
+			&YamlConfig{
+				InsecureTransport: true,
+			},
+			false,
+		},
+		{
+			&YamlConfig{
+				Keyfile:               privateKeyPath,
+				Certfile:              certPath,
+				CAfile:                caPath,
+				InsecureSkipTLSVerify: true,
+			},
+			false,
+		},
+		{
+			&YamlConfig{
+				Keyfile:  "bad",
+				Certfile: "bad",
+			},
+			true,
+		},
+		{
+			&YamlConfig{
+				Keyfile:  privateKeyPath,
+				Certfile: certPath,
+				CAfile:   "bad",
+			},
+			true,
+		},
+	}
+
+	for i, tt := range tests {
+		tmpfile, err := ioutil.TempFile("", "clientcfg")
+		if err != nil {
+			log.Fatal(err)
+		}
+
+		b, err := yaml.Marshal(tt.ym)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		_, err = tmpfile.Write(b)
+		if err != nil {
+			t.Fatal(err)
+		}
+		err = tmpfile.Close()
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		cfg, cerr := configFromFile(tmpfile.Name())
+		if cerr != nil && !tt.werr {
+			t.Errorf("#%d: err = %v, want %v", i, cerr, tt.werr)
+			continue
+		}
+		if cerr != nil {
+			continue
+		}
+
+		if !reflect.DeepEqual(cfg.Endpoints, tt.ym.Endpoints) {
+			t.Errorf("#%d: endpoint = %v, want %v", i, cfg.Endpoints, tt.ym.Endpoints)
+		}
+
+		if tt.ym.InsecureTransport != (cfg.TLS == nil) {
+			t.Errorf("#%d: insecureTransport = %v, want %v", i, cfg.TLS == nil, tt.ym.InsecureTransport)
+		}
+
+		if !tt.ym.InsecureTransport {
+			if tt.ym.Certfile != "" && len(cfg.TLS.Certificates) == 0 {
+				t.Errorf("#%d: failed to load in cert", i)
+			}
+			if tt.ym.CAfile != "" && cfg.TLS.RootCAs == nil {
+				t.Errorf("#%d: failed to load in ca cert", i)
+			}
+			if cfg.TLS.InsecureSkipVerify != tt.ym.InsecureSkipTLSVerify {
+				t.Errorf("#%d: skipTLSVeify = %v, want %v", i, cfg.TLS.InsecureSkipVerify, tt.ym.InsecureSkipTLSVerify)
+			}
+		}
+
+		os.Remove(tmpfile.Name())
+	}
+}