connection_test.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "context"
  11. "database/sql/driver"
  12. "errors"
  13. "net"
  14. "testing"
  15. )
  16. func TestInterpolateParams(t *testing.T) {
  17. mc := &mysqlConn{
  18. buf: newBuffer(nil),
  19. maxAllowedPacket: maxPacketSize,
  20. cfg: &Config{
  21. InterpolateParams: true,
  22. },
  23. }
  24. q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
  25. if err != nil {
  26. t.Errorf("Expected err=nil, got %#v", err)
  27. return
  28. }
  29. expected := `SELECT 42+'gopher'`
  30. if q != expected {
  31. t.Errorf("Expected: %q\nGot: %q", expected, q)
  32. }
  33. }
  34. func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
  35. mc := &mysqlConn{
  36. buf: newBuffer(nil),
  37. maxAllowedPacket: maxPacketSize,
  38. cfg: &Config{
  39. InterpolateParams: true,
  40. },
  41. }
  42. q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
  43. if err != driver.ErrSkip {
  44. t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
  45. }
  46. }
  47. // We don't support placeholder in string literal for now.
  48. // https://github.com/go-sql-driver/mysql/pull/490
  49. func TestInterpolateParamsPlaceholderInString(t *testing.T) {
  50. mc := &mysqlConn{
  51. buf: newBuffer(nil),
  52. maxAllowedPacket: maxPacketSize,
  53. cfg: &Config{
  54. InterpolateParams: true,
  55. },
  56. }
  57. q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
  58. // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
  59. if err != driver.ErrSkip {
  60. t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
  61. }
  62. }
  63. func TestInterpolateParamsUint64(t *testing.T) {
  64. mc := &mysqlConn{
  65. buf: newBuffer(nil),
  66. maxAllowedPacket: maxPacketSize,
  67. cfg: &Config{
  68. InterpolateParams: true,
  69. },
  70. }
  71. q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)})
  72. if err != nil {
  73. t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q)
  74. }
  75. if q != "SELECT 42" {
  76. t.Errorf("Expected uint64 interpolation to work, got q=%#v", q)
  77. }
  78. }
  79. func TestCheckNamedValue(t *testing.T) {
  80. value := driver.NamedValue{Value: ^uint64(0)}
  81. x := &mysqlConn{}
  82. err := x.CheckNamedValue(&value)
  83. if err != nil {
  84. t.Fatal("uint64 high-bit not convertible", err)
  85. }
  86. if value.Value != ^uint64(0) {
  87. t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value)
  88. }
  89. }
  90. // TestCleanCancel tests passed context is cancelled at start.
  91. // No packet should be sent. Connection should keep current status.
  92. func TestCleanCancel(t *testing.T) {
  93. mc := &mysqlConn{
  94. closech: make(chan struct{}),
  95. }
  96. mc.startWatcher()
  97. defer mc.cleanup()
  98. ctx, cancel := context.WithCancel(context.Background())
  99. cancel()
  100. for i := 0; i < 3; i++ { // Repeat same behavior
  101. err := mc.Ping(ctx)
  102. if err != context.Canceled {
  103. t.Errorf("expected context.Canceled, got %#v", err)
  104. }
  105. if mc.closed.IsSet() {
  106. t.Error("expected mc is not closed, closed actually")
  107. }
  108. if mc.watching {
  109. t.Error("expected watching is false, but true")
  110. }
  111. }
  112. }
  113. func TestPingMarkBadConnection(t *testing.T) {
  114. nc := badConnection{err: errors.New("boom")}
  115. ms := &mysqlConn{
  116. netConn: nc,
  117. buf: newBuffer(nc),
  118. maxAllowedPacket: defaultMaxAllowedPacket,
  119. }
  120. err := ms.Ping(context.Background())
  121. if err != driver.ErrBadConn {
  122. t.Errorf("expected driver.ErrBadConn, got %#v", err)
  123. }
  124. }
  125. func TestPingErrInvalidConn(t *testing.T) {
  126. nc := badConnection{err: errors.New("failed to write"), n: 10}
  127. ms := &mysqlConn{
  128. netConn: nc,
  129. buf: newBuffer(nc),
  130. maxAllowedPacket: defaultMaxAllowedPacket,
  131. closech: make(chan struct{}),
  132. }
  133. err := ms.Ping(context.Background())
  134. if err != ErrInvalidConn {
  135. t.Errorf("expected ErrInvalidConn, got %#v", err)
  136. }
  137. }
  138. type badConnection struct {
  139. n int
  140. err error
  141. net.Conn
  142. }
  143. func (bc badConnection) Write(b []byte) (n int, err error) {
  144. return bc.n, bc.err
  145. }
  146. func (bc badConnection) Close() error {
  147. return nil
  148. }