浏览代码

conn: only setup the TLS config once

we can reuse the TLS Config between connections, so avoid the
expensive read and cert parsing when dialing connections, instead
do this in the connection pool once during startup.
Chris Bannister 10 年之前
父节点
当前提交
ef3e59c0a6
共有 4 个文件被更改,包括 66 次插入36 次删除
  1. 4 1
      cluster.go
  2. 5 28
      conn.go
  3. 48 4
      connectionpool.go
  4. 9 3
      session_test.go

+ 4 - 1
cluster.go

@@ -113,7 +113,10 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 		cfg.NumStreams = maxStreams
 	}
 
-	pool := cfg.ConnPoolType(cfg)
+	pool, err := cfg.ConnPoolType(cfg)
+	if err != nil {
+		return nil, err
+	}
 
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	stmtsLRU.Lock()

+ 5 - 28
conn.go

@@ -7,11 +7,9 @@ package gocql
 import (
 	"bufio"
 	"crypto/tls"
-	"crypto/x509"
 	"errors"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"log"
 	"net"
 	"strconv"
@@ -81,7 +79,7 @@ type ConnConfig struct {
 	Compressor    Compressor
 	Authenticator Authenticator
 	Keepalive     time.Duration
-	SslOpts       *SslOptions
+	TLSConfig     *tls.Config
 }
 
 // Conn is a single connection to a Cassandra node. It can be used to execute
@@ -115,31 +113,10 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		conn net.Conn
 	)
 
-	if cfg.SslOpts != nil {
-		certPool := x509.NewCertPool()
-		//ca cert is optional
-		if cfg.SslOpts.CaPath != "" {
-			pem, err := ioutil.ReadFile(cfg.SslOpts.CaPath)
-			if err != nil {
-				return nil, err
-			}
-			if !certPool.AppendCertsFromPEM(pem) {
-				return nil, errors.New("Failed parsing or appending certs")
-			}
-		}
-
-		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
-		if err != nil {
-			return nil, err
-		}
-
-		config := tls.Config{
-			Certificates: []tls.Certificate{mycert},
-			RootCAs:      certPool,
-		}
-
-		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
-		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
+	if cfg.TLSConfig != nil {
+		// the TLS config is safe to be reused by connections but it must not
+		// be modified after being used.
+		if conn, err = tls.Dial("tcp", addr, cfg.TLSConfig); err != nil {
 			return nil, err
 		}
 	} else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {

+ 48 - 4
connectionpool.go

@@ -1,6 +1,11 @@
 package gocql
 
 import (
+	"crypto/tls"
+	"crypto/x509"
+	"errors"
+	"fmt"
+	"io/ioutil"
 	"log"
 	"sync"
 	"time"
@@ -91,7 +96,7 @@ type ConnectionPool interface {
 }
 
 //NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
-type NewPoolFunc func(*ClusterConfig) ConnectionPool
+type NewPoolFunc func(*ClusterConfig) (ConnectionPool, error)
 
 //SimplePool is the current implementation of the connection pool inside gocql. This
 //pool is meant to be a simple default used by gocql so users can get up and running
@@ -115,11 +120,42 @@ type SimplePool struct {
 	quit     bool
 	quitWait chan bool
 	quitOnce sync.Once
+
+	tlsConfig *tls.Config
+}
+
+func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
+	certPool := x509.NewCertPool()
+	// ca cert is optional
+	if sslOpts.CaPath != "" {
+		pem, err := ioutil.ReadFile(sslOpts.CaPath)
+		if err != nil {
+			return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
+		}
+
+		if !certPool.AppendCertsFromPEM(pem) {
+			return nil, errors.New("connectionpool: failed parsing or CA certs")
+		}
+	}
+
+	mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
+	if err != nil {
+		return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
+	}
+
+	config := &tls.Config{
+		Certificates: []tls.Certificate{mycert},
+		RootCAs:      certPool,
+	}
+
+	config.InsecureSkipVerify = !sslOpts.EnableHostVerification
+
+	return config, nil
 }
 
 //NewSimplePool is the function used by gocql to create the simple connection pool.
 //This is the default if no other pool type is specified.
-func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
+func NewSimplePool(cfg *ClusterConfig) (ConnectionPool, error) {
 	pool := &SimplePool{
 		cfg:          cfg,
 		hostPool:     NewRoundRobin(),
@@ -137,6 +173,14 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
 		pool.hosts[host] = &HostInfo{Peer: host}
 	}
 
+	if cfg.SslOpts != nil {
+		config, err := setupTLSConfig(cfg.SslOpts)
+		if err != nil {
+			return nil, err
+		}
+		pool.tlsConfig = config
+	}
+
 	//Walk through connecting to hosts. As soon as one host connects
 	//defer the remaining connections to cluster.fillPool()
 	for i := 0; i < len(cfg.Hosts); i++ {
@@ -149,7 +193,7 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
 		}
 	}
 
-	return pool
+	return pool, nil
 }
 
 func (c *SimplePool) connect(addr string) error {
@@ -162,7 +206,7 @@ func (c *SimplePool) connect(addr string) error {
 		Compressor:    c.cfg.Compressor,
 		Authenticator: c.cfg.Authenticator,
 		Keepalive:     c.cfg.SocketKeepalive,
-		SslOpts:       c.cfg.SslOpts,
+		TLSConfig:     c.tlsConfig,
 	}
 
 	conn, err := Connect(addr, cfg, c)

+ 9 - 3
session_test.go

@@ -9,7 +9,10 @@ import (
 func TestSessionAPI(t *testing.T) {
 
 	cfg := ClusterConfig{}
-	pool := NewSimplePool(&cfg)
+	pool, err := NewSimplePool(&cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 	s := NewSession(pool, cfg)
 
@@ -60,7 +63,7 @@ func TestSessionAPI(t *testing.T) {
 
 	testBatch := s.NewBatch(LoggedBatch)
 	testBatch.Query("test")
-	err := s.ExecuteBatch(testBatch)
+	err = s.ExecuteBatch(testBatch)
 
 	if err != ErrNoConnections {
 		t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err)
@@ -151,7 +154,10 @@ func TestBatchBasicAPI(t *testing.T) {
 
 	cfg := ClusterConfig{}
 	cfg.RetryPolicy = &SimpleRetryPolicy{NumRetries: 2}
-	pool := NewSimplePool(&cfg)
+	pool, err := NewSimplePool(&cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 	s := NewSession(pool, cfg)
 	b := s.NewBatch(UnloggedBatch)