session_connect_test.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. package gocql
  2. import (
  3. "context"
  4. "net"
  5. "strconv"
  6. "sync"
  7. "testing"
  8. "time"
  9. )
  10. type OneConnTestServer struct {
  11. Err error
  12. Addr net.IP
  13. Port int
  14. listener net.Listener
  15. acceptChan chan struct{}
  16. mu sync.Mutex
  17. closed bool
  18. }
  19. func NewOneConnTestServer() (*OneConnTestServer, error) {
  20. lstn, err := net.Listen("tcp4", "localhost:0")
  21. if err != nil {
  22. return nil, err
  23. }
  24. addr, port := parseAddressPort(lstn.Addr().String())
  25. return &OneConnTestServer{
  26. listener: lstn,
  27. acceptChan: make(chan struct{}),
  28. Addr: addr,
  29. Port: port,
  30. }, nil
  31. }
  32. func (c *OneConnTestServer) Accepted() chan struct{} {
  33. return c.acceptChan
  34. }
  35. func (c *OneConnTestServer) Close() {
  36. c.lockedClose()
  37. }
  38. func (c *OneConnTestServer) Serve() {
  39. conn, err := c.listener.Accept()
  40. c.Err = err
  41. if conn != nil {
  42. conn.Close()
  43. }
  44. c.lockedClose()
  45. }
  46. func (c *OneConnTestServer) lockedClose() {
  47. c.mu.Lock()
  48. defer c.mu.Unlock()
  49. if !c.closed {
  50. close(c.acceptChan)
  51. c.listener.Close()
  52. c.closed = true
  53. }
  54. }
  55. func parseAddressPort(hostPort string) (net.IP, int) {
  56. host, portStr, err := net.SplitHostPort(hostPort)
  57. if err != nil {
  58. return net.ParseIP(""), 0
  59. }
  60. port, _ := strconv.Atoi(portStr)
  61. return net.ParseIP(host), port
  62. }
  63. func testConnErrorHandler(t *testing.T) ConnErrorHandler {
  64. return connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
  65. t.Errorf("in connection handler: %v", err)
  66. })
  67. }
  68. func assertConnectionEventually(t *testing.T, wait time.Duration, srvr *OneConnTestServer) {
  69. ctx, cancel := context.WithTimeout(context.Background(), wait)
  70. defer cancel()
  71. select {
  72. case <-ctx.Done():
  73. if ctx.Err() != nil {
  74. t.Errorf("waiting for connection: %v", ctx.Err())
  75. }
  76. case <-srvr.Accepted():
  77. if srvr.Err != nil {
  78. t.Errorf("accepting connection: %v", srvr.Err)
  79. }
  80. }
  81. }
  82. func TestSession_connect_WithNoTranslator(t *testing.T) {
  83. srvr, err := NewOneConnTestServer()
  84. assertNil(t, "error when creating tcp server", err)
  85. defer srvr.Close()
  86. session := createTestSession()
  87. defer session.Close()
  88. go srvr.Serve()
  89. Connect(&HostInfo{
  90. connectAddress: srvr.Addr,
  91. port: srvr.Port,
  92. }, session.connCfg, testConnErrorHandler(t), session)
  93. assertConnectionEventually(t, 500*time.Millisecond, srvr)
  94. }
  95. func TestSession_connect_WithTranslator(t *testing.T) {
  96. srvr, err := NewOneConnTestServer()
  97. assertNil(t, "error when creating tcp server", err)
  98. defer srvr.Close()
  99. session := createTestSession()
  100. defer session.Close()
  101. session.cfg.AddressTranslator = staticAddressTranslator(srvr.Addr, srvr.Port)
  102. go srvr.Serve()
  103. // the provided address will be translated
  104. Connect(&HostInfo{
  105. connectAddress: net.ParseIP("10.10.10.10"),
  106. port: 5432,
  107. }, session.connCfg, testConnErrorHandler(t), session)
  108. assertConnectionEventually(t, 500*time.Millisecond, srvr)
  109. }