Bläddra i källkod

Merge pull request #1152 from mimaison/idempotent_producer

Added support for Idempotent Producer
Vlad Gorodetsky 7 år sedan
förälder
incheckning
f21e149e59
9 ändrade filer med 618 tillägg och 62 borttagningar
  1. 161 47
      async_producer.go
  2. 253 0
      async_producer_test.go
  3. 24 1
      client.go
  4. 18 0
      config.go
  5. 26 0
      config_test.go
  6. 21 3
      functional_consumer_test.go
  7. 7 1
      produce_response.go
  8. 14 7
      produce_set.go
  9. 94 3
      produce_set_test.go

+ 161 - 47
async_producer.go

@@ -47,6 +47,50 @@ type AsyncProducer interface {
 	Errors() <-chan *ProducerError
 	Errors() <-chan *ProducerError
 }
 }
 
 
+// transactionManager keeps the state necessary to ensure idempotent production
+type transactionManager struct {
+	producerID      int64
+	producerEpoch   int16
+	sequenceNumbers map[string]int32
+	mutex           sync.Mutex
+}
+
+const (
+	noProducerID    = -1
+	noProducerEpoch = -1
+)
+
+func (t *transactionManager) getAndIncrementSequenceNumber(topic string, partition int32) int32 {
+	key := fmt.Sprintf("%s-%d", topic, partition)
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	sequence := t.sequenceNumbers[key]
+	t.sequenceNumbers[key] = sequence + 1
+	return sequence
+}
+
+func newTransactionManager(conf *Config, client Client) (*transactionManager, error) {
+	txnmgr := &transactionManager{
+		producerID:    noProducerID,
+		producerEpoch: noProducerEpoch,
+	}
+
+	if conf.Producer.Idempotent {
+		initProducerIDResponse, err := client.InitProducerID()
+		if err != nil {
+			return nil, err
+		}
+		txnmgr.producerID = initProducerIDResponse.ProducerID
+		txnmgr.producerEpoch = initProducerIDResponse.ProducerEpoch
+		txnmgr.sequenceNumbers = make(map[string]int32)
+		txnmgr.mutex = sync.Mutex{}
+
+		Logger.Printf("Obtained a ProducerId: %d and ProducerEpoch: %d\n", txnmgr.producerID, txnmgr.producerEpoch)
+	}
+
+	return txnmgr, nil
+}
+
 type asyncProducer struct {
 type asyncProducer struct {
 	client    Client
 	client    Client
 	conf      *Config
 	conf      *Config
@@ -56,9 +100,11 @@ type asyncProducer struct {
 	input, successes, retries chan *ProducerMessage
 	input, successes, retries chan *ProducerMessage
 	inFlight                  sync.WaitGroup
 	inFlight                  sync.WaitGroup
 
 
-	brokers    map[*Broker]chan<- *ProducerMessage
-	brokerRefs map[chan<- *ProducerMessage]int
+	brokers    map[*Broker]*brokerProducer
+	brokerRefs map[*brokerProducer]int
 	brokerLock sync.Mutex
 	brokerLock sync.Mutex
+
+	txnmgr *transactionManager
 }
 }
 
 
 // NewAsyncProducer creates a new AsyncProducer using the given broker addresses and configuration.
 // NewAsyncProducer creates a new AsyncProducer using the given broker addresses and configuration.
@@ -84,6 +130,11 @@ func NewAsyncProducerFromClient(client Client) (AsyncProducer, error) {
 		return nil, ErrClosedClient
 		return nil, ErrClosedClient
 	}
 	}
 
 
+	txnmgr, err := newTransactionManager(client.Config(), client)
+	if err != nil {
+		return nil, err
+	}
+
 	p := &asyncProducer{
 	p := &asyncProducer{
 		client:     client,
 		client:     client,
 		conf:       client.Config(),
 		conf:       client.Config(),
@@ -91,8 +142,9 @@ func NewAsyncProducerFromClient(client Client) (AsyncProducer, error) {
 		input:      make(chan *ProducerMessage),
 		input:      make(chan *ProducerMessage),
 		successes:  make(chan *ProducerMessage),
 		successes:  make(chan *ProducerMessage),
 		retries:    make(chan *ProducerMessage),
 		retries:    make(chan *ProducerMessage),
-		brokers:    make(map[*Broker]chan<- *ProducerMessage),
-		brokerRefs: make(map[chan<- *ProducerMessage]int),
+		brokers:    make(map[*Broker]*brokerProducer),
+		brokerRefs: make(map[*brokerProducer]int),
+		txnmgr:     txnmgr,
 	}
 	}
 
 
 	// launch our singleton dispatchers
 	// launch our singleton dispatchers
@@ -145,9 +197,10 @@ type ProducerMessage struct {
 	// least version 0.10.0.
 	// least version 0.10.0.
 	Timestamp time.Time
 	Timestamp time.Time
 
 
-	retries     int
-	flags       flagSet
-	expectation chan *ProducerError
+	retries        int
+	flags          flagSet
+	expectation    chan *ProducerError
+	sequenceNumber int32
 }
 }
 
 
 const producerMessageOverhead = 26 // the metadata overhead of CRC, flags, etc.
 const producerMessageOverhead = 26 // the metadata overhead of CRC, flags, etc.
@@ -328,6 +381,10 @@ func (tp *topicProducer) dispatch() {
 				continue
 				continue
 			}
 			}
 		}
 		}
+		// All messages being retried (sent or not) have already had their retry count updated
+		if tp.parent.conf.Producer.Idempotent && msg.retries == 0 {
+			msg.sequenceNumber = tp.parent.txnmgr.getAndIncrementSequenceNumber(msg.Topic, msg.Partition)
+		}
 
 
 		handler := tp.handlers[msg.Partition]
 		handler := tp.handlers[msg.Partition]
 		if handler == nil {
 		if handler == nil {
@@ -394,9 +451,9 @@ type partitionProducer struct {
 	partition int32
 	partition int32
 	input     <-chan *ProducerMessage
 	input     <-chan *ProducerMessage
 
 
-	leader  *Broker
-	breaker *breaker.Breaker
-	output  chan<- *ProducerMessage
+	leader         *Broker
+	breaker        *breaker.Breaker
+	brokerProducer *brokerProducer
 
 
 	// highWatermark tracks the "current" retry level, which is the only one where we actually let messages through,
 	// highWatermark tracks the "current" retry level, which is the only one where we actually let messages through,
 	// all other messages get buffered in retryState[msg.retries].buf to preserve ordering
 	// all other messages get buffered in retryState[msg.retries].buf to preserve ordering
@@ -431,9 +488,9 @@ func (pp *partitionProducer) dispatch() {
 	// on the first message
 	// on the first message
 	pp.leader, _ = pp.parent.client.Leader(pp.topic, pp.partition)
 	pp.leader, _ = pp.parent.client.Leader(pp.topic, pp.partition)
 	if pp.leader != nil {
 	if pp.leader != nil {
-		pp.output = pp.parent.getBrokerProducer(pp.leader)
+		pp.brokerProducer = pp.parent.getBrokerProducer(pp.leader)
 		pp.parent.inFlight.Add(1) // we're generating a syn message; track it so we don't shut down while it's still inflight
 		pp.parent.inFlight.Add(1) // we're generating a syn message; track it so we don't shut down while it's still inflight
-		pp.output <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: syn}
+		pp.brokerProducer.input <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: syn}
 	}
 	}
 
 
 	for msg := range pp.input {
 	for msg := range pp.input {
@@ -465,7 +522,7 @@ func (pp *partitionProducer) dispatch() {
 		// if we made it this far then the current msg contains real data, and can be sent to the next goroutine
 		// if we made it this far then the current msg contains real data, and can be sent to the next goroutine
 		// without breaking any of our ordering guarantees
 		// without breaking any of our ordering guarantees
 
 
-		if pp.output == nil {
+		if pp.brokerProducer == nil {
 			if err := pp.updateLeader(); err != nil {
 			if err := pp.updateLeader(); err != nil {
 				pp.parent.returnError(msg, err)
 				pp.parent.returnError(msg, err)
 				time.Sleep(pp.parent.conf.Producer.Retry.Backoff)
 				time.Sleep(pp.parent.conf.Producer.Retry.Backoff)
@@ -474,11 +531,11 @@ func (pp *partitionProducer) dispatch() {
 			Logger.Printf("producer/leader/%s/%d selected broker %d\n", pp.topic, pp.partition, pp.leader.ID())
 			Logger.Printf("producer/leader/%s/%d selected broker %d\n", pp.topic, pp.partition, pp.leader.ID())
 		}
 		}
 
 
-		pp.output <- msg
+		pp.brokerProducer.input <- msg
 	}
 	}
 
 
-	if pp.output != nil {
-		pp.parent.unrefBrokerProducer(pp.leader, pp.output)
+	if pp.brokerProducer != nil {
+		pp.parent.unrefBrokerProducer(pp.leader, pp.brokerProducer)
 	}
 	}
 }
 }
 
 
@@ -490,12 +547,12 @@ func (pp *partitionProducer) newHighWatermark(hwm int) {
 	// back to us and we can safely flush the backlog (otherwise we risk re-ordering messages)
 	// back to us and we can safely flush the backlog (otherwise we risk re-ordering messages)
 	pp.retryState[pp.highWatermark].expectChaser = true
 	pp.retryState[pp.highWatermark].expectChaser = true
 	pp.parent.inFlight.Add(1) // we're generating a fin message; track it so we don't shut down while it's still inflight
 	pp.parent.inFlight.Add(1) // we're generating a fin message; track it so we don't shut down while it's still inflight
-	pp.output <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: fin, retries: pp.highWatermark - 1}
+	pp.brokerProducer.input <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: fin, retries: pp.highWatermark - 1}
 
 
 	// a new HWM means that our current broker selection is out of date
 	// a new HWM means that our current broker selection is out of date
 	Logger.Printf("producer/leader/%s/%d abandoning broker %d\n", pp.topic, pp.partition, pp.leader.ID())
 	Logger.Printf("producer/leader/%s/%d abandoning broker %d\n", pp.topic, pp.partition, pp.leader.ID())
-	pp.parent.unrefBrokerProducer(pp.leader, pp.output)
-	pp.output = nil
+	pp.parent.unrefBrokerProducer(pp.leader, pp.brokerProducer)
+	pp.brokerProducer = nil
 }
 }
 
 
 func (pp *partitionProducer) flushRetryBuffers() {
 func (pp *partitionProducer) flushRetryBuffers() {
@@ -503,7 +560,7 @@ func (pp *partitionProducer) flushRetryBuffers() {
 	for {
 	for {
 		pp.highWatermark--
 		pp.highWatermark--
 
 
-		if pp.output == nil {
+		if pp.brokerProducer == nil {
 			if err := pp.updateLeader(); err != nil {
 			if err := pp.updateLeader(); err != nil {
 				pp.parent.returnErrors(pp.retryState[pp.highWatermark].buf, err)
 				pp.parent.returnErrors(pp.retryState[pp.highWatermark].buf, err)
 				goto flushDone
 				goto flushDone
@@ -512,7 +569,7 @@ func (pp *partitionProducer) flushRetryBuffers() {
 		}
 		}
 
 
 		for _, msg := range pp.retryState[pp.highWatermark].buf {
 		for _, msg := range pp.retryState[pp.highWatermark].buf {
-			pp.output <- msg
+			pp.brokerProducer.input <- msg
 		}
 		}
 
 
 	flushDone:
 	flushDone:
@@ -537,16 +594,16 @@ func (pp *partitionProducer) updateLeader() error {
 			return err
 			return err
 		}
 		}
 
 
-		pp.output = pp.parent.getBrokerProducer(pp.leader)
+		pp.brokerProducer = pp.parent.getBrokerProducer(pp.leader)
 		pp.parent.inFlight.Add(1) // we're generating a syn message; track it so we don't shut down while it's still inflight
 		pp.parent.inFlight.Add(1) // we're generating a syn message; track it so we don't shut down while it's still inflight
-		pp.output <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: syn}
+		pp.brokerProducer.input <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: syn}
 
 
 		return nil
 		return nil
 	})
 	})
 }
 }
 
 
 // one per broker; also constructs an associated flusher
 // one per broker; also constructs an associated flusher
-func (p *asyncProducer) newBrokerProducer(broker *Broker) chan<- *ProducerMessage {
+func (p *asyncProducer) newBrokerProducer(broker *Broker) *brokerProducer {
 	var (
 	var (
 		input     = make(chan *ProducerMessage)
 		input     = make(chan *ProducerMessage)
 		bridge    = make(chan *produceSet)
 		bridge    = make(chan *produceSet)
@@ -580,7 +637,7 @@ func (p *asyncProducer) newBrokerProducer(broker *Broker) chan<- *ProducerMessag
 		close(responses)
 		close(responses)
 	})
 	})
 
 
-	return input
+	return bp
 }
 }
 
 
 type brokerProducerResponse struct {
 type brokerProducerResponse struct {
@@ -595,7 +652,7 @@ type brokerProducer struct {
 	parent *asyncProducer
 	parent *asyncProducer
 	broker *Broker
 	broker *Broker
 
 
-	input     <-chan *ProducerMessage
+	input     chan *ProducerMessage
 	output    chan<- *produceSet
 	output    chan<- *produceSet
 	responses <-chan *brokerProducerResponse
 	responses <-chan *brokerProducerResponse
 
 
@@ -740,16 +797,17 @@ func (bp *brokerProducer) handleResponse(response *brokerProducerResponse) {
 func (bp *brokerProducer) handleSuccess(sent *produceSet, response *ProduceResponse) {
 func (bp *brokerProducer) handleSuccess(sent *produceSet, response *ProduceResponse) {
 	// we iterate through the blocks in the request set, not the response, so that we notice
 	// we iterate through the blocks in the request set, not the response, so that we notice
 	// if the response is missing a block completely
 	// if the response is missing a block completely
-	sent.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) {
+	var retryTopics []string
+	sent.eachPartition(func(topic string, partition int32, pSet *partitionSet) {
 		if response == nil {
 		if response == nil {
 			// this only happens when RequiredAcks is NoResponse, so we have to assume success
 			// this only happens when RequiredAcks is NoResponse, so we have to assume success
-			bp.parent.returnSuccesses(msgs)
+			bp.parent.returnSuccesses(pSet.msgs)
 			return
 			return
 		}
 		}
 
 
 		block := response.GetBlock(topic, partition)
 		block := response.GetBlock(topic, partition)
 		if block == nil {
 		if block == nil {
-			bp.parent.returnErrors(msgs, ErrIncompleteResponse)
+			bp.parent.returnErrors(pSet.msgs, ErrIncompleteResponse)
 			return
 			return
 		}
 		}
 
 
@@ -757,45 +815,101 @@ func (bp *brokerProducer) handleSuccess(sent *produceSet, response *ProduceRespo
 		// Success
 		// Success
 		case ErrNoError:
 		case ErrNoError:
 			if bp.parent.conf.Version.IsAtLeast(V0_10_0_0) && !block.Timestamp.IsZero() {
 			if bp.parent.conf.Version.IsAtLeast(V0_10_0_0) && !block.Timestamp.IsZero() {
-				for _, msg := range msgs {
+				for _, msg := range pSet.msgs {
 					msg.Timestamp = block.Timestamp
 					msg.Timestamp = block.Timestamp
 				}
 				}
 			}
 			}
-			for i, msg := range msgs {
+			for i, msg := range pSet.msgs {
 				msg.Offset = block.Offset + int64(i)
 				msg.Offset = block.Offset + int64(i)
 			}
 			}
-			bp.parent.returnSuccesses(msgs)
+			bp.parent.returnSuccesses(pSet.msgs)
+		// Duplicate
+		case ErrDuplicateSequenceNumber:
+			bp.parent.returnSuccesses(pSet.msgs)
 		// Retriable errors
 		// Retriable errors
 		case ErrInvalidMessage, ErrUnknownTopicOrPartition, ErrLeaderNotAvailable, ErrNotLeaderForPartition,
 		case ErrInvalidMessage, ErrUnknownTopicOrPartition, ErrLeaderNotAvailable, ErrNotLeaderForPartition,
 			ErrRequestTimedOut, ErrNotEnoughReplicas, ErrNotEnoughReplicasAfterAppend:
 			ErrRequestTimedOut, ErrNotEnoughReplicas, ErrNotEnoughReplicasAfterAppend:
-			Logger.Printf("producer/broker/%d state change to [retrying] on %s/%d because %v\n",
-				bp.broker.ID(), topic, partition, block.Err)
-			bp.currentRetries[topic][partition] = block.Err
-			bp.parent.retryMessages(msgs, block.Err)
-			bp.parent.retryMessages(bp.buffer.dropPartition(topic, partition), block.Err)
+			retryTopics = append(retryTopics, topic)
 		// Other non-retriable errors
 		// Other non-retriable errors
 		default:
 		default:
-			bp.parent.returnErrors(msgs, block.Err)
+			bp.parent.returnErrors(pSet.msgs, block.Err)
 		}
 		}
 	})
 	})
+
+	if len(retryTopics) > 0 {
+		err := bp.parent.client.RefreshMetadata(retryTopics...)
+		if err != nil {
+			Logger.Printf("Failed refreshing metadata because of %v\n", err)
+		}
+
+		sent.eachPartition(func(topic string, partition int32, pSet *partitionSet) {
+			block := response.GetBlock(topic, partition)
+			if block == nil {
+				// handled in the previous "eachPartition" loop
+				return
+			}
+
+			switch block.Err {
+			case ErrInvalidMessage, ErrUnknownTopicOrPartition, ErrLeaderNotAvailable, ErrNotLeaderForPartition,
+				ErrRequestTimedOut, ErrNotEnoughReplicas, ErrNotEnoughReplicasAfterAppend:
+				Logger.Printf("producer/broker/%d state change to [retrying] on %s/%d because %v\n",
+					bp.broker.ID(), topic, partition, block.Err)
+				if bp.currentRetries[topic] == nil {
+					bp.currentRetries[topic] = make(map[int32]error)
+				}
+				bp.currentRetries[topic][partition] = block.Err
+				// dropping the following messages has the side effect of incrementing their retry count
+				bp.parent.retryMessages(bp.buffer.dropPartition(topic, partition), block.Err)
+				bp.parent.retryBatch(topic, partition, pSet, block.Err)
+			}
+		})
+	}
+}
+
+func (p *asyncProducer) retryBatch(topic string, partition int32, pSet *partitionSet, kerr KError) {
+	Logger.Printf("Retrying batch for %v-%d because of %s\n", topic, partition, kerr)
+	produceSet := newProduceSet(p)
+	produceSet.msgs[topic] = make(map[int32]*partitionSet)
+	produceSet.msgs[topic][partition] = pSet
+	produceSet.bufferBytes += pSet.bufferBytes
+	produceSet.bufferCount += len(pSet.msgs)
+	for _, msg := range pSet.msgs {
+		if msg.retries >= p.conf.Producer.Retry.Max {
+			p.returnError(msg, kerr)
+			return
+		}
+		msg.retries++
+	}
+
+	// it's expected that a metadata refresh has been requested prior to calling retryBatch
+	leader, err := p.client.Leader(topic, partition)
+	if err != nil {
+		Logger.Printf("Failed retrying batch for %v-%d because of %v while looking up for new leader\n", topic, partition, err)
+		for _, msg := range pSet.msgs {
+			p.returnError(msg, kerr)
+		}
+		return
+	}
+	bp := p.getBrokerProducer(leader)
+	bp.output <- produceSet
 }
 }
 
 
 func (bp *brokerProducer) handleError(sent *produceSet, err error) {
 func (bp *brokerProducer) handleError(sent *produceSet, err error) {
 	switch err.(type) {
 	switch err.(type) {
 	case PacketEncodingError:
 	case PacketEncodingError:
-		sent.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) {
-			bp.parent.returnErrors(msgs, err)
+		sent.eachPartition(func(topic string, partition int32, pSet *partitionSet) {
+			bp.parent.returnErrors(pSet.msgs, err)
 		})
 		})
 	default:
 	default:
 		Logger.Printf("producer/broker/%d state change to [closing] because %s\n", bp.broker.ID(), err)
 		Logger.Printf("producer/broker/%d state change to [closing] because %s\n", bp.broker.ID(), err)
 		bp.parent.abandonBrokerConnection(bp.broker)
 		bp.parent.abandonBrokerConnection(bp.broker)
 		_ = bp.broker.Close()
 		_ = bp.broker.Close()
 		bp.closing = err
 		bp.closing = err
-		sent.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) {
-			bp.parent.retryMessages(msgs, err)
+		sent.eachPartition(func(topic string, partition int32, pSet *partitionSet) {
+			bp.parent.retryMessages(pSet.msgs, err)
 		})
 		})
-		bp.buffer.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) {
-			bp.parent.retryMessages(msgs, err)
+		bp.buffer.eachPartition(func(topic string, partition int32, pSet *partitionSet) {
+			bp.parent.retryMessages(pSet.msgs, err)
 		})
 		})
 		bp.rollOver()
 		bp.rollOver()
 	}
 	}
@@ -892,7 +1006,7 @@ func (p *asyncProducer) retryMessages(batch []*ProducerMessage, err error) {
 	}
 	}
 }
 }
 
 
-func (p *asyncProducer) getBrokerProducer(broker *Broker) chan<- *ProducerMessage {
+func (p *asyncProducer) getBrokerProducer(broker *Broker) *brokerProducer {
 	p.brokerLock.Lock()
 	p.brokerLock.Lock()
 	defer p.brokerLock.Unlock()
 	defer p.brokerLock.Unlock()
 
 
@@ -909,13 +1023,13 @@ func (p *asyncProducer) getBrokerProducer(broker *Broker) chan<- *ProducerMessag
 	return bp
 	return bp
 }
 }
 
 
-func (p *asyncProducer) unrefBrokerProducer(broker *Broker, bp chan<- *ProducerMessage) {
+func (p *asyncProducer) unrefBrokerProducer(broker *Broker, bp *brokerProducer) {
 	p.brokerLock.Lock()
 	p.brokerLock.Lock()
 	defer p.brokerLock.Unlock()
 	defer p.brokerLock.Unlock()
 
 
 	p.brokerRefs[bp]--
 	p.brokerRefs[bp]--
 	if p.brokerRefs[bp] == 0 {
 	if p.brokerRefs[bp] == 0 {
-		close(bp)
+		close(bp.input)
 		delete(p.brokerRefs, bp)
 		delete(p.brokerRefs, bp)
 
 
 		if p.brokers[broker] == bp {
 		if p.brokers[broker] == bp {

+ 253 - 0
async_producer_test.go

@@ -300,6 +300,7 @@ func TestAsyncProducerFailureRetry(t *testing.T) {
 	for i := 0; i < 10; i++ {
 	for i := 0; i < 10; i++ {
 		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)}
 		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)}
 	}
 	}
+	leader2.Returns(metadataLeader2)
 	leader2.Returns(prodSuccess)
 	leader2.Returns(prodSuccess)
 	expectResults(t, producer, 10, 0)
 	expectResults(t, producer, 10, 0)
 
 
@@ -459,6 +460,7 @@ func TestAsyncProducerMultipleRetries(t *testing.T) {
 	metadataLeader2 := new(MetadataResponse)
 	metadataLeader2 := new(MetadataResponse)
 	metadataLeader2.AddBroker(leader2.Addr(), leader2.BrokerID())
 	metadataLeader2.AddBroker(leader2.Addr(), leader2.BrokerID())
 	metadataLeader2.AddTopicPartition("my_topic", 0, leader2.BrokerID(), nil, nil, ErrNoError)
 	metadataLeader2.AddTopicPartition("my_topic", 0, leader2.BrokerID(), nil, nil, ErrNoError)
+
 	seedBroker.Returns(metadataLeader2)
 	seedBroker.Returns(metadataLeader2)
 	leader2.Returns(prodNotLeader)
 	leader2.Returns(prodNotLeader)
 	seedBroker.Returns(metadataLeader1)
 	seedBroker.Returns(metadataLeader1)
@@ -466,6 +468,7 @@ func TestAsyncProducerMultipleRetries(t *testing.T) {
 	seedBroker.Returns(metadataLeader1)
 	seedBroker.Returns(metadataLeader1)
 	leader1.Returns(prodNotLeader)
 	leader1.Returns(prodNotLeader)
 	seedBroker.Returns(metadataLeader2)
 	seedBroker.Returns(metadataLeader2)
+	seedBroker.Returns(metadataLeader2)
 
 
 	prodSuccess := new(ProduceResponse)
 	prodSuccess := new(ProduceResponse)
 	prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError)
 	prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError)
@@ -651,6 +654,7 @@ func TestAsyncProducerFlusherRetryCondition(t *testing.T) {
 
 
 	// succeed this time
 	// succeed this time
 	expectResults(t, producer, 5, 0)
 	expectResults(t, producer, 5, 0)
+	seedBroker.Returns(metadataResponse)
 
 
 	// put five more through
 	// put five more through
 	for i := 0; i < 5; i++ {
 	for i := 0; i < 5; i++ {
@@ -753,6 +757,255 @@ func TestAsyncProducerNoReturns(t *testing.T) {
 	leader.Close()
 	leader.Close()
 }
 }
 
 
+func TestAsyncProducerIdempotentGoldenPath(t *testing.T) {
+	broker := NewMockBroker(t, 1)
+
+	metadataResponse := &MetadataResponse{
+		Version:      1,
+		ControllerID: 1,
+	}
+	metadataResponse.AddBroker(broker.Addr(), broker.BrokerID())
+	metadataResponse.AddTopicPartition("my_topic", 0, broker.BrokerID(), nil, nil, ErrNoError)
+	broker.Returns(metadataResponse)
+
+	initProducerID := &InitProducerIDResponse{
+		ThrottleTime:  0,
+		ProducerID:    1000,
+		ProducerEpoch: 1,
+	}
+	broker.Returns(initProducerID)
+
+	config := NewConfig()
+	config.Producer.Flush.Messages = 10
+	config.Producer.Return.Successes = true
+	config.Producer.Retry.Max = 4
+	config.Producer.RequiredAcks = WaitForAll
+	config.Producer.Retry.Backoff = 0
+	config.Producer.Idempotent = true
+	config.Net.MaxOpenRequests = 1
+	config.Version = V0_11_0_0
+	producer, err := NewAsyncProducer([]string{broker.Addr()}, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for i := 0; i < 10; i++ {
+		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)}
+	}
+
+	prodSuccess := &ProduceResponse{
+		Version:      3,
+		ThrottleTime: 0,
+	}
+	prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError)
+	broker.Returns(prodSuccess)
+	expectResults(t, producer, 10, 0)
+
+	broker.Close()
+	closeProducer(t, producer)
+}
+
+func TestAsyncProducerIdempotentRetryCheckBatch(t *testing.T) {
+	//Logger = log.New(os.Stderr, "", log.LstdFlags)
+	tests := []struct {
+		name           string
+		failAfterWrite bool
+	}{
+		{"FailAfterWrite", true},
+		{"FailBeforeWrite", false},
+	}
+
+	for _, test := range tests {
+		broker := NewMockBroker(t, 1)
+
+		metadataResponse := &MetadataResponse{
+			Version:      1,
+			ControllerID: 1,
+		}
+		metadataResponse.AddBroker(broker.Addr(), broker.BrokerID())
+		metadataResponse.AddTopicPartition("my_topic", 0, broker.BrokerID(), nil, nil, ErrNoError)
+
+		initProducerIDResponse := &InitProducerIDResponse{
+			ThrottleTime:  0,
+			ProducerID:    1000,
+			ProducerEpoch: 1,
+		}
+
+		prodNotLeaderResponse := &ProduceResponse{
+			Version:      3,
+			ThrottleTime: 0,
+		}
+		prodNotLeaderResponse.AddTopicPartition("my_topic", 0, ErrNotEnoughReplicas)
+
+		prodDuplicate := &ProduceResponse{
+			Version:      3,
+			ThrottleTime: 0,
+		}
+		prodDuplicate.AddTopicPartition("my_topic", 0, ErrDuplicateSequenceNumber)
+
+		prodOutOfSeq := &ProduceResponse{
+			Version:      3,
+			ThrottleTime: 0,
+		}
+		prodOutOfSeq.AddTopicPartition("my_topic", 0, ErrOutOfOrderSequenceNumber)
+
+		prodSuccessResponse := &ProduceResponse{
+			Version:      3,
+			ThrottleTime: 0,
+		}
+		prodSuccessResponse.AddTopicPartition("my_topic", 0, ErrNoError)
+
+		prodCounter := 0
+		lastBatchFirstSeq := -1
+		lastBatchSize := -1
+		lastSequenceWrittenToDisk := -1
+		handlerFailBeforeWrite := func(req *request) (res encoder) {
+			switch req.body.key() {
+			case 3:
+				return metadataResponse
+			case 22:
+				return initProducerIDResponse
+			case 0:
+				prodCounter++
+
+				preq := req.body.(*ProduceRequest)
+				batch := preq.records["my_topic"][0].RecordBatch
+				batchFirstSeq := int(batch.FirstSequence)
+				batchSize := len(batch.Records)
+
+				if lastSequenceWrittenToDisk == batchFirstSeq-1 { //in sequence append
+
+					if lastBatchFirstSeq == batchFirstSeq { //is a batch retry
+						if lastBatchSize == batchSize { //good retry
+							// mock write to disk
+							lastSequenceWrittenToDisk = batchFirstSeq + batchSize - 1
+							return prodSuccessResponse
+						}
+						t.Errorf("[%s] Retried Batch firstSeq=%d with different size old=%d new=%d", test.name, batchFirstSeq, lastBatchSize, batchSize)
+						return prodOutOfSeq
+					} else { // not a retry
+						// save batch just received for future check
+						lastBatchFirstSeq = batchFirstSeq
+						lastBatchSize = batchSize
+
+						if prodCounter%2 == 1 {
+							if test.failAfterWrite {
+								// mock write to disk
+								lastSequenceWrittenToDisk = batchFirstSeq + batchSize - 1
+							}
+							return prodNotLeaderResponse
+						}
+						// mock write to disk
+						lastSequenceWrittenToDisk = batchFirstSeq + batchSize - 1
+						return prodSuccessResponse
+					}
+				} else {
+					if lastBatchFirstSeq == batchFirstSeq && lastBatchSize == batchSize { // is a good batch retry
+						if lastSequenceWrittenToDisk == (batchFirstSeq + batchSize - 1) { // we already have the messages
+							return prodDuplicate
+						}
+						// mock write to disk
+						lastSequenceWrittenToDisk = batchFirstSeq + batchSize - 1
+						return prodSuccessResponse
+					} else { //out of sequence / bad retried batch
+						if lastBatchFirstSeq == batchFirstSeq && lastBatchSize != batchSize {
+							t.Errorf("[%s] Retried Batch firstSeq=%d with different size old=%d new=%d", test.name, batchFirstSeq, lastBatchSize, batchSize)
+						} else if lastSequenceWrittenToDisk+1 != batchFirstSeq {
+							t.Errorf("[%s] Out of sequence message lastSequence=%d new batch starts at=%d", test.name, lastSequenceWrittenToDisk, batchFirstSeq)
+						} else {
+							t.Errorf("[%s] Unexpected error", test.name)
+						}
+
+						return prodOutOfSeq
+					}
+				}
+
+			}
+			return nil
+		}
+
+		config := NewConfig()
+		config.Version = V0_11_0_0
+		config.Producer.Idempotent = true
+		config.Net.MaxOpenRequests = 1
+		config.Producer.RequiredAcks = WaitForAll
+		config.Producer.Return.Successes = true
+		config.Producer.Flush.Frequency = 50 * time.Millisecond
+		config.Producer.Retry.Backoff = 100 * time.Millisecond
+
+		broker.setHandler(handlerFailBeforeWrite)
+		producer, err := NewAsyncProducer([]string{broker.Addr()}, config)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		for i := 0; i < 3; i++ {
+			producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)}
+		}
+
+		go func() {
+			for i := 0; i < 7; i++ {
+				producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder("goroutine")}
+				time.Sleep(100 * time.Millisecond)
+			}
+		}()
+
+		expectResults(t, producer, 10, 0)
+
+		broker.Close()
+		closeProducer(t, producer)
+	}
+}
+
+func TestAsyncProducerIdempotentErrorOnOutOfSeq(t *testing.T) {
+	broker := NewMockBroker(t, 1)
+
+	metadataResponse := &MetadataResponse{
+		Version:      1,
+		ControllerID: 1,
+	}
+	metadataResponse.AddBroker(broker.Addr(), broker.BrokerID())
+	metadataResponse.AddTopicPartition("my_topic", 0, broker.BrokerID(), nil, nil, ErrNoError)
+	broker.Returns(metadataResponse)
+
+	initProducerID := &InitProducerIDResponse{
+		ThrottleTime:  0,
+		ProducerID:    1000,
+		ProducerEpoch: 1,
+	}
+	broker.Returns(initProducerID)
+
+	config := NewConfig()
+	config.Producer.Flush.Messages = 10
+	config.Producer.Return.Successes = true
+	config.Producer.Retry.Max = 400000
+	config.Producer.RequiredAcks = WaitForAll
+	config.Producer.Retry.Backoff = 0
+	config.Producer.Idempotent = true
+	config.Net.MaxOpenRequests = 1
+	config.Version = V0_11_0_0
+
+	producer, err := NewAsyncProducer([]string{broker.Addr()}, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for i := 0; i < 10; i++ {
+		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)}
+	}
+
+	prodOutOfSeq := &ProduceResponse{
+		Version:      3,
+		ThrottleTime: 0,
+	}
+	prodOutOfSeq.AddTopicPartition("my_topic", 0, ErrOutOfOrderSequenceNumber)
+	broker.Returns(prodOutOfSeq)
+	expectResults(t, producer, 0, 10)
+
+	broker.Close()
+	closeProducer(t, producer)
+}
+
 // This example shows how to use the producer while simultaneously
 // This example shows how to use the producer while simultaneously
 // reading the Errors channel to know about any failures.
 // reading the Errors channel to know about any failures.
 func ExampleAsyncProducer_select() {
 func ExampleAsyncProducer_select() {

+ 24 - 1
client.go

@@ -67,6 +67,9 @@ type Client interface {
 	// in local cache. This function only works on Kafka 0.8.2 and higher.
 	// in local cache. This function only works on Kafka 0.8.2 and higher.
 	RefreshCoordinator(consumerGroup string) error
 	RefreshCoordinator(consumerGroup string) error
 
 
+	// InitProducerID retrieves information required for Idempotent Producer
+	InitProducerID() (*InitProducerIDResponse, error)
+
 	// Close shuts down all broker connections managed by this client. It is required
 	// Close shuts down all broker connections managed by this client. It is required
 	// to call this function before a client object passes out of scope, as it will
 	// to call this function before a client object passes out of scope, as it will
 	// otherwise leak memory. You must close any Producers or Consumers using a client
 	// otherwise leak memory. You must close any Producers or Consumers using a client
@@ -183,6 +186,26 @@ func (client *client) Brokers() []*Broker {
 	return brokers
 	return brokers
 }
 }
 
 
+func (client *client) InitProducerID() (*InitProducerIDResponse, error) {
+	var err error
+	for broker := client.any(); broker != nil; broker = client.any() {
+
+		req := &InitProducerIDRequest{}
+
+		response, err := broker.InitProducerID(req)
+		switch err.(type) {
+		case nil:
+			return response, nil
+		default:
+			// some error, remove that broker and try again
+			Logger.Printf("Client got error from broker %d when issuing InitProducerID : %v\n", broker.ID(), err)
+			_ = broker.Close()
+			client.deregisterBroker(broker)
+		}
+	}
+	return nil, err
+}
+
 func (client *client) Close() error {
 func (client *client) Close() error {
 	if client.Closed() {
 	if client.Closed() {
 		// Chances are this is being called from a defer() and the error will go unobserved
 		// Chances are this is being called from a defer() and the error will go unobserved
@@ -723,7 +746,7 @@ func (client *client) tryRefreshMetadata(topics []string, attemptsRemaining int)
 			return err
 			return err
 		default:
 		default:
 			// some other error, remove that broker and try again
 			// some other error, remove that broker and try again
-			Logger.Println("client/metadata got error from broker while fetching metadata:", err)
+			Logger.Printf("client/metadata got error from broker %d while fetching metadata: %v\n", broker.ID(), err)
 			_ = broker.Close()
 			_ = broker.Close()
 			client.deregisterBroker(broker)
 			client.deregisterBroker(broker)
 		}
 		}

+ 18 - 0
config.go

@@ -124,6 +124,9 @@ type Config struct {
 		// (defaults to hashing the message key). Similar to the `partitioner.class`
 		// (defaults to hashing the message key). Similar to the `partitioner.class`
 		// setting for the JVM producer.
 		// setting for the JVM producer.
 		Partitioner PartitionerConstructor
 		Partitioner PartitionerConstructor
+		// If enabled, the producer will ensure that exactly one copy of each message is
+		// written.
+		Idempotent bool
 
 
 		// Return specifies what channels will be populated. If they are set to true,
 		// Return specifies what channels will be populated. If they are set to true,
 		// you must read from the respective channels to prevent deadlock. If,
 		// you must read from the respective channels to prevent deadlock. If,
@@ -511,6 +514,21 @@ func (c *Config) Validate() error {
 		}
 		}
 	}
 	}
 
 
+	if c.Producer.Idempotent {
+		if !c.Version.IsAtLeast(V0_11_0_0) {
+			return ConfigurationError("Idempotent producer requires Version >= V0_11_0_0")
+		}
+		if c.Producer.Retry.Max == 0 {
+			return ConfigurationError("Idempotent producer requires Producer.Retry.Max >= 1")
+		}
+		if c.Producer.RequiredAcks != WaitForAll {
+			return ConfigurationError("Idempotent producer requires Producer.RequiredAcks to be WaitForAll")
+		}
+		if c.Net.MaxOpenRequests > 1 {
+			return ConfigurationError("Idempotent producer requires Net.MaxOpenRequests to be 1")
+		}
+	}
+
 	// validate the Consumer values
 	// validate the Consumer values
 	switch {
 	switch {
 	case c.Consumer.Fetch.Min <= 0:
 	case c.Consumer.Fetch.Min <= 0:

+ 26 - 0
config_test.go

@@ -207,6 +207,32 @@ func TestProducerConfigValidates(t *testing.T) {
 				cfg.Producer.Retry.Backoff = -1
 				cfg.Producer.Retry.Backoff = -1
 			},
 			},
 			"Producer.Retry.Backoff must be >= 0"},
 			"Producer.Retry.Backoff must be >= 0"},
+		{"Idempotent Version",
+			func(cfg *Config) {
+				cfg.Producer.Idempotent = true
+				cfg.Version = V0_10_0_0
+			},
+			"Idempotent producer requires Version >= V0_11_0_0"},
+		{"Idempotent with Producer.Retry.Max",
+			func(cfg *Config) {
+				cfg.Version = V0_11_0_0
+				cfg.Producer.Idempotent = true
+				cfg.Producer.Retry.Max = 0
+			},
+			"Idempotent producer requires Producer.Retry.Max >= 1"},
+		{"Idempotent with Producer.RequiredAcks",
+			func(cfg *Config) {
+				cfg.Version = V0_11_0_0
+				cfg.Producer.Idempotent = true
+			},
+			"Idempotent producer requires Producer.RequiredAcks to be WaitForAll"},
+		{"Idempotent with Net.MaxOpenRequests",
+			func(cfg *Config) {
+				cfg.Version = V0_11_0_0
+				cfg.Producer.Idempotent = true
+				cfg.Producer.RequiredAcks = WaitForAll
+			},
+			"Idempotent producer requires Net.MaxOpenRequests to be 1"},
 	}
 	}
 
 
 	for i, test := range tests {
 	for i, test := range tests {

+ 21 - 3
functional_consumer_test.go

@@ -81,7 +81,7 @@ func TestVersionMatrix(t *testing.T) {
 	// protocol versions and compressions for the except of LZ4.
 	// protocol versions and compressions for the except of LZ4.
 	testVersions := versionRange(V0_8_2_0)
 	testVersions := versionRange(V0_8_2_0)
 	allCodecsButLZ4 := []CompressionCodec{CompressionNone, CompressionGZIP, CompressionSnappy}
 	allCodecsButLZ4 := []CompressionCodec{CompressionNone, CompressionGZIP, CompressionSnappy}
-	producedMessages := produceMsgs(t, testVersions, allCodecsButLZ4, 17, 100)
+	producedMessages := produceMsgs(t, testVersions, allCodecsButLZ4, 17, 100, false)
 
 
 	// When/Then
 	// When/Then
 	consumeMsgs(t, testVersions, producedMessages)
 	consumeMsgs(t, testVersions, producedMessages)
@@ -98,7 +98,20 @@ func TestVersionMatrixLZ4(t *testing.T) {
 	// and all possible compressions.
 	// and all possible compressions.
 	testVersions := versionRange(V0_10_0_0)
 	testVersions := versionRange(V0_10_0_0)
 	allCodecs := []CompressionCodec{CompressionNone, CompressionGZIP, CompressionSnappy, CompressionLZ4}
 	allCodecs := []CompressionCodec{CompressionNone, CompressionGZIP, CompressionSnappy, CompressionLZ4}
-	producedMessages := produceMsgs(t, testVersions, allCodecs, 17, 100)
+	producedMessages := produceMsgs(t, testVersions, allCodecs, 17, 100, false)
+
+	// When/Then
+	consumeMsgs(t, testVersions, producedMessages)
+}
+
+func TestVersionMatrixIdempotent(t *testing.T) {
+	setupFunctionalTest(t)
+	defer teardownFunctionalTest(t)
+
+	// Produce lot's of message with all possible combinations of supported
+	// protocol versions starting with v0.11 (first where idempotent was supported)
+	testVersions := versionRange(V0_11_0_0)
+	producedMessages := produceMsgs(t, testVersions, []CompressionCodec{CompressionNone}, 17, 100, true)
 
 
 	// When/Then
 	// When/Then
 	consumeMsgs(t, testVersions, producedMessages)
 	consumeMsgs(t, testVersions, producedMessages)
@@ -133,7 +146,7 @@ func versionRange(lower KafkaVersion) []KafkaVersion {
 	return versions
 	return versions
 }
 }
 
 
-func produceMsgs(t *testing.T, clientVersions []KafkaVersion, codecs []CompressionCodec, flush int, countPerVerCodec int) []*ProducerMessage {
+func produceMsgs(t *testing.T, clientVersions []KafkaVersion, codecs []CompressionCodec, flush int, countPerVerCodec int, idempotent bool) []*ProducerMessage {
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	var producedMessagesMu sync.Mutex
 	var producedMessagesMu sync.Mutex
 	var producedMessages []*ProducerMessage
 	var producedMessages []*ProducerMessage
@@ -145,6 +158,11 @@ func produceMsgs(t *testing.T, clientVersions []KafkaVersion, codecs []Compressi
 			prodCfg.Producer.Return.Errors = true
 			prodCfg.Producer.Return.Errors = true
 			prodCfg.Producer.Flush.MaxMessages = flush
 			prodCfg.Producer.Flush.MaxMessages = flush
 			prodCfg.Producer.Compression = codec
 			prodCfg.Producer.Compression = codec
+			prodCfg.Producer.Idempotent = idempotent
+			if idempotent {
+				prodCfg.Producer.RequiredAcks = WaitForAll
+				prodCfg.Net.MaxOpenRequests = 1
+			}
 
 
 			p, err := NewSyncProducer(kafkaBrokers, prodCfg)
 			p, err := NewSyncProducer(kafkaBrokers, prodCfg)
 			if err != nil {
 			if err != nil {

+ 7 - 1
produce_response.go

@@ -179,5 +179,11 @@ func (r *ProduceResponse) AddTopicPartition(topic string, partition int32, err K
 		byTopic = make(map[int32]*ProduceResponseBlock)
 		byTopic = make(map[int32]*ProduceResponseBlock)
 		r.Blocks[topic] = byTopic
 		r.Blocks[topic] = byTopic
 	}
 	}
-	byTopic[partition] = &ProduceResponseBlock{Err: err}
+	block := &ProduceResponseBlock{
+		Err: err,
+	}
+	if r.Version >= 2 {
+		block.Timestamp = time.Now()
+	}
+	byTopic[partition] = block
 }
 }

+ 14 - 7
produce_set.go

@@ -2,6 +2,7 @@ package sarama
 
 
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
+	"errors"
 	"time"
 	"time"
 )
 )
 
 
@@ -61,9 +62,13 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 			batch := &RecordBatch{
 			batch := &RecordBatch{
 				FirstTimestamp:   timestamp,
 				FirstTimestamp:   timestamp,
 				Version:          2,
 				Version:          2,
-				ProducerID:       -1, /* No producer id */
 				Codec:            ps.parent.conf.Producer.Compression,
 				Codec:            ps.parent.conf.Producer.Compression,
 				CompressionLevel: ps.parent.conf.Producer.CompressionLevel,
 				CompressionLevel: ps.parent.conf.Producer.CompressionLevel,
+				ProducerID:       ps.parent.txnmgr.producerID,
+				ProducerEpoch:    ps.parent.txnmgr.producerEpoch,
+			}
+			if ps.parent.conf.Producer.Idempotent {
+				batch.FirstSequence = msg.sequenceNumber
 			}
 			}
 			set = &partitionSet{recordsToSend: newDefaultRecords(batch)}
 			set = &partitionSet{recordsToSend: newDefaultRecords(batch)}
 			size = recordBatchOverhead
 			size = recordBatchOverhead
@@ -72,9 +77,12 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 		}
 		}
 		partitions[msg.Partition] = set
 		partitions[msg.Partition] = set
 	}
 	}
-
 	set.msgs = append(set.msgs, msg)
 	set.msgs = append(set.msgs, msg)
+
 	if ps.parent.conf.Version.IsAtLeast(V0_11_0_0) {
 	if ps.parent.conf.Version.IsAtLeast(V0_11_0_0) {
+		if ps.parent.conf.Producer.Idempotent && msg.sequenceNumber < set.recordsToSend.RecordBatch.FirstSequence {
+			return errors.New("Assertion failed: Message out of sequence added to a batch")
+		}
 		// We are being conservative here to avoid having to prep encode the record
 		// We are being conservative here to avoid having to prep encode the record
 		size += maximumRecordOverhead
 		size += maximumRecordOverhead
 		rec := &Record{
 		rec := &Record{
@@ -120,8 +128,8 @@ func (ps *produceSet) buildRequest() *ProduceRequest {
 		req.Version = 3
 		req.Version = 3
 	}
 	}
 
 
-	for topic, partitionSet := range ps.msgs {
-		for partition, set := range partitionSet {
+	for topic, partitionSets := range ps.msgs {
+		for partition, set := range partitionSets {
 			if req.Version >= 3 {
 			if req.Version >= 3 {
 				// If the API version we're hitting is 3 or greater, we need to calculate
 				// If the API version we're hitting is 3 or greater, we need to calculate
 				// offsets for each record in the batch relative to FirstOffset.
 				// offsets for each record in the batch relative to FirstOffset.
@@ -137,7 +145,6 @@ func (ps *produceSet) buildRequest() *ProduceRequest {
 						record.OffsetDelta = int64(i)
 						record.OffsetDelta = int64(i)
 					}
 					}
 				}
 				}
-
 				req.AddBatch(topic, partition, rb)
 				req.AddBatch(topic, partition, rb)
 				continue
 				continue
 			}
 			}
@@ -183,10 +190,10 @@ func (ps *produceSet) buildRequest() *ProduceRequest {
 	return req
 	return req
 }
 }
 
 
-func (ps *produceSet) eachPartition(cb func(topic string, partition int32, msgs []*ProducerMessage)) {
+func (ps *produceSet) eachPartition(cb func(topic string, partition int32, pSet *partitionSet)) {
 	for topic, partitionSet := range ps.msgs {
 	for topic, partitionSet := range ps.msgs {
 		for partition, set := range partitionSet {
 		for partition, set := range partitionSet {
-			cb(topic, partition, set.msgs)
+			cb(topic, partition, set)
 		}
 		}
 	}
 	}
 }
 }

+ 94 - 3
produce_set_test.go

@@ -7,8 +7,11 @@ import (
 )
 )
 
 
 func makeProduceSet() (*asyncProducer, *produceSet) {
 func makeProduceSet() (*asyncProducer, *produceSet) {
+	conf := NewConfig()
+	txnmgr, _ := newTransactionManager(conf, nil)
 	parent := &asyncProducer{
 	parent := &asyncProducer{
-		conf: NewConfig(),
+		conf:   conf,
+		txnmgr: txnmgr,
 	}
 	}
 	return parent, newProduceSet(parent)
 	return parent, newProduceSet(parent)
 }
 }
@@ -72,8 +75,8 @@ func TestProduceSetPartitionTracking(t *testing.T) {
 	seenT1P1 := false
 	seenT1P1 := false
 	seenT2P0 := false
 	seenT2P0 := false
 
 
-	ps.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) {
-		if len(msgs) != 1 {
+	ps.eachPartition(func(topic string, partition int32, pSet *partitionSet) {
+		if len(pSet.msgs) != 1 {
 			t.Error("Wrong message count")
 			t.Error("Wrong message count")
 		}
 		}
 
 
@@ -253,3 +256,91 @@ func TestProduceSetV3RequestBuilding(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestProduceSetIdempotentRequestBuilding(t *testing.T) {
+	const pID = 1000
+	const pEpoch = 1234
+
+	config := NewConfig()
+	config.Producer.RequiredAcks = WaitForAll
+	config.Producer.Idempotent = true
+	config.Version = V0_11_0_0
+
+	parent := &asyncProducer{
+		conf: config,
+		txnmgr: &transactionManager{
+			producerID:    pID,
+			producerEpoch: pEpoch,
+		},
+	}
+	ps := newProduceSet(parent)
+
+	now := time.Now()
+	msg := &ProducerMessage{
+		Topic:     "t1",
+		Partition: 0,
+		Key:       StringEncoder(TestMessage),
+		Value:     StringEncoder(TestMessage),
+		Headers: []RecordHeader{
+			RecordHeader{
+				Key:   []byte("header-1"),
+				Value: []byte("value-1"),
+			},
+			RecordHeader{
+				Key:   []byte("header-2"),
+				Value: []byte("value-2"),
+			},
+			RecordHeader{
+				Key:   []byte("header-3"),
+				Value: []byte("value-3"),
+			},
+		},
+		Timestamp:      now,
+		sequenceNumber: 123,
+	}
+	for i := 0; i < 10; i++ {
+		safeAddMessage(t, ps, msg)
+		msg.Timestamp = msg.Timestamp.Add(time.Second)
+	}
+
+	req := ps.buildRequest()
+
+	if req.Version != 3 {
+		t.Error("Wrong request version")
+	}
+
+	batch := req.records["t1"][0].RecordBatch
+	if batch.FirstTimestamp != now {
+		t.Errorf("Wrong first timestamp: %v", batch.FirstTimestamp)
+	}
+	if batch.ProducerID != pID {
+		t.Errorf("Wrong producerID: %v", batch.ProducerID)
+	}
+	if batch.ProducerEpoch != pEpoch {
+		t.Errorf("Wrong producerEpoch: %v", batch.ProducerEpoch)
+	}
+	if batch.FirstSequence != 123 {
+		t.Errorf("Wrong first sequence: %v", batch.FirstSequence)
+	}
+	for i := 0; i < 10; i++ {
+		rec := batch.Records[i]
+		if rec.TimestampDelta != time.Duration(i)*time.Second {
+			t.Errorf("Wrong timestamp delta: %v", rec.TimestampDelta)
+		}
+
+		if rec.OffsetDelta != int64(i) {
+			t.Errorf("Wrong relative inner offset, expected %d, got %d", i, rec.OffsetDelta)
+		}
+
+		for j, h := range batch.Records[i].Headers {
+			exp := fmt.Sprintf("header-%d", j+1)
+			if string(h.Key) != exp {
+				t.Errorf("Wrong header key, expected %v, got %v", exp, h.Key)
+			}
+			exp = fmt.Sprintf("value-%d", j+1)
+			if string(h.Value) != exp {
+				t.Errorf("Wrong header value, expected %v, got %v", exp, h.Value)
+			}
+		}
+	}
+}