Browse Source

Make sure all connections are observed

Currently, control connections used for example to discover protocol
version are not observable using connection observer. This change
aims to make sure all connections created by the driver are seen
by the connection observer.

The `dialWithoutObserver` name is intentionally longer than `dial`
so that one must explicitly acknowledge the connection will not be
observed when using the function.
Martin Sucha 6 years ago
parent
commit
0b5041cb38
2 changed files with 28 additions and 17 deletions
  1. 28 2
      conn.go
  2. 0 15
      session.go

+ 28 - 2
conn.go

@@ -159,8 +159,34 @@ type Conn struct {
 	timeouts int64
 }
 
-// Connect establishes a connection to a Cassandra node.
-func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
+// connect establishes a connection to a Cassandra node using session's connection config.
+func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
+	return s.dial(host, s.connCfg, errorHandler)
+}
+
+// dial establishes a connection to a Cassandra node and notifies the session's connectObserver.
+func (s *Session) dial(host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
+	var obs ObservedConnect
+	if s.connectObserver != nil {
+		obs.Host = host
+		obs.Start = time.Now()
+	}
+
+	conn, err := s.dialWithoutObserver(host, connConfig, errorHandler)
+
+	if s.connectObserver != nil {
+		obs.End = time.Now()
+		obs.Err = err
+		s.connectObserver.ObserveConnect(obs)
+	}
+
+	return conn, err
+}
+
+// dialWithoutObserver establishes connection to a Cassandra node.
+//
+// dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead.
+func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
 	ip := host.ConnectAddress()
 	port := host.port
 

+ 0 - 15
session.go

@@ -656,21 +656,6 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 	return applied, iter, iter.err
 }
 
-func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
-	if s.connectObserver != nil {
-		obs := ObservedConnect{
-			Host:  host,
-			Start: time.Now(),
-		}
-		conn, err := s.dial(host, s.connCfg, errorHandler)
-		obs.End = time.Now()
-		obs.Err = err
-		s.connectObserver.ObserveConnect(obs)
-		return conn, err
-	}
-	return s.dial(host, s.connCfg, errorHandler)
-}
-
 type hostMetrics struct {
 	Attempts     int
 	TotalLatency int64