conn_test.go 7.5 KB


  1. // +build all unit
  2. package gocql
  3. import (
  4. "crypto/tls"
  5. "crypto/x509"
  6. "io"
  7. "io/ioutil"
  8. "net"
  9. "strings"
  10. "sync"
  11. "sync/atomic"
  12. "testing"
  13. "time"
  14. )
  15. type TestServer struct {
  16. Address string
  17. t *testing.T
  18. nreq uint64
  19. listen net.Listener
  20. nKillReq uint64
  21. }
  22. func TestSimple(t *testing.T) {
  23. srv := NewTestServer(t)
  24. defer srv.Stop()
  25. db, err := NewCluster(srv.Address).CreateSession()
  26. if err != nil {
  27. t.Errorf("NewCluster: %v", err)
  28. }
  29. if err := db.Query("void").Exec(); err != nil {
  30. t.Error(err)
  31. }
  32. }
  33. func TestSSLSimple(t *testing.T) {
  34. srv := NewSSLTestServer(t)
  35. defer srv.Stop()
  36. db, err := createTestSslCluster(srv.Address).CreateSession()
  37. if err != nil {
  38. t.Errorf("NewCluster: %v", err)
  39. }
  40. if err := db.Query("void").Exec(); err != nil {
  41. t.Error(err)
  42. }
  43. }
  44. func createTestSslCluster(hosts string) *ClusterConfig {
  45. cluster := NewCluster(hosts)
  46. cluster.SslOpts = &SslOptions{
  47. CertPath: "testdata/pki/gocql.crt",
  48. KeyPath: "testdata/pki/gocql.key",
  49. CaPath: "testdata/pki/ca.crt",
  50. EnableHostVerification: false,
  51. }
  52. return cluster
  53. }
  54. func TestClosed(t *testing.T) {
  55. t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
  56. srv := NewTestServer(t)
  57. defer srv.Stop()
  58. session, err := NewCluster(srv.Address).CreateSession()
  59. if err != nil {
  60. t.Errorf("NewCluster: %v", err)
  61. }
  62. session.Close()
  63. if err := session.Query("void").Exec(); err != ErrSessionClosed {
  64. t.Errorf("expected %#v, got %#v", ErrSessionClosed, err)
  65. }
  66. }
  67. func TestTimeout(t *testing.T) {
  68. srv := NewTestServer(t)
  69. defer srv.Stop()
  70. db, err := NewCluster(srv.Address).CreateSession()
  71. if err != nil {
  72. t.Errorf("NewCluster: %v", err)
  73. }
  74. go func() {
  75. <-time.After(2 * time.Second)
  76. t.Fatal("no timeout")
  77. }()
  78. if err := db.Query("kill").Exec(); err == nil {
  79. t.Fatal("expected error")
  80. }
  81. }
  82. // TestQueryRetry will test to make sure that gocql will execute
  83. // the exact amount of retry queries designated by the user.
  84. func TestQueryRetry(t *testing.T) {
  85. srv := NewTestServer(t)
  86. defer srv.Stop()
  87. db, err := NewCluster(srv.Address).CreateSession()
  88. if err != nil {
  89. t.Errorf("NewCluster: %v", err)
  90. }
  91. go func() {
  92. <-time.After(5 * time.Second)
  93. t.Fatal("no timeout")
  94. }()
  95. rt := &SimpleRetryPolicy{NumRetries: 1}
  96. qry := db.Query("kill").RetryPolicy(rt)
  97. if err := qry.Exec(); err == nil {
  98. t.Fatal("expected error")
  99. }
  100. requests := srv.nKillReq
  101. if requests != uint64(qry.Attempts()) {
  102. t.Fatalf("expected requests %v to match query attemps %v", requests, qry.Attempts())
  103. }
  104. //Minus 1 from the requests variable since there is the initial query attempt
  105. if requests-1 != uint64(rt.NumRetries) {
  106. t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
  107. }
  108. }
  109. func TestSlowQuery(t *testing.T) {
  110. srv := NewTestServer(t)
  111. defer srv.Stop()
  112. db, err := NewCluster(srv.Address).CreateSession()
  113. if err != nil {
  114. t.Errorf("NewCluster: %v", err)
  115. }
  116. if err := db.Query("slow").Exec(); err != nil {
  117. t.Fatal(err)
  118. }
  119. }
  120. func TestRoundRobin(t *testing.T) {
  121. servers := make([]*TestServer, 5)
  122. addrs := make([]string, len(servers))
  123. for i := 0; i < len(servers); i++ {
  124. servers[i] = NewTestServer(t)
  125. addrs[i] = servers[i].Address
  126. defer servers[i].Stop()
  127. }
  128. cluster := NewCluster(addrs...)
  129. db, err := cluster.CreateSession()
  130. time.Sleep(1 * time.Second) //Sleep to allow the Cluster.fillPool to complete
  131. if err != nil {
  132. t.Errorf("NewCluster: %v", err)
  133. }
  134. var wg sync.WaitGroup
  135. wg.Add(5)
  136. for i := 0; i < 5; i++ {
  137. go func() {
  138. for j := 0; j < 5; j++ {
  139. if err := db.Query("void").Exec(); err != nil {
  140. t.Fatal(err)
  141. }
  142. }
  143. wg.Done()
  144. }()
  145. }
  146. wg.Wait()
  147. diff := 0
  148. for i := 1; i < len(servers); i++ {
  149. d := 0
  150. if servers[i].nreq > servers[i-1].nreq {
  151. d = int(servers[i].nreq - servers[i-1].nreq)
  152. } else {
  153. d = int(servers[i-1].nreq - servers[i].nreq)
  154. }
  155. if d > diff {
  156. diff = d
  157. }
  158. }
  159. if diff > 0 {
  160. t.Fatal("diff:", diff)
  161. }
  162. }
  163. func TestConnClosing(t *testing.T) {
  164. t.Skip("Skipping until test can be ran reliably")
  165. srv := NewTestServer(t)
  166. defer srv.Stop()
  167. db, err := NewCluster(srv.Address).CreateSession()
  168. if err != nil {
  169. t.Errorf("NewCluster: %v", err)
  170. }
  171. defer db.Close()
  172. numConns := db.cfg.NumConns
  173. count := db.cfg.NumStreams * numConns
  174. wg := &sync.WaitGroup{}
  175. wg.Add(count)
  176. for i := 0; i < count; i++ {
  177. go func(wg *sync.WaitGroup) {
  178. wg.Done()
  179. db.Query("kill").Exec()
  180. }(wg)
  181. }
  182. wg.Wait()
  183. time.Sleep(1 * time.Second) //Sleep so the fillPool can complete.
  184. pool := db.Pool.(ConnectionPool)
  185. conns := pool.Size()
  186. if conns != numConns {
  187. t.Fatalf("Expected to have %d connections but have %d", numConns, conns)
  188. }
  189. }
  190. func NewTestServer(t *testing.T) *TestServer {
  191. laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
  192. if err != nil {
  193. t.Fatal(err)
  194. }
  195. listen, err := net.ListenTCP("tcp", laddr)
  196. if err != nil {
  197. t.Fatal(err)
  198. }
  199. srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
  200. go srv.serve()
  201. return srv
  202. }
  203. func NewSSLTestServer(t *testing.T) *TestServer {
  204. pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
  205. certPool := x509.NewCertPool()
  206. if !certPool.AppendCertsFromPEM(pem) {
  207. t.Errorf("Failed parsing or appending certs")
  208. }
  209. mycert, err := tls.LoadX509KeyPair("testdata/pki/cassandra.crt", "testdata/pki/cassandra.key")
  210. if err != nil {
  211. t.Errorf("could not load cert")
  212. }
  213. config := &tls.Config{
  214. Certificates: []tls.Certificate{mycert},
  215. RootCAs: certPool,
  216. }
  217. listen, err := tls.Listen("tcp", "127.0.0.1:0", config)
  218. if err != nil {
  219. t.Fatal(err)
  220. }
  221. srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
  222. go srv.serve()
  223. return srv
  224. }
  225. func (srv *TestServer) serve() {
  226. defer srv.listen.Close()
  227. for {
  228. conn, err := srv.listen.Accept()
  229. if err != nil {
  230. break
  231. }
  232. go func(conn net.Conn) {
  233. defer conn.Close()
  234. for {
  235. frame := srv.readFrame(conn)
  236. atomic.AddUint64(&srv.nreq, 1)
  237. srv.process(frame, conn)
  238. }
  239. }(conn)
  240. }
  241. }
  242. func (srv *TestServer) Stop() {
  243. srv.listen.Close()
  244. }
  245. func (srv *TestServer) process(frame frame, conn net.Conn) {
  246. switch frame[3] {
  247. case opStartup:
  248. frame = frame[:headerSize]
  249. frame.setHeader(protoResponse, 0, frame[2], opReady)
  250. case opQuery:
  251. input := frame
  252. input.skipHeader()
  253. query := strings.TrimSpace(input.readLongString())
  254. frame = frame[:headerSize]
  255. frame.setHeader(protoResponse, 0, frame[2], opResult)
  256. first := query
  257. if n := strings.Index(query, " "); n > 0 {
  258. first = first[:n]
  259. }
  260. switch strings.ToLower(first) {
  261. case "kill":
  262. atomic.AddUint64(&srv.nKillReq, 1)
  263. select {}
  264. case "slow":
  265. go func() {
  266. <-time.After(1 * time.Second)
  267. frame.writeInt(resultKindVoid)
  268. frame.setLength(len(frame) - headerSize)
  269. if _, err := conn.Write(frame); err != nil {
  270. return
  271. }
  272. }()
  273. return
  274. case "use":
  275. frame.writeInt(3)
  276. frame.writeString(strings.TrimSpace(query[3:]))
  277. case "void":
  278. frame.writeInt(resultKindVoid)
  279. default:
  280. frame.writeInt(resultKindVoid)
  281. }
  282. default:
  283. frame = frame[:headerSize]
  284. frame.setHeader(protoResponse, 0, frame[2], opError)
  285. frame.writeInt(0)
  286. frame.writeString("not supported")
  287. }
  288. frame.setLength(len(frame) - headerSize)
  289. if _, err := conn.Write(frame); err != nil {
  290. return
  291. }
  292. }
  293. func (srv *TestServer) readFrame(conn net.Conn) frame {
  294. frame := make(frame, headerSize, headerSize+512)
  295. if _, err := io.ReadFull(conn, frame); err != nil {
  296. srv.t.Fatal(err)
  297. }
  298. if n := frame.Length(); n > 0 {
  299. frame.grow(n)
  300. if _, err := io.ReadFull(conn, frame[headerSize:]); err != nil {
  301. srv.t.Fatal(err)
  302. }
  303. }
  304. return frame
  305. }