Browse Source

Merge pull request #1730 from d1egoaz/diego_interceptors

KIP-42 Add producer and consumer interceptors
Diego Alvarez 4 years ago
parent
commit
ec4f2e8f47
7 changed files with 370 additions and 0 deletions
  1. 4 0
      async_producer.go
  2. 121 0
      async_producer_test.go
  3. 16 0
      config.go
  4. 3 0
      consumer.go
  5. 110 0
      consumer_test.go
  6. 73 0
      functional_producer_test.go
  7. 43 0
      interceptors.go

+ 4 - 0
async_producer.go

@@ -348,6 +348,10 @@ func (p *asyncProducer) dispatcher() {
 			p.inFlight.Add(1)
 		}
 
+		for _, interceptor := range p.conf.Producer.Interceptors {
+			msg.safelyApplyInterceptor(interceptor)
+		}
+
 		version := 1
 		if p.conf.Version.IsAtLeast(V0_11_0_0) {
 			version = 2

+ 121 - 0
async_producer_test.go

@@ -5,6 +5,7 @@ import (
 	"log"
 	"os"
 	"os/signal"
+	"strconv"
 	"sync"
 	"sync/atomic"
 	"testing"
@@ -1230,6 +1231,126 @@ func TestBrokerProducerShutdown(t *testing.T) {
 	mockBroker.Close()
 }
 
+type appendInterceptor struct {
+	i int
+}
+
+func (b *appendInterceptor) OnSend(msg *ProducerMessage) {
+	if b.i < 0 {
+		panic("hey, the interceptor has failed")
+	}
+	v, _ := msg.Value.Encode()
+	msg.Value = StringEncoder(string(v) + strconv.Itoa(b.i))
+	b.i++
+}
+
+func (b *appendInterceptor) OnConsume(msg *ConsumerMessage) {
+	if b.i < 0 {
+		panic("hey, the interceptor has failed")
+	}
+	msg.Value = []byte(string(msg.Value) + strconv.Itoa(b.i))
+	b.i++
+}
+
+func testProducerInterceptor(
+	t *testing.T,
+	interceptors []ProducerInterceptor,
+	expectationFn func(*testing.T, int, *ProducerMessage),
+) {
+	seedBroker := NewMockBroker(t, 1)
+	leader := NewMockBroker(t, 2)
+	metadataLeader := new(MetadataResponse)
+	metadataLeader.AddBroker(leader.Addr(), leader.BrokerID())
+	metadataLeader.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, nil, ErrNoError)
+	seedBroker.Returns(metadataLeader)
+
+	config := NewConfig()
+	config.Producer.Flush.Messages = 10
+	config.Producer.Return.Successes = true
+	config.Producer.Interceptors = interceptors
+	producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for i := 0; i < 10; i++ {
+		producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)}
+	}
+
+	prodSuccess := new(ProduceResponse)
+	prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError)
+	leader.Returns(prodSuccess)
+
+	for i := 0; i < 10; i++ {
+		select {
+		case msg := <-producer.Errors():
+			t.Error(msg.Err)
+		case msg := <-producer.Successes():
+			expectationFn(t, i, msg)
+		}
+	}
+
+	closeProducer(t, producer)
+	leader.Close()
+	seedBroker.Close()
+}
+
+func TestAsyncProducerInterceptors(t *testing.T) {
+	tests := []struct {
+		name          string
+		interceptors  []ProducerInterceptor
+		expectationFn func(*testing.T, int, *ProducerMessage)
+	}{
+		{
+			name:         "intercept messages",
+			interceptors: []ProducerInterceptor{&appendInterceptor{i: 0}},
+			expectationFn: func(t *testing.T, i int, msg *ProducerMessage) {
+				v, _ := msg.Value.Encode()
+				expected := TestMessage + strconv.Itoa(i)
+				if string(v) != expected {
+					t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected)
+				}
+			},
+		},
+		{
+			name:         "interceptor chain",
+			interceptors: []ProducerInterceptor{&appendInterceptor{i: 0}, &appendInterceptor{i: 1000}},
+			expectationFn: func(t *testing.T, i int, msg *ProducerMessage) {
+				v, _ := msg.Value.Encode()
+				expected := TestMessage + strconv.Itoa(i) + strconv.Itoa(i+1000)
+				if string(v) != expected {
+					t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected)
+				}
+			},
+		},
+		{
+			name:         "interceptor chain with one interceptor failing",
+			interceptors: []ProducerInterceptor{&appendInterceptor{i: -1}, &appendInterceptor{i: 1000}},
+			expectationFn: func(t *testing.T, i int, msg *ProducerMessage) {
+				v, _ := msg.Value.Encode()
+				expected := TestMessage + strconv.Itoa(i+1000)
+				if string(v) != expected {
+					t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected)
+				}
+			},
+		},
+		{
+			name:         "interceptor chain with all interceptors failing",
+			interceptors: []ProducerInterceptor{&appendInterceptor{i: -1}, &appendInterceptor{i: -1}},
+			expectationFn: func(t *testing.T, i int, msg *ProducerMessage) {
+				v, _ := msg.Value.Encode()
+				expected := TestMessage
+				if string(v) != expected {
+					t.Errorf("Interceptor should have not changed the value, got %s, expected %s", v, expected)
+				}
+			},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) { testProducerInterceptor(t, tt.interceptors, tt.expectationFn) })
+	}
+}
+
 // This example shows how to use the producer while simultaneously
 // reading the Errors channel to know about any failures.
 func ExampleAsyncProducer_select() {

+ 16 - 0
config.go

@@ -229,6 +229,14 @@ type Config struct {
 			// `Backoff` if set.
 			BackoffFunc func(retries, maxRetries int) time.Duration
 		}
+
+		// Interceptors to be called when the producer dispatcher reads the
+		// message for the first time. Interceptors allows to intercept and
+		// possible mutate the message before they are published to Kafka
+		// cluster. *ProducerMessage modified by the first interceptor's
+		// OnSend() is passed to the second interceptor OnSend(), and so on in
+		// the interceptor chain.
+		Interceptors []ProducerInterceptor
 	}
 
 	// Consumer is the namespace for configuration related to consuming messages,
@@ -391,6 +399,14 @@ type Config struct {
 		// 	- use `ReadUncommitted` (default) to consume and return all messages in message channel
 		//	- use `ReadCommitted` to hide messages that are part of an aborted transaction
 		IsolationLevel IsolationLevel
+
+		// Interceptors to be called just before the record is sent to the
+		// messages channel. Interceptors allows to intercept and possible
+		// mutate the message before they are returned to the client.
+		// *ConsumerMessage modified by the first interceptor's OnConsume() is
+		// passed to the second interceptor OnConsume(), and so on in the
+		// interceptor chain.
+		Interceptors []ConsumerInterceptor
 	}
 
 	// A user-provided string sent with every request to the brokers for logging,

+ 3 - 0
consumer.go

@@ -451,6 +451,9 @@ feederLoop:
 		}
 
 		for i, msg := range msgs {
+			for _, interceptor := range child.conf.Consumer.Interceptors {
+				msg.safelyApplyInterceptor(interceptor)
+			}
 		messageSelect:
 			select {
 			case <-child.dying:

+ 110 - 0
consumer_test.go

@@ -5,6 +5,7 @@ import (
 	"os"
 	"os/signal"
 	"reflect"
+	"strconv"
 	"sync"
 	"sync/atomic"
 	"testing"
@@ -1342,3 +1343,112 @@ func Test_partitionConsumer_parseResponse(t *testing.T) {
 		})
 	}
 }
+
+func testConsumerInterceptor(
+	t *testing.T,
+	interceptors []ConsumerInterceptor,
+	expectationFn func(*testing.T, int, *ConsumerMessage),
+) {
+	// Given
+	broker0 := NewMockBroker(t, 0)
+
+	mockFetchResponse := NewMockFetchResponse(t, 1)
+	for i := 0; i < 10; i++ {
+		mockFetchResponse.SetMessage("my_topic", 0, int64(i), testMsg)
+	}
+
+	broker0.SetHandlerByMap(map[string]MockResponse{
+		"MetadataRequest": NewMockMetadataResponse(t).
+			SetBroker(broker0.Addr(), broker0.BrokerID()).
+			SetLeader("my_topic", 0, broker0.BrokerID()),
+		"OffsetRequest": NewMockOffsetResponse(t).
+			SetOffset("my_topic", 0, OffsetOldest, 0).
+			SetOffset("my_topic", 0, OffsetNewest, 0),
+		"FetchRequest": mockFetchResponse,
+	})
+	config := NewConfig()
+	config.Consumer.Interceptors = interceptors
+	// When
+	master, err := NewConsumer([]string{broker0.Addr()}, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	consumer, err := master.ConsumePartition("my_topic", 0, 0)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for i := 0; i < 10; i++ {
+		select {
+		case msg := <-consumer.Messages():
+			expectationFn(t, i, msg)
+		case err := <-consumer.Errors():
+			t.Error(err)
+		}
+	}
+
+	safeClose(t, consumer)
+	safeClose(t, master)
+	broker0.Close()
+}
+
+func TestConsumerInterceptors(t *testing.T) {
+	tests := []struct {
+		name          string
+		interceptors  []ConsumerInterceptor
+		expectationFn func(*testing.T, int, *ConsumerMessage)
+	}{
+		{
+			name:         "intercept messages",
+			interceptors: []ConsumerInterceptor{&appendInterceptor{i: 0}},
+			expectationFn: func(t *testing.T, i int, msg *ConsumerMessage) {
+				ev, _ := testMsg.Encode()
+				expected := string(ev) + strconv.Itoa(i)
+				v := string(msg.Value)
+				if v != expected {
+					t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected)
+				}
+			},
+		},
+		{
+			name:         "interceptor chain",
+			interceptors: []ConsumerInterceptor{&appendInterceptor{i: 0}, &appendInterceptor{i: 1000}},
+			expectationFn: func(t *testing.T, i int, msg *ConsumerMessage) {
+				ev, _ := testMsg.Encode()
+				expected := string(ev) + strconv.Itoa(i) + strconv.Itoa(i+1000)
+				v := string(msg.Value)
+				if v != expected {
+					t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected)
+				}
+			},
+		},
+		{
+			name:         "interceptor chain with one interceptor failing",
+			interceptors: []ConsumerInterceptor{&appendInterceptor{i: -1}, &appendInterceptor{i: 1000}},
+			expectationFn: func(t *testing.T, i int, msg *ConsumerMessage) {
+				ev, _ := testMsg.Encode()
+				expected := string(ev) + strconv.Itoa(i+1000)
+				v := string(msg.Value)
+				if v != expected {
+					t.Errorf("Interceptor should have not changed the value, got %s, expected %s", v, expected)
+				}
+			},
+		},
+		{
+			name:         "interceptor chain with all interceptors failing",
+			interceptors: []ConsumerInterceptor{&appendInterceptor{i: -1}, &appendInterceptor{i: -1}},
+			expectationFn: func(t *testing.T, i int, msg *ConsumerMessage) {
+				ev, _ := testMsg.Encode()
+				expected := string(ev)
+				v := string(msg.Value)
+				if v != expected {
+					t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected)
+				}
+			},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) { testConsumerInterceptor(t, tt.interceptors, tt.expectationFn) })
+	}
+}

+ 73 - 0
functional_producer_test.go

@@ -5,6 +5,7 @@ package sarama
 import (
 	"fmt"
 	"os"
+	"strconv"
 	"strings"
 	"sync"
 	"testing"
@@ -183,6 +184,78 @@ func TestFuncProducingIdempotentWithBrokerFailure(t *testing.T) {
 	}
 }
 
+func TestInterceptors(t *testing.T) {
+	config := NewConfig()
+	setupFunctionalTest(t)
+	defer teardownFunctionalTest(t)
+
+	config.Producer.Return.Successes = true
+	config.Consumer.Return.Errors = true
+	config.Producer.Interceptors = []ProducerInterceptor{&appendInterceptor{i: 0}, &appendInterceptor{i: 100}}
+	config.Consumer.Interceptors = []ConsumerInterceptor{&appendInterceptor{i: 20}}
+
+	client, err := NewClient(FunctionalTestEnv.KafkaBrokerAddrs, config)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	initialOffset, err := client.GetOffset("test.1", 0, OffsetNewest)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	producer, err := NewAsyncProducerFromClient(client)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for i := 0; i < 10; i++ {
+		producer.Input() <- &ProducerMessage{Topic: "test.1", Key: nil, Value: StringEncoder(TestMessage)}
+	}
+
+	for i := 0; i < 10; i++ {
+		select {
+		case msg := <-producer.Errors():
+			t.Error(msg.Err)
+		case msg := <-producer.Successes():
+			v, _ := msg.Value.Encode()
+			expected := TestMessage + strconv.Itoa(i) + strconv.Itoa(i+100)
+			if string(v) != expected {
+				t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected)
+			}
+		}
+	}
+	safeClose(t, producer)
+
+	master, err := NewConsumerFromClient(client)
+	if err != nil {
+		t.Fatal(err)
+	}
+	consumer, err := master.ConsumePartition("test.1", 0, initialOffset)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for i := 0; i < 10; i++ {
+		select {
+		case <-time.After(10 * time.Second):
+			t.Fatal("Not received any more events in the last 10 seconds.")
+		case err := <-consumer.Errors():
+			t.Error(err)
+		case msg := <-consumer.Messages():
+			prodInteExpectation := strconv.Itoa(i) + strconv.Itoa(i+100)
+			consInteExpectation := strconv.Itoa(i + 20)
+			expected := TestMessage + prodInteExpectation + consInteExpectation
+			v := string(msg.Value)
+			if v != expected {
+				t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected)
+			}
+		}
+	}
+	safeClose(t, consumer)
+	safeClose(t, client)
+}
+
 func testProducingMessages(t *testing.T, config *Config) {
 	setupFunctionalTest(t)
 	defer teardownFunctionalTest(t)

+ 43 - 0
interceptors.go

@@ -0,0 +1,43 @@
+package sarama
+
+// ProducerInterceptor allows you to intercept (and possibly mutate) the records
+// received by the producer before they are published to the Kafka cluster.
+// https://cwiki.apache.org/confluence/display/KAFKA/KIP-42%3A+Add+Producer+and+Consumer+Interceptors#KIP42:AddProducerandConsumerInterceptors-Motivation
+type ProducerInterceptor interface {
+
+	// OnSend is called when the producer message is intercepted. Please avoid
+	// modifying the message until it's safe to do so, as this is _not_ a copy
+	// of the message.
+	OnSend(*ProducerMessage)
+}
+
+// ConsumerInterceptor allows you to intercept (and possibly mutate) the records
+// received by the consumer before they are sent to the messages channel.
+// https://cwiki.apache.org/confluence/display/KAFKA/KIP-42%3A+Add+Producer+and+Consumer+Interceptors#KIP42:AddProducerandConsumerInterceptors-Motivation
+type ConsumerInterceptor interface {
+
+	// OnConsume is called when the consumed message is intercepted. Please
+	// avoid modifying the message until it's safe to do so, as this is _not_ a
+	// copy of the message.
+	OnConsume(*ConsumerMessage)
+}
+
+func (msg *ProducerMessage) safelyApplyInterceptor(interceptor ProducerInterceptor) {
+	defer func() {
+		if r := recover(); r != nil {
+			Logger.Printf("Error when calling producer interceptor: %s, %w\n", interceptor, r)
+		}
+	}()
+
+	interceptor.OnSend(msg)
+}
+
+func (msg *ConsumerMessage) safelyApplyInterceptor(interceptor ConsumerInterceptor) {
+	defer func() {
+		if r := recover(); r != nil {
+			Logger.Printf("Error when calling consumer interceptor: %s, %w\n", interceptor, r)
+		}
+	}()
+
+	interceptor.OnConsume(msg)
+}