packets_test.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "bytes"
  11. "errors"
  12. "net"
  13. "testing"
  14. "time"
  15. )
  16. var (
  17. errConnClosed = errors.New("connection is closed")
  18. errConnTooManyReads = errors.New("too many reads")
  19. errConnTooManyWrites = errors.New("too many writes")
  20. )
  21. // struct to mock a net.Conn for testing purposes
  22. type mockConn struct {
  23. laddr net.Addr
  24. raddr net.Addr
  25. data []byte
  26. written []byte
  27. queuedReplies [][]byte
  28. closed bool
  29. read int
  30. reads int
  31. writes int
  32. maxReads int
  33. maxWrites int
  34. }
  35. func (m *mockConn) Read(b []byte) (n int, err error) {
  36. if m.closed {
  37. return 0, errConnClosed
  38. }
  39. m.reads++
  40. if m.maxReads > 0 && m.reads > m.maxReads {
  41. return 0, errConnTooManyReads
  42. }
  43. n = copy(b, m.data)
  44. m.read += n
  45. m.data = m.data[n:]
  46. return
  47. }
  48. func (m *mockConn) Write(b []byte) (n int, err error) {
  49. if m.closed {
  50. return 0, errConnClosed
  51. }
  52. m.writes++
  53. if m.maxWrites > 0 && m.writes > m.maxWrites {
  54. return 0, errConnTooManyWrites
  55. }
  56. n = len(b)
  57. m.written = append(m.written, b...)
  58. if n > 0 && len(m.queuedReplies) > 0 {
  59. m.data = m.queuedReplies[0]
  60. m.queuedReplies = m.queuedReplies[1:]
  61. }
  62. return
  63. }
  64. func (m *mockConn) Close() error {
  65. m.closed = true
  66. return nil
  67. }
  68. func (m *mockConn) LocalAddr() net.Addr {
  69. return m.laddr
  70. }
  71. func (m *mockConn) RemoteAddr() net.Addr {
  72. return m.raddr
  73. }
  74. func (m *mockConn) SetDeadline(t time.Time) error {
  75. return nil
  76. }
  77. func (m *mockConn) SetReadDeadline(t time.Time) error {
  78. return nil
  79. }
  80. func (m *mockConn) SetWriteDeadline(t time.Time) error {
  81. return nil
  82. }
  83. // make sure mockConn implements the net.Conn interface
  84. var _ net.Conn = new(mockConn)
  85. func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
  86. conn := new(mockConn)
  87. mc := &mysqlConn{
  88. buf: newBuffer(conn),
  89. cfg: NewConfig(),
  90. netConn: conn,
  91. closech: make(chan struct{}),
  92. maxAllowedPacket: defaultMaxAllowedPacket,
  93. sequence: sequence,
  94. }
  95. return conn, mc
  96. }
  97. func TestReadPacketSingleByte(t *testing.T) {
  98. conn := new(mockConn)
  99. mc := &mysqlConn{
  100. buf: newBuffer(conn),
  101. }
  102. conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
  103. conn.maxReads = 1
  104. packet, err := mc.readPacket()
  105. if err != nil {
  106. t.Fatal(err)
  107. }
  108. if len(packet) != 1 {
  109. t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet))
  110. }
  111. if packet[0] != 0xff {
  112. t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
  113. }
  114. }
  115. func TestReadPacketWrongSequenceID(t *testing.T) {
  116. conn := new(mockConn)
  117. mc := &mysqlConn{
  118. buf: newBuffer(conn),
  119. }
  120. // too low sequence id
  121. conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
  122. conn.maxReads = 1
  123. mc.sequence = 1
  124. _, err := mc.readPacket()
  125. if err != ErrPktSync {
  126. t.Errorf("expected ErrPktSync, got %v", err)
  127. }
  128. // reset
  129. conn.reads = 0
  130. mc.sequence = 0
  131. mc.buf = newBuffer(conn)
  132. // too high sequence id
  133. conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
  134. _, err = mc.readPacket()
  135. if err != ErrPktSyncMul {
  136. t.Errorf("expected ErrPktSyncMul, got %v", err)
  137. }
  138. }
  139. func TestReadPacketSplit(t *testing.T) {
  140. conn := new(mockConn)
  141. mc := &mysqlConn{
  142. buf: newBuffer(conn),
  143. }
  144. data := make([]byte, maxPacketSize*2+4*3)
  145. const pkt2ofs = maxPacketSize + 4
  146. const pkt3ofs = 2 * (maxPacketSize + 4)
  147. // case 1: payload has length maxPacketSize
  148. data = data[:pkt2ofs+4]
  149. // 1st packet has maxPacketSize length and sequence id 0
  150. // ff ff ff 00 ...
  151. data[0] = 0xff
  152. data[1] = 0xff
  153. data[2] = 0xff
  154. // mark the payload start and end of 1st packet so that we can check if the
  155. // content was correctly appended
  156. data[4] = 0x11
  157. data[maxPacketSize+3] = 0x22
  158. // 2nd packet has payload length 0 and squence id 1
  159. // 00 00 00 01
  160. data[pkt2ofs+3] = 0x01
  161. conn.data = data
  162. conn.maxReads = 3
  163. packet, err := mc.readPacket()
  164. if err != nil {
  165. t.Fatal(err)
  166. }
  167. if len(packet) != maxPacketSize {
  168. t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet))
  169. }
  170. if packet[0] != 0x11 {
  171. t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
  172. }
  173. if packet[maxPacketSize-1] != 0x22 {
  174. t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
  175. }
  176. // case 2: payload has length which is a multiple of maxPacketSize
  177. data = data[:cap(data)]
  178. // 2nd packet now has maxPacketSize length
  179. data[pkt2ofs] = 0xff
  180. data[pkt2ofs+1] = 0xff
  181. data[pkt2ofs+2] = 0xff
  182. // mark the payload start and end of the 2nd packet
  183. data[pkt2ofs+4] = 0x33
  184. data[pkt2ofs+maxPacketSize+3] = 0x44
  185. // 3rd packet has payload length 0 and squence id 2
  186. // 00 00 00 02
  187. data[pkt3ofs+3] = 0x02
  188. conn.data = data
  189. conn.reads = 0
  190. conn.maxReads = 5
  191. mc.sequence = 0
  192. packet, err = mc.readPacket()
  193. if err != nil {
  194. t.Fatal(err)
  195. }
  196. if len(packet) != 2*maxPacketSize {
  197. t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet))
  198. }
  199. if packet[0] != 0x11 {
  200. t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
  201. }
  202. if packet[2*maxPacketSize-1] != 0x44 {
  203. t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
  204. }
  205. // case 3: payload has a length larger maxPacketSize, which is not an exact
  206. // multiple of it
  207. data = data[:pkt2ofs+4+42]
  208. data[pkt2ofs] = 0x2a
  209. data[pkt2ofs+1] = 0x00
  210. data[pkt2ofs+2] = 0x00
  211. data[pkt2ofs+4+41] = 0x44
  212. conn.data = data
  213. conn.reads = 0
  214. conn.maxReads = 4
  215. mc.sequence = 0
  216. packet, err = mc.readPacket()
  217. if err != nil {
  218. t.Fatal(err)
  219. }
  220. if len(packet) != maxPacketSize+42 {
  221. t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet))
  222. }
  223. if packet[0] != 0x11 {
  224. t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
  225. }
  226. if packet[maxPacketSize+41] != 0x44 {
  227. t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
  228. }
  229. }
  230. func TestReadPacketFail(t *testing.T) {
  231. conn := new(mockConn)
  232. mc := &mysqlConn{
  233. buf: newBuffer(conn),
  234. closech: make(chan struct{}),
  235. }
  236. // illegal empty (stand-alone) packet
  237. conn.data = []byte{0x00, 0x00, 0x00, 0x00}
  238. conn.maxReads = 1
  239. _, err := mc.readPacket()
  240. if err != ErrInvalidConn {
  241. t.Errorf("expected ErrInvalidConn, got %v", err)
  242. }
  243. // reset
  244. conn.reads = 0
  245. mc.sequence = 0
  246. mc.buf = newBuffer(conn)
  247. // fail to read header
  248. conn.closed = true
  249. _, err = mc.readPacket()
  250. if err != ErrInvalidConn {
  251. t.Errorf("expected ErrInvalidConn, got %v", err)
  252. }
  253. // reset
  254. conn.closed = false
  255. conn.reads = 0
  256. mc.sequence = 0
  257. mc.buf = newBuffer(conn)
  258. // fail to read body
  259. conn.maxReads = 1
  260. _, err = mc.readPacket()
  261. if err != ErrInvalidConn {
  262. t.Errorf("expected ErrInvalidConn, got %v", err)
  263. }
  264. }
  265. // https://github.com/go-sql-driver/mysql/pull/801
  266. // not-NUL terminated plugin_name in init packet
  267. func TestRegression801(t *testing.T) {
  268. conn := new(mockConn)
  269. mc := &mysqlConn{
  270. buf: newBuffer(conn),
  271. cfg: new(Config),
  272. sequence: 42,
  273. closech: make(chan struct{}),
  274. }
  275. conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
  276. 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,
  277. 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77,
  278. 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95,
  279. 112, 97, 115, 115, 119, 111, 114, 100}
  280. conn.maxReads = 1
  281. authData, pluginName, err := mc.readHandshakePacket()
  282. if err != nil {
  283. t.Fatalf("got error: %v", err)
  284. }
  285. if pluginName != "mysql_native_password" {
  286. t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName)
  287. }
  288. expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114,
  289. 47, 85, 75, 109, 99, 51, 77, 50, 64}
  290. if !bytes.Equal(authData, expectedAuthData) {
  291. t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData)
  292. }
  293. }