浏览代码

Merge pull request #286 from phillipCouto/address_formatting

Moved address formatting logic to separate function
Ben Hood 11 年之前
父节点
当前提交
397f2cd84f
共有 3 个文件被更改,包括 30 次插入15 次删除
  1. 12 0
      conn.go
  2. 14 0
      conn_test.go
  3. 4 15
      connectionpool.go

+ 12 - 0
conn.go

@@ -12,6 +12,8 @@ import (
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
+	"strconv"
+	"strings"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
@@ -21,6 +23,16 @@ const defaultFrameSize = 4096
 const flagResponse = 0x80
 const flagResponse = 0x80
 const maskVersion = 0x7F
 const maskVersion = 0x7F
 
 
+//JoinHostPort is a utility to return a address string that can be used
+//gocql.Conn to form a connection with a host.
+func JoinHostPort(addr string, port int) string {
+	addr = strings.TrimSpace(addr)
+	if _, _, err := net.SplitHostPort(addr); err != nil {
+		addr = net.JoinHostPort(addr, strconv.Itoa(port))
+	}
+	return addr
+}
+
 type Authenticator interface {
 type Authenticator interface {
 	Challenge(req []byte) (resp []byte, auth Authenticator, err error)
 	Challenge(req []byte) (resp []byte, auth Authenticator, err error)
 	Success(data []byte) error
 	Success(data []byte) error

+ 14 - 0
conn_test.go

@@ -15,6 +15,20 @@ import (
 	"time"
 	"time"
 )
 )
 
 
+func TestJoinHostPort(t *testing.T) {
+	tests := map[string]string{
+		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
+		"127.0.0.1:1":                                 JoinHostPort("127.0.0.1:1", 9142),
+		"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:0": JoinHostPort("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 0),
+		"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1": JoinHostPort("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1", 9142),
+	}
+	for k, v := range tests {
+		if k != v {
+			t.Fatalf("expected '%v', got '%v'", k, v)
+		}
+	}
+}
+
 type TestServer struct {
 type TestServer struct {
 	Address  string
 	Address  string
 	t        *testing.T
 	t        *testing.T

+ 4 - 15
connectionpool.go

@@ -2,9 +2,6 @@ package gocql
 
 
 import (
 import (
 	"log"
 	"log"
-	"net"
-	"strconv"
-	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
 )
 )
@@ -29,10 +26,8 @@ Example of Single Connection Pool:
 	}
 	}
 
 
 	func NewSingleConnection(cfg *ClusterConfig) ConnectionPool {
 	func NewSingleConnection(cfg *ClusterConfig) ConnectionPool {
-		addr := strings.TrimSpace(cfg.Hosts[0])
-		if strings.Index(addr, ":") < 0 {
-			addr = fmt.Sprintf("%s:%d", addr, cfg.Port)
-		}
+		addr := JoinHostPort(cfg.Hosts[0], cfg.Port)
+
 		connCfg := ConnConfig{
 		connCfg := ConnConfig{
 			ProtoVersion:  cfg.ProtoVersion,
 			ProtoVersion:  cfg.ProtoVersion,
 			CQLVersion:    cfg.CQLVersion,
 			CQLVersion:    cfg.CQLVersion,
@@ -145,10 +140,7 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
 	//Walk through connecting to hosts. As soon as one host connects
 	//Walk through connecting to hosts. As soon as one host connects
 	//defer the remaining connections to cluster.fillPool()
 	//defer the remaining connections to cluster.fillPool()
 	for i := 0; i < len(cfg.Hosts); i++ {
 	for i := 0; i < len(cfg.Hosts); i++ {
-		addr := strings.TrimSpace(cfg.Hosts[i])
-		if _, _, err := net.SplitHostPort(addr); err != nil {
-			addr = net.JoinHostPort(addr, strconv.Itoa(cfg.Port))
-		}
+		addr := JoinHostPort(cfg.Hosts[i], cfg.Port)
 
 
 		if pool.connect(addr) == nil {
 		if pool.connect(addr) == nil {
 			pool.cFillingPool <- 1
 			pool.cFillingPool <- 1
@@ -236,10 +228,7 @@ func (c *SimplePool) fillPool() {
 
 
 	//Walk through list of defined hosts
 	//Walk through list of defined hosts
 	for host := range c.hosts {
 	for host := range c.hosts {
-		addr := strings.TrimSpace(host)
-		if _, _, err := net.SplitHostPort(addr); err != nil {
-			addr = net.JoinHostPort(addr, strconv.Itoa(c.cfg.Port))
-		}
+		addr := JoinHostPort(host, c.cfg.Port)
 
 
 		numConns := 1
 		numConns := 1
 		//See if the host already has connections in the pool
 		//See if the host already has connections in the pool