utils.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 Julien Schmidt. All rights reserved.
  4. // http://www.julienschmidt.com
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla Public
  7. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  8. // You can obtain one at http://mozilla.org/MPL/2.0/.
  9. package mysql
  10. import (
  11. "crypto/sha1"
  12. "io"
  13. "log"
  14. "os"
  15. "regexp"
  16. "strings"
  17. )
  18. // Logger
  19. var (
  20. errLog *log.Logger
  21. )
  22. func init() {
  23. errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
  24. dsnPattern = regexp.MustCompile(
  25. `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
  26. `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
  27. `\/(?P<dbname>.*?)` + // /dbname
  28. `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
  29. }
  30. // Data Source Name Parser
  31. var dsnPattern *regexp.Regexp
  32. func parseDSN(dsn string) *config {
  33. cfg := new(config)
  34. cfg.params = make(map[string]string)
  35. matches := dsnPattern.FindStringSubmatch(dsn)
  36. names := dsnPattern.SubexpNames()
  37. for i, match := range matches {
  38. switch names[i] {
  39. case "user":
  40. cfg.user = match
  41. case "passwd":
  42. cfg.passwd = match
  43. case "net":
  44. cfg.net = match
  45. case "addr":
  46. cfg.addr = match
  47. case "dbname":
  48. cfg.dbname = match
  49. case "params":
  50. for _, v := range strings.Split(match, "&") {
  51. param := strings.SplitN(v, "=", 2)
  52. if len(param) != 2 {
  53. continue
  54. }
  55. cfg.params[param[0]] = param[1]
  56. }
  57. }
  58. }
  59. // Set default network if empty
  60. if cfg.net == "" {
  61. cfg.net = "tcp"
  62. }
  63. // Set default adress if empty
  64. if cfg.addr == "" {
  65. cfg.addr = "127.0.0.1:3306"
  66. }
  67. return cfg
  68. }
  69. // Encrypt password using 4.1+ method
  70. // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
  71. func scramblePassword(scramble, password []byte) []byte {
  72. if len(password) == 0 {
  73. return nil
  74. }
  75. // stage1Hash = SHA1(password)
  76. crypt := sha1.New()
  77. crypt.Write(password)
  78. stage1 := crypt.Sum(nil)
  79. // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
  80. // inner Hash
  81. crypt.Reset()
  82. crypt.Write(stage1)
  83. hash := crypt.Sum(nil)
  84. // outer Hash
  85. crypt.Reset()
  86. crypt.Write(scramble)
  87. crypt.Write(hash)
  88. scramble = crypt.Sum(nil)
  89. // token = scrambleHash XOR stage1Hash
  90. for i := range scramble {
  91. scramble[i] ^= stage1[i]
  92. }
  93. return scramble
  94. }
  95. /******************************************************************************
  96. * Convert from and to bytes *
  97. ******************************************************************************/
  98. func uint64ToBytes(n uint64) []byte {
  99. return []byte{
  100. byte(n),
  101. byte(n >> 8),
  102. byte(n >> 16),
  103. byte(n >> 24),
  104. byte(n >> 32),
  105. byte(n >> 40),
  106. byte(n >> 48),
  107. byte(n >> 56),
  108. }
  109. }
  110. func uint64ToString(n uint64) []byte {
  111. var a [20]byte
  112. i := 20
  113. // U+0030 = 0
  114. // ...
  115. // U+0039 = 9
  116. var q uint64
  117. for n >= 10 {
  118. i--
  119. q = n / 10
  120. a[i] = uint8(n-q*10) + 0x30
  121. n = q
  122. }
  123. i--
  124. a[i] = uint8(n) + 0x30
  125. return a[i:]
  126. }
  127. // treats string value as unsigned integer representation
  128. func stringToInt(b []byte) int {
  129. val := 0
  130. for i := range b {
  131. val *= 10
  132. val += int(b[i] - 0x30)
  133. }
  134. return val
  135. }
  136. func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
  137. // Get length
  138. num, isNull, n := readLengthEncodedInteger(b)
  139. if num < 1 {
  140. return nil, isNull, n, nil
  141. }
  142. n += int(num)
  143. // Check data length
  144. if len(b) >= n {
  145. return b[n-int(num) : n], false, n, nil
  146. }
  147. return nil, false, n, io.EOF
  148. }
  149. func skipLengthEnodedString(b []byte) (int, error) {
  150. // Get length
  151. num, _, n := readLengthEncodedInteger(b)
  152. if num < 1 {
  153. return n, nil
  154. }
  155. n += int(num)
  156. // Check data length
  157. if len(b) >= n {
  158. return n, nil
  159. }
  160. return n, io.EOF
  161. }
  162. func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) {
  163. switch b[0] {
  164. // 251: NULL
  165. case 0xfb:
  166. n = 1
  167. isNull = true
  168. return
  169. // 252: value of following 2
  170. case 0xfc:
  171. num = uint64(b[1]) | uint64(b[2])<<8
  172. n = 3
  173. return
  174. // 253: value of following 3
  175. case 0xfd:
  176. num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
  177. n = 4
  178. return
  179. // 254: value of following 8
  180. case 0xfe:
  181. num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
  182. uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
  183. uint64(b[7])<<48 | uint64(b[8])<<54
  184. n = 9
  185. return
  186. }
  187. // 0-250: value of first byte
  188. num = uint64(b[0])
  189. n = 1
  190. return
  191. }
  192. func lengthEncodedIntegerToBytes(n uint64) []byte {
  193. switch {
  194. case n <= 250:
  195. return []byte{byte(n)}
  196. case n <= 0xffff:
  197. return []byte{0xfc, byte(n), byte(n >> 8)}
  198. case n <= 0xffffff:
  199. return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
  200. }
  201. return nil
  202. }