Browse Source

tools/benchmark: support tls

Anthony Romano 10 years ago
parent
commit
4380617e1a
2 changed files with 17 additions and 1 deletions
  1. 7 0
      tools/benchmark/cmd/root.go
  2. 10 1
      tools/benchmark/cmd/util.go

+ 7 - 0
tools/benchmark/cmd/root.go

@@ -19,6 +19,7 @@ import (
 
 
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/cheggaaa/pb"
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/cheggaaa/pb"
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/spf13/cobra"
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/spf13/cobra"
+	"github.com/coreos/etcd/pkg/transport"
 )
 )
 
 
 // This represents the base command when called without any subcommands
 // This represents the base command when called without any subcommands
@@ -40,6 +41,8 @@ var (
 	results chan result
 	results chan result
 	wg      sync.WaitGroup
 	wg      sync.WaitGroup
 
 
+	tls transport.TLSInfo
+
 	cpuProfPath string
 	cpuProfPath string
 	memProfPath string
 	memProfPath string
 )
 )
@@ -48,4 +51,8 @@ func init() {
 	RootCmd.PersistentFlags().StringVar(&endpoints, "endpoint", "127.0.0.1:2378", "comma-separated gRPC endpoints")
 	RootCmd.PersistentFlags().StringVar(&endpoints, "endpoint", "127.0.0.1:2378", "comma-separated gRPC endpoints")
 	RootCmd.PersistentFlags().UintVar(&totalConns, "conns", 1, "Total number of gRPC connections")
 	RootCmd.PersistentFlags().UintVar(&totalConns, "conns", 1, "Total number of gRPC connections")
 	RootCmd.PersistentFlags().UintVar(&totalClients, "clients", 1, "Total number of gRPC clients")
 	RootCmd.PersistentFlags().UintVar(&totalClients, "clients", 1, "Total number of gRPC clients")
+
+	RootCmd.PersistentFlags().StringVar(&tls.CertFile, "cert", "", "identify HTTPS client using this SSL certificate file")
+	RootCmd.PersistentFlags().StringVar(&tls.KeyFile, "key", "", "identify HTTPS client using this SSL key file")
+	RootCmd.PersistentFlags().StringVar(&tls.CAFile, "cacert", "", "verify certificates of HTTPS-enabled servers using this CA bundle")
 }
 }

+ 10 - 1
tools/benchmark/cmd/util.go

@@ -33,7 +33,16 @@ func mustCreateConn() *clientv3.Client {
 	eps := strings.Split(endpoints, ",")
 	eps := strings.Split(endpoints, ",")
 	endpoint := eps[dialTotal%len(eps)]
 	endpoint := eps[dialTotal%len(eps)]
 	dialTotal++
 	dialTotal++
-	client, err := clientv3.NewFromURL(endpoint)
+	cfgtls := &tls
+	if cfgtls.Empty() {
+		cfgtls = nil
+	}
+	client, err := clientv3.New(
+		clientv3.Config{
+			Endpoints: []string{endpoint},
+			TLS:       cfgtls,
+		},
+	)
 	if err != nil {
 	if err != nil {
 		fmt.Fprintf(os.Stderr, "dial error: %v\n", err)
 		fmt.Fprintf(os.Stderr, "dial error: %v\n", err)
 		os.Exit(1)
 		os.Exit(1)