Browse Source

Merge pull request #1529 from dnwe/fix-brokerconn-deadlines

fix: ensure consistent use of read/write deadlines
Vlad Gorodetsky 5 years ago
parent
commit
07c15a999b
2 changed files with 86 additions and 57 deletions
  1. 37 57
      broker.go
  2. 49 0
      broker_test.go

+ 37 - 57
broker.go

@@ -671,6 +671,26 @@ func (b *Broker) DescribeLogDirs(request *DescribeLogDirsRequest) (*DescribeLogD
 	return response, nil
 }
 
+// readFull ensures the conn ReadDeadline has been setup before making a
+// call to io.ReadFull
+func (b *Broker) readFull(buf []byte) (n int, err error) {
+	if err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout)); err != nil {
+		return 0, err
+	}
+
+	return io.ReadFull(b.conn, buf)
+}
+
+// write  ensures the conn WriteDeadline has been setup before making a
+// call to conn.Write
+func (b *Broker) write(buf []byte) (n int, err error) {
+	if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
+		return 0, err
+	}
+
+	return b.conn.Write(buf)
+}
+
 func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise, error) {
 	b.lock.Lock()
 	defer b.lock.Unlock()
@@ -692,14 +712,9 @@ func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise,
 		return nil, err
 	}
 
-	err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
-	if err != nil {
-		return nil, err
-	}
-
 	requestTime := time.Now()
-	bytes, err := b.conn.Write(buf)
-	b.updateOutgoingCommunicationMetrics(bytes) //TODO: should it be after error check
+	bytes, err := b.write(buf)
+	b.updateOutgoingCommunicationMetrics(bytes)
 	if err != nil {
 		return nil, err
 	}
@@ -806,14 +821,7 @@ func (b *Broker) responseReceiver() {
 			continue
 		}
 
-		err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout))
-		if err != nil {
-			dead = err
-			response.errors <- err
-			continue
-		}
-
-		bytesReadHeader, err := io.ReadFull(b.conn, header)
+		bytesReadHeader, err := b.readFull(header)
 		requestLatency := time.Since(response.requestTime)
 		if err != nil {
 			b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
@@ -840,7 +848,7 @@ func (b *Broker) responseReceiver() {
 		}
 
 		buf := make([]byte, decodedHeader.length-4)
-		bytesReadBody, err := io.ReadFull(b.conn, buf)
+		bytesReadBody, err := b.readFull(buf)
 		b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency)
 		if err != nil {
 			dead = err
@@ -883,22 +891,17 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
 		return err
 	}
 
-	err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
-	if err != nil {
-		return err
-	}
-
 	requestTime := time.Now()
-	bytes, err := b.conn.Write(buf)
+	bytes, err := b.write(buf)
 	b.updateOutgoingCommunicationMetrics(bytes)
 	if err != nil {
 		Logger.Printf("Failed to send SASL handshake %s: %s\n", b.addr, err.Error())
 		return err
 	}
 	b.correlationID++
-	//wait for the response
+
 	header := make([]byte, 8) // response header
-	_, err = io.ReadFull(b.conn, header)
+	_, err = b.readFull(header)
 	if err != nil {
 		Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error())
 		return err
@@ -906,7 +909,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
 
 	length := binary.BigEndian.Uint32(header[:4])
 	payload := make([]byte, length-4)
-	n, err := io.ReadFull(b.conn, payload)
+	n, err := b.readFull(payload)
 	if err != nil {
 		Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error())
 		return err
@@ -980,14 +983,8 @@ func (b *Broker) sendAndReceiveV0SASLPlainAuth() error {
 	binary.BigEndian.PutUint32(authBytes, uint32(length))
 	copy(authBytes[4:], []byte("\x00"+b.conf.Net.SASL.User+"\x00"+b.conf.Net.SASL.Password))
 
-	err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
-	if err != nil {
-		Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
-		return err
-	}
-
 	requestTime := time.Now()
-	bytesWritten, err := b.conn.Write(authBytes)
+	bytesWritten, err := b.write(authBytes)
 	b.updateOutgoingCommunicationMetrics(bytesWritten)
 	if err != nil {
 		Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
@@ -995,7 +992,7 @@ func (b *Broker) sendAndReceiveV0SASLPlainAuth() error {
 	}
 
 	header := make([]byte, 4)
-	n, err := io.ReadFull(b.conn, header)
+	n, err := b.readFull(header)
 	b.updateIncomingCommunicationMetrics(n, time.Since(requestTime))
 	// If the credentials are valid, we would get a 4 byte response filled with null characters.
 	// Otherwise, the broker closes the connection and we get an EOF
@@ -1151,16 +1148,12 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i
 		return 0, err
 	}
 
-	if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
-		return 0, err
-	}
-
-	return b.conn.Write(buf)
+	return b.write(buf)
 }
 
 func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
 	buf := make([]byte, responseLengthSize+correlationIDSize)
-	_, err := io.ReadFull(b.conn, buf)
+	_, err := b.readFull(buf)
 	if err != nil {
 		return nil, err
 	}
@@ -1176,7 +1169,7 @@ func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, e
 	}
 
 	buf = make([]byte, header.length-correlationIDSize)
-	_, err = io.ReadFull(b.conn, buf)
+	_, err = b.readFull(buf)
 	if err != nil {
 		return nil, err
 	}
@@ -1231,12 +1224,7 @@ func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, erro
 		return 0, err
 	}
 
-	err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
-	if err != nil {
-		Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
-		return 0, err
-	}
-	return b.conn.Write(buf)
+	return b.write(buf)
 }
 
 func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {
@@ -1250,24 +1238,17 @@ func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlatio
 		return 0, err
 	}
 
-	if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
-		return 0, err
-	}
-
-	return b.conn.Write(buf)
+	return b.write(buf)
 }
 
 func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {
-
 	buf := make([]byte, responseLengthSize+correlationIDSize)
-
-	bytesRead, err := io.ReadFull(b.conn, buf)
+	bytesRead, err := b.readFull(buf)
 	if err != nil {
 		return bytesRead, err
 	}
 
 	header := responseHeader{}
-
 	err = decode(buf, &header)
 	if err != nil {
 		return bytesRead, err
@@ -1278,8 +1259,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
 	}
 
 	buf = make([]byte, header.length-correlationIDSize)
-
-	c, err := io.ReadFull(b.conn, buf)
+	c, err := b.readFull(buf)
 	bytesRead += c
 	if err != nil {
 		return bytesRead, err

+ 49 - 0
broker_test.go

@@ -493,6 +493,55 @@ func TestSASLPlainAuth(t *testing.T) {
 	}
 }
 
+// TestSASLReadTimeout ensures that the broker connection won't block forever
+// if the remote end never responds after the handshake
+func TestSASLReadTimeout(t *testing.T) {
+	mockBroker := NewMockBroker(t, 0)
+	defer mockBroker.Close()
+
+	mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
+		SetAuthBytes([]byte(`response_payload`))
+
+	mockBroker.SetHandlerByMap(map[string]MockResponse{
+		"SaslAuthenticateRequest": mockSASLAuthResponse,
+	})
+
+	broker := NewBroker(mockBroker.Addr())
+	{
+		broker.requestRate = metrics.NilMeter{}
+		broker.outgoingByteRate = metrics.NilMeter{}
+		broker.incomingByteRate = metrics.NilMeter{}
+		broker.requestSize = metrics.NilHistogram{}
+		broker.responseSize = metrics.NilHistogram{}
+		broker.responseRate = metrics.NilMeter{}
+		broker.requestLatency = metrics.NilHistogram{}
+	}
+
+	conf := NewConfig()
+	{
+		conf.Net.ReadTimeout = time.Millisecond
+		conf.Net.SASL.Mechanism = SASLTypePlaintext
+		conf.Net.SASL.User = "token"
+		conf.Net.SASL.Password = "password"
+		conf.Net.SASL.Version = SASLHandshakeV1
+	}
+
+	broker.conf = conf
+	broker.conf.Version = V1_0_0_0
+	dialer := net.Dialer{}
+
+	conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	broker.conn = conn
+	err = broker.authenticateViaSASL()
+	if err == nil {
+		t.Errorf("should never happen - expected read timeout")
+	}
+}
+
 func TestGSSAPIKerberosAuth_Authorize(t *testing.T) {
 
 	testTable := []struct {