Преглед изворни кода

conn: store host info on the connection (#1134)

Chris Bannister пре 7 година
родитељ
комит
353a2a91ed
3 измењених фајлова са 8 додато и 3 уклоњено
  1. 6 1
      conn.go
  2. 1 1
      control.go
  3. 1 1
      session.go

+ 6 - 1
conn.go

@@ -140,6 +140,7 @@ type Conn struct {
 	addr            string
 	addr            string
 	version         uint8
 	version         uint8
 	currentKeyspace string
 	currentKeyspace string
+	host            *HostInfo
 
 
 	session *Session
 	session *Session
 
 
@@ -150,7 +151,10 @@ type Conn struct {
 }
 }
 
 
 // Connect establishes a connection to a Cassandra node.
 // Connect establishes a connection to a Cassandra node.
-func (s *Session) dial(ip net.IP, port int, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
+func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
+	ip := host.ConnectAddress()
+	port := host.port
+
 	// TODO(zariel): remove these
 	// TODO(zariel): remove these
 	if len(ip) == 0 || ip.IsUnspecified() {
 	if len(ip) == 0 || ip.IsUnspecified() {
 		panic(fmt.Sprintf("host missing connect ip address: %v", ip))
 		panic(fmt.Sprintf("host missing connect ip address: %v", ip))
@@ -196,6 +200,7 @@ func (s *Session) dial(ip net.IP, port int, cfg *ConnConfig, errorHandler ConnEr
 		quit:         make(chan struct{}),
 		quit:         make(chan struct{}),
 		session:      s,
 		session:      s,
 		streams:      streams.New(cfg.ProtoVersion),
 		streams:      streams.New(cfg.ProtoVersion),
+		host:         host,
 	}
 	}
 
 
 	if cfg.Keepalive > 0 {
 	if cfg.Keepalive > 0 {

+ 1 - 1
control.go

@@ -218,7 +218,7 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
 	var err error
 	var err error
 	for _, host := range hosts {
 	for _, host := range hosts {
 		var conn *Conn
 		var conn *Conn
-		conn, err = c.session.dial(host.ConnectAddress(), host.Port(), &connCfg, handler)
+		conn, err = c.session.dial(host, &connCfg, handler)
 		if conn != nil {
 		if conn != nil {
 			conn.Close()
 			conn.Close()
 		}
 		}

+ 1 - 1
session.go

@@ -639,7 +639,7 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 }
 }
 
 
 func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
 func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
-	return s.dial(host.ConnectAddress(), host.Port(), s.connCfg, errorHandler)
+	return s.dial(host, s.connCfg, errorHandler)
 }
 }
 
 
 // Query represents a CQL statement that can be executed.
 // Query represents a CQL statement that can be executed.