Bläddra i källkod

Use context for session cancellation

Dimitrij Denissenko 7 år sedan
förälder
incheckning
7170045d9b
3 ändrade filer med 48 tillägg och 37 borttagningar
  1. 42 35
      consumer_group.go
  2. 3 1
      consumer_group_test.go
  3. 3 1
      functional_consumer_group_test.go

+ 42 - 35
consumer_group.go

@@ -1,6 +1,7 @@
 package sarama
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"sort"
@@ -17,13 +18,15 @@ type ConsumerGroup interface {
 	// Consume joins a cluster of consumers for a given list of topics and
 	// starts a blocking consumer session through the ConsumerGroupHandler.
 	//
-	// The session will be cancelled after the first handler exits and/or a server-side
-	// rebalance cycle is initated.
+	// The session will exit and all its handlers stopped either when:
+	// 1. the context is cancelled by the user
+	// 2. first handler exits
+	// 3. a rebalance cycle is initated server-side
 	//
 	// Please note that the handler will be applied to each of the claimed partitions
 	// in separate goroutines and must therefore be thread-safe. You can only run a single
 	// session at a time and must close the previous session before initiating a new one.
-	Consume(topics []string, handler ConsumerGroupHandler) error
+	Consume(ctx context.Context, topics []string, handler ConsumerGroupHandler) error
 
 	// Errors returns a read channel of errors that occurred during the consumer life-cycle.
 	// By default, errors are logged and not returned over this channel.
@@ -136,7 +139,7 @@ func (c *consumerGroup) Close() (err error) {
 }
 
 // Consume implements ConsumerGroup.
-func (c *consumerGroup) Consume(topics []string, handler ConsumerGroupHandler) error {
+func (c *consumerGroup) Consume(ctx context.Context, topics []string, handler ConsumerGroupHandler) error {
 	// Ensure group is not closed
 	select {
 	case <-c.closed:
@@ -149,7 +152,7 @@ func (c *consumerGroup) Consume(topics []string, handler ConsumerGroupHandler) e
 	}
 
 	// Start session
-	sess, err := c.startSession(topics, handler)
+	sess, err := c.startSession(ctx, topics, handler)
 	if err == ErrClosedClient {
 		return ErrClosedConsumerGroup
 	} else if err != nil {
@@ -157,7 +160,7 @@ func (c *consumerGroup) Consume(topics []string, handler ConsumerGroupHandler) e
 	}
 
 	// Wait for session exit signal
-	<-sess.done
+	<-sess.ctx.Done()
 
 	// Gracefully release session claims
 	err = sess.release(true)
@@ -169,7 +172,7 @@ func (c *consumerGroup) Consume(topics []string, handler ConsumerGroupHandler) e
 	return err
 }
 
-func (c *consumerGroup) startSession(topics []string, handler ConsumerGroupHandler) (*consumerGroupSession, error) {
+func (c *consumerGroup) startSession(ctx context.Context, topics []string, handler ConsumerGroupHandler) (*consumerGroupSession, error) {
 	c.lock.Lock()
 	defer c.lock.Unlock()
 
@@ -189,7 +192,7 @@ func (c *consumerGroup) startSession(topics []string, handler ConsumerGroupHandl
 		return nil, err
 	}
 
-	sess, err := c.newSession(coordinator, topics, handler, c.config.Consumer.Group.Rebalance.Retry.Max)
+	sess, err := c.newSession(ctx, coordinator, topics, handler, c.config.Consumer.Group.Rebalance.Retry.Max)
 	if err != nil {
 		return nil, err
 	}
@@ -199,7 +202,7 @@ func (c *consumerGroup) startSession(topics []string, handler ConsumerGroupHandl
 	return sess, nil
 }
 
-func (c *consumerGroup) newSession(coordinator *Broker, topics []string, handler ConsumerGroupHandler, retries int) (*consumerGroupSession, error) {
+func (c *consumerGroup) newSession(ctx context.Context, coordinator *Broker, topics []string, handler ConsumerGroupHandler, retries int) (*consumerGroupSession, error) {
 	select {
 	case <-c.closed:
 		return nil, ErrClosedConsumerGroup
@@ -217,7 +220,7 @@ func (c *consumerGroup) newSession(coordinator *Broker, topics []string, handler
 		c.memberID = join.MemberId
 	case ErrUnknownMemberId: // reset member ID and retry
 		c.memberID = ""
-		return c.newSession(coordinator, topics, handler, retries)
+		return c.newSession(ctx, coordinator, topics, handler, retries)
 	default:
 		return nil, join.Err
 	}
@@ -255,7 +258,7 @@ func (c *consumerGroup) newSession(coordinator *Broker, topics []string, handler
 		case <-time.After(c.config.Consumer.Group.Rebalance.Retry.Backoff):
 		}
 
-		return c.newSession(coordinator, topics, handler, retries-1)
+		return c.newSession(ctx, coordinator, topics, handler, retries-1)
 	default:
 		return nil, sync.Err
 	}
@@ -274,7 +277,7 @@ func (c *consumerGroup) newSession(coordinator *Broker, topics []string, handler
 		}
 	}
 
-	return newConsumerGroupSession(c, claims, join.MemberId, join.GenerationId, handler)
+	return newConsumerGroupSession(c, ctx, claims, join.MemberId, join.GenerationId, handler)
 }
 
 func (c *consumerGroup) joinGroupRequest(coordinator *Broker, topics []string) (*JoinGroupResponse, error) {
@@ -451,9 +454,8 @@ type ConsumerGroupSession interface {
 	// MarkMessage marks a message as consumed.
 	MarkMessage(msg *ConsumerMessage, metadata string)
 
-	// Cancel triggers the end of the session and notifies all consumers
-	// about an upcoming rebalance cycle.
-	Cancel()
+	// Context returns the session context.
+	Context() context.Context
 }
 
 type consumerGroupSession struct {
@@ -462,21 +464,28 @@ type consumerGroupSession struct {
 	generationID int32
 	handler      ConsumerGroupHandler
 
-	claims    map[string][]int32
-	offsets   *offsetManager
-	done      chan none
-	waitGroup sync.WaitGroup
+	claims  map[string][]int32
+	offsets *offsetManager
+	ctx     context.Context
+	cancel  func()
 
-	cancelOnce, releaseOnce sync.Once
+	waitGroup   sync.WaitGroup
+	releaseOnce sync.Once
 }
 
-func newConsumerGroupSession(parent *consumerGroup, claims map[string][]int32, memberID string, generationID int32, handler ConsumerGroupHandler) (*consumerGroupSession, error) {
+func newConsumerGroupSession(parent *consumerGroup, ctx context.Context, claims map[string][]int32, memberID string, generationID int32, handler ConsumerGroupHandler) (*consumerGroupSession, error) {
 	// init offset manager
 	offsets, err := newOffsetManagerFromClient(parent.groupID, memberID, generationID, parent.client)
 	if err != nil {
 		return nil, err
 	}
 
+	// init context
+	if ctx == nil {
+		ctx = context.Background()
+	}
+	ctx, cancel := context.WithCancel(ctx)
+
 	// init session
 	sess := &consumerGroupSession{
 		parent:       parent,
@@ -485,7 +494,8 @@ func newConsumerGroupSession(parent *consumerGroup, claims map[string][]int32, m
 		handler:      handler,
 		offsets:      offsets,
 		claims:       claims,
-		done:         make(chan none),
+		ctx:          ctx,
+		cancel:       cancel,
 	}
 
 	// start heartbeat loop
@@ -526,7 +536,7 @@ func newConsumerGroupSession(parent *consumerGroup, claims map[string][]int32, m
 
 				// cancel the as session as soon as the first
 				// goroutine exits
-				defer sess.Cancel()
+				defer sess.cancel()
 
 				// consume a single topic/partition, blocking
 				sess.consume(topic, partition)
@@ -556,10 +566,14 @@ func (s *consumerGroupSession) MarkMessage(msg *ConsumerMessage, metadata string
 	s.MarkOffset(msg.Topic, msg.Partition, msg.Offset+1, metadata)
 }
 
+func (s *consumerGroupSession) Context() context.Context {
+	return s.ctx
+}
+
 func (s *consumerGroupSession) consume(topic string, partition int32) {
 	// quick exit if rebalance is due
 	select {
-	case <-s.done:
+	case <-s.ctx.Done():
 		return
 	default:
 	}
@@ -586,7 +600,7 @@ func (s *consumerGroupSession) consume(topic string, partition int32) {
 
 	// trigger close when session is done
 	go func() {
-		<-s.done
+		<-s.ctx.Done()
 		claim.AsyncClose()
 	}()
 
@@ -603,15 +617,9 @@ func (s *consumerGroupSession) consume(topic string, partition int32) {
 	return
 }
 
-func (s *consumerGroupSession) Cancel() {
-	s.cancelOnce.Do(func() {
-		close(s.done)
-	})
-}
-
 func (s *consumerGroupSession) release(withCleanup bool) (err error) {
 	// signal release, stop heartbeat
-	s.Cancel()
+	s.cancel()
 
 	// wait for consumers to exit
 	s.waitGroup.Wait()
@@ -635,8 +643,7 @@ func (s *consumerGroupSession) release(withCleanup bool) (err error) {
 
 func (s *consumerGroupSession) heartbeatLoop() {
 	defer s.waitGroup.Done()
-	// trigger the end of the session on exit
-	defer s.Cancel()
+	defer s.cancel() // trigger the end of the session on exit
 
 	heartbeat := time.NewTicker(s.parent.config.Consumer.Group.Heartbeat.Interval)
 	defer heartbeat.Stop()
@@ -652,7 +659,7 @@ func (s *consumerGroupSession) heartbeatLoop() {
 				}
 				return
 			}
-		case <-s.done:
+		case <-s.ctx.Done():
 			return
 		}
 	}

+ 3 - 1
consumer_group_test.go

@@ -1,6 +1,7 @@
 package sarama
 
 import (
+	"context"
 	"fmt"
 )
 
@@ -40,8 +41,9 @@ func ExampleConsumerGroup() {
 	}()
 
 	// Iterate over consumer sessions.
+	ctx := context.Background()
 	for {
-		err := group.Consume([]string{"my-topic"}, exampleConsumerGroupHandler(func(sess ConsumerGroupSession, claim ConsumerGroupClaim) error {
+		err := group.Consume(ctx, []string{"my-topic"}, exampleConsumerGroupHandler(func(sess ConsumerGroupSession, claim ConsumerGroupClaim) error {
 			for msg := range claim.Messages() {
 				fmt.Printf("Message topic:%q partition:%d offset:%d\n", msg.Topic, msg.Partition, msg.Offset)
 				sess.MarkMessage(msg, "")

+ 3 - 1
functional_consumer_group_test.go

@@ -3,6 +3,7 @@
 package sarama
 
 import (
+	"context"
 	"fmt"
 	"log"
 	"reflect"
@@ -395,11 +396,12 @@ func (m *testFuncConsumerGroupMember) loop(topics []string) {
 		}
 	}()
 
+	ctx := context.Background()
 	for {
 		// set state to pre-consume
 		atomic.StoreInt32(&m.state, 1)
 
-		if err := m.Consume(topics, m); err == ErrClosedConsumerGroup {
+		if err := m.Consume(ctx, topics, m); err == ErrClosedConsumerGroup {
 			return
 		} else if err != nil {
 			m.mu.Lock()