فهرست منبع

Merge pull request #504 from Zariel/ring-discovery-proxy

host_source: use system.local rpc_address
Chris Bannister 10 سال پیش
والد
کامیت
9935df5271
4فایلهای تغییر یافته به همراه259 افزوده شده و 48 حذف شده
  1. 141 0
      cassandra_test.go
  2. 18 6
      control.go
  3. 78 37
      host_source.go
  4. 22 5
      session.go

+ 141 - 0
cassandra_test.go

@@ -6,6 +6,7 @@ import (
 	"bytes"
 	"bytes"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
+	"io"
 	"log"
 	"log"
 	"math"
 	"math"
 	"math/big"
 	"math/big"
@@ -2288,3 +2289,143 @@ func TestUDF(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 }
 }
+
+func TestDiscoverViaProxy(t *testing.T) {
+	// This (complicated) test tests that when the driver is given an initial host
+	// that is infact a proxy it discovers the rest of the ring behind the proxy
+	// and does not store the proxies address as a host in its connection pool.
+	// See https://github.com/gocql/gocql/issues/481
+	proxy, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatalf("unable to create proxy listener: %v", err)
+	}
+
+	var (
+		wg         sync.WaitGroup
+		mu         sync.Mutex
+		proxyConns []net.Conn
+		closed     bool
+	)
+
+	go func(wg *sync.WaitGroup) {
+		cassandraAddr := JoinHostPort(clusterHosts[0], 9042)
+
+		cassandra := func() (net.Conn, error) {
+			return net.Dial("tcp", cassandraAddr)
+		}
+
+		proxyFn := func(wg *sync.WaitGroup, from, to net.Conn) {
+			defer wg.Done()
+
+			_, err := io.Copy(to, from)
+			if err != nil {
+				mu.Lock()
+				if !closed {
+					t.Error(err)
+				}
+				mu.Unlock()
+			}
+		}
+
+		// handle dials cassandra and then proxies requests and reponsess. It waits
+		// for both the read and write side of the TCP connection to close before
+		// returning.
+		handle := func(conn net.Conn) error {
+			defer conn.Close()
+
+			cass, err := cassandra()
+			if err != nil {
+				return err
+			}
+
+			mu.Lock()
+			proxyConns = append(proxyConns, cass)
+			mu.Unlock()
+
+			defer cass.Close()
+
+			var wg sync.WaitGroup
+			wg.Add(1)
+			go proxyFn(&wg, conn, cass)
+
+			wg.Add(1)
+			go proxyFn(&wg, cass, conn)
+
+			wg.Wait()
+
+			return nil
+		}
+
+		for {
+			// proxy just accepts connections and then proxies them to cassandra,
+			// it runs until it is closed.
+			conn, err := proxy.Accept()
+			if err != nil {
+				mu.Lock()
+				if !closed {
+					t.Error(err)
+				}
+				mu.Unlock()
+				return
+			}
+
+			mu.Lock()
+			proxyConns = append(proxyConns, conn)
+			mu.Unlock()
+
+			wg.Add(1)
+			go func(conn net.Conn) {
+				defer wg.Done()
+
+				if err := handle(conn); err != nil {
+					t.Error(err)
+					return
+				}
+			}(conn)
+		}
+	}(&wg)
+
+	defer wg.Wait()
+
+	proxyAddr := proxy.Addr().String()
+
+	cluster := createCluster()
+	cluster.DiscoverHosts = true
+	cluster.NumConns = 1
+	cluster.Discovery.Sleep = 100 * time.Millisecond
+	// initial host is the proxy address
+	cluster.Hosts = []string{proxyAddr}
+
+	session := createSessionFromCluster(cluster, t)
+	defer session.Close()
+
+	if !session.hostSource.localHasRpcAddr {
+		t.Skip("Target cluster does not have rpc_address in system.local.")
+		goto close
+	}
+
+	// we shouldnt need this but to be safe
+	time.Sleep(1 * time.Second)
+
+	session.pool.mu.RLock()
+	for _, host := range clusterHosts {
+		if _, ok := session.pool.hostConnPools[host]; !ok {
+			t.Errorf("missing host in pool after discovery: %q", host)
+		}
+	}
+	session.pool.mu.RUnlock()
+
+close:
+	if err := proxy.Close(); err != nil {
+		t.Log(err)
+	}
+
+	mu.Lock()
+	closed = true
+	for _, conn := range proxyConns {
+		if err := conn.Close(); err != nil {
+			t.Log(err)
+		}
+	}
+	mu.Unlock()
+}

+ 18 - 6
control.go

@@ -26,7 +26,6 @@ func createControlConn(session *Session) *controlConn {
 	}
 	}
 
 
 	control.conn.Store((*Conn)(nil))
 	control.conn.Store((*Conn)(nil))
-	control.reconnect()
 	go control.heartBeat()
 	go control.heartBeat()
 
 
 	return control
 	return control
@@ -55,14 +54,14 @@ func (c *controlConn) heartBeat() {
 		}
 		}
 
 
 	reconn:
 	reconn:
-		c.reconnect()
-		time.Sleep(5 * time.Second)
+		c.reconnect(true)
+		// time.Sleep(5 * time.Second)
 		continue
 		continue
 
 
 	}
 	}
 }
 }
 
 
-func (c *controlConn) reconnect() {
+func (c *controlConn) reconnect(refreshring bool) {
 	if !atomic.CompareAndSwapUint64(&c.connecting, 0, 1) {
 	if !atomic.CompareAndSwapUint64(&c.connecting, 0, 1) {
 		return
 		return
 	}
 	}
@@ -101,6 +100,10 @@ func (c *controlConn) reconnect() {
 	if oldConn != nil {
 	if oldConn != nil {
 		oldConn.Close()
 		oldConn.Close()
 	}
 	}
+
+	if refreshring {
+		c.session.hostSource.refreshRing()
+	}
 }
 }
 
 
 func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
 func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
@@ -113,7 +116,7 @@ func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
 		return
 		return
 	}
 	}
 
 
-	c.reconnect()
+	c.reconnect(true)
 }
 }
 
 
 func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
 func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
@@ -146,7 +149,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 
 
 			connectAttempts++
 			connectAttempts++
 
 
-			c.reconnect()
+			c.reconnect(false)
 			continue
 			continue
 		}
 		}
 
 
@@ -212,6 +215,15 @@ func (c *controlConn) awaitSchemaAgreement() (err error) {
 	// not exported
 	// not exported
 	return errors.New("gocql: cluster schema versions not consistent")
 	return errors.New("gocql: cluster schema versions not consistent")
 }
 }
+
+func (c *controlConn) addr() string {
+	conn := c.conn.Load().(*Conn)
+	if conn == nil {
+		return ""
+	}
+	return conn.addr
+}
+
 func (c *controlConn) close() {
 func (c *controlConn) close() {
 	// TODO: handle more gracefully
 	// TODO: handle more gracefully
 	close(c.quit)
 	close(c.quit)

+ 78 - 37
host_source.go

@@ -1,6 +1,7 @@
 package gocql
 package gocql
 
 
 import (
 import (
+	"fmt"
 	"log"
 	"log"
 	"net"
 	"net"
 	"time"
 	"time"
@@ -14,6 +15,10 @@ type HostInfo struct {
 	Tokens     []string
 	Tokens     []string
 }
 }
 
 
+func (h HostInfo) String() string {
+	return fmt.Sprintf("[hostinfo peer=%q data_centre=%q rack=%q host_id=%q num_tokens=%d]", h.Peer, h.DataCenter, h.Rack, h.HostId, len(h.Tokens))
+}
+
 // Polls system.peers at a specific interval to find new hosts
 // Polls system.peers at a specific interval to find new hosts
 type ringDescriber struct {
 type ringDescriber struct {
 	dcFilter        string
 	dcFilter        string
@@ -22,46 +27,78 @@ type ringDescriber struct {
 	prevPartitioner string
 	prevPartitioner string
 	session         *Session
 	session         *Session
 	closeChan       chan bool
 	closeChan       chan bool
+	// indicates that we can use system.local to get the connections remote address
+	localHasRpcAddr bool
+}
+
+func checkSystemLocal(control *controlConn) (bool, error) {
+	iter := control.query("SELECT rpc_address FROM system.local")
+	if err := iter.err; err != nil {
+		if errf, ok := err.(*errorFrame); ok {
+			if errf.code == errSyntax {
+				return false, nil
+			}
+		}
+
+		return false, err
+	}
+
+	return true, nil
 }
 }
 
 
 func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err error) {
 func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err error) {
 	// we need conn to be the same because we need to query system.peers and system.local
 	// we need conn to be the same because we need to query system.peers and system.local
 	// on the same node to get the whole cluster
 	// on the same node to get the whole cluster
 
 
-	iter := r.session.control.query("SELECT data_center, rack, host_id, tokens, partitioner FROM system.local")
-	if iter == nil {
-		return r.prevHosts, r.prevPartitioner, nil
-	}
+	const (
+		legacyLocalQuery = "SELECT data_center, rack, host_id, tokens, partitioner FROM system.local"
+		// only supported in 2.2.0, 2.1.6, 2.0.16
+		localQuery = "SELECT rpc_address, data_center, rack, host_id, tokens, partitioner FROM system.local"
+	)
+
+	var localHost HostInfo
+	if r.localHasRpcAddr {
+		iter := r.session.control.query(localQuery)
+		if iter == nil {
+			return r.prevHosts, r.prevPartitioner, nil
+		}
 
 
-	conn := r.session.pool.Pick(nil)
-	if conn == nil {
-		return r.prevHosts, r.prevPartitioner, nil
-	}
+		iter.Scan(&localHost.Peer, &localHost.DataCenter, &localHost.Rack,
+			&localHost.HostId, &localHost.Tokens, &partitioner)
 
 
-	host := HostInfo{}
-	iter.Scan(&host.DataCenter, &host.Rack, &host.HostId, &host.Tokens, &partitioner)
+		if err = iter.Close(); err != nil {
+			return nil, "", err
+		}
+	} else {
+		iter := r.session.control.query(legacyLocalQuery)
+		if iter == nil {
+			return r.prevHosts, r.prevPartitioner, nil
+		}
 
 
-	if err = iter.Close(); err != nil {
-		return nil, "", err
-	}
+		iter.Scan(&localHost.DataCenter, &localHost.Rack, &localHost.HostId, &localHost.Tokens, &partitioner)
 
 
-	addr, _, err := net.SplitHostPort(conn.Address())
-	if err != nil {
-		// this should not happen, ever, as this is the address that was dialed by conn, here
-		// a panic makes sense, please report a bug if it occurs.
-		panic(err)
-	}
+		if err = iter.Close(); err != nil {
+			return nil, "", err
+		}
 
 
-	host.Peer = addr
+		addr, _, err := net.SplitHostPort(r.session.control.addr())
+		if err != nil {
+			// this should not happen, ever, as this is the address that was dialed by conn, here
+			// a panic makes sense, please report a bug if it occurs.
+			panic(err)
+		}
+
+		localHost.Peer = addr
+	}
 
 
-	hosts = []HostInfo{host}
+	hosts = []HostInfo{localHost}
 
 
-	iter = r.session.control.query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
+	iter := r.session.control.query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
 	if iter == nil {
 	if iter == nil {
 		return r.prevHosts, r.prevPartitioner, nil
 		return r.prevHosts, r.prevPartitioner, nil
 	}
 	}
 
 
-	host = HostInfo{}
+	host := HostInfo{}
 	for iter.Scan(&host.Peer, &host.DataCenter, &host.Rack, &host.HostId, &host.Tokens) {
 	for iter.Scan(&host.Peer, &host.DataCenter, &host.Rack, &host.HostId, &host.Tokens) {
 		if r.matchFilter(&host) {
 		if r.matchFilter(&host) {
 			hosts = append(hosts, host)
 			hosts = append(hosts, host)
@@ -92,28 +129,32 @@ func (r *ringDescriber) matchFilter(host *HostInfo) bool {
 	return true
 	return true
 }
 }
 
 
-func (h *ringDescriber) run(sleep time.Duration) {
+func (r *ringDescriber) refreshRing() {
+	// if we have 0 hosts this will return the previous list of hosts to
+	// attempt to reconnect to the cluster otherwise we would never find
+	// downed hosts again, could possibly have an optimisation to only
+	// try to add new hosts if GetHosts didnt error and the hosts didnt change.
+	hosts, partitioner, err := r.GetHosts()
+	if err != nil {
+		log.Println("RingDescriber: unable to get ring topology:", err)
+		return
+	}
+
+	r.session.pool.SetHosts(hosts)
+	r.session.pool.SetPartitioner(partitioner)
+}
+
+func (r *ringDescriber) run(sleep time.Duration) {
 	if sleep == 0 {
 	if sleep == 0 {
 		sleep = 30 * time.Second
 		sleep = 30 * time.Second
 	}
 	}
 
 
 	for {
 	for {
-		// if we have 0 hosts this will return the previous list of hosts to
-		// attempt to reconnect to the cluster otherwise we would never find
-		// downed hosts again, could possibly have an optimisation to only
-		// try to add new hosts if GetHosts didnt error and the hosts didnt change.
-		hosts, partitioner, err := h.GetHosts()
-		if err != nil {
-			log.Println("RingDescriber: unable to get ring topology:", err)
-			continue
-		}
-
-		h.session.pool.SetHosts(hosts)
-		h.session.pool.SetPartitioner(partitioner)
+		r.refreshRing()
 
 
 		select {
 		select {
 		case <-time.After(sleep):
 		case <-time.After(sleep):
-		case <-h.closeChan:
+		case <-r.closeChan:
 			return
 			return
 		}
 		}
 	}
 	}

+ 22 - 5
session.go

@@ -10,6 +10,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"log"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -80,7 +81,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 	}
 	}
 	s.pool = pool
 	s.pool = pool
 
 
-	//See if there are any connections in the pool
+	// See if there are any connections in the pool
 	if pool.Size() == 0 {
 	if pool.Size() == 0 {
 		s.Close()
 		s.Close()
 		return nil, ErrNoConnectionsStarted
 		return nil, ErrNoConnectionsStarted
@@ -88,10 +89,8 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 
 
 	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
 
-	if !cfg.disableControlConn {
-		s.control = createControlConn(s)
-	}
-
+	// I think it might be a good idea to simplify this and make it always discover
+	// hosts, maybe with more filters.
 	if cfg.DiscoverHosts {
 	if cfg.DiscoverHosts {
 		s.hostSource = &ringDescriber{
 		s.hostSource = &ringDescriber{
 			session:    s,
 			session:    s,
@@ -99,7 +98,25 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 			rackFilter: cfg.Discovery.RackFilter,
 			rackFilter: cfg.Discovery.RackFilter,
 			closeChan:  make(chan bool),
 			closeChan:  make(chan bool),
 		}
 		}
+	}
+
+	if !cfg.disableControlConn {
+		s.control = createControlConn(s)
+		s.control.reconnect(false)
 
 
+		// need to setup host source to check for rpc_address in system.local
+		localHasRPCAddr, err := checkSystemLocal(s.control)
+		if err != nil {
+			log.Printf("gocql: unable to verify if system.local table contains rpc_address, falling back to connection address: %v", err)
+		}
+
+		if cfg.DiscoverHosts {
+			s.hostSource.localHasRpcAddr = localHasRPCAddr
+		}
+	}
+
+	if cfg.DiscoverHosts {
+		s.hostSource.refreshRing()
 		go s.hostSource.run(cfg.Discovery.Sleep)
 		go s.hostSource.run(cfg.Discovery.Sleep)
 	}
 	}