|
- // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
- //
- // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this file,
- // You can obtain one at http://mozilla.org/MPL/2.0/.
- package mysql
- import (
- "bytes"
- "errors"
- "net"
- "testing"
- "time"
- )
- var (
- errConnClosed = errors.New("connection is closed")
- errConnTooManyReads = errors.New("too many reads")
- errConnTooManyWrites = errors.New("too many writes")
- )
- // struct to mock a net.Conn for testing purposes
- type mockConn struct {
- laddr net.Addr
- raddr net.Addr
- data []byte
- written []byte
- queuedReplies [][]byte
- closed bool
- read int
- reads int
- writes int
- maxReads int
- maxWrites int
- }
- func (m *mockConn) Read(b []byte) (n int, err error) {
- if m.closed {
- return 0, errConnClosed
- }
- m.reads++
- if m.maxReads > 0 && m.reads > m.maxReads {
- return 0, errConnTooManyReads
- }
- n = copy(b, m.data)
- m.read += n
- m.data = m.data[n:]
- return
- }
- func (m *mockConn) Write(b []byte) (n int, err error) {
- if m.closed {
- return 0, errConnClosed
- }
- m.writes++
- if m.maxWrites > 0 && m.writes > m.maxWrites {
- return 0, errConnTooManyWrites
- }
- n = len(b)
- m.written = append(m.written, b...)
- if n > 0 && len(m.queuedReplies) > 0 {
- m.data = m.queuedReplies[0]
- m.queuedReplies = m.queuedReplies[1:]
- }
- return
- }
- func (m *mockConn) Close() error {
- m.closed = true
- return nil
- }
- func (m *mockConn) LocalAddr() net.Addr {
- return m.laddr
- }
- func (m *mockConn) RemoteAddr() net.Addr {
- return m.raddr
- }
- func (m *mockConn) SetDeadline(t time.Time) error {
- return nil
- }
- func (m *mockConn) SetReadDeadline(t time.Time) error {
- return nil
- }
- func (m *mockConn) SetWriteDeadline(t time.Time) error {
- return nil
- }
- // make sure mockConn implements the net.Conn interface
- var _ net.Conn = new(mockConn)
- func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
- conn := new(mockConn)
- mc := &mysqlConn{
- buf: newBuffer(conn),
- cfg: NewConfig(),
- netConn: conn,
- closech: make(chan struct{}),
- maxAllowedPacket: defaultMaxAllowedPacket,
- sequence: sequence,
- }
- return conn, mc
- }
- func TestReadPacketSingleByte(t *testing.T) {
- conn := new(mockConn)
- mc := &mysqlConn{
- buf: newBuffer(conn),
- }
- conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
- conn.maxReads = 1
- packet, err := mc.readPacket()
- if err != nil {
- t.Fatal(err)
- }
- if len(packet) != 1 {
- t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet))
- }
- if packet[0] != 0xff {
- t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
- }
- }
- func TestReadPacketWrongSequenceID(t *testing.T) {
- conn := new(mockConn)
- mc := &mysqlConn{
- buf: newBuffer(conn),
- }
- // too low sequence id
- conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
- conn.maxReads = 1
- mc.sequence = 1
- _, err := mc.readPacket()
- if err != ErrPktSync {
- t.Errorf("expected ErrPktSync, got %v", err)
- }
- // reset
- conn.reads = 0
- mc.sequence = 0
- mc.buf = newBuffer(conn)
- // too high sequence id
- conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
- _, err = mc.readPacket()
- if err != ErrPktSyncMul {
- t.Errorf("expected ErrPktSyncMul, got %v", err)
- }
- }
- func TestReadPacketSplit(t *testing.T) {
- conn := new(mockConn)
- mc := &mysqlConn{
- buf: newBuffer(conn),
- }
- data := make([]byte, maxPacketSize*2+4*3)
- const pkt2ofs = maxPacketSize + 4
- const pkt3ofs = 2 * (maxPacketSize + 4)
- // case 1: payload has length maxPacketSize
- data = data[:pkt2ofs+4]
- // 1st packet has maxPacketSize length and sequence id 0
- // ff ff ff 00 ...
- data[0] = 0xff
- data[1] = 0xff
- data[2] = 0xff
- // mark the payload start and end of 1st packet so that we can check if the
- // content was correctly appended
- data[4] = 0x11
- data[maxPacketSize+3] = 0x22
- // 2nd packet has payload length 0 and squence id 1
- // 00 00 00 01
- data[pkt2ofs+3] = 0x01
- conn.data = data
- conn.maxReads = 3
- packet, err := mc.readPacket()
- if err != nil {
- t.Fatal(err)
- }
- if len(packet) != maxPacketSize {
- t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet))
- }
- if packet[0] != 0x11 {
- t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
- }
- if packet[maxPacketSize-1] != 0x22 {
- t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
- }
- // case 2: payload has length which is a multiple of maxPacketSize
- data = data[:cap(data)]
- // 2nd packet now has maxPacketSize length
- data[pkt2ofs] = 0xff
- data[pkt2ofs+1] = 0xff
- data[pkt2ofs+2] = 0xff
- // mark the payload start and end of the 2nd packet
- data[pkt2ofs+4] = 0x33
- data[pkt2ofs+maxPacketSize+3] = 0x44
- // 3rd packet has payload length 0 and squence id 2
- // 00 00 00 02
- data[pkt3ofs+3] = 0x02
- conn.data = data
- conn.reads = 0
- conn.maxReads = 5
- mc.sequence = 0
- packet, err = mc.readPacket()
- if err != nil {
- t.Fatal(err)
- }
- if len(packet) != 2*maxPacketSize {
- t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet))
- }
- if packet[0] != 0x11 {
- t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
- }
- if packet[2*maxPacketSize-1] != 0x44 {
- t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
- }
- // case 3: payload has a length larger maxPacketSize, which is not an exact
- // multiple of it
- data = data[:pkt2ofs+4+42]
- data[pkt2ofs] = 0x2a
- data[pkt2ofs+1] = 0x00
- data[pkt2ofs+2] = 0x00
- data[pkt2ofs+4+41] = 0x44
- conn.data = data
- conn.reads = 0
- conn.maxReads = 4
- mc.sequence = 0
- packet, err = mc.readPacket()
- if err != nil {
- t.Fatal(err)
- }
- if len(packet) != maxPacketSize+42 {
- t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet))
- }
- if packet[0] != 0x11 {
- t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
- }
- if packet[maxPacketSize+41] != 0x44 {
- t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
- }
- }
- func TestReadPacketFail(t *testing.T) {
- conn := new(mockConn)
- mc := &mysqlConn{
- buf: newBuffer(conn),
- closech: make(chan struct{}),
- }
- // illegal empty (stand-alone) packet
- conn.data = []byte{0x00, 0x00, 0x00, 0x00}
- conn.maxReads = 1
- _, err := mc.readPacket()
- if err != ErrInvalidConn {
- t.Errorf("expected ErrInvalidConn, got %v", err)
- }
- // reset
- conn.reads = 0
- mc.sequence = 0
- mc.buf = newBuffer(conn)
- // fail to read header
- conn.closed = true
- _, err = mc.readPacket()
- if err != ErrInvalidConn {
- t.Errorf("expected ErrInvalidConn, got %v", err)
- }
- // reset
- conn.closed = false
- conn.reads = 0
- mc.sequence = 0
- mc.buf = newBuffer(conn)
- // fail to read body
- conn.maxReads = 1
- _, err = mc.readPacket()
- if err != ErrInvalidConn {
- t.Errorf("expected ErrInvalidConn, got %v", err)
- }
- }
- // https://github.com/go-sql-driver/mysql/pull/801
- // not-NUL terminated plugin_name in init packet
- func TestRegression801(t *testing.T) {
- conn := new(mockConn)
- mc := &mysqlConn{
- buf: newBuffer(conn),
- cfg: new(Config),
- sequence: 42,
- closech: make(chan struct{}),
- }
- conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
- 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77,
- 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95,
- 112, 97, 115, 115, 119, 111, 114, 100}
- conn.maxReads = 1
- authData, pluginName, err := mc.readHandshakePacket()
- if err != nil {
- t.Fatalf("got error: %v", err)
- }
- if pluginName != "mysql_native_password" {
- t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName)
- }
- expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114,
- 47, 85, 75, 109, 99, 51, 77, 50, 64}
- if !bytes.Equal(authData, expectedAuthData) {
- t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData)
- }
- }
|