Browse Source

Extract produceSet and add tests

Now that `produceSet` is its own stand-alone structure with well-defined
behaviour, extract it into its own file (pure copy-paste, nothing to review) to
make async_producer.go not quite so huge.

This makes it more obvious that produceSet can/should have unit tests, so add a
few basic ones.
Evan Huus 9 years ago
parent
commit
52c880384b
3 changed files with 295 additions and 157 deletions
  1. 0 157
      async_producer.go
  2. 158 0
      produce_set.go
  3. 137 0
      produce_set_test.go

+ 0 - 157
async_producer.go

@@ -777,163 +777,6 @@ func (p *asyncProducer) retryHandler() {
 	}
 }
 
-// produceSet
-
-type partitionSet struct {
-	msgs        []*ProducerMessage
-	setToSend   *MessageSet
-	bufferBytes int
-}
-
-type produceSet struct {
-	parent *asyncProducer
-	msgs   map[string]map[int32]*partitionSet
-
-	bufferBytes int
-	bufferCount int
-}
-
-func newProduceSet(parent *asyncProducer) *produceSet {
-	return &produceSet{
-		msgs:   make(map[string]map[int32]*partitionSet),
-		parent: parent,
-	}
-}
-
-func (ps *produceSet) add(msg *ProducerMessage) error {
-	var err error
-	var key, val []byte
-
-	if msg.Key != nil {
-		if key, err = msg.Key.Encode(); err != nil {
-			return err
-		}
-	}
-
-	if msg.Value != nil {
-		if val, err = msg.Value.Encode(); err != nil {
-			return err
-		}
-	}
-
-	partitions := ps.msgs[msg.Topic]
-	if partitions == nil {
-		partitions = make(map[int32]*partitionSet)
-		ps.msgs[msg.Topic] = partitions
-	}
-
-	set := partitions[msg.Partition]
-	if set == nil {
-		set = &partitionSet{setToSend: new(MessageSet)}
-		partitions[msg.Partition] = set
-	}
-
-	set.msgs = append(set.msgs, msg)
-	set.setToSend.addMessage(&Message{Codec: CompressionNone, Key: key, Value: val})
-
-	size := producerMessageOverhead + len(key) + len(val)
-	set.bufferBytes += size
-	ps.bufferBytes += size
-	ps.bufferCount++
-
-	return nil
-}
-
-func (ps *produceSet) buildRequest() *ProduceRequest {
-	req := &ProduceRequest{
-		RequiredAcks: ps.parent.conf.Producer.RequiredAcks,
-		Timeout:      int32(ps.parent.conf.Producer.Timeout / time.Millisecond),
-	}
-
-	for topic, partitionSet := range ps.msgs {
-		for partition, set := range partitionSet {
-			if ps.parent.conf.Producer.Compression == CompressionNone {
-				req.AddSet(topic, partition, set.setToSend)
-			} else {
-				// When compression is enabled, the entire set for each partition is compressed
-				// and sent as the payload of a single fake "message" with the appropriate codec
-				// set and no key. When the server sees a message with a compression codec, it
-				// decompresses the payload and treats the result as its message set.
-				payload, err := encode(set.setToSend)
-				if err != nil {
-					Logger.Println(err) // if this happens, it's basically our fault.
-					panic(err)
-				}
-				req.AddMessage(topic, partition, &Message{
-					Codec: ps.parent.conf.Producer.Compression,
-					Key:   nil,
-					Value: payload,
-				})
-			}
-		}
-	}
-
-	return req
-}
-
-func (ps *produceSet) eachPartition(cb func(topic string, partition int32, msgs []*ProducerMessage)) {
-	for topic, partitionSet := range ps.msgs {
-		for partition, set := range partitionSet {
-			cb(topic, partition, set.msgs)
-		}
-	}
-}
-
-func (ps *produceSet) dropPartition(topic string, partition int32) []*ProducerMessage {
-	if ps.msgs[topic] == nil {
-		return nil
-	}
-	set := ps.msgs[topic][partition]
-	if set == nil {
-		return nil
-	}
-	ps.bufferBytes -= set.bufferBytes
-	ps.bufferCount -= len(set.msgs)
-	delete(ps.msgs[topic], partition)
-	return set.msgs
-}
-
-func (ps *produceSet) wouldOverflow(msg *ProducerMessage) bool {
-	switch {
-	// Would we overflow our maximum possible size-on-the-wire? 10KiB is arbitrary overhead for safety.
-	case ps.bufferBytes+msg.byteSize() >= int(MaxRequestSize-(10*1024)):
-		return true
-	// Would we overflow the size-limit of a compressed message-batch for this partition?
-	case ps.parent.conf.Producer.Compression != CompressionNone &&
-		ps.msgs[msg.Topic] != nil && ps.msgs[msg.Topic][msg.Partition] != nil &&
-		ps.msgs[msg.Topic][msg.Partition].bufferBytes+msg.byteSize() >= ps.parent.conf.Producer.MaxMessageBytes:
-		return true
-	// Would we overflow simply in number of messages?
-	case ps.parent.conf.Producer.Flush.MaxMessages > 0 && ps.bufferCount >= ps.parent.conf.Producer.Flush.MaxMessages:
-		return true
-	default:
-		return false
-	}
-}
-
-func (ps *produceSet) readyToFlush() bool {
-	switch {
-	// If we don't have any messages, nothing else matters
-	case ps.empty():
-		return false
-	// If all three config values are 0, we always flush as-fast-as-possible
-	case ps.parent.conf.Producer.Flush.Frequency == 0 && ps.parent.conf.Producer.Flush.Bytes == 0 && ps.parent.conf.Producer.Flush.Messages == 0:
-		return true
-	// If we've passed the message trigger-point
-	case ps.parent.conf.Producer.Flush.Messages > 0 && ps.bufferCount >= ps.parent.conf.Producer.Flush.Messages:
-		return true
-	// If we've passed the byte trigger-point
-	case ps.parent.conf.Producer.Flush.Bytes > 0 && ps.bufferBytes >= ps.parent.conf.Producer.Flush.Bytes:
-		return true
-	default:
-		return false
-	}
-}
-
-func (ps *produceSet) empty() bool {
-	return ps.bufferCount == 0
-}
-
 // utility functions
 
 func (p *asyncProducer) shutdown() {

+ 158 - 0
produce_set.go

@@ -0,0 +1,158 @@
+package sarama
+
+import "time"
+
+type partitionSet struct {
+	msgs        []*ProducerMessage
+	setToSend   *MessageSet
+	bufferBytes int
+}
+
+type produceSet struct {
+	parent *asyncProducer
+	msgs   map[string]map[int32]*partitionSet
+
+	bufferBytes int
+	bufferCount int
+}
+
+func newProduceSet(parent *asyncProducer) *produceSet {
+	return &produceSet{
+		msgs:   make(map[string]map[int32]*partitionSet),
+		parent: parent,
+	}
+}
+
+func (ps *produceSet) add(msg *ProducerMessage) error {
+	var err error
+	var key, val []byte
+
+	if msg.Key != nil {
+		if key, err = msg.Key.Encode(); err != nil {
+			return err
+		}
+	}
+
+	if msg.Value != nil {
+		if val, err = msg.Value.Encode(); err != nil {
+			return err
+		}
+	}
+
+	partitions := ps.msgs[msg.Topic]
+	if partitions == nil {
+		partitions = make(map[int32]*partitionSet)
+		ps.msgs[msg.Topic] = partitions
+	}
+
+	set := partitions[msg.Partition]
+	if set == nil {
+		set = &partitionSet{setToSend: new(MessageSet)}
+		partitions[msg.Partition] = set
+	}
+
+	set.msgs = append(set.msgs, msg)
+	set.setToSend.addMessage(&Message{Codec: CompressionNone, Key: key, Value: val})
+
+	size := producerMessageOverhead + len(key) + len(val)
+	set.bufferBytes += size
+	ps.bufferBytes += size
+	ps.bufferCount++
+
+	return nil
+}
+
+func (ps *produceSet) buildRequest() *ProduceRequest {
+	req := &ProduceRequest{
+		RequiredAcks: ps.parent.conf.Producer.RequiredAcks,
+		Timeout:      int32(ps.parent.conf.Producer.Timeout / time.Millisecond),
+	}
+
+	for topic, partitionSet := range ps.msgs {
+		for partition, set := range partitionSet {
+			if ps.parent.conf.Producer.Compression == CompressionNone {
+				req.AddSet(topic, partition, set.setToSend)
+			} else {
+				// When compression is enabled, the entire set for each partition is compressed
+				// and sent as the payload of a single fake "message" with the appropriate codec
+				// set and no key. When the server sees a message with a compression codec, it
+				// decompresses the payload and treats the result as its message set.
+				payload, err := encode(set.setToSend)
+				if err != nil {
+					Logger.Println(err) // if this happens, it's basically our fault.
+					panic(err)
+				}
+				req.AddMessage(topic, partition, &Message{
+					Codec: ps.parent.conf.Producer.Compression,
+					Key:   nil,
+					Value: payload,
+				})
+			}
+		}
+	}
+
+	return req
+}
+
+func (ps *produceSet) eachPartition(cb func(topic string, partition int32, msgs []*ProducerMessage)) {
+	for topic, partitionSet := range ps.msgs {
+		for partition, set := range partitionSet {
+			cb(topic, partition, set.msgs)
+		}
+	}
+}
+
+func (ps *produceSet) dropPartition(topic string, partition int32) []*ProducerMessage {
+	if ps.msgs[topic] == nil {
+		return nil
+	}
+	set := ps.msgs[topic][partition]
+	if set == nil {
+		return nil
+	}
+	ps.bufferBytes -= set.bufferBytes
+	ps.bufferCount -= len(set.msgs)
+	delete(ps.msgs[topic], partition)
+	return set.msgs
+}
+
+func (ps *produceSet) wouldOverflow(msg *ProducerMessage) bool {
+	switch {
+	// Would we overflow our maximum possible size-on-the-wire? 10KiB is arbitrary overhead for safety.
+	case ps.bufferBytes+msg.byteSize() >= int(MaxRequestSize-(10*1024)):
+		return true
+	// Would we overflow the size-limit of a compressed message-batch for this partition?
+	case ps.parent.conf.Producer.Compression != CompressionNone &&
+		ps.msgs[msg.Topic] != nil && ps.msgs[msg.Topic][msg.Partition] != nil &&
+		ps.msgs[msg.Topic][msg.Partition].bufferBytes+msg.byteSize() >= ps.parent.conf.Producer.MaxMessageBytes:
+		return true
+	// Would we overflow simply in number of messages?
+	case ps.parent.conf.Producer.Flush.MaxMessages > 0 && ps.bufferCount >= ps.parent.conf.Producer.Flush.MaxMessages:
+		return true
+	default:
+		return false
+	}
+}
+
+func (ps *produceSet) readyToFlush() bool {
+	switch {
+	// If we don't have any messages, nothing else matters
+	case ps.empty():
+		return false
+	// If all three config values are 0, we always flush as-fast-as-possible
+	case ps.parent.conf.Producer.Flush.Frequency == 0 && ps.parent.conf.Producer.Flush.Bytes == 0 && ps.parent.conf.Producer.Flush.Messages == 0:
+		return true
+	// If we've passed the message trigger-point
+	case ps.parent.conf.Producer.Flush.Messages > 0 && ps.bufferCount >= ps.parent.conf.Producer.Flush.Messages:
+		return true
+	// If we've passed the byte trigger-point
+	case ps.parent.conf.Producer.Flush.Bytes > 0 && ps.bufferBytes >= ps.parent.conf.Producer.Flush.Bytes:
+		return true
+	default:
+		return false
+	}
+}
+
+func (ps *produceSet) empty() bool {
+	return ps.bufferCount == 0
+}

+ 137 - 0
produce_set_test.go

@@ -0,0 +1,137 @@
+package sarama
+
+import (
+	"testing"
+	"time"
+)
+
+func makeProduceSet() (*asyncProducer, *produceSet) {
+	parent := &asyncProducer{
+		conf: NewConfig(),
+	}
+	return parent, newProduceSet(parent)
+}
+
+func TestProduceSetInitial(t *testing.T) {
+	_, ps := makeProduceSet()
+
+	if !ps.empty() {
+		t.Error("New produceSet should be empty")
+	}
+
+	if ps.readyToFlush() {
+		t.Error("Empty produceSet must never be ready to flush")
+	}
+}
+
+func TestProduceSetAddingMessages(t *testing.T) {
+	parent, ps := makeProduceSet()
+	parent.conf.Producer.Flush.MaxMessages = 1000
+
+	msg := &ProducerMessage{Key: StringEncoder(TestMessage), Value: StringEncoder(TestMessage)}
+	ps.add(msg)
+
+	if ps.empty() {
+		t.Error("set shouldn't be empty when a message is added")
+	}
+
+	if !ps.readyToFlush() {
+		t.Error("by default set should be ready to flush when any message is in place")
+	}
+
+	for i := 0; i < 999; i++ {
+		if ps.wouldOverflow(msg) {
+			t.Error("set shouldn't fill up after only", i+1, "messages")
+		}
+		ps.add(msg)
+	}
+
+	if !ps.wouldOverflow(msg) {
+		t.Error("set should be full after 1000 messages")
+	}
+}
+
+func TestProduceSetPartitionTracking(t *testing.T) {
+	_, ps := makeProduceSet()
+
+	m1 := &ProducerMessage{Topic: "t1", Partition: 0}
+	m2 := &ProducerMessage{Topic: "t1", Partition: 1}
+	m3 := &ProducerMessage{Topic: "t2", Partition: 0}
+	ps.add(m1)
+	ps.add(m2)
+	ps.add(m3)
+
+	seenT1P0 := false
+	seenT1P1 := false
+	seenT2P0 := false
+
+	ps.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) {
+		if len(msgs) != 1 {
+			t.Error("Wrong message count")
+		}
+
+		if topic == "t1" && partition == 0 {
+			seenT1P0 = true
+		} else if topic == "t1" && partition == 1 {
+			seenT1P1 = true
+		} else if topic == "t2" && partition == 0 {
+			seenT2P0 = true
+		}
+	})
+
+	if !seenT1P0 {
+		t.Error("Didn't see t1p0")
+	}
+	if !seenT1P1 {
+		t.Error("Didn't see t1p1")
+	}
+	if !seenT2P0 {
+		t.Error("Didn't see t2p0")
+	}
+
+	if len(ps.dropPartition("t1", 1)) != 1 {
+		t.Error("Got wrong messages back from dropping partition")
+	}
+
+	if ps.bufferCount != 2 {
+		t.Error("Incorrect buffer count after dropping partition")
+	}
+}
+
+func TestProduceSetRequestBuilding(t *testing.T) {
+	parent, ps := makeProduceSet()
+	parent.conf.Producer.RequiredAcks = WaitForAll
+	parent.conf.Producer.Timeout = 10 * time.Second
+
+	msg := &ProducerMessage{
+		Topic:     "t1",
+		Partition: 0,
+		Key:       StringEncoder(TestMessage),
+		Value:     StringEncoder(TestMessage),
+	}
+	for i := 0; i < 10; i++ {
+		ps.add(msg)
+	}
+	msg.Partition = 1
+	for i := 0; i < 10; i++ {
+		ps.add(msg)
+	}
+	msg.Topic = "t2"
+	for i := 0; i < 10; i++ {
+		ps.add(msg)
+	}
+
+	req := ps.buildRequest()
+
+	if req.RequiredAcks != WaitForAll {
+		t.Error("RequiredAcks not set properly")
+	}
+
+	if req.Timeout != 10000 {
+		t.Error("Timeout not set properly")
+	}
+
+	if len(req.msgSets) != 2 {
+		t.Error("Wrong number of topics in request")
+	}
+}