utils.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 Julien Schmidt. All rights reserved.
  4. // http://www.julienschmidt.com
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla Public
  7. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  8. // You can obtain one at http://mozilla.org/MPL/2.0/.
  9. package mysql
  10. import (
  11. "bytes"
  12. "crypto/sha1"
  13. "io"
  14. "log"
  15. "math"
  16. "os"
  17. "regexp"
  18. "strconv"
  19. "strings"
  20. )
  21. // Logger
  22. var (
  23. errLog *log.Logger
  24. )
  25. func init() {
  26. errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
  27. dsnPattern = regexp.MustCompile(
  28. `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
  29. `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
  30. `\/(?P<dbname>.*?)` + // /dbname
  31. `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
  32. }
  33. // Data Source Name Parser
  34. var dsnPattern *regexp.Regexp
  35. func parseDSN(dsn string) *config {
  36. cfg := new(config)
  37. cfg.params = make(map[string]string)
  38. matches := dsnPattern.FindStringSubmatch(dsn)
  39. names := dsnPattern.SubexpNames()
  40. for i, match := range matches {
  41. switch names[i] {
  42. case "user":
  43. cfg.user = match
  44. case "passwd":
  45. cfg.passwd = match
  46. case "net":
  47. cfg.net = match
  48. case "addr":
  49. cfg.addr = match
  50. case "dbname":
  51. cfg.dbname = match
  52. case "params":
  53. for _, v := range strings.Split(match, "&") {
  54. param := strings.SplitN(v, "=", 2)
  55. if len(param) != 2 {
  56. continue
  57. }
  58. cfg.params[param[0]] = param[1]
  59. }
  60. }
  61. }
  62. // Set default network if empty
  63. if cfg.net == "" {
  64. cfg.net = "tcp"
  65. }
  66. // Set default adress if empty
  67. if cfg.addr == "" {
  68. cfg.addr = "127.0.0.1:3306"
  69. }
  70. return cfg
  71. }
  72. // Encrypt password using 4.1+ method
  73. // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
  74. func scramblePassword(scramble, password []byte) (result []byte) {
  75. if len(password) == 0 {
  76. return
  77. }
  78. // stage1Hash = SHA1(password)
  79. crypt := sha1.New()
  80. crypt.Write(password)
  81. stage1Hash := crypt.Sum(nil)
  82. // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
  83. // inner Hash
  84. crypt.Reset()
  85. crypt.Write(stage1Hash)
  86. scrambleHash := crypt.Sum(nil)
  87. // outer Hash
  88. crypt.Reset()
  89. crypt.Write(scramble)
  90. crypt.Write(scrambleHash)
  91. scrambleHash = crypt.Sum(nil)
  92. // token = scrambleHash XOR stage1Hash
  93. result = make([]byte, 20)
  94. for i := range result {
  95. result[i] = scrambleHash[i] ^ stage1Hash[i]
  96. }
  97. return
  98. }
  99. /******************************************************************************
  100. * Read data-types from bytes *
  101. ******************************************************************************/
  102. // Read a slice from the data slice
  103. func readSlice(data []byte, delim byte) (slice []byte, err error) {
  104. pos := bytes.IndexByte(data, delim)
  105. if pos > -1 {
  106. slice = data[:pos]
  107. } else {
  108. slice = data
  109. err = io.EOF
  110. }
  111. return
  112. }
  113. func readLengthCodedBinary(data []byte) (*[]byte, int, bool, error) {
  114. // Get length
  115. num, n, err := bytesToLengthCodedBinary(data)
  116. if err != nil {
  117. return nil, n, true, err
  118. }
  119. // Check data length
  120. if len(data) < n+int(num) {
  121. return nil, n, true, io.EOF
  122. }
  123. // Check if value is NULL
  124. var isNull bool
  125. if data[0] == 251 {
  126. isNull = true
  127. } else {
  128. isNull = false
  129. }
  130. // Get bytes
  131. b := data[n : n+int(num)]
  132. n += int(num)
  133. return &b, n, isNull, err
  134. }
  135. func readAndDropLengthCodedBinary(data []byte) (n int, err error) {
  136. // Get length
  137. num, n, err := bytesToLengthCodedBinary(data)
  138. if err != nil {
  139. return
  140. }
  141. // Check data length
  142. if len(data) < n+int(num) {
  143. err = io.EOF
  144. return
  145. }
  146. n += int(num)
  147. return
  148. }
  149. /******************************************************************************
  150. * Convert from and to bytes *
  151. ******************************************************************************/
  152. func byteToUint8(b byte) (n uint8) {
  153. n |= uint8(b)
  154. return
  155. }
  156. func bytesToUint16(b []byte) (n uint16) {
  157. n |= uint16(b[0])
  158. n |= uint16(b[1]) << 8
  159. return
  160. }
  161. func uint24ToBytes(n uint32) (b []byte) {
  162. b = make([]byte, 3)
  163. for i := uint8(0); i < 3; i++ {
  164. b[i] = byte(n >> (i << 3))
  165. }
  166. return
  167. }
  168. func bytesToUint32(b []byte) (n uint32) {
  169. for i := uint8(0); i < 4; i++ {
  170. n |= uint32(b[i]) << (i << 3)
  171. }
  172. return
  173. }
  174. func uint32ToBytes(n uint32) (b []byte) {
  175. b = make([]byte, 4)
  176. for i := uint8(0); i < 4; i++ {
  177. b[i] = byte(n >> (i << 3))
  178. }
  179. return
  180. }
  181. func bytesToUint64(b []byte) (n uint64) {
  182. for i := uint8(0); i < 8; i++ {
  183. n |= uint64(b[i]) << (i << 3)
  184. }
  185. return
  186. }
  187. func uint64ToBytes(n uint64) (b []byte) {
  188. b = make([]byte, 8)
  189. for i := uint8(0); i < 8; i++ {
  190. b[i] = byte(n >> (i << 3))
  191. }
  192. return
  193. }
  194. func int64ToBytes(n int64) []byte {
  195. return uint64ToBytes(uint64(n))
  196. }
  197. func bytesToFloat32(b []byte) float32 {
  198. return math.Float32frombits(bytesToUint32(b))
  199. }
  200. func bytesToFloat64(b []byte) float64 {
  201. return math.Float64frombits(bytesToUint64(b))
  202. }
  203. func float64ToBytes(f float64) []byte {
  204. return uint64ToBytes(math.Float64bits(f))
  205. }
  206. func bytesToLengthCodedBinary(b []byte) (length uint64, n int, err error) {
  207. switch {
  208. // 0-250: value of first byte
  209. case b[0] <= 250:
  210. length = uint64(b[0])
  211. n = 1
  212. return
  213. // 251: NULL
  214. case b[0] == 251:
  215. length = 0
  216. n = 1
  217. return
  218. // 252: value of following 2
  219. case b[0] == 252:
  220. n = 3
  221. // 253: value of following 3
  222. case b[0] == 253:
  223. n = 4
  224. // 254: value of following 8
  225. case b[0] == 254:
  226. n = 9
  227. }
  228. if len(b) < n {
  229. err = io.EOF
  230. return
  231. }
  232. // get Length
  233. tmp := make([]byte, 8)
  234. copy(tmp, b[1:n])
  235. length = bytesToUint64(tmp)
  236. return
  237. }
  238. func lengthCodedBinaryToBytes(n uint64) (b []byte) {
  239. switch {
  240. case n <= 250:
  241. b = []byte{byte(n)}
  242. case n <= 0xffff:
  243. b = []byte{0xfc, byte(n), byte(n >> 8)}
  244. case n <= 0xffffff:
  245. b = []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
  246. }
  247. return
  248. }
  249. func intToByteStr(i int64) (b []byte) {
  250. return strconv.AppendInt(b, i, 10)
  251. }
  252. func uintToByteStr(u uint64) (b []byte) {
  253. return strconv.AppendUint(b, u, 10)
  254. }
  255. func float32ToByteStr(f float32) (b []byte) {
  256. return strconv.AppendFloat(b, float64(f), 'f', -1, 32)
  257. }
  258. func float64ToByteStr(f float64) (b []byte) {
  259. return strconv.AppendFloat(b, f, 'f', -1, 64)
  260. }