Browse Source

Merge pull request #1538 from slaunay/feature/producer-perf-tls

Support TLS protocol in kafka-producer-performance
Vlad Gorodetsky 5 years ago
parent
commit
bb74e49545
1 changed files with 74 additions and 4 deletions
  1. 74 4
      tools/kafka-producer-performance/main.go

+ 74 - 4
tools/kafka-producer-performance/main.go

@@ -3,15 +3,19 @@ package main
 import (
 	"context"
 	"crypto/rand"
+	"crypto/x509"
 	"flag"
 	"fmt"
 	"io"
+	"io/ioutil"
+	"log"
 	"os"
 	"strings"
 	gosync "sync"
 	"time"
 
 	"github.com/Shopify/sarama"
+	"github.com/Shopify/sarama/tools/tls"
 	metrics "github.com/rcrowley/go-metrics"
 )
 
@@ -36,6 +40,31 @@ var (
 		"",
 		"REQUIRED: A comma separated list of broker addresses.",
 	)
+	securityProtocol = flag.String(
+		"security-protocol",
+		"PLAINTEXT",
+		"The name of the security protocol to talk to Kafka (PLAINTEXT, SSL) (default: PLAINTEXT).",
+	)
+	tlsRootCACerts = flag.String(
+		"tls-ca-certs",
+		"",
+		"The path to a file that contains a set of root certificate authorities in PEM format "+
+			"to trust when verifying broker certificates when -security-protocol=SSL "+
+			"(leave empty to use the host's root CA set).",
+	)
+	tlsClientCert = flag.String(
+		"tls-client-cert",
+		"",
+		"The path to a file that contains the client certificate to send to the broker "+
+			"in PEM format if client authentication is required when -security-protocol=SSL "+
+			"(leave empty to disable client authentication).",
+	)
+	tlsClientKey = flag.String(
+		"tls-client-key",
+		"",
+		"The path to a file that contains the client private key linked to the client certificate "+
+			"in PEM format when -security-protocol=SSL (REQUIRED if tls-client-cert is provided).",
+	)
 	topic = flag.String(
 		"topic",
 		"",
@@ -126,6 +155,11 @@ var (
 		"0.8.2.0",
 		"The assumed version of Kafka.",
 	)
+	verbose = flag.Bool(
+		"verbose",
+		false,
+		"Turn on sarama logging to stderr",
+	)
 )
 
 func parseCompression(scheme string) sarama.CompressionCodec {
@@ -205,6 +239,12 @@ func main() {
 	if *routines < 1 || *routines > *messageLoad {
 		printUsageErrorAndExit("-routines must be greater than 0 and less than or equal to -message-load")
 	}
+	if *securityProtocol != "PLAINTEXT" && *securityProtocol != "SSL" {
+		printUsageErrorAndExit(fmt.Sprintf("-security-protocol %q is not supported", *securityProtocol))
+	}
+	if *verbose {
+		sarama.Logger = log.New(os.Stderr, "", log.LstdFlags)
+	}
 
 	config := sarama.NewConfig()
 
@@ -222,6 +262,30 @@ func main() {
 	config.ChannelBufferSize = *channelBufferSize
 	config.Version = parseVersion(*version)
 
+	if *securityProtocol == "SSL" {
+		tlsConfig, err := tls.NewConfig(*tlsClientCert, *tlsClientKey)
+		if err != nil {
+			printErrorAndExit(69, "failed to load client certificate from: %s and private key from: %s: %v",
+				*tlsClientCert, *tlsClientKey, err)
+		}
+
+		if *tlsRootCACerts != "" {
+			rootCAsBytes, err := ioutil.ReadFile(*tlsRootCACerts)
+			if err != nil {
+				printErrorAndExit(69, "failed to read root CA certificates: %v", err)
+			}
+			certPool := x509.NewCertPool()
+			if !certPool.AppendCertsFromPEM(rootCAsBytes) {
+				printErrorAndExit(69, "failed to load root CA certificates from file: %s", *tlsRootCACerts)
+			}
+			// Use specific root CA set vs the host's set
+			tlsConfig.RootCAs = certPool
+		}
+
+		config.Net.TLS.Enable = true
+		config.Net.TLS.Config = tlsConfig
+	}
+
 	if err := config.Validate(); err != nil {
 		printErrorAndExit(69, "Invalid configuration: %s", err)
 	}
@@ -363,18 +427,24 @@ func runSyncProducer(topic string, partition, messageLoad, messageSize, routines
 }
 
 func printMetrics(w io.Writer, r metrics.Registry) {
-	if r.Get("record-send-rate") == nil || r.Get("request-latency-in-ms") == nil {
+	recordSendRateMetric := r.Get("record-send-rate")
+	requestLatencyMetric := r.Get("request-latency-in-ms")
+	outgoingByteRateMetric := r.Get("outgoing-byte-rate")
+
+	if recordSendRateMetric == nil || requestLatencyMetric == nil || outgoingByteRateMetric == nil {
 		return
 	}
-	recordSendRate := r.Get("record-send-rate").(metrics.Meter).Snapshot()
-	requestLatency := r.Get("request-latency-in-ms").(metrics.Histogram).Snapshot()
+	recordSendRate := recordSendRateMetric.(metrics.Meter).Snapshot()
+	requestLatency := requestLatencyMetric.(metrics.Histogram).Snapshot()
 	requestLatencyPercentiles := requestLatency.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999})
-	fmt.Fprintf(w, "%d records sent, %.1f records/sec (%.2f MB/sec), "+
+	outgoingByteRate := outgoingByteRateMetric.(metrics.Meter).Snapshot()
+	fmt.Fprintf(w, "%d records sent, %.1f records/sec (%.2f MiB/sec ingress, %.2f MiB/sec egress), "+
 		"%.1f ms avg latency, %.1f ms stddev, %.1f ms 50th, %.1f ms 75th, "+
 		"%.1f ms 95th, %.1f ms 99th, %.1f ms 99.9th\n",
 		recordSendRate.Count(),
 		recordSendRate.RateMean(),
 		recordSendRate.RateMean()*float64(*messageSize)/1024/1024,
+		outgoingByteRate.RateMean()/1024/1024,
 		requestLatency.Mean(),
 		requestLatency.StdDev(),
 		requestLatencyPercentiles[0],