Przeglądaj źródła

Broker cleanup, make it really concurrency-safe

Evan Huus 12 lat temu
rodzic
commit
fd2d940dbf
2 zmienionych plików z 46 dodań i 53 usunięć
  1. 44 52
      broker.go
  2. 2 1
      metadata_cache.go

+ 44 - 52
broker.go

@@ -2,8 +2,8 @@ package kafka
 
 import (
 	"io"
-	"math"
 	"net"
+	"sync"
 )
 
 type Broker struct {
@@ -12,11 +12,11 @@ type Broker struct {
 	port int32
 
 	correlation_id int32
+	conn           net.Conn
+	lock           sync.Mutex
 
-	conn net.Conn
-
-	requests  chan requestToSend
 	responses chan responsePromise
+	done      chan bool
 }
 
 type responsePromise struct {
@@ -25,25 +25,18 @@ type responsePromise struct {
 	errors         chan error
 }
 
-type requestToSend struct {
-	// we cheat and use the responsePromise channels to avoid creating more than necessary
-	response       responsePromise
-	expectResponse bool
-}
-
-func NewBroker(host string, port int32) (b *Broker, err error) {
-	b = new(Broker)
+func NewBroker(host string, port int32) *Broker {
+	b := new(Broker)
 	b.id = -1 // don't know it yet
 	b.host = &host
 	b.port = port
-	err = b.Connect()
-	if err != nil {
-		return nil, err
-	}
-	return b, nil
+	return b
 }
 
 func (b *Broker) Connect() (err error) {
+	b.lock.Lock()
+	defer b.lock.Unlock()
+
 	addr, err := net.ResolveIPAddr("ip", *b.host)
 	if err != nil {
 		return err
@@ -54,20 +47,29 @@ func (b *Broker) Connect() (err error) {
 		return err
 	}
 
-	b.requests = make(chan requestToSend)
-	b.responses = make(chan responsePromise)
+	b.done = make(chan bool)
+
+	// permit a few outstanding requests before we block waiting for responses
+	b.responses = make(chan responsePromise, 4)
 
-	go b.sendRequestLoop()
-	go b.rcvResponseLoop()
+	go b.responseReceiver()
 
 	return nil
 }
 
 func (b *Broker) Close() error {
-	close(b.requests)
+	b.lock.Lock()
+	defer b.lock.Unlock()
+
 	close(b.responses)
+	<-b.done
+
+	err := b.conn.Close()
 
-	return b.conn.Close()
+	b.conn = nil
+	b.responses = nil
+
+	return err
 }
 
 func (b *Broker) RequestMetadata(clientID *string, request *MetadataRequest) (*MetadataResponse, error) {
@@ -98,29 +100,35 @@ func (b *Broker) Produce(clientID *string, request *ProduceRequest) (*ProduceRes
 }
 
 func (b *Broker) sendAndReceive(clientID *string, req requestEncoder, res decoder) error {
+	b.lock.Lock()
+	defer b.lock.Unlock()
+
 	fullRequest := request{b.correlation_id, clientID, req}
-	packet, err := encode(&fullRequest)
+	buf, err := encode(&fullRequest)
+	if err != nil {
+		return err
+	}
+
+	_, err = b.conn.Write(buf)
 	if err != nil {
 		return err
 	}
 
-	sendRequest := requestToSend{responsePromise{b.correlation_id, make(chan []byte), make(chan error)}, res != nil}
+	if res == nil {
+		return nil
+	}
 
-	b.requests <- sendRequest
-	sendRequest.response.packets <- packet // we cheat to avoid poofing up more channels than necessary
+	promise := responsePromise{b.correlation_id, make(chan []byte), make(chan error)}
+	b.responses <- promise
 	b.correlation_id++
 
 	select {
-	case buf := <-sendRequest.response.packets:
+	case buf := <-promise.packets:
 		err = decode(buf, res)
-	case err = <-sendRequest.response.errors:
-	}
-
-	if err != nil {
-		return err
+	case err = <-promise.errors:
 	}
 
-	return nil
+	return err
 }
 
 func (b *Broker) decode(pd packetDecoder) (err error) {
@@ -138,28 +146,11 @@ func (b *Broker) decode(pd packetDecoder) (err error) {
 	if err != nil {
 		return err
 	}
-	if b.port > math.MaxUint16 {
-		return DecodingError("Broker port > 65536")
-	}
 
 	return nil
 }
 
-func (b *Broker) sendRequestLoop() {
-	for request := range b.requests {
-		buf := <-request.response.packets
-		_, err := b.conn.Write(buf)
-		if err != nil {
-			request.response.errors <- err
-			continue
-		}
-		if request.expectResponse {
-			b.responses <- request.response
-		}
-	}
-}
-
-func (b *Broker) rcvResponseLoop() {
+func (b *Broker) responseReceiver() {
 	header := make([]byte, 8)
 	for response := range b.responses {
 		_, err := io.ReadFull(b.conn, header)
@@ -188,4 +179,5 @@ func (b *Broker) rcvResponseLoop() {
 
 		response.packets <- buf
 	}
+	close(b.done)
 }

+ 2 - 1
metadata_cache.go

@@ -15,7 +15,8 @@ type metadataCache struct {
 func newMetadataCache(client *Client, host string, port int32) (*metadataCache, error) {
 	mc := new(metadataCache)
 
-	starter, err := NewBroker(host, port)
+	starter := NewBroker(host, port)
+	err := starter.Connect()
 	if err != nil {
 		return nil, err
 	}