tds.go 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367
  1. package mssql
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "io/ioutil"
  11. "net"
  12. "net/url"
  13. "os"
  14. "sort"
  15. "strconv"
  16. "strings"
  17. "time"
  18. "unicode"
  19. "unicode/utf16"
  20. "unicode/utf8"
  21. )
  22. func parseInstances(msg []byte) map[string]map[string]string {
  23. results := map[string]map[string]string{}
  24. if len(msg) > 3 && msg[0] == 5 {
  25. out_s := string(msg[3:])
  26. tokens := strings.Split(out_s, ";")
  27. instdict := map[string]string{}
  28. got_name := false
  29. var name string
  30. for _, token := range tokens {
  31. if got_name {
  32. instdict[name] = token
  33. got_name = false
  34. } else {
  35. name = token
  36. if len(name) == 0 {
  37. if len(instdict) == 0 {
  38. break
  39. }
  40. results[strings.ToUpper(instdict["InstanceName"])] = instdict
  41. instdict = map[string]string{}
  42. continue
  43. }
  44. got_name = true
  45. }
  46. }
  47. }
  48. return results
  49. }
  50. func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) {
  51. maxTime := 5 * time.Second
  52. dialer := &net.Dialer{
  53. Timeout: maxTime,
  54. }
  55. conn, err := dialer.DialContext(ctx, "udp", address+":1434")
  56. if err != nil {
  57. return nil, err
  58. }
  59. defer conn.Close()
  60. conn.SetDeadline(time.Now().Add(maxTime))
  61. _, err = conn.Write([]byte{3})
  62. if err != nil {
  63. return nil, err
  64. }
  65. var resp = make([]byte, 16*1024-1)
  66. read, err := conn.Read(resp)
  67. if err != nil {
  68. return nil, err
  69. }
  70. return parseInstances(resp[:read]), nil
  71. }
  72. // tds versions
  73. const (
  74. verTDS70 = 0x70000000
  75. verTDS71 = 0x71000000
  76. verTDS71rev1 = 0x71000001
  77. verTDS72 = 0x72090002
  78. verTDS73A = 0x730A0003
  79. verTDS73 = verTDS73A
  80. verTDS73B = 0x730B0003
  81. verTDS74 = 0x74000004
  82. )
  83. // packet types
  84. // https://msdn.microsoft.com/en-us/library/dd304214.aspx
  85. const (
  86. packSQLBatch packetType = 1
  87. packRPCRequest = 3
  88. packReply = 4
  89. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  90. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  91. packAttention = 6
  92. packBulkLoadBCP = 7
  93. packTransMgrReq = 14
  94. packNormal = 15
  95. packLogin7 = 16
  96. packSSPIMessage = 17
  97. packPrelogin = 18
  98. )
  99. // prelogin fields
  100. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  101. const (
  102. preloginVERSION = 0
  103. preloginENCRYPTION = 1
  104. preloginINSTOPT = 2
  105. preloginTHREADID = 3
  106. preloginMARS = 4
  107. preloginTRACEID = 5
  108. preloginTERMINATOR = 0xff
  109. )
  110. const (
  111. encryptOff = 0 // Encryption is available but off.
  112. encryptOn = 1 // Encryption is available and on.
  113. encryptNotSup = 2 // Encryption is not available.
  114. encryptReq = 3 // Encryption is required.
  115. )
  116. type tdsSession struct {
  117. buf *tdsBuffer
  118. loginAck loginAckStruct
  119. database string
  120. partner string
  121. columns []columnStruct
  122. tranid uint64
  123. logFlags uint64
  124. log optionalLogger
  125. routedServer string
  126. routedPort uint16
  127. }
  128. const (
  129. logErrors = 1
  130. logMessages = 2
  131. logRows = 4
  132. logSQL = 8
  133. logParams = 16
  134. logTransaction = 32
  135. logDebug = 64
  136. )
  137. type columnStruct struct {
  138. UserType uint32
  139. Flags uint16
  140. ColName string
  141. ti typeInfo
  142. }
  143. type keySlice []uint8
  144. func (p keySlice) Len() int { return len(p) }
  145. func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
  146. func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
  147. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  148. func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
  149. var err error
  150. w.BeginPacket(packPrelogin, false)
  151. offset := uint16(5*len(fields) + 1)
  152. keys := make(keySlice, 0, len(fields))
  153. for k, _ := range fields {
  154. keys = append(keys, k)
  155. }
  156. sort.Sort(keys)
  157. // writing header
  158. for _, k := range keys {
  159. err = w.WriteByte(k)
  160. if err != nil {
  161. return err
  162. }
  163. err = binary.Write(w, binary.BigEndian, offset)
  164. if err != nil {
  165. return err
  166. }
  167. v := fields[k]
  168. size := uint16(len(v))
  169. err = binary.Write(w, binary.BigEndian, size)
  170. if err != nil {
  171. return err
  172. }
  173. offset += size
  174. }
  175. err = w.WriteByte(preloginTERMINATOR)
  176. if err != nil {
  177. return err
  178. }
  179. // writing values
  180. for _, k := range keys {
  181. v := fields[k]
  182. written, err := w.Write(v)
  183. if err != nil {
  184. return err
  185. }
  186. if written != len(v) {
  187. return errors.New("Write method didn't write the whole value")
  188. }
  189. }
  190. return w.FinishPacket()
  191. }
  192. func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
  193. packet_type, err := r.BeginRead()
  194. if err != nil {
  195. return nil, err
  196. }
  197. struct_buf, err := ioutil.ReadAll(r)
  198. if err != nil {
  199. return nil, err
  200. }
  201. if packet_type != 4 {
  202. return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
  203. }
  204. offset := 0
  205. results := map[uint8][]byte{}
  206. for true {
  207. rec_type := struct_buf[offset]
  208. if rec_type == preloginTERMINATOR {
  209. break
  210. }
  211. rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
  212. rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
  213. value := struct_buf[rec_offset : rec_offset+rec_len]
  214. results[rec_type] = value
  215. offset += 5
  216. }
  217. return results, nil
  218. }
  219. // OptionFlags2
  220. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  221. const (
  222. fLanguageFatal = 1
  223. fODBC = 2
  224. fTransBoundary = 4
  225. fCacheConnect = 8
  226. fIntSecurity = 0x80
  227. )
  228. // TypeFlags
  229. const (
  230. // 4 bits for fSQLType
  231. // 1 bit for fOLEDB
  232. fReadOnlyIntent = 32
  233. )
  234. type login struct {
  235. TDSVersion uint32
  236. PacketSize uint32
  237. ClientProgVer uint32
  238. ClientPID uint32
  239. ConnectionID uint32
  240. OptionFlags1 uint8
  241. OptionFlags2 uint8
  242. TypeFlags uint8
  243. OptionFlags3 uint8
  244. ClientTimeZone int32
  245. ClientLCID uint32
  246. HostName string
  247. UserName string
  248. Password string
  249. AppName string
  250. ServerName string
  251. CtlIntName string
  252. Language string
  253. Database string
  254. ClientID [6]byte
  255. SSPI []byte
  256. AtchDBFile string
  257. ChangePassword string
  258. }
  259. type loginHeader struct {
  260. Length uint32
  261. TDSVersion uint32
  262. PacketSize uint32
  263. ClientProgVer uint32
  264. ClientPID uint32
  265. ConnectionID uint32
  266. OptionFlags1 uint8
  267. OptionFlags2 uint8
  268. TypeFlags uint8
  269. OptionFlags3 uint8
  270. ClientTimeZone int32
  271. ClientLCID uint32
  272. HostNameOffset uint16
  273. HostNameLength uint16
  274. UserNameOffset uint16
  275. UserNameLength uint16
  276. PasswordOffset uint16
  277. PasswordLength uint16
  278. AppNameOffset uint16
  279. AppNameLength uint16
  280. ServerNameOffset uint16
  281. ServerNameLength uint16
  282. ExtensionOffset uint16
  283. ExtensionLenght uint16
  284. CtlIntNameOffset uint16
  285. CtlIntNameLength uint16
  286. LanguageOffset uint16
  287. LanguageLength uint16
  288. DatabaseOffset uint16
  289. DatabaseLength uint16
  290. ClientID [6]byte
  291. SSPIOffset uint16
  292. SSPILength uint16
  293. AtchDBFileOffset uint16
  294. AtchDBFileLength uint16
  295. ChangePasswordOffset uint16
  296. ChangePasswordLength uint16
  297. SSPILongLength uint32
  298. }
  299. // convert Go string to UTF-16 encoded []byte (littleEndian)
  300. // done manually rather than using bytes and binary packages
  301. // for performance reasons
  302. func str2ucs2(s string) []byte {
  303. res := utf16.Encode([]rune(s))
  304. ucs2 := make([]byte, 2*len(res))
  305. for i := 0; i < len(res); i++ {
  306. ucs2[2*i] = byte(res[i])
  307. ucs2[2*i+1] = byte(res[i] >> 8)
  308. }
  309. return ucs2
  310. }
  311. func ucs22str(s []byte) (string, error) {
  312. if len(s)%2 != 0 {
  313. return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
  314. }
  315. buf := make([]uint16, len(s)/2)
  316. for i := 0; i < len(s); i += 2 {
  317. buf[i/2] = binary.LittleEndian.Uint16(s[i:])
  318. }
  319. return string(utf16.Decode(buf)), nil
  320. }
  321. func manglePassword(password string) []byte {
  322. var ucs2password []byte = str2ucs2(password)
  323. for i, ch := range ucs2password {
  324. ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
  325. }
  326. return ucs2password
  327. }
  328. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  329. func sendLogin(w *tdsBuffer, login login) error {
  330. w.BeginPacket(packLogin7, false)
  331. hostname := str2ucs2(login.HostName)
  332. username := str2ucs2(login.UserName)
  333. password := manglePassword(login.Password)
  334. appname := str2ucs2(login.AppName)
  335. servername := str2ucs2(login.ServerName)
  336. ctlintname := str2ucs2(login.CtlIntName)
  337. language := str2ucs2(login.Language)
  338. database := str2ucs2(login.Database)
  339. atchdbfile := str2ucs2(login.AtchDBFile)
  340. changepassword := str2ucs2(login.ChangePassword)
  341. hdr := loginHeader{
  342. TDSVersion: login.TDSVersion,
  343. PacketSize: login.PacketSize,
  344. ClientProgVer: login.ClientProgVer,
  345. ClientPID: login.ClientPID,
  346. ConnectionID: login.ConnectionID,
  347. OptionFlags1: login.OptionFlags1,
  348. OptionFlags2: login.OptionFlags2,
  349. TypeFlags: login.TypeFlags,
  350. OptionFlags3: login.OptionFlags3,
  351. ClientTimeZone: login.ClientTimeZone,
  352. ClientLCID: login.ClientLCID,
  353. HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
  354. UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
  355. PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
  356. AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
  357. ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
  358. CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
  359. LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
  360. DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
  361. ClientID: login.ClientID,
  362. SSPILength: uint16(len(login.SSPI)),
  363. AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
  364. ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
  365. }
  366. offset := uint16(binary.Size(hdr))
  367. hdr.HostNameOffset = offset
  368. offset += uint16(len(hostname))
  369. hdr.UserNameOffset = offset
  370. offset += uint16(len(username))
  371. hdr.PasswordOffset = offset
  372. offset += uint16(len(password))
  373. hdr.AppNameOffset = offset
  374. offset += uint16(len(appname))
  375. hdr.ServerNameOffset = offset
  376. offset += uint16(len(servername))
  377. hdr.CtlIntNameOffset = offset
  378. offset += uint16(len(ctlintname))
  379. hdr.LanguageOffset = offset
  380. offset += uint16(len(language))
  381. hdr.DatabaseOffset = offset
  382. offset += uint16(len(database))
  383. hdr.SSPIOffset = offset
  384. offset += uint16(len(login.SSPI))
  385. hdr.AtchDBFileOffset = offset
  386. offset += uint16(len(atchdbfile))
  387. hdr.ChangePasswordOffset = offset
  388. offset += uint16(len(changepassword))
  389. hdr.Length = uint32(offset)
  390. var err error
  391. err = binary.Write(w, binary.LittleEndian, &hdr)
  392. if err != nil {
  393. return err
  394. }
  395. _, err = w.Write(hostname)
  396. if err != nil {
  397. return err
  398. }
  399. _, err = w.Write(username)
  400. if err != nil {
  401. return err
  402. }
  403. _, err = w.Write(password)
  404. if err != nil {
  405. return err
  406. }
  407. _, err = w.Write(appname)
  408. if err != nil {
  409. return err
  410. }
  411. _, err = w.Write(servername)
  412. if err != nil {
  413. return err
  414. }
  415. _, err = w.Write(ctlintname)
  416. if err != nil {
  417. return err
  418. }
  419. _, err = w.Write(language)
  420. if err != nil {
  421. return err
  422. }
  423. _, err = w.Write(database)
  424. if err != nil {
  425. return err
  426. }
  427. _, err = w.Write(login.SSPI)
  428. if err != nil {
  429. return err
  430. }
  431. _, err = w.Write(atchdbfile)
  432. if err != nil {
  433. return err
  434. }
  435. _, err = w.Write(changepassword)
  436. if err != nil {
  437. return err
  438. }
  439. return w.FinishPacket()
  440. }
  441. func readUcs2(r io.Reader, numchars int) (res string, err error) {
  442. buf := make([]byte, numchars*2)
  443. _, err = io.ReadFull(r, buf)
  444. if err != nil {
  445. return "", err
  446. }
  447. return ucs22str(buf)
  448. }
  449. func readUsVarChar(r io.Reader) (res string, err error) {
  450. var numchars uint16
  451. err = binary.Read(r, binary.LittleEndian, &numchars)
  452. if err != nil {
  453. return "", err
  454. }
  455. return readUcs2(r, int(numchars))
  456. }
  457. func writeUsVarChar(w io.Writer, s string) (err error) {
  458. buf := str2ucs2(s)
  459. var numchars int = len(buf) / 2
  460. if numchars > 0xffff {
  461. panic("invalid size for US_VARCHAR")
  462. }
  463. err = binary.Write(w, binary.LittleEndian, uint16(numchars))
  464. if err != nil {
  465. return
  466. }
  467. _, err = w.Write(buf)
  468. return
  469. }
  470. func readBVarChar(r io.Reader) (res string, err error) {
  471. var numchars uint8
  472. err = binary.Read(r, binary.LittleEndian, &numchars)
  473. if err != nil {
  474. return "", err
  475. }
  476. // A zero length could be returned, return an empty string
  477. if numchars == 0 {
  478. return "", nil
  479. }
  480. return readUcs2(r, int(numchars))
  481. }
  482. func writeBVarChar(w io.Writer, s string) (err error) {
  483. buf := str2ucs2(s)
  484. var numchars int = len(buf) / 2
  485. if numchars > 0xff {
  486. panic("invalid size for B_VARCHAR")
  487. }
  488. err = binary.Write(w, binary.LittleEndian, uint8(numchars))
  489. if err != nil {
  490. return
  491. }
  492. _, err = w.Write(buf)
  493. return
  494. }
  495. func readBVarByte(r io.Reader) (res []byte, err error) {
  496. var length uint8
  497. err = binary.Read(r, binary.LittleEndian, &length)
  498. if err != nil {
  499. return
  500. }
  501. res = make([]byte, length)
  502. _, err = io.ReadFull(r, res)
  503. return
  504. }
  505. func readUshort(r io.Reader) (res uint16, err error) {
  506. err = binary.Read(r, binary.LittleEndian, &res)
  507. return
  508. }
  509. func readByte(r io.Reader) (res byte, err error) {
  510. var b [1]byte
  511. _, err = r.Read(b[:])
  512. res = b[0]
  513. return
  514. }
  515. // Packet Data Stream Headers
  516. // http://msdn.microsoft.com/en-us/library/dd304953.aspx
  517. type headerStruct struct {
  518. hdrtype uint16
  519. data []byte
  520. }
  521. const (
  522. dataStmHdrQueryNotif = 1 // query notifications
  523. dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
  524. dataStmHdrTraceActivity = 3
  525. )
  526. // Query Notifications Header
  527. // http://msdn.microsoft.com/en-us/library/dd304949.aspx
  528. type queryNotifHdr struct {
  529. notifyId string
  530. ssbDeployment string
  531. notifyTimeout uint32
  532. }
  533. func (hdr queryNotifHdr) pack() (res []byte) {
  534. notifyId := str2ucs2(hdr.notifyId)
  535. ssbDeployment := str2ucs2(hdr.ssbDeployment)
  536. res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
  537. b := res
  538. binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
  539. b = b[2:]
  540. copy(b, notifyId)
  541. b = b[len(notifyId):]
  542. binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
  543. b = b[2:]
  544. copy(b, ssbDeployment)
  545. b = b[len(ssbDeployment):]
  546. binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
  547. return res
  548. }
  549. // MARS Transaction Descriptor Header
  550. // http://msdn.microsoft.com/en-us/library/dd340515.aspx
  551. type transDescrHdr struct {
  552. transDescr uint64 // transaction descriptor returned from ENVCHANGE
  553. outstandingReqCnt uint32 // outstanding request count
  554. }
  555. func (hdr transDescrHdr) pack() (res []byte) {
  556. res = make([]byte, 8+4)
  557. binary.LittleEndian.PutUint64(res, hdr.transDescr)
  558. binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
  559. return res
  560. }
  561. func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
  562. // Calculating total length.
  563. var totallen uint32 = 4
  564. for _, hdr := range headers {
  565. totallen += 4 + 2 + uint32(len(hdr.data))
  566. }
  567. // writing
  568. err = binary.Write(w, binary.LittleEndian, totallen)
  569. if err != nil {
  570. return err
  571. }
  572. for _, hdr := range headers {
  573. var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
  574. err = binary.Write(w, binary.LittleEndian, headerlen)
  575. if err != nil {
  576. return err
  577. }
  578. err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
  579. if err != nil {
  580. return err
  581. }
  582. _, err = w.Write(hdr.data)
  583. if err != nil {
  584. return err
  585. }
  586. }
  587. return nil
  588. }
  589. func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
  590. buf.BeginPacket(packSQLBatch, resetSession)
  591. if err = writeAllHeaders(buf, headers); err != nil {
  592. return
  593. }
  594. _, err = buf.Write(str2ucs2(sqltext))
  595. if err != nil {
  596. return
  597. }
  598. return buf.FinishPacket()
  599. }
  600. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  601. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  602. func sendAttention(buf *tdsBuffer) error {
  603. buf.BeginPacket(packAttention, false)
  604. return buf.FinishPacket()
  605. }
  606. type connectParams struct {
  607. logFlags uint64
  608. port uint64
  609. host string
  610. instance string
  611. database string
  612. user string
  613. password string
  614. dial_timeout time.Duration
  615. conn_timeout time.Duration
  616. keepAlive time.Duration
  617. encrypt bool
  618. disableEncryption bool
  619. trustServerCertificate bool
  620. certificate string
  621. hostInCertificate string
  622. serverSPN string
  623. workstation string
  624. appname string
  625. typeFlags uint8
  626. failOverPartner string
  627. failOverPort uint64
  628. packetSize uint16
  629. }
  630. func splitConnectionString(dsn string) (res map[string]string) {
  631. res = map[string]string{}
  632. parts := strings.Split(dsn, ";")
  633. for _, part := range parts {
  634. if len(part) == 0 {
  635. continue
  636. }
  637. lst := strings.SplitN(part, "=", 2)
  638. name := strings.TrimSpace(strings.ToLower(lst[0]))
  639. if len(name) == 0 {
  640. continue
  641. }
  642. var value string = ""
  643. if len(lst) > 1 {
  644. value = strings.TrimSpace(lst[1])
  645. }
  646. res[name] = value
  647. }
  648. return res
  649. }
  650. // Splits a URL in the ODBC format
  651. func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
  652. res := map[string]string{}
  653. type parserState int
  654. const (
  655. // Before the start of a key
  656. parserStateBeforeKey parserState = iota
  657. // Inside a key
  658. parserStateKey
  659. // Beginning of a value. May be bare or braced
  660. parserStateBeginValue
  661. // Inside a bare value
  662. parserStateBareValue
  663. // Inside a braced value
  664. parserStateBracedValue
  665. // A closing brace inside a braced value.
  666. // May be the end of the value or an escaped closing brace, depending on the next character
  667. parserStateBracedValueClosingBrace
  668. // After a value. Next character should be a semicolon or whitespace.
  669. parserStateEndValue
  670. )
  671. var state = parserStateBeforeKey
  672. var key string
  673. var value string
  674. for i, c := range dsn {
  675. switch state {
  676. case parserStateBeforeKey:
  677. switch {
  678. case c == '=':
  679. return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
  680. case !unicode.IsSpace(c) && c != ';':
  681. state = parserStateKey
  682. key += string(c)
  683. }
  684. case parserStateKey:
  685. switch c {
  686. case '=':
  687. key = normalizeOdbcKey(key)
  688. if len(key) == 0 {
  689. return res, fmt.Errorf("Unexpected end of key at index %d.", i)
  690. }
  691. state = parserStateBeginValue
  692. case ';':
  693. // Key without value
  694. key = normalizeOdbcKey(key)
  695. if len(key) == 0 {
  696. return res, fmt.Errorf("Unexpected end of key at index %d.", i)
  697. }
  698. res[key] = value
  699. key = ""
  700. value = ""
  701. state = parserStateBeforeKey
  702. default:
  703. key += string(c)
  704. }
  705. case parserStateBeginValue:
  706. switch {
  707. case c == '{':
  708. state = parserStateBracedValue
  709. case c == ';':
  710. // Empty value
  711. res[key] = value
  712. key = ""
  713. state = parserStateBeforeKey
  714. case unicode.IsSpace(c):
  715. // Ignore whitespace
  716. default:
  717. state = parserStateBareValue
  718. value += string(c)
  719. }
  720. case parserStateBareValue:
  721. if c == ';' {
  722. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  723. key = ""
  724. value = ""
  725. state = parserStateBeforeKey
  726. } else {
  727. value += string(c)
  728. }
  729. case parserStateBracedValue:
  730. if c == '}' {
  731. state = parserStateBracedValueClosingBrace
  732. } else {
  733. value += string(c)
  734. }
  735. case parserStateBracedValueClosingBrace:
  736. if c == '}' {
  737. // Escaped closing brace
  738. value += string(c)
  739. state = parserStateBracedValue
  740. continue
  741. }
  742. // End of braced value
  743. res[key] = value
  744. key = ""
  745. value = ""
  746. // This character is the first character past the end,
  747. // so it needs to be parsed like the parserStateEndValue state.
  748. state = parserStateEndValue
  749. switch {
  750. case c == ';':
  751. state = parserStateBeforeKey
  752. case unicode.IsSpace(c):
  753. // Ignore whitespace
  754. default:
  755. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  756. }
  757. case parserStateEndValue:
  758. switch {
  759. case c == ';':
  760. state = parserStateBeforeKey
  761. case unicode.IsSpace(c):
  762. // Ignore whitespace
  763. default:
  764. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  765. }
  766. }
  767. }
  768. switch state {
  769. case parserStateBeforeKey: // Okay
  770. case parserStateKey: // Unfinished key. Treat as key without value.
  771. key = normalizeOdbcKey(key)
  772. if len(key) == 0 {
  773. return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
  774. }
  775. res[key] = value
  776. case parserStateBeginValue: // Empty value
  777. res[key] = value
  778. case parserStateBareValue:
  779. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  780. case parserStateBracedValue:
  781. return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
  782. case parserStateBracedValueClosingBrace: // End of braced value
  783. res[key] = value
  784. case parserStateEndValue: // Okay
  785. }
  786. return res, nil
  787. }
  788. // Normalizes the given string as an ODBC-format key
  789. func normalizeOdbcKey(s string) string {
  790. return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
  791. }
  792. // Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
  793. func splitConnectionStringURL(dsn string) (map[string]string, error) {
  794. res := map[string]string{}
  795. u, err := url.Parse(dsn)
  796. if err != nil {
  797. return res, err
  798. }
  799. if u.Scheme != "sqlserver" {
  800. return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
  801. }
  802. if u.User != nil {
  803. res["user id"] = u.User.Username()
  804. p, exists := u.User.Password()
  805. if exists {
  806. res["password"] = p
  807. }
  808. }
  809. host, port, err := net.SplitHostPort(u.Host)
  810. if err != nil {
  811. host = u.Host
  812. }
  813. if len(u.Path) > 0 {
  814. res["server"] = host + "\\" + u.Path[1:]
  815. } else {
  816. res["server"] = host
  817. }
  818. if len(port) > 0 {
  819. res["port"] = port
  820. }
  821. query := u.Query()
  822. for k, v := range query {
  823. if len(v) > 1 {
  824. return res, fmt.Errorf("key %s provided more than once", k)
  825. }
  826. res[strings.ToLower(k)] = v[0]
  827. }
  828. return res, nil
  829. }
  830. func parseConnectParams(dsn string) (connectParams, error) {
  831. var p connectParams
  832. var params map[string]string
  833. if strings.HasPrefix(dsn, "odbc:") {
  834. parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
  835. if err != nil {
  836. return p, err
  837. }
  838. params = parameters
  839. } else if strings.HasPrefix(dsn, "sqlserver://") {
  840. parameters, err := splitConnectionStringURL(dsn)
  841. if err != nil {
  842. return p, err
  843. }
  844. params = parameters
  845. } else {
  846. params = splitConnectionString(dsn)
  847. }
  848. strlog, ok := params["log"]
  849. if ok {
  850. var err error
  851. p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
  852. if err != nil {
  853. return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
  854. }
  855. }
  856. server := params["server"]
  857. parts := strings.SplitN(server, `\`, 2)
  858. p.host = parts[0]
  859. if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
  860. p.host = "localhost"
  861. }
  862. if len(parts) > 1 {
  863. p.instance = parts[1]
  864. }
  865. p.database = params["database"]
  866. p.user = params["user id"]
  867. p.password = params["password"]
  868. p.port = 1433
  869. strport, ok := params["port"]
  870. if ok {
  871. var err error
  872. p.port, err = strconv.ParseUint(strport, 10, 16)
  873. if err != nil {
  874. f := "Invalid tcp port '%v': %v"
  875. return p, fmt.Errorf(f, strport, err.Error())
  876. }
  877. }
  878. // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
  879. // Default packet size remains at 4096 bytes
  880. p.packetSize = 4096
  881. strpsize, ok := params["packet size"]
  882. if ok {
  883. var err error
  884. psize, err := strconv.ParseUint(strpsize, 0, 16)
  885. if err != nil {
  886. f := "Invalid packet size '%v': %v"
  887. return p, fmt.Errorf(f, strpsize, err.Error())
  888. }
  889. // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
  890. // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
  891. // a higher packet size, the server will respond with an ENVCHANGE request to
  892. // alter the packet size to 16383 bytes.
  893. p.packetSize = uint16(psize)
  894. if p.packetSize < 512 {
  895. p.packetSize = 512
  896. } else if p.packetSize > 32767 {
  897. p.packetSize = 32767
  898. }
  899. }
  900. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  901. //
  902. // Do not set a connection timeout. Use Context to manage such things.
  903. // Default to zero, but still allow it to be set.
  904. if strconntimeout, ok := params["connection timeout"]; ok {
  905. timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
  906. if err != nil {
  907. f := "Invalid connection timeout '%v': %v"
  908. return p, fmt.Errorf(f, strconntimeout, err.Error())
  909. }
  910. p.conn_timeout = time.Duration(timeout) * time.Second
  911. }
  912. p.dial_timeout = 15 * time.Second
  913. if strdialtimeout, ok := params["dial timeout"]; ok {
  914. timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
  915. if err != nil {
  916. f := "Invalid dial timeout '%v': %v"
  917. return p, fmt.Errorf(f, strdialtimeout, err.Error())
  918. }
  919. p.dial_timeout = time.Duration(timeout) * time.Second
  920. }
  921. // default keep alive should be 30 seconds according to spec:
  922. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  923. p.keepAlive = 30 * time.Second
  924. if keepAlive, ok := params["keepalive"]; ok {
  925. timeout, err := strconv.ParseUint(keepAlive, 10, 64)
  926. if err != nil {
  927. f := "Invalid keepAlive value '%s': %s"
  928. return p, fmt.Errorf(f, keepAlive, err.Error())
  929. }
  930. p.keepAlive = time.Duration(timeout) * time.Second
  931. }
  932. encrypt, ok := params["encrypt"]
  933. if ok {
  934. if strings.EqualFold(encrypt, "DISABLE") {
  935. p.disableEncryption = true
  936. } else {
  937. var err error
  938. p.encrypt, err = strconv.ParseBool(encrypt)
  939. if err != nil {
  940. f := "Invalid encrypt '%s': %s"
  941. return p, fmt.Errorf(f, encrypt, err.Error())
  942. }
  943. }
  944. } else {
  945. p.trustServerCertificate = true
  946. }
  947. trust, ok := params["trustservercertificate"]
  948. if ok {
  949. var err error
  950. p.trustServerCertificate, err = strconv.ParseBool(trust)
  951. if err != nil {
  952. f := "Invalid trust server certificate '%s': %s"
  953. return p, fmt.Errorf(f, trust, err.Error())
  954. }
  955. }
  956. p.certificate = params["certificate"]
  957. p.hostInCertificate, ok = params["hostnameincertificate"]
  958. if !ok {
  959. p.hostInCertificate = p.host
  960. }
  961. serverSPN, ok := params["serverspn"]
  962. if ok {
  963. p.serverSPN = serverSPN
  964. } else {
  965. p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
  966. }
  967. workstation, ok := params["workstation id"]
  968. if ok {
  969. p.workstation = workstation
  970. } else {
  971. workstation, err := os.Hostname()
  972. if err == nil {
  973. p.workstation = workstation
  974. }
  975. }
  976. appname, ok := params["app name"]
  977. if !ok {
  978. appname = "go-mssqldb"
  979. }
  980. p.appname = appname
  981. appintent, ok := params["applicationintent"]
  982. if ok {
  983. if appintent == "ReadOnly" {
  984. p.typeFlags |= fReadOnlyIntent
  985. }
  986. }
  987. failOverPartner, ok := params["failoverpartner"]
  988. if ok {
  989. p.failOverPartner = failOverPartner
  990. }
  991. failOverPort, ok := params["failoverport"]
  992. if ok {
  993. var err error
  994. p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
  995. if err != nil {
  996. f := "Invalid tcp port '%v': %v"
  997. return p, fmt.Errorf(f, failOverPort, err.Error())
  998. }
  999. }
  1000. return p, nil
  1001. }
  1002. type auth interface {
  1003. InitialBytes() ([]byte, error)
  1004. NextBytes([]byte) ([]byte, error)
  1005. Free()
  1006. }
  1007. // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
  1008. // list of IP addresses. So if there is more than one, try them all and
  1009. // use the first one that allows a connection.
  1010. func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) {
  1011. var ips []net.IP
  1012. ips, err = net.LookupIP(p.host)
  1013. if err != nil {
  1014. ip := net.ParseIP(p.host)
  1015. if ip == nil {
  1016. return nil, err
  1017. }
  1018. ips = []net.IP{ip}
  1019. }
  1020. if len(ips) == 1 {
  1021. d := createDialer(&p)
  1022. addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
  1023. conn, err = d.Dial(ctx, addr)
  1024. } else {
  1025. //Try Dials in parallel to avoid waiting for timeouts.
  1026. connChan := make(chan net.Conn, len(ips))
  1027. errChan := make(chan error, len(ips))
  1028. portStr := strconv.Itoa(int(p.port))
  1029. for _, ip := range ips {
  1030. go func(ip net.IP) {
  1031. d := createDialer(&p)
  1032. addr := net.JoinHostPort(ip.String(), portStr)
  1033. conn, err := d.Dial(ctx, addr)
  1034. if err == nil {
  1035. connChan <- conn
  1036. } else {
  1037. errChan <- err
  1038. }
  1039. }(ip)
  1040. }
  1041. // Wait for either the *first* successful connection, or all the errors
  1042. wait_loop:
  1043. for i, _ := range ips {
  1044. select {
  1045. case conn = <-connChan:
  1046. // Got a connection to use, close any others
  1047. go func(n int) {
  1048. for i := 0; i < n; i++ {
  1049. select {
  1050. case conn := <-connChan:
  1051. conn.Close()
  1052. case <-errChan:
  1053. }
  1054. }
  1055. }(len(ips) - i - 1)
  1056. // Remove any earlier errors we may have collected
  1057. err = nil
  1058. break wait_loop
  1059. case err = <-errChan:
  1060. }
  1061. }
  1062. }
  1063. // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
  1064. if conn == nil {
  1065. f := "Unable to open tcp connection with host '%v:%v': %v"
  1066. return nil, fmt.Errorf(f, p.host, p.port, err.Error())
  1067. }
  1068. return conn, err
  1069. }
  1070. func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) {
  1071. dialCtx := ctx
  1072. if p.dial_timeout > 0 {
  1073. var cancel func()
  1074. dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
  1075. defer cancel()
  1076. }
  1077. // if instance is specified use instance resolution service
  1078. if p.instance != "" {
  1079. p.instance = strings.ToUpper(p.instance)
  1080. instances, err := getInstances(dialCtx, p.host)
  1081. if err != nil {
  1082. f := "Unable to get instances from Sql Server Browser on host %v: %v"
  1083. return nil, fmt.Errorf(f, p.host, err.Error())
  1084. }
  1085. strport, ok := instances[p.instance]["tcp"]
  1086. if !ok {
  1087. f := "No instance matching '%v' returned from host '%v'"
  1088. return nil, fmt.Errorf(f, p.instance, p.host)
  1089. }
  1090. p.port, err = strconv.ParseUint(strport, 0, 16)
  1091. if err != nil {
  1092. f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
  1093. return nil, fmt.Errorf(f, strport, err.Error())
  1094. }
  1095. }
  1096. initiate_connection:
  1097. conn, err := dialConnection(dialCtx, p)
  1098. if err != nil {
  1099. return nil, err
  1100. }
  1101. toconn := newTimeoutConn(conn, p.conn_timeout)
  1102. outbuf := newTdsBuffer(p.packetSize, toconn)
  1103. sess := tdsSession{
  1104. buf: outbuf,
  1105. log: log,
  1106. logFlags: p.logFlags,
  1107. }
  1108. instance_buf := []byte(p.instance)
  1109. instance_buf = append(instance_buf, 0) // zero terminate instance name
  1110. var encrypt byte
  1111. if p.disableEncryption {
  1112. encrypt = encryptNotSup
  1113. } else if p.encrypt {
  1114. encrypt = encryptOn
  1115. } else {
  1116. encrypt = encryptOff
  1117. }
  1118. fields := map[uint8][]byte{
  1119. preloginVERSION: {0, 0, 0, 0, 0, 0},
  1120. preloginENCRYPTION: {encrypt},
  1121. preloginINSTOPT: instance_buf,
  1122. preloginTHREADID: {0, 0, 0, 0},
  1123. preloginMARS: {0}, // MARS disabled
  1124. }
  1125. err = writePrelogin(outbuf, fields)
  1126. if err != nil {
  1127. return nil, err
  1128. }
  1129. fields, err = readPrelogin(outbuf)
  1130. if err != nil {
  1131. return nil, err
  1132. }
  1133. encryptBytes, ok := fields[preloginENCRYPTION]
  1134. if !ok {
  1135. return nil, fmt.Errorf("Encrypt negotiation failed")
  1136. }
  1137. encrypt = encryptBytes[0]
  1138. if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
  1139. return nil, fmt.Errorf("Server does not support encryption")
  1140. }
  1141. if encrypt != encryptNotSup {
  1142. var config tls.Config
  1143. if p.certificate != "" {
  1144. pem, err := ioutil.ReadFile(p.certificate)
  1145. if err != nil {
  1146. return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
  1147. }
  1148. certs := x509.NewCertPool()
  1149. certs.AppendCertsFromPEM(pem)
  1150. config.RootCAs = certs
  1151. }
  1152. if p.trustServerCertificate {
  1153. config.InsecureSkipVerify = true
  1154. }
  1155. config.ServerName = p.hostInCertificate
  1156. // fix for https://github.com/denisenkom/go-mssqldb/issues/166
  1157. // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
  1158. // while SQL Server seems to expect one TCP segment per encrypted TDS package.
  1159. // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
  1160. config.DynamicRecordSizingDisabled = true
  1161. outbuf.transport = conn
  1162. toconn.buf = outbuf
  1163. tlsConn := tls.Client(toconn, &config)
  1164. err = tlsConn.Handshake()
  1165. toconn.buf = nil
  1166. outbuf.transport = tlsConn
  1167. if err != nil {
  1168. return nil, fmt.Errorf("TLS Handshake failed: %v", err)
  1169. }
  1170. if encrypt == encryptOff {
  1171. outbuf.afterFirst = func() {
  1172. outbuf.transport = toconn
  1173. }
  1174. }
  1175. }
  1176. login := login{
  1177. TDSVersion: verTDS74,
  1178. PacketSize: uint32(outbuf.PackageSize()),
  1179. Database: p.database,
  1180. OptionFlags2: fODBC, // to get unlimited TEXTSIZE
  1181. HostName: p.workstation,
  1182. ServerName: p.host,
  1183. AppName: p.appname,
  1184. TypeFlags: p.typeFlags,
  1185. }
  1186. auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
  1187. if auth_ok {
  1188. login.SSPI, err = auth.InitialBytes()
  1189. if err != nil {
  1190. return nil, err
  1191. }
  1192. login.OptionFlags2 |= fIntSecurity
  1193. defer auth.Free()
  1194. } else {
  1195. login.UserName = p.user
  1196. login.Password = p.password
  1197. }
  1198. err = sendLogin(outbuf, login)
  1199. if err != nil {
  1200. return nil, err
  1201. }
  1202. // processing login response
  1203. var sspi_msg []byte
  1204. continue_login:
  1205. tokchan := make(chan tokenStruct, 5)
  1206. go processResponse(context.Background(), &sess, tokchan, nil)
  1207. success := false
  1208. for tok := range tokchan {
  1209. switch token := tok.(type) {
  1210. case sspiMsg:
  1211. sspi_msg, err = auth.NextBytes(token)
  1212. if err != nil {
  1213. return nil, err
  1214. }
  1215. case loginAckStruct:
  1216. success = true
  1217. sess.loginAck = token
  1218. case error:
  1219. return nil, fmt.Errorf("Login error: %s", token.Error())
  1220. case doneStruct:
  1221. if token.isError() {
  1222. return nil, fmt.Errorf("Login error: %s", token.getError())
  1223. }
  1224. }
  1225. }
  1226. if sspi_msg != nil {
  1227. outbuf.BeginPacket(packSSPIMessage, false)
  1228. _, err = outbuf.Write(sspi_msg)
  1229. if err != nil {
  1230. return nil, err
  1231. }
  1232. err = outbuf.FinishPacket()
  1233. if err != nil {
  1234. return nil, err
  1235. }
  1236. sspi_msg = nil
  1237. goto continue_login
  1238. }
  1239. if !success {
  1240. return nil, fmt.Errorf("Login failed")
  1241. }
  1242. if sess.routedServer != "" {
  1243. toconn.Close()
  1244. p.host = sess.routedServer
  1245. p.port = uint64(sess.routedPort)
  1246. goto initiate_connection
  1247. }
  1248. return &sess, nil
  1249. }