packets_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "errors"
  12. "net"
  13. "testing"
  14. "time"
  15. )
  16. var (
  17. errConnClosed = errors.New("connection is closed")
  18. errConnTooManyReads = errors.New("too many reads")
  19. errConnTooManyWrites = errors.New("too many writes")
  20. )
  21. // struct to mock a net.Conn for testing purposes
  22. type mockConn struct {
  23. laddr net.Addr
  24. raddr net.Addr
  25. data []byte
  26. closed bool
  27. read int
  28. written int
  29. reads int
  30. writes int
  31. maxReads int
  32. maxWrites int
  33. }
  34. func (m *mockConn) Read(b []byte) (n int, err error) {
  35. if m.closed {
  36. return 0, errConnClosed
  37. }
  38. m.reads++
  39. if m.maxReads > 0 && m.reads > m.maxReads {
  40. return 0, errConnTooManyReads
  41. }
  42. n = copy(b, m.data)
  43. m.read += n
  44. m.data = m.data[n:]
  45. return
  46. }
  47. func (m *mockConn) Write(b []byte) (n int, err error) {
  48. if m.closed {
  49. return 0, errConnClosed
  50. }
  51. m.writes++
  52. if m.maxWrites > 0 && m.writes > m.maxWrites {
  53. return 0, errConnTooManyWrites
  54. }
  55. n = len(b)
  56. m.written += n
  57. return
  58. }
  59. func (m *mockConn) Close() error {
  60. m.closed = true
  61. return nil
  62. }
  63. func (m *mockConn) LocalAddr() net.Addr {
  64. return m.laddr
  65. }
  66. func (m *mockConn) RemoteAddr() net.Addr {
  67. return m.raddr
  68. }
  69. func (m *mockConn) SetDeadline(t time.Time) error {
  70. return nil
  71. }
  72. func (m *mockConn) SetReadDeadline(t time.Time) error {
  73. return nil
  74. }
  75. func (m *mockConn) SetWriteDeadline(t time.Time) error {
  76. return nil
  77. }
  78. // make sure mockConn implements the net.Conn interface
  79. var _ net.Conn = new(mockConn)
  80. func TestReadPacketSingleByte(t *testing.T) {
  81. conn := new(mockConn)
  82. mc := &mysqlConn{
  83. buf: newBuffer(conn),
  84. }
  85. conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
  86. conn.maxReads = 1
  87. packet, err := mc.readPacket()
  88. if err != nil {
  89. t.Fatal(err)
  90. }
  91. if len(packet) != 1 {
  92. t.Fatalf("unexpected packet lenght: expected %d, got %d", 1, len(packet))
  93. }
  94. if packet[0] != 0xff {
  95. t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
  96. }
  97. }
  98. func TestReadPacketWrongSequenceID(t *testing.T) {
  99. conn := new(mockConn)
  100. mc := &mysqlConn{
  101. buf: newBuffer(conn),
  102. }
  103. // too low sequence id
  104. conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
  105. conn.maxReads = 1
  106. mc.sequence = 1
  107. _, err := mc.readPacket()
  108. if err != ErrPktSync {
  109. t.Errorf("expected ErrPktSync, got %v", err)
  110. }
  111. // reset
  112. conn.reads = 0
  113. mc.sequence = 0
  114. mc.buf = newBuffer(conn)
  115. // too high sequence id
  116. conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
  117. _, err = mc.readPacket()
  118. if err != ErrPktSyncMul {
  119. t.Errorf("expected ErrPktSyncMul, got %v", err)
  120. }
  121. }
  122. func TestReadPacketSplit(t *testing.T) {
  123. conn := new(mockConn)
  124. mc := &mysqlConn{
  125. buf: newBuffer(conn),
  126. }
  127. data := make([]byte, maxPacketSize*2+4*3)
  128. const pkt2ofs = maxPacketSize + 4
  129. const pkt3ofs = 2 * (maxPacketSize + 4)
  130. // case 1: payload has length maxPacketSize
  131. data = data[:pkt2ofs+4]
  132. // 1st packet has maxPacketSize length and sequence id 0
  133. // ff ff ff 00 ...
  134. data[0] = 0xff
  135. data[1] = 0xff
  136. data[2] = 0xff
  137. // mark the payload start and end of 1st packet so that we can check if the
  138. // content was correctly appended
  139. data[4] = 0x11
  140. data[maxPacketSize+3] = 0x22
  141. // 2nd packet has payload length 0 and squence id 1
  142. // 00 00 00 01
  143. data[pkt2ofs+3] = 0x01
  144. conn.data = data
  145. conn.maxReads = 3
  146. packet, err := mc.readPacket()
  147. if err != nil {
  148. t.Fatal(err)
  149. }
  150. if len(packet) != maxPacketSize {
  151. t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize, len(packet))
  152. }
  153. if packet[0] != 0x11 {
  154. t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
  155. }
  156. if packet[maxPacketSize-1] != 0x22 {
  157. t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
  158. }
  159. // case 2: payload has length which is a multiple of maxPacketSize
  160. data = data[:cap(data)]
  161. // 2nd packet now has maxPacketSize length
  162. data[pkt2ofs] = 0xff
  163. data[pkt2ofs+1] = 0xff
  164. data[pkt2ofs+2] = 0xff
  165. // mark the payload start and end of the 2nd packet
  166. data[pkt2ofs+4] = 0x33
  167. data[pkt2ofs+maxPacketSize+3] = 0x44
  168. // 3rd packet has payload length 0 and squence id 2
  169. // 00 00 00 02
  170. data[pkt3ofs+3] = 0x02
  171. conn.data = data
  172. conn.reads = 0
  173. conn.maxReads = 5
  174. mc.sequence = 0
  175. packet, err = mc.readPacket()
  176. if err != nil {
  177. t.Fatal(err)
  178. }
  179. if len(packet) != 2*maxPacketSize {
  180. t.Fatalf("unexpected packet lenght: expected %d, got %d", 2*maxPacketSize, len(packet))
  181. }
  182. if packet[0] != 0x11 {
  183. t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
  184. }
  185. if packet[2*maxPacketSize-1] != 0x44 {
  186. t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
  187. }
  188. // case 3: payload has a length larger maxPacketSize, which is not an exact
  189. // multiple of it
  190. data = data[:pkt2ofs+4+42]
  191. data[pkt2ofs] = 0x2a
  192. data[pkt2ofs+1] = 0x00
  193. data[pkt2ofs+2] = 0x00
  194. data[pkt2ofs+4+41] = 0x44
  195. conn.data = data
  196. conn.reads = 0
  197. conn.maxReads = 4
  198. mc.sequence = 0
  199. packet, err = mc.readPacket()
  200. if err != nil {
  201. t.Fatal(err)
  202. }
  203. if len(packet) != maxPacketSize+42 {
  204. t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize+42, len(packet))
  205. }
  206. if packet[0] != 0x11 {
  207. t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
  208. }
  209. if packet[maxPacketSize+41] != 0x44 {
  210. t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
  211. }
  212. }
  213. func TestReadPacketFail(t *testing.T) {
  214. conn := new(mockConn)
  215. mc := &mysqlConn{
  216. buf: newBuffer(conn),
  217. }
  218. // illegal empty (stand-alone) packet
  219. conn.data = []byte{0x00, 0x00, 0x00, 0x00}
  220. conn.maxReads = 1
  221. _, err := mc.readPacket()
  222. if err != driver.ErrBadConn {
  223. t.Errorf("expected ErrBadConn, got %v", err)
  224. }
  225. // reset
  226. conn.reads = 0
  227. mc.sequence = 0
  228. mc.buf = newBuffer(conn)
  229. // fail to read header
  230. conn.closed = true
  231. _, err = mc.readPacket()
  232. if err != driver.ErrBadConn {
  233. t.Errorf("expected ErrBadConn, got %v", err)
  234. }
  235. // reset
  236. conn.closed = false
  237. conn.reads = 0
  238. mc.sequence = 0
  239. mc.buf = newBuffer(conn)
  240. // fail to read body
  241. conn.maxReads = 1
  242. _, err = mc.readPacket()
  243. if err != driver.ErrBadConn {
  244. t.Errorf("expected ErrBadConn, got %v", err)
  245. }
  246. }