瀏覽代碼

Merge pull request #331 from Shopify/consumer_mock

Consumer mock
Willem van Bergen 10 年之前
父節點
當前提交
ba97c45bd5
共有 8 個文件被更改,包括 579 次插入70 次删除
  1. 1 1
      .travis.yml
  2. 70 53
      consumer.go
  3. 4 4
      consumer_test.go
  4. 268 0
      mocks/consumer.go
  5. 192 0
      mocks/consumer_test.go
  6. 12 2
      mocks/mocks.go
  7. 9 3
      mocks/producer_test.go
  8. 23 7
      mocks/sync_producer_test.go

+ 1 - 1
.travis.yml

@@ -22,7 +22,7 @@ before_install:
 script:
 - go test -v -race ./...
 - go vet ./...
-- errcheck github.com/Shopify/sarama/
+- errcheck github.com/Shopify/sarama/...
 - if [[ -n $(go fmt ./...) ]]; then echo "Please run go fmt on your code." && exit 1; fi
 
 notifications:

+ 70 - 53
consumer.go

@@ -38,18 +38,28 @@ func (ce ConsumerErrors) Error() string {
 // Consumer manages PartitionConsumers which process Kafka messages from brokers. You MUST call Close()
 // on a consumer to avoid leaks, it will not be garbage-collected automatically when it passes out of
 // scope.
-type Consumer struct {
+type Consumer interface {
+	// ConsumePartition creates a PartitionConsumer on the given topic/partition with the given offset. It will
+	// return an error if this Consumer is already consuming on the given topic/partition. Offset can be a
+	// literal offset, or OffsetNewest or OffsetOldest
+	ConsumePartition(topic string, partition int32, offset int64) (PartitionConsumer, error)
+
+	// Close shuts down the consumer. It must be called after all child PartitionConsumers have already been closed.
+	Close() error
+}
+
+type consumer struct {
 	client    *Client
 	conf      *Config
 	ownClient bool
 
 	lock            sync.Mutex
-	children        map[string]map[int32]*PartitionConsumer
+	children        map[string]map[int32]*partitionConsumer
 	brokerConsumers map[*Broker]*brokerConsumer
 }
 
 // NewConsumer creates a new consumer using the given broker addresses and configuration.
-func NewConsumer(addrs []string, config *Config) (*Consumer, error) {
+func NewConsumer(addrs []string, config *Config) (Consumer, error) {
 	client, err := NewClient(addrs, config)
 	if err != nil {
 		return nil, err
@@ -59,29 +69,28 @@ func NewConsumer(addrs []string, config *Config) (*Consumer, error) {
 	if err != nil {
 		return nil, err
 	}
-	c.ownClient = true
+	c.(*consumer).ownClient = true
 	return c, nil
 }
 
 // NewConsumerFromClient creates a new consumer using the given client.
-func NewConsumerFromClient(client *Client) (*Consumer, error) {
+func NewConsumerFromClient(client *Client) (Consumer, error) {
 	// Check that we are not dealing with a closed Client before processing any other arguments
 	if client.Closed() {
 		return nil, ErrClosedClient
 	}
 
-	c := &Consumer{
+	c := &consumer{
 		client:          client,
 		conf:            client.conf,
-		children:        make(map[string]map[int32]*PartitionConsumer),
+		children:        make(map[string]map[int32]*partitionConsumer),
 		brokerConsumers: make(map[*Broker]*brokerConsumer),
 	}
 
 	return c, nil
 }
 
-// Close shuts down the consumer. It must be called after all child PartitionConsumers have already been closed.
-func (c *Consumer) Close() error {
+func (c *consumer) Close() error {
 	if c.ownClient {
 		return c.client.Close()
 	}
@@ -97,11 +106,8 @@ const (
 	OffsetOldest int64 = -2
 )
 
-// ConsumePartition creates a PartitionConsumer on the given topic/partition with the given offset. It will
-// return an error if this Consumer is already consuming on the given topic/partition. Offset can be a
-// literal offset, or OffsetNewest or OffsetOldest
-func (c *Consumer) ConsumePartition(topic string, partition int32, offset int64) (*PartitionConsumer, error) {
-	child := &PartitionConsumer{
+func (c *consumer) ConsumePartition(topic string, partition int32, offset int64) (PartitionConsumer, error) {
+	child := &partitionConsumer{
 		consumer:  c,
 		conf:      c.conf,
 		topic:     topic,
@@ -135,13 +141,13 @@ func (c *Consumer) ConsumePartition(topic string, partition int32, offset int64)
 	return child, nil
 }
 
-func (c *Consumer) addChild(child *PartitionConsumer) error {
+func (c *consumer) addChild(child *partitionConsumer) error {
 	c.lock.Lock()
 	defer c.lock.Unlock()
 
 	topicChildren := c.children[child.topic]
 	if topicChildren == nil {
-		topicChildren = make(map[int32]*PartitionConsumer)
+		topicChildren = make(map[int32]*partitionConsumer)
 		c.children[child.topic] = topicChildren
 	}
 
@@ -153,14 +159,14 @@ func (c *Consumer) addChild(child *PartitionConsumer) error {
 	return nil
 }
 
-func (c *Consumer) removeChild(child *PartitionConsumer) {
+func (c *consumer) removeChild(child *partitionConsumer) {
 	c.lock.Lock()
 	defer c.lock.Unlock()
 
 	delete(c.children[child.topic], child.partition)
 }
 
-func (c *Consumer) refBrokerConsumer(broker *Broker) *brokerConsumer {
+func (c *consumer) refBrokerConsumer(broker *Broker) *brokerConsumer {
 	c.lock.Lock()
 	defer c.lock.Unlock()
 
@@ -169,10 +175,10 @@ func (c *Consumer) refBrokerConsumer(broker *Broker) *brokerConsumer {
 		brokerWorker = &brokerConsumer{
 			consumer:         c,
 			broker:           broker,
-			input:            make(chan *PartitionConsumer),
-			newSubscriptions: make(chan []*PartitionConsumer),
+			input:            make(chan *partitionConsumer),
+			newSubscriptions: make(chan []*partitionConsumer),
 			wait:             make(chan none),
-			subscriptions:    make(map[*PartitionConsumer]none),
+			subscriptions:    make(map[*partitionConsumer]none),
 			refs:             1,
 		}
 		go withRecover(brokerWorker.subscriptionManager)
@@ -185,7 +191,7 @@ func (c *Consumer) refBrokerConsumer(broker *Broker) *brokerConsumer {
 	return brokerWorker
 }
 
-func (c *Consumer) unrefBrokerConsumer(broker *Broker) {
+func (c *consumer) unrefBrokerConsumer(broker *Broker) {
 	c.lock.Lock()
 	defer c.lock.Unlock()
 
@@ -204,8 +210,33 @@ func (c *Consumer) unrefBrokerConsumer(broker *Broker) {
 // on a consumer to avoid leaks, it will not be garbage-collected automatically when it passes out of
 // scope (this is in addition to calling Close on the underlying consumer's client, which is still necessary).
 // You have to read from both the Messages and Errors channels to prevent the consumer from locking eventually.
-type PartitionConsumer struct {
-	consumer  *Consumer
+type PartitionConsumer interface {
+
+	// AsyncClose initiates a shutdown of the PartitionConsumer. This method will return immediately,
+	// after which you should wait until the 'messages' and 'errors' channel are drained.
+	// It is required to call this function, or Close before a consumer object passes out of scope,
+	// as it will otherwise leak memory.  You must call this before calling Close on the underlying
+	// client.
+	AsyncClose()
+
+	// Close stops the PartitionConsumer from fetching messages. It is required to call this function
+	// (or AsyncClose) before a consumer object passes out of scope, as it will otherwise leak memory. You must
+	// call this before calling Close on the underlying client.
+	Close() error
+
+	// Errors returns the read channel for any errors that occurred while consuming the partition.
+	// You have to read this channel to prevent the consumer from deadlock. Under no circumstances,
+	// the partition consumer will shut down by itself. It will just wait until it is able to continue
+	// consuming messages. If you want to shut down your consumer, you will have trigger it yourself
+	// by consuming this channel and calling Close or AsyncClose when appropriate.
+	Errors() <-chan *ConsumerError
+
+	// Messages returns the read channel for the messages that are returned by the broker
+	Messages() <-chan *ConsumerMessage
+}
+
+type partitionConsumer struct {
+	consumer  *consumer
 	conf      *Config
 	topic     string
 	partition int32
@@ -219,7 +250,7 @@ type PartitionConsumer struct {
 	offset    int64
 }
 
-func (child *PartitionConsumer) sendError(err error) {
+func (child *partitionConsumer) sendError(err error) {
 	child.errors <- &ConsumerError{
 		Topic:     child.topic,
 		Partition: child.partition,
@@ -227,7 +258,7 @@ func (child *PartitionConsumer) sendError(err error) {
 	}
 }
 
-func (child *PartitionConsumer) dispatcher() {
+func (child *partitionConsumer) dispatcher() {
 	for _ = range child.trigger {
 		select {
 		case <-child.dying:
@@ -260,7 +291,7 @@ func (child *PartitionConsumer) dispatcher() {
 	close(child.errors)
 }
 
-func (child *PartitionConsumer) dispatch() error {
+func (child *partitionConsumer) dispatch() error {
 	if err := child.consumer.client.RefreshTopicMetadata(child.topic); err != nil {
 		return err
 	}
@@ -278,7 +309,7 @@ func (child *PartitionConsumer) dispatch() error {
 	return nil
 }
 
-func (child *PartitionConsumer) chooseStartingOffset(offset int64) (err error) {
+func (child *partitionConsumer) chooseStartingOffset(offset int64) (err error) {
 	var where OffsetTime
 
 	switch offset {
@@ -298,26 +329,15 @@ func (child *PartitionConsumer) chooseStartingOffset(offset int64) (err error) {
 	return err
 }
 
-// Messages returns the read channel for the messages that are returned by the broker
-func (child *PartitionConsumer) Messages() <-chan *ConsumerMessage {
+func (child *partitionConsumer) Messages() <-chan *ConsumerMessage {
 	return child.messages
 }
 
-// Errors returns the read channel for any errors that occurred while consuming the partition.
-// You have to read this channel to prevent the consumer from deadlock. Under no circumstances,
-// the partition consumer will shut down by itself. It will just wait until it is able to continue
-// consuming messages. If you want to shut down your consumer, you will have trigger it yourself
-// by consuming this channel and calling Close or AsyncClose when appropriate.
-func (child *PartitionConsumer) Errors() <-chan *ConsumerError {
+func (child *partitionConsumer) Errors() <-chan *ConsumerError {
 	return child.errors
 }
 
-// AsyncClose initiates a shutdown of the PartitionConsumer. This method will return immediately,
-// after which you should wait until the 'messages' and 'errors' channel are drained.
-// It is required to call this function, or Close before a consumer object passes out of scope,
-// as it will otherwise leak memory.  You must call this before calling Close on the underlying
-// client.
-func (child *PartitionConsumer) AsyncClose() {
+func (child *partitionConsumer) AsyncClose() {
 	// this triggers whatever worker owns this child to abandon it and close its trigger channel, which causes
 	// the dispatcher to exit its loop, which removes it from the consumer then closes its 'messages' and
 	// 'errors' channel (alternatively, if the child is already at the dispatcher for some reason, that will
@@ -325,10 +345,7 @@ func (child *PartitionConsumer) AsyncClose() {
 	close(child.dying)
 }
 
-// Close stops the PartitionConsumer from fetching messages. It is required to call this function
-// (or AsyncClose) before a consumer object passes out of scope, as it will otherwise leak memory. You must
-// call this before calling Close on the underlying client.
-func (child *PartitionConsumer) Close() error {
+func (child *partitionConsumer) Close() error {
 	child.AsyncClose()
 
 	go withRecover(func() {
@@ -351,17 +368,17 @@ func (child *PartitionConsumer) Close() error {
 // brokerConsumer
 
 type brokerConsumer struct {
-	consumer         *Consumer
+	consumer         *consumer
 	broker           *Broker
-	input            chan *PartitionConsumer
-	newSubscriptions chan []*PartitionConsumer
+	input            chan *partitionConsumer
+	newSubscriptions chan []*partitionConsumer
 	wait             chan none
-	subscriptions    map[*PartitionConsumer]none
+	subscriptions    map[*partitionConsumer]none
 	refs             int
 }
 
 func (w *brokerConsumer) subscriptionManager() {
-	var buffer []*PartitionConsumer
+	var buffer []*partitionConsumer
 
 	// The subscriptionManager constantly accepts new subscriptions on `input` (even when the main subscriptionConsumer
 	//  goroutine is in the middle of a network request) and batches it up. The main worker goroutine picks
@@ -436,7 +453,7 @@ func (w *brokerConsumer) subscriptionConsumer() {
 	}
 }
 
-func (w *brokerConsumer) updateSubscriptionCache(newSubscriptions []*PartitionConsumer) {
+func (w *brokerConsumer) updateSubscriptionCache(newSubscriptions []*partitionConsumer) {
 	// take new subscriptions, and abandon subscriptions that have been closed
 	for _, child := range newSubscriptions {
 		w.subscriptions[child] = none{}
@@ -482,7 +499,7 @@ func (w *brokerConsumer) fetchNewMessages() (*FetchResponse, error) {
 	return w.broker.Fetch(request)
 }
 
-func (w *brokerConsumer) handleResponse(child *PartitionConsumer, block *FetchResponseBlock) {
+func (w *brokerConsumer) handleResponse(child *partitionConsumer, block *FetchResponseBlock) {
 	switch block.Err {
 	case ErrNoError:
 		break

+ 4 - 4
consumer_test.go

@@ -82,8 +82,8 @@ func TestConsumerLatestOffset(t *testing.T) {
 	safeClose(t, consumer)
 
 	// we deliver one message, so it should be one higher than we return in the OffsetResponse
-	if consumer.offset != 0x010102 {
-		t.Error("Latest offset not fetched correctly:", consumer.offset)
+	if consumer.(*partitionConsumer).offset != 0x010102 {
+		t.Error("Latest offset not fetched correctly:", consumer.(*partitionConsumer).offset)
 	}
 }
 
@@ -155,14 +155,14 @@ func TestConsumerRebalancingMultiplePartitions(t *testing.T) {
 			t.Error(err)
 		}
 
-		go func(c *PartitionConsumer) {
+		go func(c PartitionConsumer) {
 			for err := range c.Errors() {
 				t.Error(err)
 			}
 		}(consumer)
 
 		wg.Add(1)
-		go func(partition int32, c *PartitionConsumer) {
+		go func(partition int32, c PartitionConsumer) {
 			for i := 0; i < 10; i++ {
 				message := <-consumer.Messages()
 				if message.Offset != int64(i) {

+ 268 - 0
mocks/consumer.go

@@ -0,0 +1,268 @@
+package mocks
+
+import (
+	"sync"
+
+	"github.com/Shopify/sarama"
+)
+
+// Consumer implements sarama's Consumer interface for testing purposes.
+// Before you can start consuming from this consumer, you have to register
+// topic/partitions using ExpectConsumePartition, and set expectations on them.
+type Consumer struct {
+	l                  sync.Mutex
+	t                  ErrorReporter
+	config             *sarama.Config
+	partitionConsumers map[string]map[int32]*PartitionConsumer
+}
+
+// NewConsumer returns a new mock Consumer instance. The t argument should
+// be the *testing.T instance of your test method. An error will be written to it if
+// an expectation is violated. The config argument is currently unused and can be set to nil.
+func NewConsumer(t ErrorReporter, config *sarama.Config) *Consumer {
+	if config == nil {
+		config = sarama.NewConfig()
+	}
+
+	c := &Consumer{
+		t:                  t,
+		config:             config,
+		partitionConsumers: make(map[string]map[int32]*PartitionConsumer),
+	}
+	return c
+}
+
+///////////////////////////////////////////////////
+// Consumer interface implementation
+///////////////////////////////////////////////////
+
+// ConsumePartition implements the ConsumePartition method from the sarama.Consumer interface.
+// Before you can start consuming a partition, you have to set expectations on it using
+// ExpectConsumePartition. You can only consume a partition once per consumer.
+func (c *Consumer) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) {
+	c.l.Lock()
+	defer c.l.Unlock()
+
+	if c.partitionConsumers[topic] == nil || c.partitionConsumers[topic][partition] == nil {
+		c.t.Errorf("No expectations set for %s/%d", topic, partition)
+		return nil, errOutOfExpectations
+	}
+
+	pc := c.partitionConsumers[topic][partition]
+	if pc.consumed {
+		return nil, sarama.ConfigurationError("The topic/partition is already being consumed")
+	}
+
+	if pc.offset != AnyOffset && pc.offset != offset {
+		c.t.Errorf("Unexpected offset when calling ConsumePartition for %s/%d. Expected %d, got %d.", topic, partition, pc.offset, offset)
+	}
+
+	pc.consumed = true
+	go pc.handleExpectations()
+	return pc, nil
+}
+
+// Close implements the Close method from the sarama.Consumer interface. It will close
+// all registered PartitionConsumer instances.
+func (c *Consumer) Close() error {
+	c.l.Lock()
+	defer c.l.Unlock()
+
+	for _, partitions := range c.partitionConsumers {
+		for _, partitionConsumer := range partitions {
+			_ = partitionConsumer.Close()
+		}
+	}
+
+	return nil
+}
+
+///////////////////////////////////////////////////
+// Expectation API
+///////////////////////////////////////////////////
+
+// ExpectConsumePartition will register a topic/partition, so you can set expectations on it.
+// The registered PartitionConsumer will be returned, so you can set expectations
+// on it using method chanining. Once a topic/partition is registered, you are
+// expected to start consuming it using ConsumePartition. If that doesn't happen,
+// an error will be written to the error reporter once the mock consumer is closed. It will
+// also expect that the
+func (c *Consumer) ExpectConsumePartition(topic string, partition int32, offset int64) *PartitionConsumer {
+	c.l.Lock()
+	defer c.l.Unlock()
+
+	if c.partitionConsumers[topic] == nil {
+		c.partitionConsumers[topic] = make(map[int32]*PartitionConsumer)
+	}
+
+	if c.partitionConsumers[topic][partition] == nil {
+		c.partitionConsumers[topic][partition] = &PartitionConsumer{
+			t:            c.t,
+			topic:        topic,
+			partition:    partition,
+			offset:       offset,
+			expectations: make(chan *consumerExpectation, 1000),
+			messages:     make(chan *sarama.ConsumerMessage, c.config.ChannelBufferSize),
+			errors:       make(chan *sarama.ConsumerError, c.config.ChannelBufferSize),
+		}
+	}
+
+	return c.partitionConsumers[topic][partition]
+}
+
+///////////////////////////////////////////////////
+// PartitionConsumer mock type
+///////////////////////////////////////////////////
+
+// PartitionConsumer implements sarama's PartitionConsumer interface for testing purposes.
+// It is returned by the mock Consumers ConsumePartitionMethod, but only if it is
+// registered first using the Consumer's ExpectConsumePartition method. Before consuming the
+// Errors and Messages channel, you should specify what values will be provided on these
+// channels using YieldMessage and YieldError.
+type PartitionConsumer struct {
+	l                       sync.Mutex
+	t                       ErrorReporter
+	topic                   string
+	partition               int32
+	offset                  int64
+	expectations            chan *consumerExpectation
+	messages                chan *sarama.ConsumerMessage
+	errors                  chan *sarama.ConsumerError
+	singleClose             sync.Once
+	consumed                bool
+	errorsShouldBeDrained   bool
+	messagesShouldBeDrained bool
+}
+
+func (pc *PartitionConsumer) handleExpectations() {
+	pc.l.Lock()
+	defer pc.l.Unlock()
+
+	var offset int64
+	for ex := range pc.expectations {
+		if ex.Err != nil {
+			pc.errors <- &sarama.ConsumerError{
+				Topic:     pc.topic,
+				Partition: pc.partition,
+				Err:       ex.Err,
+			}
+		} else {
+			offset++
+
+			ex.Msg.Topic = pc.topic
+			ex.Msg.Partition = pc.partition
+			ex.Msg.Offset = offset
+
+			pc.messages <- ex.Msg
+		}
+	}
+
+	close(pc.messages)
+	close(pc.errors)
+}
+
+///////////////////////////////////////////////////
+// PartitionConsumer interface implementation
+///////////////////////////////////////////////////
+
+// AsyncClose implements the AsyncClose method from the sarama.PartitionConsumer interface.
+func (pc *PartitionConsumer) AsyncClose() {
+	pc.singleClose.Do(func() {
+		close(pc.expectations)
+	})
+}
+
+// Close implements the Close method from the sarama.PartitionConsumer interface. It will
+// verify whether the partition consumer was actually started.
+func (pc *PartitionConsumer) Close() error {
+	if !pc.consumed {
+		pc.t.Errorf("Expectations set on %s/%d, but no partition consumer was started.", pc.topic, pc.partition)
+		return errPartitionConsumerNotStarted
+	}
+
+	if pc.errorsShouldBeDrained && len(pc.errors) > 0 {
+		pc.t.Errorf("Expected the errors channel for %s/%d to be drained on close, but found %d errors.", pc.topic, pc.partition, len(pc.errors))
+	}
+
+	if pc.messagesShouldBeDrained && len(pc.messages) > 0 {
+		pc.t.Errorf("Expected the messages channel for %s/%d to be drained on close, but found %d messages.", pc.topic, pc.partition, len(pc.messages))
+	}
+
+	pc.AsyncClose()
+
+	var (
+		closeErr error
+		wg       sync.WaitGroup
+	)
+
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+
+		var errs = make(sarama.ConsumerErrors, 0)
+		for err := range pc.errors {
+			errs = append(errs, err)
+		}
+
+		if len(errs) > 0 {
+			closeErr = errs
+		}
+	}()
+
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		for _ = range pc.messages {
+			// drain
+		}
+	}()
+
+	wg.Wait()
+	return closeErr
+}
+
+// Errors implements the Errors method from the sarama.PartitionConsumer interface.
+func (pc *PartitionConsumer) Errors() <-chan *sarama.ConsumerError {
+	return pc.errors
+}
+
+// Messages implements the Messages method from the sarama.PartitionConsumer interface.
+func (pc *PartitionConsumer) Messages() <-chan *sarama.ConsumerMessage {
+	return pc.messages
+}
+
+///////////////////////////////////////////////////
+// Expectation API
+///////////////////////////////////////////////////
+
+// YieldMessage will yield a messages Messages channel of this partition consumer
+// when it is consumed. By default, the mock consumer will not verify whether this
+// message was consumed from the Messages channel, because there are legitimate
+// reasons forthis not to happen. ou can call ExpectMessagesDrainedOnClose so it will
+// verify that the channel is empty on close.
+func (pc *PartitionConsumer) YieldMessage(msg *sarama.ConsumerMessage) {
+	pc.expectations <- &consumerExpectation{Msg: msg}
+}
+
+// YieldError will yield an error on the Errors channel of this partition consumer
+// when it is consumed. By default, the mock consumer will not verify whether this error was
+// consumed from the Errors channel, because there are legitimate reasons for this
+// not to happen. You can call ExpectErrorsDrainedOnClose so it will verify that
+// the channel is empty on close.
+func (pc *PartitionConsumer) YieldError(err error) {
+	pc.expectations <- &consumerExpectation{Err: err}
+}
+
+// ExpectMessagesDrainedOnClose sets an expectation on the partition consumer
+// that the messages channel will be fully drained when Close is called. If this
+// expectation is not met, an error is reported to the error reporter.
+func (pc *PartitionConsumer) ExpectMessagesDrainedOnClose() {
+	pc.messagesShouldBeDrained = true
+}
+
+// ExpectErrorsDrainedOnClose sets an expectation on the partition consumer
+// that the errors channel will be fully drained when Close is called. If this
+// expectation is not met, an error is reported to the error reporter.
+func (pc *PartitionConsumer) ExpectErrorsDrainedOnClose() {
+	pc.errorsShouldBeDrained = true
+}

+ 192 - 0
mocks/consumer_test.go

@@ -0,0 +1,192 @@
+package mocks
+
+import (
+	"testing"
+
+	"github.com/Shopify/sarama"
+)
+
+func TestMockConsumerImplementsConsumerInterface(t *testing.T) {
+	var c interface{} = &Consumer{}
+	if _, ok := c.(sarama.Consumer); !ok {
+		t.Error("The mock consumer should implement the sarama.Consumer interface.")
+	}
+
+	var pc interface{} = &PartitionConsumer{}
+	if _, ok := pc.(sarama.PartitionConsumer); !ok {
+		t.Error("The mock partitionconsumer should implement the sarama.PartitionConsumer interface.")
+	}
+}
+
+func TestConsumerHandlesExpectations(t *testing.T) {
+	consumer := NewConsumer(t, nil)
+	defer func() {
+		if err := consumer.Close(); err != nil {
+			t.Error(err)
+		}
+	}()
+
+	consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello world")})
+	consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldError(sarama.ErrOutOfBrokers)
+	consumer.ExpectConsumePartition("test", 1, sarama.OffsetOldest).YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello world again")})
+	consumer.ExpectConsumePartition("other", 0, AnyOffset).YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello other")})
+
+	pc_test0, err := consumer.ConsumePartition("test", 0, sarama.OffsetOldest)
+	if err != nil {
+		t.Fatal(err)
+	}
+	test0_msg := <-pc_test0.Messages()
+	if test0_msg.Topic != "test" || test0_msg.Partition != 0 || string(test0_msg.Value) != "hello world" {
+		t.Error("Message was not as expected:", test0_msg)
+	}
+	test0_err := <-pc_test0.Errors()
+	if test0_err.Err != sarama.ErrOutOfBrokers {
+		t.Error("Expected sarama.ErrOutOfBrokers, found:", test0_err.Err)
+	}
+
+	pc_test1, err := consumer.ConsumePartition("test", 1, sarama.OffsetOldest)
+	if err != nil {
+		t.Fatal(err)
+	}
+	test1_msg := <-pc_test1.Messages()
+	if test1_msg.Topic != "test" || test1_msg.Partition != 1 || string(test1_msg.Value) != "hello world again" {
+		t.Error("Message was not as expected:", test1_msg)
+	}
+
+	pc_other0, err := consumer.ConsumePartition("other", 0, sarama.OffsetNewest)
+	if err != nil {
+		t.Fatal(err)
+	}
+	other0_msg := <-pc_other0.Messages()
+	if other0_msg.Topic != "other" || other0_msg.Partition != 0 || string(other0_msg.Value) != "hello other" {
+		t.Error("Message was not as expected:", other0_msg)
+	}
+}
+
+func TestConsumerReturnsNonconsumedErrorsOnClose(t *testing.T) {
+	consumer := NewConsumer(t, nil)
+	consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldError(sarama.ErrOutOfBrokers)
+	consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldError(sarama.ErrOutOfBrokers)
+
+	pc, err := consumer.ConsumePartition("test", 0, sarama.OffsetOldest)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	select {
+	case <-pc.Messages():
+		t.Error("Did not epxect a message on the messages channel.")
+	case err := <-pc.Errors():
+		if err.Err != sarama.ErrOutOfBrokers {
+			t.Error("Expected sarama.ErrOutOfBrokers, found", err)
+		}
+	}
+
+	errs := pc.Close().(sarama.ConsumerErrors)
+	if len(errs) != 1 && errs[0].Err != sarama.ErrOutOfBrokers {
+		t.Error("Expected Close to return the remaining sarama.ErrOutOfBrokers")
+	}
+}
+
+func TestConsumerWithoutExpectationsOnPartition(t *testing.T) {
+	trm := newTestReporterMock()
+	consumer := NewConsumer(trm, nil)
+
+	_, err := consumer.ConsumePartition("test", 1, sarama.OffsetOldest)
+	if err != errOutOfExpectations {
+		t.Error("Expected ConsumePartition to return errOutOfExpectations")
+	}
+
+	if err := consumer.Close(); err != nil {
+		t.Error("No error expected on close, but found:", err)
+	}
+
+	if len(trm.errors) != 1 {
+		t.Errorf("Expected an expectation failure to be set on the error reporter.")
+	}
+}
+
+func TestConsumerWithExpectationsOnUnconsumedPartition(t *testing.T) {
+	trm := newTestReporterMock()
+	consumer := NewConsumer(trm, nil)
+	consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello world")})
+
+	if err := consumer.Close(); err != nil {
+		t.Error("No error expected on close, but found:", err)
+	}
+
+	if len(trm.errors) != 1 {
+		t.Errorf("Expected an expectation failure to be set on the error reporter.")
+	}
+}
+
+func TestConsumerWithWrongOffsetExpectation(t *testing.T) {
+	trm := newTestReporterMock()
+	consumer := NewConsumer(trm, nil)
+	consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest)
+
+	_, err := consumer.ConsumePartition("test", 0, sarama.OffsetNewest)
+	if err != nil {
+		t.Error("Did not expect error, found:", err)
+	}
+
+	if len(trm.errors) != 1 {
+		t.Errorf("Expected an expectation failure to be set on the error reporter.")
+	}
+
+	if err := consumer.Close(); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestConsumerViolatesMessagesDrainedExpectation(t *testing.T) {
+	trm := newTestReporterMock()
+	consumer := NewConsumer(trm, nil)
+	pcmock := consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest)
+	pcmock.YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello")})
+	pcmock.YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello")})
+	pcmock.ExpectMessagesDrainedOnClose()
+
+	pc, err := consumer.ConsumePartition("test", 0, sarama.OffsetOldest)
+	if err != nil {
+		t.Error(err)
+	}
+
+	// consume first message, not second one
+	<-pc.Messages()
+
+	if err := consumer.Close(); err != nil {
+		t.Error(err)
+	}
+
+	if len(trm.errors) != 1 {
+		t.Errorf("Expected an expectation failure to be set on the error reporter.")
+	}
+}
+
+func TestConsumerMeetsErrorsDrainedExpectation(t *testing.T) {
+	trm := newTestReporterMock()
+	consumer := NewConsumer(trm, nil)
+
+	pcmock := consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest)
+	pcmock.YieldError(sarama.ErrInvalidMessage)
+	pcmock.YieldError(sarama.ErrInvalidMessage)
+	pcmock.ExpectErrorsDrainedOnClose()
+
+	pc, err := consumer.ConsumePartition("test", 0, sarama.OffsetOldest)
+	if err != nil {
+		t.Error(err)
+	}
+
+	// consume first and second error,
+	<-pc.Errors()
+	<-pc.Errors()
+
+	if err := consumer.Close(); err != nil {
+		t.Error(err)
+	}
+
+	if len(trm.errors) != 0 {
+		t.Errorf("Expected ano expectation failures to be set on the error reporter.")
+	}
+}

+ 12 - 2
mocks/mocks.go

@@ -15,6 +15,8 @@ package mocks
 
 import (
 	"errors"
+
+	"github.com/Shopify/sarama"
 )
 
 // A simple interface that includes the testing.T methods we use to report
@@ -24,10 +26,18 @@ type ErrorReporter interface {
 }
 
 var (
-	errProduceSuccess    error = nil
-	errOutOfExpectations       = errors.New("No more expectations set on mock producer")
+	errProduceSuccess              error = nil
+	errOutOfExpectations                 = errors.New("No more expectations set on mock")
+	errPartitionConsumerNotStarted       = errors.New("The partition consumer was never started")
 )
 
+const AnyOffset int64 = -1000
+
 type producerExpectation struct {
 	Result error
 }
+
+type consumerExpectation struct {
+	Err error
+	Msg *sarama.ConsumerMessage
+}

+ 9 - 3
mocks/producer_test.go

@@ -55,7 +55,9 @@ func TestProducerReturnsExpectationsToChannels(t *testing.T) {
 		t.Error("Expected message 3 to be returned as error")
 	}
 
-	mp.Close()
+	if err := mp.Close(); err != nil {
+		t.Error(err)
+	}
 }
 
 func TestProducerWithTooFewExpectations(t *testing.T) {
@@ -66,7 +68,9 @@ func TestProducerWithTooFewExpectations(t *testing.T) {
 	mp.Input() <- &sarama.ProducerMessage{Topic: "test"}
 	mp.Input() <- &sarama.ProducerMessage{Topic: "test"}
 
-	mp.Close()
+	if err := mp.Close(); err != nil {
+		t.Error(err)
+	}
 
 	if len(trm.errors) != 1 {
 		t.Error("Expected to report an error")
@@ -80,7 +84,9 @@ func TestProducerWithTooManyExpectations(t *testing.T) {
 	mp.ExpectInputAndFail(sarama.ErrOutOfBrokers)
 
 	mp.Input() <- &sarama.ProducerMessage{Topic: "test"}
-	mp.Close()
+	if err := mp.Close(); err != nil {
+		t.Error(err)
+	}
 
 	if len(trm.errors) != 1 {
 		t.Error("Expected to report an error")

+ 23 - 7
mocks/sync_producer_test.go

@@ -15,7 +15,11 @@ func TestMockSyncProducerImplementsSyncProducerInterface(t *testing.T) {
 
 func TestSyncProducerReturnsExpectationsToSendMessage(t *testing.T) {
 	sp := NewSyncProducer(t, nil)
-	defer sp.Close()
+	defer func() {
+		if err := sp.Close(); err != nil {
+			t.Error(err)
+		}
+	}()
 
 	sp.ExpectSendMessageAndSucceed()
 	sp.ExpectSendMessageAndSucceed()
@@ -47,7 +51,9 @@ func TestSyncProducerReturnsExpectationsToSendMessage(t *testing.T) {
 		t.Errorf("The third message should not have been produced successfully")
 	}
 
-	sp.Close()
+	if err := sp.Close(); err != nil {
+		t.Error(err)
+	}
 }
 
 func TestSyncProducerWithTooManyExpectations(t *testing.T) {
@@ -57,9 +63,13 @@ func TestSyncProducerWithTooManyExpectations(t *testing.T) {
 	sp.ExpectSendMessageAndSucceed()
 	sp.ExpectSendMessageAndFail(sarama.ErrOutOfBrokers)
 
-	sp.SendMessage("test", nil, sarama.StringEncoder("test"))
+	if _, _, err := sp.SendMessage("test", nil, sarama.StringEncoder("test")); err != nil {
+		t.Error("No error expected on first SendMessage call", err)
+	}
 
-	sp.Close()
+	if err := sp.Close(); err != nil {
+		t.Error(err)
+	}
 
 	if len(trm.errors) != 1 {
 		t.Error("Expected to report an error")
@@ -72,10 +82,16 @@ func TestSyncProducerWithTooFewExpectations(t *testing.T) {
 	sp := NewSyncProducer(trm, nil)
 	sp.ExpectSendMessageAndSucceed()
 
-	sp.SendMessage("test", nil, sarama.StringEncoder("test"))
-	sp.SendMessage("test", nil, sarama.StringEncoder("test"))
+	if _, _, err := sp.SendMessage("test", nil, sarama.StringEncoder("test")); err != nil {
+		t.Error("No error expected on first SendMessage call", err)
+	}
+	if _, _, err := sp.SendMessage("test", nil, sarama.StringEncoder("test")); err != errOutOfExpectations {
+		t.Error("errOutOfExpectations expected on second SendMessage call, found:", err)
+	}
 
-	sp.Close()
+	if err := sp.Close(); err != nil {
+		t.Error(err)
+	}
 
 	if len(trm.errors) != 1 {
 		t.Error("Expected to report an error")