common_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package gocql
  2. import (
  3. "flag"
  4. "fmt"
  5. "log"
  6. "net"
  7. "strings"
  8. "sync"
  9. "testing"
  10. "time"
  11. )
  12. var (
  13. flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
  14. flagProto = flag.Int("proto", 0, "protcol version")
  15. flagCQL = flag.String("cql", "3.0.0", "CQL version")
  16. flagRF = flag.Int("rf", 1, "replication factor for test keyspace")
  17. clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
  18. flagRetry = flag.Int("retries", 5, "number of times to retry queries")
  19. flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
  20. flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
  21. flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
  22. flagCompressTest = flag.String("compressor", "", "compressor to use")
  23. flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")
  24. flagCassVersion cassVersion
  25. clusterHosts []string
  26. )
  27. func init() {
  28. flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against")
  29. flag.Parse()
  30. clusterHosts = strings.Split(*flagCluster, ",")
  31. log.SetFlags(log.Lshortfile | log.LstdFlags)
  32. }
  33. func addSslOptions(cluster *ClusterConfig) *ClusterConfig {
  34. if *flagRunSslTest {
  35. cluster.SslOpts = &SslOptions{
  36. CertPath: "testdata/pki/gocql.crt",
  37. KeyPath: "testdata/pki/gocql.key",
  38. CaPath: "testdata/pki/ca.crt",
  39. EnableHostVerification: false,
  40. }
  41. }
  42. return cluster
  43. }
  44. var initOnce sync.Once
  45. func createTable(s *Session, table string) error {
  46. // lets just be really sure
  47. if err := s.control.awaitSchemaAgreement(); err != nil {
  48. log.Printf("error waiting for schema agreement pre create table=%q err=%v\n", table, err)
  49. return err
  50. }
  51. if err := s.Query(table).RetryPolicy(nil).Exec(); err != nil {
  52. log.Printf("error creating table table=%q err=%v\n", table, err)
  53. return err
  54. }
  55. if err := s.control.awaitSchemaAgreement(); err != nil {
  56. log.Printf("error waiting for schema agreement post create table=%q err=%v\n", table, err)
  57. return err
  58. }
  59. return nil
  60. }
  61. func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
  62. cluster := NewCluster(clusterHosts...)
  63. cluster.ProtoVersion = *flagProto
  64. cluster.CQLVersion = *flagCQL
  65. cluster.Timeout = *flagTimeout
  66. cluster.Consistency = Quorum
  67. cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
  68. if *flagRetry > 0 {
  69. cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
  70. }
  71. switch *flagCompressTest {
  72. case "snappy":
  73. cluster.Compressor = &SnappyCompressor{}
  74. case "":
  75. default:
  76. panic("invalid compressor: " + *flagCompressTest)
  77. }
  78. cluster = addSslOptions(cluster)
  79. for _, opt := range opts {
  80. opt(cluster)
  81. }
  82. return cluster
  83. }
  84. func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
  85. // TODO: tb.Helper()
  86. c := *cluster
  87. c.Keyspace = "system"
  88. c.Timeout = 30 * time.Second
  89. session, err := c.CreateSession()
  90. if err != nil {
  91. panic(err)
  92. }
  93. defer session.Close()
  94. err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace)
  95. if err != nil {
  96. panic(fmt.Sprintf("unable to drop keyspace: %v", err))
  97. }
  98. err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s
  99. WITH replication = {
  100. 'class' : 'SimpleStrategy',
  101. 'replication_factor' : %d
  102. }`, keyspace, *flagRF))
  103. if err != nil {
  104. panic(fmt.Sprintf("unable to create keyspace: %v", err))
  105. }
  106. }
  107. func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session {
  108. // Drop and re-create the keyspace once. Different tests should use their own
  109. // individual tables, but can assume that the table does not exist before.
  110. initOnce.Do(func() {
  111. createKeyspace(tb, cluster, "gocql_test")
  112. })
  113. cluster.Keyspace = "gocql_test"
  114. session, err := cluster.CreateSession()
  115. if err != nil {
  116. tb.Fatal("createSession:", err)
  117. }
  118. if err := session.control.awaitSchemaAgreement(); err != nil {
  119. tb.Fatal(err)
  120. }
  121. return session
  122. }
  123. func createSession(tb testing.TB, opts ...func(config *ClusterConfig)) *Session {
  124. cluster := createCluster(opts...)
  125. return createSessionFromCluster(cluster, tb)
  126. }
  127. // createTestSession is hopefully moderately useful in actual unit tests
  128. func createTestSession() *Session {
  129. config := NewCluster()
  130. config.NumConns = 1
  131. config.Timeout = 0
  132. config.DisableInitialHostLookup = true
  133. config.IgnorePeerAddr = true
  134. config.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
  135. session := &Session{
  136. cfg: *config,
  137. connCfg: &ConnConfig{
  138. Timeout: 10 * time.Millisecond,
  139. Keepalive: 0,
  140. },
  141. policy: config.PoolConfig.HostSelectionPolicy,
  142. }
  143. session.pool = config.PoolConfig.buildPool(session)
  144. return session
  145. }
  146. func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
  147. return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
  148. return newAddr, newPort
  149. })
  150. }
  151. func assertTrue(t *testing.T, description string, value bool) {
  152. if !value {
  153. t.Errorf("expected %s to be true", description)
  154. }
  155. }
  156. func assertEqual(t *testing.T, description string, expected, actual interface{}) {
  157. if expected != actual {
  158. t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
  159. }
  160. }
  161. func assertNil(t *testing.T, description string, actual interface{}) {
  162. if actual != nil {
  163. t.Errorf("expected %s to be (nil) but was (%+v) instead", description, actual)
  164. }
  165. }
  166. func assertNotNil(t *testing.T, description string, actual interface{}) {
  167. if actual == nil {
  168. t.Errorf("expected %s not to be (nil)", description)
  169. }
  170. }