utils.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 Julien Schmidt. All rights reserved.
  4. // http://www.julienschmidt.com
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla Public
  7. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  8. // You can obtain one at http://mozilla.org/MPL/2.0/.
  9. package mysql
  10. import (
  11. "bytes"
  12. "crypto/sha1"
  13. "io"
  14. "log"
  15. "math"
  16. "os"
  17. "regexp"
  18. "strconv"
  19. "strings"
  20. )
  21. // Logger
  22. var (
  23. errLog *log.Logger
  24. dbgLog *log.Logger
  25. )
  26. func init() {
  27. errLog = log.New(os.Stderr, "[MySQL] ", log.LstdFlags)
  28. dbgLog = log.New(os.Stdout, "[MySQL] ", log.LstdFlags)
  29. dsnPattern = regexp.MustCompile(
  30. `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
  31. `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
  32. `\/(?P<dbname>.*?)` + // /dbname
  33. `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
  34. }
  35. // Data Source Name Parser
  36. var dsnPattern *regexp.Regexp
  37. func parseDSN(dsn string) *config {
  38. cfg := new(config)
  39. cfg.params = make(map[string]string)
  40. matches := dsnPattern.FindStringSubmatch(dsn)
  41. names := dsnPattern.SubexpNames()
  42. for i, match := range matches {
  43. switch names[i] {
  44. case "user":
  45. cfg.user = match
  46. case "passwd":
  47. cfg.passwd = match
  48. case "net":
  49. cfg.net = match
  50. case "addr":
  51. cfg.addr = match
  52. case "dbname":
  53. cfg.dbname = match
  54. case "params":
  55. for _, v := range strings.Split(match, "&") {
  56. param := strings.SplitN(v, "=", 2)
  57. if len(param) != 2 {
  58. continue
  59. }
  60. cfg.params[param[0]] = param[1]
  61. }
  62. }
  63. }
  64. // Set default network if empty
  65. if cfg.net == "" {
  66. cfg.net = "tcp"
  67. }
  68. // Set default adress if empty
  69. if cfg.addr == "" {
  70. cfg.addr = "127.0.0.1:3306"
  71. }
  72. return cfg
  73. }
  74. // Encrypt password using 4.1+ method
  75. // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
  76. func scramblePassword(scramble, password []byte) (result []byte) {
  77. if len(password) == 0 {
  78. return
  79. }
  80. // stage1Hash = SHA1(password)
  81. crypt := sha1.New()
  82. crypt.Write(password)
  83. stage1Hash := crypt.Sum(nil)
  84. // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
  85. // inner Hash
  86. crypt.Reset()
  87. crypt.Write(stage1Hash)
  88. scrambleHash := crypt.Sum(nil)
  89. // outer Hash
  90. crypt.Reset()
  91. crypt.Write(scramble)
  92. crypt.Write(scrambleHash)
  93. scrambleHash = crypt.Sum(nil)
  94. // token = scrambleHash XOR stage1Hash
  95. result = make([]byte, 20)
  96. for i := range result {
  97. result[i] = scrambleHash[i] ^ stage1Hash[i]
  98. }
  99. return
  100. }
  101. /******************************************************************************
  102. * Read data-types from bytes *
  103. ******************************************************************************/
  104. // Read a slice from the data slice
  105. func readSlice(data []byte, delim byte) (slice []byte, err error) {
  106. pos := bytes.IndexByte(data, delim)
  107. if pos > -1 {
  108. slice = data[:pos]
  109. } else {
  110. slice = data
  111. err = io.EOF
  112. }
  113. return
  114. }
  115. func readLengthCodedBinary(data []byte) (*[]byte, int, bool, error) {
  116. // Get length
  117. num, n, err := bytesToLengthCodedBinary(data)
  118. if err != nil {
  119. return nil, n, true, err
  120. }
  121. // Check data length
  122. if len(data) < n+int(num) {
  123. return nil, n, true, io.EOF
  124. }
  125. // Check if value is NULL
  126. var isNull bool
  127. if data[0] == 251 {
  128. isNull = true
  129. } else {
  130. isNull = false
  131. }
  132. // Get bytes
  133. b := data[n : n+int(num)]
  134. n += int(num)
  135. return &b, n, isNull, err
  136. }
  137. func readAndDropLengthCodedBinary(data []byte) (n int, err error) {
  138. // Get length
  139. num, n, err := bytesToLengthCodedBinary(data)
  140. if err != nil {
  141. return
  142. }
  143. // Check data length
  144. if len(data) < n+int(num) {
  145. err = io.EOF
  146. return
  147. }
  148. n += int(num)
  149. return
  150. }
  151. /******************************************************************************
  152. * Convert from and to bytes *
  153. ******************************************************************************/
  154. func byteToUint8(b byte) (n uint8) {
  155. n |= uint8(b)
  156. return
  157. }
  158. func bytesToUint16(b []byte) (n uint16) {
  159. n |= uint16(b[0])
  160. n |= uint16(b[1]) << 8
  161. return
  162. }
  163. func uint24ToBytes(n uint32) (b []byte) {
  164. b = make([]byte, 3)
  165. for i := uint8(0); i < 3; i++ {
  166. b[i] = byte(n >> (i << 3))
  167. }
  168. return
  169. }
  170. func bytesToUint32(b []byte) (n uint32) {
  171. for i := uint8(0); i < 4; i++ {
  172. n |= uint32(b[i]) << (i << 3)
  173. }
  174. return
  175. }
  176. func uint32ToBytes(n uint32) (b []byte) {
  177. b = make([]byte, 4)
  178. for i := uint8(0); i < 4; i++ {
  179. b[i] = byte(n >> (i << 3))
  180. }
  181. return
  182. }
  183. func bytesToUint64(b []byte) (n uint64) {
  184. for i := uint8(0); i < 8; i++ {
  185. n |= uint64(b[i]) << (i << 3)
  186. }
  187. return
  188. }
  189. func uint64ToBytes(n uint64) (b []byte) {
  190. b = make([]byte, 8)
  191. for i := uint8(0); i < 8; i++ {
  192. b[i] = byte(n >> (i << 3))
  193. }
  194. return
  195. }
  196. func int64ToBytes(n int64) []byte {
  197. return uint64ToBytes(uint64(n))
  198. }
  199. func bytesToFloat32(b []byte) float32 {
  200. return math.Float32frombits(bytesToUint32(b))
  201. }
  202. func bytesToFloat64(b []byte) float64 {
  203. return math.Float64frombits(bytesToUint64(b))
  204. }
  205. func float64ToBytes(f float64) []byte {
  206. return uint64ToBytes(math.Float64bits(f))
  207. }
  208. func bytesToLengthCodedBinary(b []byte) (length uint64, n int, err error) {
  209. switch {
  210. // 0-250: value of first byte
  211. case b[0] <= 250:
  212. length = uint64(b[0])
  213. n = 1
  214. return
  215. // 251: NULL
  216. case b[0] == 251:
  217. length = 0
  218. n = 1
  219. return
  220. // 252: value of following 2
  221. case b[0] == 252:
  222. n = 3
  223. // 253: value of following 3
  224. case b[0] == 253:
  225. n = 4
  226. // 254: value of following 8
  227. case b[0] == 254:
  228. n = 9
  229. }
  230. if len(b) < n {
  231. err = io.EOF
  232. return
  233. }
  234. // get Length
  235. tmp := make([]byte, 8)
  236. copy(tmp, b[1:n])
  237. length = bytesToUint64(tmp)
  238. return
  239. }
  240. func lengthCodedBinaryToBytes(n uint64) (b []byte) {
  241. switch {
  242. case n <= 250:
  243. b = []byte{byte(n)}
  244. case n <= 0xffff:
  245. b = []byte{0xfc, byte(n), byte(n >> 8)}
  246. case n <= 0xffffff:
  247. b = []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
  248. }
  249. return
  250. }
  251. func intToByteStr(i int64) (b []byte) {
  252. return strconv.AppendInt(b, i, 10)
  253. }
  254. func uintToByteStr(u uint64) (b []byte) {
  255. return strconv.AppendUint(b, u, 10)
  256. }
  257. func float32ToByteStr(f float32) (b []byte) {
  258. return strconv.AppendFloat(b, float64(f), 'f', -1, 32)
  259. }
  260. func float64ToByteStr(f float64) (b []byte) {
  261. return strconv.AppendFloat(b, f, 'f', -1, 64)
  262. }