Browse Source

Merge pull request #327 from Zariel/dont-read-ca-dial

conn: only setup the TLS config once
Chris Bannister 10 years ago
parent
commit
cb5d261084
4 changed files with 66 additions and 36 deletions
  1. 4 1
      cluster.go
  2. 5 28
      conn.go
  3. 48 4
      connectionpool.go
  4. 9 3
      session_test.go

+ 4 - 1
cluster.go

@@ -113,7 +113,10 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 		cfg.NumStreams = maxStreams
 		cfg.NumStreams = maxStreams
 	}
 	}
 
 
-	pool := cfg.ConnPoolType(cfg)
+	pool, err := cfg.ConnPoolType(cfg)
+	if err != nil {
+		return nil, err
+	}
 
 
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	stmtsLRU.Lock()
 	stmtsLRU.Lock()

+ 5 - 28
conn.go

@@ -7,11 +7,9 @@ package gocql
 import (
 import (
 	"bufio"
 	"bufio"
 	"crypto/tls"
 	"crypto/tls"
-	"crypto/x509"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"io/ioutil"
 	"log"
 	"log"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
@@ -81,7 +79,7 @@ type ConnConfig struct {
 	Compressor    Compressor
 	Compressor    Compressor
 	Authenticator Authenticator
 	Authenticator Authenticator
 	Keepalive     time.Duration
 	Keepalive     time.Duration
-	SslOpts       *SslOptions
+	tlsConfig     *tls.Config
 }
 }
 
 
 // Conn is a single connection to a Cassandra node. It can be used to execute
 // Conn is a single connection to a Cassandra node. It can be used to execute
@@ -115,31 +113,10 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		conn net.Conn
 		conn net.Conn
 	)
 	)
 
 
-	if cfg.SslOpts != nil {
-		certPool := x509.NewCertPool()
-		//ca cert is optional
-		if cfg.SslOpts.CaPath != "" {
-			pem, err := ioutil.ReadFile(cfg.SslOpts.CaPath)
-			if err != nil {
-				return nil, err
-			}
-			if !certPool.AppendCertsFromPEM(pem) {
-				return nil, errors.New("Failed parsing or appending certs")
-			}
-		}
-
-		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
-		if err != nil {
-			return nil, err
-		}
-
-		config := tls.Config{
-			Certificates: []tls.Certificate{mycert},
-			RootCAs:      certPool,
-		}
-
-		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
-		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
+	if cfg.tlsConfig != nil {
+		// the TLS config is safe to be reused by connections but it must not
+		// be modified after being used.
+		if conn, err = tls.Dial("tcp", addr, cfg.tlsConfig); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 	} else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {
 	} else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {

+ 48 - 4
connectionpool.go

@@ -1,6 +1,11 @@
 package gocql
 package gocql
 
 
 import (
 import (
+	"crypto/tls"
+	"crypto/x509"
+	"errors"
+	"fmt"
+	"io/ioutil"
 	"log"
 	"log"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -91,7 +96,7 @@ type ConnectionPool interface {
 }
 }
 
 
 //NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
 //NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
-type NewPoolFunc func(*ClusterConfig) ConnectionPool
+type NewPoolFunc func(*ClusterConfig) (ConnectionPool, error)
 
 
 //SimplePool is the current implementation of the connection pool inside gocql. This
 //SimplePool is the current implementation of the connection pool inside gocql. This
 //pool is meant to be a simple default used by gocql so users can get up and running
 //pool is meant to be a simple default used by gocql so users can get up and running
@@ -115,11 +120,42 @@ type SimplePool struct {
 	quit     bool
 	quit     bool
 	quitWait chan bool
 	quitWait chan bool
 	quitOnce sync.Once
 	quitOnce sync.Once
+
+	tlsConfig *tls.Config
+}
+
+func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
+	certPool := x509.NewCertPool()
+	// ca cert is optional
+	if sslOpts.CaPath != "" {
+		pem, err := ioutil.ReadFile(sslOpts.CaPath)
+		if err != nil {
+			return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
+		}
+
+		if !certPool.AppendCertsFromPEM(pem) {
+			return nil, errors.New("connectionpool: failed parsing or CA certs")
+		}
+	}
+
+	mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
+	if err != nil {
+		return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
+	}
+
+	config := &tls.Config{
+		Certificates: []tls.Certificate{mycert},
+		RootCAs:      certPool,
+	}
+
+	config.InsecureSkipVerify = !sslOpts.EnableHostVerification
+
+	return config, nil
 }
 }
 
 
 //NewSimplePool is the function used by gocql to create the simple connection pool.
 //NewSimplePool is the function used by gocql to create the simple connection pool.
 //This is the default if no other pool type is specified.
 //This is the default if no other pool type is specified.
-func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
+func NewSimplePool(cfg *ClusterConfig) (ConnectionPool, error) {
 	pool := &SimplePool{
 	pool := &SimplePool{
 		cfg:          cfg,
 		cfg:          cfg,
 		hostPool:     NewRoundRobin(),
 		hostPool:     NewRoundRobin(),
@@ -137,6 +173,14 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
 		pool.hosts[host] = &HostInfo{Peer: host}
 		pool.hosts[host] = &HostInfo{Peer: host}
 	}
 	}
 
 
+	if cfg.SslOpts != nil {
+		config, err := setupTLSConfig(cfg.SslOpts)
+		if err != nil {
+			return nil, err
+		}
+		pool.tlsConfig = config
+	}
+
 	//Walk through connecting to hosts. As soon as one host connects
 	//Walk through connecting to hosts. As soon as one host connects
 	//defer the remaining connections to cluster.fillPool()
 	//defer the remaining connections to cluster.fillPool()
 	for i := 0; i < len(cfg.Hosts); i++ {
 	for i := 0; i < len(cfg.Hosts); i++ {
@@ -149,7 +193,7 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
 		}
 		}
 	}
 	}
 
 
-	return pool
+	return pool, nil
 }
 }
 
 
 func (c *SimplePool) connect(addr string) error {
 func (c *SimplePool) connect(addr string) error {
@@ -162,7 +206,7 @@ func (c *SimplePool) connect(addr string) error {
 		Compressor:    c.cfg.Compressor,
 		Compressor:    c.cfg.Compressor,
 		Authenticator: c.cfg.Authenticator,
 		Authenticator: c.cfg.Authenticator,
 		Keepalive:     c.cfg.SocketKeepalive,
 		Keepalive:     c.cfg.SocketKeepalive,
-		SslOpts:       c.cfg.SslOpts,
+		tlsConfig:     c.tlsConfig,
 	}
 	}
 
 
 	conn, err := Connect(addr, cfg, c)
 	conn, err := Connect(addr, cfg, c)

+ 9 - 3
session_test.go

@@ -9,7 +9,10 @@ import (
 func TestSessionAPI(t *testing.T) {
 func TestSessionAPI(t *testing.T) {
 
 
 	cfg := ClusterConfig{}
 	cfg := ClusterConfig{}
-	pool := NewSimplePool(&cfg)
+	pool, err := NewSimplePool(&cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 
 	s := NewSession(pool, cfg)
 	s := NewSession(pool, cfg)
 
 
@@ -60,7 +63,7 @@ func TestSessionAPI(t *testing.T) {
 
 
 	testBatch := s.NewBatch(LoggedBatch)
 	testBatch := s.NewBatch(LoggedBatch)
 	testBatch.Query("test")
 	testBatch.Query("test")
-	err := s.ExecuteBatch(testBatch)
+	err = s.ExecuteBatch(testBatch)
 
 
 	if err != ErrNoConnections {
 	if err != ErrNoConnections {
 		t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err)
 		t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err)
@@ -151,7 +154,10 @@ func TestBatchBasicAPI(t *testing.T) {
 
 
 	cfg := ClusterConfig{}
 	cfg := ClusterConfig{}
 	cfg.RetryPolicy = &SimpleRetryPolicy{NumRetries: 2}
 	cfg.RetryPolicy = &SimpleRetryPolicy{NumRetries: 2}
-	pool := NewSimplePool(&cfg)
+	pool, err := NewSimplePool(&cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 
 	s := NewSession(pool, cfg)
 	s := NewSession(pool, cfg)
 	b := s.NewBatch(UnloggedBatch)
 	b := s.NewBatch(UnloggedBatch)