浏览代码

Retry messages on shutdown

Evan Huus 10 年之前
父节点
当前提交
42123dd0dd
共有 1 个文件被更改,包括 48 次插入63 次删除
  1. 48 63
      async_producer.go

+ 48 - 63
async_producer.go

@@ -54,6 +54,7 @@ type asyncProducer struct {
 
 	errors                    chan *ProducerError
 	input, successes, retries chan *ProducerMessage
+	inFlight                  sync.WaitGroup
 
 	brokers    map[*Broker]chan *ProducerMessage
 	brokerRefs map[chan *ProducerMessage]int
@@ -105,8 +106,6 @@ type flagSet int8
 
 const (
 	chaser   flagSet = 1 << iota // message is last in a group that failed
-	ref                          // add a reference to a singleton channel
-	unref                        // remove a reference from a singleton channel
 	shutdown                     // start the shutdown process
 )
 
@@ -193,9 +192,7 @@ func (p *asyncProducer) Close() error {
 }
 
 func (p *asyncProducer) AsyncClose() {
-	go withRecover(func() {
-		p.input <- &ProducerMessage{flags: shutdown}
-	})
+	go withRecover(p.shutdown)
 }
 
 ///////////////////////////////////////////
@@ -209,6 +206,7 @@ func (p *asyncProducer) AsyncClose() {
 // dispatches messages by topic
 func (p *asyncProducer) topicDispatcher() {
 	handlers := make(map[string]chan *ProducerMessage)
+	shuttingDown := false
 
 	for msg := range p.input {
 		if msg == nil {
@@ -217,8 +215,14 @@ func (p *asyncProducer) topicDispatcher() {
 		}
 
 		if msg.flags&shutdown != 0 {
-			Logger.Println("Producer shutting down.")
-			break
+			shuttingDown = true
+			continue
+		} else if msg.retries == 0 {
+			p.inFlight.Add(1)
+			if shuttingDown {
+				p.returnError(msg, ErrShuttingDown)
+				continue
+			}
 		}
 
 		if (p.conf.Producer.Compression == CompressionNone && msg.Value != nil && msg.Value.Length() > p.conf.Producer.MaxMessageBytes) ||
@@ -230,7 +234,6 @@ func (p *asyncProducer) topicDispatcher() {
 
 		handler := handlers[msg.Topic]
 		if handler == nil {
-			p.retries <- &ProducerMessage{flags: ref}
 			newHandler := make(chan *ProducerMessage, p.conf.ChannelBufferSize)
 			topic := msg.Topic // block local because go's closure semantics suck
 			go withRecover(func() { p.partitionDispatcher(topic, newHandler) })
@@ -244,21 +247,6 @@ func (p *asyncProducer) topicDispatcher() {
 	for _, handler := range handlers {
 		close(handler)
 	}
-
-	p.retries <- &ProducerMessage{flags: shutdown}
-
-	for msg := range p.input {
-		p.returnError(msg, ErrShuttingDown)
-	}
-
-	if p.ownClient {
-		err := p.client.Close()
-		if err != nil {
-			Logger.Println("producer/shutdown failed to close the embedded client:", err)
-		}
-	}
-	close(p.errors)
-	close(p.successes)
 }
 
 // one per topic
@@ -281,7 +269,6 @@ func (p *asyncProducer) partitionDispatcher(topic string, input chan *ProducerMe
 
 		handler := handlers[msg.Partition]
 		if handler == nil {
-			p.retries <- &ProducerMessage{flags: ref}
 			newHandler := make(chan *ProducerMessage, p.conf.ChannelBufferSize)
 			topic := msg.Topic         // block local because go's closure semantics suck
 			partition := msg.Partition // block local because go's closure semantics suck
@@ -296,7 +283,6 @@ func (p *asyncProducer) partitionDispatcher(topic string, input chan *ProducerMe
 	for _, handler := range handlers {
 		close(handler)
 	}
-	p.retries <- &ProducerMessage{flags: unref}
 }
 
 // one per partition per topic
@@ -344,6 +330,7 @@ func (p *asyncProducer) leaderDispatcher(topic string, partition int32, input ch
 			highWatermark = msg.retries
 			Logger.Printf("producer/leader/%s/%d state change to [retrying-%d]\n", topic, partition, highWatermark)
 			retryState[msg.retries].expectChaser = true
+			p.inFlight.Add(1) // we're generating a chaser message; track it so we don't shut down while it's still inflight
 			output <- &ProducerMessage{Topic: topic, Partition: partition, flags: chaser, retries: msg.retries - 1}
 			Logger.Printf("producer/leader/%s/%d abandoning broker %d\n", topic, partition, leader.ID())
 			p.unrefBrokerProducer(leader, output)
@@ -355,6 +342,7 @@ func (p *asyncProducer) leaderDispatcher(topic string, partition int32, input ch
 				// in fact this message is not even the current retry level, so buffer it for now (unless it's a just a chaser)
 				if msg.flags&chaser == chaser {
 					retryState[msg.retries].expectChaser = false
+					p.inFlight.Done() // this chaser is now handled and will be garbage collected
 				} else {
 					retryState[msg.retries].buf = append(retryState[msg.retries].buf, msg)
 				}
@@ -392,6 +380,7 @@ func (p *asyncProducer) leaderDispatcher(topic string, partition int32, input ch
 					}
 
 				}
+				p.inFlight.Done() // this chaser is now handled and will be garbage collected
 				continue
 			}
 		}
@@ -414,7 +403,6 @@ func (p *asyncProducer) leaderDispatcher(topic string, partition int32, input ch
 	if output != nil {
 		p.unrefBrokerProducer(leader, output)
 	}
-	p.retries <- &ProducerMessage{flags: unref}
 }
 
 // one per broker
@@ -543,9 +531,7 @@ func (p *asyncProducer) flusher(broker *Broker, input chan []*ProducerMessage) {
 
 		if response == nil {
 			// this only happens when RequiredAcks is NoResponse, so we have to assume success
-			if p.conf.Producer.Return.Successes {
-				p.returnSuccesses(batch)
-			}
+			p.returnSuccesses(batch)
 			continue
 		}
 
@@ -563,12 +549,10 @@ func (p *asyncProducer) flusher(broker *Broker, input chan []*ProducerMessage) {
 				switch block.Err {
 				case ErrNoError:
 					// All the messages for this topic-partition were delivered successfully!
-					if p.conf.Producer.Return.Successes {
-						for i := range msgs {
-							msgs[i].Offset = block.Offset + int64(i)
-						}
-						p.returnSuccesses(msgs)
+					for i := range msgs {
+						msgs[i].Offset = block.Offset + int64(i)
 					}
+					p.returnSuccesses(msgs)
 				case ErrUnknownTopicOrPartition, ErrNotLeaderForPartition, ErrLeaderNotAvailable,
 					ErrRequestTimedOut, ErrNotEnoughReplicas, ErrNotEnoughReplicasAfterAppend:
 					Logger.Printf("producer/flusher/%d state change to [retrying] on %s/%d because %v\n",
@@ -585,19 +569,14 @@ func (p *asyncProducer) flusher(broker *Broker, input chan []*ProducerMessage) {
 		}
 	}
 	Logger.Printf("producer/flusher/%d shut down\n", broker.ID())
-	p.retries <- &ProducerMessage{flags: unref}
 }
 
 // singleton
 // effectively a "bridge" between the flushers and the topicDispatcher in order to avoid deadlock
 // based on https://godoc.org/github.com/eapache/channels#InfiniteChannel
 func (p *asyncProducer) retryHandler() {
-	var (
-		msg          *ProducerMessage
-		buf          = queue.New()
-		refs         = 0
-		shuttingDown = false
-	)
+	var msg *ProducerMessage
+	buf := queue.New()
 
 	for {
 		if buf.Length() == 0 {
@@ -611,29 +590,12 @@ func (p *asyncProducer) retryHandler() {
 			}
 		}
 
-		if msg.flags&ref != 0 {
-			refs++
-		} else if msg.flags&unref != 0 {
-			refs--
-			if refs == 0 && shuttingDown {
-				break
-			}
-		} else if msg.flags&shutdown != 0 {
-			shuttingDown = true
-			if refs == 0 {
-				break
-			}
-		} else {
-			buf.Add(msg)
+		if msg == nil {
+			return
 		}
-	}
 
-	close(p.retries)
-	for buf.Length() != 0 {
-		p.input <- buf.Peek().(*ProducerMessage)
-		buf.Remove()
+		buf.Add(msg)
 	}
-	close(p.input)
 }
 
 ///////////////////////////////////////////
@@ -641,6 +603,25 @@ func (p *asyncProducer) retryHandler() {
 
 // utility functions
 
+func (p *asyncProducer) shutdown() {
+	Logger.Println("Producer shutting down.")
+	p.input <- &ProducerMessage{flags: shutdown}
+
+	p.inFlight.Wait()
+
+	if p.ownClient {
+		err := p.client.Close()
+		if err != nil {
+			Logger.Println("producer/shutdown failed to close the embedded client:", err)
+		}
+	}
+
+	close(p.input)
+	close(p.retries)
+	close(p.errors)
+	close(p.successes)
+}
+
 func (p *asyncProducer) assignPartition(partitioner Partitioner, msg *ProducerMessage) error {
 	var partitions []int32
 	var err error
@@ -745,6 +726,7 @@ func (p *asyncProducer) returnError(msg *ProducerMessage, err error) {
 	} else {
 		Logger.Println(pErr)
 	}
+	p.inFlight.Done()
 }
 
 func (p *asyncProducer) returnErrors(batch []*ProducerMessage, err error) {
@@ -757,10 +739,14 @@ func (p *asyncProducer) returnErrors(batch []*ProducerMessage, err error) {
 
 func (p *asyncProducer) returnSuccesses(batch []*ProducerMessage) {
 	for _, msg := range batch {
-		if msg != nil {
+		if msg == nil {
+			continue
+		}
+		if p.conf.Producer.Return.Successes {
 			msg.flags = 0
 			p.successes <- msg
 		}
+		p.inFlight.Done()
 	}
 }
 
@@ -785,7 +771,6 @@ func (p *asyncProducer) getBrokerProducer(broker *Broker) chan *ProducerMessage
 	bp := p.brokers[broker]
 
 	if bp == nil {
-		p.retries <- &ProducerMessage{flags: ref}
 		bp = make(chan *ProducerMessage)
 		p.brokers[broker] = bp
 		p.brokerRefs[bp] = 0