Browse Source

Merge pull request #7689 from mitake/bench-leader

benchmark: a new flag --target-leader for targetting a leader endpoint
Hitoshi Mitake 8 years ago
parent
commit
e7e7451213
2 changed files with 48 additions and 3 deletions
  1. 4 0
      tools/benchmark/cmd/root.go
  2. 44 3
      tools/benchmark/cmd/util.go

+ 4 - 0
tools/benchmark/cmd/root.go

@@ -52,6 +52,8 @@ var (
 	user string
 
 	dialTimeout time.Duration
+
+	targetLeader bool
 )
 
 func init() {
@@ -67,4 +69,6 @@ func init() {
 
 	RootCmd.PersistentFlags().StringVar(&user, "user", "", "specify username and password in username:password format")
 	RootCmd.PersistentFlags().DurationVar(&dialTimeout, "dial-timeout", 0, "dial timeout for client connections")
+
+	RootCmd.PersistentFlags().BoolVar(&targetLeader, "target-leader", false, "connect only to the leader node")
 }

+ 44 - 3
tools/benchmark/cmd/util.go

@@ -23,19 +23,53 @@ import (
 
 	"github.com/coreos/etcd/clientv3"
 	"github.com/coreos/etcd/pkg/report"
+	"golang.org/x/net/context"
 )
 
 var (
 	// dialTotal counts the number of mustCreateConn calls so that endpoint
 	// connections can be handed out in round-robin order
 	dialTotal int
+
+	// leaderEps is a cache for holding endpoints of a leader node
+	leaderEps []string
 )
 
+func mustFindLeaderEndpoints(c *clientv3.Client) {
+	resp, lerr := c.MemberList(context.TODO())
+	if lerr != nil {
+		fmt.Fprintf(os.Stderr, "failed to get a member list: %s\n", lerr)
+		os.Exit(1)
+	}
+
+	leaderId := uint64(0)
+	for _, ep := range c.Endpoints() {
+		resp, serr := c.Status(context.TODO(), ep)
+		if serr == nil {
+			leaderId = resp.Leader
+			break
+		}
+	}
+
+	for _, m := range resp.Members {
+		if m.ID == leaderId {
+			leaderEps = m.ClientURLs
+			return
+		}
+	}
+
+	fmt.Fprintf(os.Stderr, "failed to find a leader endpoint\n")
+	os.Exit(1)
+}
+
 func mustCreateConn() *clientv3.Client {
-	endpoint := endpoints[dialTotal%len(endpoints)]
-	dialTotal++
+	connEndpoints := leaderEps
+	if len(connEndpoints) == 0 {
+		connEndpoints = []string{endpoints[dialTotal%len(endpoints)]}
+		dialTotal++
+	}
 	cfg := clientv3.Config{
-		Endpoints:   []string{endpoint},
+		Endpoints:   connEndpoints,
 		DialTimeout: dialTimeout,
 	}
 	if !tls.Empty() {
@@ -59,12 +93,19 @@ func mustCreateConn() *clientv3.Client {
 	}
 
 	client, err := clientv3.New(cfg)
+	if targetLeader && len(leaderEps) == 0 {
+		mustFindLeaderEndpoints(client)
+		client.Close()
+		return mustCreateConn()
+	}
+
 	clientv3.SetLogger(log.New(os.Stderr, "grpc", 0))
 
 	if err != nil {
 		fmt.Fprintf(os.Stderr, "dial error: %v\n", err)
 		os.Exit(1)
 	}
+
 	return client
 }