common_test.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. package gocql
  2. import (
  3. "flag"
  4. "fmt"
  5. "log"
  6. "net"
  7. "reflect"
  8. "strings"
  9. "sync"
  10. "testing"
  11. "time"
  12. )
  13. var (
  14. flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
  15. flagProto = flag.Int("proto", 0, "protcol version")
  16. flagCQL = flag.String("cql", "3.0.0", "CQL version")
  17. flagRF = flag.Int("rf", 1, "replication factor for test keyspace")
  18. clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
  19. flagRetry = flag.Int("retries", 5, "number of times to retry queries")
  20. flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
  21. flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
  22. flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
  23. flagCompressTest = flag.String("compressor", "", "compressor to use")
  24. flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")
  25. flagCassVersion cassVersion
  26. clusterHosts []string
  27. )
  28. func init() {
  29. flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against")
  30. flag.Parse()
  31. clusterHosts = strings.Split(*flagCluster, ",")
  32. log.SetFlags(log.Lshortfile | log.LstdFlags)
  33. }
  34. func addSslOptions(cluster *ClusterConfig) *ClusterConfig {
  35. if *flagRunSslTest {
  36. cluster.SslOpts = &SslOptions{
  37. CertPath: "testdata/pki/gocql.crt",
  38. KeyPath: "testdata/pki/gocql.key",
  39. CaPath: "testdata/pki/ca.crt",
  40. EnableHostVerification: false,
  41. }
  42. }
  43. return cluster
  44. }
  45. var initOnce sync.Once
  46. func createTable(s *Session, table string) error {
  47. // lets just be really sure
  48. if err := s.control.awaitSchemaAgreement(); err != nil {
  49. log.Printf("error waiting for schema agreement pre create table=%q err=%v\n", table, err)
  50. return err
  51. }
  52. if err := s.Query(table).RetryPolicy(nil).Exec(); err != nil {
  53. log.Printf("error creating table table=%q err=%v\n", table, err)
  54. return err
  55. }
  56. if err := s.control.awaitSchemaAgreement(); err != nil {
  57. log.Printf("error waiting for schema agreement post create table=%q err=%v\n", table, err)
  58. return err
  59. }
  60. return nil
  61. }
  62. func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
  63. cluster := NewCluster(clusterHosts...)
  64. cluster.ProtoVersion = *flagProto
  65. cluster.CQLVersion = *flagCQL
  66. cluster.Timeout = *flagTimeout
  67. cluster.Consistency = Quorum
  68. cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
  69. if *flagRetry > 0 {
  70. cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
  71. }
  72. switch *flagCompressTest {
  73. case "snappy":
  74. cluster.Compressor = &SnappyCompressor{}
  75. case "":
  76. default:
  77. panic("invalid compressor: " + *flagCompressTest)
  78. }
  79. cluster = addSslOptions(cluster)
  80. for _, opt := range opts {
  81. opt(cluster)
  82. }
  83. return cluster
  84. }
  85. func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
  86. // TODO: tb.Helper()
  87. c := *cluster
  88. c.Keyspace = "system"
  89. c.Timeout = 30 * time.Second
  90. session, err := c.CreateSession()
  91. if err != nil {
  92. panic(err)
  93. }
  94. defer session.Close()
  95. err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace)
  96. if err != nil {
  97. panic(fmt.Sprintf("unable to drop keyspace: %v", err))
  98. }
  99. err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s
  100. WITH replication = {
  101. 'class' : 'SimpleStrategy',
  102. 'replication_factor' : %d
  103. }`, keyspace, *flagRF))
  104. if err != nil {
  105. panic(fmt.Sprintf("unable to create keyspace: %v", err))
  106. }
  107. }
  108. func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session {
  109. // Drop and re-create the keyspace once. Different tests should use their own
  110. // individual tables, but can assume that the table does not exist before.
  111. initOnce.Do(func() {
  112. createKeyspace(tb, cluster, "gocql_test")
  113. })
  114. cluster.Keyspace = "gocql_test"
  115. session, err := cluster.CreateSession()
  116. if err != nil {
  117. tb.Fatal("createSession:", err)
  118. }
  119. if err := session.control.awaitSchemaAgreement(); err != nil {
  120. tb.Fatal(err)
  121. }
  122. return session
  123. }
  124. func createSession(tb testing.TB, opts ...func(config *ClusterConfig)) *Session {
  125. cluster := createCluster(opts...)
  126. return createSessionFromCluster(cluster, tb)
  127. }
  128. // createTestSession is hopefully moderately useful in actual unit tests
  129. func createTestSession() *Session {
  130. config := NewCluster()
  131. config.NumConns = 1
  132. config.Timeout = 0
  133. config.DisableInitialHostLookup = true
  134. config.IgnorePeerAddr = true
  135. config.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
  136. session := &Session{
  137. cfg: *config,
  138. connCfg: &ConnConfig{
  139. Timeout: 10 * time.Millisecond,
  140. Keepalive: 0,
  141. },
  142. policy: config.PoolConfig.HostSelectionPolicy,
  143. }
  144. session.pool = config.PoolConfig.buildPool(session)
  145. return session
  146. }
  147. func createViews(t *testing.T, session *Session) {
  148. if err := session.Query(`
  149. CREATE TYPE IF NOT EXISTS gocql_test.basicView (
  150. birthday timestamp,
  151. nationality text,
  152. weight text,
  153. height text); `).Exec(); err != nil {
  154. t.Fatalf("failed to create view with err: %v", err)
  155. }
  156. }
  157. func createFunctions(t *testing.T, session *Session) {
  158. if err := session.Query(`
  159. CREATE OR REPLACE FUNCTION gocql_test.avgState ( state tuple<int,bigint>, val int )
  160. CALLED ON NULL INPUT
  161. RETURNS tuple<int,bigint>
  162. LANGUAGE java AS
  163. $$if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;$$; `).Exec(); err != nil {
  164. t.Fatalf("failed to create function with err: %v", err)
  165. }
  166. if err := session.Query(`
  167. CREATE OR REPLACE FUNCTION gocql_test.avgFinal ( state tuple<int,bigint> )
  168. CALLED ON NULL INPUT
  169. RETURNS double
  170. LANGUAGE java AS
  171. $$double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);$$
  172. `).Exec(); err != nil {
  173. t.Fatalf("failed to create function with err: %v", err)
  174. }
  175. }
  176. func createAggregate(t *testing.T, session *Session) {
  177. createFunctions(t, session)
  178. if err := session.Query(`
  179. CREATE OR REPLACE AGGREGATE gocql_test.average(int)
  180. SFUNC avgState
  181. STYPE tuple<int,bigint>
  182. FINALFUNC avgFinal
  183. INITCOND (0,0);
  184. `).Exec(); err != nil {
  185. t.Fatalf("failed to create aggregate with err: %v", err)
  186. }
  187. }
  188. func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
  189. return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
  190. return newAddr, newPort
  191. })
  192. }
  193. func assertTrue(t *testing.T, description string, value bool) {
  194. if !value {
  195. t.Errorf("expected %s to be true", description)
  196. }
  197. }
  198. func assertEqual(t *testing.T, description string, expected, actual interface{}) {
  199. if expected != actual {
  200. t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
  201. }
  202. }
  203. func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) {
  204. if !reflect.DeepEqual(expected, actual) {
  205. t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
  206. }
  207. }
  208. func assertNil(t *testing.T, description string, actual interface{}) {
  209. if actual != nil {
  210. t.Errorf("expected %s to be (nil) but was (%+v) instead", description, actual)
  211. }
  212. }
  213. func assertNotNil(t *testing.T, description string, actual interface{}) {
  214. if actual == nil {
  215. t.Errorf("expected %s not to be (nil)", description)
  216. }
  217. }