common_test.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package gocql
  2. import (
  3. "flag"
  4. "fmt"
  5. "log"
  6. "net"
  7. "reflect"
  8. "strings"
  9. "sync"
  10. "testing"
  11. "time"
  12. )
  13. var (
  14. flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
  15. flagProto = flag.Int("proto", 0, "protcol version")
  16. flagCQL = flag.String("cql", "3.0.0", "CQL version")
  17. flagRF = flag.Int("rf", 1, "replication factor for test keyspace")
  18. clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
  19. flagRetry = flag.Int("retries", 5, "number of times to retry queries")
  20. flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
  21. flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
  22. flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
  23. flagCompressTest = flag.String("compressor", "", "compressor to use")
  24. flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")
  25. flagCassVersion cassVersion
  26. )
  27. func init() {
  28. flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against")
  29. log.SetFlags(log.Lshortfile | log.LstdFlags)
  30. }
  31. func getClusterHosts() []string {
  32. return strings.Split(*flagCluster, ",")
  33. }
  34. func addSslOptions(cluster *ClusterConfig) *ClusterConfig {
  35. if *flagRunSslTest {
  36. cluster.SslOpts = &SslOptions{
  37. CertPath: "testdata/pki/gocql.crt",
  38. KeyPath: "testdata/pki/gocql.key",
  39. CaPath: "testdata/pki/ca.crt",
  40. EnableHostVerification: false,
  41. }
  42. }
  43. return cluster
  44. }
  45. var initOnce sync.Once
  46. func createTable(s *Session, table string) error {
  47. // lets just be really sure
  48. if err := s.control.awaitSchemaAgreement(); err != nil {
  49. log.Printf("error waiting for schema agreement pre create table=%q err=%v\n", table, err)
  50. return err
  51. }
  52. if err := s.Query(table).RetryPolicy(nil).Exec(); err != nil {
  53. log.Printf("error creating table table=%q err=%v\n", table, err)
  54. return err
  55. }
  56. if err := s.control.awaitSchemaAgreement(); err != nil {
  57. log.Printf("error waiting for schema agreement post create table=%q err=%v\n", table, err)
  58. return err
  59. }
  60. return nil
  61. }
  62. func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
  63. clusterHosts := getClusterHosts()
  64. cluster := NewCluster(clusterHosts...)
  65. cluster.ProtoVersion = *flagProto
  66. cluster.CQLVersion = *flagCQL
  67. cluster.Timeout = *flagTimeout
  68. cluster.Consistency = Quorum
  69. cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
  70. if *flagRetry > 0 {
  71. cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
  72. }
  73. switch *flagCompressTest {
  74. case "snappy":
  75. cluster.Compressor = &SnappyCompressor{}
  76. case "":
  77. default:
  78. panic("invalid compressor: " + *flagCompressTest)
  79. }
  80. cluster = addSslOptions(cluster)
  81. for _, opt := range opts {
  82. opt(cluster)
  83. }
  84. return cluster
  85. }
  86. func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
  87. // TODO: tb.Helper()
  88. c := *cluster
  89. c.Keyspace = "system"
  90. c.Timeout = 30 * time.Second
  91. session, err := c.CreateSession()
  92. if err != nil {
  93. panic(err)
  94. }
  95. defer session.Close()
  96. err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace)
  97. if err != nil {
  98. panic(fmt.Sprintf("unable to drop keyspace: %v", err))
  99. }
  100. err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s
  101. WITH replication = {
  102. 'class' : 'SimpleStrategy',
  103. 'replication_factor' : %d
  104. }`, keyspace, *flagRF))
  105. if err != nil {
  106. panic(fmt.Sprintf("unable to create keyspace: %v", err))
  107. }
  108. }
  109. func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session {
  110. // Drop and re-create the keyspace once. Different tests should use their own
  111. // individual tables, but can assume that the table does not exist before.
  112. initOnce.Do(func() {
  113. createKeyspace(tb, cluster, "gocql_test")
  114. })
  115. cluster.Keyspace = "gocql_test"
  116. session, err := cluster.CreateSession()
  117. if err != nil {
  118. tb.Fatal("createSession:", err)
  119. }
  120. if err := session.control.awaitSchemaAgreement(); err != nil {
  121. tb.Fatal(err)
  122. }
  123. return session
  124. }
  125. func createSession(tb testing.TB, opts ...func(config *ClusterConfig)) *Session {
  126. cluster := createCluster(opts...)
  127. return createSessionFromCluster(cluster, tb)
  128. }
  129. // createTestSession is hopefully moderately useful in actual unit tests
  130. func createTestSession() *Session {
  131. config := NewCluster()
  132. config.NumConns = 1
  133. config.Timeout = 0
  134. config.DisableInitialHostLookup = true
  135. config.IgnorePeerAddr = true
  136. config.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
  137. session := &Session{
  138. cfg: *config,
  139. connCfg: &ConnConfig{
  140. Timeout: 10 * time.Millisecond,
  141. Keepalive: 0,
  142. },
  143. policy: config.PoolConfig.HostSelectionPolicy,
  144. }
  145. session.pool = config.PoolConfig.buildPool(session)
  146. return session
  147. }
  148. func createViews(t *testing.T, session *Session) {
  149. if err := session.Query(`
  150. CREATE TYPE IF NOT EXISTS gocql_test.basicView (
  151. birthday timestamp,
  152. nationality text,
  153. weight text,
  154. height text); `).Exec(); err != nil {
  155. t.Fatalf("failed to create view with err: %v", err)
  156. }
  157. }
  158. func createFunctions(t *testing.T, session *Session) {
  159. if err := session.Query(`
  160. CREATE OR REPLACE FUNCTION gocql_test.avgState ( state tuple<int,bigint>, val int )
  161. CALLED ON NULL INPUT
  162. RETURNS tuple<int,bigint>
  163. LANGUAGE java AS
  164. $$if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;$$; `).Exec(); err != nil {
  165. t.Fatalf("failed to create function with err: %v", err)
  166. }
  167. if err := session.Query(`
  168. CREATE OR REPLACE FUNCTION gocql_test.avgFinal ( state tuple<int,bigint> )
  169. CALLED ON NULL INPUT
  170. RETURNS double
  171. LANGUAGE java AS
  172. $$double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);$$
  173. `).Exec(); err != nil {
  174. t.Fatalf("failed to create function with err: %v", err)
  175. }
  176. }
  177. func createAggregate(t *testing.T, session *Session) {
  178. createFunctions(t, session)
  179. if err := session.Query(`
  180. CREATE OR REPLACE AGGREGATE gocql_test.average(int)
  181. SFUNC avgState
  182. STYPE tuple<int,bigint>
  183. FINALFUNC avgFinal
  184. INITCOND (0,0);
  185. `).Exec(); err != nil {
  186. t.Fatalf("failed to create aggregate with err: %v", err)
  187. }
  188. }
  189. func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
  190. return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
  191. return newAddr, newPort
  192. })
  193. }
  194. func assertTrue(t *testing.T, description string, value bool) {
  195. t.Helper()
  196. if !value {
  197. t.Fatalf("expected %s to be true", description)
  198. }
  199. }
  200. func assertEqual(t *testing.T, description string, expected, actual interface{}) {
  201. t.Helper()
  202. if expected != actual {
  203. t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
  204. }
  205. }
  206. func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) {
  207. t.Helper()
  208. if !reflect.DeepEqual(expected, actual) {
  209. t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
  210. }
  211. }
  212. func assertNil(t *testing.T, description string, actual interface{}) {
  213. t.Helper()
  214. if actual != nil {
  215. t.Fatalf("expected %s to be (nil) but was (%+v) instead", description, actual)
  216. }
  217. }
  218. func assertNotNil(t *testing.T, description string, actual interface{}) {
  219. t.Helper()
  220. if actual == nil {
  221. t.Fatalf("expected %s not to be (nil)", description)
  222. }
  223. }