Просмотр исходного кода

Merge remote-tracking branch 'upstream/master'

Bo Blanton 9 лет назад
Родитель
Сommit
e21827980b
13 измененных файлов с 325 добавлено и 9 удалено
  1. 26 0
      address_translators.go
  2. 34 0
      address_translators_test.go
  3. 21 0
      cluster.go
  4. 53 0
      cluster_test.go
  5. 51 0
      common_test.go
  6. 3 1
      conn.go
  7. 2 2
      conn_test.go
  8. 2 2
      control.go
  9. 0 1
      events.go
  10. 1 1
      integration.sh
  11. 131 0
      session_connect_test.go
  12. 0 1
      session_test.go
  13. 1 1
      udt_test.go

+ 26 - 0
address_translators.go

@@ -0,0 +1,26 @@
+package gocql
+
+import "net"
+
+// AddressTranslator provides a way to translate node addresses (and ports) that are
+// discovered or received as a node event. This can be useful in an ec2 environment,
+// for instance, to translate public IPs to private IPs.
+type AddressTranslator interface {
+	// Translate will translate the provided address and/or port to another
+	// address and/or port. If no translation is possible, Translate will return the
+	// address and port provided to it.
+	Translate(addr net.IP, port int) (net.IP, int)
+}
+
+type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)
+
+func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
+	return fn(addr, port)
+}
+
+// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
+func IdentityTranslator() AddressTranslator {
+	return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
+		return addr, port
+	})
+}

+ 34 - 0
address_translators_test.go

@@ -0,0 +1,34 @@
+package gocql
+
+import (
+	"net"
+	"testing"
+)
+
+func TestIdentityAddressTranslator_NilAddrAndZeroPort(t *testing.T) {
+	var tr AddressTranslator = IdentityTranslator()
+	hostIP := net.ParseIP("")
+	if hostIP != nil {
+		t.Errorf("expected host ip to be (nil) but was (%+v) instead", hostIP)
+	}
+
+	addr, port := tr.Translate(hostIP, 0)
+	if addr != nil {
+		t.Errorf("expected translated host to be (nil) but was (%+v) instead", addr)
+	}
+	assertEqual(t, "translated port", 0, port)
+}
+
+func TestIdentityAddressTranslator_HostProvided(t *testing.T) {
+	var tr AddressTranslator = IdentityTranslator()
+	hostIP := net.ParseIP("10.1.2.3")
+	if hostIP == nil {
+		t.Error("expected host ip not to be (nil)")
+	}
+
+	addr, port := tr.Translate(hostIP, 9042)
+	if !hostIP.Equal(addr) {
+		t.Errorf("expected translated addr to be (%+v) but was (%+v) instead", hostIP, addr)
+	}
+	assertEqual(t, "translated port", 9042, port)
+}

+ 21 - 0
cluster.go

@@ -6,6 +6,8 @@ package gocql
 
 import (
 	"errors"
+	"log"
+	"net"
 	"time"
 )
 
@@ -75,6 +77,10 @@ type ClusterConfig struct {
 	// via Discovery
 	HostFilter HostFilter
 
+	// AddressTranslator will translate addresses found on peer discovery and/or
+	// node change events.
+	AddressTranslator AddressTranslator
+
 	// If IgnorePeerAddr is true and the address in system.peers does not match
 	// the supplied host by either initial hosts or discovered via events then the
 	// host will be replaced with the supplied address.
@@ -146,6 +152,21 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	return NewSession(*cfg)
 }
 
+// translateAddressPort is a helper method that will use the given AddressTranslator
+// if defined, to translate the given address and port into a possibly new address
+// and port, If no AddressTranslator or if an error occurs, the given address and
+// port will be returned.
+func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
+	if cfg.AddressTranslator == nil || len(addr) == 0 {
+		return addr, port
+	}
+	newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
+	if gocqlDebug {
+		log.Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
+	}
+	return newAddr, newPort
+}
+
 var (
 	ErrNoHosts              = errors.New("no hosts provided")
 	ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")

+ 53 - 0
cluster_test.go

@@ -0,0 +1,53 @@
+package gocql
+
+import (
+	"testing"
+	"time"
+	"net"
+)
+
+func TestNewCluster_Defaults(t *testing.T) {
+	cfg := NewCluster()
+	assertEqual(t, "cluster config cql version", "3.0.0", cfg.CQLVersion)
+	assertEqual(t, "cluster config timeout", 600*time.Millisecond, cfg.Timeout)
+	assertEqual(t, "cluster config port", 9042, cfg.Port)
+	assertEqual(t, "cluster config num-conns", 2, cfg.NumConns)
+	assertEqual(t, "cluster config consistency", Quorum, cfg.Consistency)
+	assertEqual(t, "cluster config max prepared statements", defaultMaxPreparedStmts, cfg.MaxPreparedStmts)
+	assertEqual(t, "cluster config max routing key info", 1000, cfg.MaxRoutingKeyInfo)
+	assertEqual(t, "cluster config page-size", 5000, cfg.PageSize)
+	assertEqual(t, "cluster config default timestamp", true, cfg.DefaultTimestamp)
+	assertEqual(t, "cluster config max wait schema agreement", 60*time.Second, cfg.MaxWaitSchemaAgreement)
+	assertEqual(t, "cluster config reconnect interval", 60*time.Second, cfg.ReconnectInterval)
+}
+
+func TestNewCluster_WithHosts(t *testing.T) {
+	cfg := NewCluster("addr1", "addr2")
+	assertEqual(t, "cluster config hosts length", 2, len(cfg.Hosts))
+	assertEqual(t, "cluster config host 0", "addr1", cfg.Hosts[0])
+	assertEqual(t, "cluster config host 1", "addr2", cfg.Hosts[1])
+}
+
+func TestClusterConfig_translateAddressAndPort_NilTranslator(t *testing.T) {
+	cfg := NewCluster()
+	assertNil(t, "cluster config address translator", cfg.AddressTranslator)
+	newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 1234)
+	assertTrue(t, "same address as provided", net.ParseIP("10.0.0.1").Equal(newAddr))
+	assertEqual(t, "translated host and port", 1234, newPort)
+}
+
+func TestClusterConfig_translateAddressAndPort_EmptyAddr(t *testing.T) {
+	cfg := NewCluster()
+	cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
+	newAddr, newPort := cfg.translateAddressPort(net.IP([]byte{}), 0)
+	assertTrue(t, "translated address is still empty", len(newAddr) == 0)
+	assertEqual(t, "translated port", 0, newPort)
+}
+
+func TestClusterConfig_translateAddressAndPort_Success(t *testing.T) {
+	cfg := NewCluster()
+	cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
+	newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 2345)
+	assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr))
+	assertEqual(t, "translated port", 5432, newPort)
+}

+ 51 - 0
common_test.go

@@ -8,6 +8,7 @@ import (
 	"sync"
 	"testing"
 	"time"
+	"net"
 )
 
 var (
@@ -143,3 +144,53 @@ func createSession(tb testing.TB) *Session {
 	cluster := createCluster()
 	return createSessionFromCluster(cluster, tb)
 }
+
+// createTestSession is hopefully moderately useful in actual unit tests
+func createTestSession() *Session {
+	config := NewCluster()
+	config.NumConns = 1
+	config.Timeout = 0
+	config.DisableInitialHostLookup = true
+	config.IgnorePeerAddr = true
+	config.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+	session := &Session{
+		cfg:    *config,
+		connCfg: &ConnConfig{
+			Timeout: 10*time.Millisecond,
+			Keepalive: 0,
+		},
+		policy: config.PoolConfig.HostSelectionPolicy,
+	}
+	session.pool = config.PoolConfig.buildPool(session)
+	return session
+}
+
+func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
+	return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
+		return newAddr, newPort
+	})
+}
+
+func assertTrue(t *testing.T, description string, value bool) {
+	if !value {
+		t.Errorf("expected %s to be true", description)
+	}
+}
+
+func assertEqual(t *testing.T, description string, expected, actual interface{}) {
+	if expected != actual {
+		t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
+	}
+}
+
+func assertNil(t *testing.T, description string, actual interface{}) {
+	if actual != nil {
+		t.Errorf("expected %s to be (nil) but was (%+v) instead", description, actual)
+	}
+}
+
+func assertNotNil(t *testing.T, description string, actual interface{}) {
+	if actual == nil {
+		t.Errorf("expected %s not to be (nil)", description)
+	}
+}

+ 3 - 1
conn.go

@@ -172,7 +172,9 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
 	}
 
 	// TODO(zariel): handle ipv6 zone
-	addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
+	translatedPeer, translatedPort := session.cfg.translateAddressPort(host.Peer(), host.Port())
+	addr := (&net.TCPAddr{IP: translatedPeer, Port: translatedPort}).String()
+	//addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
 
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not

+ 2 - 2
conn_test.go

@@ -474,7 +474,7 @@ func TestStream0(t *testing.T) {
 		}
 	})
 
-	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -509,7 +509,7 @@ func TestConnClosedBlocked(t *testing.T) {
 		t.Log(err)
 	})
 
-	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
 	if err != nil {
 		t.Fatal(err)
 	}

+ 2 - 2
control.go

@@ -318,8 +318,8 @@ func (c *controlConn) reconnect(refreshring bool) {
 		}
 	}
 
-	// TODO: should have our own roundrobbin for hosts so that we can try each
-	// in succession and guantee that we get a different host each time.
+	// TODO: should have our own round-robin for hosts so that we can try each
+	// in succession and guarantee that we get a different host each time.
 	if newConn == nil {
 		host := c.session.ring.rrHost()
 		if host == nil {

+ 0 - 1
events.go

@@ -205,7 +205,6 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	s.pool.addHost(hostInfo)
 	s.policy.AddHost(hostInfo)
 	hostInfo.setState(NodeUp)
-
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
 		s.hostSource.refreshRing()
 	}

+ 1 - 1
integration.sh

@@ -64,7 +64,7 @@ function run_tests() {
 
 	local args="-gocql.timeout=60s -runssl -proto=$proto -rf=3 -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy -gocql.cversion=$version -cluster=$(ccm liveset) ./..."
 
-	go test -v -tags unit
+    go test -v -tags unit
 
 	if [ "$auth" = true ]
 	then

+ 131 - 0
session_connect_test.go

@@ -0,0 +1,131 @@
+package gocql
+
+import (
+	"golang.org/x/net/context"
+	"net"
+	"strconv"
+	"sync"
+	"testing"
+	"time"
+)
+
+type OneConnTestServer struct {
+	Err  error
+	Addr net.IP
+	Port int
+
+	listener   net.Listener
+	acceptChan chan struct{}
+	mu         sync.Mutex
+	closed     bool
+}
+
+func NewOneConnTestServer() (*OneConnTestServer, error) {
+	lstn, err := net.Listen("tcp4", "localhost:0")
+	if err != nil {
+		return nil, err
+	}
+	addr, port := parseAddressPort(lstn.Addr().String())
+	return &OneConnTestServer{
+		listener:   lstn,
+		acceptChan: make(chan struct{}),
+		Addr:       addr,
+		Port:       port,
+	}, nil
+}
+
+func (c *OneConnTestServer) Accepted() chan struct{} {
+	return c.acceptChan
+}
+
+func (c *OneConnTestServer) Close() {
+	c.lockedClose()
+}
+
+func (c *OneConnTestServer) Serve() {
+	conn, err := c.listener.Accept()
+	c.Err = err
+	if conn != nil {
+		conn.Close()
+	}
+	c.lockedClose()
+}
+
+func (c *OneConnTestServer) lockedClose() {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if !c.closed {
+		close(c.acceptChan)
+		c.listener.Close()
+		c.closed = true
+	}
+}
+
+func parseAddressPort(hostPort string) (net.IP, int) {
+	host, portStr, err := net.SplitHostPort(hostPort)
+	if err != nil {
+		return net.ParseIP(""), 0
+	}
+	port, _ := strconv.Atoi(portStr)
+	return net.ParseIP(host), port
+}
+
+func testConnErrorHandler(t *testing.T) ConnErrorHandler {
+	return connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
+		t.Errorf("in connection handler: %v", err)
+	})
+}
+
+func assertConnectionEventually(t *testing.T, wait time.Duration, srvr *OneConnTestServer) {
+	ctx, cancel := context.WithTimeout(context.Background(), wait)
+	defer cancel()
+
+	select {
+	case <-ctx.Done():
+		if ctx.Err() != nil {
+			t.Errorf("waiting for connection: %v", ctx.Err())
+		}
+	case <-srvr.Accepted():
+		if srvr.Err != nil {
+			t.Errorf("accepting connection: %v", srvr.Err)
+		}
+	}
+}
+
+func TestSession_connect_WithNoTranslator(t *testing.T) {
+	srvr, err := NewOneConnTestServer()
+	assertNil(t, "error when creating tcp server", err)
+	defer srvr.Close()
+
+	session := createTestSession()
+	defer session.Close()
+
+	go srvr.Serve()
+
+	Connect(&HostInfo{
+		peer: srvr.Addr,
+		port: srvr.Port,
+	}, session.connCfg, testConnErrorHandler(t), session)
+
+	assertConnectionEventually(t, 500*time.Millisecond, srvr)
+}
+
+func TestSession_connect_WithTranslator(t *testing.T) {
+	srvr, err := NewOneConnTestServer()
+	assertNil(t, "error when creating tcp server", err)
+	defer srvr.Close()
+
+	session := createTestSession()
+	defer session.Close()
+	session.cfg.AddressTranslator = staticAddressTranslator(srvr.Addr, srvr.Port)
+
+	go srvr.Serve()
+
+	// the provided address will be translated
+	Connect(&HostInfo{
+		peer: net.ParseIP("10.10.10.10"),
+		port: 5432,
+	}, session.connCfg, testConnErrorHandler(t), session)
+
+	assertConnectionEventually(t, 500*time.Millisecond, srvr)
+}

+ 0 - 1
session_test.go

@@ -8,7 +8,6 @@ import (
 )
 
 func TestSessionAPI(t *testing.T) {
-
 	cfg := &ClusterConfig{}
 
 	s := &Session{

+ 1 - 1
udt_test.go

@@ -147,7 +147,7 @@ func TestUDT_Reflect(t *testing.T) {
 	}
 
 	if *retrievedHorse != *insertedHorse {
-		t.Fatalf("exepcted to get %+v got %+v", insertedHorse, retrievedHorse)
+		t.Fatalf("expected to get %+v got %+v", insertedHorse, retrievedHorse)
 	}
 }