mssql.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. package mssql
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "math"
  11. "net"
  12. "reflect"
  13. "strings"
  14. "time"
  15. )
  16. var driverInstance = &Driver{processQueryText: true}
  17. var driverInstanceNoProcess = &Driver{processQueryText: false}
  18. func init() {
  19. sql.Register("mssql", driverInstance)
  20. sql.Register("sqlserver", driverInstanceNoProcess)
  21. createDialer = func(p *connectParams) dialer {
  22. return tcpDialer{&net.Dialer{KeepAlive: p.keepAlive}}
  23. }
  24. }
  25. // Abstract the dialer for testing and for non-TCP based connections.
  26. type dialer interface {
  27. Dial(ctx context.Context, addr string) (net.Conn, error)
  28. }
  29. var createDialer func(p *connectParams) dialer
  30. type tcpDialer struct {
  31. nd *net.Dialer
  32. }
  33. func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
  34. return d.nd.DialContext(ctx, "tcp", addr)
  35. }
  36. type Driver struct {
  37. log optionalLogger
  38. processQueryText bool
  39. }
  40. // OpenConnector opens a new connector. Useful to dial with a context.
  41. func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
  42. params, err := parseConnectParams(dsn)
  43. if err != nil {
  44. return nil, err
  45. }
  46. return &Connector{
  47. params: params,
  48. driver: d,
  49. }, nil
  50. }
  51. func (d *Driver) Open(dsn string) (driver.Conn, error) {
  52. return d.open(context.Background(), dsn)
  53. }
  54. func SetLogger(logger Logger) {
  55. driverInstance.SetLogger(logger)
  56. driverInstanceNoProcess.SetLogger(logger)
  57. }
  58. func (d *Driver) SetLogger(logger Logger) {
  59. d.log = optionalLogger{logger}
  60. }
  61. // NewConnector creates a new connector from a DSN.
  62. // The returned connector may be used with sql.OpenDB.
  63. func NewConnector(dsn string) (*Connector, error) {
  64. params, err := parseConnectParams(dsn)
  65. if err != nil {
  66. return nil, err
  67. }
  68. c := &Connector{
  69. params: params,
  70. driver: driverInstanceNoProcess,
  71. }
  72. return c, nil
  73. }
  74. // Connector holds the parsed DSN and is ready to make a new connection
  75. // at any time.
  76. //
  77. // In the future, settings that cannot be passed through a string DSN
  78. // may be set directly on the connector.
  79. type Connector struct {
  80. params connectParams
  81. driver *Driver
  82. // SessionInitSQL is executed after marking a given session to be reset.
  83. // When not present, the next query will still reset the session to the
  84. // database defaults.
  85. //
  86. // When present the connection will immediately mark the session to
  87. // be reset, then execute the SessionInitSQL text to setup the session
  88. // that may be different from the base database defaults.
  89. //
  90. // For Example, the application relies on the following defaults
  91. // but is not allowed to set them at the database system level.
  92. //
  93. // SET XACT_ABORT ON;
  94. // SET TEXTSIZE -1;
  95. // SET ANSI_NULLS ON;
  96. // SET LOCK_TIMEOUT 10000;
  97. //
  98. // SessionInitSQL should not attempt to manually call sp_reset_connection.
  99. // This will happen at the TDS layer.
  100. //
  101. // SessionInitSQL is optional. The session will be reset even if
  102. // SessionInitSQL is empty.
  103. SessionInitSQL string
  104. }
  105. type Conn struct {
  106. connector *Connector
  107. sess *tdsSession
  108. transactionCtx context.Context
  109. resetSession bool
  110. processQueryText bool
  111. connectionGood bool
  112. outs map[string]interface{}
  113. }
  114. func (c *Conn) checkBadConn(err error) error {
  115. // this is a hack to address Issue #275
  116. // we set connectionGood flag to false if
  117. // error indicates that connection is not usable
  118. // but we return actual error instead of ErrBadConn
  119. // this will cause connection to stay in a pool
  120. // but next request to this connection will return ErrBadConn
  121. // it might be possible to revise this hack after
  122. // https://github.com/golang/go/issues/20807
  123. // is implemented
  124. switch err {
  125. case nil:
  126. return nil
  127. case io.EOF:
  128. c.connectionGood = false
  129. return driver.ErrBadConn
  130. case driver.ErrBadConn:
  131. // It is an internal programming error if driver.ErrBadConn
  132. // is ever passed to this function. driver.ErrBadConn should
  133. // only ever be returned in response to a *mssql.Conn.connectionGood == false
  134. // check in the external facing API.
  135. panic("driver.ErrBadConn in checkBadConn. This should not happen.")
  136. }
  137. switch err.(type) {
  138. case net.Error:
  139. c.connectionGood = false
  140. return err
  141. case StreamError:
  142. c.connectionGood = false
  143. return err
  144. default:
  145. return err
  146. }
  147. }
  148. func (c *Conn) clearOuts() {
  149. c.outs = nil
  150. }
  151. func (c *Conn) simpleProcessResp(ctx context.Context) error {
  152. tokchan := make(chan tokenStruct, 5)
  153. go processResponse(ctx, c.sess, tokchan, c.outs)
  154. c.clearOuts()
  155. for tok := range tokchan {
  156. switch token := tok.(type) {
  157. case doneStruct:
  158. if token.isError() {
  159. return c.checkBadConn(token.getError())
  160. }
  161. case error:
  162. return c.checkBadConn(token)
  163. }
  164. }
  165. return nil
  166. }
  167. func (c *Conn) Commit() error {
  168. if !c.connectionGood {
  169. return driver.ErrBadConn
  170. }
  171. if err := c.sendCommitRequest(); err != nil {
  172. return c.checkBadConn(err)
  173. }
  174. return c.simpleProcessResp(c.transactionCtx)
  175. }
  176. func (c *Conn) sendCommitRequest() error {
  177. headers := []headerStruct{
  178. {hdrtype: dataStmHdrTransDescr,
  179. data: transDescrHdr{c.sess.tranid, 1}.pack()},
  180. }
  181. reset := c.resetSession
  182. c.resetSession = false
  183. if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
  184. if c.sess.logFlags&logErrors != 0 {
  185. c.sess.log.Printf("Failed to send CommitXact with %v", err)
  186. }
  187. c.connectionGood = false
  188. return fmt.Errorf("Faild to send CommitXact: %v", err)
  189. }
  190. return nil
  191. }
  192. func (c *Conn) Rollback() error {
  193. if !c.connectionGood {
  194. return driver.ErrBadConn
  195. }
  196. if err := c.sendRollbackRequest(); err != nil {
  197. return c.checkBadConn(err)
  198. }
  199. return c.simpleProcessResp(c.transactionCtx)
  200. }
  201. func (c *Conn) sendRollbackRequest() error {
  202. headers := []headerStruct{
  203. {hdrtype: dataStmHdrTransDescr,
  204. data: transDescrHdr{c.sess.tranid, 1}.pack()},
  205. }
  206. reset := c.resetSession
  207. c.resetSession = false
  208. if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
  209. if c.sess.logFlags&logErrors != 0 {
  210. c.sess.log.Printf("Failed to send RollbackXact with %v", err)
  211. }
  212. c.connectionGood = false
  213. return fmt.Errorf("Failed to send RollbackXact: %v", err)
  214. }
  215. return nil
  216. }
  217. func (c *Conn) Begin() (driver.Tx, error) {
  218. return c.begin(context.Background(), isolationUseCurrent)
  219. }
  220. func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
  221. if !c.connectionGood {
  222. return nil, driver.ErrBadConn
  223. }
  224. err = c.sendBeginRequest(ctx, tdsIsolation)
  225. if err != nil {
  226. return nil, c.checkBadConn(err)
  227. }
  228. tx, err = c.processBeginResponse(ctx)
  229. if err != nil {
  230. return nil, c.checkBadConn(err)
  231. }
  232. return
  233. }
  234. func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
  235. c.transactionCtx = ctx
  236. headers := []headerStruct{
  237. {hdrtype: dataStmHdrTransDescr,
  238. data: transDescrHdr{0, 1}.pack()},
  239. }
  240. reset := c.resetSession
  241. c.resetSession = false
  242. if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
  243. if c.sess.logFlags&logErrors != 0 {
  244. c.sess.log.Printf("Failed to send BeginXact with %v", err)
  245. }
  246. c.connectionGood = false
  247. return fmt.Errorf("Failed to send BeginXact: %v", err)
  248. }
  249. return nil
  250. }
  251. func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
  252. if err := c.simpleProcessResp(ctx); err != nil {
  253. return nil, err
  254. }
  255. // successful BEGINXACT request will return sess.tranid
  256. // for started transaction
  257. return c, nil
  258. }
  259. func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
  260. params, err := parseConnectParams(dsn)
  261. if err != nil {
  262. return nil, err
  263. }
  264. return d.connect(ctx, params)
  265. }
  266. // connect to the server, using the provided context for dialing only.
  267. func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) {
  268. sess, err := connect(ctx, d.log, params)
  269. if err != nil {
  270. // main server failed, try fail-over partner
  271. if params.failOverPartner == "" {
  272. return nil, err
  273. }
  274. params.host = params.failOverPartner
  275. if params.failOverPort != 0 {
  276. params.port = params.failOverPort
  277. }
  278. sess, err = connect(ctx, d.log, params)
  279. if err != nil {
  280. // fail-over partner also failed, now fail
  281. return nil, err
  282. }
  283. }
  284. conn := &Conn{
  285. sess: sess,
  286. transactionCtx: context.Background(),
  287. processQueryText: d.processQueryText,
  288. connectionGood: true,
  289. }
  290. conn.sess.log = d.log
  291. return conn, nil
  292. }
  293. func (c *Conn) Close() error {
  294. return c.sess.buf.transport.Close()
  295. }
  296. type Stmt struct {
  297. c *Conn
  298. query string
  299. paramCount int
  300. notifSub *queryNotifSub
  301. }
  302. type queryNotifSub struct {
  303. msgText string
  304. options string
  305. timeout uint32
  306. }
  307. func (c *Conn) Prepare(query string) (driver.Stmt, error) {
  308. if !c.connectionGood {
  309. return nil, driver.ErrBadConn
  310. }
  311. if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
  312. return c.prepareCopyIn(context.Background(), query)
  313. }
  314. return c.prepareContext(context.Background(), query)
  315. }
  316. func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
  317. paramCount := -1
  318. if c.processQueryText {
  319. query, paramCount = parseParams(query)
  320. }
  321. return &Stmt{c, query, paramCount, nil}, nil
  322. }
  323. func (s *Stmt) Close() error {
  324. return nil
  325. }
  326. func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) {
  327. to := uint32(timeout / time.Second)
  328. if to < 1 {
  329. to = 1
  330. }
  331. s.notifSub = &queryNotifSub{id, options, to}
  332. }
  333. func (s *Stmt) NumInput() int {
  334. return s.paramCount
  335. }
  336. func (s *Stmt) sendQuery(args []namedValue) (err error) {
  337. headers := []headerStruct{
  338. {hdrtype: dataStmHdrTransDescr,
  339. data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
  340. }
  341. if s.notifSub != nil {
  342. headers = append(headers,
  343. headerStruct{
  344. hdrtype: dataStmHdrQueryNotif,
  345. data: queryNotifHdr{
  346. s.notifSub.msgText,
  347. s.notifSub.options,
  348. s.notifSub.timeout,
  349. }.pack(),
  350. })
  351. }
  352. conn := s.c
  353. // no need to check number of parameters here, it is checked by database/sql
  354. if conn.sess.logFlags&logSQL != 0 {
  355. conn.sess.log.Println(s.query)
  356. }
  357. if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
  358. for i := 0; i < len(args); i++ {
  359. if len(args[i].Name) > 0 {
  360. s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
  361. } else {
  362. s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
  363. }
  364. }
  365. }
  366. reset := conn.resetSession
  367. conn.resetSession = false
  368. if len(args) == 0 {
  369. if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
  370. if conn.sess.logFlags&logErrors != 0 {
  371. conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
  372. }
  373. conn.connectionGood = false
  374. return fmt.Errorf("failed to send SQL Batch: %v", err)
  375. }
  376. } else {
  377. proc := sp_ExecuteSql
  378. var params []param
  379. if isProc(s.query) {
  380. proc.name = s.query
  381. params, _, err = s.makeRPCParams(args, 0)
  382. if err != nil {
  383. return
  384. }
  385. } else {
  386. var decls []string
  387. params, decls, err = s.makeRPCParams(args, 2)
  388. if err != nil {
  389. return
  390. }
  391. params[0] = makeStrParam(s.query)
  392. params[1] = makeStrParam(strings.Join(decls, ","))
  393. }
  394. if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
  395. if conn.sess.logFlags&logErrors != 0 {
  396. conn.sess.log.Printf("Failed to send Rpc with %v", err)
  397. }
  398. conn.connectionGood = false
  399. return fmt.Errorf("Failed to send RPC: %v", err)
  400. }
  401. }
  402. return
  403. }
  404. // isProc takes the query text in s and determines if it is a stored proc name
  405. // or SQL text.
  406. func isProc(s string) bool {
  407. if len(s) == 0 {
  408. return false
  409. }
  410. if s[0] == '[' && s[len(s)-1] == ']' && strings.ContainsAny(s, "\n\r") == false {
  411. return true
  412. }
  413. return !strings.ContainsAny(s, " \t\n\r;")
  414. }
  415. func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]param, []string, error) {
  416. var err error
  417. params := make([]param, len(args)+offset)
  418. decls := make([]string, len(args))
  419. for i, val := range args {
  420. params[i+offset], err = s.makeParam(val.Value)
  421. if err != nil {
  422. return nil, nil, err
  423. }
  424. var name string
  425. if len(val.Name) > 0 {
  426. name = "@" + val.Name
  427. } else {
  428. name = fmt.Sprintf("@p%d", val.Ordinal)
  429. }
  430. params[i+offset].Name = name
  431. decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti))
  432. }
  433. return params, decls, nil
  434. }
  435. type namedValue struct {
  436. Name string
  437. Ordinal int
  438. Value driver.Value
  439. }
  440. func convertOldArgs(args []driver.Value) []namedValue {
  441. list := make([]namedValue, len(args))
  442. for i, v := range args {
  443. list[i] = namedValue{
  444. Ordinal: i + 1,
  445. Value: v,
  446. }
  447. }
  448. return list
  449. }
  450. func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
  451. return s.queryContext(context.Background(), convertOldArgs(args))
  452. }
  453. func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
  454. if !s.c.connectionGood {
  455. return nil, driver.ErrBadConn
  456. }
  457. if err = s.sendQuery(args); err != nil {
  458. return nil, s.c.checkBadConn(err)
  459. }
  460. return s.processQueryResponse(ctx)
  461. }
  462. func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
  463. tokchan := make(chan tokenStruct, 5)
  464. ctx, cancel := context.WithCancel(ctx)
  465. go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
  466. s.c.clearOuts()
  467. // process metadata
  468. var cols []columnStruct
  469. loop:
  470. for tok := range tokchan {
  471. switch token := tok.(type) {
  472. // By ignoring DONE token we effectively
  473. // skip empty result-sets.
  474. // This improves results in queries like that:
  475. // set nocount on; select 1
  476. // see TestIgnoreEmptyResults test
  477. //case doneStruct:
  478. //break loop
  479. case []columnStruct:
  480. cols = token
  481. break loop
  482. case doneStruct:
  483. if token.isError() {
  484. return nil, s.c.checkBadConn(token.getError())
  485. }
  486. case error:
  487. return nil, s.c.checkBadConn(token)
  488. }
  489. }
  490. res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
  491. return
  492. }
  493. func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
  494. return s.exec(context.Background(), convertOldArgs(args))
  495. }
  496. func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
  497. if !s.c.connectionGood {
  498. return nil, driver.ErrBadConn
  499. }
  500. if err = s.sendQuery(args); err != nil {
  501. return nil, s.c.checkBadConn(err)
  502. }
  503. if res, err = s.processExec(ctx); err != nil {
  504. return nil, s.c.checkBadConn(err)
  505. }
  506. return
  507. }
  508. func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
  509. tokchan := make(chan tokenStruct, 5)
  510. go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
  511. s.c.clearOuts()
  512. var rowCount int64
  513. for token := range tokchan {
  514. switch token := token.(type) {
  515. case doneInProcStruct:
  516. if token.Status&doneCount != 0 {
  517. rowCount += int64(token.RowCount)
  518. }
  519. case doneStruct:
  520. if token.Status&doneCount != 0 {
  521. rowCount += int64(token.RowCount)
  522. }
  523. if token.isError() {
  524. return nil, token.getError()
  525. }
  526. case error:
  527. return nil, token
  528. }
  529. }
  530. return &Result{s.c, rowCount}, nil
  531. }
  532. type Rows struct {
  533. stmt *Stmt
  534. cols []columnStruct
  535. tokchan chan tokenStruct
  536. nextCols []columnStruct
  537. cancel func()
  538. }
  539. func (rc *Rows) Close() error {
  540. rc.cancel()
  541. for _ = range rc.tokchan {
  542. }
  543. rc.tokchan = nil
  544. return nil
  545. }
  546. func (rc *Rows) Columns() (res []string) {
  547. res = make([]string, len(rc.cols))
  548. for i, col := range rc.cols {
  549. res[i] = col.ColName
  550. }
  551. return
  552. }
  553. func (rc *Rows) Next(dest []driver.Value) error {
  554. if !rc.stmt.c.connectionGood {
  555. return driver.ErrBadConn
  556. }
  557. if rc.nextCols != nil {
  558. return io.EOF
  559. }
  560. for tok := range rc.tokchan {
  561. switch tokdata := tok.(type) {
  562. case []columnStruct:
  563. rc.nextCols = tokdata
  564. return io.EOF
  565. case []interface{}:
  566. for i := range dest {
  567. dest[i] = tokdata[i]
  568. }
  569. return nil
  570. case doneStruct:
  571. if tokdata.isError() {
  572. return rc.stmt.c.checkBadConn(tokdata.getError())
  573. }
  574. case error:
  575. return rc.stmt.c.checkBadConn(tokdata)
  576. }
  577. }
  578. return io.EOF
  579. }
  580. func (rc *Rows) HasNextResultSet() bool {
  581. return rc.nextCols != nil
  582. }
  583. func (rc *Rows) NextResultSet() error {
  584. rc.cols = rc.nextCols
  585. rc.nextCols = nil
  586. if rc.cols == nil {
  587. return io.EOF
  588. }
  589. return nil
  590. }
  591. // It should return
  592. // the value type that can be used to scan types into. For example, the database
  593. // column type "bigint" this should return "reflect.TypeOf(int64(0))".
  594. func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
  595. return makeGoLangScanType(r.cols[index].ti)
  596. }
  597. // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
  598. // database system type name without the length. Type names should be uppercase.
  599. // Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
  600. // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
  601. // "TIMESTAMP".
  602. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
  603. return makeGoLangTypeName(r.cols[index].ti)
  604. }
  605. // RowsColumnTypeLength may be implemented by Rows. It should return the length
  606. // of the column type if the column is a variable length type. If the column is
  607. // not a variable length type ok should return false.
  608. // If length is not limited other than system limits, it should return math.MaxInt64.
  609. // The following are examples of returned values for various types:
  610. // TEXT (math.MaxInt64, true)
  611. // varchar(10) (10, true)
  612. // nvarchar(10) (10, true)
  613. // decimal (0, false)
  614. // int (0, false)
  615. // bytea(30) (30, true)
  616. func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
  617. return makeGoLangTypeLength(r.cols[index].ti)
  618. }
  619. // It should return
  620. // the precision and scale for decimal types. If not applicable, ok should be false.
  621. // The following are examples of returned values for various types:
  622. // decimal(38, 4) (38, 4, true)
  623. // int (0, 0, false)
  624. // decimal (math.MaxInt64, math.MaxInt64, true)
  625. func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
  626. return makeGoLangTypePrecisionScale(r.cols[index].ti)
  627. }
  628. // The nullable value should
  629. // be true if it is known the column may be null, or false if the column is known
  630. // to be not nullable.
  631. // If the column nullability is unknown, ok should be false.
  632. func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
  633. nullable = r.cols[index].Flags&colFlagNullable != 0
  634. ok = true
  635. return
  636. }
  637. func makeStrParam(val string) (res param) {
  638. res.ti.TypeId = typeNVarChar
  639. res.buffer = str2ucs2(val)
  640. res.ti.Size = len(res.buffer)
  641. return
  642. }
  643. func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
  644. if val == nil {
  645. res.ti.TypeId = typeNull
  646. res.buffer = nil
  647. res.ti.Size = 0
  648. return
  649. }
  650. switch val := val.(type) {
  651. case int64:
  652. res.ti.TypeId = typeIntN
  653. res.buffer = make([]byte, 8)
  654. res.ti.Size = 8
  655. binary.LittleEndian.PutUint64(res.buffer, uint64(val))
  656. case sql.NullInt64:
  657. // only null values should be getting here
  658. res.ti.TypeId = typeIntN
  659. res.ti.Size = 8
  660. res.buffer = []byte{}
  661. case float64:
  662. res.ti.TypeId = typeFltN
  663. res.ti.Size = 8
  664. res.buffer = make([]byte, 8)
  665. binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
  666. case sql.NullFloat64:
  667. // only null values should be getting here
  668. res.ti.TypeId = typeFltN
  669. res.ti.Size = 8
  670. res.buffer = []byte{}
  671. case []byte:
  672. res.ti.TypeId = typeBigVarBin
  673. res.ti.Size = len(val)
  674. res.buffer = val
  675. case string:
  676. res = makeStrParam(val)
  677. case sql.NullString:
  678. // only null values should be getting here
  679. res.ti.TypeId = typeNVarChar
  680. res.buffer = nil
  681. res.ti.Size = 8000
  682. case bool:
  683. res.ti.TypeId = typeBitN
  684. res.ti.Size = 1
  685. res.buffer = make([]byte, 1)
  686. if val {
  687. res.buffer[0] = 1
  688. }
  689. case sql.NullBool:
  690. // only null values should be getting here
  691. res.ti.TypeId = typeBitN
  692. res.ti.Size = 1
  693. res.buffer = []byte{}
  694. case time.Time:
  695. if s.c.sess.loginAck.TDSVersion >= verTDS73 {
  696. res.ti.TypeId = typeDateTimeOffsetN
  697. res.ti.Scale = 7
  698. res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale))
  699. res.ti.Size = len(res.buffer)
  700. } else {
  701. res.ti.TypeId = typeDateTimeN
  702. res.buffer = encodeDateTime(val)
  703. res.ti.Size = len(res.buffer)
  704. }
  705. default:
  706. return s.makeParamExtra(val)
  707. }
  708. return
  709. }
  710. type Result struct {
  711. c *Conn
  712. rowsAffected int64
  713. }
  714. func (r *Result) RowsAffected() (int64, error) {
  715. return r.rowsAffected, nil
  716. }
  717. func (r *Result) LastInsertId() (int64, error) {
  718. s, err := r.c.Prepare("select cast(@@identity as bigint)")
  719. if err != nil {
  720. return 0, err
  721. }
  722. defer s.Close()
  723. rows, err := s.Query(nil)
  724. if err != nil {
  725. return 0, err
  726. }
  727. defer rows.Close()
  728. dest := make([]driver.Value, 1)
  729. err = rows.Next(dest)
  730. if err != nil {
  731. return 0, err
  732. }
  733. if dest[0] == nil {
  734. return -1, errors.New("There is no generated identity value")
  735. }
  736. lastInsertId := dest[0].(int64)
  737. return lastInsertId, nil
  738. }
  739. var _ driver.Pinger = &Conn{}
  740. // Ping is used to check if the remote server is available and satisfies the Pinger interface.
  741. func (c *Conn) Ping(ctx context.Context) error {
  742. if !c.connectionGood {
  743. return driver.ErrBadConn
  744. }
  745. stmt := &Stmt{c, `select 1;`, 0, nil}
  746. _, err := stmt.ExecContext(ctx, nil)
  747. return err
  748. }
  749. var _ driver.ConnBeginTx = &Conn{}
  750. // BeginTx satisfies ConnBeginTx.
  751. func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
  752. if !c.connectionGood {
  753. return nil, driver.ErrBadConn
  754. }
  755. if opts.ReadOnly {
  756. return nil, errors.New("Read-only transactions are not supported")
  757. }
  758. var tdsIsolation isoLevel
  759. switch sql.IsolationLevel(opts.Isolation) {
  760. case sql.LevelDefault:
  761. tdsIsolation = isolationUseCurrent
  762. case sql.LevelReadUncommitted:
  763. tdsIsolation = isolationReadUncommited
  764. case sql.LevelReadCommitted:
  765. tdsIsolation = isolationReadCommited
  766. case sql.LevelWriteCommitted:
  767. return nil, errors.New("LevelWriteCommitted isolation level is not supported")
  768. case sql.LevelRepeatableRead:
  769. tdsIsolation = isolationRepeatableRead
  770. case sql.LevelSnapshot:
  771. tdsIsolation = isolationSnapshot
  772. case sql.LevelSerializable:
  773. tdsIsolation = isolationSerializable
  774. case sql.LevelLinearizable:
  775. return nil, errors.New("LevelLinearizable isolation level is not supported")
  776. default:
  777. return nil, errors.New("Isolation level is not supported or unknown")
  778. }
  779. return c.begin(ctx, tdsIsolation)
  780. }
  781. func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
  782. if !c.connectionGood {
  783. return nil, driver.ErrBadConn
  784. }
  785. if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
  786. return c.prepareCopyIn(ctx, query)
  787. }
  788. return c.prepareContext(ctx, query)
  789. }
  790. func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
  791. if !s.c.connectionGood {
  792. return nil, driver.ErrBadConn
  793. }
  794. list := make([]namedValue, len(args))
  795. for i, nv := range args {
  796. list[i] = namedValue(nv)
  797. }
  798. return s.queryContext(ctx, list)
  799. }
  800. func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
  801. if !s.c.connectionGood {
  802. return nil, driver.ErrBadConn
  803. }
  804. list := make([]namedValue, len(args))
  805. for i, nv := range args {
  806. list[i] = namedValue(nv)
  807. }
  808. return s.exec(ctx, list)
  809. }