Ver código fonte

adding AddressTranslator interface and impl for use in ec2

This change introduces the AddressTranslator interface, which is
intended to translate peer addresses just before creating a connection
to those nodes. The primary use -- which is driving the change -- is
to be able to translate public IPs to private IPs in ec2.

This solution is common among other CQL driver implementations. The
specific implementation here also follows the convention set by
HostFilter.

Signed-off-by: Justin "Gus" Knowlden <gus@gusg.us>
Charles Frantz 9 anos atrás
pai
commit
d93ce32f1b
13 arquivos alterados com 325 adições e 9 exclusões
  1. 26 0
      address_translators.go
  2. 34 0
      address_translators_test.go
  3. 21 0
      cluster.go
  4. 53 0
      cluster_test.go
  5. 51 0
      common_test.go
  6. 3 1
      conn.go
  7. 2 2
      conn_test.go
  8. 2 2
      control.go
  9. 0 1
      events.go
  10. 1 1
      integration.sh
  11. 131 0
      session_connect_test.go
  12. 0 1
      session_test.go
  13. 1 1
      udt_test.go

+ 26 - 0
address_translators.go

@@ -0,0 +1,26 @@
+package gocql
+
+import "net"
+
+// AddressTranslator provides a way to translate node addresses (and ports) that are
+// discovered or received as a node event. This can be useful in an ec2 environment,
+// for instance, to translate public IPs to private IPs.
+type AddressTranslator interface {
+	// Translate will translate the provided address and/or port to another
+	// address and/or port. If no translation is possible, Translate will return the
+	// address and port provided to it.
+	Translate(addr net.IP, port int) (net.IP, int)
+}
+
+type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)
+
+func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
+	return fn(addr, port)
+}
+
+// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
+func IdentityTranslator() AddressTranslator {
+	return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
+		return addr, port
+	})
+}

+ 34 - 0
address_translators_test.go

@@ -0,0 +1,34 @@
+package gocql
+
+import (
+	"net"
+	"testing"
+)
+
+func TestIdentityAddressTranslator_NilAddrAndZeroPort(t *testing.T) {
+	var tr AddressTranslator = IdentityTranslator()
+	hostIP := net.ParseIP("")
+	if hostIP != nil {
+		t.Errorf("expected host ip to be (nil) but was (%+v) instead", hostIP)
+	}
+
+	addr, port := tr.Translate(hostIP, 0)
+	if addr != nil {
+		t.Errorf("expected translated host to be (nil) but was (%+v) instead", addr)
+	}
+	assertEqual(t, "translated port", 0, port)
+}
+
+func TestIdentityAddressTranslator_HostProvided(t *testing.T) {
+	var tr AddressTranslator = IdentityTranslator()
+	hostIP := net.ParseIP("10.1.2.3")
+	if hostIP == nil {
+		t.Error("expected host ip not to be (nil)")
+	}
+
+	addr, port := tr.Translate(hostIP, 9042)
+	if !hostIP.Equal(addr) {
+		t.Errorf("expected translated addr to be (%+v) but was (%+v) instead", hostIP, addr)
+	}
+	assertEqual(t, "translated port", 9042, port)
+}

+ 21 - 0
cluster.go

@@ -6,6 +6,8 @@ package gocql
 
 import (
 	"errors"
+	"log"
+	"net"
 	"time"
 )
 
@@ -75,6 +77,10 @@ type ClusterConfig struct {
 	// via Discovery
 	HostFilter HostFilter
 
+	// AddressTranslator will translate addresses found on peer discovery and/or
+	// node change events.
+	AddressTranslator AddressTranslator
+
 	// If IgnorePeerAddr is true and the address in system.peers does not match
 	// the supplied host by either initial hosts or discovered via events then the
 	// host will be replaced with the supplied address.
@@ -146,6 +152,21 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	return NewSession(*cfg)
 }
 
+// translateAddressPort is a helper method that will use the given AddressTranslator
+// if defined, to translate the given address and port into a possibly new address
+// and port, If no AddressTranslator or if an error occurs, the given address and
+// port will be returned.
+func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
+	if cfg.AddressTranslator == nil || len(addr) == 0 {
+		return addr, port
+	}
+	newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
+	if gocqlDebug {
+		log.Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
+	}
+	return newAddr, newPort
+}
+
 var (
 	ErrNoHosts              = errors.New("no hosts provided")
 	ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")

+ 53 - 0
cluster_test.go

@@ -0,0 +1,53 @@
+package gocql
+
+import (
+	"testing"
+	"time"
+	"net"
+)
+
+func TestNewCluster_Defaults(t *testing.T) {
+	cfg := NewCluster()
+	assertEqual(t, "cluster config cql version", "3.0.0", cfg.CQLVersion)
+	assertEqual(t, "cluster config timeout", 600*time.Millisecond, cfg.Timeout)
+	assertEqual(t, "cluster config port", 9042, cfg.Port)
+	assertEqual(t, "cluster config num-conns", 2, cfg.NumConns)
+	assertEqual(t, "cluster config consistency", Quorum, cfg.Consistency)
+	assertEqual(t, "cluster config max prepared statements", defaultMaxPreparedStmts, cfg.MaxPreparedStmts)
+	assertEqual(t, "cluster config max routing key info", 1000, cfg.MaxRoutingKeyInfo)
+	assertEqual(t, "cluster config page-size", 5000, cfg.PageSize)
+	assertEqual(t, "cluster config default timestamp", true, cfg.DefaultTimestamp)
+	assertEqual(t, "cluster config max wait schema agreement", 60*time.Second, cfg.MaxWaitSchemaAgreement)
+	assertEqual(t, "cluster config reconnect interval", 60*time.Second, cfg.ReconnectInterval)
+}
+
+func TestNewCluster_WithHosts(t *testing.T) {
+	cfg := NewCluster("addr1", "addr2")
+	assertEqual(t, "cluster config hosts length", 2, len(cfg.Hosts))
+	assertEqual(t, "cluster config host 0", "addr1", cfg.Hosts[0])
+	assertEqual(t, "cluster config host 1", "addr2", cfg.Hosts[1])
+}
+
+func TestClusterConfig_translateAddressAndPort_NilTranslator(t *testing.T) {
+	cfg := NewCluster()
+	assertNil(t, "cluster config address translator", cfg.AddressTranslator)
+	newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 1234)
+	assertTrue(t, "same address as provided", net.ParseIP("10.0.0.1").Equal(newAddr))
+	assertEqual(t, "translated host and port", 1234, newPort)
+}
+
+func TestClusterConfig_translateAddressAndPort_EmptyAddr(t *testing.T) {
+	cfg := NewCluster()
+	cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
+	newAddr, newPort := cfg.translateAddressPort(net.IP([]byte{}), 0)
+	assertTrue(t, "translated address is still empty", len(newAddr) == 0)
+	assertEqual(t, "translated port", 0, newPort)
+}
+
+func TestClusterConfig_translateAddressAndPort_Success(t *testing.T) {
+	cfg := NewCluster()
+	cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
+	newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 2345)
+	assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr))
+	assertEqual(t, "translated port", 5432, newPort)
+}

+ 51 - 0
common_test.go

@@ -8,6 +8,7 @@ import (
 	"sync"
 	"testing"
 	"time"
+	"net"
 )
 
 var (
@@ -143,3 +144,53 @@ func createSession(tb testing.TB) *Session {
 	cluster := createCluster()
 	return createSessionFromCluster(cluster, tb)
 }
+
+// createTestSession is hopefully moderately useful in actual unit tests
+func createTestSession() *Session {
+	config := NewCluster()
+	config.NumConns = 1
+	config.Timeout = 0
+	config.DisableInitialHostLookup = true
+	config.IgnorePeerAddr = true
+	config.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+	session := &Session{
+		cfg:    *config,
+		connCfg: &ConnConfig{
+			Timeout: 10*time.Millisecond,
+			Keepalive: 0,
+		},
+		policy: config.PoolConfig.HostSelectionPolicy,
+	}
+	session.pool = config.PoolConfig.buildPool(session)
+	return session
+}
+
+func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
+	return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
+		return newAddr, newPort
+	})
+}
+
+func assertTrue(t *testing.T, description string, value bool) {
+	if !value {
+		t.Errorf("expected %s to be true", description)
+	}
+}
+
+func assertEqual(t *testing.T, description string, expected, actual interface{}) {
+	if expected != actual {
+		t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
+	}
+}
+
+func assertNil(t *testing.T, description string, actual interface{}) {
+	if actual != nil {
+		t.Errorf("expected %s to be (nil) but was (%+v) instead", description, actual)
+	}
+}
+
+func assertNotNil(t *testing.T, description string, actual interface{}) {
+	if actual == nil {
+		t.Errorf("expected %s not to be (nil)", description)
+	}
+}

+ 3 - 1
conn.go

@@ -172,7 +172,9 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
 	}
 
 	// TODO(zariel): handle ipv6 zone
-	addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
+	translatedPeer, translatedPort := session.cfg.translateAddressPort(host.Peer(), host.Port())
+	addr := (&net.TCPAddr{IP: translatedPeer, Port: translatedPort}).String()
+	//addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
 
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not

+ 2 - 2
conn_test.go

@@ -474,7 +474,7 @@ func TestStream0(t *testing.T) {
 		}
 	})
 
-	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -509,7 +509,7 @@ func TestConnClosedBlocked(t *testing.T) {
 		t.Log(err)
 	})
 
-	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
 	if err != nil {
 		t.Fatal(err)
 	}

+ 2 - 2
control.go

@@ -318,8 +318,8 @@ func (c *controlConn) reconnect(refreshring bool) {
 		}
 	}
 
-	// TODO: should have our own roundrobbin for hosts so that we can try each
-	// in succession and guantee that we get a different host each time.
+	// TODO: should have our own round-robin for hosts so that we can try each
+	// in succession and guarantee that we get a different host each time.
 	if newConn == nil {
 		host := c.session.ring.rrHost()
 		if host == nil {

+ 0 - 1
events.go

@@ -205,7 +205,6 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	s.pool.addHost(hostInfo)
 	s.policy.AddHost(hostInfo)
 	hostInfo.setState(NodeUp)
-
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
 		s.hostSource.refreshRing()
 	}

+ 1 - 1
integration.sh

@@ -64,7 +64,7 @@ function run_tests() {
 
 	local args="-gocql.timeout=60s -runssl -proto=$proto -rf=3 -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy -gocql.cversion=$version -cluster=$(ccm liveset) ./..."
 
-	go test -v -tags unit
+    go test -v -tags unit
 
 	if [ "$auth" = true ]
 	then

+ 131 - 0
session_connect_test.go

@@ -0,0 +1,131 @@
+package gocql
+
+import (
+	"golang.org/x/net/context"
+	"net"
+	"strconv"
+	"sync"
+	"testing"
+	"time"
+)
+
+type OneConnTestServer struct {
+	Err  error
+	Addr net.IP
+	Port int
+
+	listener   net.Listener
+	acceptChan chan struct{}
+	mu         sync.Mutex
+	closed     bool
+}
+
+func NewOneConnTestServer() (*OneConnTestServer, error) {
+	lstn, err := net.Listen("tcp4", "localhost:0")
+	if err != nil {
+		return nil, err
+	}
+	addr, port := parseAddressPort(lstn.Addr().String())
+	return &OneConnTestServer{
+		listener:   lstn,
+		acceptChan: make(chan struct{}),
+		Addr:       addr,
+		Port:       port,
+	}, nil
+}
+
+func (c *OneConnTestServer) Accepted() chan struct{} {
+	return c.acceptChan
+}
+
+func (c *OneConnTestServer) Close() {
+	c.lockedClose()
+}
+
+func (c *OneConnTestServer) Serve() {
+	conn, err := c.listener.Accept()
+	c.Err = err
+	if conn != nil {
+		conn.Close()
+	}
+	c.lockedClose()
+}
+
+func (c *OneConnTestServer) lockedClose() {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if !c.closed {
+		close(c.acceptChan)
+		c.listener.Close()
+		c.closed = true
+	}
+}
+
+func parseAddressPort(hostPort string) (net.IP, int) {
+	host, portStr, err := net.SplitHostPort(hostPort)
+	if err != nil {
+		return net.ParseIP(""), 0
+	}
+	port, _ := strconv.Atoi(portStr)
+	return net.ParseIP(host), port
+}
+
+func testConnErrorHandler(t *testing.T) ConnErrorHandler {
+	return connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
+		t.Errorf("in connection handler: %v", err)
+	})
+}
+
+func assertConnectionEventually(t *testing.T, wait time.Duration, srvr *OneConnTestServer) {
+	ctx, cancel := context.WithTimeout(context.Background(), wait)
+	defer cancel()
+
+	select {
+	case <-ctx.Done():
+		if ctx.Err() != nil {
+			t.Errorf("waiting for connection: %v", ctx.Err())
+		}
+	case <-srvr.Accepted():
+		if srvr.Err != nil {
+			t.Errorf("accepting connection: %v", srvr.Err)
+		}
+	}
+}
+
+func TestSession_connect_WithNoTranslator(t *testing.T) {
+	srvr, err := NewOneConnTestServer()
+	assertNil(t, "error when creating tcp server", err)
+	defer srvr.Close()
+
+	session := createTestSession()
+	defer session.Close()
+
+	go srvr.Serve()
+
+	Connect(&HostInfo{
+		peer: srvr.Addr,
+		port: srvr.Port,
+	}, session.connCfg, testConnErrorHandler(t), session)
+
+	assertConnectionEventually(t, 500*time.Millisecond, srvr)
+}
+
+func TestSession_connect_WithTranslator(t *testing.T) {
+	srvr, err := NewOneConnTestServer()
+	assertNil(t, "error when creating tcp server", err)
+	defer srvr.Close()
+
+	session := createTestSession()
+	defer session.Close()
+	session.cfg.AddressTranslator = staticAddressTranslator(srvr.Addr, srvr.Port)
+
+	go srvr.Serve()
+
+	// the provided address will be translated
+	Connect(&HostInfo{
+		peer: net.ParseIP("10.10.10.10"),
+		port: 5432,
+	}, session.connCfg, testConnErrorHandler(t), session)
+
+	assertConnectionEventually(t, 500*time.Millisecond, srvr)
+}

+ 0 - 1
session_test.go

@@ -8,7 +8,6 @@ import (
 )
 
 func TestSessionAPI(t *testing.T) {
-
 	cfg := &ClusterConfig{}
 
 	s := &Session{

+ 1 - 1
udt_test.go

@@ -147,7 +147,7 @@ func TestUDT_Reflect(t *testing.T) {
 	}
 
 	if *retrievedHorse != *insertedHorse {
-		t.Fatalf("exepcted to get %+v got %+v", insertedHorse, retrievedHorse)
+		t.Fatalf("expected to get %+v got %+v", insertedHorse, retrievedHorse)
 	}
 }