Browse Source

events: support HostFilter to filter host events

Integrate HostFilter to filter incoming host events.
Chris Bannister 10 years ago
parent
commit
8e6cb50242
7 changed files with 198 additions and 8 deletions
  1. 5 0
      cluster.go
  2. 7 0
      connectionpool.go
  3. 28 4
      events.go
  4. 57 1
      events_ccm_test.go
  5. 7 1
      filters.go
  6. 90 0
      filters_test.go
  7. 4 2
      session.go

+ 5 - 0
cluster.go

@@ -121,6 +121,11 @@ type ClusterConfig struct {
 	// receiving a schema change frame. (deault: 60s)
 	// receiving a schema change frame. (deault: 60s)
 	MaxWaitSchemaAgreement time.Duration
 	MaxWaitSchemaAgreement time.Duration
 
 
+	// HostFilter will filter all incoming events for host, any which dont pass
+	// the filter will be ignored. If set will take precedence over any options set
+	// via Discovery
+	HostFilter HostFilter
+
 	// internal config for testing
 	// internal config for testing
 	disableControlConn bool
 	disableControlConn bool
 }
 }

+ 7 - 0
connectionpool.go

@@ -190,6 +190,13 @@ func (p *policyConnPool) Size() int {
 	return count
 	return count
 }
 }
 
 
+func (p *policyConnPool) getPool(addr string) (pool *hostConnPool, ok bool) {
+	p.mu.RLock()
+	pool, ok = p.hostConnPools[addr]
+	p.mu.RUnlock()
+	return
+}
+
 func (p *policyConnPool) Pick(qry *Query) (SelectedHost, *Conn) {
 func (p *policyConnPool) Pick(qry *Query) (SelectedHost, *Conn) {
 	nextHost := p.hostPolicy.Pick(qry)
 	nextHost := p.hostPolicy.Pick(qry)
 
 

+ 28 - 4
events.go

@@ -167,8 +167,12 @@ func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
 		hostInfo = &HostInfo{peer: host.String(), port: port, state: NodeUp}
 		hostInfo = &HostInfo{peer: host.String(), port: port, state: NodeUp}
 	}
 	}
 
 
-	// TODO: remove this when the host selection policy is more sophisticated
-	if !s.cfg.Discovery.matchFilter(hostInfo) {
+	if s.cfg.HostFilter != nil {
+		if !s.cfg.HostFilter.Accept(hostInfo) {
+			return
+		}
+	} else if !s.cfg.Discovery.matchFilter(hostInfo) {
+		// TODO: remove this when the host selection policy is more sophisticated
 		return
 		return
 	}
 	}
 
 
@@ -192,6 +196,16 @@ func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
 func (s *Session) handleRemovedNode(ip net.IP, port int) {
 func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	// we remove all nodes but only add ones which pass the filter
 	// we remove all nodes but only add ones which pass the filter
 	addr := ip.String()
 	addr := ip.String()
+
+	host := s.ring.getHost(addr)
+	if host == nil {
+		host = &HostInfo{peer: addr}
+	}
+
+	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
+		return
+	}
+
 	s.pool.removeHost(addr)
 	s.pool.removeHost(addr)
 	s.ring.removeHost(addr)
 	s.ring.removeHost(addr)
 
 
@@ -202,8 +216,12 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 	addr := ip.String()
 	addr := ip.String()
 	host := s.ring.getHost(addr)
 	host := s.ring.getHost(addr)
 	if host != nil {
 	if host != nil {
-		// TODO: remove this when the host selection policy is more sophisticated
-		if !s.cfg.Discovery.matchFilter(host) {
+		if s.cfg.HostFilter != nil {
+			if !s.cfg.HostFilter.Accept(host) {
+				return
+			}
+		} else if !s.cfg.Discovery.matchFilter(host) {
+			// TODO: remove this when the host selection policy is more sophisticated
 			return
 			return
 		}
 		}
 
 
@@ -224,6 +242,12 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	host := s.ring.getHost(addr)
 	host := s.ring.getHost(addr)
 	if host != nil {
 	if host != nil {
 		host.setState(NodeDown)
 		host.setState(NodeDown)
+	} else {
+		host = &HostInfo{peer: addr}
+	}
+
+	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
+		return
 	}
 	}
 
 
 	s.pool.hostDown(addr)
 	s.pool.hostDown(addr)

+ 57 - 1
events_ccm_test.go

@@ -3,10 +3,11 @@
 package gocql
 package gocql
 
 
 import (
 import (
-	"github.com/gocql/gocql/ccm_test"
 	"log"
 	"log"
 	"testing"
 	"testing"
 	"time"
 	"time"
+
+	"github.com/gocql/gocql/internal/ccm"
 )
 )
 
 
 func TestEventDiscovery(t *testing.T) {
 func TestEventDiscovery(t *testing.T) {
@@ -166,3 +167,58 @@ func TestEventNodeUp(t *testing.T) {
 	}
 	}
 	session.pool.mu.RUnlock()
 	session.pool.mu.RUnlock()
 }
 }
+
+func TestEventFilter(t *testing.T) {
+	if err := ccm.AllUp(); err != nil {
+		t.Fatal(err)
+	}
+
+	status, err := ccm.Status()
+	if err != nil {
+		t.Fatal(err)
+	}
+	log.Printf("status=%+v\n", status)
+
+	cluster := createCluster()
+	cluster.HostFilter = WhiteListHostFilter(status["node1"].Addr)
+	session := createSessionFromCluster(cluster, t)
+	defer session.Close()
+
+	if _, ok := session.pool.getPool(status["node1"].Addr); !ok {
+		t.Errorf("should have %v in pool but dont", "node1")
+	}
+
+	for _, host := range [...]string{"node2", "node3"} {
+		_, ok := session.pool.getPool(status[host].Addr)
+		if ok {
+			t.Errorf("should not have %v in pool", host)
+		}
+	}
+
+	if t.Failed() {
+		t.FailNow()
+	}
+
+	if err := ccm.NodeDown("node2"); err != nil {
+		t.Fatal(err)
+	}
+
+	time.Sleep(5 * time.Second)
+
+	if err := ccm.NodeUp("node2"); err != nil {
+		t.Fatal(err)
+	}
+
+	time.Sleep(15 * time.Second)
+	for _, host := range [...]string{"node2", "node3"} {
+		_, ok := session.pool.getPool(status[host].Addr)
+		if ok {
+			t.Errorf("should not have %v in pool", host)
+		}
+	}
+
+	if t.Failed() {
+		t.FailNow()
+	}
+
+}

+ 7 - 1
filters.go

@@ -15,12 +15,18 @@ func (fn HostFilterFunc) Accept(host *HostInfo) bool {
 }
 }
 
 
 // AcceptAllFilter will accept all hosts
 // AcceptAllFilter will accept all hosts
-func AcceptAllFilterfunc() HostFilter {
+func AcceptAllFilter() HostFilter {
 	return HostFilterFunc(func(host *HostInfo) bool {
 	return HostFilterFunc(func(host *HostInfo) bool {
 		return true
 		return true
 	})
 	})
 }
 }
 
 
+func DenyAllFilter() HostFilter {
+	return HostFilterFunc(func(host *HostInfo) bool {
+		return false
+	})
+}
+
 // DataCentreHostFilter filters all hosts such that they are in the same data centre
 // DataCentreHostFilter filters all hosts such that they are in the same data centre
 // as the supplied data centre.
 // as the supplied data centre.
 func DataCentreHostFilter(dataCentre string) HostFilter {
 func DataCentreHostFilter(dataCentre string) HostFilter {

+ 90 - 0
filters_test.go

@@ -0,0 +1,90 @@
+package gocql
+
+import "testing"
+
+func TestFilter_WhiteList(t *testing.T) {
+	f := WhiteListHostFilter("addr1", "addr2")
+	tests := [...]struct {
+		addr   string
+		accept bool
+	}{
+		{"addr1", true},
+		{"addr2", true},
+		{"addr3", false},
+	}
+
+	for i, test := range tests {
+		if f.Accept(&HostInfo{peer: test.addr}) {
+			if !test.accept {
+				t.Errorf("%d: should not have been accepted but was", i)
+			}
+		} else if test.accept {
+			t.Errorf("%d: should have been accepted but wasn't", i)
+		}
+	}
+}
+
+func TestFilter_AllowAll(t *testing.T) {
+	f := AcceptAllFilter()
+	tests := [...]struct {
+		addr   string
+		accept bool
+	}{
+		{"addr1", true},
+		{"addr2", true},
+		{"addr3", true},
+	}
+
+	for i, test := range tests {
+		if f.Accept(&HostInfo{peer: test.addr}) {
+			if !test.accept {
+				t.Errorf("%d: should not have been accepted but was", i)
+			}
+		} else if test.accept {
+			t.Errorf("%d: should have been accepted but wasn't", i)
+		}
+	}
+}
+
+func TestFilter_DenyAll(t *testing.T) {
+	f := DenyAllFilter()
+	tests := [...]struct {
+		addr   string
+		accept bool
+	}{
+		{"addr1", false},
+		{"addr2", false},
+		{"addr3", false},
+	}
+
+	for i, test := range tests {
+		if f.Accept(&HostInfo{peer: test.addr}) {
+			if !test.accept {
+				t.Errorf("%d: should not have been accepted but was", i)
+			}
+		} else if test.accept {
+			t.Errorf("%d: should have been accepted but wasn't", i)
+		}
+	}
+}
+
+func TestFilter_DataCentre(t *testing.T) {
+	f := DataCentreHostFilter("dc1")
+	tests := [...]struct {
+		dc     string
+		accept bool
+	}{
+		{"dc1", true},
+		{"dc2", false},
+	}
+
+	for i, test := range tests {
+		if f.Accept(&HostInfo{dataCenter: test.dc}) {
+			if !test.accept {
+				t.Errorf("%d: should not have been accepted but was", i)
+			}
+		} else if test.accept {
+			t.Errorf("%d: should have been accepted but wasn't", i)
+		}
+	}
+}

+ 4 - 2
session.go

@@ -44,8 +44,6 @@ type Session struct {
 
 
 	mu sync.RWMutex
 	mu sync.RWMutex
 
 
-	hostFilter HostFilter
-
 	control *controlConn
 	control *controlConn
 
 
 	// event handlers
 	// event handlers
@@ -117,6 +115,10 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
+		for _, host := range hosts {
+			s.ring.addHost(host)
+		}
+
 	} else {
 	} else {
 		// we dont get host info
 		// we dont get host info
 		hosts = make([]*HostInfo, len(cfg.Hosts))
 		hosts = make([]*HostInfo, len(cfg.Hosts))