Sfoglia il codice sorgente

Merge branch 'master' into add-consumer-batchsize-metric

Varun 7 anni fa
parent
commit
af26fcd370

+ 3 - 3
.travis.yml

@@ -2,6 +2,7 @@ language: go
 go:
 - 1.10.x
 - 1.11.x
+- 1.12.x
 
 env:
   global:
@@ -11,9 +12,8 @@ env:
   - KAFKA_HOSTNAME=localhost
   - DEBUG=true
   matrix:
-  - KAFKA_VERSION=1.1.1 KAFKA_SCALA_VERSION=2.11
   - KAFKA_VERSION=2.0.1 KAFKA_SCALA_VERSION=2.12
-  - KAFKA_VERSION=2.1.0 KAFKA_SCALA_VERSION=2.12
+  - KAFKA_VERSION=2.1.1 KAFKA_SCALA_VERSION=2.12
 
 before_install:
 - export REPOSITORY_ROOT=${TRAVIS_BUILD_DIR}
@@ -27,7 +27,7 @@ script:
 - make test
 - make vet
 - make errcheck
-- if [[ "$TRAVIS_GO_VERSION" == 1.11* ]]; then make fmt; fi
+- if [[ "$TRAVIS_GO_VERSION" == 1.12* ]]; then make fmt; fi
 
 after_success:
 - bash <(curl -s https://codecov.io/bash)

+ 36 - 0
CHANGELOG.md

@@ -1,5 +1,41 @@
 # Changelog
 
+#### Version 1.21.0 (2019-02-24)
+
+New Features:
+- Add CreateAclRequest, DescribeAclRequest, DeleteAclRequest
+  ([1236](https://github.com/Shopify/sarama/pull/1236)).
+- Add DescribeTopic, DescribeConsumerGroup, ListConsumerGroups, ListConsumerGroupOffsets admin requests
+  ([1178](https://github.com/Shopify/sarama/pull/1178)).
+- Implement SASL/OAUTHBEARER
+  ([1240](https://github.com/Shopify/sarama/pull/1240)).
+
+Improvements:
+- Add Go mod support
+  ([1282](https://github.com/Shopify/sarama/pull/1282)).
+- Add error codes 73—76
+  ([1239](https://github.com/Shopify/sarama/pull/1239)).
+- Add retry backoff function
+  ([1160](https://github.com/Shopify/sarama/pull/1160)).
+- Maintain metadata in the producer even when retries are disabled
+  ([1189](https://github.com/Shopify/sarama/pull/1189)).
+- Include ReplicaAssignment in ListTopics
+  ([1274](https://github.com/Shopify/sarama/pull/1274)).
+- Add producer performance tool
+  ([1222](https://github.com/Shopify/sarama/pull/1222)).
+- Add support LogAppend timestamps
+  ([1258](https://github.com/Shopify/sarama/pull/1258)).
+
+Bug Fixes:
+- Fix potential deadlock when a heartbeat request fails
+  ([1286](https://github.com/Shopify/sarama/pull/1286)).
+- Fix consuming compacted topic
+  ([1227](https://github.com/Shopify/sarama/pull/1227)).
+- Set correct Kafka version for DescribeConfigsRequest v1
+  ([1277](https://github.com/Shopify/sarama/pull/1277)).
+- Update kafka test version
+  ([1273](https://github.com/Shopify/sarama/pull/1273)).
+
 #### Version 1.20.1 (2019-01-10)
 
 New Features:

+ 2 - 0
Makefile

@@ -1,3 +1,5 @@
+export GO111MODULE=on
+
 default: fmt vet errcheck test
 
 # Taken from https://github.com/codecov/example-go#caveat-multiple-files

+ 1 - 1
README.md

@@ -21,7 +21,7 @@ You might also want to look at the [Frequently Asked Questions](https://github.c
 Sarama provides a "2 releases + 2 months" compatibility guarantee: we support
 the two latest stable releases of Kafka and Go, and we provide a two month
 grace period for older releases. This means we currently officially support
-Go 1.8 through 1.11, and Kafka 1.0 through 2.0, although older releases are
+Go 1.10 through 1.12, and Kafka 2.0 through 2.2, although older releases are
 still likely to work.
 
 Sarama follows semantic versioning and provides API stability via the gopkg.in service.

+ 21 - 5
acl_bindings.go

@@ -1,17 +1,26 @@
 package sarama
 
 type Resource struct {
-	ResourceType AclResourceType
-	ResourceName string
+	ResourceType       AclResourceType
+	ResourceName       string
+	ResoucePatternType AclResourcePatternType
 }
 
-func (r *Resource) encode(pe packetEncoder) error {
+func (r *Resource) encode(pe packetEncoder, version int16) error {
 	pe.putInt8(int8(r.ResourceType))
 
 	if err := pe.putString(r.ResourceName); err != nil {
 		return err
 	}
 
+	if version == 1 {
+		if r.ResoucePatternType == AclPatternUnknown {
+			Logger.Print("Cannot encode an unknown resource pattern type, using Literal instead")
+			r.ResoucePatternType = AclPatternLiteral
+		}
+		pe.putInt8(int8(r.ResoucePatternType))
+	}
+
 	return nil
 }
 
@@ -25,6 +34,13 @@ func (r *Resource) decode(pd packetDecoder, version int16) (err error) {
 	if r.ResourceName, err = pd.getString(); err != nil {
 		return err
 	}
+	if version == 1 {
+		pattern, err := pd.getInt8()
+		if err != nil {
+			return err
+		}
+		r.ResoucePatternType = AclResourcePatternType(pattern)
+	}
 
 	return nil
 }
@@ -80,8 +96,8 @@ type ResourceAcls struct {
 	Acls []*Acl
 }
 
-func (r *ResourceAcls) encode(pe packetEncoder) error {
-	if err := r.Resource.encode(pe); err != nil {
+func (r *ResourceAcls) encode(pe packetEncoder, version int16) error {
+	if err := r.Resource.encode(pe, version); err != nil {
 		return err
 	}
 

+ 12 - 5
acl_create_request.go

@@ -1,6 +1,7 @@
 package sarama
 
 type CreateAclsRequest struct {
+	Version      int16
 	AclCreations []*AclCreation
 }
 
@@ -10,7 +11,7 @@ func (c *CreateAclsRequest) encode(pe packetEncoder) error {
 	}
 
 	for _, aclCreation := range c.AclCreations {
-		if err := aclCreation.encode(pe); err != nil {
+		if err := aclCreation.encode(pe, c.Version); err != nil {
 			return err
 		}
 	}
@@ -19,6 +20,7 @@ func (c *CreateAclsRequest) encode(pe packetEncoder) error {
 }
 
 func (c *CreateAclsRequest) decode(pd packetDecoder, version int16) (err error) {
+	c.Version = version
 	n, err := pd.getArrayLength()
 	if err != nil {
 		return err
@@ -41,11 +43,16 @@ func (d *CreateAclsRequest) key() int16 {
 }
 
 func (d *CreateAclsRequest) version() int16 {
-	return 0
+	return d.Version
 }
 
 func (d *CreateAclsRequest) requiredVersion() KafkaVersion {
-	return V0_11_0_0
+	switch d.Version {
+	case 1:
+		return V2_0_0_0
+	default:
+		return V0_11_0_0
+	}
 }
 
 type AclCreation struct {
@@ -53,8 +60,8 @@ type AclCreation struct {
 	Acl
 }
 
-func (a *AclCreation) encode(pe packetEncoder) error {
-	if err := a.Resource.encode(pe); err != nil {
+func (a *AclCreation) encode(pe packetEncoder, version int16) error {
+	if err := a.Resource.encode(pe, version); err != nil {
 		return err
 	}
 	if err := a.Acl.encode(pe); err != nil {

+ 33 - 1
acl_create_request_test.go

@@ -12,10 +12,21 @@ var (
 		2, // all
 		2, // deny
 	}
+	aclCreateRequestv1 = []byte{
+		0, 0, 0, 1,
+		3, // resource type = group
+		0, 5, 'g', 'r', 'o', 'u', 'p',
+		3, // resource pattten type = literal
+		0, 9, 'p', 'r', 'i', 'n', 'c', 'i', 'p', 'a', 'l',
+		0, 4, 'h', 'o', 's', 't',
+		2, // all
+		2, // deny
+	}
 )
 
-func TestCreateAclsRequest(t *testing.T) {
+func TestCreateAclsRequestv0(t *testing.T) {
 	req := &CreateAclsRequest{
+		Version: 0,
 		AclCreations: []*AclCreation{{
 			Resource: Resource{
 				ResourceType: AclResourceGroup,
@@ -32,3 +43,24 @@ func TestCreateAclsRequest(t *testing.T) {
 
 	testRequest(t, "create request", req, aclCreateRequest)
 }
+
+func TestCreateAclsRequestv1(t *testing.T) {
+	req := &CreateAclsRequest{
+		Version: 1,
+		AclCreations: []*AclCreation{{
+			Resource: Resource{
+				ResourceType:       AclResourceGroup,
+				ResourceName:       "group",
+				ResoucePatternType: AclPatternLiteral,
+			},
+			Acl: Acl{
+				Principal:      "principal",
+				Host:           "host",
+				Operation:      AclOperationAll,
+				PermissionType: AclPermissionDeny,
+			}},
+		},
+	}
+
+	testRequest(t, "create request v1", req, aclCreateRequestv1)
+}

+ 11 - 2
acl_delete_request.go

@@ -1,6 +1,7 @@
 package sarama
 
 type DeleteAclsRequest struct {
+	Version int
 	Filters []*AclFilter
 }
 
@@ -10,6 +11,7 @@ func (d *DeleteAclsRequest) encode(pe packetEncoder) error {
 	}
 
 	for _, filter := range d.Filters {
+		filter.Version = d.Version
 		if err := filter.encode(pe); err != nil {
 			return err
 		}
@@ -19,6 +21,7 @@ func (d *DeleteAclsRequest) encode(pe packetEncoder) error {
 }
 
 func (d *DeleteAclsRequest) decode(pd packetDecoder, version int16) (err error) {
+	d.Version = int(version)
 	n, err := pd.getArrayLength()
 	if err != nil {
 		return err
@@ -27,6 +30,7 @@ func (d *DeleteAclsRequest) decode(pd packetDecoder, version int16) (err error)
 	d.Filters = make([]*AclFilter, n)
 	for i := 0; i < n; i++ {
 		d.Filters[i] = new(AclFilter)
+		d.Filters[i].Version = int(version)
 		if err := d.Filters[i].decode(pd, version); err != nil {
 			return err
 		}
@@ -40,9 +44,14 @@ func (d *DeleteAclsRequest) key() int16 {
 }
 
 func (d *DeleteAclsRequest) version() int16 {
-	return 0
+	return int16(d.Version)
 }
 
 func (d *DeleteAclsRequest) requiredVersion() KafkaVersion {
-	return V0_11_0_0
+	switch d.Version {
+	case 1:
+		return V2_0_0_0
+	default:
+		return V0_11_0_0
+	}
 }

+ 43 - 0
acl_delete_request_test.go

@@ -3,6 +3,28 @@ package sarama
 import "testing"
 
 var (
+	aclDeleteRequestNullsv1 = []byte{
+		0, 0, 0, 1,
+		1,
+		255, 255,
+		1, // Any
+		255, 255,
+		255, 255,
+		11,
+		3,
+	}
+
+	aclDeleteRequestv1 = []byte{
+		0, 0, 0, 1,
+		1, // any
+		0, 6, 'f', 'i', 'l', 't', 'e', 'r',
+		1, // Any Filter
+		0, 9, 'p', 'r', 'i', 'n', 'c', 'i', 'p', 'a', 'l',
+		0, 4, 'h', 'o', 's', 't',
+		4, // write
+		3, // allow
+	}
+
 	aclDeleteRequestNulls = []byte{
 		0, 0, 0, 1,
 		1,
@@ -67,3 +89,24 @@ func TestDeleteAclsRequest(t *testing.T) {
 
 	testRequest(t, "delete request array", req, aclDeleteRequestArray)
 }
+
+func TestDeleteAclsRequestV1(t *testing.T) {
+	req := &DeleteAclsRequest{
+		Version: 1,
+		Filters: []*AclFilter{{
+			ResourceType:              AclResourceAny,
+			Operation:                 AclOperationAlterConfigs,
+			PermissionType:            AclPermissionAllow,
+			ResourcePatternTypeFilter: AclPatternAny,
+		}},
+	}
+
+	testRequest(t, "delete request nulls", req, aclDeleteRequestNullsv1)
+
+	req.Filters[0].ResourceName = nullString("filter")
+	req.Filters[0].Principal = nullString("principal")
+	req.Filters[0].Host = nullString("host")
+	req.Filters[0].Operation = AclOperationWrite
+
+	testRequest(t, "delete request", req, aclDeleteRequestv1)
+}

+ 7 - 6
acl_delete_response.go

@@ -3,6 +3,7 @@ package sarama
 import "time"
 
 type DeleteAclsResponse struct {
+	Version         int16
 	ThrottleTime    time.Duration
 	FilterResponses []*FilterResponse
 }
@@ -15,7 +16,7 @@ func (a *DeleteAclsResponse) encode(pe packetEncoder) error {
 	}
 
 	for _, filterResponse := range a.FilterResponses {
-		if err := filterResponse.encode(pe); err != nil {
+		if err := filterResponse.encode(pe, a.Version); err != nil {
 			return err
 		}
 	}
@@ -51,7 +52,7 @@ func (d *DeleteAclsResponse) key() int16 {
 }
 
 func (d *DeleteAclsResponse) version() int16 {
-	return 0
+	return int16(d.Version)
 }
 
 func (d *DeleteAclsResponse) requiredVersion() KafkaVersion {
@@ -64,7 +65,7 @@ type FilterResponse struct {
 	MatchingAcls []*MatchingAcl
 }
 
-func (f *FilterResponse) encode(pe packetEncoder) error {
+func (f *FilterResponse) encode(pe packetEncoder, version int16) error {
 	pe.putInt16(int16(f.Err))
 	if err := pe.putNullableString(f.ErrMsg); err != nil {
 		return err
@@ -74,7 +75,7 @@ func (f *FilterResponse) encode(pe packetEncoder) error {
 		return err
 	}
 	for _, matchingAcl := range f.MatchingAcls {
-		if err := matchingAcl.encode(pe); err != nil {
+		if err := matchingAcl.encode(pe, version); err != nil {
 			return err
 		}
 	}
@@ -115,13 +116,13 @@ type MatchingAcl struct {
 	Acl
 }
 
-func (m *MatchingAcl) encode(pe packetEncoder) error {
+func (m *MatchingAcl) encode(pe packetEncoder, version int16) error {
 	pe.putInt16(int16(m.Err))
 	if err := pe.putNullableString(m.ErrMsg); err != nil {
 		return err
 	}
 
-	if err := m.Resource.encode(pe); err != nil {
+	if err := m.Resource.encode(pe, version); err != nil {
 		return err
 	}
 

+ 11 - 2
acl_describe_request.go

@@ -1,14 +1,18 @@
 package sarama
 
 type DescribeAclsRequest struct {
+	Version int
 	AclFilter
 }
 
 func (d *DescribeAclsRequest) encode(pe packetEncoder) error {
+	d.AclFilter.Version = d.Version
 	return d.AclFilter.encode(pe)
 }
 
 func (d *DescribeAclsRequest) decode(pd packetDecoder, version int16) (err error) {
+	d.Version = int(version)
+	d.AclFilter.Version = int(version)
 	return d.AclFilter.decode(pd, version)
 }
 
@@ -17,9 +21,14 @@ func (d *DescribeAclsRequest) key() int16 {
 }
 
 func (d *DescribeAclsRequest) version() int16 {
-	return 0
+	return int16(d.Version)
 }
 
 func (d *DescribeAclsRequest) requiredVersion() KafkaVersion {
-	return V0_11_0_0
+	switch d.Version {
+	case 1:
+		return V2_0_0_0
+	default:
+		return V0_11_0_0
+	}
 }

+ 32 - 2
acl_describe_request_test.go

@@ -13,15 +13,24 @@ var (
 		5, // acl operation
 		3, // acl permission type
 	}
+	aclDescribeRequestV1 = []byte{
+		2, // resource type
+		0, 5, 't', 'o', 'p', 'i', 'c',
+		1, // any Type
+		0, 9, 'p', 'r', 'i', 'n', 'c', 'i', 'p', 'a', 'l',
+		0, 4, 'h', 'o', 's', 't',
+		5, // acl operation
+		3, // acl permission type
+	}
 )
 
-func TestAclDescribeRequest(t *testing.T) {
+func TestAclDescribeRequestV0(t *testing.T) {
 	resourcename := "topic"
 	principal := "principal"
 	host := "host"
 
 	req := &DescribeAclsRequest{
-		AclFilter{
+		AclFilter: AclFilter{
 			ResourceType:   AclResourceTopic,
 			ResourceName:   &resourcename,
 			Principal:      &principal,
@@ -33,3 +42,24 @@ func TestAclDescribeRequest(t *testing.T) {
 
 	testRequest(t, "", req, aclDescribeRequest)
 }
+
+func TestAclDescribeRequestV1(t *testing.T) {
+	resourcename := "topic"
+	principal := "principal"
+	host := "host"
+
+	req := &DescribeAclsRequest{
+		Version: 1,
+		AclFilter: AclFilter{
+			ResourceType:              AclResourceTopic,
+			ResourceName:              &resourcename,
+			ResourcePatternTypeFilter: AclPatternAny,
+			Principal:                 &principal,
+			Host:                      &host,
+			Operation:                 AclOperationCreate,
+			PermissionType:            AclPermissionAllow,
+		},
+	}
+
+	testRequest(t, "", req, aclDescribeRequestV1)
+}

+ 9 - 3
acl_describe_response.go

@@ -3,6 +3,7 @@ package sarama
 import "time"
 
 type DescribeAclsResponse struct {
+	Version      int16
 	ThrottleTime time.Duration
 	Err          KError
 	ErrMsg       *string
@@ -22,7 +23,7 @@ func (d *DescribeAclsResponse) encode(pe packetEncoder) error {
 	}
 
 	for _, resourceAcl := range d.ResourceAcls {
-		if err := resourceAcl.encode(pe); err != nil {
+		if err := resourceAcl.encode(pe, d.Version); err != nil {
 			return err
 		}
 	}
@@ -72,9 +73,14 @@ func (d *DescribeAclsResponse) key() int16 {
 }
 
 func (d *DescribeAclsResponse) version() int16 {
-	return 0
+	return int16(d.Version)
 }
 
 func (d *DescribeAclsResponse) requiredVersion() KafkaVersion {
-	return V0_11_0_0
+	switch d.Version {
+	case 1:
+		return V2_0_0_0
+	default:
+		return V0_11_0_0
+	}
 }

+ 23 - 6
acl_filter.go

@@ -1,12 +1,14 @@
 package sarama
 
 type AclFilter struct {
-	ResourceType   AclResourceType
-	ResourceName   *string
-	Principal      *string
-	Host           *string
-	Operation      AclOperation
-	PermissionType AclPermissionType
+	Version                   int
+	ResourceType              AclResourceType
+	ResourceName              *string
+	ResourcePatternTypeFilter AclResourcePatternType
+	Principal                 *string
+	Host                      *string
+	Operation                 AclOperation
+	PermissionType            AclPermissionType
 }
 
 func (a *AclFilter) encode(pe packetEncoder) error {
@@ -14,6 +16,11 @@ func (a *AclFilter) encode(pe packetEncoder) error {
 	if err := pe.putNullableString(a.ResourceName); err != nil {
 		return err
 	}
+
+	if a.Version == 1 {
+		pe.putInt8(int8(a.ResourcePatternTypeFilter))
+	}
+
 	if err := pe.putNullableString(a.Principal); err != nil {
 		return err
 	}
@@ -37,6 +44,16 @@ func (a *AclFilter) decode(pd packetDecoder, version int16) (err error) {
 		return err
 	}
 
+	if a.Version == 1 {
+		pattern, err := pd.getInt8()
+
+		if err != nil {
+			return err
+		}
+
+		a.ResourcePatternTypeFilter = AclResourcePatternType(pattern)
+	}
+
 	if a.Principal, err = pd.getNullableString(); err != nil {
 		return err
 	}

+ 12 - 0
acl_types.go

@@ -40,3 +40,15 @@ const (
 	AclResourceCluster         AclResourceType = 4
 	AclResourceTransactionalID AclResourceType = 5
 )
+
+type AclResourcePatternType int
+
+// ref: https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/apache/kafka/common/resource/PatternType.java
+
+const (
+	AclPatternUnknown AclResourcePatternType = iota
+	AclPatternAny
+	AclPatternMatch
+	AclPatternLiteral
+	AclPatternPrefixed
+)

+ 2 - 2
admin.go

@@ -166,7 +166,7 @@ func (ca *clusterAdmin) CreateTopic(topic string, detail *TopicDetail, validateO
 	}
 
 	if topicErr.Err != ErrNoError {
-		return topicErr.Err
+		return topicErr
 	}
 
 	return nil
@@ -358,7 +358,7 @@ func (ca *clusterAdmin) CreatePartitions(topic string, count int32, assignment [
 	}
 
 	if topicErr.Err != ErrNoError {
-		return topicErr.Err
+		return topicErr
 	}
 
 	return nil

+ 35 - 4
admin_test.go

@@ -2,6 +2,7 @@ package sarama
 
 import (
 	"errors"
+	"strings"
 	"testing"
 )
 
@@ -105,7 +106,7 @@ func TestClusterAdminCreateTopicWithInvalidTopicDetail(t *testing.T) {
 	}
 }
 
-func TestClusterAdminCreateTopicWithDiffVersion(t *testing.T) {
+func TestClusterAdminCreateTopicWithoutAuthorization(t *testing.T) {
 	seedBroker := NewMockBroker(t, 1)
 	defer seedBroker.Close()
 
@@ -118,16 +119,17 @@ func TestClusterAdminCreateTopicWithDiffVersion(t *testing.T) {
 
 	config := NewConfig()
 	config.Version = V0_11_0_0
+
 	admin, err := NewClusterAdmin([]string{seedBroker.Addr()}, config)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	err = admin.CreateTopic("my_topic", &TopicDetail{NumPartitions: 1, ReplicationFactor: 1}, false)
-	if err != ErrInsufficientData {
+	err = admin.CreateTopic("_internal_topic", &TopicDetail{NumPartitions: 1, ReplicationFactor: 1}, false)
+	want := "insufficient permissions to create topic with reserved prefix"
+	if !strings.HasSuffix(err.Error(), want) {
 		t.Fatal(err)
 	}
-
 	err = admin.Close()
 	if err != nil {
 		t.Fatal(err)
@@ -301,6 +303,35 @@ func TestClusterAdminCreatePartitionsWithDiffVersion(t *testing.T) {
 	}
 }
 
+func TestClusterAdminCreatePartitionsWithoutAuthorization(t *testing.T) {
+	seedBroker := NewMockBroker(t, 1)
+	defer seedBroker.Close()
+
+	seedBroker.SetHandlerByMap(map[string]MockResponse{
+		"MetadataRequest": NewMockMetadataResponse(t).
+			SetController(seedBroker.BrokerID()).
+			SetBroker(seedBroker.Addr(), seedBroker.BrokerID()),
+		"CreatePartitionsRequest": NewMockCreatePartitionsResponse(t),
+	})
+
+	config := NewConfig()
+	config.Version = V1_0_0_0
+	admin, err := NewClusterAdmin([]string{seedBroker.Addr()}, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = admin.CreatePartitions("_internal_topic", 3, nil, false)
+	want := "insufficient permissions to create partition on topic with reserved prefix"
+	if !strings.HasSuffix(err.Error(), want) {
+		t.Fatal(err)
+	}
+	err = admin.Close()
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
 func TestClusterAdminDeleteRecords(t *testing.T) {
 	seedBroker := NewMockBroker(t, 1)
 	defer seedBroker.Close()

+ 67 - 22
async_producer.go

@@ -92,9 +92,8 @@ func newTransactionManager(conf *Config, client Client) (*transactionManager, er
 }
 
 type asyncProducer struct {
-	client    Client
-	conf      *Config
-	ownClient bool
+	client Client
+	conf   *Config
 
 	errors                    chan *ProducerError
 	input, successes, retries chan *ProducerMessage
@@ -113,18 +112,19 @@ func NewAsyncProducer(addrs []string, conf *Config) (AsyncProducer, error) {
 	if err != nil {
 		return nil, err
 	}
-
-	p, err := NewAsyncProducerFromClient(client)
-	if err != nil {
-		return nil, err
-	}
-	p.(*asyncProducer).ownClient = true
-	return p, nil
+	return newAsyncProducer(client)
 }
 
 // NewAsyncProducerFromClient creates a new Producer using the given client. It is still
 // necessary to call Close() on the underlying client when shutting down this producer.
 func NewAsyncProducerFromClient(client Client) (AsyncProducer, error) {
+	// For clients passed in by the client, ensure we don't
+	// call Close() on it.
+	cli := &nopCloserClient{client}
+	return newAsyncProducer(cli)
+}
+
+func newAsyncProducer(client Client) (AsyncProducer, error) {
 	// Check that we are not dealing with a closed Client before processing any other arguments
 	if client.Closed() {
 		return nil, ErrClosedClient
@@ -483,6 +483,19 @@ func (p *asyncProducer) newPartitionProducer(topic string, partition int32) chan
 	return input
 }
 
+func (pp *partitionProducer) backoff(retries int) {
+	var backoff time.Duration
+	if pp.parent.conf.Producer.Retry.BackoffFunc != nil {
+		maxRetries := pp.parent.conf.Producer.Retry.Max
+		backoff = pp.parent.conf.Producer.Retry.BackoffFunc(retries, maxRetries)
+	} else {
+		backoff = pp.parent.conf.Producer.Retry.Backoff
+	}
+	if backoff > 0 {
+		time.Sleep(backoff)
+	}
+}
+
 func (pp *partitionProducer) dispatch() {
 	// try to prefetch the leader; if this doesn't work, we'll do a proper call to `updateLeader`
 	// on the first message
@@ -493,11 +506,31 @@ func (pp *partitionProducer) dispatch() {
 		pp.brokerProducer.input <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: syn}
 	}
 
+	defer func() {
+		if pp.brokerProducer != nil {
+			pp.parent.unrefBrokerProducer(pp.leader, pp.brokerProducer)
+		}
+	}()
+
 	for msg := range pp.input {
+
+		if pp.brokerProducer != nil && pp.brokerProducer.abandoned != nil {
+			select {
+			case <-pp.brokerProducer.abandoned:
+				// a message on the abandoned channel means that our current broker selection is out of date
+				Logger.Printf("producer/leader/%s/%d abandoning broker %d\n", pp.topic, pp.partition, pp.leader.ID())
+				pp.parent.unrefBrokerProducer(pp.leader, pp.brokerProducer)
+				pp.brokerProducer = nil
+				time.Sleep(pp.parent.conf.Producer.Retry.Backoff)
+			default:
+				// producer connection is still open.
+			}
+		}
+
 		if msg.retries > pp.highWatermark {
 			// a new, higher, retry level; handle it and then back off
 			pp.newHighWatermark(msg.retries)
-			time.Sleep(pp.parent.conf.Producer.Retry.Backoff)
+			pp.backoff(msg.retries)
 		} else if pp.highWatermark > 0 {
 			// we are retrying something (else highWatermark would be 0) but this message is not a *new* retry level
 			if msg.retries < pp.highWatermark {
@@ -525,7 +558,7 @@ func (pp *partitionProducer) dispatch() {
 		if pp.brokerProducer == nil {
 			if err := pp.updateLeader(); err != nil {
 				pp.parent.returnError(msg, err)
-				time.Sleep(pp.parent.conf.Producer.Retry.Backoff)
+				pp.backoff(msg.retries)
 				continue
 			}
 			Logger.Printf("producer/leader/%s/%d selected broker %d\n", pp.topic, pp.partition, pp.leader.ID())
@@ -533,10 +566,6 @@ func (pp *partitionProducer) dispatch() {
 
 		pp.brokerProducer.input <- msg
 	}
-
-	if pp.brokerProducer != nil {
-		pp.parent.unrefBrokerProducer(pp.leader, pp.brokerProducer)
-	}
 }
 
 func (pp *partitionProducer) newHighWatermark(hwm int) {
@@ -637,6 +666,10 @@ func (p *asyncProducer) newBrokerProducer(broker *Broker) *brokerProducer {
 		close(responses)
 	})
 
+	if p.conf.Producer.Retry.Max <= 0 {
+		bp.abandoned = make(chan struct{})
+	}
+
 	return bp
 }
 
@@ -655,6 +688,7 @@ type brokerProducer struct {
 	input     chan *ProducerMessage
 	output    chan<- *produceSet
 	responses <-chan *brokerProducerResponse
+	abandoned chan struct{}
 
 	buffer     *produceSet
 	timer      <-chan time.Time
@@ -829,9 +863,17 @@ func (bp *brokerProducer) handleSuccess(sent *produceSet, response *ProduceRespo
 		// Retriable errors
 		case ErrInvalidMessage, ErrUnknownTopicOrPartition, ErrLeaderNotAvailable, ErrNotLeaderForPartition,
 			ErrRequestTimedOut, ErrNotEnoughReplicas, ErrNotEnoughReplicasAfterAppend:
-			retryTopics = append(retryTopics, topic)
+			if bp.parent.conf.Producer.Retry.Max <= 0 {
+				bp.parent.abandonBrokerConnection(bp.broker)
+				bp.parent.returnErrors(pSet.msgs, block.Err)
+			} else {
+				retryTopics = append(retryTopics, topic)
+			}
 		// Other non-retriable errors
 		default:
+			if bp.parent.conf.Producer.Retry.Max <= 0 {
+				bp.parent.abandonBrokerConnection(bp.broker)
+			}
 			bp.parent.returnErrors(pSet.msgs, block.Err)
 		}
 	})
@@ -957,11 +999,9 @@ func (p *asyncProducer) shutdown() {
 
 	p.inFlight.Wait()
 
-	if p.ownClient {
-		err := p.client.Close()
-		if err != nil {
-			Logger.Println("producer/shutdown failed to close the embedded client:", err)
-		}
+	err := p.client.Close()
+	if err != nil {
+		Logger.Println("producer/shutdown failed to close the embedded client:", err)
 	}
 
 	close(p.input)
@@ -1048,5 +1088,10 @@ func (p *asyncProducer) abandonBrokerConnection(broker *Broker) {
 	p.brokerLock.Lock()
 	defer p.brokerLock.Unlock()
 
+	bc, ok := p.brokers[broker]
+	if ok && bc.abandoned != nil {
+		close(bc.abandoned)
+	}
+
 	delete(p.brokers, broker)
 }

+ 130 - 0
async_producer_test.go

@@ -6,6 +6,7 @@ import (
 	"os"
 	"os/signal"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -307,6 +308,68 @@ func TestAsyncProducerFailureRetry(t *testing.T) {
 	closeProducer(t, producer)
 }
 
+func TestAsyncProducerRecoveryWithRetriesDisabled(t *testing.T) {
+
+	tt := func(t *testing.T, kErr KError) {
+		seedBroker := NewMockBroker(t, 1)
+		leader1 := NewMockBroker(t, 2)
+		leader2 := NewMockBroker(t, 3)
+
+		metadataLeader1 := new(MetadataResponse)
+		metadataLeader1.AddBroker(leader1.Addr(), leader1.BrokerID())
+		metadataLeader1.AddTopicPartition("my_topic", 0, leader1.BrokerID(), nil, nil, ErrNoError)
+		metadataLeader1.AddTopicPartition("my_topic", 1, leader1.BrokerID(), nil, nil, ErrNoError)
+		seedBroker.Returns(metadataLeader1)
+
+		config := NewConfig()
+		config.Producer.Flush.Messages = 2
+		config.Producer.Return.Successes = true
+		config.Producer.Retry.Max = 0 // disable!
+		config.Producer.Retry.Backoff = 0
+		config.Producer.Partitioner = NewManualPartitioner
+		producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config)
+		if err != nil {
+			t.Fatal(err)
+		}
+		seedBroker.Close()
+
+		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage), Partition: 0}
+		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage), Partition: 1}
+		prodNotLeader := new(ProduceResponse)
+		prodNotLeader.AddTopicPartition("my_topic", 0, kErr)
+		prodNotLeader.AddTopicPartition("my_topic", 1, kErr)
+		leader1.Returns(prodNotLeader)
+		expectResults(t, producer, 0, 2)
+
+		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage), Partition: 0}
+		metadataLeader2 := new(MetadataResponse)
+		metadataLeader2.AddBroker(leader2.Addr(), leader2.BrokerID())
+		metadataLeader2.AddTopicPartition("my_topic", 0, leader2.BrokerID(), nil, nil, ErrNoError)
+		metadataLeader2.AddTopicPartition("my_topic", 1, leader2.BrokerID(), nil, nil, ErrNoError)
+		leader1.Returns(metadataLeader2)
+		leader1.Returns(metadataLeader2)
+
+		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage), Partition: 1}
+		prodSuccess := new(ProduceResponse)
+		prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError)
+		prodSuccess.AddTopicPartition("my_topic", 1, ErrNoError)
+		leader2.Returns(prodSuccess)
+		expectResults(t, producer, 2, 0)
+
+		leader1.Close()
+		leader2.Close()
+		closeProducer(t, producer)
+	}
+
+	t.Run("retriable error", func(t *testing.T) {
+		tt(t, ErrNotLeaderForPartition)
+	})
+
+	t.Run("non-retriable error", func(t *testing.T) {
+		tt(t, ErrNotController)
+	})
+}
+
 func TestAsyncProducerEncoderFailures(t *testing.T) {
 	seedBroker := NewMockBroker(t, 1)
 	leader := NewMockBroker(t, 2)
@@ -485,6 +548,73 @@ func TestAsyncProducerMultipleRetries(t *testing.T) {
 	closeProducer(t, producer)
 }
 
+func TestAsyncProducerMultipleRetriesWithBackoffFunc(t *testing.T) {
+	seedBroker := NewMockBroker(t, 1)
+	leader1 := NewMockBroker(t, 2)
+	leader2 := NewMockBroker(t, 3)
+
+	metadataLeader1 := new(MetadataResponse)
+	metadataLeader1.AddBroker(leader1.Addr(), leader1.BrokerID())
+	metadataLeader1.AddTopicPartition("my_topic", 0, leader1.BrokerID(), nil, nil, ErrNoError)
+	seedBroker.Returns(metadataLeader1)
+
+	config := NewConfig()
+	config.Producer.Flush.Messages = 1
+	config.Producer.Return.Successes = true
+	config.Producer.Retry.Max = 4
+
+	backoffCalled := make([]int32, config.Producer.Retry.Max+1)
+	config.Producer.Retry.BackoffFunc = func(retries, maxRetries int) time.Duration {
+		atomic.AddInt32(&backoffCalled[retries-1], 1)
+		return 0
+	}
+	producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)}
+	prodNotLeader := new(ProduceResponse)
+	prodNotLeader.AddTopicPartition("my_topic", 0, ErrNotLeaderForPartition)
+
+	prodSuccess := new(ProduceResponse)
+	prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError)
+
+	metadataLeader2 := new(MetadataResponse)
+	metadataLeader2.AddBroker(leader2.Addr(), leader2.BrokerID())
+	metadataLeader2.AddTopicPartition("my_topic", 0, leader2.BrokerID(), nil, nil, ErrNoError)
+
+	leader1.Returns(prodNotLeader)
+	seedBroker.Returns(metadataLeader2)
+	leader2.Returns(prodNotLeader)
+	seedBroker.Returns(metadataLeader1)
+	leader1.Returns(prodNotLeader)
+	seedBroker.Returns(metadataLeader1)
+	leader1.Returns(prodNotLeader)
+	seedBroker.Returns(metadataLeader2)
+	leader2.Returns(prodSuccess)
+
+	expectResults(t, producer, 1, 0)
+
+	producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)}
+	leader2.Returns(prodSuccess)
+	expectResults(t, producer, 1, 0)
+
+	seedBroker.Close()
+	leader1.Close()
+	leader2.Close()
+	closeProducer(t, producer)
+
+	for i := 0; i < config.Producer.Retry.Max; i++ {
+		if atomic.LoadInt32(&backoffCalled[i]) != 1 {
+			t.Errorf("expected one retry attempt #%d", i)
+		}
+	}
+	if atomic.LoadInt32(&backoffCalled[config.Producer.Retry.Max]) != 0 {
+		t.Errorf("expected no retry attempt #%d", config.Producer.Retry.Max)
+	}
+}
+
 func TestAsyncProducerOutOfRetries(t *testing.T) {
 	t.Skip("Enable once bug #294 is fixed.")
 

+ 118 - 5
broker.go

@@ -56,6 +56,10 @@ const (
 	SASLTypeOAuth = "OAUTHBEARER"
 	// SASLTypePlaintext represents the SASL/PLAIN mechanism
 	SASLTypePlaintext = "PLAIN"
+	// SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism.
+	SASLTypeSCRAMSHA256 = "SCRAM-SHA-256"
+	// SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism.
+	SASLTypeSCRAMSHA512 = "SCRAM-SHA-512"
 	// SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
 	// server negotiate SASL auth using opaque packets.
 	SASLHandshakeV0 = int16(0)
@@ -92,6 +96,20 @@ type AccessTokenProvider interface {
 	Token() (*AccessToken, error)
 }
 
+// SCRAMClient is a an interface to a SCRAM
+// client implementation.
+type SCRAMClient interface {
+	// Begin prepares the client for the SCRAM exchange
+	// with the server with a user name and a password
+	Begin(userName, password, authzID string) error
+	// Step steps client through the SCRAM exchange. It is
+	// called repeatedly until it errors or `Done` returns true.
+	Step(challenge string) (response string, err error)
+	// Done should return true when the SCRAM conversation
+	// is over.
+	Done() bool
+}
+
 type responsePromise struct {
 	requestTime   time.Time
 	correlationID int32
@@ -793,14 +811,19 @@ func (b *Broker) responseReceiver() {
 }
 
 func (b *Broker) authenticateViaSASL() error {
-	if b.conf.Net.SASL.Mechanism == SASLTypeOAuth {
+	switch b.conf.Net.SASL.Mechanism {
+	case SASLTypeOAuth:
 		return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
+	case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
+		return b.sendAndReceiveSASLSCRAMv1()
+	default:
+		return b.sendAndReceiveSASLPlainAuth()
 	}
-	return b.sendAndReceiveSASLPlainAuth()
+
 }
 
-func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
-	rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version}
+func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error {
+	rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}
 
 	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
 	buf, err := encode(req, b.conf.MetricRegistry)
@@ -846,7 +869,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) err
 		Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
 		return res.Err
 	}
-	Logger.Print("Successful SASL handshake")
+	Logger.Print("Successful SASL handshake. Available mechanisms: ", res.EnabledMechanisms)
 	return nil
 }
 
@@ -949,6 +972,96 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
 	return nil
 }
 
+func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
+	if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil {
+		return err
+	}
+
+	scramClient := b.conf.Net.SASL.SCRAMClient
+	if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
+		return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
+	}
+
+	msg, err := scramClient.Step("")
+	if err != nil {
+		return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error())
+
+	}
+
+	for !scramClient.Done() {
+		requestTime := time.Now()
+		correlationID := b.correlationID
+		bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg))
+		if err != nil {
+			Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
+			return err
+		}
+
+		b.updateOutgoingCommunicationMetrics(bytesWritten)
+		b.correlationID++
+		challenge, err := b.receiveSaslAuthenticateResponse(correlationID)
+		if err != nil {
+			Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
+			return err
+		}
+
+		b.updateIncomingCommunicationMetrics(len(challenge), time.Since(requestTime))
+		msg, err = scramClient.Step(string(challenge))
+		if err != nil {
+			Logger.Println("SASL authentication failed", err)
+			return err
+		}
+	}
+	Logger.Println("SASL authentication succeeded")
+	return nil
+}
+
+func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
+	rb := &SaslAuthenticateRequest{msg}
+	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
+	buf, err := encode(req, b.conf.MetricRegistry)
+	if err != nil {
+		return 0, err
+	}
+	if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
+		return 0, err
+	}
+	return b.conn.Write(buf)
+}
+
+func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
+	buf := make([]byte, responseLengthSize+correlationIDSize)
+	bytesRead, err := io.ReadFull(b.conn, buf)
+	if err != nil {
+		return nil, err
+	}
+	header := responseHeader{}
+	err = decode(buf, &header)
+	if err != nil {
+		return nil, err
+	}
+	if header.correlationID != correlationID {
+		return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
+	}
+	buf = make([]byte, header.length-correlationIDSize)
+	c, err := io.ReadFull(b.conn, buf)
+	bytesRead += c
+	if err != nil {
+		return nil, err
+	}
+	res := &SaslAuthenticateResponse{}
+	if err := versionedDecode(buf, res, 0); err != nil {
+		return nil, err
+	}
+	if err != nil {
+		return nil, err
+	}
+	if res.Err != ErrNoError {
+		return nil, res.Err
+	}
+	return res.SaslAuthBytes, nil
+}
+
 // Build SASL/OAUTHBEARER initial client response as described by RFC-7628
 // https://tools.ietf.org/html/rfc7628
 func buildClientInitialResponse(token *AccessToken) ([]byte, error) {

+ 135 - 6
broker_test.go

@@ -179,16 +179,12 @@ func TestSASLOAuthBearer(t *testing.T) {
 		// mockBroker mocks underlying network logic and broker responses
 		mockBroker := NewMockBroker(t, 0)
 
-		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
-			SetAuthBytes([]byte(`response_payload`))
-
+		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload"))
 		if test.mockAuthErr != ErrNoError {
 			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
 		}
 
-		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
-			SetEnabledMechanisms([]string{SASLTypeOAuth})
-
+		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth})
 		if test.mockHandshakeErr != ErrNoError {
 			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
 		}
@@ -248,6 +244,139 @@ func TestSASLOAuthBearer(t *testing.T) {
 	}
 }
 
+// A mock scram client.
+type MockSCRAMClient struct {
+	done bool
+}
+
+func (m *MockSCRAMClient) Begin(userName, password, authzID string) (err error) {
+	return nil
+}
+
+func (m *MockSCRAMClient) Step(challenge string) (response string, err error) {
+	if challenge == "" {
+		return "ping", nil
+	}
+	if challenge == "pong" {
+		m.done = true
+		return "", nil
+	}
+	return "", errors.New("failed to authenticate :(")
+}
+
+func (m *MockSCRAMClient) Done() bool {
+	return m.done
+}
+
+var _ SCRAMClient = &MockSCRAMClient{}
+
+func TestSASLSCRAMSHAXXX(t *testing.T) {
+	testTable := []struct {
+		name               string
+		mockHandshakeErr   KError
+		mockSASLAuthErr    KError
+		expectClientErr    bool
+		scramClient        *MockSCRAMClient
+		scramChallengeResp string
+	}{
+		{
+			name:               "SASL/SCRAMSHAXXX successfull authentication",
+			mockHandshakeErr:   ErrNoError,
+			scramClient:        &MockSCRAMClient{},
+			scramChallengeResp: "pong",
+		},
+		{
+			name:               "SASL/SCRAMSHAXXX SCRAM client step error client",
+			mockHandshakeErr:   ErrNoError,
+			mockSASLAuthErr:    ErrNoError,
+			scramClient:        &MockSCRAMClient{},
+			scramChallengeResp: "gong",
+			expectClientErr:    true,
+		},
+		{
+			name:               "SASL/SCRAMSHAXXX server authentication error",
+			mockHandshakeErr:   ErrNoError,
+			mockSASLAuthErr:    ErrSASLAuthenticationFailed,
+			scramClient:        &MockSCRAMClient{},
+			scramChallengeResp: "pong",
+		},
+		{
+			name:               "SASL/SCRAMSHAXXX unsupported SCRAM mechanism",
+			mockHandshakeErr:   ErrUnsupportedSASLMechanism,
+			mockSASLAuthErr:    ErrNoError,
+			scramClient:        &MockSCRAMClient{},
+			scramChallengeResp: "pong",
+		},
+	}
+
+	for i, test := range testTable {
+
+		// mockBroker mocks underlying network logic and broker responses
+		mockBroker := NewMockBroker(t, 0)
+		broker := NewBroker(mockBroker.Addr())
+		// broker executes SASL requests against mockBroker
+		broker.requestRate = metrics.NilMeter{}
+		broker.outgoingByteRate = metrics.NilMeter{}
+		broker.incomingByteRate = metrics.NilMeter{}
+		broker.requestSize = metrics.NilHistogram{}
+		broker.responseSize = metrics.NilHistogram{}
+		broker.responseRate = metrics.NilMeter{}
+		broker.requestLatency = metrics.NilHistogram{}
+
+		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp))
+		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512})
+
+		if test.mockSASLAuthErr != ErrNoError {
+			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr)
+		}
+		if test.mockHandshakeErr != ErrNoError {
+			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
+		}
+
+		mockBroker.SetHandlerByMap(map[string]MockResponse{
+			"SaslAuthenticateRequest": mockSASLAuthResponse,
+			"SaslHandshakeRequest":    mockSASLHandshakeResponse,
+		})
+
+		conf := NewConfig()
+		conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
+		conf.Net.SASL.SCRAMClient = test.scramClient
+
+		broker.conf = conf
+		dialer := net.Dialer{
+			Timeout:   conf.Net.DialTimeout,
+			KeepAlive: conf.Net.KeepAlive,
+			LocalAddr: conf.Net.LocalAddr,
+		}
+
+		conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
+
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		broker.conn = conn
+
+		err = broker.authenticateViaSASL()
+
+		if test.mockSASLAuthErr != ErrNoError {
+			if test.mockSASLAuthErr != err {
+				t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+			}
+		} else if test.mockHandshakeErr != ErrNoError {
+			if test.mockHandshakeErr != err {
+				t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+			}
+		} else if test.expectClientErr && err == nil {
+			t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
+		} else if !test.expectClientErr && err != nil {
+			t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
+		}
+
+		mockBroker.Close()
+	}
+}
+
 func TestBuildClientInitialResponse(t *testing.T) {
 
 	testTable := []struct {

+ 32 - 3
client.go

@@ -710,8 +710,11 @@ func (client *client) refreshMetadata() error {
 func (client *client) tryRefreshMetadata(topics []string, attemptsRemaining int) error {
 	retry := func(err error) error {
 		if attemptsRemaining > 0 {
+			backoff := client.computeBackoff(attemptsRemaining)
 			Logger.Printf("client/metadata retrying after %dms... (%d attempts remaining)\n", client.conf.Metadata.Retry.Backoff/time.Millisecond, attemptsRemaining)
-			time.Sleep(client.conf.Metadata.Retry.Backoff)
+			if backoff > 0 {
+				time.Sleep(backoff)
+			}
 			return client.tryRefreshMetadata(topics, attemptsRemaining-1)
 		}
 		return err
@@ -839,11 +842,22 @@ func (client *client) cachedController() *Broker {
 	return client.brokers[client.controllerID]
 }
 
+func (client *client) computeBackoff(attemptsRemaining int) time.Duration {
+	if client.conf.Metadata.Retry.BackoffFunc != nil {
+		maxRetries := client.conf.Metadata.Retry.Max
+		retries := maxRetries - attemptsRemaining
+		return client.conf.Metadata.Retry.BackoffFunc(retries, maxRetries)
+	} else {
+		return client.conf.Metadata.Retry.Backoff
+	}
+}
+
 func (client *client) getConsumerMetadata(consumerGroup string, attemptsRemaining int) (*FindCoordinatorResponse, error) {
 	retry := func(err error) (*FindCoordinatorResponse, error) {
 		if attemptsRemaining > 0 {
-			Logger.Printf("client/coordinator retrying after %dms... (%d attempts remaining)\n", client.conf.Metadata.Retry.Backoff/time.Millisecond, attemptsRemaining)
-			time.Sleep(client.conf.Metadata.Retry.Backoff)
+			backoff := client.computeBackoff(attemptsRemaining)
+			Logger.Printf("client/coordinator retrying after %dms... (%d attempts remaining)\n", backoff/time.Millisecond, attemptsRemaining)
+			time.Sleep(backoff)
 			return client.getConsumerMetadata(consumerGroup, attemptsRemaining-1)
 		}
 		return nil, err
@@ -897,3 +911,18 @@ func (client *client) getConsumerMetadata(consumerGroup string, attemptsRemainin
 	client.resurrectDeadBrokers()
 	return retry(ErrOutOfBrokers)
 }
+
+// nopCloserClient embeds an existing Client, but disables
+// the Close method (yet all other methods pass
+// through unchanged). This is for use in larger structs
+// where it is undesirable to close the client that was
+// passed in by the caller.
+type nopCloserClient struct {
+	Client
+}
+
+// Close intercepts and purposely does not call the underlying
+// client's Close() method.
+func (ncc *nopCloserClient) Close() error {
+	return nil
+}

+ 38 - 0
client_test.go

@@ -3,6 +3,7 @@ package sarama
 import (
 	"io"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -260,6 +261,43 @@ func TestClientGetOffset(t *testing.T) {
 	safeClose(t, client)
 }
 
+func TestClientReceivingUnknownTopicWithBackoffFunc(t *testing.T) {
+	seedBroker := NewMockBroker(t, 1)
+
+	metadataResponse1 := new(MetadataResponse)
+	seedBroker.Returns(metadataResponse1)
+
+	retryCount := int32(0)
+
+	config := NewConfig()
+	config.Metadata.Retry.Max = 1
+	config.Metadata.Retry.BackoffFunc = func(retries, maxRetries int) time.Duration {
+		atomic.AddInt32(&retryCount, 1)
+		return 0
+	}
+	client, err := NewClient([]string{seedBroker.Addr()}, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	metadataUnknownTopic := new(MetadataResponse)
+	metadataUnknownTopic.AddTopic("new_topic", ErrUnknownTopicOrPartition)
+	seedBroker.Returns(metadataUnknownTopic)
+	seedBroker.Returns(metadataUnknownTopic)
+
+	if err := client.RefreshMetadata("new_topic"); err != ErrUnknownTopicOrPartition {
+		t.Error("ErrUnknownTopicOrPartition expected, got", err)
+	}
+
+	safeClose(t, client)
+	seedBroker.Close()
+
+	actualRetryCount := atomic.LoadInt32(&retryCount)
+	if actualRetryCount != 1 {
+		t.Fatalf("Expected BackoffFunc to be called exactly once, but saw %d", actualRetryCount)
+	}
+}
+
 func TestClientReceivingUnknownTopic(t *testing.T) {
 	seedBroker := NewMockBroker(t, 1)
 

+ 39 - 9
config.go

@@ -61,9 +61,14 @@ type Config struct {
 			// (defaults to true). You should only set this to false if you're using
 			// a non-Kafka SASL proxy.
 			Handshake bool
-			//username and password for SASL/PLAIN authentication
+			//username and password for SASL/PLAIN  or SASL/SCRAM authentication
 			User     string
 			Password string
+			// authz id used for SASL/SCRAM authentication
+			SCRAMAuthzID string
+			// SCRAMClient is a user provided implementation of a SCRAM
+			// client used to perform the SCRAM exchange with the server.
+			SCRAMClient SCRAMClient
 			// TokenProvider is a user-defined callback for generating
 			// access tokens for SASL/OAUTHBEARER auth. See the
 			// AccessTokenProvider interface docs for proper implementation
@@ -92,6 +97,10 @@ type Config struct {
 			// How long to wait for leader election to occur before retrying
 			// (default 250ms). Similar to the JVM's `retry.backoff.ms`.
 			Backoff time.Duration
+			// Called to compute backoff time dynamically. Useful for implementing
+			// more sophisticated backoff strategies. This takes precedence over
+			// `Backoff` if set.
+			BackoffFunc func(retries, maxRetries int) time.Duration
 		}
 		// How frequently to refresh the cluster metadata in the background.
 		// Defaults to 10 minutes. Set to 0 to disable. Similar to
@@ -179,6 +188,10 @@ type Config struct {
 			// (default 100ms). Similar to the `retry.backoff.ms` setting of the
 			// JVM producer.
 			Backoff time.Duration
+			// Called to compute backoff time dynamically. Useful for implementing
+			// more sophisticated backoff strategies. This takes precedence over
+			// `Backoff` if set.
+			BackoffFunc func(retries, maxRetries int) time.Duration
 		}
 	}
 
@@ -237,6 +250,10 @@ type Config struct {
 			// How long to wait after a failing to read from a partition before
 			// trying again (default 2s).
 			Backoff time.Duration
+			// Called to compute backoff time dynamically. Useful for implementing
+			// more sophisticated backoff strategies. This takes precedence over
+			// `Backoff` if set.
+			BackoffFunc func(retries int) time.Duration
 		}
 
 		// Fetch is the namespace for controlling how many bytes are retrieved by any
@@ -463,22 +480,35 @@ func (c *Config) Validate() error {
 	case c.Net.KeepAlive < 0:
 		return ConfigurationError("Net.KeepAlive must be >= 0")
 	case c.Net.SASL.Enable:
-		// For backwards compatibility, empty mechanism value defaults to PLAIN
-		isSASLPlain := len(c.Net.SASL.Mechanism) == 0 || c.Net.SASL.Mechanism == SASLTypePlaintext
-		if isSASLPlain {
+		if c.Net.SASL.Mechanism == "" {
+			c.Net.SASL.Mechanism = SASLTypePlaintext
+		}
+
+		switch c.Net.SASL.Mechanism {
+		case SASLTypePlaintext:
 			if c.Net.SASL.User == "" {
 				return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
 			}
 			if c.Net.SASL.Password == "" {
 				return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
 			}
-		} else if c.Net.SASL.Mechanism == SASLTypeOAuth {
+		case SASLTypeOAuth:
 			if c.Net.SASL.TokenProvider == nil {
-				return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider")
+				return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider")
+			}
+		case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
+			if c.Net.SASL.User == "" {
+				return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
+			}
+			if c.Net.SASL.Password == "" {
+				return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
+			}
+			if c.Net.SASL.SCRAMClient == nil {
+				return ConfigurationError("A SCRAMClient instance must be provided to Net.SASL.SCRAMClient")
 			}
-		} else {
-			msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s` and `%s`",
-				SASLTypeOAuth, SASLTypePlaintext)
+		default:
+			msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s` and `%s`",
+				SASLTypeOAuth, SASLTypePlaintext, SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512)
 			return ConfigurationError(msg)
 		}
 	}

+ 20 - 2
config_test.go

@@ -91,14 +91,32 @@ func TestNetConfigValidates(t *testing.T) {
 				cfg.Net.SASL.Mechanism = "AnIncorrectSASLMechanism"
 				cfg.Net.SASL.TokenProvider = &DummyTokenProvider{}
 			},
-			"The SASL mechanism configuration is invalid. Possible values are `OAUTHBEARER` and `PLAIN`"},
+			"The SASL mechanism configuration is invalid. Possible values are `OAUTHBEARER`, `PLAIN`, `SCRAM-SHA-256` and `SCRAM-SHA-512`"},
 		{"SASL.Mechanism.OAUTHBEARER - Missing token provider",
 			func(cfg *Config) {
 				cfg.Net.SASL.Enable = true
 				cfg.Net.SASL.Mechanism = SASLTypeOAuth
 				cfg.Net.SASL.TokenProvider = nil
 			},
-			"An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider"},
+			"An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider"},
+		{"SASL.Mechanism SCRAM-SHA-256 - Missing SCRAM client",
+			func(cfg *Config) {
+				cfg.Net.SASL.Enable = true
+				cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA256
+				cfg.Net.SASL.SCRAMClient = nil
+				cfg.Net.SASL.User = "user"
+				cfg.Net.SASL.Password = "stong_password"
+			},
+			"A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"},
+		{"SASL.Mechanism SCRAM-SHA-512 - Missing SCRAM client",
+			func(cfg *Config) {
+				cfg.Net.SASL.Enable = true
+				cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
+				cfg.Net.SASL.SCRAMClient = nil
+				cfg.Net.SASL.User = "user"
+				cfg.Net.SASL.Password = "stong_password"
+			},
+			"A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"},
 	}
 
 	for i, test := range tests {

+ 51 - 49
consumer.go

@@ -12,13 +12,14 @@ import (
 
 // ConsumerMessage encapsulates a Kafka message returned by the consumer.
 type ConsumerMessage struct {
-	Key, Value     []byte
-	Topic          string
-	Partition      int32
-	Offset         int64
+	Headers        []*RecordHeader // only set if kafka is version 0.11+
 	Timestamp      time.Time       // only set if kafka is version 0.10+, inner message timestamp
 	BlockTimestamp time.Time       // only set if kafka is version 0.10+, outer (compressed) block timestamp
-	Headers        []*RecordHeader // only set if kafka is version 0.11+
+
+	Key, Value []byte
+	Topic      string
+	Partition  int32
+	Offset     int64
 }
 
 // ConsumerError is what is provided to the user when an error occurs.
@@ -45,11 +46,6 @@ func (ce ConsumerErrors) Error() string {
 // Consumer manages PartitionConsumers which process Kafka messages from brokers. You MUST call Close()
 // on a consumer to avoid leaks, it will not be garbage-collected automatically when it passes out of
 // scope.
-//
-// Sarama's Consumer type does not currently support automatic consumer-group rebalancing and offset tracking.
-// For Zookeeper-based tracking (Kafka 0.8.2 and earlier), the https://github.com/wvanbergen/kafka library
-// builds on Sarama to add this support. For Kafka-based tracking (Kafka 0.9 and later), the
-// https://github.com/bsm/sarama-cluster library builds on Sarama to add this support.
 type Consumer interface {
 
 	// Topics returns the set of available topics as retrieved from the cluster
@@ -77,13 +73,11 @@ type Consumer interface {
 }
 
 type consumer struct {
-	client    Client
-	conf      *Config
-	ownClient bool
-
-	lock            sync.Mutex
+	conf            *Config
 	children        map[string]map[int32]*partitionConsumer
 	brokerConsumers map[*Broker]*brokerConsumer
+	client          Client
+	lock            sync.Mutex
 }
 
 // NewConsumer creates a new consumer using the given broker addresses and configuration.
@@ -92,18 +86,19 @@ func NewConsumer(addrs []string, config *Config) (Consumer, error) {
 	if err != nil {
 		return nil, err
 	}
-
-	c, err := NewConsumerFromClient(client)
-	if err != nil {
-		return nil, err
-	}
-	c.(*consumer).ownClient = true
-	return c, nil
+	return newConsumer(client)
 }
 
 // NewConsumerFromClient creates a new consumer using the given client. It is still
 // necessary to call Close() on the underlying client when shutting down this consumer.
 func NewConsumerFromClient(client Client) (Consumer, error) {
+	// For clients passed in by the client, ensure we don't
+	// call Close() on it.
+	cli := &nopCloserClient{client}
+	return newConsumer(cli)
+}
+
+func newConsumer(client Client) (Consumer, error) {
 	// Check that we are not dealing with a closed Client before processing any other arguments
 	if client.Closed() {
 		return nil, ErrClosedClient
@@ -120,10 +115,7 @@ func NewConsumerFromClient(client Client) (Consumer, error) {
 }
 
 func (c *consumer) Close() error {
-	if c.ownClient {
-		return c.client.Close()
-	}
-	return nil
+	return c.client.Close()
 }
 
 func (c *consumer) Topics() ([]string, error) {
@@ -263,7 +255,7 @@ func (c *consumer) abandonBrokerConsumer(brokerWorker *brokerConsumer) {
 // or a separate goroutine. Check out the Consumer examples to see implementations of these different approaches.
 //
 // To terminate such a for/range loop while the loop is executing, call AsyncClose. This will kick off the process of
-// consumer tear-down & return imediately. Continue to loop, servicing the Messages channel until the teardown process
+// consumer tear-down & return immediately. Continue to loop, servicing the Messages channel until the teardown process
 // AsyncClose initiated closes it (thus terminating the for/range loop). If you've already ceased reading Messages, call
 // Close; this will signal the PartitionConsumer's goroutines to begin shutting down (just like AsyncClose), but will
 // also drain the Messages channel, harvest all errors & return them once cleanup has completed.
@@ -300,22 +292,22 @@ type PartitionConsumer interface {
 
 type partitionConsumer struct {
 	highWaterMarkOffset int64 // must be at the top of the struct because https://golang.org/pkg/sync/atomic/#pkg-note-BUG
-	consumer            *consumer
-	conf                *Config
-	topic               string
-	partition           int32
 
+	consumer *consumer
+	conf     *Config
 	broker   *brokerConsumer
 	messages chan *ConsumerMessage
 	errors   chan *ConsumerError
 	feeder   chan *FetchResponse
 
 	trigger, dying chan none
-	responseResult error
 	closeOnce      sync.Once
-
-	fetchSize int32
-	offset    int64
+	topic          string
+	partition      int32
+	responseResult error
+	fetchSize      int32
+	offset         int64
+	retries        int32
 }
 
 var errTimedOut = errors.New("timed out feeding messages to the user") // not user-facing
@@ -334,12 +326,20 @@ func (child *partitionConsumer) sendError(err error) {
 	}
 }
 
+func (child *partitionConsumer) computeBackoff() time.Duration {
+	if child.conf.Consumer.Retry.BackoffFunc != nil {
+		retries := atomic.AddInt32(&child.retries, 1)
+		return child.conf.Consumer.Retry.BackoffFunc(int(retries))
+	}
+	return child.conf.Consumer.Retry.Backoff
+}
+
 func (child *partitionConsumer) dispatcher() {
 	for range child.trigger {
 		select {
 		case <-child.dying:
 			close(child.trigger)
-		case <-time.After(child.conf.Consumer.Retry.Backoff):
+		case <-time.After(child.computeBackoff()):
 			if child.broker != nil {
 				child.consumer.unrefBrokerConsumer(child.broker)
 				child.broker = nil
@@ -453,6 +453,10 @@ feederLoop:
 	for response := range child.feeder {
 		msgs, child.responseResult = child.parseResponse(response)
 
+		if child.responseResult == nil {
+			atomic.StoreInt32(&child.retries, 0)
+		}
+
 		for i, msg := range msgs {
 		messageSelect:
 			select {
@@ -513,13 +517,13 @@ func (child *partitionConsumer) parseMessages(msgSet *MessageSet) ([]*ConsumerMe
 		}
 	}
 	if len(messages) == 0 {
-		return nil, ErrIncompleteResponse
+		child.offset++
 	}
 	return messages, nil
 }
 
 func (child *partitionConsumer) parseRecords(batch *RecordBatch) ([]*ConsumerMessage, error) {
-	var messages []*ConsumerMessage
+	messages := make([]*ConsumerMessage, 0, len(batch.Records))
 
 	for _, rec := range batch.Records {
 		offset := batch.FirstOffset + rec.OffsetDelta
@@ -542,7 +546,7 @@ func (child *partitionConsumer) parseRecords(batch *RecordBatch) ([]*ConsumerMes
 		child.offset = offset + 1
 	}
 	if len(messages) == 0 {
-		child.offset += 1
+		child.offset++
 	}
 	return messages, nil
 }
@@ -628,15 +632,13 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 	return messages, nil
 }
 
-// brokerConsumer
-
 type brokerConsumer struct {
 	consumer         *consumer
 	broker           *Broker
 	input            chan *partitionConsumer
 	newSubscriptions chan []*partitionConsumer
-	wait             chan none
 	subscriptions    map[*partitionConsumer]none
+	wait             chan none
 	acks             sync.WaitGroup
 	refs             int
 }
@@ -658,14 +660,14 @@ func (c *consumer) newBrokerConsumer(broker *Broker) *brokerConsumer {
 	return bc
 }
 
+// The subscriptionManager constantly accepts new subscriptions on `input` (even when the main subscriptionConsumer
+// goroutine is in the middle of a network request) and batches it up. The main worker goroutine picks
+// up a batch of new subscriptions between every network request by reading from `newSubscriptions`, so we give
+// it nil if no new subscriptions are available. We also write to `wait` only when new subscriptions is available,
+// so the main goroutine can block waiting for work if it has none.
 func (bc *brokerConsumer) subscriptionManager() {
 	var buffer []*partitionConsumer
 
-	// The subscriptionManager constantly accepts new subscriptions on `input` (even when the main subscriptionConsumer
-	// goroutine is in the middle of a network request) and batches it up. The main worker goroutine picks
-	// up a batch of new subscriptions between every network request by reading from `newSubscriptions`, so we give
-	// it nil if no new subscriptions are available. We also write to `wait` only when new subscriptions is available,
-	// so the main goroutine can block waiting for work if it has none.
 	for {
 		if len(buffer) > 0 {
 			select {
@@ -698,10 +700,10 @@ done:
 	close(bc.newSubscriptions)
 }
 
+//subscriptionConsumer ensures we will get nil right away if no new subscriptions is available
 func (bc *brokerConsumer) subscriptionConsumer() {
 	<-bc.wait // wait for our first piece of work
 
-	// the subscriptionConsumer ensures we will get nil right away if no new subscriptions is available
 	for newSubscriptions := range bc.newSubscriptions {
 		bc.updateSubscriptions(newSubscriptions)
 
@@ -747,8 +749,8 @@ func (bc *brokerConsumer) updateSubscriptions(newSubscriptions []*partitionConsu
 	}
 }
 
+//handleResponses handles the response codes left for us by our subscriptions, and abandons ones that have been closed
 func (bc *brokerConsumer) handleResponses() {
-	// handles the response codes left for us by our subscriptions, and abandons ones that have been closed
 	for child := range bc.subscriptions {
 		result := child.responseResult
 		child.responseResult = nil

+ 18 - 11
consumer_group.go

@@ -52,8 +52,7 @@ type ConsumerGroup interface {
 }
 
 type consumerGroup struct {
-	client    Client
-	ownClient bool
+	client Client
 
 	config   *Config
 	consumer Consumer
@@ -73,20 +72,24 @@ func NewConsumerGroup(addrs []string, groupID string, config *Config) (ConsumerG
 		return nil, err
 	}
 
-	c, err := NewConsumerGroupFromClient(groupID, client)
+	c, err := newConsumerGroup(groupID, client)
 	if err != nil {
 		_ = client.Close()
-		return nil, err
 	}
-
-	c.(*consumerGroup).ownClient = true
-	return c, nil
+	return c, err
 }
 
 // NewConsumerGroupFromClient creates a new consumer group using the given client. It is still
 // necessary to call Close() on the underlying client when shutting down this consumer.
 // PLEASE NOTE: consumer groups can only re-use but not share clients.
 func NewConsumerGroupFromClient(groupID string, client Client) (ConsumerGroup, error) {
+	// For clients passed in by the client, ensure we don't
+	// call Close() on it.
+	cli := &nopCloserClient{client}
+	return newConsumerGroup(groupID, cli)
+}
+
+func newConsumerGroup(groupID string, client Client) (ConsumerGroup, error) {
 	config := client.Config()
 	if !config.Version.IsAtLeast(V0_10_2_0) {
 		return nil, ConfigurationError("consumer groups require Version to be >= V0_10_2_0")
@@ -131,10 +134,8 @@ func (c *consumerGroup) Close() (err error) {
 			err = e
 		}
 
-		if c.ownClient {
-			if e := c.client.Close(); e != nil {
-				err = e
-			}
+		if e := c.client.Close(); e != nil {
+			err = e
 		}
 	})
 	return
@@ -657,6 +658,12 @@ func (s *consumerGroupSession) heartbeatLoop() {
 		resp, err := s.parent.heartbeatRequest(coordinator, s.memberID, s.generationID)
 		if err != nil {
 			_ = coordinator.Close()
+
+			if retries <= 0 {
+				s.parent.handleError(err, "", -1)
+				return
+			}
+
 			retries--
 			continue
 		}

+ 34 - 8
consumer_test.go

@@ -5,6 +5,7 @@ import (
 	"os"
 	"os/signal"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -180,9 +181,7 @@ func TestConsumerDuplicate(t *testing.T) {
 	broker0.Close()
 }
 
-// If consumer fails to refresh metadata it keeps retrying with frequency
-// specified by `Config.Consumer.Retry.Backoff`.
-func TestConsumerLeaderRefreshError(t *testing.T) {
+func runConsumerLeaderRefreshErrorTestWithConfig(t *testing.T, config *Config) {
 	// Given
 	broker0 := NewMockBroker(t, 100)
 
@@ -200,11 +199,6 @@ func TestConsumerLeaderRefreshError(t *testing.T) {
 			SetMessage("my_topic", 0, 123, testMsg),
 	})
 
-	config := NewConfig()
-	config.Net.ReadTimeout = 100 * time.Millisecond
-	config.Consumer.Retry.Backoff = 200 * time.Millisecond
-	config.Consumer.Return.Errors = true
-	config.Metadata.Retry.Max = 0
 	c, err := NewConsumer([]string{broker0.Addr()}, config)
 	if err != nil {
 		t.Fatal(err)
@@ -258,6 +252,38 @@ func TestConsumerLeaderRefreshError(t *testing.T) {
 	broker0.Close()
 }
 
+// If consumer fails to refresh metadata it keeps retrying with frequency
+// specified by `Config.Consumer.Retry.Backoff`.
+func TestConsumerLeaderRefreshError(t *testing.T) {
+	config := NewConfig()
+	config.Net.ReadTimeout = 100 * time.Millisecond
+	config.Consumer.Retry.Backoff = 200 * time.Millisecond
+	config.Consumer.Return.Errors = true
+	config.Metadata.Retry.Max = 0
+
+	runConsumerLeaderRefreshErrorTestWithConfig(t, config)
+}
+
+func TestConsumerLeaderRefreshErrorWithBackoffFunc(t *testing.T) {
+	var calls int32 = 0
+
+	config := NewConfig()
+	config.Net.ReadTimeout = 100 * time.Millisecond
+	config.Consumer.Retry.BackoffFunc = func(retries int) time.Duration {
+		atomic.AddInt32(&calls, 1)
+		return 200 * time.Millisecond
+	}
+	config.Consumer.Return.Errors = true
+	config.Metadata.Retry.Max = 0
+
+	runConsumerLeaderRefreshErrorTestWithConfig(t, config)
+
+	// we expect at least one call to our backoff function
+	if calls == 0 {
+		t.Fail()
+	}
+}
+
 func TestConsumerInvalidTopic(t *testing.T) {
 	// Given
 	broker0 := NewMockBroker(t, 100)

+ 12 - 1
create_partitions_response.go

@@ -1,6 +1,9 @@
 package sarama
 
-import "time"
+import (
+	"fmt"
+	"time"
+)
 
 type CreatePartitionsResponse struct {
 	ThrottleTime         time.Duration
@@ -69,6 +72,14 @@ type TopicPartitionError struct {
 	ErrMsg *string
 }
 
+func (t *TopicPartitionError) Error() string {
+	text := t.Err.Error()
+	if t.ErrMsg != nil {
+		text = fmt.Sprintf("%s - %s", text, *t.ErrMsg)
+	}
+	return text
+}
+
 func (t *TopicPartitionError) encode(pe packetEncoder) error {
 	pe.putInt16(int16(t.Err))
 

+ 24 - 0
create_partitions_response_test.go

@@ -50,3 +50,27 @@ func TestCreatePartitionsResponse(t *testing.T) {
 		t.Errorf("Decoding error: expected %v but got %v", decodedresp, resp)
 	}
 }
+
+func TestTopicPartitionError(t *testing.T) {
+	// Assert that TopicPartitionError satisfies error interface
+	var err error = &TopicPartitionError{
+		Err: ErrTopicAuthorizationFailed,
+	}
+
+	got := err.Error()
+	want := ErrTopicAuthorizationFailed.Error()
+	if got != want {
+		t.Errorf("TopicPartitionError.Error() = %v; want %v", got, want)
+	}
+
+	msg := "reason why topic authorization failed"
+	err = &TopicPartitionError{
+		Err:    ErrTopicAuthorizationFailed,
+		ErrMsg: &msg,
+	}
+	got = err.Error()
+	want = ErrTopicAuthorizationFailed.Error() + " - " + msg
+	if got != want {
+		t.Errorf("TopicPartitionError.Error() = %v; want %v", got, want)
+	}
+}

+ 12 - 1
create_topics_response.go

@@ -1,6 +1,9 @@
 package sarama
 
-import "time"
+import (
+	"fmt"
+	"time"
+)
 
 type CreateTopicsResponse struct {
 	Version      int16
@@ -83,6 +86,14 @@ type TopicError struct {
 	ErrMsg *string
 }
 
+func (t *TopicError) Error() string {
+	text := t.Err.Error()
+	if t.ErrMsg != nil {
+		text = fmt.Sprintf("%s - %s", text, *t.ErrMsg)
+	}
+	return text
+}
+
 func (t *TopicError) encode(pe packetEncoder, version int16) error {
 	pe.putInt16(int16(t.Err))
 

+ 24 - 0
create_topics_response_test.go

@@ -50,3 +50,27 @@ func TestCreateTopicsResponse(t *testing.T) {
 
 	testResponse(t, "version 2", resp, createTopicsResponseV2)
 }
+
+func TestTopicError(t *testing.T) {
+	// Assert that TopicError satisfies error interface
+	var err error = &TopicError{
+		Err: ErrTopicAuthorizationFailed,
+	}
+
+	got := err.Error()
+	want := ErrTopicAuthorizationFailed.Error()
+	if got != want {
+		t.Errorf("TopicError.Error() = %v; want %v", got, want)
+	}
+
+	msg := "reason why topic authorization failed"
+	err = &TopicError{
+		Err:    ErrTopicAuthorizationFailed,
+		ErrMsg: &msg,
+	}
+	got = err.Error()
+	want = ErrTopicAuthorizationFailed.Error() + " - " + msg
+	if got != want {
+		t.Errorf("TopicError.Error() = %v; want %v", got, want)
+	}
+}

+ 1 - 1
dev.yml

@@ -2,7 +2,7 @@ name: sarama
 
 up:
   - go:
-      version: '1.11'
+      version: '1.12'
 
 commands:
   test:

+ 12 - 0
errors.go

@@ -157,6 +157,10 @@ const (
 	ErrFetchSessionIDNotFound             KError = 70
 	ErrInvalidFetchSessionEpoch           KError = 71
 	ErrListenerNotFound                   KError = 72
+	ErrTopicDeletionDisabled              KError = 73
+	ErrFencedLeaderEpoch                  KError = 74
+	ErrUnknownLeaderEpoch                 KError = 75
+	ErrUnsupportedCompressionType         KError = 76
 )
 
 func (err KError) Error() string {
@@ -311,6 +315,14 @@ func (err KError) Error() string {
 		return "kafka server: The fetch session epoch is invalid."
 	case ErrListenerNotFound:
 		return "kafka server: There is no listener on the leader broker that matches the listener on which metadata request was processed."
+	case ErrTopicDeletionDisabled:
+		return "kafka server: Topic deletion is disabled."
+	case ErrFencedLeaderEpoch:
+		return "kafka server: The leader epoch in the request is older than the epoch on the broker."
+	case ErrUnknownLeaderEpoch:
+		return "kafka server: The leader epoch in the request is newer than the epoch on the broker."
+	case ErrUnsupportedCompressionType:
+		return "kafka server: The requesting client does not support the compression type of given partition."
 	}
 
 	return fmt.Sprintf("Unknown error, how did this happen? Error code = %d", err)

+ 13 - 0
go.mod

@@ -0,0 +1,13 @@
+module github.com/Shopify/sarama
+
+require (
+	github.com/DataDog/zstd v1.3.5
+	github.com/Shopify/toxiproxy v2.1.4+incompatible
+	github.com/davecgh/go-spew v1.1.1
+	github.com/eapache/go-resiliency v1.1.0
+	github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21
+	github.com/eapache/queue v1.1.0
+	github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
+	github.com/pierrec/lz4 v2.0.5+incompatible
+	github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a
+)

+ 18 - 0
go.sum

@@ -0,0 +1,18 @@
+github.com/DataDog/zstd v1.3.5 h1:DtpNbljikUepEPD16hD4LvIcmhnhdLTiW/5pHgbmp14=
+github.com/DataDog/zstd v1.3.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
+github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc=
+github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/eapache/go-resiliency v1.1.0 h1:1NtRmCAqadE2FN4ZcN6g90TP3uk8cg9rn9eNK2197aU=
+github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
+github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw=
+github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
+github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
+github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
+github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
+github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
+github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I=
+github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
+github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a h1:9ZKAASQSHhDYGoxY8uLVpewe1GDZ2vu2Tr/vTdVAkFQ=
+github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=

+ 1 - 1
message.go

@@ -157,7 +157,7 @@ func (m *Message) decode(pd packetDecoder) (err error) {
 	return pd.pop()
 }
 
-// decodes a message set from a previousy encoded bulk-message
+// decodes a message set from a previously encoded bulk-message
 func (m *Message) decodeSet() (err error) {
 	pd := realDecoder{raw: m.Value}
 	m.Set = &MessageSet{}

+ 20 - 1
mockresponses.go

@@ -2,6 +2,7 @@ package sarama
 
 import (
 	"fmt"
+	"strings"
 )
 
 // TestReporter has methods matching go's testing.T to avoid importing
@@ -620,10 +621,20 @@ func NewMockCreateTopicsResponse(t TestReporter) *MockCreateTopicsResponse {
 
 func (mr *MockCreateTopicsResponse) For(reqBody versionedDecoder) encoder {
 	req := reqBody.(*CreateTopicsRequest)
-	res := &CreateTopicsResponse{}
+	res := &CreateTopicsResponse{
+		Version: req.Version,
+	}
 	res.TopicErrors = make(map[string]*TopicError)
 
 	for topic, _ := range req.TopicDetails {
+		if res.Version >= 1 && strings.HasPrefix(topic, "_") {
+			msg := "insufficient permissions to create topic with reserved prefix"
+			res.TopicErrors[topic] = &TopicError{
+				Err:    ErrTopicAuthorizationFailed,
+				ErrMsg: &msg,
+			}
+			continue
+		}
 		res.TopicErrors[topic] = &TopicError{Err: ErrNoError}
 	}
 	return res
@@ -662,6 +673,14 @@ func (mr *MockCreatePartitionsResponse) For(reqBody versionedDecoder) encoder {
 	res.TopicPartitionErrors = make(map[string]*TopicPartitionError)
 
 	for topic, _ := range req.TopicPartitions {
+		if strings.HasPrefix(topic, "_") {
+			msg := "insufficient permissions to create partition on topic with reserved prefix"
+			res.TopicPartitionErrors[topic] = &TopicPartitionError{
+				Err:    ErrTopicAuthorizationFailed,
+				ErrMsg: &msg,
+			}
+			continue
+		}
 		res.TopicPartitionErrors[topic] = &TopicPartitionError{Err: ErrNoError}
 	}
 	return res

+ 10 - 1
offset_manager.go

@@ -120,6 +120,14 @@ func (om *offsetManager) Close() error {
 	return nil
 }
 
+func (om *offsetManager) computeBackoff(retries int) time.Duration {
+	if om.conf.Metadata.Retry.BackoffFunc != nil {
+		return om.conf.Metadata.Retry.BackoffFunc(retries, om.conf.Metadata.Retry.Max)
+	} else {
+		return om.conf.Metadata.Retry.Backoff
+	}
+}
+
 func (om *offsetManager) fetchInitialOffset(topic string, partition int32, retries int) (int64, string, error) {
 	broker, err := om.coordinator()
 	if err != nil {
@@ -161,10 +169,11 @@ func (om *offsetManager) fetchInitialOffset(topic string, partition int32, retri
 		if retries <= 0 {
 			return 0, "", block.Err
 		}
+		backoff := om.computeBackoff(retries)
 		select {
 		case <-om.closing:
 			return 0, "", block.Err
-		case <-time.After(om.conf.Metadata.Retry.Backoff):
+		case <-time.After(backoff):
 		}
 		return om.fetchInitialOffset(topic, partition, retries-1)
 	default:

+ 21 - 2
offset_manager_test.go

@@ -1,15 +1,20 @@
 package sarama
 
 import (
+	"sync/atomic"
 	"testing"
 	"time"
 )
 
-func initOffsetManager(t *testing.T, retention time.Duration) (om OffsetManager,
+func initOffsetManagerWithBackoffFunc(t *testing.T, retention time.Duration,
+	backoffFunc func(retries, maxRetries int) time.Duration) (om OffsetManager,
 	testClient Client, broker, coordinator *MockBroker) {
 
 	config := NewConfig()
 	config.Metadata.Retry.Max = 1
+	if backoffFunc != nil {
+		config.Metadata.Retry.BackoffFunc = backoffFunc
+	}
 	config.Consumer.Offsets.CommitInterval = 1 * time.Millisecond
 	config.Version = V0_9_0_0
 	if retention > 0 {
@@ -45,6 +50,11 @@ func initOffsetManager(t *testing.T, retention time.Duration) (om OffsetManager,
 	return om, testClient, broker, coordinator
 }
 
+func initOffsetManager(t *testing.T, retention time.Duration) (om OffsetManager,
+	testClient Client, broker, coordinator *MockBroker) {
+	return initOffsetManagerWithBackoffFunc(t, retention, nil)
+}
+
 func initPartitionOffsetManager(t *testing.T, om OffsetManager,
 	coordinator *MockBroker, initialOffset int64, metadata string) PartitionOffsetManager {
 
@@ -133,7 +143,12 @@ func TestOffsetManagerFetchInitialFail(t *testing.T) {
 
 // Test fetchInitialOffset retry on ErrOffsetsLoadInProgress
 func TestOffsetManagerFetchInitialLoadInProgress(t *testing.T) {
-	om, testClient, broker, coordinator := initOffsetManager(t, 0)
+	retryCount := int32(0)
+	backoff := func(retries, maxRetries int) time.Duration {
+		atomic.AddInt32(&retryCount, 1)
+		return 0
+	}
+	om, testClient, broker, coordinator := initOffsetManagerWithBackoffFunc(t, 0, backoff)
 
 	// Error on first fetchInitialOffset call
 	responseBlock := OffsetFetchResponseBlock{
@@ -163,6 +178,10 @@ func TestOffsetManagerFetchInitialLoadInProgress(t *testing.T) {
 	safeClose(t, pom)
 	safeClose(t, om)
 	safeClose(t, testClient)
+
+	if atomic.LoadInt32(&retryCount) == 0 {
+		t.Fatal("Expected at least one retry")
+	}
 }
 
 func TestPartitionOffsetManagerInitialOffset(t *testing.T) {

+ 5 - 3
record.go

@@ -10,6 +10,7 @@ const (
 	maximumRecordOverhead = 5*binary.MaxVarintLen32 + binary.MaxVarintLen64 + 1
 )
 
+//RecordHeader stores key and value for a record header
 type RecordHeader struct {
 	Key   []byte
 	Value []byte
@@ -33,15 +34,16 @@ func (h *RecordHeader) decode(pd packetDecoder) (err error) {
 	return nil
 }
 
+//Record is kafka record type
 type Record struct {
+	Headers []*RecordHeader
+
 	Attributes     int8
 	TimestampDelta time.Duration
 	OffsetDelta    int64
 	Key            []byte
 	Value          []byte
-	Headers        []*RecordHeader
-
-	length varintLengthField
+	length         varintLengthField
 }
 
 func (r *Record) encode(pe packetEncoder) error {

+ 3 - 0
response_header.go

@@ -2,6 +2,9 @@ package sarama
 
 import "fmt"
 
+const responseLengthSize = 4
+const correlationIDSize = 4
+
 type responseHeader struct {
 	length        int32
 	correlationID int32

+ 1 - 4
sarama.go

@@ -10,10 +10,7 @@ useful but comes with two caveats: it will generally be less efficient, and the
 depend on the configured value of `Producer.RequiredAcks`. There are configurations where a message acknowledged by the
 SyncProducer can still sometimes be lost.
 
-To consume messages, use the Consumer. Note that Sarama's Consumer implementation does not currently support automatic
-consumer-group rebalancing and offset tracking. For Zookeeper-based tracking (Kafka 0.8.2 and earlier), the
-https://github.com/wvanbergen/kafka library builds on Sarama to add this support. For Kafka-based tracking (Kafka 0.9
-and later), the https://github.com/bsm/sarama-cluster library builds on Sarama to add this support.
+To consume messages, use Consumer or Consumer-Group API.
 
 For lower-level needs, the Broker and Request/Response objects permit precise control over each connection
 and message sent on the wire; the Client provides higher-level metadata management that is shared between

+ 45 - 0
sync_producer_test.go

@@ -177,6 +177,51 @@ func TestSyncProducerToNonExistingTopic(t *testing.T) {
 	broker.Close()
 }
 
+func TestSyncProducerRecoveryWithRetriesDisabled(t *testing.T) {
+	seedBroker := NewMockBroker(t, 1)
+	leader1 := NewMockBroker(t, 2)
+	leader2 := NewMockBroker(t, 3)
+
+	metadataLeader1 := new(MetadataResponse)
+	metadataLeader1.AddBroker(leader1.Addr(), leader1.BrokerID())
+	metadataLeader1.AddTopicPartition("my_topic", 0, leader1.BrokerID(), nil, nil, ErrNoError)
+	seedBroker.Returns(metadataLeader1)
+
+	config := NewConfig()
+	config.Producer.Retry.Max = 0 // disable!
+	config.Producer.Retry.Backoff = 0
+	config.Producer.Return.Successes = true
+	producer, err := NewSyncProducer([]string{seedBroker.Addr()}, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+	seedBroker.Close()
+
+	prodNotLeader := new(ProduceResponse)
+	prodNotLeader.AddTopicPartition("my_topic", 0, ErrNotLeaderForPartition)
+	leader1.Returns(prodNotLeader)
+	_, _, err = producer.SendMessage(&ProducerMessage{Topic: "my_topic", Value: StringEncoder(TestMessage)})
+	if err != ErrNotLeaderForPartition {
+		t.Fatal(err)
+	}
+
+	metadataLeader2 := new(MetadataResponse)
+	metadataLeader2.AddBroker(leader2.Addr(), leader2.BrokerID())
+	metadataLeader2.AddTopicPartition("my_topic", 0, leader2.BrokerID(), nil, nil, ErrNoError)
+	leader1.Returns(metadataLeader2)
+	prodSuccess := new(ProduceResponse)
+	prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError)
+	leader2.Returns(prodSuccess)
+	_, _, err = producer.SendMessage(&ProducerMessage{Topic: "my_topic", Value: StringEncoder(TestMessage)})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	leader1.Close()
+	leader2.Close()
+	safeClose(t, producer)
+}
+
 // This example shows the basic usage pattern of the SyncProducer.
 func ExampleSyncProducer() {
 	producer, err := NewSyncProducer([]string{"localhost:9092"}, nil)

+ 24 - 6
tools/kafka-console-consumer/kafka-console-consumer.go

@@ -11,14 +11,20 @@ import (
 	"sync"
 
 	"github.com/Shopify/sarama"
+	"github.com/Shopify/sarama/tools/tls"
 )
 
 var (
-	brokerList = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The comma separated list of brokers in the Kafka cluster")
-	topic      = flag.String("topic", "", "REQUIRED: the topic to consume")
-	partitions = flag.String("partitions", "all", "The partitions to consume, can be 'all' or comma-separated numbers")
-	offset     = flag.String("offset", "newest", "The offset to start with. Can be `oldest`, `newest`")
-	verbose    = flag.Bool("verbose", false, "Whether to turn on sarama logging")
+	brokerList    = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The comma separated list of brokers in the Kafka cluster")
+	topic         = flag.String("topic", "", "REQUIRED: the topic to consume")
+	partitions    = flag.String("partitions", "all", "The partitions to consume, can be 'all' or comma-separated numbers")
+	offset        = flag.String("offset", "newest", "The offset to start with. Can be `oldest`, `newest`")
+	verbose       = flag.Bool("verbose", false, "Whether to turn on sarama logging")
+	tlsEnabled    = flag.Bool("tls-enabled", false, "Whether to enable TLS")
+	tlsSkipVerify = flag.Bool("tls-skip-verify", false, "Whether skip TLS server cert verification")
+	tlsClientCert = flag.String("tls-client-cert", "", "Client cert for client authentication (use with -tls-enabled and -tls-client-key)")
+	tlsClientKey  = flag.String("tls-client-key", "", "Client key for client authentication (use with tls-enabled and -tls-client-cert)")
+
 	bufferSize = flag.Int("buffer-size", 256, "The buffer size of the message channel.")
 
 	logger = log.New(os.Stderr, "", log.LstdFlags)
@@ -49,7 +55,19 @@ func main() {
 		printUsageErrorAndExit("-offset should be `oldest` or `newest`")
 	}
 
-	c, err := sarama.NewConsumer(strings.Split(*brokerList, ","), nil)
+	config := sarama.NewConfig()
+	if *tlsEnabled {
+		tlsConfig, err := tls.NewConfig(*tlsClientCert, *tlsClientKey)
+		if err != nil {
+			printErrorAndExit(69, "Failed to create TLS config: %s", err)
+		}
+
+		config.Net.TLS.Enable = true
+		config.Net.TLS.Config = tlsConfig
+		config.Net.TLS.Config.InsecureSkipVerify = *tlsSkipVerify
+	}
+
+	c, err := sarama.NewConsumer(strings.Split(*brokerList, ","), config)
 	if err != nil {
 		printErrorAndExit(69, "Failed to start consumer: %s", err)
 	}

+ 25 - 9
tools/kafka-console-producer/kafka-console-producer.go

@@ -9,19 +9,24 @@ import (
 	"strings"
 
 	"github.com/Shopify/sarama"
+	"github.com/Shopify/sarama/tools/tls"
 	"github.com/rcrowley/go-metrics"
 )
 
 var (
-	brokerList  = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The comma separated list of brokers in the Kafka cluster. You can also set the KAFKA_PEERS environment variable")
-	topic       = flag.String("topic", "", "REQUIRED: the topic to produce to")
-	key         = flag.String("key", "", "The key of the message to produce. Can be empty.")
-	value       = flag.String("value", "", "REQUIRED: the value of the message to produce. You can also provide the value on stdin.")
-	partitioner = flag.String("partitioner", "", "The partitioning scheme to use. Can be `hash`, `manual`, or `random`")
-	partition   = flag.Int("partition", -1, "The partition to produce to.")
-	verbose     = flag.Bool("verbose", false, "Turn on sarama logging to stderr")
-	showMetrics = flag.Bool("metrics", false, "Output metrics on successful publish to stderr")
-	silent      = flag.Bool("silent", false, "Turn off printing the message's topic, partition, and offset to stdout")
+	brokerList    = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The comma separated list of brokers in the Kafka cluster. You can also set the KAFKA_PEERS environment variable")
+	topic         = flag.String("topic", "", "REQUIRED: the topic to produce to")
+	key           = flag.String("key", "", "The key of the message to produce. Can be empty.")
+	value         = flag.String("value", "", "REQUIRED: the value of the message to produce. You can also provide the value on stdin.")
+	partitioner   = flag.String("partitioner", "", "The partitioning scheme to use. Can be `hash`, `manual`, or `random`")
+	partition     = flag.Int("partition", -1, "The partition to produce to.")
+	verbose       = flag.Bool("verbose", false, "Turn on sarama logging to stderr")
+	showMetrics   = flag.Bool("metrics", false, "Output metrics on successful publish to stderr")
+	silent        = flag.Bool("silent", false, "Turn off printing the message's topic, partition, and offset to stdout")
+	tlsEnabled    = flag.Bool("tls-enabled", false, "Whether to enable TLS")
+	tlsSkipVerify = flag.Bool("tls-skip-verify", false, "Whether skip TLS server cert verification")
+	tlsClientCert = flag.String("tls-client-cert", "", "Client cert for client authentication (use with -tls-enabled and -tls-client-key)")
+	tlsClientKey  = flag.String("tls-client-key", "", "Client key for client authentication (use with tls-enabled and -tls-client-cert)")
 
 	logger = log.New(os.Stderr, "", log.LstdFlags)
 )
@@ -45,6 +50,17 @@ func main() {
 	config.Producer.RequiredAcks = sarama.WaitForAll
 	config.Producer.Return.Successes = true
 
+	if *tlsEnabled {
+		tlsConfig, err := tls.NewConfig(*tlsClientCert, *tlsClientKey)
+		if err != nil {
+			printErrorAndExit(69, "Failed to create TLS config: %s", err)
+		}
+
+		config.Net.TLS.Enable = true
+		config.Net.TLS.Config = tlsConfig
+		config.Net.TLS.Config.InsecureSkipVerify = *tlsSkipVerify
+	}
+
 	switch *partitioner {
 	case "":
 		if *partition >= 0 {

+ 17 - 0
tools/tls/config.go

@@ -0,0 +1,17 @@
+package tls
+
+import "crypto/tls"
+
+func NewConfig(clientCert, clientKey string) (*tls.Config, error) {
+	tlsConfig := tls.Config{}
+
+	if clientCert != "" && clientKey != "" {
+		cert, err := tls.LoadX509KeyPair(clientCert, clientKey)
+		if err != nil {
+			return &tlsConfig, err
+		}
+		tlsConfig.Certificates = []tls.Certificate{cert}
+	}
+
+	return &tlsConfig, nil
+}

+ 3 - 1
utils.go

@@ -159,6 +159,7 @@ var (
 	V2_0_0_0  = newKafkaVersion(2, 0, 0, 0)
 	V2_0_1_0  = newKafkaVersion(2, 0, 1, 0)
 	V2_1_0_0  = newKafkaVersion(2, 1, 0, 0)
+	V2_2_0_0  = newKafkaVersion(2, 2, 0, 0)
 
 	SupportedVersions = []KafkaVersion{
 		V0_8_2_0,
@@ -181,9 +182,10 @@ var (
 		V2_0_0_0,
 		V2_0_1_0,
 		V2_1_0_0,
+		V2_2_0_0,
 	}
 	MinVersion = V0_8_2_0
-	MaxVersion = V2_1_0_0
+	MaxVersion = V2_2_0_0
 )
 
 func ParseKafkaVersion(s string) (KafkaVersion, error) {

+ 1 - 1
vagrant/install_cluster.sh

@@ -2,7 +2,7 @@
 
 set -ex
 
-TOXIPROXY_VERSION=2.1.3
+TOXIPROXY_VERSION=2.1.4
 
 mkdir -p ${KAFKA_INSTALL_ROOT}
 if [ ! -f ${KAFKA_INSTALL_ROOT}/kafka-${KAFKA_VERSION}.tgz ]; then