Bläddra i källkod

added ssl/tls feature for gocql

Nick DHUPIA 11 år sedan
förälder
incheckning
12fdedb378
3 ändrade filer med 37 tillägg och 2 borttagningar
  1. 1 0
      cluster.go
  2. 35 2
      conn.go
  3. 1 0
      connectionpool.go

+ 1 - 0
cluster.go

@@ -62,6 +62,7 @@ type ClusterConfig struct {
 	DiscoverHosts    bool          // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
 	MaxPreparedStmts int           // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
 	Discovery        DiscoveryConfig
+	SslOpts          *SslOptions
 }
 
 // NewCluster generates a new config for the default cluster implementation.

+ 35 - 2
conn.go

@@ -6,8 +6,11 @@ package gocql
 
 import (
 	"bufio"
+	"crypto/tls"
+	"crypto/x509"
 	"errors"
 	"fmt"
+	"io/ioutil"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -44,6 +47,13 @@ func (p PasswordAuthenticator) Success(data []byte) error {
 	return nil
 }
 
+type SslOptions struct {
+	CertPath               string
+	KeyPath                string
+	CaPath                 string
+	EnableHostVerification bool //most of the time people will want to not verify host they are connecting to
+}
+
 type ConnConfig struct {
 	ProtoVersion  int
 	CQLVersion    string
@@ -52,6 +62,7 @@ type ConnConfig struct {
 	Compressor    Compressor
 	Authenticator Authenticator
 	Keepalive     time.Duration
+	SslOpts       *SslOptions
 }
 
 // Conn is a single connection to a Cassandra node. It can be used to execute
@@ -80,8 +91,30 @@ type Conn struct {
 // Connect establishes a connection to a Cassandra node.
 // You must also call the Serve method before you can execute any queries.
 func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
-	conn, err := net.DialTimeout("tcp", addr, cfg.Timeout)
-	if err != nil {
+	var (
+		err  error
+		conn net.Conn
+	)
+	if cfg.SslOpts != nil {
+		pem, err := ioutil.ReadFile(cfg.SslOpts.CaPath)
+		certPool := x509.NewCertPool()
+		if !certPool.AppendCertsFromPEM(pem) {
+			panic("Failed parsing or appending certs")
+		}
+		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
+		if err != nil {
+			panic(err)
+		}
+
+		config := tls.Config{
+			Certificates: []tls.Certificate{mycert},
+			RootCAs:      certPool,
+		}
+		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
+		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
+			return nil, err
+		}
+	} else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {
 		return nil, err
 	}
 

+ 1 - 0
connectionpool.go

@@ -169,6 +169,7 @@ func (c *SimplePool) connect(addr string) error {
 		Compressor:    c.cfg.Compressor,
 		Authenticator: c.cfg.Authenticator,
 		Keepalive:     c.cfg.SocketKeepalive,
+		SslOpts:       c.cfg.SslOpts,
 	}
 
 	conn, err := Connect(addr, cfg, c)