Selaa lähdekoodia

add SslOptions when NewCluster or createCluster is called if -runssl is enabled

xoraes 11 vuotta sitten
vanhempi
commit
d3232227c6
1 muutettua tiedostoa jossa 18 lisäystä ja 5 poistoa
  1. 18 5
      cassandra_test.go

+ 18 - 5
cassandra_test.go

@@ -19,7 +19,6 @@ import (
 	"testing"
 	"time"
 	"unicode"
-
 	"speter.net/go/exp/math/dec/inf"
 )
 
@@ -36,12 +35,23 @@ var (
 )
 
 func init() {
-
 	flag.Parse()
 	clusterHosts = strings.Split(*flagCluster, ",")
 	log.SetFlags(log.Lshortfile | log.LstdFlags)
 }
 
+func addSslOptions(cluster *ClusterConfig) *ClusterConfig {
+	if *flagRunSslTest {
+		cluster.SslOpts = &SslOptions{
+			CertPath:               "testdata/pki/gocql.crt",
+			KeyPath:                "testdata/pki/gocql.key",
+			CaPath:                 "testdata/pki/ca.crt",
+			EnableHostVerification: false,
+		}
+	}
+	return cluster
+}
+
 var initOnce sync.Once
 
 func createTable(s *Session, table string) error {
@@ -53,7 +63,7 @@ func createTable(s *Session, table string) error {
 	return err
 }
 
-var createCluster = func() *ClusterConfig {
+func createCluster() *ClusterConfig {
 	cluster := NewCluster(clusterHosts...)
 	cluster.ProtoVersion = *flagProto
 	cluster.CQLVersion = *flagCQL
@@ -62,7 +72,7 @@ var createCluster = func() *ClusterConfig {
 	if *flagRetry > 0 {
 		cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
 	}
-
+	cluster = addSslOptions(cluster)
 	return cluster
 }
 
@@ -115,7 +125,7 @@ func TestRingDiscovery(t *testing.T) {
 		cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
 	}
 	cluster.DiscoverHosts = true
-
+	cluster = addSslOptions(cluster)
 	session, err := cluster.CreateSession()
 	if err != nil {
 		t.Errorf("got error connecting to the cluster %v", err)
@@ -137,6 +147,7 @@ func TestRingDiscovery(t *testing.T) {
 
 func TestEmptyHosts(t *testing.T) {
 	cluster := NewCluster()
+	cluster = addSslOptions(cluster)
 	if session, err := cluster.CreateSession(); err == nil {
 		session.Close()
 		t.Error("expected err, got nil")
@@ -163,6 +174,7 @@ func TestInvalidKeyspace(t *testing.T) {
 	cluster.ProtoVersion = *flagProto
 	cluster.CQLVersion = *flagCQL
 	cluster.Keyspace = "invalidKeyspace"
+	cluster = addSslOptions(cluster)
 	session, err := cluster.CreateSession()
 	if err != nil {
 		if err != ErrNoConnectionsStarted {
@@ -473,6 +485,7 @@ func TestCreateSessionTimeout(t *testing.T) {
 		t.Fatal("no startup timeout")
 	}()
 	c := NewCluster("127.0.0.1:1")
+	c = addSslOptions(c)
 	_, err := c.CreateSession()
 
 	if err == nil {