소스 검색

Merge branch 'master' into minor

Vlad Gorodetsky 7 년 전
부모
커밋
09c19c9929
19개의 변경된 파일1721개의 추가작업 그리고 38개의 파일을 삭제
  1. 4 4
      .travis.yml
  2. 16 0
      CHANGELOG.md
  3. 1 1
      Makefile
  4. 1 1
      README.md
  5. 129 0
      balance_strategy.go
  6. 102 0
      balance_strategy_test.go
  7. 6 6
      client_tls_test.go
  8. 82 7
      config.go
  9. 1 1
      config_test.go
  10. 774 0
      consumer_group.go
  11. 58 0
      consumer_group_test.go
  12. 1 1
      dev.yml
  13. 418 0
      functional_consumer_group_test.go
  14. 26 4
      message.go
  15. 44 0
      message_test.go
  16. 11 0
      metadata_response.go
  17. 1 2
      mocks/async_producer.go
  18. 35 11
      offset_manager.go
  19. 11 0
      record_batch.go

+ 4 - 4
.travis.yml

@@ -1,8 +1,8 @@
 language: go
 language: go
 go:
 go:
-- 1.8.x
-- 1.9.x
-- 1.10.x
+- 1.9.7
+- 1.10.4
+- 1.11
 
 
 env:
 env:
   global:
   global:
@@ -28,7 +28,7 @@ script:
 - make test
 - make test
 - make vet
 - make vet
 - make errcheck
 - make errcheck
-- make fmt
+- if [ "$TRAVIS_GO_VERSION" = "1.11" ]; then make fmt; fi
 
 
 after_success:
 after_success:
 - bash <(curl -s https://codecov.io/bash)
 - bash <(curl -s https://codecov.io/bash)

+ 16 - 0
CHANGELOG.md

@@ -1,5 +1,21 @@
 # Changelog
 # Changelog
 
 
+#### Version 1.19.0 (2018-09-27)
+
+New Features:
+ - Implement a higher-level consumer group
+   ([#1099](https://github.com/Shopify/sarama/pull/1099)).
+
+Improvements:
+ - Add support for Go 1.11
+   ([#1176](https://github.com/Shopify/sarama/pull/1176)).
+
+Bug Fixes:
+ - Fix encoding of `MetadataResponse` with version 2 and higher
+   ([#1174](https://github.com/Shopify/sarama/pull/1174)).
+ - Fix race condition in mock async producer
+   ([#1174](https://github.com/Shopify/sarama/pull/1174)).
+
 #### Version 1.18.0 (2018-09-07)
 #### Version 1.18.0 (2018-09-07)
 
 
 New Features:
 New Features:

+ 1 - 1
Makefile

@@ -4,7 +4,7 @@ default: fmt vet errcheck test
 test:
 test:
 	echo "" > coverage.txt
 	echo "" > coverage.txt
 	for d in `go list ./... | grep -v vendor`; do \
 	for d in `go list ./... | grep -v vendor`; do \
-		go test -p 1 -v -timeout 90s -race -coverprofile=profile.out -covermode=atomic $$d || exit 1; \
+		go test -p 1 -v -timeout 240s -race -coverprofile=profile.out -covermode=atomic $$d || exit 1; \
 		if [ -f profile.out ]; then \
 		if [ -f profile.out ]; then \
 			cat profile.out >> coverage.txt; \
 			cat profile.out >> coverage.txt; \
 			rm profile.out; \
 			rm profile.out; \

+ 1 - 1
README.md

@@ -21,7 +21,7 @@ You might also want to look at the [Frequently Asked Questions](https://github.c
 Sarama provides a "2 releases + 2 months" compatibility guarantee: we support
 Sarama provides a "2 releases + 2 months" compatibility guarantee: we support
 the two latest stable releases of Kafka and Go, and we provide a two month
 the two latest stable releases of Kafka and Go, and we provide a two month
 grace period for older releases. This means we currently officially support
 grace period for older releases. This means we currently officially support
-Go 1.8 through 1.10, and Kafka 1.0 through 2.0, although older releases are
+Go 1.8 through 1.11, and Kafka 1.0 through 2.0, although older releases are
 still likely to work.
 still likely to work.
 
 
 Sarama follows semantic versioning and provides API stability via the gopkg.in service.
 Sarama follows semantic versioning and provides API stability via the gopkg.in service.

+ 129 - 0
balance_strategy.go

@@ -0,0 +1,129 @@
+package sarama
+
+import (
+	"math"
+	"sort"
+)
+
+// BalanceStrategyPlan is the results of any BalanceStrategy.Plan attempt.
+// It contains an allocation of topic/partitions by memberID in the form of
+// a `memberID -> topic -> partitions` map.
+type BalanceStrategyPlan map[string]map[string][]int32
+
+// Add assigns a topic with a number partitions to a member.
+func (p BalanceStrategyPlan) Add(memberID, topic string, partitions ...int32) {
+	if len(partitions) == 0 {
+		return
+	}
+	if _, ok := p[memberID]; !ok {
+		p[memberID] = make(map[string][]int32, 1)
+	}
+	p[memberID][topic] = append(p[memberID][topic], partitions...)
+}
+
+// --------------------------------------------------------------------
+
+// BalanceStrategy is used to balance topics and partitions
+// across memebers of a consumer group
+type BalanceStrategy interface {
+	// Name uniquely identifies the strategy.
+	Name() string
+
+	// Plan accepts a map of `memberID -> metadata` and a map of `topic -> partitions`
+	// and returns a distribution plan.
+	Plan(members map[string]ConsumerGroupMemberMetadata, topics map[string][]int32) (BalanceStrategyPlan, error)
+}
+
+// --------------------------------------------------------------------
+
+// BalanceStrategyRange is the default and assigns partitions as ranges to consumer group members.
+// Example with one topic T with six partitions (0..5) and two members (M1, M2):
+//   M1: {T: [0, 1, 2]}
+//   M2: {T: [3, 4, 5]}
+var BalanceStrategyRange = &balanceStrategy{
+	name: "range",
+	coreFn: func(plan BalanceStrategyPlan, memberIDs []string, topic string, partitions []int32) {
+		step := float64(len(partitions)) / float64(len(memberIDs))
+
+		for i, memberID := range memberIDs {
+			pos := float64(i)
+			min := int(math.Floor(pos*step + 0.5))
+			max := int(math.Floor((pos+1)*step + 0.5))
+			plan.Add(memberID, topic, partitions[min:max]...)
+		}
+	},
+}
+
+// BalanceStrategyRoundRobin assigns partitions to members in alternating order.
+// Example with topic T with six partitions (0..5) and two members (M1, M2):
+//   M1: {T: [0, 2, 4]}
+//   M2: {T: [1, 3, 5]}
+var BalanceStrategyRoundRobin = &balanceStrategy{
+	name: "roundrobin",
+	coreFn: func(plan BalanceStrategyPlan, memberIDs []string, topic string, partitions []int32) {
+		for i, part := range partitions {
+			memberID := memberIDs[i%len(memberIDs)]
+			plan.Add(memberID, topic, part)
+		}
+	},
+}
+
+// --------------------------------------------------------------------
+
+type balanceStrategy struct {
+	name   string
+	coreFn func(plan BalanceStrategyPlan, memberIDs []string, topic string, partitions []int32)
+}
+
+// Name implements BalanceStrategy.
+func (s *balanceStrategy) Name() string { return s.name }
+
+// Balance implements BalanceStrategy.
+func (s *balanceStrategy) Plan(members map[string]ConsumerGroupMemberMetadata, topics map[string][]int32) (BalanceStrategyPlan, error) {
+	// Build members by topic map
+	mbt := make(map[string][]string)
+	for memberID, meta := range members {
+		for _, topic := range meta.Topics {
+			mbt[topic] = append(mbt[topic], memberID)
+		}
+	}
+
+	// Sort members for each topic
+	for topic, memberIDs := range mbt {
+		sort.Sort(&balanceStrategySortable{
+			topic:     topic,
+			memberIDs: memberIDs,
+		})
+	}
+
+	// Assemble plan
+	plan := make(BalanceStrategyPlan, len(members))
+	for topic, memberIDs := range mbt {
+		s.coreFn(plan, memberIDs, topic, topics[topic])
+	}
+	return plan, nil
+}
+
+type balanceStrategySortable struct {
+	topic     string
+	memberIDs []string
+}
+
+func (p balanceStrategySortable) Len() int { return len(p.memberIDs) }
+func (p balanceStrategySortable) Swap(i, j int) {
+	p.memberIDs[i], p.memberIDs[j] = p.memberIDs[j], p.memberIDs[i]
+}
+func (p balanceStrategySortable) Less(i, j int) bool {
+	return balanceStrategyHashValue(p.topic, p.memberIDs[i]) < balanceStrategyHashValue(p.topic, p.memberIDs[j])
+}
+
+func balanceStrategyHashValue(vv ...string) uint32 {
+	h := uint32(2166136261)
+	for _, s := range vv {
+		for _, c := range s {
+			h ^= uint32(c)
+			h *= 16777619
+		}
+	}
+	return h
+}

+ 102 - 0
balance_strategy_test.go

@@ -0,0 +1,102 @@
+package sarama
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestBalanceStrategyRange(t *testing.T) {
+	tests := []struct {
+		members  map[string][]string
+		topics   map[string][]int32
+		expected BalanceStrategyPlan
+	}{
+		{
+			members: map[string][]string{"M1": {"T1", "T2"}, "M2": {"T1", "T2"}},
+			topics:  map[string][]int32{"T1": {0, 1, 2, 3}, "T2": {0, 1, 2, 3}},
+			expected: BalanceStrategyPlan{
+				"M1": map[string][]int32{"T1": {0, 1}, "T2": {2, 3}},
+				"M2": map[string][]int32{"T1": {2, 3}, "T2": {0, 1}},
+			},
+		},
+		{
+			members: map[string][]string{"M1": {"T1", "T2"}, "M2": {"T1", "T2"}},
+			topics:  map[string][]int32{"T1": {0, 1, 2}, "T2": {0, 1, 2}},
+			expected: BalanceStrategyPlan{
+				"M1": map[string][]int32{"T1": {0, 1}, "T2": {2}},
+				"M2": map[string][]int32{"T1": {2}, "T2": {0, 1}},
+			},
+		},
+		{
+			members: map[string][]string{"M1": {"T1"}, "M2": {"T1", "T2"}},
+			topics:  map[string][]int32{"T1": {0, 1}, "T2": {0, 1}},
+			expected: BalanceStrategyPlan{
+				"M1": map[string][]int32{"T1": {0}},
+				"M2": map[string][]int32{"T1": {1}, "T2": {0, 1}},
+			},
+		},
+	}
+
+	strategy := BalanceStrategyRange
+	if strategy.Name() != "range" {
+		t.Errorf("Unexpected stategy name\nexpected: range\nactual: %v", strategy.Name())
+	}
+
+	for _, test := range tests {
+		members := make(map[string]ConsumerGroupMemberMetadata)
+		for memberID, topics := range test.members {
+			members[memberID] = ConsumerGroupMemberMetadata{Topics: topics}
+		}
+
+		actual, err := strategy.Plan(members, test.topics)
+		if err != nil {
+			t.Errorf("Unexpected error %v", err)
+		} else if !reflect.DeepEqual(actual, test.expected) {
+			t.Errorf("Plan does not match expectation\nexpected: %#v\nactual: %#v", test.expected, actual)
+		}
+	}
+}
+
+func TestBalanceStrategyRoundRobin(t *testing.T) {
+	tests := []struct {
+		members  map[string][]string
+		topics   map[string][]int32
+		expected BalanceStrategyPlan
+	}{
+		{
+			members: map[string][]string{"M1": {"T1", "T2"}, "M2": {"T1", "T2"}},
+			topics:  map[string][]int32{"T1": {0, 1, 2, 3}, "T2": {0, 1, 2, 3}},
+			expected: BalanceStrategyPlan{
+				"M1": map[string][]int32{"T1": {0, 2}, "T2": {1, 3}},
+				"M2": map[string][]int32{"T1": {1, 3}, "T2": {0, 2}},
+			},
+		},
+		{
+			members: map[string][]string{"M1": {"T1", "T2"}, "M2": {"T1", "T2"}},
+			topics:  map[string][]int32{"T1": {0, 1, 2}, "T2": {0, 1, 2}},
+			expected: BalanceStrategyPlan{
+				"M1": map[string][]int32{"T1": {0, 2}, "T2": {1}},
+				"M2": map[string][]int32{"T1": {1}, "T2": {0, 2}},
+			},
+		},
+	}
+
+	strategy := BalanceStrategyRoundRobin
+	if strategy.Name() != "roundrobin" {
+		t.Errorf("Unexpected stategy name\nexpected: range\nactual: %v", strategy.Name())
+	}
+
+	for _, test := range tests {
+		members := make(map[string]ConsumerGroupMemberMetadata)
+		for memberID, topics := range test.members {
+			members[memberID] = ConsumerGroupMemberMetadata{Topics: topics}
+		}
+
+		actual, err := strategy.Plan(members, test.topics)
+		if err != nil {
+			t.Errorf("Unexpected error %v", err)
+		} else if !reflect.DeepEqual(actual, test.expected) {
+			t.Errorf("Plan does not match expectation\nexpected: %#v\nactual: %#v", test.expected, actual)
+		}
+	}
+}

+ 6 - 6
client_tls_test.go

@@ -33,12 +33,12 @@ func TestTLS(t *testing.T) {
 	nva := time.Now().Add(1 * time.Hour)
 	nva := time.Now().Add(1 * time.Hour)
 
 
 	caTemplate := &x509.Certificate{
 	caTemplate := &x509.Certificate{
-		Subject:      pkix.Name{CommonName: "ca"},
-		Issuer:       pkix.Name{CommonName: "ca"},
-		SerialNumber: big.NewInt(0),
-		NotAfter:     nva,
-		NotBefore:    nvb,
-		IsCA:         true,
+		Subject:               pkix.Name{CommonName: "ca"},
+		Issuer:                pkix.Name{CommonName: "ca"},
+		SerialNumber:          big.NewInt(0),
+		NotAfter:              nva,
+		NotBefore:             nvb,
+		IsCA:                  true,
 		BasicConstraintsValid: true,
 		BasicConstraintsValid: true,
 		KeyUsage:              x509.KeyUsageCertSign,
 		KeyUsage:              x509.KeyUsageCertSign,
 	}
 	}

+ 82 - 7
config.go

@@ -173,14 +173,55 @@ type Config struct {
 
 
 	// Consumer is the namespace for configuration related to consuming messages,
 	// Consumer is the namespace for configuration related to consuming messages,
 	// used by the Consumer.
 	// used by the Consumer.
-	//
-	// Note that Sarama's Consumer type does not currently support automatic
-	// consumer-group rebalancing and offset tracking.  For Zookeeper-based
-	// tracking (Kafka 0.8.2 and earlier), the https://github.com/wvanbergen/kafka
-	// library builds on Sarama to add this support. For Kafka-based tracking
-	// (Kafka 0.9 and later), the https://github.com/bsm/sarama-cluster library
-	// builds on Sarama to add this support.
 	Consumer struct {
 	Consumer struct {
+
+		// Group is the namespace for configuring consumer group.
+		Group struct {
+			Session struct {
+				// The timeout used to detect consumer failures when using Kafka's group management facility.
+				// The consumer sends periodic heartbeats to indicate its liveness to the broker.
+				// If no heartbeats are received by the broker before the expiration of this session timeout,
+				// then the broker will remove this consumer from the group and initiate a rebalance.
+				// Note that the value must be in the allowable range as configured in the broker configuration
+				// by `group.min.session.timeout.ms` and `group.max.session.timeout.ms` (default 10s)
+				Timeout time.Duration
+			}
+			Heartbeat struct {
+				// The expected time between heartbeats to the consumer coordinator when using Kafka's group
+				// management facilities. Heartbeats are used to ensure that the consumer's session stays active and
+				// to facilitate rebalancing when new consumers join or leave the group.
+				// The value must be set lower than Consumer.Group.Session.Timeout, but typically should be set no
+				// higher than 1/3 of that value.
+				// It can be adjusted even lower to control the expected time for normal rebalances (default 3s)
+				Interval time.Duration
+			}
+			Rebalance struct {
+				// Strategy for allocating topic partitions to members (default BalanceStrategyRange)
+				Strategy BalanceStrategy
+				// The maximum allowed time for each worker to join the group once a rebalance has begun.
+				// This is basically a limit on the amount of time needed for all tasks to flush any pending
+				// data and commit offsets. If the timeout is exceeded, then the worker will be removed from
+				// the group, which will cause offset commit failures (default 60s).
+				Timeout time.Duration
+
+				Retry struct {
+					// When a new consumer joins a consumer group the set of consumers attempt to "rebalance"
+					// the load to assign partitions to each consumer. If the set of consumers changes while
+					// this assignment is taking place the rebalance will fail and retry. This setting controls
+					// the maximum number of attempts before giving up (default 4).
+					Max int
+					// Backoff time between retries during rebalance (default 2s)
+					Backoff time.Duration
+				}
+			}
+			Member struct {
+				// Custom metadata to include when joining the group. The user data for all joined members
+				// can be retrieved by sending a DescribeGroupRequest to the broker that is the
+				// coordinator for the group.
+				UserData []byte
+			}
+		}
+
 		Retry struct {
 		Retry struct {
 			// How long to wait after a failing to read from a partition before
 			// How long to wait after a failing to read from a partition before
 			// trying again (default 2s).
 			// trying again (default 2s).
@@ -331,6 +372,13 @@ func NewConfig() *Config {
 	c.Consumer.Offsets.Initial = OffsetNewest
 	c.Consumer.Offsets.Initial = OffsetNewest
 	c.Consumer.Offsets.Retry.Max = 3
 	c.Consumer.Offsets.Retry.Max = 3
 
 
+	c.Consumer.Group.Session.Timeout = 10 * time.Second
+	c.Consumer.Group.Heartbeat.Interval = 3 * time.Second
+	c.Consumer.Group.Rebalance.Strategy = BalanceStrategyRange
+	c.Consumer.Group.Rebalance.Timeout = 60 * time.Second
+	c.Consumer.Group.Rebalance.Retry.Max = 4
+	c.Consumer.Group.Rebalance.Retry.Backoff = 2 * time.Second
+
 	c.ClientID = defaultClientID
 	c.ClientID = defaultClientID
 	c.ChannelBufferSize = 256
 	c.ChannelBufferSize = 256
 	c.Version = MinVersion
 	c.Version = MinVersion
@@ -378,6 +426,15 @@ func (c *Config) Validate() error {
 	if c.Consumer.Offsets.Retention%time.Millisecond != 0 {
 	if c.Consumer.Offsets.Retention%time.Millisecond != 0 {
 		Logger.Println("Consumer.Offsets.Retention only supports millisecond precision; nanoseconds will be truncated.")
 		Logger.Println("Consumer.Offsets.Retention only supports millisecond precision; nanoseconds will be truncated.")
 	}
 	}
+	if c.Consumer.Group.Session.Timeout%time.Millisecond != 0 {
+		Logger.Println("Consumer.Group.Session.Timeout only supports millisecond precision; nanoseconds will be truncated.")
+	}
+	if c.Consumer.Group.Heartbeat.Interval%time.Millisecond != 0 {
+		Logger.Println("Consumer.Group.Heartbeat.Interval only supports millisecond precision; nanoseconds will be truncated.")
+	}
+	if c.Consumer.Group.Rebalance.Timeout%time.Millisecond != 0 {
+		Logger.Println("Consumer.Group.Rebalance.Timeout only supports millisecond precision; nanoseconds will be truncated.")
+	}
 	if c.ClientID == defaultClientID {
 	if c.ClientID == defaultClientID {
 		Logger.Println("ClientID is the default of 'sarama', you should consider setting it to something application-specific.")
 		Logger.Println("ClientID is the default of 'sarama', you should consider setting it to something application-specific.")
 	}
 	}
@@ -476,6 +533,24 @@ func (c *Config) Validate() error {
 		return ConfigurationError("Consumer.Offsets.Retry.Max must be >= 0")
 		return ConfigurationError("Consumer.Offsets.Retry.Max must be >= 0")
 	}
 	}
 
 
+	// validate the Consumer Group values
+	switch {
+	case c.Consumer.Group.Session.Timeout <= 2*time.Millisecond:
+		return ConfigurationError("Consumer.Group.Session.Timeout must be >= 2ms")
+	case c.Consumer.Group.Heartbeat.Interval < 1*time.Millisecond:
+		return ConfigurationError("Consumer.Group.Heartbeat.Interval must be >= 1ms")
+	case c.Consumer.Group.Heartbeat.Interval >= c.Consumer.Group.Session.Timeout:
+		return ConfigurationError("Consumer.Group.Heartbeat.Interval must be < Consumer.Group.Session.Timeout")
+	case c.Consumer.Group.Rebalance.Strategy == nil:
+		return ConfigurationError("Consumer.Group.Rebalance.Strategy must not be empty")
+	case c.Consumer.Group.Rebalance.Timeout <= time.Millisecond:
+		return ConfigurationError("Consumer.Group.Rebalance.Timeout must be >= 1ms")
+	case c.Consumer.Group.Rebalance.Retry.Max < 0:
+		return ConfigurationError("Consumer.Group.Rebalance.Retry.Max must be >= 0")
+	case c.Consumer.Group.Rebalance.Retry.Backoff < 0:
+		return ConfigurationError("Consumer.Group.Rebalance.Retry.Backoff must be >= 0")
+	}
+
 	// validate misc shared values
 	// validate misc shared values
 	switch {
 	switch {
 	case c.ChannelBufferSize < 0:
 	case c.ChannelBufferSize < 0:

+ 1 - 1
config_test.go

@@ -222,7 +222,7 @@ func TestLZ4ConfigValidation(t *testing.T) {
 	config := NewConfig()
 	config := NewConfig()
 	config.Producer.Compression = CompressionLZ4
 	config.Producer.Compression = CompressionLZ4
 	if err := config.Validate(); string(err.(ConfigurationError)) != "lz4 compression requires Version >= V0_10_0_0" {
 	if err := config.Validate(); string(err.(ConfigurationError)) != "lz4 compression requires Version >= V0_10_0_0" {
-		t.Error("Expected invalid lz4/kakfa version error, got ", err)
+		t.Error("Expected invalid lz4/kafka version error, got ", err)
 	}
 	}
 	config.Version = V0_10_0_0
 	config.Version = V0_10_0_0
 	if err := config.Validate(); err != nil {
 	if err := config.Validate(); err != nil {

+ 774 - 0
consumer_group.go

@@ -0,0 +1,774 @@
+package sarama
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"sort"
+	"sync"
+	"time"
+)
+
+// ErrClosedConsumerGroup is the error returned when a method is called on a consumer group that has been closed.
+var ErrClosedConsumerGroup = errors.New("kafka: tried to use a consumer group that was closed")
+
+// ConsumerGroup is responsible for dividing up processing of topics and partitions
+// over a collection of processes (the members of the consumer group).
+type ConsumerGroup interface {
+	// Consume joins a cluster of consumers for a given list of topics and
+	// starts a blocking ConsumerGroupSession through the ConsumerGroupHandler.
+	//
+	// The life-cycle of a session is represented by the following steps:
+	//
+	// 1. The consumers join the group (as explained in https://kafka.apache.org/documentation/#intro_consumers)
+	//    and is assigned their "fair share" of partitions, aka 'claims'.
+	// 2. Before processing starts, the handler's Setup() hook is called to notify the user
+	//    of the claims and allow any necessary preparation or alteration of state.
+	// 3. For each of the assigned claims the handler's ConsumeClaim() function is then called
+	//    in a separate goroutine which requires it to be thread-safe. Any state must be carefully protected
+	//    from concurrent reads/writes.
+	// 4. The session will persist until one of the ConsumeClaim() functions exits. This can be either when the
+	//    parent context is cancelled or when a server-side rebalance cycle is initiated.
+	// 5. Once all the ConsumeClaim() loops have exited, the handler's Cleanup() hook is called
+	//    to allow the user to perform any final tasks before a rebalance.
+	// 6. Finally, marked offsets are committed one last time before claims are released.
+	//
+	// Please note, that once a relance is triggered, sessions must be completed within
+	// Config.Consumer.Group.Rebalance.Timeout. This means that ConsumeClaim() functions must exit
+	// as quickly as possible to allow time for Cleanup() and the final offset commit. If the timeout
+	// is exceeded, the consumer will be removed from the group by Kafka, which will cause offset
+	// commit failures.
+	Consume(ctx context.Context, topics []string, handler ConsumerGroupHandler) error
+
+	// Errors returns a read channel of errors that occurred during the consumer life-cycle.
+	// By default, errors are logged and not returned over this channel.
+	// If you want to implement any custom error handling, set your config's
+	// Consumer.Return.Errors setting to true, and read from this channel.
+	Errors() <-chan error
+
+	// Close stops the ConsumerGroup and detaches any running sessions. It is required to call
+	// this function before the object passes out of scope, as it will otherwise leak memory.
+	Close() error
+}
+
+type consumerGroup struct {
+	client    Client
+	ownClient bool
+
+	config   *Config
+	consumer Consumer
+	groupID  string
+	memberID string
+	errors   chan error
+
+	lock      sync.Mutex
+	closed    chan none
+	closeOnce sync.Once
+}
+
+// NewConsumerGroup creates a new consumer group the given broker addresses and configuration.
+func NewConsumerGroup(addrs []string, groupID string, config *Config) (ConsumerGroup, error) {
+	client, err := NewClient(addrs, config)
+	if err != nil {
+		return nil, err
+	}
+
+	c, err := NewConsumerGroupFromClient(groupID, client)
+	if err != nil {
+		_ = client.Close()
+		return nil, err
+	}
+
+	c.(*consumerGroup).ownClient = true
+	return c, nil
+}
+
+// NewConsumerFromClient creates a new consumer group using the given client. It is still
+// necessary to call Close() on the underlying client when shutting down this consumer.
+// PLEASE NOTE: consumer groups can only re-use but not share clients.
+func NewConsumerGroupFromClient(groupID string, client Client) (ConsumerGroup, error) {
+	config := client.Config()
+	if !config.Version.IsAtLeast(V0_10_2_0) {
+		return nil, ConfigurationError("consumer groups require Version to be >= V0_10_2_0")
+	}
+
+	consumer, err := NewConsumerFromClient(client)
+	if err != nil {
+		return nil, err
+	}
+
+	return &consumerGroup{
+		client:   client,
+		consumer: consumer,
+		config:   config,
+		groupID:  groupID,
+		errors:   make(chan error, config.ChannelBufferSize),
+		closed:   make(chan none),
+	}, nil
+}
+
+// Errors implements ConsumerGroup.
+func (c *consumerGroup) Errors() <-chan error { return c.errors }
+
+// Close implements ConsumerGroup.
+func (c *consumerGroup) Close() (err error) {
+	c.closeOnce.Do(func() {
+		close(c.closed)
+
+		c.lock.Lock()
+		defer c.lock.Unlock()
+
+		// leave group
+		if e := c.leave(); e != nil {
+			err = e
+		}
+
+		// drain errors
+		go func() {
+			close(c.errors)
+		}()
+		for e := range c.errors {
+			err = e
+		}
+
+		if c.ownClient {
+			if e := c.client.Close(); e != nil {
+				err = e
+			}
+		}
+	})
+	return
+}
+
+// Consume implements ConsumerGroup.
+func (c *consumerGroup) Consume(ctx context.Context, topics []string, handler ConsumerGroupHandler) error {
+	// Ensure group is not closed
+	select {
+	case <-c.closed:
+		return ErrClosedConsumerGroup
+	default:
+	}
+
+	c.lock.Lock()
+	defer c.lock.Unlock()
+
+	// Quick exit when no topics are provided
+	if len(topics) == 0 {
+		return fmt.Errorf("no topics provided")
+	}
+
+	// Refresh metadata for requested topics
+	if err := c.client.RefreshMetadata(topics...); err != nil {
+		return err
+	}
+
+	// Get coordinator
+	coordinator, err := c.client.Coordinator(c.groupID)
+	if err != nil {
+		return err
+	}
+
+	// Init session
+	sess, err := c.newSession(ctx, coordinator, topics, handler, c.config.Consumer.Group.Rebalance.Retry.Max)
+	if err == ErrClosedClient {
+		return ErrClosedConsumerGroup
+	} else if err != nil {
+		return err
+	}
+
+	// Wait for session exit signal
+	<-sess.ctx.Done()
+
+	// Gracefully release session claims
+	return sess.release(true)
+}
+
+func (c *consumerGroup) newSession(ctx context.Context, coordinator *Broker, topics []string, handler ConsumerGroupHandler, retries int) (*consumerGroupSession, error) {
+	// Join consumer group
+	join, err := c.joinGroupRequest(coordinator, topics)
+	if err != nil {
+		_ = coordinator.Close()
+		return nil, err
+	}
+	switch join.Err {
+	case ErrNoError:
+		c.memberID = join.MemberId
+	case ErrUnknownMemberId, ErrIllegalGeneration: // reset member ID and retry immediately
+		c.memberID = ""
+		return c.newSession(ctx, coordinator, topics, handler, retries)
+	case ErrRebalanceInProgress: // retry after backoff
+		if retries <= 0 {
+			return nil, join.Err
+		}
+
+		select {
+		case <-c.closed:
+			return nil, ErrClosedConsumerGroup
+		case <-time.After(c.config.Consumer.Group.Rebalance.Retry.Backoff):
+		}
+
+		return c.newSession(ctx, coordinator, topics, handler, retries-1)
+	default:
+		return nil, join.Err
+	}
+
+	// Prepare distribution plan if we joined as the leader
+	var plan BalanceStrategyPlan
+	if join.LeaderId == join.MemberId {
+		members, err := join.GetMembers()
+		if err != nil {
+			return nil, err
+		}
+
+		plan, err = c.balance(members)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	// Sync consumer group
+	sync, err := c.syncGroupRequest(coordinator, plan, join.GenerationId)
+	if err != nil {
+		_ = coordinator.Close()
+		return nil, err
+	}
+	switch sync.Err {
+	case ErrNoError:
+	case ErrUnknownMemberId, ErrIllegalGeneration: // reset member ID and retry immediately
+		c.memberID = ""
+		return c.newSession(ctx, coordinator, topics, handler, retries)
+	case ErrRebalanceInProgress: // retry after backoff
+		if retries <= 0 {
+			return nil, sync.Err
+		}
+
+		select {
+		case <-c.closed:
+			return nil, ErrClosedConsumerGroup
+		case <-time.After(c.config.Consumer.Group.Rebalance.Retry.Backoff):
+		}
+
+		return c.newSession(ctx, coordinator, topics, handler, retries-1)
+	default:
+		return nil, sync.Err
+	}
+
+	// Retrieve and sort claims
+	var claims map[string][]int32
+	if len(sync.MemberAssignment) > 0 {
+		members, err := sync.GetMemberAssignment()
+		if err != nil {
+			return nil, err
+		}
+		claims = members.Topics
+
+		for _, partitions := range claims {
+			sort.Sort(int32Slice(partitions))
+		}
+	}
+
+	return newConsumerGroupSession(c, ctx, claims, join.MemberId, join.GenerationId, handler)
+}
+
+func (c *consumerGroup) joinGroupRequest(coordinator *Broker, topics []string) (*JoinGroupResponse, error) {
+	req := &JoinGroupRequest{
+		GroupId:        c.groupID,
+		MemberId:       c.memberID,
+		SessionTimeout: int32(c.config.Consumer.Group.Session.Timeout / time.Millisecond),
+		ProtocolType:   "consumer",
+	}
+	if c.config.Version.IsAtLeast(V0_10_1_0) {
+		req.Version = 1
+		req.RebalanceTimeout = int32(c.config.Consumer.Group.Rebalance.Timeout / time.Millisecond)
+	}
+
+	meta := &ConsumerGroupMemberMetadata{
+		Topics:   topics,
+		UserData: c.config.Consumer.Group.Member.UserData,
+	}
+	strategy := c.config.Consumer.Group.Rebalance.Strategy
+	if err := req.AddGroupProtocolMetadata(strategy.Name(), meta); err != nil {
+		return nil, err
+	}
+
+	return coordinator.JoinGroup(req)
+}
+
+func (c *consumerGroup) syncGroupRequest(coordinator *Broker, plan BalanceStrategyPlan, generationID int32) (*SyncGroupResponse, error) {
+	req := &SyncGroupRequest{
+		GroupId:      c.groupID,
+		MemberId:     c.memberID,
+		GenerationId: generationID,
+	}
+	for memberID, topics := range plan {
+		err := req.AddGroupAssignmentMember(memberID, &ConsumerGroupMemberAssignment{
+			Topics: topics,
+		})
+		if err != nil {
+			return nil, err
+		}
+	}
+	return coordinator.SyncGroup(req)
+}
+
+func (c *consumerGroup) heartbeatRequest(coordinator *Broker, memberID string, generationID int32) (*HeartbeatResponse, error) {
+	req := &HeartbeatRequest{
+		GroupId:      c.groupID,
+		MemberId:     memberID,
+		GenerationId: generationID,
+	}
+
+	return coordinator.Heartbeat(req)
+}
+
+func (c *consumerGroup) balance(members map[string]ConsumerGroupMemberMetadata) (BalanceStrategyPlan, error) {
+	topics := make(map[string][]int32)
+	for _, meta := range members {
+		for _, topic := range meta.Topics {
+			topics[topic] = nil
+		}
+	}
+
+	for topic := range topics {
+		partitions, err := c.client.Partitions(topic)
+		if err != nil {
+			return nil, err
+		}
+		topics[topic] = partitions
+	}
+
+	strategy := c.config.Consumer.Group.Rebalance.Strategy
+	return strategy.Plan(members, topics)
+}
+
+// Leaves the cluster, called by Close, protected by lock.
+func (c *consumerGroup) leave() error {
+	if c.memberID == "" {
+		return nil
+	}
+
+	coordinator, err := c.client.Coordinator(c.groupID)
+	if err != nil {
+		return err
+	}
+
+	resp, err := coordinator.LeaveGroup(&LeaveGroupRequest{
+		GroupId:  c.groupID,
+		MemberId: c.memberID,
+	})
+	if err != nil {
+		_ = coordinator.Close()
+		return err
+	}
+
+	// Unset memberID
+	c.memberID = ""
+
+	// Check response
+	switch resp.Err {
+	case ErrRebalanceInProgress, ErrUnknownMemberId, ErrNoError:
+		return nil
+	default:
+		return resp.Err
+	}
+}
+
+func (c *consumerGroup) handleError(err error, topic string, partition int32) {
+	select {
+	case <-c.closed:
+		return
+	default:
+	}
+
+	if _, ok := err.(*ConsumerError); !ok && topic != "" && partition > -1 {
+		err = &ConsumerError{
+			Topic:     topic,
+			Partition: partition,
+			Err:       err,
+		}
+	}
+
+	if c.config.Consumer.Return.Errors {
+		select {
+		case c.errors <- err:
+		default:
+		}
+	} else {
+		Logger.Println(err)
+	}
+}
+
+// --------------------------------------------------------------------
+
+// ConsumerGroupSession represents a consumer group member session.
+type ConsumerGroupSession interface {
+	// Claims returns information about the claimed partitions by topic.
+	Claims() map[string][]int32
+
+	// MemberID returns the cluster member ID.
+	MemberID() string
+
+	// GenerationID returns the current generation ID.
+	GenerationID() int32
+
+	// MarkOffset marks the provided offset, alongside a metadata string
+	// that represents the state of the partition consumer at that point in time. The
+	// metadata string can be used by another consumer to restore that state, so it
+	// can resume consumption.
+	//
+	// To follow upstream conventions, you are expected to mark the offset of the
+	// next message to read, not the last message read. Thus, when calling `MarkOffset`
+	// you should typically add one to the offset of the last consumed message.
+	//
+	// Note: calling MarkOffset does not necessarily commit the offset to the backend
+	// store immediately for efficiency reasons, and it may never be committed if
+	// your application crashes. This means that you may end up processing the same
+	// message twice, and your processing should ideally be idempotent.
+	MarkOffset(topic string, partition int32, offset int64, metadata string)
+
+	// ResetOffset resets to the provided offset, alongside a metadata string that
+	// represents the state of the partition consumer at that point in time. Reset
+	// acts as a counterpart to MarkOffset, the difference being that it allows to
+	// reset an offset to an earlier or smaller value, where MarkOffset only
+	// allows incrementing the offset. cf MarkOffset for more details.
+	ResetOffset(topic string, partition int32, offset int64, metadata string)
+
+	// MarkMessage marks a message as consumed.
+	MarkMessage(msg *ConsumerMessage, metadata string)
+
+	// Context returns the session context.
+	Context() context.Context
+}
+
+type consumerGroupSession struct {
+	parent       *consumerGroup
+	memberID     string
+	generationID int32
+	handler      ConsumerGroupHandler
+
+	claims  map[string][]int32
+	offsets *offsetManager
+	ctx     context.Context
+	cancel  func()
+
+	waitGroup       sync.WaitGroup
+	releaseOnce     sync.Once
+	hbDying, hbDead chan none
+}
+
+func newConsumerGroupSession(parent *consumerGroup, ctx context.Context, claims map[string][]int32, memberID string, generationID int32, handler ConsumerGroupHandler) (*consumerGroupSession, error) {
+	// init offset manager
+	offsets, err := newOffsetManagerFromClient(parent.groupID, memberID, generationID, parent.client)
+	if err != nil {
+		return nil, err
+	}
+
+	// init context
+	ctx, cancel := context.WithCancel(ctx)
+
+	// init session
+	sess := &consumerGroupSession{
+		parent:       parent,
+		memberID:     memberID,
+		generationID: generationID,
+		handler:      handler,
+		offsets:      offsets,
+		claims:       claims,
+		ctx:          ctx,
+		cancel:       cancel,
+		hbDying:      make(chan none),
+		hbDead:       make(chan none),
+	}
+
+	// start heartbeat loop
+	go sess.heartbeatLoop()
+
+	// create a POM for each claim
+	for topic, partitions := range claims {
+		for _, partition := range partitions {
+			pom, err := offsets.ManagePartition(topic, partition)
+			if err != nil {
+				_ = sess.release(false)
+				return nil, err
+			}
+
+			// handle POM errors
+			go func(topic string, partition int32) {
+				for err := range pom.Errors() {
+					sess.parent.handleError(err, topic, partition)
+				}
+			}(topic, partition)
+		}
+	}
+
+	// perform setup
+	if err := handler.Setup(sess); err != nil {
+		_ = sess.release(true)
+		return nil, err
+	}
+
+	// start consuming
+	for topic, partitions := range claims {
+		for _, partition := range partitions {
+			sess.waitGroup.Add(1)
+
+			go func(topic string, partition int32) {
+				defer sess.waitGroup.Done()
+
+				// cancel the as session as soon as the first
+				// goroutine exits
+				defer sess.cancel()
+
+				// consume a single topic/partition, blocking
+				sess.consume(topic, partition)
+			}(topic, partition)
+		}
+	}
+	return sess, nil
+}
+
+func (s *consumerGroupSession) Claims() map[string][]int32 { return s.claims }
+func (s *consumerGroupSession) MemberID() string           { return s.memberID }
+func (s *consumerGroupSession) GenerationID() int32        { return s.generationID }
+
+func (s *consumerGroupSession) MarkOffset(topic string, partition int32, offset int64, metadata string) {
+	if pom := s.offsets.findPOM(topic, partition); pom != nil {
+		pom.MarkOffset(offset, metadata)
+	}
+}
+
+func (s *consumerGroupSession) ResetOffset(topic string, partition int32, offset int64, metadata string) {
+	if pom := s.offsets.findPOM(topic, partition); pom != nil {
+		pom.ResetOffset(offset, metadata)
+	}
+}
+
+func (s *consumerGroupSession) MarkMessage(msg *ConsumerMessage, metadata string) {
+	s.MarkOffset(msg.Topic, msg.Partition, msg.Offset+1, metadata)
+}
+
+func (s *consumerGroupSession) Context() context.Context {
+	return s.ctx
+}
+
+func (s *consumerGroupSession) consume(topic string, partition int32) {
+	// quick exit if rebalance is due
+	select {
+	case <-s.ctx.Done():
+		return
+	case <-s.parent.closed:
+		return
+	default:
+	}
+
+	// get next offset
+	offset := s.parent.config.Consumer.Offsets.Initial
+	if pom := s.offsets.findPOM(topic, partition); pom != nil {
+		offset, _ = pom.NextOffset()
+	}
+
+	// create new claim
+	claim, err := newConsumerGroupClaim(s, topic, partition, offset)
+	if err != nil {
+		s.parent.handleError(err, topic, partition)
+		return
+	}
+
+	// handle errors
+	go func() {
+		for err := range claim.Errors() {
+			s.parent.handleError(err, topic, partition)
+		}
+	}()
+
+	// trigger close when session is done
+	go func() {
+		select {
+		case <-s.ctx.Done():
+		case <-s.parent.closed:
+		}
+		claim.AsyncClose()
+	}()
+
+	// start processing
+	if err := s.handler.ConsumeClaim(s, claim); err != nil {
+		s.parent.handleError(err, topic, partition)
+	}
+
+	// ensure consumer is clased & drained
+	claim.AsyncClose()
+	for _, err := range claim.waitClosed() {
+		s.parent.handleError(err, topic, partition)
+	}
+}
+
+func (s *consumerGroupSession) release(withCleanup bool) (err error) {
+	// signal release, stop heartbeat
+	s.cancel()
+
+	// wait for consumers to exit
+	s.waitGroup.Wait()
+
+	// perform release
+	s.releaseOnce.Do(func() {
+		if withCleanup {
+			if e := s.handler.Cleanup(s); e != nil {
+				s.parent.handleError(err, "", -1)
+				err = e
+			}
+		}
+
+		if e := s.offsets.Close(); e != nil {
+			err = e
+		}
+
+		close(s.hbDying)
+		<-s.hbDead
+	})
+
+	return
+}
+
+func (s *consumerGroupSession) heartbeatLoop() {
+	defer close(s.hbDead)
+	defer s.cancel() // trigger the end of the session on exit
+
+	pause := time.NewTicker(s.parent.config.Consumer.Group.Heartbeat.Interval)
+	defer pause.Stop()
+
+	retries := s.parent.config.Metadata.Retry.Max
+	for {
+		coordinator, err := s.parent.client.Coordinator(s.parent.groupID)
+		if err != nil {
+			if retries <= 0 {
+				s.parent.handleError(err, "", -1)
+				return
+			}
+
+			select {
+			case <-s.hbDying:
+				return
+			case <-time.After(s.parent.config.Metadata.Retry.Backoff):
+				retries--
+			}
+			continue
+		}
+
+		resp, err := s.parent.heartbeatRequest(coordinator, s.memberID, s.generationID)
+		if err != nil {
+			_ = coordinator.Close()
+			retries--
+			continue
+		}
+
+		switch resp.Err {
+		case ErrNoError:
+			retries = s.parent.config.Metadata.Retry.Max
+		case ErrRebalanceInProgress, ErrUnknownMemberId, ErrIllegalGeneration:
+			return
+		default:
+			s.parent.handleError(err, "", -1)
+			return
+		}
+
+		select {
+		case <-pause.C:
+		case <-s.hbDying:
+			return
+		}
+	}
+}
+
+// --------------------------------------------------------------------
+
+// ConsumerGroupHandler instances are used to handle individual topic/partition claims.
+// It also provides hooks for your consumer group session life-cycle and allow you to
+// trigger logic before or after the consume loop(s).
+//
+// PLEASE NOTE that handlers are likely be called from several goroutines concurrently,
+// ensure that all state is safely protected against race conditions.
+type ConsumerGroupHandler interface {
+	// Setup is run at the beginning of a new session, before ConsumeClaim.
+	Setup(ConsumerGroupSession) error
+
+	// Cleanup is run at the end of a session, once all ConsumeClaim goroutines have exites
+	// but before the offsets are committed for the very last time.
+	Cleanup(ConsumerGroupSession) error
+
+	// ConsumeClaim must start a consumer loop of ConsumerGroupClaim's Messages().
+	// Once the Messages() channel is closed, the Handler must finish its processing
+	// loop and exit.
+	ConsumeClaim(ConsumerGroupSession, ConsumerGroupClaim) error
+}
+
+// ConsumerGroupClaim processes Kafka messages from a given topic and partition within a consumer group.
+type ConsumerGroupClaim interface {
+	// Topic returns the consumed topic name.
+	Topic() string
+
+	// Partition returns the consumed partition.
+	Partition() int32
+
+	// InitialOffset returns the initial offset that was used as a starting point for this claim.
+	InitialOffset() int64
+
+	// HighWaterMarkOffset returns the high water mark offset of the partition,
+	// i.e. the offset that will be used for the next message that will be produced.
+	// You can use this to determine how far behind the processing is.
+	HighWaterMarkOffset() int64
+
+	// Messages returns the read channel for the messages that are returned by
+	// the broker. The messages channel will be closed when a new rebalance cycle
+	// is due. You must finish processing and mark offsets within
+	// Config.Consumer.Group.Session.Timeout before the topic/partition is eventually
+	// re-assigned to another group member.
+	Messages() <-chan *ConsumerMessage
+}
+
+type consumerGroupClaim struct {
+	topic     string
+	partition int32
+	offset    int64
+	PartitionConsumer
+}
+
+func newConsumerGroupClaim(sess *consumerGroupSession, topic string, partition int32, offset int64) (*consumerGroupClaim, error) {
+	pcm, err := sess.parent.consumer.ConsumePartition(topic, partition, offset)
+	if err == ErrOffsetOutOfRange {
+		offset = sess.parent.config.Consumer.Offsets.Initial
+		pcm, err = sess.parent.consumer.ConsumePartition(topic, partition, offset)
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	go func() {
+		for err := range pcm.Errors() {
+			sess.parent.handleError(err, topic, partition)
+		}
+	}()
+
+	return &consumerGroupClaim{
+		topic:             topic,
+		partition:         partition,
+		offset:            offset,
+		PartitionConsumer: pcm,
+	}, nil
+}
+
+func (c *consumerGroupClaim) Topic() string        { return c.topic }
+func (c *consumerGroupClaim) Partition() int32     { return c.partition }
+func (c *consumerGroupClaim) InitialOffset() int64 { return c.offset }
+
+// Drains messages and errors, ensures the claim is fully closed.
+func (c *consumerGroupClaim) waitClosed() (errs ConsumerErrors) {
+	go func() {
+		for range c.Messages() {
+		}
+	}()
+
+	for err := range c.Errors() {
+		errs = append(errs, err)
+	}
+	return
+}

+ 58 - 0
consumer_group_test.go

@@ -0,0 +1,58 @@
+package sarama
+
+import (
+	"context"
+	"fmt"
+)
+
+type exampleConsumerGroupHandler struct{}
+
+func (exampleConsumerGroupHandler) Setup(_ ConsumerGroupSession) error   { return nil }
+func (exampleConsumerGroupHandler) Cleanup(_ ConsumerGroupSession) error { return nil }
+func (h exampleConsumerGroupHandler) ConsumeClaim(sess ConsumerGroupSession, claim ConsumerGroupClaim) error {
+	for msg := range claim.Messages() {
+		fmt.Printf("Message topic:%q partition:%d offset:%d\n", msg.Topic, msg.Partition, msg.Offset)
+		sess.MarkMessage(msg, "")
+	}
+	return nil
+}
+
+func ExampleConsumerGroup() {
+	// Init config, specify appropriate version
+	config := NewConfig()
+	config.Version = V1_0_0_0
+	config.Consumer.Return.Errors = true
+
+	// Start with a client
+	client, err := NewClient([]string{"localhost:9092"}, config)
+	if err != nil {
+		panic(err)
+	}
+	defer func() { _ = client.Close() }()
+
+	// Start a new consumer group
+	group, err := NewConsumerGroupFromClient("my-group", client)
+	if err != nil {
+		panic(err)
+	}
+	defer func() { _ = group.Close() }()
+
+	// Track errors
+	go func() {
+		for err := range group.Errors() {
+			fmt.Println("ERROR", err)
+		}
+	}()
+
+	// Iterate over consumer sessions.
+	ctx := context.Background()
+	for {
+		topics := []string{"my-topic"}
+		handler := exampleConsumerGroupHandler{}
+
+		err := group.Consume(ctx, topics, handler)
+		if err != nil {
+			panic(err)
+		}
+	}
+}

+ 1 - 1
dev.yml

@@ -2,7 +2,7 @@ name: sarama
 
 
 up:
 up:
   - go:
   - go:
-      version: '1.10'
+      version: '1.11'
 
 
 commands:
 commands:
   test:
   test:

+ 418 - 0
functional_consumer_group_test.go

@@ -0,0 +1,418 @@
+// +build go1.9
+
+package sarama
+
+import (
+	"context"
+	"fmt"
+	"log"
+	"reflect"
+	"sync"
+	"sync/atomic"
+	"testing"
+	"time"
+)
+
+func TestFuncConsumerGroupPartitioning(t *testing.T) {
+	checkKafkaVersion(t, "0.10.2")
+	setupFunctionalTest(t)
+	defer teardownFunctionalTest(t)
+
+	groupID := testFuncConsumerGroupID(t)
+
+	// start M1
+	m1 := runTestFuncConsumerGroupMember(t, groupID, "M1", 0, nil)
+	defer m1.Stop()
+	m1.WaitForState(2)
+	m1.WaitForClaims(map[string]int{"test.4": 4})
+	m1.WaitForHandlers(4)
+
+	// start M2
+	m2 := runTestFuncConsumerGroupMember(t, groupID, "M2", 0, nil, "test.1", "test.4")
+	defer m2.Stop()
+	m2.WaitForState(2)
+
+	// assert that claims are shared among both members
+	m1.WaitForClaims(map[string]int{"test.4": 2})
+	m1.WaitForHandlers(2)
+	m2.WaitForClaims(map[string]int{"test.1": 1, "test.4": 2})
+	m2.WaitForHandlers(3)
+
+	// shutdown M1, wait for M2 to take over
+	m1.AssertCleanShutdown()
+	m2.WaitForClaims(map[string]int{"test.1": 1, "test.4": 4})
+	m2.WaitForHandlers(5)
+
+	// shutdown M2
+	m2.AssertCleanShutdown()
+}
+
+func TestFuncConsumerGroupExcessConsumers(t *testing.T) {
+	checkKafkaVersion(t, "0.10.2")
+	setupFunctionalTest(t)
+	defer teardownFunctionalTest(t)
+
+	groupID := testFuncConsumerGroupID(t)
+
+	// start members
+	m1 := runTestFuncConsumerGroupMember(t, groupID, "M1", 0, nil)
+	defer m1.Stop()
+	m2 := runTestFuncConsumerGroupMember(t, groupID, "M2", 0, nil)
+	defer m2.Stop()
+	m3 := runTestFuncConsumerGroupMember(t, groupID, "M3", 0, nil)
+	defer m3.Stop()
+	m4 := runTestFuncConsumerGroupMember(t, groupID, "M4", 0, nil)
+	defer m4.Stop()
+
+	m1.WaitForClaims(map[string]int{"test.4": 1})
+	m2.WaitForClaims(map[string]int{"test.4": 1})
+	m3.WaitForClaims(map[string]int{"test.4": 1})
+	m4.WaitForClaims(map[string]int{"test.4": 1})
+
+	// start M5
+	m5 := runTestFuncConsumerGroupMember(t, groupID, "M5", 0, nil)
+	defer m5.Stop()
+	m5.WaitForState(1)
+	m5.AssertNoErrs()
+
+	// assert that claims are shared among both members
+	m4.AssertCleanShutdown()
+	m5.WaitForState(2)
+	m5.WaitForClaims(map[string]int{"test.4": 1})
+
+	// shutdown everything
+	m1.AssertCleanShutdown()
+	m2.AssertCleanShutdown()
+	m3.AssertCleanShutdown()
+	m5.AssertCleanShutdown()
+}
+
+func TestFuncConsumerGroupFuzzy(t *testing.T) {
+	checkKafkaVersion(t, "0.10.2")
+	setupFunctionalTest(t)
+	defer teardownFunctionalTest(t)
+
+	if err := testFuncConsumerGroupFuzzySeed("test.4"); err != nil {
+		t.Fatal(err)
+	}
+
+	groupID := testFuncConsumerGroupID(t)
+	sink := &testFuncConsumerGroupSink{msgs: make(chan testFuncConsumerGroupMessage, 20000)}
+	waitForMessages := func(t *testing.T, n int) {
+		t.Helper()
+
+		for i := 0; i < 600; i++ {
+			if sink.Len() >= n {
+				break
+			}
+			time.Sleep(100 * time.Millisecond)
+		}
+		if sz := sink.Len(); sz < n {
+			log.Fatalf("expected to consume %d messages, but consumed %d", n, sz)
+		}
+	}
+
+	defer runTestFuncConsumerGroupMember(t, groupID, "M1", 1500, sink).Stop()
+	defer runTestFuncConsumerGroupMember(t, groupID, "M2", 3000, sink).Stop()
+	defer runTestFuncConsumerGroupMember(t, groupID, "M3", 1500, sink).Stop()
+	defer runTestFuncConsumerGroupMember(t, groupID, "M4", 200, sink).Stop()
+	defer runTestFuncConsumerGroupMember(t, groupID, "M5", 100, sink).Stop()
+	waitForMessages(t, 3000)
+
+	defer runTestFuncConsumerGroupMember(t, groupID, "M6", 300, sink).Stop()
+	defer runTestFuncConsumerGroupMember(t, groupID, "M7", 400, sink).Stop()
+	defer runTestFuncConsumerGroupMember(t, groupID, "M8", 500, sink).Stop()
+	defer runTestFuncConsumerGroupMember(t, groupID, "M9", 2000, sink).Stop()
+	waitForMessages(t, 8000)
+
+	defer runTestFuncConsumerGroupMember(t, groupID, "M10", 1000, sink).Stop()
+	waitForMessages(t, 10000)
+
+	defer runTestFuncConsumerGroupMember(t, groupID, "M11", 1000, sink).Stop()
+	defer runTestFuncConsumerGroupMember(t, groupID, "M12", 2500, sink).Stop()
+	waitForMessages(t, 12000)
+
+	defer runTestFuncConsumerGroupMember(t, groupID, "M13", 1000, sink).Stop()
+	waitForMessages(t, 15000)
+
+	if umap := sink.Close(); len(umap) != 15000 {
+		dupes := make(map[string][]string)
+		for k, v := range umap {
+			if len(v) > 1 {
+				dupes[k] = v
+			}
+		}
+		t.Fatalf("expected %d unique messages to be consumed but got %d, including %d duplicates:\n%v", 15000, len(umap), len(dupes), dupes)
+	}
+}
+
+// --------------------------------------------------------------------
+
+func testFuncConsumerGroupID(t *testing.T) string {
+	return fmt.Sprintf("sarama.%s%d", t.Name(), time.Now().UnixNano())
+}
+
+func testFuncConsumerGroupFuzzySeed(topic string) error {
+	client, err := NewClient(kafkaBrokers, nil)
+	if err != nil {
+		return err
+	}
+	defer func() { _ = client.Close() }()
+
+	total := int64(0)
+	for pn := int32(0); pn < 4; pn++ {
+		newest, err := client.GetOffset(topic, pn, OffsetNewest)
+		if err != nil {
+			return err
+		}
+		oldest, err := client.GetOffset(topic, pn, OffsetOldest)
+		if err != nil {
+			return err
+		}
+		total = total + newest - oldest
+	}
+	if total >= 21000 {
+		return nil
+	}
+
+	producer, err := NewAsyncProducerFromClient(client)
+	if err != nil {
+		return err
+	}
+	for i := total; i < 21000; i++ {
+		producer.Input() <- &ProducerMessage{Topic: topic, Value: ByteEncoder([]byte("testdata"))}
+	}
+	return producer.Close()
+}
+
+type testFuncConsumerGroupMessage struct {
+	ClientID string
+	*ConsumerMessage
+}
+
+type testFuncConsumerGroupSink struct {
+	msgs  chan testFuncConsumerGroupMessage
+	count int32
+}
+
+func (s *testFuncConsumerGroupSink) Len() int {
+	if s == nil {
+		return -1
+	}
+	return int(atomic.LoadInt32(&s.count))
+}
+
+func (s *testFuncConsumerGroupSink) Push(clientID string, m *ConsumerMessage) {
+	if s != nil {
+		s.msgs <- testFuncConsumerGroupMessage{ClientID: clientID, ConsumerMessage: m}
+		atomic.AddInt32(&s.count, 1)
+	}
+}
+
+func (s *testFuncConsumerGroupSink) Close() map[string][]string {
+	close(s.msgs)
+
+	res := make(map[string][]string)
+	for msg := range s.msgs {
+		key := fmt.Sprintf("%s-%d:%d", msg.Topic, msg.Partition, msg.Offset)
+		res[key] = append(res[key], msg.ClientID)
+	}
+	return res
+}
+
+type testFuncConsumerGroupMember struct {
+	ConsumerGroup
+	clientID    string
+	claims      map[string]int
+	state       int32
+	handlers    int32
+	errs        []error
+	maxMessages int32
+	isCapped    bool
+	sink        *testFuncConsumerGroupSink
+
+	t  *testing.T
+	mu sync.RWMutex
+}
+
+func runTestFuncConsumerGroupMember(t *testing.T, groupID, clientID string, maxMessages int32, sink *testFuncConsumerGroupSink, topics ...string) *testFuncConsumerGroupMember {
+	t.Helper()
+
+	config := NewConfig()
+	config.ClientID = clientID
+	config.Version = V0_10_2_0
+	config.Consumer.Return.Errors = true
+	config.Consumer.Offsets.Initial = OffsetOldest
+	config.Consumer.Group.Rebalance.Timeout = 10 * time.Second
+
+	group, err := NewConsumerGroup(kafkaBrokers, groupID, config)
+	if err != nil {
+		t.Fatal(err)
+		return nil
+	}
+
+	if len(topics) == 0 {
+		topics = []string{"test.4"}
+	}
+
+	member := &testFuncConsumerGroupMember{
+		ConsumerGroup: group,
+		clientID:      clientID,
+		claims:        make(map[string]int),
+		maxMessages:   maxMessages,
+		isCapped:      maxMessages != 0,
+		sink:          sink,
+		t:             t,
+	}
+	go member.loop(topics)
+	return member
+}
+
+func (m *testFuncConsumerGroupMember) AssertCleanShutdown() {
+	m.t.Helper()
+
+	if err := m.Close(); err != nil {
+		m.t.Fatalf("unexpected error on Close(): %v", err)
+	}
+	m.WaitForState(4)
+	m.WaitForHandlers(0)
+	m.AssertNoErrs()
+}
+
+func (m *testFuncConsumerGroupMember) AssertNoErrs() {
+	m.t.Helper()
+
+	var errs []error
+	m.mu.RLock()
+	errs = append(errs, m.errs...)
+	m.mu.RUnlock()
+
+	if len(errs) != 0 {
+		m.t.Fatalf("unexpected consumer errors: %v", errs)
+	}
+}
+
+func (m *testFuncConsumerGroupMember) WaitForState(expected int32) {
+	m.t.Helper()
+
+	m.waitFor("state", expected, func() (interface{}, error) {
+		return atomic.LoadInt32(&m.state), nil
+	})
+}
+
+func (m *testFuncConsumerGroupMember) WaitForHandlers(expected int) {
+	m.t.Helper()
+
+	m.waitFor("handlers", expected, func() (interface{}, error) {
+		return int(atomic.LoadInt32(&m.handlers)), nil
+	})
+}
+
+func (m *testFuncConsumerGroupMember) WaitForClaims(expected map[string]int) {
+	m.t.Helper()
+
+	m.waitFor("claims", expected, func() (interface{}, error) {
+		m.mu.RLock()
+		claims := m.claims
+		m.mu.RUnlock()
+		return claims, nil
+	})
+}
+
+func (m *testFuncConsumerGroupMember) Stop() { _ = m.Close() }
+
+func (m *testFuncConsumerGroupMember) Setup(s ConsumerGroupSession) error {
+	// store claims
+	claims := make(map[string]int)
+	for topic, partitions := range s.Claims() {
+		claims[topic] = len(partitions)
+	}
+	m.mu.Lock()
+	m.claims = claims
+	m.mu.Unlock()
+
+	// enter post-setup state
+	atomic.StoreInt32(&m.state, 2)
+	return nil
+}
+func (m *testFuncConsumerGroupMember) Cleanup(s ConsumerGroupSession) error {
+	// enter post-cleanup state
+	atomic.StoreInt32(&m.state, 3)
+	return nil
+}
+func (m *testFuncConsumerGroupMember) ConsumeClaim(s ConsumerGroupSession, c ConsumerGroupClaim) error {
+	atomic.AddInt32(&m.handlers, 1)
+	defer atomic.AddInt32(&m.handlers, -1)
+
+	for msg := range c.Messages() {
+		if n := atomic.AddInt32(&m.maxMessages, -1); m.isCapped && n < 0 {
+			break
+		}
+		s.MarkMessage(msg, "")
+		m.sink.Push(m.clientID, msg)
+	}
+	return nil
+}
+
+func (m *testFuncConsumerGroupMember) waitFor(kind string, expected interface{}, factory func() (interface{}, error)) {
+	m.t.Helper()
+
+	deadline := time.NewTimer(60 * time.Second)
+	defer deadline.Stop()
+
+	ticker := time.NewTicker(100 * time.Millisecond)
+	defer ticker.Stop()
+
+	var actual interface{}
+	for {
+		var err error
+		if actual, err = factory(); err != nil {
+			m.t.Errorf("failed retrieve value, expected %s %#v but received error %v", kind, expected, err)
+		}
+
+		if reflect.DeepEqual(expected, actual) {
+			return
+		}
+
+		select {
+		case <-deadline.C:
+			m.t.Fatalf("ttl exceeded, expected %s %#v but got %#v", kind, expected, actual)
+			return
+		case <-ticker.C:
+		}
+	}
+}
+
+func (m *testFuncConsumerGroupMember) loop(topics []string) {
+	defer atomic.StoreInt32(&m.state, 4)
+
+	go func() {
+		for err := range m.Errors() {
+			_ = m.Close()
+
+			m.mu.Lock()
+			m.errs = append(m.errs, err)
+			m.mu.Unlock()
+		}
+	}()
+
+	ctx := context.Background()
+	for {
+		// set state to pre-consume
+		atomic.StoreInt32(&m.state, 1)
+
+		if err := m.Consume(ctx, topics, m); err == ErrClosedConsumerGroup {
+			return
+		} else if err != nil {
+			m.mu.Lock()
+			m.errs = append(m.errs, err)
+			m.mu.Unlock()
+			return
+		}
+
+		// return if capped
+		if n := atomic.LoadInt32(&m.maxMessages); m.isCapped && n < 0 {
+			return
+		}
+	}
+}

+ 26 - 4
message.go

@@ -7,6 +7,7 @@ import (
 	"io/ioutil"
 	"io/ioutil"
 	"time"
 	"time"
 
 
+	"github.com/DataDog/zstd"
 	"github.com/eapache/go-xerial-snappy"
 	"github.com/eapache/go-xerial-snappy"
 	"github.com/pierrec/lz4"
 	"github.com/pierrec/lz4"
 )
 )
@@ -14,14 +15,15 @@ import (
 // CompressionCodec represents the various compression codecs recognized by Kafka in messages.
 // CompressionCodec represents the various compression codecs recognized by Kafka in messages.
 type CompressionCodec int8
 type CompressionCodec int8
 
 
-// only the last two bits are really used
-const compressionCodecMask int8 = 0x03
+// The lowest 3 bits contain the compression codec used for the message
+const compressionCodecMask int8 = 0x07
 
 
 const (
 const (
 	CompressionNone   CompressionCodec = 0
 	CompressionNone   CompressionCodec = 0
 	CompressionGZIP   CompressionCodec = 1
 	CompressionGZIP   CompressionCodec = 1
 	CompressionSnappy CompressionCodec = 2
 	CompressionSnappy CompressionCodec = 2
 	CompressionLZ4    CompressionCodec = 3
 	CompressionLZ4    CompressionCodec = 3
+	CompressionZSTD   CompressionCodec = 4
 )
 )
 
 
 func (cc CompressionCodec) String() string {
 func (cc CompressionCodec) String() string {
@@ -113,7 +115,18 @@ func (m *Message) encode(pe packetEncoder) error {
 			}
 			}
 			m.compressedCache = buf.Bytes()
 			m.compressedCache = buf.Bytes()
 			payload = m.compressedCache
 			payload = m.compressedCache
-
+		case CompressionZSTD:
+			if len(m.Value) == 0 {
+				// Hardcoded empty ZSTD frame, see: https://github.com/DataDog/zstd/issues/41
+				m.compressedCache = []byte{0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x00, 0x01, 0x00, 0x00, 0x99, 0xe9, 0xd8, 0x51}
+			} else {
+				c, err := zstd.CompressLevel(nil, m.Value, m.CompressionLevel)
+				if err != nil {
+					return err
+				}
+				m.compressedCache = c
+			}
+			payload = m.compressedCache
 		default:
 		default:
 			return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", m.Codec)}
 			return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", m.Codec)}
 		}
 		}
@@ -207,7 +220,16 @@ func (m *Message) decode(pd packetDecoder) (err error) {
 		if err := m.decodeSet(); err != nil {
 		if err := m.decodeSet(); err != nil {
 			return err
 			return err
 		}
 		}
-
+	case CompressionZSTD:
+		if m.Value == nil {
+			break
+		}
+		if m.Value, err = zstd.Decompress(nil, m.Value); err != nil {
+			return err
+		}
+		if err := m.decodeSet(); err != nil {
+			return err
+		}
 	default:
 	default:
 		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", m.Codec)}
 		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", m.Codec)}
 	}
 	}

+ 44 - 0
message_test.go

@@ -52,6 +52,17 @@ var (
 		5, 93, 204, 2, // LZ4 checksum
 		5, 93, 204, 2, // LZ4 checksum
 	}
 	}
 
 
+	emptyZSTDMessage = []byte{
+		252, 62, 137, 23, // CRC
+		0x01,                          // version byte
+		0x04,                          // attribute flags: lz4
+		0, 0, 1, 88, 141, 205, 89, 56, // timestamp
+		0xFF, 0xFF, 0xFF, 0xFF, // key
+		0x00, 0x00, 0x00, 0x0d, // len
+		// ZSTD data
+		0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x00, 0x01, 0x00, 0x00, 0x99, 0xe9, 0xd8, 0x51,
+	}
+
 	emptyBulkSnappyMessage = []byte{
 	emptyBulkSnappyMessage = []byte{
 		180, 47, 53, 209, //CRC
 		180, 47, 53, 209, //CRC
 		0x00,                   // magic version byte
 		0x00,                   // magic version byte
@@ -86,6 +97,17 @@ var (
 		112, 185, 52, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 121, 87, 72, 224, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 14, 121, 87, 72, 224, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0,
 		112, 185, 52, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 121, 87, 72, 224, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 14, 121, 87, 72, 224, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0,
 		71, 129, 23, 111, // LZ4 checksum
 		71, 129, 23, 111, // LZ4 checksum
 	}
 	}
+
+	emptyBulkZSTDMessage = []byte{
+		203, 151, 133, 28, // CRC
+		0x01,                                  // Version
+		0x04,                                  // attribute flags (ZSTD)
+		255, 255, 249, 209, 212, 181, 73, 201, // timestamp
+		0xFF, 0xFF, 0xFF, 0xFF, // key
+		0x00, 0x00, 0x00, 0x26, // len
+		// ZSTD data
+		0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x34, 0xcd, 0x0, 0x0, 0x78, 0x0, 0x0, 0xe, 0x79, 0x57, 0x48, 0xe0, 0x0, 0x0, 0xff, 0xff, 0xff, 0xff, 0x0, 0x1, 0x3, 0x0, 0x3d, 0xbd, 0x0, 0x3b, 0x15, 0x0, 0xb, 0xd2, 0x34, 0xc1, 0x78,
+	}
 )
 )
 
 
 func TestMessageEncoding(t *testing.T) {
 func TestMessageEncoding(t *testing.T) {
@@ -101,6 +123,12 @@ func TestMessageEncoding(t *testing.T) {
 	message.Timestamp = time.Unix(1479847795, 0)
 	message.Timestamp = time.Unix(1479847795, 0)
 	message.Version = 1
 	message.Version = 1
 	testEncodable(t, "empty lz4", &message, emptyLZ4Message)
 	testEncodable(t, "empty lz4", &message, emptyLZ4Message)
+
+	message.Value = []byte{}
+	message.Codec = CompressionZSTD
+	message.Timestamp = time.Unix(1479847795, 0)
+	message.Version = 1
+	testEncodable(t, "empty zstd", &message, emptyZSTDMessage)
 }
 }
 
 
 func TestMessageDecoding(t *testing.T) {
 func TestMessageDecoding(t *testing.T) {
@@ -179,6 +207,22 @@ func TestMessageDecodingBulkLZ4(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestMessageDecodingBulkZSTD(t *testing.T) {
+	message := Message{}
+	testDecodable(t, "bulk zstd", &message, emptyBulkZSTDMessage)
+	if message.Codec != CompressionZSTD {
+		t.Errorf("Decoding produced codec %d, but expected %d.", message.Codec, CompressionZSTD)
+	}
+	if message.Key != nil {
+		t.Errorf("Decoding produced key %+v, but none was expected.", message.Key)
+	}
+	if message.Set == nil {
+		t.Error("Decoding produced no set, but one was expected.")
+	} else if len(message.Set.Messages) != 2 {
+		t.Errorf("Decoding produced a set with %d messages, but 2 were expected.", len(message.Set.Messages))
+	}
+}
+
 func TestMessageDecodingVersion1(t *testing.T) {
 func TestMessageDecodingVersion1(t *testing.T) {
 	message := Message{Version: 1}
 	message := Message{Version: 1}
 	testDecodable(t, "decoding empty v1 message", &message, emptyV1Message)
 	testDecodable(t, "decoding empty v1 message", &message, emptyV1Message)

+ 11 - 0
metadata_response.go

@@ -207,6 +207,10 @@ func (r *MetadataResponse) decode(pd packetDecoder, version int16) (err error) {
 }
 }
 
 
 func (r *MetadataResponse) encode(pe packetEncoder) error {
 func (r *MetadataResponse) encode(pe packetEncoder) error {
+	if r.Version >= 3 {
+		pe.putInt32(r.ThrottleTimeMs)
+	}
+
 	err := pe.putArrayLength(len(r.Brokers))
 	err := pe.putArrayLength(len(r.Brokers))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -218,6 +222,13 @@ func (r *MetadataResponse) encode(pe packetEncoder) error {
 		}
 		}
 	}
 	}
 
 
+	if r.Version >= 2 {
+		err := pe.putNullableString(r.ClusterID)
+		if err != nil {
+			return err
+		}
+	}
+
 	if r.Version >= 1 {
 	if r.Version >= 1 {
 		pe.putInt32(r.ControllerID)
 		pe.putInt32(r.ControllerID)
 	}
 	}

+ 1 - 2
mocks/async_producer.go

@@ -44,6 +44,7 @@ func NewAsyncProducer(t ErrorReporter, config *sarama.Config) *AsyncProducer {
 		defer func() {
 		defer func() {
 			close(mp.successes)
 			close(mp.successes)
 			close(mp.errors)
 			close(mp.errors)
+			close(mp.closed)
 		}()
 		}()
 
 
 		for msg := range mp.input {
 		for msg := range mp.input {
@@ -86,8 +87,6 @@ func NewAsyncProducer(t ErrorReporter, config *sarama.Config) *AsyncProducer {
 			mp.t.Errorf("Expected to exhaust all expectations, but %d are left.", len(mp.expectations))
 			mp.t.Errorf("Expected to exhaust all expectations, but %d are left.", len(mp.expectations))
 		}
 		}
 		mp.l.Unlock()
 		mp.l.Unlock()
-
-		close(mp.closed)
 	}()
 	}()
 
 
 	return mp
 	return mp

+ 35 - 11
offset_manager.go

@@ -27,11 +27,14 @@ type offsetManager struct {
 	group  string
 	group  string
 	ticker *time.Ticker
 	ticker *time.Ticker
 
 
+	memberID   string
+	generation int32
+
 	broker     *Broker
 	broker     *Broker
 	brokerLock sync.RWMutex
 	brokerLock sync.RWMutex
 
 
 	poms     map[string]map[int32]*partitionOffsetManager
 	poms     map[string]map[int32]*partitionOffsetManager
-	pomsLock sync.Mutex
+	pomsLock sync.RWMutex
 
 
 	closeOnce sync.Once
 	closeOnce sync.Once
 	closing   chan none
 	closing   chan none
@@ -41,6 +44,10 @@ type offsetManager struct {
 // NewOffsetManagerFromClient creates a new OffsetManager from the given client.
 // NewOffsetManagerFromClient creates a new OffsetManager from the given client.
 // It is still necessary to call Close() on the underlying client when finished with the partition manager.
 // It is still necessary to call Close() on the underlying client when finished with the partition manager.
 func NewOffsetManagerFromClient(group string, client Client) (OffsetManager, error) {
 func NewOffsetManagerFromClient(group string, client Client) (OffsetManager, error) {
+	return newOffsetManagerFromClient(group, "", GroupGenerationUndefined, client)
+}
+
+func newOffsetManagerFromClient(group, memberID string, generation int32, client Client) (*offsetManager, error) {
 	// Check that we are not dealing with a closed Client before processing any other arguments
 	// Check that we are not dealing with a closed Client before processing any other arguments
 	if client.Closed() {
 	if client.Closed() {
 		return nil, ErrClosedClient
 		return nil, ErrClosedClient
@@ -54,6 +61,9 @@ func NewOffsetManagerFromClient(group string, client Client) (OffsetManager, err
 		ticker: time.NewTicker(conf.Consumer.Offsets.CommitInterval),
 		ticker: time.NewTicker(conf.Consumer.Offsets.CommitInterval),
 		poms:   make(map[string]map[int32]*partitionOffsetManager),
 		poms:   make(map[string]map[int32]*partitionOffsetManager),
 
 
+		memberID:   memberID,
+		generation: generation,
+
 		closing: make(chan none),
 		closing: make(chan none),
 		closed:  make(chan none),
 		closed:  make(chan none),
 	}
 	}
@@ -245,20 +255,22 @@ func (om *offsetManager) constructRequest() *OffsetCommitRequest {
 		r = &OffsetCommitRequest{
 		r = &OffsetCommitRequest{
 			Version:                 1,
 			Version:                 1,
 			ConsumerGroup:           om.group,
 			ConsumerGroup:           om.group,
-			ConsumerGroupGeneration: GroupGenerationUndefined,
+			ConsumerID:              om.memberID,
+			ConsumerGroupGeneration: om.generation,
 		}
 		}
 	} else {
 	} else {
 		r = &OffsetCommitRequest{
 		r = &OffsetCommitRequest{
 			Version:                 2,
 			Version:                 2,
 			RetentionTime:           int64(om.conf.Consumer.Offsets.Retention / time.Millisecond),
 			RetentionTime:           int64(om.conf.Consumer.Offsets.Retention / time.Millisecond),
 			ConsumerGroup:           om.group,
 			ConsumerGroup:           om.group,
-			ConsumerGroupGeneration: GroupGenerationUndefined,
+			ConsumerID:              om.memberID,
+			ConsumerGroupGeneration: om.generation,
 		}
 		}
 
 
 	}
 	}
 
 
-	om.pomsLock.Lock()
-	defer om.pomsLock.Unlock()
+	om.pomsLock.RLock()
+	defer om.pomsLock.RUnlock()
 
 
 	for _, topicManagers := range om.poms {
 	for _, topicManagers := range om.poms {
 		for _, pom := range topicManagers {
 		for _, pom := range topicManagers {
@@ -278,8 +290,8 @@ func (om *offsetManager) constructRequest() *OffsetCommitRequest {
 }
 }
 
 
 func (om *offsetManager) handleResponse(broker *Broker, req *OffsetCommitRequest, resp *OffsetCommitResponse) {
 func (om *offsetManager) handleResponse(broker *Broker, req *OffsetCommitRequest, resp *OffsetCommitResponse) {
-	om.pomsLock.Lock()
-	defer om.pomsLock.Unlock()
+	om.pomsLock.RLock()
+	defer om.pomsLock.RUnlock()
 
 
 	for _, topicManagers := range om.poms {
 	for _, topicManagers := range om.poms {
 		for _, pom := range topicManagers {
 		for _, pom := range topicManagers {
@@ -329,8 +341,8 @@ func (om *offsetManager) handleResponse(broker *Broker, req *OffsetCommitRequest
 }
 }
 
 
 func (om *offsetManager) handleError(err error) {
 func (om *offsetManager) handleError(err error) {
-	om.pomsLock.Lock()
-	defer om.pomsLock.Unlock()
+	om.pomsLock.RLock()
+	defer om.pomsLock.RUnlock()
 
 
 	for _, topicManagers := range om.poms {
 	for _, topicManagers := range om.poms {
 		for _, pom := range topicManagers {
 		for _, pom := range topicManagers {
@@ -340,8 +352,8 @@ func (om *offsetManager) handleError(err error) {
 }
 }
 
 
 func (om *offsetManager) asyncClosePOMs() {
 func (om *offsetManager) asyncClosePOMs() {
-	om.pomsLock.Lock()
-	defer om.pomsLock.Unlock()
+	om.pomsLock.RLock()
+	defer om.pomsLock.RUnlock()
 
 
 	for _, topicManagers := range om.poms {
 	for _, topicManagers := range om.poms {
 		for _, pom := range topicManagers {
 		for _, pom := range topicManagers {
@@ -375,6 +387,18 @@ func (om *offsetManager) releasePOMs(force bool) (remaining int) {
 	return
 	return
 }
 }
 
 
+func (om *offsetManager) findPOM(topic string, partition int32) *partitionOffsetManager {
+	om.pomsLock.RLock()
+	defer om.pomsLock.RUnlock()
+
+	if partitions, ok := om.poms[topic]; ok {
+		if pom, ok := partitions[partition]; ok {
+			return pom
+		}
+	}
+	return nil
+}
+
 // Partition Offset Manager
 // Partition Offset Manager
 
 
 // PartitionOffsetManager uses Kafka to store and fetch consumed partition offsets. You MUST call Close()
 // PartitionOffsetManager uses Kafka to store and fetch consumed partition offsets. You MUST call Close()

+ 11 - 0
record_batch.go

@@ -7,6 +7,7 @@ import (
 	"io/ioutil"
 	"io/ioutil"
 	"time"
 	"time"
 
 
+	"github.com/DataDog/zstd"
 	"github.com/eapache/go-xerial-snappy"
 	"github.com/eapache/go-xerial-snappy"
 	"github.com/pierrec/lz4"
 	"github.com/pierrec/lz4"
 )
 )
@@ -193,6 +194,10 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) {
 		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
 		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
 			return err
 			return err
 		}
 		}
+	case CompressionZSTD:
+		if recBuffer, err = zstd.Decompress(nil, recBuffer); err != nil {
+			return err
+		}
 	default:
 	default:
 		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
 		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
 	}
 	}
@@ -248,6 +253,12 @@ func (b *RecordBatch) encodeRecords(pe packetEncoder) error {
 			return err
 			return err
 		}
 		}
 		b.compressedRecords = buf.Bytes()
 		b.compressedRecords = buf.Bytes()
+	case CompressionZSTD:
+		c, err := zstd.CompressLevel(nil, raw, b.CompressionLevel)
+		if err != nil {
+			return err
+		}
+		b.compressedRecords = c
 	default:
 	default:
 		return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
 		return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
 	}
 	}