packets.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 Julien Schmidt. All rights reserved.
  4. // http://www.julienschmidt.com
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla Public
  7. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  8. // You can obtain one at http://mozilla.org/MPL/2.0/.
  9. package mysql
  10. import (
  11. "bytes"
  12. "database/sql/driver"
  13. "encoding/binary"
  14. "errors"
  15. "fmt"
  16. "io"
  17. "math"
  18. "time"
  19. )
  20. // Packets documentation:
  21. // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
  22. // Read packet to buffer 'data'
  23. func (mc *mysqlConn) readPacket() (data []byte, err error) {
  24. // Read packet header
  25. data, err = mc.buf.readNext(4)
  26. if err != nil {
  27. errLog.Print(err.Error())
  28. return nil, driver.ErrBadConn
  29. }
  30. // Packet Length [24 bit]
  31. pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
  32. if pktLen < 1 {
  33. errLog.Print(errMalformPkt.Error())
  34. return nil, driver.ErrBadConn
  35. }
  36. // Check Packet Sync [8 bit]
  37. if data[3] != mc.sequence {
  38. if data[3] > mc.sequence {
  39. return nil, errPktSyncMul
  40. } else {
  41. return nil, errPktSync
  42. }
  43. }
  44. mc.sequence++
  45. // Read packet body [pktLen bytes]
  46. data, err = mc.buf.readNext(pktLen)
  47. if err == nil {
  48. if pktLen < maxPacketSize {
  49. return data, nil
  50. }
  51. // More data
  52. var data2 []byte
  53. data2, err = mc.readPacket()
  54. if err == nil {
  55. return append(data, data2...), nil
  56. }
  57. }
  58. errLog.Print(err.Error())
  59. return nil, driver.ErrBadConn
  60. }
  61. // Write packet buffer 'data'
  62. // The packet header must be already included
  63. func (mc *mysqlConn) writePacket(data []byte) error {
  64. if len(data)-4 <= mc.maxWriteSize { // Can send data at once
  65. // Write packet
  66. n, err := mc.netConn.Write(data)
  67. if err == nil && n == len(data) {
  68. mc.sequence++
  69. return nil
  70. }
  71. // Handle error
  72. if err == nil { // n != len(data)
  73. errLog.Print(errMalformPkt.Error())
  74. } else {
  75. errLog.Print(err.Error())
  76. }
  77. return driver.ErrBadConn
  78. }
  79. // Must split packet
  80. return mc.splitPacket(data)
  81. }
  82. func (mc *mysqlConn) splitPacket(data []byte) (err error) {
  83. pktLen := len(data) - 4
  84. if pktLen > mc.maxPacketAllowed {
  85. return errPktTooLarge
  86. }
  87. for pktLen >= maxPacketSize {
  88. data[0] = 0xff
  89. data[1] = 0xff
  90. data[2] = 0xff
  91. data[3] = mc.sequence
  92. // Write packet
  93. n, err := mc.netConn.Write(data[:4+maxPacketSize])
  94. if err == nil && n == 4+maxPacketSize {
  95. mc.sequence++
  96. data = data[maxPacketSize:]
  97. pktLen -= maxPacketSize
  98. continue
  99. }
  100. // Handle error
  101. if err == nil { // n != len(data)
  102. errLog.Print(errMalformPkt.Error())
  103. } else {
  104. errLog.Print(err.Error())
  105. }
  106. return driver.ErrBadConn
  107. }
  108. data[0] = byte(pktLen)
  109. data[1] = byte(pktLen >> 8)
  110. data[2] = byte(pktLen >> 16)
  111. data[3] = mc.sequence
  112. return mc.writePacket(data)
  113. }
  114. /******************************************************************************
  115. * Initialisation Process *
  116. ******************************************************************************/
  117. // Handshake Initialization Packet
  118. // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
  119. func (mc *mysqlConn) readInitPacket() (err error) {
  120. data, err := mc.readPacket()
  121. if err != nil {
  122. return
  123. }
  124. if data[0] == iERR {
  125. return mc.handleErrorPacket(data)
  126. }
  127. // protocol version [1 byte]
  128. if data[0] < minProtocolVersion {
  129. err = fmt.Errorf(
  130. "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
  131. data[0],
  132. minProtocolVersion)
  133. }
  134. // server version [null terminated string]
  135. // connection id [4 bytes]
  136. pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
  137. // first part of the password cipher [8 bytes]
  138. mc.cipher = append(mc.cipher, data[pos:pos+8]...)
  139. // (filler) always 0x00 [1 byte]
  140. pos += 8 + 1
  141. // capability flags (lower 2 bytes) [2 bytes]
  142. mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
  143. if mc.flags&clientProtocol41 == 0 {
  144. err = errors.New("MySQL-Server does not support required Protocol 41+")
  145. }
  146. pos += 2
  147. if len(data) > pos {
  148. // character set [1 byte]
  149. mc.charset = data[pos]
  150. // status flags [2 bytes]
  151. // capability flags (upper 2 bytes) [2 bytes]
  152. // length of auth-plugin-data [1 byte]
  153. // reserved (all [00]) [10 bytes]
  154. pos += 1 + 2 + 2 + 1 + 10
  155. // second part of the password cipher [12? bytes]
  156. // The documentation is ambiguous about the length.
  157. // The official Python library uses the fixed length 12
  158. // which is not documented but seems to work.
  159. mc.cipher = append(mc.cipher, data[pos:pos+12]...)
  160. if data[len(data)-1] == 0 {
  161. return
  162. }
  163. return errMalformPkt
  164. }
  165. return
  166. }
  167. // Client Authentication Packet
  168. // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
  169. func (mc *mysqlConn) writeAuthPacket() error {
  170. // Adjust client flags based on server support
  171. clientFlags := uint32(
  172. clientProtocol41 |
  173. clientSecureConn |
  174. clientLongPassword |
  175. clientTransactions |
  176. clientLocalFiles,
  177. )
  178. if mc.flags&clientLongFlag > 0 {
  179. clientFlags |= uint32(clientLongFlag)
  180. }
  181. // User Password
  182. scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
  183. mc.cipher = nil
  184. pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
  185. // To specify a db name
  186. if len(mc.cfg.dbname) > 0 {
  187. clientFlags |= uint32(clientConnectWithDB)
  188. pktLen += len(mc.cfg.dbname) + 1
  189. }
  190. // Calculate packet length and make buffer with that size
  191. data := make([]byte, pktLen+4)
  192. // Add the packet header [24bit length + 1 byte sequence]
  193. data[0] = byte(pktLen)
  194. data[1] = byte(pktLen >> 8)
  195. data[2] = byte(pktLen >> 16)
  196. data[3] = mc.sequence
  197. // ClientFlags [32 bit]
  198. data[4] = byte(clientFlags)
  199. data[5] = byte(clientFlags >> 8)
  200. data[6] = byte(clientFlags >> 16)
  201. data[7] = byte(clientFlags >> 24)
  202. // MaxPacketSize [32 bit] (none)
  203. //data[8] = 0x00
  204. //data[9] = 0x00
  205. //data[10] = 0x00
  206. //data[11] = 0x00
  207. // Charset [1 byte]
  208. data[12] = mc.charset
  209. // Filler [23 bytes] (all 0x00)
  210. pos := 13 + 23
  211. // User [null terminated string]
  212. if len(mc.cfg.user) > 0 {
  213. pos += copy(data[pos:], mc.cfg.user)
  214. }
  215. //data[pos] = 0x00
  216. pos++
  217. // ScrambleBuffer [length encoded integer]
  218. data[pos] = byte(len(scrambleBuff))
  219. pos += 1 + copy(data[pos+1:], scrambleBuff)
  220. // Databasename [null terminated string]
  221. if len(mc.cfg.dbname) > 0 {
  222. pos += copy(data[pos:], mc.cfg.dbname)
  223. //data[pos] = 0x00
  224. }
  225. // Send Auth packet
  226. return mc.writePacket(data)
  227. }
  228. /******************************************************************************
  229. * Command Packets *
  230. ******************************************************************************/
  231. func (mc *mysqlConn) writeCommandPacket(command byte) error {
  232. // Reset Packet Sequence
  233. mc.sequence = 0
  234. // Send CMD packet
  235. return mc.writePacket([]byte{
  236. // Add the packet header [24bit length + 1 byte sequence]
  237. 0x05, // 5 bytes long
  238. 0x00,
  239. 0x00,
  240. 0x00, // mc.sequence
  241. // Add command byte
  242. command,
  243. })
  244. }
  245. func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
  246. // Reset Packet Sequence
  247. mc.sequence = 0
  248. pktLen := 1 + len(arg)
  249. data := make([]byte, pktLen+4)
  250. // Add the packet header [24bit length + 1 byte sequence]
  251. data[0] = byte(pktLen)
  252. data[1] = byte(pktLen >> 8)
  253. data[2] = byte(pktLen >> 16)
  254. //data[3] = mc.sequence
  255. // Add command byte
  256. data[4] = command
  257. // Add arg
  258. copy(data[5:], arg)
  259. // Send CMD packet
  260. return mc.writePacket(data)
  261. }
  262. func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
  263. // Reset Packet Sequence
  264. mc.sequence = 0
  265. // Send CMD packet
  266. return mc.writePacket([]byte{
  267. // Add the packet header [24bit length + 1 byte sequence]
  268. 0x05, // 5 bytes long
  269. 0x00,
  270. 0x00,
  271. 0x00, // mc.sequence
  272. // Add command byte
  273. command,
  274. // Add arg [32 bit]
  275. byte(arg),
  276. byte(arg >> 8),
  277. byte(arg >> 16),
  278. byte(arg >> 24),
  279. })
  280. }
  281. /******************************************************************************
  282. * Result Packets *
  283. ******************************************************************************/
  284. // Returns error if Packet is not an 'Result OK'-Packet
  285. func (mc *mysqlConn) readResultOK() error {
  286. data, err := mc.readPacket()
  287. if err == nil {
  288. // packet indicator
  289. switch data[0] {
  290. case iOK:
  291. mc.handleOkPacket(data)
  292. return nil
  293. case iEOF: // someone is using old_passwords
  294. return errOldPassword
  295. default: // Error otherwise
  296. return mc.handleErrorPacket(data)
  297. }
  298. }
  299. return err
  300. }
  301. // Result Set Header Packet
  302. // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
  303. func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
  304. data, err := mc.readPacket()
  305. if err == nil {
  306. switch data[0] {
  307. case iOK:
  308. mc.handleOkPacket(data)
  309. return 0, nil
  310. case iERR:
  311. return 0, mc.handleErrorPacket(data)
  312. case iLocalInFile:
  313. return 0, mc.handleInFileRequest(string(data[1:]))
  314. }
  315. // column count
  316. num, _, n := readLengthEncodedInteger(data)
  317. if n-len(data) == 0 {
  318. return int(num), nil
  319. }
  320. return 0, errMalformPkt
  321. }
  322. return 0, err
  323. }
  324. // Error Packet
  325. // http://dev.mysql.com/doc/internals/en/overview.html#packet-ERR_Packet
  326. func (mc *mysqlConn) handleErrorPacket(data []byte) error {
  327. if data[0] != iERR {
  328. return errMalformPkt
  329. }
  330. // 0xff [1 byte]
  331. // Error Number [16 bit uint]
  332. errno := binary.LittleEndian.Uint16(data[1:3])
  333. pos := 3
  334. // SQL State [optional: # + 5bytes string]
  335. //sqlstate := string(data[pos : pos+6])
  336. if data[pos] == 0x23 {
  337. pos = 9
  338. }
  339. // Error Message [string]
  340. return fmt.Errorf("Error %d: %s", errno, string(data[pos:]))
  341. }
  342. // Ok Packet
  343. // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
  344. func (mc *mysqlConn) handleOkPacket(data []byte) {
  345. var n int
  346. // 0x00 [1 byte]
  347. // Affected rows [Length Coded Binary]
  348. mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
  349. // Insert id [Length Coded Binary]
  350. mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
  351. // server_status [2 bytes]
  352. // warning count [2 bytes]
  353. // message [until end of packet]
  354. }
  355. // Read Packets as Field Packets until EOF-Packet or an Error appears
  356. // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41
  357. func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
  358. var data []byte
  359. var i, pos, n int
  360. var name []byte
  361. columns = make([]mysqlField, count)
  362. for {
  363. data, err = mc.readPacket()
  364. if err != nil {
  365. return
  366. }
  367. // EOF Packet
  368. if data[0] == iEOF && len(data) == 5 {
  369. if i != count {
  370. err = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
  371. }
  372. return
  373. }
  374. // Catalog
  375. pos, err = skipLengthEnodedString(data)
  376. if err != nil {
  377. return
  378. }
  379. // Database [len coded string]
  380. n, err = skipLengthEnodedString(data[pos:])
  381. if err != nil {
  382. return
  383. }
  384. pos += n
  385. // Table [len coded string]
  386. n, err = skipLengthEnodedString(data[pos:])
  387. if err != nil {
  388. return
  389. }
  390. pos += n
  391. // Original table [len coded string]
  392. n, err = skipLengthEnodedString(data[pos:])
  393. if err != nil {
  394. return
  395. }
  396. pos += n
  397. // Name [len coded string]
  398. name, _, n, err = readLengthEnodedString(data[pos:])
  399. if err != nil {
  400. return
  401. }
  402. columns[i].name = string(name)
  403. pos += n
  404. // Original name [len coded string]
  405. n, err = skipLengthEnodedString(data[pos:])
  406. if err != nil {
  407. return
  408. }
  409. // Filler [1 byte]
  410. // Charset [16 bit uint]
  411. // Length [32 bit uint]
  412. pos += n + 1 + 2 + 4
  413. // Field type [byte]
  414. columns[i].fieldType = data[pos]
  415. pos++
  416. // Flags [16 bit uint]
  417. columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
  418. //pos += 2
  419. // Decimals [8 bit uint]
  420. //pos++
  421. // Default value [len coded binary]
  422. //if pos < len(data) {
  423. // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
  424. //}
  425. i++
  426. }
  427. return
  428. }
  429. // Read Packets as Field Packets until EOF-Packet or an Error appears
  430. // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow
  431. func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
  432. data, err := rows.mc.readPacket()
  433. if err != nil {
  434. return
  435. }
  436. // EOF Packet
  437. if data[0] == iEOF && len(data) == 5 {
  438. return io.EOF
  439. }
  440. // RowSet Packet
  441. var n int
  442. var isNull bool
  443. pos := 0
  444. for i := range dest {
  445. // Read bytes and convert to string
  446. dest[i], isNull, n, err = readLengthEnodedString(data[pos:])
  447. pos += n
  448. if err == nil {
  449. if !isNull {
  450. continue
  451. } else {
  452. dest[i] = nil
  453. continue
  454. }
  455. }
  456. return // err
  457. }
  458. return
  459. }
  460. // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
  461. func (mc *mysqlConn) readUntilEOF() (err error) {
  462. var data []byte
  463. for {
  464. data, err = mc.readPacket()
  465. // No Err and no EOF Packet
  466. if err == nil && (data[0] != iEOF || len(data) != 5) {
  467. continue
  468. }
  469. return // Err or EOF
  470. }
  471. return
  472. }
  473. /******************************************************************************
  474. * Prepared Statements *
  475. ******************************************************************************/
  476. // Prepare Result Packets
  477. // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
  478. func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
  479. data, err := stmt.mc.readPacket()
  480. if err == nil {
  481. // Position
  482. pos := 0
  483. // packet indicator [1 byte]
  484. if data[pos] != iOK {
  485. err = stmt.mc.handleErrorPacket(data)
  486. return
  487. }
  488. pos++
  489. // statement id [4 bytes]
  490. stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
  491. pos += 4
  492. // Column count [16 bit uint]
  493. columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
  494. pos += 2
  495. // Param count [16 bit uint]
  496. stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
  497. pos += 2
  498. // Warning count [16 bit uint]
  499. // bytesToUint16(data[pos : pos+2])
  500. }
  501. return
  502. }
  503. // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data
  504. func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) {
  505. maxLen := stmt.mc.maxPacketAllowed - 1
  506. pktLen := maxLen
  507. argLen := len(arg)
  508. data := make([]byte, 4+1+4+2+argLen)
  509. copy(data[4+1+4+2:], arg)
  510. for argLen > 0 {
  511. if 1+4+2+argLen < maxLen {
  512. pktLen = 1 + 4 + 2 + argLen
  513. }
  514. // Add the packet header [24bit length + 1 byte sequence]
  515. data[0] = byte(pktLen)
  516. data[1] = byte(pktLen >> 8)
  517. data[2] = byte(pktLen >> 16)
  518. data[3] = 0x00 // mc.sequence
  519. // Add command byte [1 byte]
  520. data[4] = comStmtSendLongData
  521. // Add stmtID [32 bit]
  522. data[5] = byte(stmt.id)
  523. data[6] = byte(stmt.id >> 8)
  524. data[7] = byte(stmt.id >> 16)
  525. data[8] = byte(stmt.id >> 24)
  526. // Add paramID [16 bit]
  527. data[9] = byte(paramID)
  528. data[10] = byte(paramID >> 8)
  529. // Send CMD packet
  530. err = stmt.mc.writePacket(data[:4+pktLen])
  531. if err == nil {
  532. argLen -= pktLen - (1 + 4 + 2)
  533. data = data[pktLen-(1+4+2):]
  534. continue
  535. }
  536. return err
  537. }
  538. // Reset Packet Sequence
  539. stmt.mc.sequence = 0
  540. return nil
  541. }
  542. // Execute Prepared Statement
  543. // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute
  544. func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
  545. if len(args) != stmt.paramCount {
  546. return fmt.Errorf(
  547. "Arguments count mismatch (Got: %d Has: %d",
  548. len(args),
  549. stmt.paramCount)
  550. }
  551. // Reset packet-sequence
  552. stmt.mc.sequence = 0
  553. pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (stmt.paramCount << 1)
  554. paramValues := make([][]byte, stmt.paramCount)
  555. paramTypes := make([]byte, (stmt.paramCount << 1))
  556. bitMask := uint64(0)
  557. var i int
  558. for i = range args {
  559. // build NULL-bitmap
  560. if args[i] == nil {
  561. bitMask += 1 << uint(i)
  562. paramTypes[i<<1] = fieldTypeNULL
  563. continue
  564. }
  565. // cache types and values
  566. switch v := args[i].(type) {
  567. case int64:
  568. paramTypes[i<<1] = fieldTypeLongLong
  569. paramValues[i] = uint64ToBytes(uint64(v))
  570. pktLen += 8
  571. continue
  572. case float64:
  573. paramTypes[i<<1] = fieldTypeDouble
  574. paramValues[i] = uint64ToBytes(math.Float64bits(v))
  575. pktLen += 8
  576. continue
  577. case bool:
  578. paramTypes[i<<1] = fieldTypeTiny
  579. pktLen++
  580. if v {
  581. paramValues[i] = []byte{0x01}
  582. } else {
  583. paramValues[i] = []byte{0x00}
  584. }
  585. continue
  586. case []byte:
  587. paramTypes[i<<1] = fieldTypeString
  588. if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 {
  589. paramValues[i] = append(
  590. lengthEncodedIntegerToBytes(uint64(len(v))),
  591. v...,
  592. )
  593. pktLen += len(paramValues[i])
  594. continue
  595. } else {
  596. err := stmt.writeCommandLongData(i, v)
  597. if err == nil {
  598. continue
  599. }
  600. return err
  601. }
  602. case string:
  603. paramTypes[i<<1] = fieldTypeString
  604. if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 {
  605. paramValues[i] = append(
  606. lengthEncodedIntegerToBytes(uint64(len(v))),
  607. []byte(v)...,
  608. )
  609. pktLen += len(paramValues[i])
  610. continue
  611. } else {
  612. err := stmt.writeCommandLongData(i, []byte(v))
  613. if err == nil {
  614. continue
  615. }
  616. return err
  617. }
  618. case time.Time:
  619. paramTypes[i<<1] = fieldTypeString
  620. val := []byte(v.Format(timeFormat))
  621. paramValues[i] = append(
  622. lengthEncodedIntegerToBytes(uint64(len(val))),
  623. val...,
  624. )
  625. pktLen += len(paramValues[i])
  626. continue
  627. default:
  628. return fmt.Errorf("Can't convert type: %T", args[i])
  629. }
  630. }
  631. data := make([]byte, pktLen+4)
  632. // packet header [4 bytes]
  633. data[0] = byte(pktLen)
  634. data[1] = byte(pktLen >> 8)
  635. data[2] = byte(pktLen >> 16)
  636. data[3] = stmt.mc.sequence
  637. // command [1 byte]
  638. data[4] = comStmtExecute
  639. // statement_id [4 bytes]
  640. data[5] = byte(stmt.id)
  641. data[6] = byte(stmt.id >> 8)
  642. data[7] = byte(stmt.id >> 16)
  643. data[8] = byte(stmt.id >> 24)
  644. // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
  645. //data[9] = 0x00
  646. // iteration_count (uint32(1)) [4 bytes]
  647. data[10] = 0x01
  648. //data[11] = 0x00
  649. //data[12] = 0x00
  650. //data[13] = 0x00
  651. if stmt.paramCount > 0 {
  652. // NULL-bitmap [(param_count+7)/8 bytes]
  653. pos := 14 + ((stmt.paramCount + 7) >> 3)
  654. // Convert bitMask to bytes
  655. for i = 14; i < pos; i++ {
  656. data[i] = byte(bitMask >> uint((i-14)<<3))
  657. }
  658. // newParameterBoundFlag 1 [1 byte]
  659. data[pos] = 0x01
  660. pos++
  661. // type of parameters [param_count*2 bytes]
  662. pos += copy(data[pos:], paramTypes)
  663. // values for the parameters [n bytes]
  664. for i = range paramValues {
  665. pos += copy(data[pos:], paramValues[i])
  666. }
  667. }
  668. return stmt.mc.writePacket(data)
  669. }
  670. // http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
  671. func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
  672. data, err := rc.mc.readPacket()
  673. if err != nil {
  674. return
  675. }
  676. // packet indicator [1 byte]
  677. if data[0] != iOK {
  678. // EOF Packet
  679. if data[0] == iEOF && len(data) == 5 {
  680. return io.EOF
  681. } else {
  682. // Error otherwise
  683. return rc.mc.handleErrorPacket(data)
  684. }
  685. }
  686. // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
  687. pos := 1 + (len(dest)+7+2)>>3
  688. nullBitMap := data[1:pos]
  689. // values [rest]
  690. var n int
  691. var unsigned bool
  692. for i := range dest {
  693. // Field is NULL
  694. // (byte >> bit-pos) % 2 == 1
  695. if ((nullBitMap[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
  696. dest[i] = nil
  697. continue
  698. }
  699. unsigned = rc.columns[i].flags&flagUnsigned != 0
  700. // Convert to byte-coded string
  701. switch rc.columns[i].fieldType {
  702. case fieldTypeNULL:
  703. dest[i] = nil
  704. continue
  705. // Numeric Types
  706. case fieldTypeTiny:
  707. if unsigned {
  708. dest[i] = int64(data[pos])
  709. } else {
  710. dest[i] = int64(int8(data[pos]))
  711. }
  712. pos++
  713. continue
  714. case fieldTypeShort, fieldTypeYear:
  715. if unsigned {
  716. dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
  717. } else {
  718. dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
  719. }
  720. pos += 2
  721. continue
  722. case fieldTypeInt24, fieldTypeLong:
  723. if unsigned {
  724. dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
  725. } else {
  726. dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
  727. }
  728. pos += 4
  729. continue
  730. case fieldTypeLongLong:
  731. if unsigned {
  732. val := binary.LittleEndian.Uint64(data[pos : pos+8])
  733. if val > math.MaxInt64 {
  734. dest[i] = uint64ToString(val)
  735. } else {
  736. dest[i] = int64(val)
  737. }
  738. } else {
  739. dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
  740. }
  741. pos += 8
  742. continue
  743. case fieldTypeFloat:
  744. dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
  745. pos += 4
  746. continue
  747. case fieldTypeDouble:
  748. dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
  749. pos += 8
  750. continue
  751. // Length coded Binary Strings
  752. case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
  753. fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
  754. fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
  755. fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
  756. var isNull bool
  757. dest[i], isNull, n, err = readLengthEnodedString(data[pos:])
  758. pos += n
  759. if err == nil {
  760. if !isNull {
  761. continue
  762. } else {
  763. dest[i] = nil
  764. continue
  765. }
  766. }
  767. return // err
  768. // Date YYYY-MM-DD
  769. case fieldTypeDate, fieldTypeNewDate:
  770. var num uint64
  771. var isNull bool
  772. num, isNull, n = readLengthEncodedInteger(data[pos:])
  773. pos += n
  774. if num == 0 {
  775. if isNull {
  776. dest[i] = nil
  777. continue
  778. } else {
  779. dest[i] = []byte("0000-00-00")
  780. continue
  781. }
  782. } else {
  783. dest[i] = []byte(fmt.Sprintf("%04d-%02d-%02d",
  784. binary.LittleEndian.Uint16(data[pos:pos+2]),
  785. data[pos+2],
  786. data[pos+3]))
  787. pos += int(num)
  788. continue
  789. }
  790. // Time [-][H]HH:MM:SS[.fractal]
  791. case fieldTypeTime:
  792. var num uint64
  793. var isNull bool
  794. num, isNull, n = readLengthEncodedInteger(data[pos:])
  795. pos += n
  796. if num == 0 {
  797. if isNull {
  798. dest[i] = nil
  799. continue
  800. } else {
  801. dest[i] = []byte("00:00:00")
  802. continue
  803. }
  804. }
  805. var sign byte
  806. if data[pos] == 1 {
  807. sign = byte('-')
  808. }
  809. switch num {
  810. case 8:
  811. dest[i] = []byte(fmt.Sprintf(
  812. "%c%02d:%02d:%02d",
  813. sign,
  814. uint16(data[pos+1])*24+uint16(data[pos+5]),
  815. data[pos+6],
  816. data[pos+7],
  817. ))
  818. pos += 8
  819. continue
  820. case 12:
  821. dest[i] = []byte(fmt.Sprintf(
  822. "%c%02d:%02d:%02d.%06d",
  823. sign,
  824. uint16(data[pos+1])*24+uint16(data[pos+5]),
  825. data[pos+6],
  826. data[pos+7],
  827. binary.LittleEndian.Uint32(data[pos+8:pos+12]),
  828. ))
  829. pos += 12
  830. continue
  831. default:
  832. return fmt.Errorf("Invalid TIME-packet length %d", num)
  833. }
  834. // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
  835. case fieldTypeTimestamp, fieldTypeDateTime:
  836. var num uint64
  837. var isNull bool
  838. num, isNull, n = readLengthEncodedInteger(data[pos:])
  839. pos += n
  840. if num == 0 {
  841. if isNull {
  842. dest[i] = nil
  843. continue
  844. } else {
  845. dest[i] = []byte("0000-00-00 00:00:00")
  846. continue
  847. }
  848. }
  849. switch num {
  850. case 4:
  851. dest[i] = []byte(fmt.Sprintf(
  852. "%04d-%02d-%02d 00:00:00",
  853. binary.LittleEndian.Uint16(data[pos:pos+2]),
  854. data[pos+2],
  855. data[pos+3],
  856. ))
  857. pos += 4
  858. continue
  859. case 7:
  860. dest[i] = []byte(fmt.Sprintf(
  861. "%04d-%02d-%02d %02d:%02d:%02d",
  862. binary.LittleEndian.Uint16(data[pos:pos+2]),
  863. data[pos+2],
  864. data[pos+3],
  865. data[pos+4],
  866. data[pos+5],
  867. data[pos+6],
  868. ))
  869. pos += 7
  870. continue
  871. case 11:
  872. dest[i] = []byte(fmt.Sprintf(
  873. "%04d-%02d-%02d %02d:%02d:%02d.%06d",
  874. binary.LittleEndian.Uint16(data[pos:pos+2]),
  875. data[pos+2],
  876. data[pos+3],
  877. data[pos+4],
  878. data[pos+5],
  879. data[pos+6],
  880. binary.LittleEndian.Uint32(data[pos+7:pos+11]),
  881. ))
  882. pos += 11
  883. continue
  884. default:
  885. return fmt.Errorf("Invalid DATETIME-packet length %d", num)
  886. }
  887. // Please report if this happens!
  888. default:
  889. return fmt.Errorf("Unknown FieldType %d", rc.columns[i].fieldType)
  890. }
  891. }
  892. return
  893. }