packets.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "bytes"
  11. "crypto/tls"
  12. "database/sql/driver"
  13. "encoding/binary"
  14. "fmt"
  15. "io"
  16. "math"
  17. "time"
  18. )
  19. // Packets documentation:
  20. // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
  21. // Read packet to buffer 'data'
  22. func (mc *mysqlConn) readPacket() ([]byte, error) {
  23. var payload []byte
  24. for {
  25. // Read packet header
  26. data, err := mc.buf.readNext(4)
  27. if err != nil {
  28. errLog.Print(err)
  29. mc.Close()
  30. return nil, driver.ErrBadConn
  31. }
  32. // Packet Length [24 bit]
  33. pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
  34. if pktLen < 1 {
  35. errLog.Print(ErrMalformPkt)
  36. mc.Close()
  37. return nil, driver.ErrBadConn
  38. }
  39. // Check Packet Sync [8 bit]
  40. if data[3] != mc.sequence {
  41. if data[3] > mc.sequence {
  42. return nil, ErrPktSyncMul
  43. } else {
  44. return nil, ErrPktSync
  45. }
  46. }
  47. mc.sequence++
  48. // Read packet body [pktLen bytes]
  49. data, err = mc.buf.readNext(pktLen)
  50. if err != nil {
  51. errLog.Print(err)
  52. mc.Close()
  53. return nil, driver.ErrBadConn
  54. }
  55. isLastPacket := (pktLen < maxPacketSize)
  56. // Zero allocations for non-splitting packets
  57. if isLastPacket && payload == nil {
  58. return data, nil
  59. }
  60. payload = append(payload, data...)
  61. if isLastPacket {
  62. return payload, nil
  63. }
  64. }
  65. }
  66. // Write packet buffer 'data'
  67. func (mc *mysqlConn) writePacket(data []byte) error {
  68. pktLen := len(data) - 4
  69. if pktLen > mc.maxPacketAllowed {
  70. return ErrPktTooLarge
  71. }
  72. for {
  73. var size int
  74. if pktLen >= maxPacketSize {
  75. data[0] = 0xff
  76. data[1] = 0xff
  77. data[2] = 0xff
  78. size = maxPacketSize
  79. } else {
  80. data[0] = byte(pktLen)
  81. data[1] = byte(pktLen >> 8)
  82. data[2] = byte(pktLen >> 16)
  83. size = pktLen
  84. }
  85. data[3] = mc.sequence
  86. // Write packet
  87. n, err := mc.netConn.Write(data[:4+size])
  88. if err == nil && n == 4+size {
  89. mc.sequence++
  90. if size != maxPacketSize {
  91. return nil
  92. }
  93. pktLen -= size
  94. data = data[size:]
  95. continue
  96. }
  97. // Handle error
  98. if err == nil { // n != len(data)
  99. errLog.Print(ErrMalformPkt)
  100. } else {
  101. errLog.Print(err)
  102. }
  103. return driver.ErrBadConn
  104. }
  105. }
  106. /******************************************************************************
  107. * Initialisation Process *
  108. ******************************************************************************/
  109. // Handshake Initialization Packet
  110. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
  111. func (mc *mysqlConn) readInitPacket() ([]byte, error) {
  112. data, err := mc.readPacket()
  113. if err != nil {
  114. return nil, err
  115. }
  116. if data[0] == iERR {
  117. return nil, mc.handleErrorPacket(data)
  118. }
  119. // protocol version [1 byte]
  120. if data[0] < minProtocolVersion {
  121. return nil, fmt.Errorf(
  122. "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
  123. data[0],
  124. minProtocolVersion,
  125. )
  126. }
  127. // server version [null terminated string]
  128. // connection id [4 bytes]
  129. pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
  130. // first part of the password cipher [8 bytes]
  131. cipher := data[pos : pos+8]
  132. // (filler) always 0x00 [1 byte]
  133. pos += 8 + 1
  134. // capability flags (lower 2 bytes) [2 bytes]
  135. mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
  136. if mc.flags&clientProtocol41 == 0 {
  137. return nil, ErrOldProtocol
  138. }
  139. if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
  140. return nil, ErrNoTLS
  141. }
  142. pos += 2
  143. if len(data) > pos {
  144. // character set [1 byte]
  145. // status flags [2 bytes]
  146. // capability flags (upper 2 bytes) [2 bytes]
  147. // length of auth-plugin-data [1 byte]
  148. // reserved (all [00]) [10 bytes]
  149. pos += 1 + 2 + 2 + 1 + 10
  150. // second part of the password cipher [mininum 13 bytes],
  151. // where len=MAX(13, length of auth-plugin-data - 8)
  152. //
  153. // The web documentation is ambiguous about the length. However,
  154. // according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
  155. // the 13th byte is "\0 byte, terminating the second part of
  156. // a scramble". So the second part of the password cipher is
  157. // a NULL terminated string that's at least 13 bytes with the
  158. // last byte being NULL.
  159. //
  160. // The official Python library uses the fixed length 12
  161. // which seems to work but technically could have a hidden bug.
  162. cipher = append(cipher, data[pos:pos+12]...)
  163. // TODO: Verify string termination
  164. // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
  165. // \NUL otherwise
  166. //
  167. //if data[len(data)-1] == 0 {
  168. // return
  169. //}
  170. //return ErrMalformPkt
  171. // make a memory safe copy of the cipher slice
  172. var b [20]byte
  173. copy(b[:], cipher)
  174. return b[:], nil
  175. }
  176. // make a memory safe copy of the cipher slice
  177. var b [8]byte
  178. copy(b[:], cipher)
  179. return b[:], nil
  180. }
  181. // Client Authentication Packet
  182. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
  183. func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
  184. // Adjust client flags based on server support
  185. clientFlags := clientProtocol41 |
  186. clientSecureConn |
  187. clientLongPassword |
  188. clientTransactions |
  189. clientLocalFiles |
  190. clientPluginAuth |
  191. mc.flags&clientLongFlag
  192. if mc.cfg.ClientFoundRows {
  193. clientFlags |= clientFoundRows
  194. }
  195. // To enable TLS / SSL
  196. if mc.cfg.TLS != nil {
  197. clientFlags |= clientSSL
  198. }
  199. // User Password
  200. scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
  201. pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
  202. // To specify a db name
  203. if n := len(mc.cfg.DBName); n > 0 {
  204. clientFlags |= clientConnectWithDB
  205. pktLen += n + 1
  206. }
  207. // Calculate packet length and get buffer with that size
  208. data := mc.buf.takeSmallBuffer(pktLen + 4)
  209. if data == nil {
  210. // can not take the buffer. Something must be wrong with the connection
  211. errLog.Print(ErrBusyBuffer)
  212. return driver.ErrBadConn
  213. }
  214. // ClientFlags [32 bit]
  215. data[4] = byte(clientFlags)
  216. data[5] = byte(clientFlags >> 8)
  217. data[6] = byte(clientFlags >> 16)
  218. data[7] = byte(clientFlags >> 24)
  219. // MaxPacketSize [32 bit] (none)
  220. data[8] = 0x00
  221. data[9] = 0x00
  222. data[10] = 0x00
  223. data[11] = 0x00
  224. // Charset [1 byte]
  225. data[12] = mc.cfg.Collation
  226. // SSL Connection Request Packet
  227. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
  228. if mc.cfg.TLS != nil {
  229. // Send TLS / SSL request packet
  230. if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
  231. return err
  232. }
  233. // Switch to TLS
  234. tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
  235. if err := tlsConn.Handshake(); err != nil {
  236. return err
  237. }
  238. mc.netConn = tlsConn
  239. mc.buf.rd = tlsConn
  240. }
  241. // Filler [23 bytes] (all 0x00)
  242. pos := 13
  243. for ; pos < 13+23; pos++ {
  244. data[pos] = 0
  245. }
  246. // User [null terminated string]
  247. if len(mc.cfg.User) > 0 {
  248. pos += copy(data[pos:], mc.cfg.User)
  249. }
  250. data[pos] = 0x00
  251. pos++
  252. // ScrambleBuffer [length encoded integer]
  253. data[pos] = byte(len(scrambleBuff))
  254. pos += 1 + copy(data[pos+1:], scrambleBuff)
  255. // Databasename [null terminated string]
  256. if len(mc.cfg.DBName) > 0 {
  257. pos += copy(data[pos:], mc.cfg.DBName)
  258. data[pos] = 0x00
  259. pos++
  260. }
  261. // Assume native client during response
  262. pos += copy(data[pos:], "mysql_native_password")
  263. data[pos] = 0x00
  264. // Send Auth packet
  265. return mc.writePacket(data)
  266. }
  267. // Client old authentication packet
  268. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
  269. func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
  270. // User password
  271. scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd))
  272. // Calculate the packet length and add a tailing 0
  273. pktLen := len(scrambleBuff) + 1
  274. data := mc.buf.takeSmallBuffer(4 + pktLen)
  275. if data == nil {
  276. // can not take the buffer. Something must be wrong with the connection
  277. errLog.Print(ErrBusyBuffer)
  278. return driver.ErrBadConn
  279. }
  280. // Add the scrambled password [null terminated string]
  281. copy(data[4:], scrambleBuff)
  282. data[4+pktLen-1] = 0x00
  283. return mc.writePacket(data)
  284. }
  285. // Client clear text authentication packet
  286. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
  287. func (mc *mysqlConn) writeClearAuthPacket() error {
  288. // Calculate the packet length and add a tailing 0
  289. pktLen := len(mc.cfg.Passwd) + 1
  290. data := mc.buf.takeSmallBuffer(4 + pktLen)
  291. if data == nil {
  292. // can not take the buffer. Something must be wrong with the connection
  293. errLog.Print(ErrBusyBuffer)
  294. return driver.ErrBadConn
  295. }
  296. // Add the clear password [null terminated string]
  297. copy(data[4:], mc.cfg.Passwd)
  298. data[4+pktLen-1] = 0x00
  299. return mc.writePacket(data)
  300. }
  301. /******************************************************************************
  302. * Command Packets *
  303. ******************************************************************************/
  304. func (mc *mysqlConn) writeCommandPacket(command byte) error {
  305. // Reset Packet Sequence
  306. mc.sequence = 0
  307. data := mc.buf.takeSmallBuffer(4 + 1)
  308. if data == nil {
  309. // can not take the buffer. Something must be wrong with the connection
  310. errLog.Print(ErrBusyBuffer)
  311. return driver.ErrBadConn
  312. }
  313. // Add command byte
  314. data[4] = command
  315. // Send CMD packet
  316. return mc.writePacket(data)
  317. }
  318. func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
  319. // Reset Packet Sequence
  320. mc.sequence = 0
  321. pktLen := 1 + len(arg)
  322. data := mc.buf.takeBuffer(pktLen + 4)
  323. if data == nil {
  324. // can not take the buffer. Something must be wrong with the connection
  325. errLog.Print(ErrBusyBuffer)
  326. return driver.ErrBadConn
  327. }
  328. // Add command byte
  329. data[4] = command
  330. // Add arg
  331. copy(data[5:], arg)
  332. // Send CMD packet
  333. return mc.writePacket(data)
  334. }
  335. func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
  336. // Reset Packet Sequence
  337. mc.sequence = 0
  338. data := mc.buf.takeSmallBuffer(4 + 1 + 4)
  339. if data == nil {
  340. // can not take the buffer. Something must be wrong with the connection
  341. errLog.Print(ErrBusyBuffer)
  342. return driver.ErrBadConn
  343. }
  344. // Add command byte
  345. data[4] = command
  346. // Add arg [32 bit]
  347. data[5] = byte(arg)
  348. data[6] = byte(arg >> 8)
  349. data[7] = byte(arg >> 16)
  350. data[8] = byte(arg >> 24)
  351. // Send CMD packet
  352. return mc.writePacket(data)
  353. }
  354. /******************************************************************************
  355. * Result Packets *
  356. ******************************************************************************/
  357. // Returns error if Packet is not an 'Result OK'-Packet
  358. func (mc *mysqlConn) readResultOK() error {
  359. data, err := mc.readPacket()
  360. if err == nil {
  361. // packet indicator
  362. switch data[0] {
  363. case iOK:
  364. return mc.handleOkPacket(data)
  365. case iEOF:
  366. if len(data) > 1 {
  367. plugin := string(data[1:bytes.IndexByte(data, 0x00)])
  368. if plugin == "mysql_old_password" {
  369. // using old_passwords
  370. return ErrOldPassword
  371. } else if plugin == "mysql_clear_password" {
  372. // using clear text password
  373. return ErrCleartextPassword
  374. } else {
  375. return ErrUnknownPlugin
  376. }
  377. } else {
  378. return ErrOldPassword
  379. }
  380. default: // Error otherwise
  381. return mc.handleErrorPacket(data)
  382. }
  383. }
  384. return err
  385. }
  386. // Result Set Header Packet
  387. // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
  388. func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
  389. data, err := mc.readPacket()
  390. if err == nil {
  391. switch data[0] {
  392. case iOK:
  393. return 0, mc.handleOkPacket(data)
  394. case iERR:
  395. return 0, mc.handleErrorPacket(data)
  396. case iLocalInFile:
  397. return 0, mc.handleInFileRequest(string(data[1:]))
  398. }
  399. // column count
  400. num, _, n := readLengthEncodedInteger(data)
  401. if n-len(data) == 0 {
  402. return int(num), nil
  403. }
  404. return 0, ErrMalformPkt
  405. }
  406. return 0, err
  407. }
  408. // Error Packet
  409. // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
  410. func (mc *mysqlConn) handleErrorPacket(data []byte) error {
  411. if data[0] != iERR {
  412. return ErrMalformPkt
  413. }
  414. // 0xff [1 byte]
  415. // Error Number [16 bit uint]
  416. errno := binary.LittleEndian.Uint16(data[1:3])
  417. pos := 3
  418. // SQL State [optional: # + 5bytes string]
  419. if data[3] == 0x23 {
  420. //sqlstate := string(data[4 : 4+5])
  421. pos = 9
  422. }
  423. // Error Message [string]
  424. return &MySQLError{
  425. Number: errno,
  426. Message: string(data[pos:]),
  427. }
  428. }
  429. // Ok Packet
  430. // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
  431. func (mc *mysqlConn) handleOkPacket(data []byte) error {
  432. var n, m int
  433. // 0x00 [1 byte]
  434. // Affected rows [Length Coded Binary]
  435. mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
  436. // Insert id [Length Coded Binary]
  437. mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
  438. // server_status [2 bytes]
  439. mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
  440. // warning count [2 bytes]
  441. if !mc.strict {
  442. return nil
  443. } else {
  444. pos := 1 + n + m + 2
  445. if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
  446. return mc.getWarnings()
  447. }
  448. return nil
  449. }
  450. }
  451. // Read Packets as Field Packets until EOF-Packet or an Error appears
  452. // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
  453. func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
  454. columns := make([]mysqlField, count)
  455. for i := 0; ; i++ {
  456. data, err := mc.readPacket()
  457. if err != nil {
  458. return nil, err
  459. }
  460. // EOF Packet
  461. if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
  462. if i == count {
  463. return columns, nil
  464. }
  465. return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
  466. }
  467. // Catalog
  468. pos, err := skipLengthEncodedString(data)
  469. if err != nil {
  470. return nil, err
  471. }
  472. // Database [len coded string]
  473. n, err := skipLengthEncodedString(data[pos:])
  474. if err != nil {
  475. return nil, err
  476. }
  477. pos += n
  478. // Table [len coded string]
  479. if mc.cfg.ColumnsWithAlias {
  480. tableName, _, n, err := readLengthEncodedString(data[pos:])
  481. if err != nil {
  482. return nil, err
  483. }
  484. pos += n
  485. columns[i].tableName = string(tableName)
  486. } else {
  487. n, err = skipLengthEncodedString(data[pos:])
  488. if err != nil {
  489. return nil, err
  490. }
  491. pos += n
  492. }
  493. // Original table [len coded string]
  494. n, err = skipLengthEncodedString(data[pos:])
  495. if err != nil {
  496. return nil, err
  497. }
  498. pos += n
  499. // Name [len coded string]
  500. name, _, n, err := readLengthEncodedString(data[pos:])
  501. if err != nil {
  502. return nil, err
  503. }
  504. columns[i].name = string(name)
  505. pos += n
  506. // Original name [len coded string]
  507. n, err = skipLengthEncodedString(data[pos:])
  508. if err != nil {
  509. return nil, err
  510. }
  511. // Filler [uint8]
  512. // Charset [charset, collation uint8]
  513. // Length [uint32]
  514. pos += n + 1 + 2 + 4
  515. // Field type [uint8]
  516. columns[i].fieldType = data[pos]
  517. pos++
  518. // Flags [uint16]
  519. columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
  520. pos += 2
  521. // Decimals [uint8]
  522. columns[i].decimals = data[pos]
  523. //pos++
  524. // Default value [len coded binary]
  525. //if pos < len(data) {
  526. // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
  527. //}
  528. }
  529. }
  530. // Read Packets as Field Packets until EOF-Packet or an Error appears
  531. // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
  532. func (rows *textRows) readRow(dest []driver.Value) error {
  533. mc := rows.mc
  534. data, err := mc.readPacket()
  535. if err != nil {
  536. return err
  537. }
  538. // EOF Packet
  539. if data[0] == iEOF && len(data) == 5 {
  540. rows.mc = nil
  541. return io.EOF
  542. }
  543. if data[0] == iERR {
  544. rows.mc = nil
  545. return mc.handleErrorPacket(data)
  546. }
  547. // RowSet Packet
  548. var n int
  549. var isNull bool
  550. pos := 0
  551. for i := range dest {
  552. // Read bytes and convert to string
  553. dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
  554. pos += n
  555. if err == nil {
  556. if !isNull {
  557. if !mc.parseTime {
  558. continue
  559. } else {
  560. switch rows.columns[i].fieldType {
  561. case fieldTypeTimestamp, fieldTypeDateTime,
  562. fieldTypeDate, fieldTypeNewDate:
  563. dest[i], err = parseDateTime(
  564. string(dest[i].([]byte)),
  565. mc.cfg.Loc,
  566. )
  567. if err == nil {
  568. continue
  569. }
  570. default:
  571. continue
  572. }
  573. }
  574. } else {
  575. dest[i] = nil
  576. continue
  577. }
  578. }
  579. return err // err != nil
  580. }
  581. return nil
  582. }
  583. // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
  584. func (mc *mysqlConn) readUntilEOF() error {
  585. for {
  586. data, err := mc.readPacket()
  587. // No Err and no EOF Packet
  588. if err == nil && data[0] != iEOF {
  589. continue
  590. }
  591. return err // Err or EOF
  592. }
  593. }
  594. /******************************************************************************
  595. * Prepared Statements *
  596. ******************************************************************************/
  597. // Prepare Result Packets
  598. // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
  599. func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
  600. data, err := stmt.mc.readPacket()
  601. if err == nil {
  602. // packet indicator [1 byte]
  603. if data[0] != iOK {
  604. return 0, stmt.mc.handleErrorPacket(data)
  605. }
  606. // statement id [4 bytes]
  607. stmt.id = binary.LittleEndian.Uint32(data[1:5])
  608. // Column count [16 bit uint]
  609. columnCount := binary.LittleEndian.Uint16(data[5:7])
  610. // Param count [16 bit uint]
  611. stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
  612. // Reserved [8 bit]
  613. // Warning count [16 bit uint]
  614. if !stmt.mc.strict {
  615. return columnCount, nil
  616. } else {
  617. // Check for warnings count > 0, only available in MySQL > 4.1
  618. if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
  619. return columnCount, stmt.mc.getWarnings()
  620. }
  621. return columnCount, nil
  622. }
  623. }
  624. return 0, err
  625. }
  626. // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
  627. func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
  628. maxLen := stmt.mc.maxPacketAllowed - 1
  629. pktLen := maxLen
  630. // After the header (bytes 0-3) follows before the data:
  631. // 1 byte command
  632. // 4 bytes stmtID
  633. // 2 bytes paramID
  634. const dataOffset = 1 + 4 + 2
  635. // Can not use the write buffer since
  636. // a) the buffer is too small
  637. // b) it is in use
  638. data := make([]byte, 4+1+4+2+len(arg))
  639. copy(data[4+dataOffset:], arg)
  640. for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
  641. if dataOffset+argLen < maxLen {
  642. pktLen = dataOffset + argLen
  643. }
  644. stmt.mc.sequence = 0
  645. // Add command byte [1 byte]
  646. data[4] = comStmtSendLongData
  647. // Add stmtID [32 bit]
  648. data[5] = byte(stmt.id)
  649. data[6] = byte(stmt.id >> 8)
  650. data[7] = byte(stmt.id >> 16)
  651. data[8] = byte(stmt.id >> 24)
  652. // Add paramID [16 bit]
  653. data[9] = byte(paramID)
  654. data[10] = byte(paramID >> 8)
  655. // Send CMD packet
  656. err := stmt.mc.writePacket(data[:4+pktLen])
  657. if err == nil {
  658. data = data[pktLen-dataOffset:]
  659. continue
  660. }
  661. return err
  662. }
  663. // Reset Packet Sequence
  664. stmt.mc.sequence = 0
  665. return nil
  666. }
  667. // Execute Prepared Statement
  668. // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
  669. func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
  670. if len(args) != stmt.paramCount {
  671. return fmt.Errorf(
  672. "Arguments count mismatch (Got: %d Has: %d)",
  673. len(args),
  674. stmt.paramCount,
  675. )
  676. }
  677. const minPktLen = 4 + 1 + 4 + 1 + 4
  678. mc := stmt.mc
  679. // Reset packet-sequence
  680. mc.sequence = 0
  681. var data []byte
  682. if len(args) == 0 {
  683. data = mc.buf.takeBuffer(minPktLen)
  684. } else {
  685. data = mc.buf.takeCompleteBuffer()
  686. }
  687. if data == nil {
  688. // can not take the buffer. Something must be wrong with the connection
  689. errLog.Print(ErrBusyBuffer)
  690. return driver.ErrBadConn
  691. }
  692. // command [1 byte]
  693. data[4] = comStmtExecute
  694. // statement_id [4 bytes]
  695. data[5] = byte(stmt.id)
  696. data[6] = byte(stmt.id >> 8)
  697. data[7] = byte(stmt.id >> 16)
  698. data[8] = byte(stmt.id >> 24)
  699. // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
  700. data[9] = 0x00
  701. // iteration_count (uint32(1)) [4 bytes]
  702. data[10] = 0x01
  703. data[11] = 0x00
  704. data[12] = 0x00
  705. data[13] = 0x00
  706. if len(args) > 0 {
  707. pos := minPktLen
  708. var nullMask []byte
  709. if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
  710. // buffer has to be extended but we don't know by how much so
  711. // we depend on append after all data with known sizes fit.
  712. // We stop at that because we deal with a lot of columns here
  713. // which makes the required allocation size hard to guess.
  714. tmp := make([]byte, pos+maskLen+typesLen)
  715. copy(tmp[:pos], data[:pos])
  716. data = tmp
  717. nullMask = data[pos : pos+maskLen]
  718. pos += maskLen
  719. } else {
  720. nullMask = data[pos : pos+maskLen]
  721. for i := 0; i < maskLen; i++ {
  722. nullMask[i] = 0
  723. }
  724. pos += maskLen
  725. }
  726. // newParameterBoundFlag 1 [1 byte]
  727. data[pos] = 0x01
  728. pos++
  729. // type of each parameter [len(args)*2 bytes]
  730. paramTypes := data[pos:]
  731. pos += len(args) * 2
  732. // value of each parameter [n bytes]
  733. paramValues := data[pos:pos]
  734. valuesCap := cap(paramValues)
  735. for i, arg := range args {
  736. // build NULL-bitmap
  737. if arg == nil {
  738. nullMask[i/8] |= 1 << (uint(i) & 7)
  739. paramTypes[i+i] = fieldTypeNULL
  740. paramTypes[i+i+1] = 0x00
  741. continue
  742. }
  743. // cache types and values
  744. switch v := arg.(type) {
  745. case int64:
  746. paramTypes[i+i] = fieldTypeLongLong
  747. paramTypes[i+i+1] = 0x00
  748. if cap(paramValues)-len(paramValues)-8 >= 0 {
  749. paramValues = paramValues[:len(paramValues)+8]
  750. binary.LittleEndian.PutUint64(
  751. paramValues[len(paramValues)-8:],
  752. uint64(v),
  753. )
  754. } else {
  755. paramValues = append(paramValues,
  756. uint64ToBytes(uint64(v))...,
  757. )
  758. }
  759. case float64:
  760. paramTypes[i+i] = fieldTypeDouble
  761. paramTypes[i+i+1] = 0x00
  762. if cap(paramValues)-len(paramValues)-8 >= 0 {
  763. paramValues = paramValues[:len(paramValues)+8]
  764. binary.LittleEndian.PutUint64(
  765. paramValues[len(paramValues)-8:],
  766. math.Float64bits(v),
  767. )
  768. } else {
  769. paramValues = append(paramValues,
  770. uint64ToBytes(math.Float64bits(v))...,
  771. )
  772. }
  773. case bool:
  774. paramTypes[i+i] = fieldTypeTiny
  775. paramTypes[i+i+1] = 0x00
  776. if v {
  777. paramValues = append(paramValues, 0x01)
  778. } else {
  779. paramValues = append(paramValues, 0x00)
  780. }
  781. case []byte:
  782. // Common case (non-nil value) first
  783. if v != nil {
  784. paramTypes[i+i] = fieldTypeString
  785. paramTypes[i+i+1] = 0x00
  786. if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
  787. paramValues = appendLengthEncodedInteger(paramValues,
  788. uint64(len(v)),
  789. )
  790. paramValues = append(paramValues, v...)
  791. } else {
  792. if err := stmt.writeCommandLongData(i, v); err != nil {
  793. return err
  794. }
  795. }
  796. continue
  797. }
  798. // Handle []byte(nil) as a NULL value
  799. nullMask[i/8] |= 1 << (uint(i) & 7)
  800. paramTypes[i+i] = fieldTypeNULL
  801. paramTypes[i+i+1] = 0x00
  802. case string:
  803. paramTypes[i+i] = fieldTypeString
  804. paramTypes[i+i+1] = 0x00
  805. if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
  806. paramValues = appendLengthEncodedInteger(paramValues,
  807. uint64(len(v)),
  808. )
  809. paramValues = append(paramValues, v...)
  810. } else {
  811. if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
  812. return err
  813. }
  814. }
  815. case time.Time:
  816. paramTypes[i+i] = fieldTypeString
  817. paramTypes[i+i+1] = 0x00
  818. var val []byte
  819. if v.IsZero() {
  820. val = []byte("0000-00-00")
  821. } else {
  822. val = []byte(v.In(mc.cfg.Loc).Format(timeFormat))
  823. }
  824. paramValues = appendLengthEncodedInteger(paramValues,
  825. uint64(len(val)),
  826. )
  827. paramValues = append(paramValues, val...)
  828. default:
  829. return fmt.Errorf("Can't convert type: %T", arg)
  830. }
  831. }
  832. // Check if param values exceeded the available buffer
  833. // In that case we must build the data packet with the new values buffer
  834. if valuesCap != cap(paramValues) {
  835. data = append(data[:pos], paramValues...)
  836. mc.buf.buf = data
  837. }
  838. pos += len(paramValues)
  839. data = data[:pos]
  840. }
  841. return mc.writePacket(data)
  842. }
  843. // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
  844. func (rows *binaryRows) readRow(dest []driver.Value) error {
  845. data, err := rows.mc.readPacket()
  846. if err != nil {
  847. return err
  848. }
  849. // packet indicator [1 byte]
  850. if data[0] != iOK {
  851. rows.mc = nil
  852. // EOF Packet
  853. if data[0] == iEOF && len(data) == 5 {
  854. return io.EOF
  855. }
  856. // Error otherwise
  857. return rows.mc.handleErrorPacket(data)
  858. }
  859. // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
  860. pos := 1 + (len(dest)+7+2)>>3
  861. nullMask := data[1:pos]
  862. for i := range dest {
  863. // Field is NULL
  864. // (byte >> bit-pos) % 2 == 1
  865. if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
  866. dest[i] = nil
  867. continue
  868. }
  869. // Convert to byte-coded string
  870. switch rows.columns[i].fieldType {
  871. case fieldTypeNULL:
  872. dest[i] = nil
  873. continue
  874. // Numeric Types
  875. case fieldTypeTiny:
  876. if rows.columns[i].flags&flagUnsigned != 0 {
  877. dest[i] = int64(data[pos])
  878. } else {
  879. dest[i] = int64(int8(data[pos]))
  880. }
  881. pos++
  882. continue
  883. case fieldTypeShort, fieldTypeYear:
  884. if rows.columns[i].flags&flagUnsigned != 0 {
  885. dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
  886. } else {
  887. dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
  888. }
  889. pos += 2
  890. continue
  891. case fieldTypeInt24, fieldTypeLong:
  892. if rows.columns[i].flags&flagUnsigned != 0 {
  893. dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
  894. } else {
  895. dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
  896. }
  897. pos += 4
  898. continue
  899. case fieldTypeLongLong:
  900. if rows.columns[i].flags&flagUnsigned != 0 {
  901. val := binary.LittleEndian.Uint64(data[pos : pos+8])
  902. if val > math.MaxInt64 {
  903. dest[i] = uint64ToString(val)
  904. } else {
  905. dest[i] = int64(val)
  906. }
  907. } else {
  908. dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
  909. }
  910. pos += 8
  911. continue
  912. case fieldTypeFloat:
  913. dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
  914. pos += 4
  915. continue
  916. case fieldTypeDouble:
  917. dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
  918. pos += 8
  919. continue
  920. // Length coded Binary Strings
  921. case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
  922. fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
  923. fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
  924. fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
  925. var isNull bool
  926. var n int
  927. dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
  928. pos += n
  929. if err == nil {
  930. if !isNull {
  931. continue
  932. } else {
  933. dest[i] = nil
  934. continue
  935. }
  936. }
  937. return err
  938. case
  939. fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
  940. fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal]
  941. fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
  942. num, isNull, n := readLengthEncodedInteger(data[pos:])
  943. pos += n
  944. switch {
  945. case isNull:
  946. dest[i] = nil
  947. continue
  948. case rows.columns[i].fieldType == fieldTypeTime:
  949. // database/sql does not support an equivalent to TIME, return a string
  950. var dstlen uint8
  951. switch decimals := rows.columns[i].decimals; decimals {
  952. case 0x00, 0x1f:
  953. dstlen = 8
  954. case 1, 2, 3, 4, 5, 6:
  955. dstlen = 8 + 1 + decimals
  956. default:
  957. return fmt.Errorf(
  958. "MySQL protocol error, illegal decimals value %d",
  959. rows.columns[i].decimals,
  960. )
  961. }
  962. dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
  963. case rows.mc.parseTime:
  964. dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
  965. default:
  966. var dstlen uint8
  967. if rows.columns[i].fieldType == fieldTypeDate {
  968. dstlen = 10
  969. } else {
  970. switch decimals := rows.columns[i].decimals; decimals {
  971. case 0x00, 0x1f:
  972. dstlen = 19
  973. case 1, 2, 3, 4, 5, 6:
  974. dstlen = 19 + 1 + decimals
  975. default:
  976. return fmt.Errorf(
  977. "MySQL protocol error, illegal decimals value %d",
  978. rows.columns[i].decimals,
  979. )
  980. }
  981. }
  982. dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false)
  983. }
  984. if err == nil {
  985. pos += int(num)
  986. continue
  987. } else {
  988. return err
  989. }
  990. // Please report if this happens!
  991. default:
  992. return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType)
  993. }
  994. }
  995. return nil
  996. }