conn_test.go 32 KB


  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build all unit
  5. package gocql
  6. import (
  7. "bufio"
  8. "bytes"
  9. "context"
  10. "crypto/tls"
  11. "crypto/x509"
  12. "fmt"
  13. "io"
  14. "io/ioutil"
  15. "math/rand"
  16. "net"
  17. "os"
  18. "strings"
  19. "sync"
  20. "sync/atomic"
  21. "testing"
  22. "time"
  23. "github.com/gocql/gocql/internal/streams"
  24. )
  25. const (
  26. defaultProto = protoVersion2
  27. )
  28. func TestApprove(t *testing.T) {
  29. tests := map[bool]bool{
  30. approve("org.apache.cassandra.auth.PasswordAuthenticator"): true,
  31. approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator"): true,
  32. approve("com.datastax.bdp.cassandra.auth.DseAuthenticator"): true,
  33. approve("io.aiven.cassandra.auth.AivenAuthenticator"): true,
  34. approve("com.amazon.helenus.auth.HelenusAuthenticator"): true,
  35. approve("com.apache.cassandra.auth.FakeAuthenticator"): false,
  36. }
  37. for k, v := range tests {
  38. if k != v {
  39. t.Fatalf("expected '%v', got '%v'", k, v)
  40. }
  41. }
  42. }
  43. func TestJoinHostPort(t *testing.T) {
  44. tests := map[string]string{
  45. "127.0.0.1:0": JoinHostPort("127.0.0.1", 0),
  46. "127.0.0.1:1": JoinHostPort("127.0.0.1:1", 9142),
  47. "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:0": JoinHostPort("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 0),
  48. "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1": JoinHostPort("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1", 9142),
  49. }
  50. for k, v := range tests {
  51. if k != v {
  52. t.Fatalf("expected '%v', got '%v'", k, v)
  53. }
  54. }
  55. }
  56. func testCluster(proto protoVersion, addresses ...string) *ClusterConfig {
  57. cluster := NewCluster(addresses...)
  58. cluster.ProtoVersion = int(proto)
  59. cluster.disableControlConn = true
  60. return cluster
  61. }
  62. func TestSimple(t *testing.T) {
  63. srv := NewTestServer(t, defaultProto, context.Background())
  64. defer srv.Stop()
  65. cluster := testCluster(defaultProto, srv.Address)
  66. db, err := cluster.CreateSession()
  67. if err != nil {
  68. t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
  69. }
  70. if err := db.Query("void").Exec(); err != nil {
  71. t.Fatalf("0x%x: %v", defaultProto, err)
  72. }
  73. }
  74. func TestSSLSimple(t *testing.T) {
  75. srv := NewSSLTestServer(t, defaultProto, context.Background())
  76. defer srv.Stop()
  77. db, err := createTestSslCluster(srv.Address, defaultProto, true).CreateSession()
  78. if err != nil {
  79. t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
  80. }
  81. if err := db.Query("void").Exec(); err != nil {
  82. t.Fatalf("0x%x: %v", defaultProto, err)
  83. }
  84. }
  85. func TestSSLSimpleNoClientCert(t *testing.T) {
  86. srv := NewSSLTestServer(t, defaultProto, context.Background())
  87. defer srv.Stop()
  88. db, err := createTestSslCluster(srv.Address, defaultProto, false).CreateSession()
  89. if err != nil {
  90. t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
  91. }
  92. if err := db.Query("void").Exec(); err != nil {
  93. t.Fatalf("0x%x: %v", defaultProto, err)
  94. }
  95. }
  96. func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *ClusterConfig {
  97. cluster := testCluster(proto, addr)
  98. sslOpts := &SslOptions{
  99. CaPath: "testdata/pki/ca.crt",
  100. EnableHostVerification: false,
  101. }
  102. if useClientCert {
  103. sslOpts.CertPath = "testdata/pki/gocql.crt"
  104. sslOpts.KeyPath = "testdata/pki/gocql.key"
  105. }
  106. cluster.SslOpts = sslOpts
  107. return cluster
  108. }
  109. func TestClosed(t *testing.T) {
  110. t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
  111. srv := NewTestServer(t, defaultProto, context.Background())
  112. defer srv.Stop()
  113. session, err := newTestSession(defaultProto, srv.Address)
  114. if err != nil {
  115. t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
  116. }
  117. session.Close()
  118. if err := session.Query("void").Exec(); err != ErrSessionClosed {
  119. t.Fatalf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
  120. }
  121. }
  122. func newTestSession(proto protoVersion, addresses ...string) (*Session, error) {
  123. return testCluster(proto, addresses...).CreateSession()
  124. }
  125. func TestDNSLookupConnected(t *testing.T) {
  126. log := &testLogger{}
  127. Logger = log
  128. defer func() {
  129. Logger = &defaultLogger{}
  130. }()
  131. // Override the defaul DNS resolver and restore at the end
  132. failDNS = true
  133. defer func() { failDNS = false }()
  134. srv := NewTestServer(t, defaultProto, context.Background())
  135. defer srv.Stop()
  136. cluster := NewCluster("cassandra1.invalid", srv.Address, "cassandra2.invalid")
  137. cluster.ProtoVersion = int(defaultProto)
  138. cluster.disableControlConn = true
  139. // CreateSession() should attempt to resolve the DNS name "cassandraX.invalid"
  140. // and fail, but continue to connect via srv.Address
  141. _, err := cluster.CreateSession()
  142. if err != nil {
  143. t.Fatal("CreateSession() should have connected")
  144. }
  145. if !strings.Contains(log.String(), "gocql: dns error") {
  146. t.Fatalf("Expected to receive dns error log message - got '%s' instead", log.String())
  147. }
  148. }
  149. func TestDNSLookupError(t *testing.T) {
  150. log := &testLogger{}
  151. Logger = log
  152. defer func() {
  153. Logger = &defaultLogger{}
  154. }()
  155. // Override the defaul DNS resolver and restore at the end
  156. failDNS = true
  157. defer func() { failDNS = false }()
  158. cluster := NewCluster("cassandra1.invalid", "cassandra2.invalid")
  159. cluster.ProtoVersion = int(defaultProto)
  160. cluster.disableControlConn = true
  161. // CreateSession() should attempt to resolve each DNS name "cassandraX.invalid"
  162. // and fail since it could not resolve any dns entries
  163. _, err := cluster.CreateSession()
  164. if err == nil {
  165. t.Fatal("CreateSession() should have returned an error")
  166. }
  167. if !strings.Contains(log.String(), "gocql: dns error") {
  168. t.Fatalf("Expected to receive dns error log message - got '%s' instead", log.String())
  169. }
  170. if err.Error() != "gocql: unable to create session: failed to resolve any of the provided hostnames" {
  171. t.Fatalf("Expected CreateSession() to fail with message - got '%s' instead", err.Error())
  172. }
  173. }
  174. func TestStartupTimeout(t *testing.T) {
  175. ctx, cancel := context.WithCancel(context.Background())
  176. log := &testLogger{}
  177. Logger = log
  178. defer func() {
  179. Logger = &defaultLogger{}
  180. }()
  181. srv := NewTestServer(t, defaultProto, ctx)
  182. defer srv.Stop()
  183. // Tell the server to never respond to Startup frame
  184. atomic.StoreInt32(&srv.TimeoutOnStartup, 1)
  185. startTime := time.Now()
  186. cluster := NewCluster(srv.Address)
  187. cluster.ProtoVersion = int(defaultProto)
  188. cluster.disableControlConn = true
  189. // Set very long query connection timeout
  190. // so we know CreateSession() is using the ConnectTimeout
  191. cluster.Timeout = time.Second * 5
  192. // Create session should timeout during connect attempt
  193. _, err := cluster.CreateSession()
  194. if err == nil {
  195. t.Fatal("CreateSession() should have returned a timeout error")
  196. }
  197. elapsed := time.Since(startTime)
  198. if elapsed > time.Second*5 {
  199. t.Fatal("ConnectTimeout is not respected")
  200. }
  201. if !strings.Contains(err.Error(), "no connections were made when creating the session") {
  202. t.Fatalf("Expected to receive no connections error - got '%s'", err)
  203. }
  204. if !strings.Contains(log.String(), "no response to connection startup within timeout") {
  205. t.Fatalf("Expected to receive timeout log message - got '%s'", log.String())
  206. }
  207. cancel()
  208. }
  209. func TestTimeout(t *testing.T) {
  210. ctx, cancel := context.WithCancel(context.Background())
  211. srv := NewTestServer(t, defaultProto, ctx)
  212. defer srv.Stop()
  213. db, err := newTestSession(defaultProto, srv.Address)
  214. if err != nil {
  215. t.Fatalf("NewCluster: %v", err)
  216. }
  217. defer db.Close()
  218. var wg sync.WaitGroup
  219. wg.Add(1)
  220. go func() {
  221. defer wg.Done()
  222. select {
  223. case <-time.After(5 * time.Second):
  224. t.Errorf("no timeout")
  225. case <-ctx.Done():
  226. }
  227. }()
  228. if err := db.Query("kill").WithContext(ctx).Exec(); err == nil {
  229. t.Fatal("expected error got nil")
  230. }
  231. cancel()
  232. wg.Wait()
  233. }
  234. func TestCancel(t *testing.T) {
  235. ctx, cancel := context.WithCancel(context.Background())
  236. defer cancel()
  237. srv := NewTestServer(t, defaultProto, ctx)
  238. defer srv.Stop()
  239. cluster := testCluster(defaultProto, srv.Address)
  240. cluster.Timeout = 1 * time.Second
  241. db, err := cluster.CreateSession()
  242. if err != nil {
  243. t.Fatalf("NewCluster: %v", err)
  244. }
  245. defer db.Close()
  246. qry := db.Query("timeout").WithContext(ctx)
  247. // Make sure we finish the query without leftovers
  248. var wg sync.WaitGroup
  249. wg.Add(1)
  250. go func() {
  251. if err := qry.Exec(); err != context.Canceled {
  252. t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err)
  253. }
  254. wg.Done()
  255. }()
  256. // The query will timeout after about 1 seconds, so cancel it after a short pause
  257. time.AfterFunc(20*time.Millisecond, cancel)
  258. wg.Wait()
  259. }
  260. type testQueryObserver struct {
  261. metrics map[string]*hostMetrics
  262. verbose bool
  263. }
  264. func (o *testQueryObserver) ObserveQuery(ctx context.Context, q ObservedQuery) {
  265. host := q.Host.ConnectAddress().String()
  266. o.metrics[host] = q.Metrics
  267. if o.verbose {
  268. Logger.Printf("Observed query %q. Returned %v rows, took %v on host %q with %v attempts and total latency %v. Error: %q\n",
  269. q.Statement, q.Rows, q.End.Sub(q.Start), host, q.Metrics.Attempts, q.Metrics.TotalLatency, q.Err)
  270. }
  271. }
  272. func (o *testQueryObserver) GetMetrics(host *HostInfo) *hostMetrics {
  273. return o.metrics[host.ConnectAddress().String()]
  274. }
  275. // TestQueryRetry will test to make sure that gocql will execute
  276. // the exact amount of retry queries designated by the user.
  277. func TestQueryRetry(t *testing.T) {
  278. ctx, cancel := context.WithCancel(context.Background())
  279. defer cancel()
  280. srv := NewTestServer(t, defaultProto, ctx)
  281. defer srv.Stop()
  282. db, err := newTestSession(defaultProto, srv.Address)
  283. if err != nil {
  284. t.Fatalf("NewCluster: %v", err)
  285. }
  286. defer db.Close()
  287. go func() {
  288. select {
  289. case <-ctx.Done():
  290. return
  291. case <-time.After(5 * time.Second):
  292. t.Errorf("no timeout")
  293. }
  294. }()
  295. rt := &SimpleRetryPolicy{NumRetries: 1}
  296. qry := db.Query("kill").RetryPolicy(rt)
  297. if err := qry.Exec(); err == nil {
  298. t.Fatalf("expected error")
  299. }
  300. requests := atomic.LoadInt64(&srv.nKillReq)
  301. attempts := qry.Attempts()
  302. if requests != int64(attempts) {
  303. t.Fatalf("expected requests %v to match query attempts %v", requests, attempts)
  304. }
  305. // the query will only be attempted once, but is being retried
  306. if requests != int64(rt.NumRetries) {
  307. t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
  308. }
  309. }
  310. func TestQueryMultinodeWithMetrics(t *testing.T) {
  311. log := &testLogger{}
  312. Logger = log
  313. defer func() {
  314. Logger = &defaultLogger{}
  315. os.Stdout.WriteString(log.String())
  316. }()
  317. // Build a 3 node cluster to test host metric mapping
  318. var nodes []*TestServer
  319. var addresses = []string{
  320. "127.0.0.1",
  321. "127.0.0.2",
  322. "127.0.0.3",
  323. }
  324. // Can do with 1 context for all servers
  325. ctx := context.Background()
  326. for _, ip := range addresses {
  327. srv := NewTestServerWithAddress(ip+":0", t, defaultProto, ctx)
  328. defer srv.Stop()
  329. nodes = append(nodes, srv)
  330. }
  331. db, err := newTestSession(defaultProto, nodes[0].Address, nodes[1].Address, nodes[2].Address)
  332. if err != nil {
  333. t.Fatalf("NewCluster: %v", err)
  334. }
  335. defer db.Close()
  336. // 1 retry per host
  337. rt := &SimpleRetryPolicy{NumRetries: 3}
  338. observer := &testQueryObserver{metrics: make(map[string]*hostMetrics), verbose: false}
  339. qry := db.Query("kill").RetryPolicy(rt).Observer(observer)
  340. if err := qry.Exec(); err == nil {
  341. t.Fatalf("expected error")
  342. }
  343. for i, ip := range addresses {
  344. host := &HostInfo{connectAddress: net.ParseIP(ip)}
  345. queryMetric := qry.metrics.hostMetrics(host)
  346. observedMetrics := observer.GetMetrics(host)
  347. requests := int(atomic.LoadInt64(&nodes[i].nKillReq))
  348. hostAttempts := queryMetric.Attempts
  349. if requests != hostAttempts {
  350. t.Fatalf("expected requests %v to match query attempts %v", requests, hostAttempts)
  351. }
  352. if hostAttempts != observedMetrics.Attempts {
  353. t.Fatalf("expected observed attempts %v to match query attempts %v on host %v", observedMetrics.Attempts, hostAttempts, ip)
  354. }
  355. hostLatency := queryMetric.TotalLatency
  356. observedLatency := observedMetrics.TotalLatency
  357. if hostLatency != observedLatency {
  358. t.Fatalf("expected observed latency %v to match query latency %v on host %v", observedLatency, hostLatency, ip)
  359. }
  360. }
  361. // the query will only be attempted once, but is being retried
  362. attempts := qry.Attempts()
  363. if attempts != rt.NumRetries {
  364. t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, attempts)
  365. }
  366. }
  367. type testRetryPolicy struct {
  368. NumRetries int
  369. }
  370. func (t *testRetryPolicy) Attempt(qry RetryableQuery) bool {
  371. return qry.Attempts() <= t.NumRetries
  372. }
  373. func (t *testRetryPolicy) GetRetryType(err error) RetryType {
  374. return Retry
  375. }
  376. func TestSpeculativeExecution(t *testing.T) {
  377. log := &testLogger{}
  378. Logger = log
  379. defer func() {
  380. Logger = &defaultLogger{}
  381. os.Stdout.WriteString(log.String())
  382. }()
  383. // Build a 3 node cluster
  384. var nodes []*TestServer
  385. var addresses = []string{
  386. "127.0.0.1",
  387. "127.0.0.2",
  388. "127.0.0.3",
  389. }
  390. // Can do with 1 context for all servers
  391. ctx := context.Background()
  392. for _, ip := range addresses {
  393. srv := NewTestServerWithAddress(ip+":0", t, defaultProto, ctx)
  394. defer srv.Stop()
  395. nodes = append(nodes, srv)
  396. }
  397. db, err := newTestSession(defaultProto, nodes[0].Address, nodes[1].Address, nodes[2].Address)
  398. if err != nil {
  399. t.Fatalf("NewCluster: %v", err)
  400. }
  401. defer db.Close()
  402. // Create a test retry policy, 6 retries will cover 2 executions
  403. rt := &testRetryPolicy{NumRetries: 8}
  404. // test Speculative policy with 1 additional execution
  405. sp := &SimpleSpeculativeExecution{NumAttempts: 1, TimeoutDelay: 200 * time.Millisecond}
  406. // Build the query
  407. qry := db.Query("speculative").RetryPolicy(rt).SetSpeculativeExecutionPolicy(sp).Idempotent(true)
  408. // Execute the query and close, check that it doesn't error out
  409. if err := qry.Exec(); err != nil {
  410. t.Errorf("The query failed with '%v'!\n", err)
  411. }
  412. requests1 := atomic.LoadInt64(&nodes[0].nKillReq)
  413. requests2 := atomic.LoadInt64(&nodes[1].nKillReq)
  414. requests3 := atomic.LoadInt64(&nodes[2].nKillReq)
  415. // Spec Attempts == 1, so expecting to see only 1 regular + 1 speculative = 2 nodes attempted
  416. if requests1 != 0 && requests2 != 0 && requests3 != 0 {
  417. t.Error("error: all 3 nodes were attempted, should have been only 2")
  418. }
  419. // Only the 4th request will generate results, so
  420. if requests1 != 4 && requests2 != 4 && requests3 != 4 {
  421. t.Error("error: none of 3 nodes was attempted 4 times!")
  422. }
  423. // "speculative" query will succeed on one arbitrary node after 4 attempts, so
  424. // expecting to see 4 (on successful node) + not more than 2 (as cancelled on another node) == 6
  425. if requests1+requests2+requests3 > 6 {
  426. t.Errorf("error: expected to see 6 attempts, got %v\n", requests1+requests2+requests3)
  427. }
  428. }
  429. func TestStreams_Protocol1(t *testing.T) {
  430. srv := NewTestServer(t, protoVersion1, context.Background())
  431. defer srv.Stop()
  432. // TODO: these are more like session tests and should instead operate
  433. // on a single Conn
  434. cluster := testCluster(protoVersion1, srv.Address)
  435. cluster.NumConns = 1
  436. cluster.ProtoVersion = 1
  437. db, err := cluster.CreateSession()
  438. if err != nil {
  439. t.Fatal(err)
  440. }
  441. defer db.Close()
  442. var wg sync.WaitGroup
  443. for i := 1; i < 128; i++ {
  444. // here were just validating that if we send NumStream request we get
  445. // a response for every stream and the lengths for the queries are set
  446. // correctly.
  447. wg.Add(1)
  448. go func() {
  449. defer wg.Done()
  450. if err := db.Query("void").Exec(); err != nil {
  451. t.Error(err)
  452. }
  453. }()
  454. }
  455. wg.Wait()
  456. }
  457. func TestStreams_Protocol3(t *testing.T) {
  458. srv := NewTestServer(t, protoVersion3, context.Background())
  459. defer srv.Stop()
  460. // TODO: these are more like session tests and should instead operate
  461. // on a single Conn
  462. cluster := testCluster(protoVersion3, srv.Address)
  463. cluster.NumConns = 1
  464. cluster.ProtoVersion = 3
  465. db, err := cluster.CreateSession()
  466. if err != nil {
  467. t.Fatal(err)
  468. }
  469. defer db.Close()
  470. for i := 1; i < 32768; i++ {
  471. // the test server processes each conn synchronously
  472. // here were just validating that if we send NumStream request we get
  473. // a response for every stream and the lengths for the queries are set
  474. // correctly.
  475. if err = db.Query("void").Exec(); err != nil {
  476. t.Fatal(err)
  477. }
  478. }
  479. }
  480. func BenchmarkProtocolV3(b *testing.B) {
  481. srv := NewTestServer(b, protoVersion3, context.Background())
  482. defer srv.Stop()
  483. // TODO: these are more like session tests and should instead operate
  484. // on a single Conn
  485. cluster := NewCluster(srv.Address)
  486. cluster.NumConns = 1
  487. cluster.ProtoVersion = 3
  488. db, err := cluster.CreateSession()
  489. if err != nil {
  490. b.Fatal(err)
  491. }
  492. defer db.Close()
  493. b.ResetTimer()
  494. b.ReportAllocs()
  495. for i := 0; i < b.N; i++ {
  496. if err = db.Query("void").Exec(); err != nil {
  497. b.Fatal(err)
  498. }
  499. }
  500. }
  501. // This tests that the policy connection pool handles SSL correctly
  502. func TestPolicyConnPoolSSL(t *testing.T) {
  503. srv := NewSSLTestServer(t, defaultProto, context.Background())
  504. defer srv.Stop()
  505. cluster := createTestSslCluster(srv.Address, defaultProto, true)
  506. cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
  507. db, err := cluster.CreateSession()
  508. if err != nil {
  509. t.Fatalf("failed to create new session: %v", err)
  510. }
  511. if err := db.Query("void").Exec(); err != nil {
  512. t.Fatalf("query failed due to error: %v", err)
  513. }
  514. db.Close()
  515. // wait for the pool to drain
  516. time.Sleep(100 * time.Millisecond)
  517. size := db.pool.Size()
  518. if size != 0 {
  519. t.Fatalf("connection pool did not drain, still contains %d connections", size)
  520. }
  521. }
  522. func TestQueryTimeout(t *testing.T) {
  523. srv := NewTestServer(t, defaultProto, context.Background())
  524. defer srv.Stop()
  525. cluster := testCluster(defaultProto, srv.Address)
  526. // Set the timeout arbitrarily low so that the query hits the timeout in a
  527. // timely manner.
  528. cluster.Timeout = 1 * time.Millisecond
  529. db, err := cluster.CreateSession()
  530. if err != nil {
  531. t.Fatalf("NewCluster: %v", err)
  532. }
  533. defer db.Close()
  534. ch := make(chan error, 1)
  535. go func() {
  536. err := db.Query("timeout").Exec()
  537. if err != nil {
  538. ch <- err
  539. return
  540. }
  541. t.Errorf("err was nil, expected to get a timeout after %v", db.cfg.Timeout)
  542. }()
  543. select {
  544. case err := <-ch:
  545. if err != ErrTimeoutNoResponse {
  546. t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err)
  547. }
  548. case <-time.After(40*time.Millisecond + db.cfg.Timeout):
  549. // ensure that the query goroutines have been scheduled
  550. t.Fatalf("query did not timeout after %v", db.cfg.Timeout)
  551. }
  552. }
  553. func BenchmarkSingleConn(b *testing.B) {
  554. srv := NewTestServer(b, 3, context.Background())
  555. defer srv.Stop()
  556. cluster := testCluster(3, srv.Address)
  557. // Set the timeout arbitrarily low so that the query hits the timeout in a
  558. // timely manner.
  559. cluster.Timeout = 500 * time.Millisecond
  560. cluster.NumConns = 1
  561. db, err := cluster.CreateSession()
  562. if err != nil {
  563. b.Fatalf("NewCluster: %v", err)
  564. }
  565. defer db.Close()
  566. b.ResetTimer()
  567. b.RunParallel(func(pb *testing.PB) {
  568. for pb.Next() {
  569. err := db.Query("void").Exec()
  570. if err != nil {
  571. b.Error(err)
  572. return
  573. }
  574. }
  575. })
  576. }
  577. func TestQueryTimeoutReuseStream(t *testing.T) {
  578. t.Skip("no longer tests anything")
  579. // TODO(zariel): move this to conn test, we really just want to check what
  580. // happens when a conn is
  581. srv := NewTestServer(t, defaultProto, context.Background())
  582. defer srv.Stop()
  583. cluster := testCluster(defaultProto, srv.Address)
  584. // Set the timeout arbitrarily low so that the query hits the timeout in a
  585. // timely manner.
  586. cluster.Timeout = 1 * time.Millisecond
  587. cluster.NumConns = 1
  588. db, err := cluster.CreateSession()
  589. if err != nil {
  590. t.Fatalf("NewCluster: %v", err)
  591. }
  592. defer db.Close()
  593. db.Query("slow").Exec()
  594. err = db.Query("void").Exec()
  595. if err != nil {
  596. t.Fatal(err)
  597. }
  598. }
  599. func TestQueryTimeoutClose(t *testing.T) {
  600. srv := NewTestServer(t, defaultProto, context.Background())
  601. defer srv.Stop()
  602. cluster := testCluster(defaultProto, srv.Address)
  603. // Set the timeout arbitrarily low so that the query hits the timeout in a
  604. // timely manner.
  605. cluster.Timeout = 1000 * time.Millisecond
  606. cluster.NumConns = 1
  607. db, err := cluster.CreateSession()
  608. if err != nil {
  609. t.Fatalf("NewCluster: %v", err)
  610. }
  611. ch := make(chan error)
  612. go func() {
  613. err := db.Query("timeout").Exec()
  614. ch <- err
  615. }()
  616. // ensure that the above goroutine gets sheduled
  617. time.Sleep(50 * time.Millisecond)
  618. db.Close()
  619. select {
  620. case err = <-ch:
  621. case <-time.After(1 * time.Second):
  622. t.Fatal("timedout waiting to get a response once cluster is closed")
  623. }
  624. if err != ErrConnectionClosed {
  625. t.Fatalf("expected to get %v got %v", ErrConnectionClosed, err)
  626. }
  627. }
  628. func TestStream0(t *testing.T) {
  629. // TODO: replace this with type check
  630. const expErr = "gocql: received unexpected frame on stream 0"
  631. var buf bytes.Buffer
  632. f := newFramer(nil, &buf, nil, protoVersion4)
  633. f.writeHeader(0, opResult, 0)
  634. f.writeInt(resultKindVoid)
  635. f.wbuf[0] |= 0x80
  636. if err := f.finishWrite(); err != nil {
  637. t.Fatal(err)
  638. }
  639. conn := &Conn{
  640. r: bufio.NewReader(&buf),
  641. streams: streams.New(protoVersion4),
  642. }
  643. err := conn.recv(context.Background())
  644. if err == nil {
  645. t.Fatal("expected to get an error on stream 0")
  646. } else if !strings.HasPrefix(err.Error(), expErr) {
  647. t.Fatalf("expected to get error prefix %q got %q", expErr, err.Error())
  648. }
  649. }
  650. func TestContext_Timeout(t *testing.T) {
  651. ctx, cancel := context.WithCancel(context.Background())
  652. defer cancel()
  653. srv := NewTestServer(t, defaultProto, ctx)
  654. defer srv.Stop()
  655. cluster := testCluster(defaultProto, srv.Address)
  656. cluster.Timeout = 5 * time.Second
  657. db, err := cluster.CreateSession()
  658. if err != nil {
  659. t.Fatal(err)
  660. }
  661. defer db.Close()
  662. ctx, cancel = context.WithCancel(ctx)
  663. cancel()
  664. err = db.Query("timeout").WithContext(ctx).Exec()
  665. if err != context.Canceled {
  666. t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
  667. }
  668. }
  669. // tcpConnPair returns a matching set of a TCP client side and server side connection.
  670. func tcpConnPair() (s, c net.Conn, err error) {
  671. l, err := net.Listen("tcp", "localhost:0")
  672. if err != nil {
  673. // maybe ipv6 works, if ipv4 fails?
  674. l, err = net.Listen("tcp6", "[::1]:0")
  675. if err != nil {
  676. return nil, nil, err
  677. }
  678. }
  679. defer l.Close() // we only try to accept one connection, so will stop listening.
  680. addr := l.Addr()
  681. done := make(chan struct{})
  682. var errDial error
  683. go func(done chan<- struct{}) {
  684. c, errDial = net.Dial(addr.Network(), addr.String())
  685. close(done)
  686. }(done)
  687. s, err = l.Accept()
  688. <-done
  689. if err == nil {
  690. err = errDial
  691. }
  692. if err != nil {
  693. if s != nil {
  694. s.Close()
  695. }
  696. if c != nil {
  697. c.Close()
  698. }
  699. }
  700. return s, c, err
  701. }
  702. func TestWriteCoalescing(t *testing.T) {
  703. ctx, cancel := context.WithCancel(context.Background())
  704. defer cancel()
  705. server, client, err := tcpConnPair()
  706. if err != nil {
  707. t.Fatal(err)
  708. }
  709. done := make(chan struct{}, 1)
  710. var (
  711. buf bytes.Buffer
  712. bufMutex sync.Mutex
  713. )
  714. go func() {
  715. defer close(done)
  716. defer server.Close()
  717. var err error
  718. b := make([]byte, 256)
  719. var n int
  720. for {
  721. if n, err = server.Read(b); err != nil {
  722. break
  723. }
  724. bufMutex.Lock()
  725. buf.Write(b[:n])
  726. bufMutex.Unlock()
  727. }
  728. if err != io.EOF {
  729. t.Errorf("unexpected read error: %v", err)
  730. }
  731. }()
  732. w := &writeCoalescer{
  733. c: client,
  734. writeCh: make(chan struct{}),
  735. cond: sync.NewCond(&sync.Mutex{}),
  736. quit: ctx.Done(),
  737. running: true,
  738. }
  739. go func() {
  740. if _, err := w.Write([]byte("one")); err != nil {
  741. t.Error(err)
  742. }
  743. }()
  744. go func() {
  745. if _, err := w.Write([]byte("two")); err != nil {
  746. t.Error(err)
  747. }
  748. }()
  749. bufMutex.Lock()
  750. if buf.Len() != 0 {
  751. t.Fatalf("expected buffer to be empty have: %v", buf.String())
  752. }
  753. bufMutex.Unlock()
  754. for true {
  755. w.cond.L.Lock()
  756. if len(w.buffers) == 2 {
  757. w.cond.L.Unlock()
  758. break
  759. }
  760. w.cond.L.Unlock()
  761. }
  762. w.flush()
  763. client.Close()
  764. <-done
  765. if got := buf.String(); got != "onetwo" && got != "twoone" {
  766. t.Fatalf("expected to get %q got %q", "onetwo or twoone", got)
  767. }
  768. }
  769. func TestWriteCoalescing_WriteAfterClose(t *testing.T) {
  770. ctx, cancel := context.WithCancel(context.Background())
  771. defer cancel()
  772. var buf bytes.Buffer
  773. defer cancel()
  774. server, client, err := tcpConnPair()
  775. if err != nil {
  776. t.Fatal(err)
  777. }
  778. done := make(chan struct{}, 1)
  779. go func() {
  780. io.Copy(&buf, server)
  781. server.Close()
  782. close(done)
  783. }()
  784. w := newWriteCoalescer(client, 0, 5*time.Millisecond, ctx.Done())
  785. // ensure 1 write works
  786. if _, err := w.Write([]byte("one")); err != nil {
  787. t.Fatal(err)
  788. }
  789. client.Close()
  790. <-done
  791. if v := buf.String(); v != "one" {
  792. t.Fatalf("expected buffer to be %q got %q", "one", v)
  793. }
  794. // now close and do a write, we should error
  795. cancel()
  796. client.Close() // close client conn too, since server won't see the answer anyway.
  797. if _, err := w.Write([]byte("two")); err == nil {
  798. t.Fatal("expected to get error for write after closing")
  799. } else if err != io.EOF {
  800. t.Fatalf("expected to get EOF got %v", err)
  801. }
  802. }
  803. type recordingFrameHeaderObserver struct {
  804. t *testing.T
  805. mu sync.Mutex
  806. frames []ObservedFrameHeader
  807. }
  808. func (r *recordingFrameHeaderObserver) ObserveFrameHeader(ctx context.Context, frm ObservedFrameHeader) {
  809. r.mu.Lock()
  810. r.frames = append(r.frames, frm)
  811. r.mu.Unlock()
  812. }
  813. func (r *recordingFrameHeaderObserver) getFrames() []ObservedFrameHeader {
  814. r.mu.Lock()
  815. defer r.mu.Unlock()
  816. return r.frames
  817. }
  818. func TestFrameHeaderObserver(t *testing.T) {
  819. srv := NewTestServer(t, defaultProto, context.Background())
  820. defer srv.Stop()
  821. cluster := testCluster(defaultProto, srv.Address)
  822. cluster.NumConns = 1
  823. observer := &recordingFrameHeaderObserver{t: t}
  824. cluster.FrameHeaderObserver = observer
  825. db, err := cluster.CreateSession()
  826. if err != nil {
  827. t.Fatal(err)
  828. }
  829. if err := db.Query("void").Exec(); err != nil {
  830. t.Fatal(err)
  831. }
  832. frames := observer.getFrames()
  833. expFrames := []frameOp{opSupported, opReady, opResult}
  834. if len(frames) != len(expFrames) {
  835. t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames))
  836. }
  837. for i, op := range expFrames {
  838. if op != frames[i].Opcode {
  839. t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i])
  840. }
  841. }
  842. voidResultFrame := frames[2]
  843. if voidResultFrame.Length != int32(4) {
  844. t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length)
  845. }
  846. }
  847. func NewTestServerWithAddress(addr string, t testing.TB, protocol uint8, ctx context.Context) *TestServer {
  848. laddr, err := net.ResolveTCPAddr("tcp", addr)
  849. if err != nil {
  850. t.Fatal(err)
  851. }
  852. listen, err := net.ListenTCP("tcp", laddr)
  853. if err != nil {
  854. t.Fatal(err)
  855. }
  856. headerSize := 8
  857. if protocol > protoVersion2 {
  858. headerSize = 9
  859. }
  860. ctx, cancel := context.WithCancel(ctx)
  861. srv := &TestServer{
  862. Address: listen.Addr().String(),
  863. listen: listen,
  864. t: t,
  865. protocol: protocol,
  866. headerSize: headerSize,
  867. ctx: ctx,
  868. cancel: cancel,
  869. }
  870. go srv.closeWatch()
  871. go srv.serve()
  872. return srv
  873. }
  874. func NewTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
  875. return NewTestServerWithAddress("127.0.0.1:0", t, protocol, ctx)
  876. }
  877. func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
  878. pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
  879. certPool := x509.NewCertPool()
  880. if !certPool.AppendCertsFromPEM(pem) {
  881. t.Fatalf("Failed parsing or appending certs")
  882. }
  883. mycert, err := tls.LoadX509KeyPair("testdata/pki/cassandra.crt", "testdata/pki/cassandra.key")
  884. if err != nil {
  885. t.Fatalf("could not load cert")
  886. }
  887. config := &tls.Config{
  888. Certificates: []tls.Certificate{mycert},
  889. RootCAs: certPool,
  890. }
  891. listen, err := tls.Listen("tcp", "127.0.0.1:0", config)
  892. if err != nil {
  893. t.Fatal(err)
  894. }
  895. headerSize := 8
  896. if protocol > protoVersion2 {
  897. headerSize = 9
  898. }
  899. ctx, cancel := context.WithCancel(ctx)
  900. srv := &TestServer{
  901. Address: listen.Addr().String(),
  902. listen: listen,
  903. t: t,
  904. protocol: protocol,
  905. headerSize: headerSize,
  906. ctx: ctx,
  907. cancel: cancel,
  908. }
  909. go srv.closeWatch()
  910. go srv.serve()
  911. return srv
  912. }
  913. type TestServer struct {
  914. Address string
  915. TimeoutOnStartup int32
  916. t testing.TB
  917. nreq uint64
  918. listen net.Listener
  919. nKillReq int64
  920. compressor Compressor
  921. protocol byte
  922. headerSize int
  923. ctx context.Context
  924. cancel context.CancelFunc
  925. quit chan struct{}
  926. mu sync.Mutex
  927. closed bool
  928. }
  929. func (srv *TestServer) session() (*Session, error) {
  930. return testCluster(protoVersion(srv.protocol), srv.Address).CreateSession()
  931. }
  932. func (srv *TestServer) host() *HostInfo {
  933. hosts, err := hostInfo(srv.Address, 9042)
  934. if err != nil {
  935. srv.t.Fatal(err)
  936. }
  937. return hosts[0]
  938. }
  939. func (srv *TestServer) closeWatch() {
  940. <-srv.ctx.Done()
  941. srv.mu.Lock()
  942. defer srv.mu.Unlock()
  943. srv.closeLocked()
  944. }
  945. func (srv *TestServer) serve() {
  946. defer srv.listen.Close()
  947. for !srv.isClosed() {
  948. conn, err := srv.listen.Accept()
  949. if err != nil {
  950. break
  951. }
  952. go func(conn net.Conn) {
  953. defer conn.Close()
  954. for !srv.isClosed() {
  955. framer, err := srv.readFrame(conn)
  956. if err != nil {
  957. if err == io.EOF {
  958. return
  959. }
  960. srv.errorLocked(err)
  961. return
  962. }
  963. atomic.AddUint64(&srv.nreq, 1)
  964. go srv.process(framer)
  965. }
  966. }(conn)
  967. }
  968. }
  969. func (srv *TestServer) isClosed() bool {
  970. srv.mu.Lock()
  971. defer srv.mu.Unlock()
  972. return srv.closed
  973. }
  974. func (srv *TestServer) closeLocked() {
  975. if srv.closed {
  976. return
  977. }
  978. srv.closed = true
  979. srv.listen.Close()
  980. srv.cancel()
  981. }
  982. func (srv *TestServer) Stop() {
  983. srv.mu.Lock()
  984. defer srv.mu.Unlock()
  985. srv.closeLocked()
  986. }
  987. func (srv *TestServer) errorLocked(err interface{}) {
  988. srv.mu.Lock()
  989. defer srv.mu.Unlock()
  990. if srv.closed {
  991. return
  992. }
  993. srv.t.Error(err)
  994. }
  995. func (srv *TestServer) process(f *framer) {
  996. head := f.header
  997. if head == nil {
  998. srv.errorLocked("process frame with a nil header")
  999. return
  1000. }
  1001. switch head.op {
  1002. case opStartup:
  1003. if atomic.LoadInt32(&srv.TimeoutOnStartup) > 0 {
  1004. // Do not respond to startup command
  1005. // wait until we get a cancel signal
  1006. select {
  1007. case <-srv.ctx.Done():
  1008. return
  1009. }
  1010. }
  1011. f.writeHeader(0, opReady, head.stream)
  1012. case opOptions:
  1013. f.writeHeader(0, opSupported, head.stream)
  1014. f.writeShort(0)
  1015. case opQuery:
  1016. query := f.readLongString()
  1017. first := query
  1018. if n := strings.Index(query, " "); n > 0 {
  1019. first = first[:n]
  1020. }
  1021. switch strings.ToLower(first) {
  1022. case "kill":
  1023. atomic.AddInt64(&srv.nKillReq, 1)
  1024. f.writeHeader(0, opError, head.stream)
  1025. f.writeInt(0x1001)
  1026. f.writeString("query killed")
  1027. case "use":
  1028. f.writeInt(resultKindKeyspace)
  1029. f.writeString(strings.TrimSpace(query[3:]))
  1030. case "void":
  1031. f.writeHeader(0, opResult, head.stream)
  1032. f.writeInt(resultKindVoid)
  1033. case "timeout":
  1034. <-srv.ctx.Done()
  1035. return
  1036. case "slow":
  1037. go func() {
  1038. f.writeHeader(0, opResult, head.stream)
  1039. f.writeInt(resultKindVoid)
  1040. f.wbuf[0] = srv.protocol | 0x80
  1041. select {
  1042. case <-srv.ctx.Done():
  1043. return
  1044. case <-time.After(50 * time.Millisecond):
  1045. f.finishWrite()
  1046. }
  1047. }()
  1048. return
  1049. case "speculative":
  1050. atomic.AddInt64(&srv.nKillReq, 1)
  1051. if atomic.LoadInt64(&srv.nKillReq) > 3 {
  1052. f.writeHeader(0, opResult, head.stream)
  1053. f.writeInt(resultKindVoid)
  1054. f.writeString("speculative query success on the node " + srv.Address)
  1055. } else {
  1056. f.writeHeader(0, opError, head.stream)
  1057. f.writeInt(0x1001)
  1058. f.writeString("speculative error")
  1059. rand.Seed(time.Now().UnixNano())
  1060. <-time.After(time.Millisecond * 120)
  1061. }
  1062. default:
  1063. f.writeHeader(0, opResult, head.stream)
  1064. f.writeInt(resultKindVoid)
  1065. }
  1066. case opError:
  1067. f.writeHeader(0, opError, head.stream)
  1068. f.wbuf = append(f.wbuf, f.rbuf...)
  1069. default:
  1070. f.writeHeader(0, opError, head.stream)
  1071. f.writeInt(0)
  1072. f.writeString("not supported")
  1073. }
  1074. f.wbuf[0] = srv.protocol | 0x80
  1075. if err := f.finishWrite(); err != nil {
  1076. srv.errorLocked(err)
  1077. }
  1078. }
  1079. func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
  1080. buf := make([]byte, srv.headerSize)
  1081. head, err := readHeader(conn, buf)
  1082. if err != nil {
  1083. return nil, err
  1084. }
  1085. framer := newFramer(conn, conn, nil, srv.protocol)
  1086. err = framer.readFrame(&head)
  1087. if err != nil {
  1088. return nil, err
  1089. }
  1090. // should be a request frame
  1091. if head.version.response() {
  1092. return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version)
  1093. } else if head.version.version() != srv.protocol {
  1094. return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version())
  1095. }
  1096. return framer, nil
  1097. }