// Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/sha1" "crypto/tls" "crypto/x509" "database/sql/driver" "encoding/binary" "encoding/pem" "errors" "fmt" "io" "math" "time" ) // Packets documentation: // http://dev.mysql.com/doc/internals/en/client-server-protocol.html // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte for { // read packet header data, err := mc.buf.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } errLog.Print(err) mc.Close() return nil, ErrInvalidConn } // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) // check packet sync [8 bit] if data[3] != mc.sequence { if data[3] > mc.sequence { return nil, ErrPktSyncMul } return nil, ErrPktSync } mc.sequence++ // packets with length 0 terminate a previous packet which is a // multiple of (2^24)−1 bytes long if pktLen == 0 { // there was no previous packet if prevData == nil { errLog.Print(ErrMalformPkt) mc.Close() return nil, ErrInvalidConn } return prevData, nil } // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } errLog.Print(err) mc.Close() return nil, ErrInvalidConn } // return data if this was the last packet if pktLen < maxPacketSize { // zero allocations for non-split packets if prevData == nil { return data, nil } return append(prevData, data...), nil } prevData = append(prevData, data...) } } // Write packet buffer 'data' func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } for { var size int if pktLen >= maxPacketSize { data[0] = 0xff data[1] = 0xff data[2] = 0xff size = maxPacketSize } else { data[0] = byte(pktLen) data[1] = byte(pktLen >> 8) data[2] = byte(pktLen >> 16) size = pktLen } data[3] = mc.sequence // Write packet if mc.writeTimeout > 0 { if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { return err } } n, err := mc.netConn.Write(data[:4+size]) if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { return nil } pktLen -= size data = data[size:] continue } // Handle error if err == nil { // n != len(data) mc.cleanup() errLog.Print(ErrMalformPkt) } else { if cerr := mc.canceled.Value(); cerr != nil { return cerr } if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet return errBadConnNoWrite } mc.cleanup() errLog.Print(err) } return ErrInvalidConn } } /****************************************************************************** * Initialisation Process * ******************************************************************************/ // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake func (mc *mysqlConn) readInitPacket() ([]byte, string, error) { data, err := mc.readPacket() if err != nil { // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since // in connection initialization we don't risk retrying non-idempotent actions. if err == ErrInvalidConn { return nil, "", driver.ErrBadConn } return nil, "", err } if data[0] == iERR { return nil, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { return nil, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, ) } // server version [null terminated string] // connection id [4 bytes] pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] cipher := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { return nil, "", ErrNoTLS } pos += 2 pluginName := "mysql_native_password" if len(data) > pos { // character set [1 byte] // status flags [2 bytes] // capability flags (upper 2 bytes) [2 bytes] // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] pos += 1 + 2 + 2 + 1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) // // The web documentation is ambiguous about the length. However, // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, // the 13th byte is "\0 byte, terminating the second part of // a scramble". So the second part of the password cipher is // a NULL terminated string that's at least 13 bytes with the // last byte being NULL. // // The official Python library uses the fixed length 12 // which seems to work but technically could have a hidden bug. cipher = append(cipher, data[pos:pos+12]...) pos += 13 // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) // \NUL otherwise if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { pluginName = string(data[pos : pos+end]) } else { pluginName = string(data[pos:]) } // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], cipher) return b[:], pluginName, nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], cipher) return b[:], pluginName, nil } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error { if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" { return fmt.Errorf("unknown authentication plugin name '%s'", pluginName) } // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | clientLongPassword | clientTransactions | clientLocalFiles | clientPluginAuth | clientMultiResults | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } // To enable TLS / SSL if mc.cfg.tls != nil { clientFlags |= clientSSL } if mc.cfg.MultiStatements { clientFlags |= clientMultiStatements } // User Password var scrambleBuff []byte switch pluginName { case "mysql_native_password": scrambleBuff = scramblePassword(cipher, []byte(mc.cfg.Passwd)) case "caching_sha2_password": scrambleBuff = scrambleCachingSha2Password(cipher, []byte(mc.cfg.Passwd)) } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 // To specify a db name if n := len(mc.cfg.DBName); n > 0 { clientFlags |= clientConnectWithDB pktLen += n + 1 } // Calculate packet length and get buffer with that size data := mc.buf.takeSmallBuffer(pktLen + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // ClientFlags [32 bit] data[4] = byte(clientFlags) data[5] = byte(clientFlags >> 8) data[6] = byte(clientFlags >> 16) data[7] = byte(clientFlags >> 24) // MaxPacketSize [32 bit] (none) data[8] = 0x00 data[9] = 0x00 data[10] = 0x00 data[11] = 0x00 // Charset [1 byte] var found bool data[12], found = collations[mc.cfg.Collation] if !found { // Note possibility for false negatives: // could be triggered although the collation is valid if the // collations map does not contain entries the server supports. return errors.New("unknown collation") } // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { // Send TLS / SSL request packet if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { return err } // Switch to TLS tlsConn := tls.Client(mc.netConn, mc.cfg.tls) if err := tlsConn.Handshake(); err != nil { return err } mc.netConn = tlsConn mc.buf.nc = tlsConn } // Filler [23 bytes] (all 0x00) pos := 13 for ; pos < 13+23; pos++ { data[pos] = 0 } // User [null terminated string] if len(mc.cfg.User) > 0 { pos += copy(data[pos:], mc.cfg.User) } data[pos] = 0x00 pos++ // ScrambleBuffer [length encoded integer] data[pos] = byte(len(scrambleBuff)) pos += 1 + copy(data[pos+1:], scrambleBuff) // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { pos += copy(data[pos:], mc.cfg.DBName) data[pos] = 0x00 pos++ } pos += copy(data[pos:], pluginName) data[pos] = 0x00 // Send Auth packet return mc.writePacket(data) } // Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // User password // https://dev.mysql.com/doc/internals/en/old-password-authentication.html // Old password authentication only need and will need 8-byte challenge. scrambleBuff := scrambleOldPassword(cipher[:8], []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add the scrambled password [null terminated string] copy(data[4:], scrambleBuff) data[4+pktLen-1] = 0x00 return mc.writePacket(data) } // Client clear text authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeClearAuthPacket() error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add the clear password [null terminated string] copy(data[4:], mc.cfg.Passwd) data[4+pktLen-1] = 0x00 return mc.writePacket(data) } // Native password authentication method // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. scrambleBuff := scramblePassword(cipher[0:20], []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) data := mc.buf.takeSmallBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add the scramble copy(data[4:], scrambleBuff) return mc.writePacket(data) } // Caching sha2 authentication. Public key request and send encrypted password // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writePublicKeyAuthPacket(cipher []byte) error { // request public key data := mc.buf.takeSmallBuffer(4 + 1) data[4] = cachingSha2PasswordRequestPublicKey mc.writePacket(data) data, err := mc.readPacket() if err != nil { return err } block, _ := pem.Decode(data[1:]) pub, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return err } plain := make([]byte, len(mc.cfg.Passwd)+1) copy(plain, mc.cfg.Passwd) for i := range plain { j := i % len(cipher) plain[i] ^= cipher[j] } sha1 := sha1.New() enc, _ := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil) data = mc.buf.takeSmallBuffer(4 + len(enc)) copy(data[4:], enc) return mc.writePacket(data) } /****************************************************************************** * Command Packets * ******************************************************************************/ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 data := mc.buf.takeSmallBuffer(4 + 1) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add command byte data[4] = command // Send CMD packet return mc.writePacket(data) } func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 pktLen := 1 + len(arg) data := mc.buf.takeBuffer(pktLen + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add command byte data[4] = command // Add arg copy(data[5:], arg) // Send CMD packet return mc.writePacket(data) } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 data := mc.buf.takeSmallBuffer(4 + 1 + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add command byte data[4] = command // Add arg [32 bit] data[5] = byte(arg) data[6] = byte(arg >> 8) data[7] = byte(arg >> 16) data[8] = byte(arg >> 24) // Send CMD packet return mc.writePacket(data) } /****************************************************************************** * Result Packets * ******************************************************************************/ func readAuthSwitch(data []byte) ([]byte, error) { if len(data) > 1 { pluginEndIndex := bytes.IndexByte(data, 0x00) plugin := string(data[1:pluginEndIndex]) cipher := data[pluginEndIndex+1:] switch plugin { case "mysql_old_password": // using old_passwords return cipher, ErrOldPassword case "mysql_clear_password": // using clear text password return cipher, ErrCleartextPassword case "mysql_native_password": // using mysql default authentication method return cipher, ErrNativePassword default: return cipher, ErrUnknownPlugin } } // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest return nil, ErrOldPassword } // Returns error if Packet is not an 'Result OK'-Packet func (mc *mysqlConn) readResultOK() ([]byte, error) { data, err := mc.readPacket() if err != nil { return nil, err } // packet indicator switch data[0] { case iOK: return nil, mc.handleOkPacket(data) case iAuthMoreData: return data[1:], nil case iEOF: return readAuthSwitch(data) default: // Error otherwise return nil, mc.handleErrorPacket(data) } } // Result Set Header Packet // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { data, err := mc.readPacket() if err == nil { switch data[0] { case iOK: return 0, mc.handleOkPacket(data) case iERR: return 0, mc.handleErrorPacket(data) case iLocalInFile: return 0, mc.handleInFileRequest(string(data[1:])) } // column count num, _, n := readLengthEncodedInteger(data) if n-len(data) == 0 { return int(num), nil } return 0, ErrMalformPkt } return 0, err } // Error Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet func (mc *mysqlConn) handleErrorPacket(data []byte) error { if data[0] != iERR { return ErrMalformPkt } // 0xff [1 byte] // Error Number [16 bit uint] errno := binary.LittleEndian.Uint16(data[1:3]) // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { // Oops; we are connected to a read-only connection, and won't be able // to issue any write statements. Since RejectReadOnly is configured, // we throw away this connection hoping this one would have write // permission. This is specifically for a possible race condition // during failover (e.g. on AWS Aurora). See README.md for more. // // We explicitly close the connection before returning // driver.ErrBadConn to ensure that `database/sql` purges this // connection and initiates a new one for next statement next time. mc.Close() return driver.ErrBadConn } pos := 3 // SQL State [optional: # + 5bytes string] if data[3] == 0x23 { //sqlstate := string(data[4 : 4+5]) pos = 9 } // Error Message [string] return &MySQLError{ Number: errno, Message: string(data[pos:]), } } func readStatus(b []byte) statusFlag { return statusFlag(b[0]) | statusFlag(b[1])<<8 } // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { var n, m int // 0x00 [1 byte] // Affected rows [Length Coded Binary] mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) // Insert id [Length Coded Binary] mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) if mc.status&statusMoreResultsExists != 0 { return nil } // warning count [2 bytes] return nil } // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) for i := 0; ; i++ { data, err := mc.readPacket() if err != nil { return nil, err } // EOF Packet if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { if i == count { return columns, nil } return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) } // Catalog pos, err := skipLengthEncodedString(data) if err != nil { return nil, err } // Database [len coded string] n, err := skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n // Table [len coded string] if mc.cfg.ColumnsWithAlias { tableName, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n columns[i].tableName = string(tableName) } else { n, err = skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n } // Original table [len coded string] n, err = skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n // Name [len coded string] name, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err } columns[i].name = string(name) pos += n // Original name [len coded string] n, err = skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n // Filler [uint8] pos++ // Charset [charset, collation uint8] columns[i].charSet = data[pos] pos += 2 // Length [uint32] columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) pos += 4 // Field type [uint8] columns[i].fieldType = fieldType(data[pos]) pos++ // Flags [uint16] columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) pos += 2 // Decimals [uint8] columns[i].decimals = data[pos] //pos++ // Default value [len coded binary] //if pos < len(data) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} } } // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc if rows.rs.done { return io.EOF } data, err := mc.readPacket() if err != nil { return err } // EOF Packet if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } if data[0] == iERR { rows.mc = nil return mc.handleErrorPacket(data) } // RowSet Packet var n int var isNull bool pos := 0 for i := range dest { // Read bytes and convert to string dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) pos += n if err == nil { if !isNull { if !mc.parseTime { continue } else { switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( string(dest[i].([]byte)), mc.cfg.Loc, ) if err == nil { continue } default: continue } } } else { dest[i] = nil continue } } return err // err != nil } return nil } // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read func (mc *mysqlConn) readUntilEOF() error { for { data, err := mc.readPacket() if err != nil { return err } switch data[0] { case iERR: return mc.handleErrorPacket(data) case iEOF: if len(data) == 5 { mc.status = readStatus(data[3:]) } return nil } } } /****************************************************************************** * Prepared Statements * ******************************************************************************/ // Prepare Result Packets // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { data, err := stmt.mc.readPacket() if err == nil { // packet indicator [1 byte] if data[0] != iOK { return 0, stmt.mc.handleErrorPacket(data) } // statement id [4 bytes] stmt.id = binary.LittleEndian.Uint32(data[1:5]) // Column count [16 bit uint] columnCount := binary.LittleEndian.Uint16(data[5:7]) // Param count [16 bit uint] stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) // Reserved [8 bit] // Warning count [16 bit uint] return columnCount, nil } return 0, err } // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen // After the header (bytes 0-3) follows before the data: // 1 byte command // 4 bytes stmtID // 2 bytes paramID const dataOffset = 1 + 4 + 2 // Can not use the write buffer since // a) the buffer is too small // b) it is in use data := make([]byte, 4+1+4+2+len(arg)) copy(data[4+dataOffset:], arg) for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { if dataOffset+argLen < maxLen { pktLen = dataOffset + argLen } stmt.mc.sequence = 0 // Add command byte [1 byte] data[4] = comStmtSendLongData // Add stmtID [32 bit] data[5] = byte(stmt.id) data[6] = byte(stmt.id >> 8) data[7] = byte(stmt.id >> 16) data[8] = byte(stmt.id >> 24) // Add paramID [16 bit] data[9] = byte(paramID) data[10] = byte(paramID >> 8) // Send CMD packet err := stmt.mc.writePacket(data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue } return err } // Reset Packet Sequence stmt.mc.sequence = 0 return nil } // Execute Prepared Statement // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( "argument count mismatch (got: %d; has: %d)", len(args), stmt.paramCount, ) } const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc // Determine threshould dynamically to avoid packet size shortage. longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) if longDataSize < 64 { longDataSize = 64 } // Reset packet-sequence mc.sequence = 0 var data []byte if len(args) == 0 { data = mc.buf.takeBuffer(minPktLen) } else { data = mc.buf.takeCompleteBuffer() } if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // command [1 byte] data[4] = comStmtExecute // statement_id [4 bytes] data[5] = byte(stmt.id) data[6] = byte(stmt.id >> 8) data[7] = byte(stmt.id >> 16) data[8] = byte(stmt.id >> 24) // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] data[9] = 0x00 // iteration_count (uint32(1)) [4 bytes] data[10] = 0x01 data[11] = 0x00 data[12] = 0x00 data[13] = 0x00 if len(args) > 0 { pos := minPktLen var nullMask []byte if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here // which makes the required allocation size hard to guess. tmp := make([]byte, pos+maskLen+typesLen) copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] pos += maskLen } else { nullMask = data[pos : pos+maskLen] for i := 0; i < maskLen; i++ { nullMask[i] = 0 } pos += maskLen } // newParameterBoundFlag 1 [1 byte] data[pos] = 0x01 pos++ // type of each parameter [len(args)*2 bytes] paramTypes := data[pos:] pos += len(args) * 2 // value of each parameter [n bytes] paramValues := data[pos:pos] valuesCap := cap(paramValues) for i, arg := range args { // build NULL-bitmap if arg == nil { nullMask[i/8] |= 1 << (uint(i) & 7) paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 continue } // cache types and values switch v := arg.(type) { case int64: paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] binary.LittleEndian.PutUint64( paramValues[len(paramValues)-8:], uint64(v), ) } else { paramValues = append(paramValues, uint64ToBytes(uint64(v))..., ) } case float64: paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] binary.LittleEndian.PutUint64( paramValues[len(paramValues)-8:], math.Float64bits(v), ) } else { paramValues = append(paramValues, uint64ToBytes(math.Float64bits(v))..., ) } case bool: paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i+1] = 0x00 if v { paramValues = append(paramValues, 0x01) } else { paramValues = append(paramValues, 0x00) } case []byte: // Common case (non-nil value) first if v != nil { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) paramValues = append(paramValues, v...) } else { if err := stmt.writeCommandLongData(i, v); err != nil { return err } } continue } // Handle []byte(nil) as a NULL value nullMask[i/8] |= 1 << (uint(i) & 7) paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 case string: paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) paramValues = append(paramValues, v...) } else { if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { return err } } case time.Time: paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 var a [64]byte var b = a[:0] if v.IsZero() { b = append(b, "0000-00-00"...) } else { b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) } paramValues = appendLengthEncodedInteger(paramValues, uint64(len(b)), ) paramValues = append(paramValues, b...) default: return fmt.Errorf("can not convert type: %T", arg) } } // Check if param values exceeded the available buffer // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) mc.buf.buf = data } pos += len(paramValues) data = data[:pos] } return mc.writePacket(data) } func (mc *mysqlConn) discardResults() error { for mc.status&statusMoreResultsExists != 0 { resLen, err := mc.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns if err := mc.readUntilEOF(); err != nil { return err } // rows if err := mc.readUntilEOF(); err != nil { return err } } } return nil } // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *binaryRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() if err != nil { return err } // packet indicator [1 byte] if data[0] != iOK { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } mc := rows.mc rows.mc = nil // Error otherwise return mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] pos := 1 + (len(dest)+7+2)>>3 nullMask := data[1:pos] for i := range dest { // Field is NULL // (byte >> bit-pos) % 2 == 1 if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { dest[i] = nil continue } // Convert to byte-coded string switch rows.rs.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) } pos++ continue case fieldTypeShort, fieldTypeYear: if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) } pos += 2 continue case fieldTypeInt24, fieldTypeLong: if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) } pos += 4 continue case fieldTypeLongLong: if rows.rs.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) } else { dest[i] = int64(val) } } else { dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) } pos += 8 continue case fieldTypeFloat: dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 continue case fieldTypeDouble: dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) pos += 8 continue // Length coded Binary Strings case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: var isNull bool var n int dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) pos += n if err == nil { if !isNull { continue } else { dest[i] = nil continue } } return err case fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n switch { case isNull: dest[i] = nil continue case rows.rs.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: dstlen = 8 + 1 + decimals default: return fmt.Errorf( "protocol error, illegal decimals value %d", rows.rs.columns[i].decimals, ) } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) case rows.mc.parseTime: dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 if rows.rs.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: dstlen = 19 + 1 + decimals default: return fmt.Errorf( "protocol error, illegal decimals value %d", rows.rs.columns[i].decimals, ) } } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) } if err == nil { pos += int(num) continue } else { return err } // Please report if this happens! default: return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) } } return nil }