Przeglądaj źródła

hostinfo: fix using default conn and overriding values (#1009)

Only set values in HostInfo.update if they were not set previously.

Fix using default port when we have a port available from an event
frame.

Fix overwriting the supplied ports in a hostinfo
Chris Bannister 8 lat temu
rodzic
commit
f5cee64470
3 zmienionych plików z 108 dodań i 23 usunięć
  1. 3 1
      conn.go
  2. 60 22
      host_source.go
  3. 45 0
      host_source_gen.go

+ 3 - 1
conn.go

@@ -1162,8 +1162,10 @@ func (c *Conn) localHostInfo() (*HostInfo, error) {
 		return nil, err
 	}
 
+	port := c.conn.RemoteAddr().(*net.TCPAddr).Port
+
 	// TODO(zariel): avoid doing this here
-	host, err := c.session.hostInfoFromMap(row)
+	host, err := c.session.hostInfoFromMap(row, port)
 	if err != nil {
 		return nil, err
 	}

+ 60 - 22
host_source.go

@@ -336,27 +336,65 @@ func (h *HostInfo) setPort(port int) *HostInfo {
 }
 
 func (h *HostInfo) update(from *HostInfo) {
+	if h == from {
+		return
+	}
+
 	h.mu.Lock()
 	defer h.mu.Unlock()
 
-	h.peer = from.peer
-	h.broadcastAddress = from.broadcastAddress
-	h.listenAddress = from.listenAddress
-	h.rpcAddress = from.rpcAddress
-	h.preferredIP = from.preferredIP
-	h.connectAddress = from.connectAddress
-	h.port = from.port
-	h.dataCenter = from.dataCenter
-	h.rack = from.rack
-	h.hostId = from.hostId
-	h.workload = from.workload
-	h.graph = from.graph
-	h.dseVersion = from.dseVersion
-	h.partitioner = from.partitioner
-	h.clusterName = from.clusterName
-	h.version = from.version
-	h.state = from.state
-	h.tokens = from.tokens
+	from.mu.RLock()
+	defer from.mu.RUnlock()
+
+	// autogenerated do not update
+	if h.peer == nil {
+		h.peer = from.peer
+	}
+	if h.broadcastAddress == nil {
+		h.broadcastAddress = from.broadcastAddress
+	}
+	if h.listenAddress == nil {
+		h.listenAddress = from.listenAddress
+	}
+	if h.rpcAddress == nil {
+		h.rpcAddress = from.rpcAddress
+	}
+	if h.preferredIP == nil {
+		h.preferredIP = from.preferredIP
+	}
+	if h.connectAddress == nil {
+		h.connectAddress = from.connectAddress
+	}
+	if h.port == 0 {
+		h.port = from.port
+	}
+	if h.dataCenter == "" {
+		h.dataCenter = from.dataCenter
+	}
+	if h.rack == "" {
+		h.rack = from.rack
+	}
+	if h.hostId == "" {
+		h.hostId = from.hostId
+	}
+	if h.workload == "" {
+		h.workload = from.workload
+	}
+	if h.dseVersion == "" {
+		h.dseVersion = from.dseVersion
+	}
+	if h.partitioner == "" {
+		h.partitioner = from.partitioner
+	}
+	if h.clusterName == "" {
+		h.clusterName = from.clusterName
+	}
+	if h.version == (cassVersion{}) {
+		h.version = from.version
+	}
+	if h.tokens == nil {
+		h.tokens = from.tokens
+	}
 }
 
 func (h *HostInfo) IsUp() bool {
@@ -402,13 +440,13 @@ func checkSystemSchema(control *controlConn) (bool, error) {
 
 // Given a map that represents a row from either system.local or system.peers
 // return as much information as we can in *HostInfo
-func (s *Session) hostInfoFromMap(row map[string]interface{}) (*HostInfo, error) {
+func (s *Session) hostInfoFromMap(row map[string]interface{}, port int) (*HostInfo, error) {
 	const assertErrorMsg = "Assertion failed for %s"
 	var ok bool
 
 	// Default to our connected port if the cluster doesn't have port information
 	host := HostInfo{
-		port: s.cfg.Port,
+		port: port,
 	}
 
 	for key, value := range row {
@@ -527,7 +565,7 @@ func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
 
 	for _, row := range rows {
 		// extract all available info about the peer
-		host, err := r.session.hostInfoFromMap(row)
+		host, err := r.session.hostInfoFromMap(row, r.session.cfg.Port)
 		if err != nil {
 			return nil, err
 		} else if !isValidPeer(host) {
@@ -589,7 +627,7 @@ func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
 		}
 
 		for _, row := range rows {
-			h, err := r.session.hostInfoFromMap(row)
+			h, err := r.session.hostInfoFromMap(row, port)
 			if err != nil {
 				return nil, err
 			}

+ 45 - 0
host_source_gen.go

@@ -0,0 +1,45 @@
+// +build genhostinfo
+
+package main
+
+import (
+	"fmt"
+	"reflect"
+	"sync"
+
+	"github.com/gocql/gocql"
+)
+
+func gen(clause, field string) {
+	fmt.Printf("if h.%s == %s {\n", field, clause)
+	fmt.Printf("\th.%s = from.%s\n", field, field)
+	fmt.Println("}")
+}
+
+func main() {
+	t := reflect.ValueOf(&gocql.HostInfo{}).Elem().Type()
+	mu := reflect.TypeOf(sync.RWMutex{})
+
+	for i := 0; i < t.NumField(); i++ {
+		f := t.Field(i)
+		if f.Type == mu {
+			continue
+		}
+
+		switch f.Type.Kind() {
+		case reflect.Slice:
+			gen("nil", f.Name)
+		case reflect.String:
+			gen(`""`, f.Name)
+		case reflect.Int:
+			gen("0", f.Name)
+		case reflect.Struct:
+			gen("("+f.Type.Name()+"{})", f.Name)
+		case reflect.Bool, reflect.Int32:
+			continue
+		default:
+			panic(fmt.Sprintf("unknown field: %s", f))
+		}
+	}
+
+}