utils.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 Julien Schmidt. All rights reserved.
  4. // http://www.julienschmidt.com
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla Public
  7. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  8. // You can obtain one at http://mozilla.org/MPL/2.0/.
  9. package mysql
  10. import (
  11. "bytes"
  12. "crypto/sha1"
  13. "io"
  14. "math"
  15. "regexp"
  16. "strconv"
  17. "strings"
  18. )
  19. var dsnPattern *regexp.Regexp
  20. func init() {
  21. dsnPattern = regexp.MustCompile(
  22. `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
  23. `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
  24. `\/(?P<dbname>.*?)` + // /dbname
  25. `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
  26. }
  27. func parseDSN(dsn string) *config {
  28. cfg := new(config)
  29. cfg.params = make(map[string]string)
  30. matches := dsnPattern.FindStringSubmatch(dsn)
  31. names := dsnPattern.SubexpNames()
  32. for i, match := range matches {
  33. switch names[i] {
  34. case "user":
  35. cfg.user = match
  36. case "passwd":
  37. cfg.passwd = match
  38. case "net":
  39. cfg.net = match
  40. case "addr":
  41. cfg.addr = match
  42. case "dbname":
  43. cfg.dbname = match
  44. case "params":
  45. for _, v := range strings.Split(match, "&") {
  46. param := strings.SplitN(v, "=", 2)
  47. if len(param) != 2 {
  48. continue
  49. }
  50. cfg.params[param[0]] = param[1]
  51. }
  52. }
  53. }
  54. // Set default network if empty
  55. if cfg.net == "" {
  56. cfg.net = "tcp"
  57. }
  58. // Set default adress if empty
  59. if cfg.addr == "" {
  60. cfg.addr = "127.0.0.1:3306"
  61. }
  62. return cfg
  63. }
  64. // Encrypt password using 4.1+ method
  65. // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
  66. func scramblePassword(scramble, password []byte) (result []byte) {
  67. if len(password) == 0 {
  68. return
  69. }
  70. // stage1Hash = SHA1(password)
  71. crypt := sha1.New()
  72. crypt.Write(password)
  73. stage1Hash := crypt.Sum(nil)
  74. // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
  75. // inner Hash
  76. crypt.Reset()
  77. crypt.Write(stage1Hash)
  78. scrambleHash := crypt.Sum(nil)
  79. // outer Hash
  80. crypt.Reset()
  81. crypt.Write(scramble)
  82. crypt.Write(scrambleHash)
  83. scrambleHash = crypt.Sum(nil)
  84. // token = scrambleHash XOR stage1Hash
  85. result = make([]byte, 20)
  86. for i := range result {
  87. result[i] = scrambleHash[i] ^ stage1Hash[i]
  88. }
  89. return
  90. }
  91. /******************************************************************************
  92. * Read data-types from bytes *
  93. ******************************************************************************/
  94. // Read a slice from the data slice
  95. func readSlice(data []byte, delim byte) (slice []byte, e error) {
  96. pos := bytes.IndexByte(data, delim)
  97. if pos > -1 {
  98. slice = data[:pos]
  99. } else {
  100. slice = data
  101. e = io.EOF
  102. }
  103. return
  104. }
  105. func readLengthCodedBinary(data []byte) (b []byte, n int, isNull bool, e error) {
  106. // Get length
  107. num, n, e := bytesToLengthCodedBinary(data)
  108. if e != nil {
  109. return
  110. }
  111. // Check data length
  112. if len(data) < n+int(num) {
  113. e = io.EOF
  114. return
  115. }
  116. // Check if null
  117. if data[0] == 251 {
  118. isNull = true
  119. } else {
  120. isNull = false
  121. }
  122. // Get bytes
  123. b = data[n : n+int(num)]
  124. n += int(num)
  125. return
  126. }
  127. func readAndDropLengthCodedBinary(data []byte) (n int, e error) {
  128. // Get length
  129. num, n, e := bytesToLengthCodedBinary(data)
  130. if e != nil {
  131. return
  132. }
  133. // Check data length
  134. if len(data) < n+int(num) {
  135. e = io.EOF
  136. return
  137. }
  138. n += int(num)
  139. return
  140. }
  141. /******************************************************************************
  142. * Convert from and to bytes *
  143. ******************************************************************************/
  144. func byteToUint8(b byte) (n uint8) {
  145. n |= uint8(b)
  146. return
  147. }
  148. func bytesToUint16(b []byte) (n uint16) {
  149. n |= uint16(b[0])
  150. n |= uint16(b[1]) << 8
  151. return
  152. }
  153. func uint24ToBytes(n uint32) (b []byte) {
  154. b = make([]byte, 3)
  155. for i := uint8(0); i < 3; i++ {
  156. b[i] = byte(n >> (i * 8))
  157. }
  158. return
  159. }
  160. func bytesToUint32(b []byte) (n uint32) {
  161. for i := uint8(0); i < 4; i++ {
  162. n |= uint32(b[i]) << (i * 8)
  163. }
  164. return
  165. }
  166. func uint32ToBytes(n uint32) (b []byte) {
  167. b = make([]byte, 4)
  168. for i := uint8(0); i < 4; i++ {
  169. b[i] = byte(n >> (i * 8))
  170. }
  171. return
  172. }
  173. func bytesToUint64(b []byte) (n uint64) {
  174. for i := uint8(0); i < 8; i++ {
  175. n |= uint64(b[i]) << (i * 8)
  176. }
  177. return
  178. }
  179. func uint64ToBytes(n uint64) (b []byte) {
  180. b = make([]byte, 8)
  181. for i := uint8(0); i < 8; i++ {
  182. b[i] = byte(n >> (i * 8))
  183. }
  184. return
  185. }
  186. func int64ToBytes(n int64) []byte {
  187. return uint64ToBytes(uint64(n))
  188. }
  189. func bytesToFloat32(b []byte) float32 {
  190. return math.Float32frombits(bytesToUint32(b))
  191. }
  192. func bytesToFloat64(b []byte) float64 {
  193. return math.Float64frombits(bytesToUint64(b))
  194. }
  195. func float64ToBytes(f float64) []byte {
  196. return uint64ToBytes(math.Float64bits(f))
  197. }
  198. func bytesToLengthCodedBinary(b []byte) (length uint64, n int, e error) {
  199. switch {
  200. // 0-250: value of first byte
  201. case b[0] <= 250:
  202. length = uint64(b[0])
  203. n = 1
  204. return
  205. // 251: NULL
  206. case b[0] == 251:
  207. length = 0
  208. n = 1
  209. return
  210. // 252: value of following 2
  211. case b[0] == 252:
  212. n = 3
  213. // 253: value of following 3
  214. case b[0] == 253:
  215. n = 4
  216. // 254: value of following 8
  217. case b[0] == 254:
  218. n = 9
  219. }
  220. if len(b) < n {
  221. e = io.EOF
  222. return
  223. }
  224. // get Length
  225. tmp := make([]byte, 8)
  226. copy(tmp, b[1:n])
  227. length = bytesToUint64(tmp)
  228. return
  229. }
  230. func lengthCodedBinaryToBytes(n uint64) (b []byte) {
  231. switch {
  232. case n <= 250:
  233. b = []byte{byte(n)}
  234. case n <= 0xffff:
  235. b = []byte{0xfc, byte(n), byte(n >> 8)}
  236. case n <= 0xffffff:
  237. b = []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
  238. }
  239. return
  240. }
  241. func intToByteStr(i int64) (d *[]byte) {
  242. tmp := make([]byte, 0)
  243. tmp = strconv.AppendInt(tmp, i, 10)
  244. return &tmp
  245. }
  246. func uintToByteStr(u uint64) (d *[]byte) {
  247. tmp := make([]byte, 0)
  248. tmp = strconv.AppendUint(tmp, u, 10)
  249. return &tmp
  250. }
  251. func float32ToByteStr(f float32) (d *[]byte) {
  252. tmp := make([]byte, 0)
  253. tmp = strconv.AppendFloat(tmp, float64(f), 'f', -1, 32)
  254. return &tmp
  255. }
  256. func float64ToByteStr(f float64) (d *[]byte) {
  257. tmp := make([]byte, 0)
  258. tmp = strconv.AppendFloat(tmp, f, 'f', -1, 64)
  259. return &tmp
  260. }