浏览代码

conn: only setup the TLS config once

we can reuse the TLS Config between connections, so avoid the
expensive read and cert parsing when dialing connections, instead
do this in the connection pool once during startup.
Chris Bannister 10 年之前
父节点
当前提交
ef3e59c0a6
共有 4 个文件被更改,包括 66 次插入36 次删除
  1. 4 1
      cluster.go
  2. 5 28
      conn.go
  3. 48 4
      connectionpool.go
  4. 9 3
      session_test.go

+ 4 - 1
cluster.go

@@ -113,7 +113,10 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 		cfg.NumStreams = maxStreams
 		cfg.NumStreams = maxStreams
 	}
 	}
 
 
-	pool := cfg.ConnPoolType(cfg)
+	pool, err := cfg.ConnPoolType(cfg)
+	if err != nil {
+		return nil, err
+	}
 
 
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	stmtsLRU.Lock()
 	stmtsLRU.Lock()

+ 5 - 28
conn.go

@@ -7,11 +7,9 @@ package gocql
 import (
 import (
 	"bufio"
 	"bufio"
 	"crypto/tls"
 	"crypto/tls"
-	"crypto/x509"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"io/ioutil"
 	"log"
 	"log"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
@@ -81,7 +79,7 @@ type ConnConfig struct {
 	Compressor    Compressor
 	Compressor    Compressor
 	Authenticator Authenticator
 	Authenticator Authenticator
 	Keepalive     time.Duration
 	Keepalive     time.Duration
-	SslOpts       *SslOptions
+	TLSConfig     *tls.Config
 }
 }
 
 
 // Conn is a single connection to a Cassandra node. It can be used to execute
 // Conn is a single connection to a Cassandra node. It can be used to execute
@@ -115,31 +113,10 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		conn net.Conn
 		conn net.Conn
 	)
 	)
 
 
-	if cfg.SslOpts != nil {
-		certPool := x509.NewCertPool()
-		//ca cert is optional
-		if cfg.SslOpts.CaPath != "" {
-			pem, err := ioutil.ReadFile(cfg.SslOpts.CaPath)
-			if err != nil {
-				return nil, err
-			}
-			if !certPool.AppendCertsFromPEM(pem) {
-				return nil, errors.New("Failed parsing or appending certs")
-			}
-		}
-
-		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
-		if err != nil {
-			return nil, err
-		}
-
-		config := tls.Config{
-			Certificates: []tls.Certificate{mycert},
-			RootCAs:      certPool,
-		}
-
-		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
-		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
+	if cfg.TLSConfig != nil {
+		// the TLS config is safe to be reused by connections but it must not
+		// be modified after being used.
+		if conn, err = tls.Dial("tcp", addr, cfg.TLSConfig); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 	} else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {
 	} else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {

+ 48 - 4
connectionpool.go

@@ -1,6 +1,11 @@
 package gocql
 package gocql
 
 
 import (
 import (
+	"crypto/tls"
+	"crypto/x509"
+	"errors"
+	"fmt"
+	"io/ioutil"
 	"log"
 	"log"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -91,7 +96,7 @@ type ConnectionPool interface {
 }
 }
 
 
 //NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
 //NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
-type NewPoolFunc func(*ClusterConfig) ConnectionPool
+type NewPoolFunc func(*ClusterConfig) (ConnectionPool, error)
 
 
 //SimplePool is the current implementation of the connection pool inside gocql. This
 //SimplePool is the current implementation of the connection pool inside gocql. This
 //pool is meant to be a simple default used by gocql so users can get up and running
 //pool is meant to be a simple default used by gocql so users can get up and running
@@ -115,11 +120,42 @@ type SimplePool struct {
 	quit     bool
 	quit     bool
 	quitWait chan bool
 	quitWait chan bool
 	quitOnce sync.Once
 	quitOnce sync.Once
+
+	tlsConfig *tls.Config
+}
+
+func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
+	certPool := x509.NewCertPool()
+	// ca cert is optional
+	if sslOpts.CaPath != "" {
+		pem, err := ioutil.ReadFile(sslOpts.CaPath)
+		if err != nil {
+			return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
+		}
+
+		if !certPool.AppendCertsFromPEM(pem) {
+			return nil, errors.New("connectionpool: failed parsing or CA certs")
+		}
+	}
+
+	mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
+	if err != nil {
+		return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
+	}
+
+	config := &tls.Config{
+		Certificates: []tls.Certificate{mycert},
+		RootCAs:      certPool,
+	}
+
+	config.InsecureSkipVerify = !sslOpts.EnableHostVerification
+
+	return config, nil
 }
 }
 
 
 //NewSimplePool is the function used by gocql to create the simple connection pool.
 //NewSimplePool is the function used by gocql to create the simple connection pool.
 //This is the default if no other pool type is specified.
 //This is the default if no other pool type is specified.
-func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
+func NewSimplePool(cfg *ClusterConfig) (ConnectionPool, error) {
 	pool := &SimplePool{
 	pool := &SimplePool{
 		cfg:          cfg,
 		cfg:          cfg,
 		hostPool:     NewRoundRobin(),
 		hostPool:     NewRoundRobin(),
@@ -137,6 +173,14 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
 		pool.hosts[host] = &HostInfo{Peer: host}
 		pool.hosts[host] = &HostInfo{Peer: host}
 	}
 	}
 
 
+	if cfg.SslOpts != nil {
+		config, err := setupTLSConfig(cfg.SslOpts)
+		if err != nil {
+			return nil, err
+		}
+		pool.tlsConfig = config
+	}
+
 	//Walk through connecting to hosts. As soon as one host connects
 	//Walk through connecting to hosts. As soon as one host connects
 	//defer the remaining connections to cluster.fillPool()
 	//defer the remaining connections to cluster.fillPool()
 	for i := 0; i < len(cfg.Hosts); i++ {
 	for i := 0; i < len(cfg.Hosts); i++ {
@@ -149,7 +193,7 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
 		}
 		}
 	}
 	}
 
 
-	return pool
+	return pool, nil
 }
 }
 
 
 func (c *SimplePool) connect(addr string) error {
 func (c *SimplePool) connect(addr string) error {
@@ -162,7 +206,7 @@ func (c *SimplePool) connect(addr string) error {
 		Compressor:    c.cfg.Compressor,
 		Compressor:    c.cfg.Compressor,
 		Authenticator: c.cfg.Authenticator,
 		Authenticator: c.cfg.Authenticator,
 		Keepalive:     c.cfg.SocketKeepalive,
 		Keepalive:     c.cfg.SocketKeepalive,
-		SslOpts:       c.cfg.SslOpts,
+		TLSConfig:     c.tlsConfig,
 	}
 	}
 
 
 	conn, err := Connect(addr, cfg, c)
 	conn, err := Connect(addr, cfg, c)

+ 9 - 3
session_test.go

@@ -9,7 +9,10 @@ import (
 func TestSessionAPI(t *testing.T) {
 func TestSessionAPI(t *testing.T) {
 
 
 	cfg := ClusterConfig{}
 	cfg := ClusterConfig{}
-	pool := NewSimplePool(&cfg)
+	pool, err := NewSimplePool(&cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 
 	s := NewSession(pool, cfg)
 	s := NewSession(pool, cfg)
 
 
@@ -60,7 +63,7 @@ func TestSessionAPI(t *testing.T) {
 
 
 	testBatch := s.NewBatch(LoggedBatch)
 	testBatch := s.NewBatch(LoggedBatch)
 	testBatch.Query("test")
 	testBatch.Query("test")
-	err := s.ExecuteBatch(testBatch)
+	err = s.ExecuteBatch(testBatch)
 
 
 	if err != ErrNoConnections {
 	if err != ErrNoConnections {
 		t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err)
 		t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err)
@@ -151,7 +154,10 @@ func TestBatchBasicAPI(t *testing.T) {
 
 
 	cfg := ClusterConfig{}
 	cfg := ClusterConfig{}
 	cfg.RetryPolicy = &SimpleRetryPolicy{NumRetries: 2}
 	cfg.RetryPolicy = &SimpleRetryPolicy{NumRetries: 2}
-	pool := NewSimplePool(&cfg)
+	pool, err := NewSimplePool(&cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 
 	s := NewSession(pool, cfg)
 	s := NewSession(pool, cfg)
 	b := s.NewBatch(UnloggedBatch)
 	b := s.NewBatch(UnloggedBatch)