connection_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "context"
  11. "database/sql/driver"
  12. "errors"
  13. "net"
  14. "testing"
  15. )
  16. func TestInterpolateParams(t *testing.T) {
  17. mc := &mysqlConn{
  18. buf: newBuffer(nil),
  19. maxAllowedPacket: maxPacketSize,
  20. cfg: &Config{
  21. InterpolateParams: true,
  22. },
  23. }
  24. q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
  25. if err != nil {
  26. t.Errorf("Expected err=nil, got %#v", err)
  27. return
  28. }
  29. expected := `SELECT 42+'gopher'`
  30. if q != expected {
  31. t.Errorf("Expected: %q\nGot: %q", expected, q)
  32. }
  33. }
  34. func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
  35. mc := &mysqlConn{
  36. buf: newBuffer(nil),
  37. maxAllowedPacket: maxPacketSize,
  38. cfg: &Config{
  39. InterpolateParams: true,
  40. },
  41. }
  42. q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
  43. if err != driver.ErrSkip {
  44. t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
  45. }
  46. }
  47. // We don't support placeholder in string literal for now.
  48. // https://github.com/go-sql-driver/mysql/pull/490
  49. func TestInterpolateParamsPlaceholderInString(t *testing.T) {
  50. mc := &mysqlConn{
  51. buf: newBuffer(nil),
  52. maxAllowedPacket: maxPacketSize,
  53. cfg: &Config{
  54. InterpolateParams: true,
  55. },
  56. }
  57. q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
  58. // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
  59. if err != driver.ErrSkip {
  60. t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
  61. }
  62. }
  63. func TestCheckNamedValue(t *testing.T) {
  64. value := driver.NamedValue{Value: ^uint64(0)}
  65. x := &mysqlConn{}
  66. err := x.CheckNamedValue(&value)
  67. if err != nil {
  68. t.Fatal("uint64 high-bit not convertible", err)
  69. }
  70. if value.Value != "18446744073709551615" {
  71. t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
  72. }
  73. }
  74. // TestCleanCancel tests passed context is cancelled at start.
  75. // No packet should be sent. Connection should keep current status.
  76. func TestCleanCancel(t *testing.T) {
  77. mc := &mysqlConn{
  78. closech: make(chan struct{}),
  79. }
  80. mc.startWatcher()
  81. defer mc.cleanup()
  82. ctx, cancel := context.WithCancel(context.Background())
  83. cancel()
  84. for i := 0; i < 3; i++ { // Repeat same behavior
  85. err := mc.Ping(ctx)
  86. if err != context.Canceled {
  87. t.Errorf("expected context.Canceled, got %#v", err)
  88. }
  89. if mc.closed.IsSet() {
  90. t.Error("expected mc is not closed, closed actually")
  91. }
  92. if mc.watching {
  93. t.Error("expected watching is false, but true")
  94. }
  95. }
  96. }
  97. func TestPingMarkBadConnection(t *testing.T) {
  98. nc := badConnection{err: errors.New("boom")}
  99. ms := &mysqlConn{
  100. netConn: nc,
  101. buf: newBuffer(nc),
  102. maxAllowedPacket: defaultMaxAllowedPacket,
  103. }
  104. err := ms.Ping(context.Background())
  105. if err != driver.ErrBadConn {
  106. t.Errorf("expected driver.ErrBadConn, got %#v", err)
  107. }
  108. }
  109. func TestPingErrInvalidConn(t *testing.T) {
  110. nc := badConnection{err: errors.New("failed to write"), n: 10}
  111. ms := &mysqlConn{
  112. netConn: nc,
  113. buf: newBuffer(nc),
  114. maxAllowedPacket: defaultMaxAllowedPacket,
  115. closech: make(chan struct{}),
  116. }
  117. err := ms.Ping(context.Background())
  118. if err != ErrInvalidConn {
  119. t.Errorf("expected ErrInvalidConn, got %#v", err)
  120. }
  121. }
  122. type badConnection struct {
  123. n int
  124. err error
  125. net.Conn
  126. }
  127. func (bc badConnection) Write(b []byte) (n int, err error) {
  128. return bc.n, bc.err
  129. }
  130. func (bc badConnection) Close() error {
  131. return nil
  132. }