main.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. package main
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "crypto/x509"
  6. "flag"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "log"
  11. "os"
  12. "strings"
  13. gosync "sync"
  14. "time"
  15. "github.com/Shopify/sarama"
  16. "github.com/Shopify/sarama/tools/tls"
  17. metrics "github.com/rcrowley/go-metrics"
  18. )
  19. var (
  20. sync = flag.Bool(
  21. "sync",
  22. false,
  23. "Use a synchronous producer.",
  24. )
  25. messageLoad = flag.Int(
  26. "message-load",
  27. 0,
  28. "REQUIRED: The number of messages to produce to -topic.",
  29. )
  30. messageSize = flag.Int(
  31. "message-size",
  32. 0,
  33. "REQUIRED: The approximate size (in bytes) of each message to produce to -topic.",
  34. )
  35. brokers = flag.String(
  36. "brokers",
  37. "",
  38. "REQUIRED: A comma separated list of broker addresses.",
  39. )
  40. securityProtocol = flag.String(
  41. "security-protocol",
  42. "PLAINTEXT",
  43. "The name of the security protocol to talk to Kafka (PLAINTEXT, SSL) (default: PLAINTEXT).",
  44. )
  45. tlsRootCACerts = flag.String(
  46. "tls-ca-certs",
  47. "",
  48. "The path to a file that contains a set of root certificate authorities in PEM format "+
  49. "to trust when verifying broker certificates when -security-protocol=SSL "+
  50. "(leave empty to use the host's root CA set).",
  51. )
  52. tlsClientCert = flag.String(
  53. "tls-client-cert",
  54. "",
  55. "The path to a file that contains the client certificate to send to the broker "+
  56. "in PEM format if client authentication is required when -security-protocol=SSL "+
  57. "(leave empty to disable client authentication).",
  58. )
  59. tlsClientKey = flag.String(
  60. "tls-client-key",
  61. "",
  62. "The path to a file that contains the client private key linked to the client certificate "+
  63. "in PEM format when -security-protocol=SSL (REQUIRED if tls-client-cert is provided).",
  64. )
  65. topic = flag.String(
  66. "topic",
  67. "",
  68. "REQUIRED: The topic to run the performance test on.",
  69. )
  70. partition = flag.Int(
  71. "partition",
  72. -1,
  73. "The partition of -topic to run the performance test on.",
  74. )
  75. throughput = flag.Int(
  76. "throughput",
  77. 0,
  78. "The maximum number of messages to send per second (0 for no limit).",
  79. )
  80. maxMessageBytes = flag.Int(
  81. "max-message-bytes",
  82. 1000000,
  83. "The max permitted size of a message.",
  84. )
  85. requiredAcks = flag.Int(
  86. "required-acks",
  87. 1,
  88. "The required number of acks needed from the broker (-1: all, 0: none, 1: local).",
  89. )
  90. timeout = flag.Duration(
  91. "timeout",
  92. 10*time.Second,
  93. "The duration the producer will wait to receive -required-acks.",
  94. )
  95. partitioner = flag.String(
  96. "partitioner",
  97. "roundrobin",
  98. "The partitioning scheme to use (hash, manual, random, roundrobin).",
  99. )
  100. compression = flag.String(
  101. "compression",
  102. "none",
  103. "The compression method to use (none, gzip, snappy, lz4).",
  104. )
  105. flushFrequency = flag.Duration(
  106. "flush-frequency",
  107. 0,
  108. "The best-effort frequency of flushes.",
  109. )
  110. flushBytes = flag.Int(
  111. "flush-bytes",
  112. 0,
  113. "The best-effort number of bytes needed to trigger a flush.",
  114. )
  115. flushMessages = flag.Int(
  116. "flush-messages",
  117. 0,
  118. "The best-effort number of messages needed to trigger a flush.",
  119. )
  120. flushMaxMessages = flag.Int(
  121. "flush-max-messages",
  122. 0,
  123. "The maximum number of messages the producer will send in a single request.",
  124. )
  125. retryMax = flag.Int(
  126. "retry-max",
  127. 3,
  128. "The total number of times to retry sending a message.",
  129. )
  130. retryBackoff = flag.Duration(
  131. "retry-backoff",
  132. 100*time.Millisecond,
  133. "The duration the producer will wait for the cluster to settle between retries.",
  134. )
  135. clientID = flag.String(
  136. "client-id",
  137. "sarama",
  138. "The client ID sent with every request to the brokers.",
  139. )
  140. channelBufferSize = flag.Int(
  141. "channel-buffer-size",
  142. 256,
  143. "The number of events to buffer in internal and external channels.",
  144. )
  145. routines = flag.Int(
  146. "routines",
  147. 1,
  148. "The number of routines to send the messages from (-sync only).",
  149. )
  150. version = flag.String(
  151. "version",
  152. "0.8.2.0",
  153. "The assumed version of Kafka.",
  154. )
  155. verbose = flag.Bool(
  156. "verbose",
  157. false,
  158. "Turn on sarama logging to stderr",
  159. )
  160. )
  161. func parseCompression(scheme string) sarama.CompressionCodec {
  162. switch scheme {
  163. case "none":
  164. return sarama.CompressionNone
  165. case "gzip":
  166. return sarama.CompressionGZIP
  167. case "snappy":
  168. return sarama.CompressionSnappy
  169. case "lz4":
  170. return sarama.CompressionLZ4
  171. default:
  172. printUsageErrorAndExit(fmt.Sprintf("Unknown -compression: %s", scheme))
  173. }
  174. panic("should not happen")
  175. }
  176. func parsePartitioner(scheme string, partition int) sarama.PartitionerConstructor {
  177. if partition < 0 && scheme == "manual" {
  178. printUsageErrorAndExit("-partition must not be -1 for -partitioning=manual")
  179. }
  180. switch scheme {
  181. case "manual":
  182. return sarama.NewManualPartitioner
  183. case "hash":
  184. return sarama.NewHashPartitioner
  185. case "random":
  186. return sarama.NewRandomPartitioner
  187. case "roundrobin":
  188. return sarama.NewRoundRobinPartitioner
  189. default:
  190. printUsageErrorAndExit(fmt.Sprintf("Unknown -partitioning: %s", scheme))
  191. }
  192. panic("should not happen")
  193. }
  194. func parseVersion(version string) sarama.KafkaVersion {
  195. result, err := sarama.ParseKafkaVersion(version)
  196. if err != nil {
  197. printUsageErrorAndExit(fmt.Sprintf("unknown -version: %s", version))
  198. }
  199. return result
  200. }
  201. func generateMessages(topic string, partition, messageLoad, messageSize int) []*sarama.ProducerMessage {
  202. messages := make([]*sarama.ProducerMessage, messageLoad)
  203. for i := 0; i < messageLoad; i++ {
  204. payload := make([]byte, messageSize)
  205. if _, err := rand.Read(payload); err != nil {
  206. printErrorAndExit(69, "Failed to generate message payload: %s", err)
  207. }
  208. messages[i] = &sarama.ProducerMessage{
  209. Topic: topic,
  210. Partition: int32(partition),
  211. Value: sarama.ByteEncoder(payload),
  212. }
  213. }
  214. return messages
  215. }
  216. func main() {
  217. flag.Parse()
  218. if *brokers == "" {
  219. printUsageErrorAndExit("-brokers is required")
  220. }
  221. if *topic == "" {
  222. printUsageErrorAndExit("-topic is required")
  223. }
  224. if *messageLoad <= 0 {
  225. printUsageErrorAndExit("-message-load must be greater than 0")
  226. }
  227. if *messageSize <= 0 {
  228. printUsageErrorAndExit("-message-size must be greater than 0")
  229. }
  230. if *routines < 1 || *routines > *messageLoad {
  231. printUsageErrorAndExit("-routines must be greater than 0 and less than or equal to -message-load")
  232. }
  233. if *securityProtocol != "PLAINTEXT" && *securityProtocol != "SSL" {
  234. printUsageErrorAndExit(fmt.Sprintf("-security-protocol %q is not supported", *securityProtocol))
  235. }
  236. if *verbose {
  237. sarama.Logger = log.New(os.Stderr, "", log.LstdFlags)
  238. }
  239. config := sarama.NewConfig()
  240. config.Producer.MaxMessageBytes = *maxMessageBytes
  241. config.Producer.RequiredAcks = sarama.RequiredAcks(*requiredAcks)
  242. config.Producer.Timeout = *timeout
  243. config.Producer.Partitioner = parsePartitioner(*partitioner, *partition)
  244. config.Producer.Compression = parseCompression(*compression)
  245. config.Producer.Flush.Frequency = *flushFrequency
  246. config.Producer.Flush.Bytes = *flushBytes
  247. config.Producer.Flush.Messages = *flushMessages
  248. config.Producer.Flush.MaxMessages = *flushMaxMessages
  249. config.Producer.Return.Successes = true
  250. config.ClientID = *clientID
  251. config.ChannelBufferSize = *channelBufferSize
  252. config.Version = parseVersion(*version)
  253. if *securityProtocol == "SSL" {
  254. tlsConfig, err := tls.NewConfig(*tlsClientCert, *tlsClientKey)
  255. if err != nil {
  256. printErrorAndExit(69, "failed to load client certificate from: %s and private key from: %s: %v",
  257. *tlsClientCert, *tlsClientKey, err)
  258. }
  259. if *tlsRootCACerts != "" {
  260. rootCAsBytes, err := ioutil.ReadFile(*tlsRootCACerts)
  261. if err != nil {
  262. printErrorAndExit(69, "failed to read root CA certificates: %v", err)
  263. }
  264. certPool := x509.NewCertPool()
  265. if !certPool.AppendCertsFromPEM(rootCAsBytes) {
  266. printErrorAndExit(69, "failed to load root CA certificates from file: %s", *tlsRootCACerts)
  267. }
  268. // Use specific root CA set vs the host's set
  269. tlsConfig.RootCAs = certPool
  270. }
  271. config.Net.TLS.Enable = true
  272. config.Net.TLS.Config = tlsConfig
  273. }
  274. if err := config.Validate(); err != nil {
  275. printErrorAndExit(69, "Invalid configuration: %s", err)
  276. }
  277. // Print out metrics periodically.
  278. done := make(chan struct{})
  279. ctx, cancel := context.WithCancel(context.Background())
  280. go func(ctx context.Context) {
  281. defer close(done)
  282. t := time.Tick(5 * time.Second)
  283. for {
  284. select {
  285. case <-t:
  286. printMetrics(os.Stdout, config.MetricRegistry)
  287. case <-ctx.Done():
  288. return
  289. }
  290. }
  291. }(ctx)
  292. brokers := strings.Split(*brokers, ",")
  293. if *sync {
  294. runSyncProducer(*topic, *partition, *messageLoad, *messageSize, *routines,
  295. config, brokers, *throughput)
  296. } else {
  297. runAsyncProducer(*topic, *partition, *messageLoad, *messageSize,
  298. config, brokers, *throughput)
  299. }
  300. cancel()
  301. <-done
  302. // Print final metrics.
  303. printMetrics(os.Stdout, config.MetricRegistry)
  304. }
  305. func runAsyncProducer(topic string, partition, messageLoad, messageSize int,
  306. config *sarama.Config, brokers []string, throughput int) {
  307. producer, err := sarama.NewAsyncProducer(brokers, config)
  308. if err != nil {
  309. printErrorAndExit(69, "Failed to create producer: %s", err)
  310. }
  311. defer func() {
  312. if err := producer.Close(); err != nil {
  313. printErrorAndExit(69, "Failed to close producer: %s", err)
  314. }
  315. }()
  316. messages := generateMessages(topic, partition, messageLoad, messageSize)
  317. messagesDone := make(chan struct{})
  318. go func() {
  319. for i := 0; i < messageLoad; i++ {
  320. select {
  321. case <-producer.Successes():
  322. case err = <-producer.Errors():
  323. printErrorAndExit(69, "%s", err)
  324. }
  325. }
  326. messagesDone <- struct{}{}
  327. }()
  328. if throughput > 0 {
  329. ticker := time.NewTicker(time.Second)
  330. for _, message := range messages {
  331. for i := 0; i < throughput; i++ {
  332. producer.Input() <- message
  333. }
  334. <-ticker.C
  335. }
  336. ticker.Stop()
  337. } else {
  338. for _, message := range messages {
  339. producer.Input() <- message
  340. }
  341. }
  342. <-messagesDone
  343. close(messagesDone)
  344. }
  345. func runSyncProducer(topic string, partition, messageLoad, messageSize, routines int,
  346. config *sarama.Config, brokers []string, throughput int) {
  347. producer, err := sarama.NewSyncProducer(brokers, config)
  348. if err != nil {
  349. printErrorAndExit(69, "Failed to create producer: %s", err)
  350. }
  351. defer func() {
  352. if err := producer.Close(); err != nil {
  353. printErrorAndExit(69, "Failed to close producer: %s", err)
  354. }
  355. }()
  356. messages := make([][]*sarama.ProducerMessage, routines)
  357. for i := 0; i < routines; i++ {
  358. if i == routines-1 {
  359. messages[i] = generateMessages(topic, partition, messageLoad/routines+messageLoad%routines, messageSize)
  360. } else {
  361. messages[i] = generateMessages(topic, partition, messageLoad/routines, messageSize)
  362. }
  363. }
  364. var wg gosync.WaitGroup
  365. if throughput > 0 {
  366. for _, messages := range messages {
  367. messages := messages
  368. wg.Add(1)
  369. go func() {
  370. ticker := time.NewTicker(time.Second)
  371. for _, message := range messages {
  372. for i := 0; i < throughput; i++ {
  373. _, _, err = producer.SendMessage(message)
  374. if err != nil {
  375. printErrorAndExit(69, "Failed to send message: %s", err)
  376. }
  377. }
  378. <-ticker.C
  379. }
  380. ticker.Stop()
  381. wg.Done()
  382. }()
  383. }
  384. } else {
  385. for _, messages := range messages {
  386. messages := messages
  387. wg.Add(1)
  388. go func() {
  389. for _, message := range messages {
  390. _, _, err = producer.SendMessage(message)
  391. if err != nil {
  392. printErrorAndExit(69, "Failed to send message: %s", err)
  393. }
  394. }
  395. wg.Done()
  396. }()
  397. }
  398. }
  399. wg.Wait()
  400. }
  401. func printMetrics(w io.Writer, r metrics.Registry) {
  402. recordSendRateMetric := r.Get("record-send-rate")
  403. requestLatencyMetric := r.Get("request-latency-in-ms")
  404. outgoingByteRateMetric := r.Get("outgoing-byte-rate")
  405. if recordSendRateMetric == nil || requestLatencyMetric == nil || outgoingByteRateMetric == nil {
  406. return
  407. }
  408. recordSendRate := recordSendRateMetric.(metrics.Meter).Snapshot()
  409. requestLatency := requestLatencyMetric.(metrics.Histogram).Snapshot()
  410. requestLatencyPercentiles := requestLatency.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999})
  411. outgoingByteRate := outgoingByteRateMetric.(metrics.Meter).Snapshot()
  412. fmt.Fprintf(w, "%d records sent, %.1f records/sec (%.2f MiB/sec ingress, %.2f MiB/sec egress), "+
  413. "%.1f ms avg latency, %.1f ms stddev, %.1f ms 50th, %.1f ms 75th, "+
  414. "%.1f ms 95th, %.1f ms 99th, %.1f ms 99.9th\n",
  415. recordSendRate.Count(),
  416. recordSendRate.RateMean(),
  417. recordSendRate.RateMean()*float64(*messageSize)/1024/1024,
  418. outgoingByteRate.RateMean()/1024/1024,
  419. requestLatency.Mean(),
  420. requestLatency.StdDev(),
  421. requestLatencyPercentiles[0],
  422. requestLatencyPercentiles[1],
  423. requestLatencyPercentiles[2],
  424. requestLatencyPercentiles[3],
  425. requestLatencyPercentiles[4],
  426. )
  427. }
  428. func printUsageErrorAndExit(message string) {
  429. fmt.Fprintln(os.Stderr, "ERROR:", message)
  430. fmt.Fprintln(os.Stderr)
  431. fmt.Fprintln(os.Stderr, "Available command line options:")
  432. flag.PrintDefaults()
  433. os.Exit(64)
  434. }
  435. func printErrorAndExit(code int, format string, values ...interface{}) {
  436. fmt.Fprintf(os.Stderr, "ERROR: %s\n", fmt.Sprintf(format, values...))
  437. fmt.Fprintln(os.Stderr)
  438. os.Exit(code)
  439. }