123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- package main
- import (
- "context"
- "crypto/rand"
- "crypto/x509"
- "flag"
- "fmt"
- "io"
- "io/ioutil"
- "log"
- "os"
- "strings"
- gosync "sync"
- "time"
- metrics "github.com/rcrowley/go-metrics"
- "github.com/Shopify/sarama"
- "github.com/Shopify/sarama/tools/tls"
- )
- var (
- sync = flag.Bool(
- "sync",
- false,
- "Use a synchronous producer.",
- )
- messageLoad = flag.Int(
- "message-load",
- 0,
- "REQUIRED: The number of messages to produce to -topic.",
- )
- messageSize = flag.Int(
- "message-size",
- 0,
- "REQUIRED: The approximate size (in bytes) of each message to produce to -topic.",
- )
- brokers = flag.String(
- "brokers",
- "",
- "REQUIRED: A comma separated list of broker addresses.",
- )
- securityProtocol = flag.String(
- "security-protocol",
- "PLAINTEXT",
- "The name of the security protocol to talk to Kafka (PLAINTEXT, SSL) (default: PLAINTEXT).",
- )
- tlsRootCACerts = flag.String(
- "tls-ca-certs",
- "",
- "The path to a file that contains a set of root certificate authorities in PEM format "+
- "to trust when verifying broker certificates when -security-protocol=SSL "+
- "(leave empty to use the host's root CA set).",
- )
- tlsClientCert = flag.String(
- "tls-client-cert",
- "",
- "The path to a file that contains the client certificate to send to the broker "+
- "in PEM format if client authentication is required when -security-protocol=SSL "+
- "(leave empty to disable client authentication).",
- )
- tlsClientKey = flag.String(
- "tls-client-key",
- "",
- "The path to a file that contains the client private key linked to the client certificate "+
- "in PEM format when -security-protocol=SSL (REQUIRED if tls-client-cert is provided).",
- )
- topic = flag.String(
- "topic",
- "",
- "REQUIRED: The topic to run the performance test on.",
- )
- partition = flag.Int(
- "partition",
- -1,
- "The partition of -topic to run the performance test on.",
- )
- throughput = flag.Int(
- "throughput",
- 0,
- "The maximum number of messages to send per second (0 for no limit).",
- )
- maxOpenRequests = flag.Int(
- "max-open-requests",
- 5,
- "The maximum number of unacknowledged requests the client will send on a single connection before blocking (default: 5).",
- )
- maxMessageBytes = flag.Int(
- "max-message-bytes",
- 1000000,
- "The max permitted size of a message.",
- )
- requiredAcks = flag.Int(
- "required-acks",
- 1,
- "The required number of acks needed from the broker (-1: all, 0: none, 1: local).",
- )
- timeout = flag.Duration(
- "timeout",
- 10*time.Second,
- "The duration the producer will wait to receive -required-acks.",
- )
- partitioner = flag.String(
- "partitioner",
- "roundrobin",
- "The partitioning scheme to use (hash, manual, random, roundrobin).",
- )
- compression = flag.String(
- "compression",
- "none",
- "The compression method to use (none, gzip, snappy, lz4).",
- )
- flushFrequency = flag.Duration(
- "flush-frequency",
- 0,
- "The best-effort frequency of flushes.",
- )
- flushBytes = flag.Int(
- "flush-bytes",
- 0,
- "The best-effort number of bytes needed to trigger a flush.",
- )
- flushMessages = flag.Int(
- "flush-messages",
- 0,
- "The best-effort number of messages needed to trigger a flush.",
- )
- flushMaxMessages = flag.Int(
- "flush-max-messages",
- 0,
- "The maximum number of messages the producer will send in a single request.",
- )
- clientID = flag.String(
- "client-id",
- "sarama",
- "The client ID sent with every request to the brokers.",
- )
- channelBufferSize = flag.Int(
- "channel-buffer-size",
- 256,
- "The number of events to buffer in internal and external channels.",
- )
- routines = flag.Int(
- "routines",
- 1,
- "The number of routines to send the messages from (-sync only).",
- )
- version = flag.String(
- "version",
- "0.8.2.0",
- "The assumed version of Kafka.",
- )
- verbose = flag.Bool(
- "verbose",
- false,
- "Turn on sarama logging to stderr",
- )
- )
- func parseCompression(scheme string) sarama.CompressionCodec {
- switch scheme {
- case "none":
- return sarama.CompressionNone
- case "gzip":
- return sarama.CompressionGZIP
- case "snappy":
- return sarama.CompressionSnappy
- case "lz4":
- return sarama.CompressionLZ4
- default:
- printUsageErrorAndExit(fmt.Sprintf("Unknown -compression: %s", scheme))
- }
- panic("should not happen")
- }
- func parsePartitioner(scheme string, partition int) sarama.PartitionerConstructor {
- if partition < 0 && scheme == "manual" {
- printUsageErrorAndExit("-partition must not be -1 for -partitioning=manual")
- }
- switch scheme {
- case "manual":
- return sarama.NewManualPartitioner
- case "hash":
- return sarama.NewHashPartitioner
- case "random":
- return sarama.NewRandomPartitioner
- case "roundrobin":
- return sarama.NewRoundRobinPartitioner
- default:
- printUsageErrorAndExit(fmt.Sprintf("Unknown -partitioning: %s", scheme))
- }
- panic("should not happen")
- }
- func parseVersion(version string) sarama.KafkaVersion {
- result, err := sarama.ParseKafkaVersion(version)
- if err != nil {
- printUsageErrorAndExit(fmt.Sprintf("unknown -version: %s", version))
- }
- return result
- }
- func generateMessages(topic string, partition, messageLoad, messageSize int) []*sarama.ProducerMessage {
- messages := make([]*sarama.ProducerMessage, messageLoad)
- for i := 0; i < messageLoad; i++ {
- payload := make([]byte, messageSize)
- if _, err := rand.Read(payload); err != nil {
- printErrorAndExit(69, "Failed to generate message payload: %s", err)
- }
- messages[i] = &sarama.ProducerMessage{
- Topic: topic,
- Partition: int32(partition),
- Value: sarama.ByteEncoder(payload),
- }
- }
- return messages
- }
- func main() {
- flag.Parse()
- if *brokers == "" {
- printUsageErrorAndExit("-brokers is required")
- }
- if *topic == "" {
- printUsageErrorAndExit("-topic is required")
- }
- if *messageLoad <= 0 {
- printUsageErrorAndExit("-message-load must be greater than 0")
- }
- if *messageSize <= 0 {
- printUsageErrorAndExit("-message-size must be greater than 0")
- }
- if *routines < 1 || *routines > *messageLoad {
- printUsageErrorAndExit("-routines must be greater than 0 and less than or equal to -message-load")
- }
- if *securityProtocol != "PLAINTEXT" && *securityProtocol != "SSL" {
- printUsageErrorAndExit(fmt.Sprintf("-security-protocol %q is not supported", *securityProtocol))
- }
- if *verbose {
- sarama.Logger = log.New(os.Stderr, "", log.LstdFlags)
- }
- config := sarama.NewConfig()
- config.Net.MaxOpenRequests = *maxOpenRequests
- config.Producer.MaxMessageBytes = *maxMessageBytes
- config.Producer.RequiredAcks = sarama.RequiredAcks(*requiredAcks)
- config.Producer.Timeout = *timeout
- config.Producer.Partitioner = parsePartitioner(*partitioner, *partition)
- config.Producer.Compression = parseCompression(*compression)
- config.Producer.Flush.Frequency = *flushFrequency
- config.Producer.Flush.Bytes = *flushBytes
- config.Producer.Flush.Messages = *flushMessages
- config.Producer.Flush.MaxMessages = *flushMaxMessages
- config.Producer.Return.Successes = true
- config.ClientID = *clientID
- config.ChannelBufferSize = *channelBufferSize
- config.Version = parseVersion(*version)
- if *securityProtocol == "SSL" {
- tlsConfig, err := tls.NewConfig(*tlsClientCert, *tlsClientKey)
- if err != nil {
- printErrorAndExit(69, "failed to load client certificate from: %s and private key from: %s: %v",
- *tlsClientCert, *tlsClientKey, err)
- }
- if *tlsRootCACerts != "" {
- rootCAsBytes, err := ioutil.ReadFile(*tlsRootCACerts)
- if err != nil {
- printErrorAndExit(69, "failed to read root CA certificates: %v", err)
- }
- certPool := x509.NewCertPool()
- if !certPool.AppendCertsFromPEM(rootCAsBytes) {
- printErrorAndExit(69, "failed to load root CA certificates from file: %s", *tlsRootCACerts)
- }
- // Use specific root CA set vs the host's set
- tlsConfig.RootCAs = certPool
- }
- config.Net.TLS.Enable = true
- config.Net.TLS.Config = tlsConfig
- }
- if err := config.Validate(); err != nil {
- printErrorAndExit(69, "Invalid configuration: %s", err)
- }
- // Print out metrics periodically.
- done := make(chan struct{})
- ctx, cancel := context.WithCancel(context.Background())
- go func(ctx context.Context) {
- defer close(done)
- t := time.Tick(5 * time.Second)
- for {
- select {
- case <-t:
- printMetrics(os.Stdout, config.MetricRegistry)
- case <-ctx.Done():
- return
- }
- }
- }(ctx)
- brokers := strings.Split(*brokers, ",")
- if *sync {
- runSyncProducer(*topic, *partition, *messageLoad, *messageSize, *routines,
- config, brokers, *throughput)
- } else {
- runAsyncProducer(*topic, *partition, *messageLoad, *messageSize,
- config, brokers, *throughput)
- }
- cancel()
- <-done
- // Print final metrics.
- printMetrics(os.Stdout, config.MetricRegistry)
- }
- func runAsyncProducer(topic string, partition, messageLoad, messageSize int,
- config *sarama.Config, brokers []string, throughput int) {
- producer, err := sarama.NewAsyncProducer(brokers, config)
- if err != nil {
- printErrorAndExit(69, "Failed to create producer: %s", err)
- }
- defer func() {
- if err := producer.Close(); err != nil {
- printErrorAndExit(69, "Failed to close producer: %s", err)
- }
- }()
- messages := generateMessages(topic, partition, messageLoad, messageSize)
- messagesDone := make(chan struct{})
- go func() {
- for i := 0; i < messageLoad; i++ {
- select {
- case <-producer.Successes():
- case err = <-producer.Errors():
- printErrorAndExit(69, "%s", err)
- }
- }
- messagesDone <- struct{}{}
- }()
- if throughput > 0 {
- ticker := time.NewTicker(time.Second)
- for _, message := range messages {
- for i := 0; i < throughput; i++ {
- producer.Input() <- message
- }
- <-ticker.C
- }
- ticker.Stop()
- } else {
- for _, message := range messages {
- producer.Input() <- message
- }
- }
- <-messagesDone
- close(messagesDone)
- }
- func runSyncProducer(topic string, partition, messageLoad, messageSize, routines int,
- config *sarama.Config, brokers []string, throughput int) {
- producer, err := sarama.NewSyncProducer(brokers, config)
- if err != nil {
- printErrorAndExit(69, "Failed to create producer: %s", err)
- }
- defer func() {
- if err := producer.Close(); err != nil {
- printErrorAndExit(69, "Failed to close producer: %s", err)
- }
- }()
- messages := make([][]*sarama.ProducerMessage, routines)
- for i := 0; i < routines; i++ {
- if i == routines-1 {
- messages[i] = generateMessages(topic, partition, messageLoad/routines+messageLoad%routines, messageSize)
- } else {
- messages[i] = generateMessages(topic, partition, messageLoad/routines, messageSize)
- }
- }
- var wg gosync.WaitGroup
- if throughput > 0 {
- for _, messages := range messages {
- messages := messages
- wg.Add(1)
- go func() {
- ticker := time.NewTicker(time.Second)
- for _, message := range messages {
- for i := 0; i < throughput; i++ {
- _, _, err = producer.SendMessage(message)
- if err != nil {
- printErrorAndExit(69, "Failed to send message: %s", err)
- }
- }
- <-ticker.C
- }
- ticker.Stop()
- wg.Done()
- }()
- }
- } else {
- for _, messages := range messages {
- messages := messages
- wg.Add(1)
- go func() {
- for _, message := range messages {
- _, _, err = producer.SendMessage(message)
- if err != nil {
- printErrorAndExit(69, "Failed to send message: %s", err)
- }
- }
- wg.Done()
- }()
- }
- }
- wg.Wait()
- }
- func printMetrics(w io.Writer, r metrics.Registry) {
- recordSendRateMetric := r.Get("record-send-rate")
- requestLatencyMetric := r.Get("request-latency-in-ms")
- outgoingByteRateMetric := r.Get("outgoing-byte-rate")
- requestsInFlightMetric := r.Get("requests-in-flight")
- if recordSendRateMetric == nil || requestLatencyMetric == nil || outgoingByteRateMetric == nil ||
- requestsInFlightMetric == nil {
- return
- }
- recordSendRate := recordSendRateMetric.(metrics.Meter).Snapshot()
- requestLatency := requestLatencyMetric.(metrics.Histogram).Snapshot()
- requestLatencyPercentiles := requestLatency.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999})
- outgoingByteRate := outgoingByteRateMetric.(metrics.Meter).Snapshot()
- requestsInFlight := requestsInFlightMetric.(metrics.Counter).Count()
- fmt.Fprintf(w, "%d records sent, %.1f records/sec (%.2f MiB/sec ingress, %.2f MiB/sec egress), "+
- "%.1f ms avg latency, %.1f ms stddev, %.1f ms 50th, %.1f ms 75th, "+
- "%.1f ms 95th, %.1f ms 99th, %.1f ms 99.9th, %d total req. in flight\n",
- recordSendRate.Count(),
- recordSendRate.RateMean(),
- recordSendRate.RateMean()*float64(*messageSize)/1024/1024,
- outgoingByteRate.RateMean()/1024/1024,
- requestLatency.Mean(),
- requestLatency.StdDev(),
- requestLatencyPercentiles[0],
- requestLatencyPercentiles[1],
- requestLatencyPercentiles[2],
- requestLatencyPercentiles[3],
- requestLatencyPercentiles[4],
- requestsInFlight,
- )
- }
- func printUsageErrorAndExit(message string) {
- fmt.Fprintln(os.Stderr, "ERROR:", message)
- fmt.Fprintln(os.Stderr)
- fmt.Fprintln(os.Stderr, "Available command line options:")
- flag.PrintDefaults()
- os.Exit(64)
- }
- func printErrorAndExit(code int, format string, values ...interface{}) {
- fmt.Fprintf(os.Stderr, "ERROR: %s\n", fmt.Sprintf(format, values...))
- fmt.Fprintln(os.Stderr)
- os.Exit(code)
- }
|