浏览代码

Merge pull request #993 from Zariel/improve-initial-session-setup

session: improve host setup in session
Chris Bannister 8 年之前
父节点
当前提交
bfab62f32a
共有 6 个文件被更改,包括 108 次插入86 次删除
  1. 4 0
      cluster.go
  2. 2 3
      conn.go
  3. 40 27
      events.go
  4. 44 14
      host_source.go
  5. 18 4
      session.go
  6. 0 38
      session_connect_test.go

+ 4 - 0
cluster.go

@@ -168,6 +168,10 @@ func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, i
 	return newAddr, newPort
 }
 
+func (cfg *ClusterConfig) filterHost(host *HostInfo) bool {
+	return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host))
+}
+
 var (
 	ErrNoHosts              = errors.New("no hosts provided")
 	ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")

+ 2 - 3
conn.go

@@ -168,8 +168,7 @@ func (s *Session) dial(ip net.IP, port int, cfg *ConnConfig, errorHandler ConnEr
 	}
 
 	// TODO(zariel): handle ipv6 zone
-	translatedPeer, translatedPort := s.cfg.translateAddressPort(ip, port)
-	addr := (&net.TCPAddr{IP: translatedPeer, Port: translatedPort}).String()
+	addr := (&net.TCPAddr{IP: ip, Port: port}).String()
 
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not
@@ -1164,7 +1163,7 @@ func (c *Conn) localHostInfo() (*HostInfo, error) {
 	}
 
 	// TODO(zariel): avoid doing this here
-	host, err := hostInfoFromMap(row, c.session.cfg.Port)
+	host, err := c.session.hostInfoFromMap(row)
 	if err != nil {
 		return nil, err
 	}

+ 40 - 27
events.go

@@ -173,7 +173,23 @@ func (s *Session) handleNodeEvent(frames []frame) {
 	}
 }
 
+func (s *Session) addNewNode(host *HostInfo) {
+	if s.cfg.filterHost(host) {
+		return
+	}
+
+	host.setState(NodeUp)
+	s.pool.addHost(host)
+	s.policy.AddHost(host)
+}
+
 func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
+	if gocqlDebug {
+		Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
+	}
+
+	ip, port = s.cfg.translateAddressPort(ip, port)
+
 	// Get host info and apply any filters to the host
 	hostInfo, err := s.hostSource.getHostInfo(ip, port)
 	if err != nil {
@@ -189,14 +205,9 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	}
 
 	// should this handle token moving?
-	if existing, ok := s.ring.addHostIfMissing(hostInfo); ok {
-		existing.update(hostInfo)
-		hostInfo = existing
-	}
+	hostInfo = s.ring.addOrUpdate(hostInfo)
 
-	s.pool.addHost(hostInfo)
-	s.policy.AddHost(hostInfo)
-	hostInfo.setState(NodeUp)
+	s.addNewNode(hostInfo)
 
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
 		// TODO(zariel): debounce ring refresh
@@ -205,6 +216,12 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 }
 
 func (s *Session) handleRemovedNode(ip net.IP, port int) {
+	if gocqlDebug {
+		Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
+	}
+
+	ip, port = s.cfg.translateAddressPort(ip, port)
+
 	// we remove all nodes but only add ones which pass the filter
 	host := s.ring.getHost(ip)
 	if host == nil {
@@ -225,34 +242,30 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	}
 }
 
-func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
+func (s *Session) handleNodeUp(eventIp net.IP, eventPort int, waitForBinary bool) {
 	if gocqlDebug {
-		Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", ip.String(), port)
+		Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
 	}
 
-	host := s.ring.getHost(ip)
-	if host != nil {
-		// If we receive a node up event and user has asked us to ignore the peer address use
-		// the address provide by the event instead the address provide by peer the table.
-		if s.cfg.IgnorePeerAddr && !host.ConnectAddress().Equal(ip) {
-			host.SetConnectAddress(ip)
-		}
+	ip, _ := s.cfg.translateAddressPort(eventIp, eventPort)
 
-		if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
-			return
-		}
-
-		if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary {
-			time.Sleep(t)
-		}
+	host := s.ring.getHost(ip)
+	if host == nil {
+		// TODO(zariel): avoid the need to translate twice in this
+		// case
+		s.handleNewNode(eventIp, eventPort, waitForBinary)
+		return
+	}
 
-		s.pool.hostUp(host)
-		s.policy.HostUp(host)
-		host.setState(NodeUp)
+	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
 		return
 	}
 
-	s.handleNewNode(ip, port, waitForBinary)
+	if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary {
+		time.Sleep(t)
+	}
+
+	s.addNewNode(host)
 }
 
 func (s *Session) handleNodeDown(ip net.IP, port int) {

+ 44 - 14
host_source.go

@@ -183,6 +183,7 @@ func (h *HostInfo) ConnectAddress() net.IP {
 }
 
 func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
+	// TODO(zariel): should this not be exported?
 	h.mu.Lock()
 	defer h.mu.Unlock()
 	h.connectAddress = address
@@ -338,10 +339,24 @@ func (h *HostInfo) update(from *HostInfo) {
 	h.mu.Lock()
 	defer h.mu.Unlock()
 
-	h.tokens = from.tokens
-	h.version = from.version
-	h.hostId = from.hostId
+	h.peer = from.peer
+	h.broadcastAddress = from.broadcastAddress
+	h.listenAddress = from.listenAddress
+	h.rpcAddress = from.rpcAddress
+	h.preferredIP = from.preferredIP
+	h.connectAddress = from.connectAddress
+	h.port = from.port
 	h.dataCenter = from.dataCenter
+	h.rack = from.rack
+	h.hostId = from.hostId
+	h.workload = from.workload
+	h.graph = from.graph
+	h.dseVersion = from.dseVersion
+	h.partitioner = from.partitioner
+	h.clusterName = from.clusterName
+	h.version = from.version
+	h.state = from.state
+	h.tokens = from.tokens
 }
 
 func (h *HostInfo) IsUp() bool {
@@ -387,13 +402,13 @@ func checkSystemSchema(control *controlConn) (bool, error) {
 
 // Given a map that represents a row from either system.local or system.peers
 // return as much information as we can in *HostInfo
-func hostInfoFromMap(row map[string]interface{}, defaultPort int) (*HostInfo, error) {
+func (s *Session) hostInfoFromMap(row map[string]interface{}) (*HostInfo, error) {
 	const assertErrorMsg = "Assertion failed for %s"
 	var ok bool
 
 	// Default to our connected port if the cluster doesn't have port information
 	host := HostInfo{
-		port: defaultPort,
+		port: s.cfg.Port,
 	}
 
 	for key, value := range row {
@@ -485,6 +500,10 @@ func hostInfoFromMap(row map[string]interface{}, defaultPort int) (*HostInfo, er
 		// Not sure what the port field will be called until the JIRA issue is complete
 	}
 
+	ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
+	host.connectAddress = ip
+	host.port = port
+
 	return &host, nil
 }
 
@@ -508,7 +527,7 @@ func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
 
 	for _, row := range rows {
 		// extract all available info about the peer
-		host, err := hostInfoFromMap(row, r.session.cfg.Port)
+		host, err := r.session.hostInfoFromMap(row)
 		if err != nil {
 			return nil, err
 		} else if !isValidPeer(host) {
@@ -560,24 +579,35 @@ func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
 			return nil
 		}
 
-		return ch.conn.query("SELECT * FROM system.peers WHERE peer=?", ip)
+		return ch.conn.query("SELECT * FROM system.peers")
 	})
 
 	if iter != nil {
-		row, err := iter.rowMap()
+		rows, err := iter.SliceMap()
 		if err != nil {
 			return nil, err
 		}
 
-		host, err = hostInfoFromMap(row, port)
-		if err != nil {
-			return nil, err
+		for _, row := range rows {
+			h, err := r.session.hostInfoFromMap(row)
+			if err != nil {
+				return nil, err
+			}
+
+			if host.ConnectAddress().Equal(ip) {
+				host = h
+				break
+			}
+		}
+
+		if host == nil {
+			return nil, errors.New("host not found in peers table")
 		}
-	} else if host == nil {
-		return nil, errors.New("unable to fetch host info: invalid control connection")
 	}
 
-	if host.invalidConnectAddr() {
+	if host == nil {
+		return nil, errors.New("unable to fetch host info: invalid control connection")
+	} else if host.invalidConnectAddr() {
 		return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", ip, host)
 	}
 

+ 18 - 4
session.go

@@ -161,6 +161,16 @@ func (s *Session) init() error {
 		return err
 	}
 
+	allHosts := hosts
+	hosts = hosts[:0]
+	hostMap := make(map[string]*HostInfo, len(allHosts))
+	for _, host := range allHosts {
+		if !s.cfg.filterHost(host) {
+			hosts = append(hosts, host)
+			hostMap[host.ConnectAddress().String()] = host
+		}
+	}
+
 	if !s.cfg.disableControlConn {
 		s.control = createControlConn(s)
 		if s.cfg.ProtoVersion == 0 {
@@ -182,17 +192,20 @@ func (s *Session) init() error {
 
 		if !s.cfg.DisableInitialHostLookup {
 			var partitioner string
-			hosts, partitioner, err = s.hostSource.GetHosts()
+			newHosts, partitioner, err := s.hostSource.GetHosts()
 			if err != nil {
 				return err
 			}
 			s.policy.SetPartitioner(partitioner)
+			for _, host := range newHosts {
+				hostMap[host.ConnectAddress().String()] = host
+			}
 		}
 	}
 
-	for _, host := range hosts {
+	for _, host := range hostMap {
 		host = s.ring.addOrUpdate(host)
-		s.handleNodeUp(host.ConnectAddress(), host.Port(), false)
+		s.addNewNode(host)
 	}
 
 	// TODO(zariel): we probably dont need this any more as we verify that we
@@ -210,7 +223,8 @@ func (s *Session) init() error {
 		newer, _ := checkSystemSchema(s.control)
 		s.useSystemSchema = newer
 	} else {
-		s.useSystemSchema = hosts[0].Version().Major >= 3
+		host := s.ring.rrHost()
+		s.useSystemSchema = host.Version().Major >= 3
 	}
 
 	if s.pool.Size() == 0 {

+ 0 - 38
session_connect_test.go

@@ -91,41 +91,3 @@ func assertConnectionEventually(t *testing.T, wait time.Duration, srvr *OneConnT
 		}
 	}
 }
-
-func TestSession_connect_WithNoTranslator(t *testing.T) {
-	srvr, err := NewOneConnTestServer()
-	assertNil(t, "error when creating tcp server", err)
-	defer srvr.Close()
-
-	session := createTestSession()
-	defer session.Close()
-
-	go srvr.Serve()
-
-	session.connect(&HostInfo{
-		connectAddress: srvr.Addr,
-		port:           srvr.Port,
-	}, testConnErrorHandler(t))
-
-	assertConnectionEventually(t, 500*time.Millisecond, srvr)
-}
-
-func TestSession_connect_WithTranslator(t *testing.T) {
-	srvr, err := NewOneConnTestServer()
-	assertNil(t, "error when creating tcp server", err)
-	defer srvr.Close()
-
-	session := createTestSession()
-	defer session.Close()
-	session.cfg.AddressTranslator = staticAddressTranslator(srvr.Addr, srvr.Port)
-
-	go srvr.Serve()
-
-	// the provided address will be translated
-	session.connect(&HostInfo{
-		connectAddress: net.ParseIP("10.10.10.10"),
-		port:           5432,
-	}, testConnErrorHandler(t))
-
-	assertConnectionEventually(t, 500*time.Millisecond, srvr)
-}