bulkcopy.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. package mssql
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "fmt"
  7. "math"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "time"
  12. )
  13. type Bulk struct {
  14. // ctx is used only for AddRow and Done methods.
  15. // This could be removed if AddRow and Done accepted
  16. // a ctx field as well, which is available with the
  17. // database/sql call.
  18. ctx context.Context
  19. cn *Conn
  20. metadata []columnStruct
  21. bulkColumns []columnStruct
  22. columnsName []string
  23. tablename string
  24. numRows int
  25. headerSent bool
  26. Options BulkOptions
  27. Debug bool
  28. }
  29. type BulkOptions struct {
  30. CheckConstraints bool
  31. FireTriggers bool
  32. KeepNulls bool
  33. KilobytesPerBatch int
  34. RowsPerBatch int
  35. Order []string
  36. Tablock bool
  37. }
  38. type DataValue interface{}
  39. func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
  40. b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns}
  41. b.Debug = false
  42. return &b
  43. }
  44. func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) {
  45. b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns}
  46. b.Debug = false
  47. return &b
  48. }
  49. func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
  50. //get table columns info
  51. err = b.getMetadata(ctx)
  52. if err != nil {
  53. return err
  54. }
  55. //match the columns
  56. for _, colname := range b.columnsName {
  57. var bulkCol *columnStruct
  58. for _, m := range b.metadata {
  59. if m.ColName == colname {
  60. bulkCol = &m
  61. break
  62. }
  63. }
  64. if bulkCol != nil {
  65. if bulkCol.ti.TypeId == typeUdt {
  66. //send udt as binary
  67. bulkCol.ti.TypeId = typeBigVarBin
  68. }
  69. b.bulkColumns = append(b.bulkColumns, *bulkCol)
  70. b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
  71. } else {
  72. return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
  73. }
  74. }
  75. //create the bulk command
  76. //columns definitions
  77. var col_defs bytes.Buffer
  78. for i, col := range b.bulkColumns {
  79. if i != 0 {
  80. col_defs.WriteString(", ")
  81. }
  82. col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti))
  83. }
  84. //options
  85. var with_opts []string
  86. if b.Options.CheckConstraints {
  87. with_opts = append(with_opts, "CHECK_CONSTRAINTS")
  88. }
  89. if b.Options.FireTriggers {
  90. with_opts = append(with_opts, "FIRE_TRIGGERS")
  91. }
  92. if b.Options.KeepNulls {
  93. with_opts = append(with_opts, "KEEP_NULLS")
  94. }
  95. if b.Options.KilobytesPerBatch > 0 {
  96. with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch))
  97. }
  98. if b.Options.RowsPerBatch > 0 {
  99. with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch))
  100. }
  101. if len(b.Options.Order) > 0 {
  102. with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ",")))
  103. }
  104. if b.Options.Tablock {
  105. with_opts = append(with_opts, "TABLOCK")
  106. }
  107. var with_part string
  108. if len(with_opts) > 0 {
  109. with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ","))
  110. }
  111. query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
  112. stmt, err := b.cn.PrepareContext(ctx, query)
  113. if err != nil {
  114. return fmt.Errorf("Prepare failed: %s", err.Error())
  115. }
  116. b.dlogf(query)
  117. _, err = stmt.(*Stmt).ExecContext(ctx, nil)
  118. if err != nil {
  119. return err
  120. }
  121. b.headerSent = true
  122. var buf = b.cn.sess.buf
  123. buf.BeginPacket(packBulkLoadBCP, false)
  124. // Send the columns metadata.
  125. columnMetadata := b.createColMetadata()
  126. _, err = buf.Write(columnMetadata)
  127. return
  128. }
  129. // AddRow immediately writes the row to the destination table.
  130. // The arguments are the row values in the order they were specified.
  131. func (b *Bulk) AddRow(row []interface{}) (err error) {
  132. if !b.headerSent {
  133. err = b.sendBulkCommand(b.ctx)
  134. if err != nil {
  135. return
  136. }
  137. }
  138. if len(row) != len(b.bulkColumns) {
  139. return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
  140. len(row), len(b.bulkColumns))
  141. }
  142. bytes, err := b.makeRowData(row)
  143. if err != nil {
  144. return
  145. }
  146. _, err = b.cn.sess.buf.Write(bytes)
  147. if err != nil {
  148. return
  149. }
  150. b.numRows = b.numRows + 1
  151. return
  152. }
  153. func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
  154. buf := new(bytes.Buffer)
  155. buf.WriteByte(byte(tokenRow))
  156. var logcol bytes.Buffer
  157. for i, col := range b.bulkColumns {
  158. if b.Debug {
  159. logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i]))
  160. }
  161. param, err := b.makeParam(row[i], col)
  162. if err != nil {
  163. return nil, fmt.Errorf("bulkcopy: %s", err.Error())
  164. }
  165. if col.ti.Writer == nil {
  166. return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
  167. col.ColName, col.ti.TypeId)
  168. }
  169. err = col.ti.Writer(buf, param.ti, param.buffer)
  170. if err != nil {
  171. return nil, fmt.Errorf("bulkcopy: %s", err.Error())
  172. }
  173. }
  174. b.dlogf("row[%d] %s\n", b.numRows, logcol.String())
  175. return buf.Bytes(), nil
  176. }
  177. func (b *Bulk) Done() (rowcount int64, err error) {
  178. if b.headerSent == false {
  179. //no rows had been sent
  180. return 0, nil
  181. }
  182. var buf = b.cn.sess.buf
  183. buf.WriteByte(byte(tokenDone))
  184. binary.Write(buf, binary.LittleEndian, uint16(doneFinal))
  185. binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd
  186. if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
  187. binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0
  188. } else {
  189. binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0
  190. }
  191. buf.FinishPacket()
  192. tokchan := make(chan tokenStruct, 5)
  193. go processResponse(b.ctx, b.cn.sess, tokchan, nil)
  194. var rowCount int64
  195. for token := range tokchan {
  196. switch token := token.(type) {
  197. case doneStruct:
  198. if token.Status&doneCount != 0 {
  199. rowCount = int64(token.RowCount)
  200. }
  201. if token.isError() {
  202. return 0, token.getError()
  203. }
  204. case error:
  205. return 0, b.cn.checkBadConn(token)
  206. }
  207. }
  208. return rowCount, nil
  209. }
  210. func (b *Bulk) createColMetadata() []byte {
  211. buf := new(bytes.Buffer)
  212. buf.WriteByte(byte(tokenColMetadata)) // token
  213. binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
  214. for i, col := range b.bulkColumns {
  215. if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
  216. binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0?
  217. } else {
  218. binary.Write(buf, binary.LittleEndian, uint16(col.UserType))
  219. }
  220. binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
  221. writeTypeInfo(buf, &b.bulkColumns[i].ti)
  222. if col.ti.TypeId == typeNText ||
  223. col.ti.TypeId == typeText ||
  224. col.ti.TypeId == typeImage {
  225. tablename_ucs2 := str2ucs2(b.tablename)
  226. binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2))
  227. buf.Write(tablename_ucs2)
  228. }
  229. colname_ucs2 := str2ucs2(col.ColName)
  230. buf.WriteByte(uint8(len(colname_ucs2) / 2))
  231. buf.Write(colname_ucs2)
  232. }
  233. return buf.Bytes()
  234. }
  235. func (b *Bulk) getMetadata(ctx context.Context) (err error) {
  236. stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON")
  237. if err != nil {
  238. return
  239. }
  240. _, err = stmt.ExecContext(ctx, nil)
  241. if err != nil {
  242. return
  243. }
  244. // Get columns info.
  245. stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
  246. if err != nil {
  247. return
  248. }
  249. rows, err := stmt.QueryContext(ctx, nil)
  250. if err != nil {
  251. return fmt.Errorf("get columns info failed: %v", err)
  252. }
  253. b.metadata = rows.(*Rows).cols
  254. if b.Debug {
  255. for _, col := range b.metadata {
  256. b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
  257. col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
  258. col.Flags, col.ti.Collation.LcidAndFlags)
  259. }
  260. }
  261. return rows.Close()
  262. }
  263. func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) {
  264. res.ti.Size = col.ti.Size
  265. res.ti.TypeId = col.ti.TypeId
  266. if val == nil {
  267. res.ti.Size = 0
  268. return
  269. }
  270. switch col.ti.TypeId {
  271. case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
  272. var intvalue int64
  273. switch val := val.(type) {
  274. case int:
  275. intvalue = int64(val)
  276. case int32:
  277. intvalue = int64(val)
  278. case int64:
  279. intvalue = val
  280. default:
  281. err = fmt.Errorf("mssql: invalid type for int column")
  282. return
  283. }
  284. res.buffer = make([]byte, res.ti.Size)
  285. if col.ti.Size == 1 {
  286. res.buffer[0] = byte(intvalue)
  287. } else if col.ti.Size == 2 {
  288. binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
  289. } else if col.ti.Size == 4 {
  290. binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
  291. } else if col.ti.Size == 8 {
  292. binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue))
  293. }
  294. case typeFlt4, typeFlt8, typeFltN:
  295. var floatvalue float64
  296. switch val := val.(type) {
  297. case float32:
  298. floatvalue = float64(val)
  299. case float64:
  300. floatvalue = val
  301. case int:
  302. floatvalue = float64(val)
  303. case int64:
  304. floatvalue = float64(val)
  305. default:
  306. err = fmt.Errorf("mssql: invalid type for float column: %s", val)
  307. return
  308. }
  309. if col.ti.Size == 4 {
  310. res.buffer = make([]byte, 4)
  311. binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue)))
  312. } else if col.ti.Size == 8 {
  313. res.buffer = make([]byte, 8)
  314. binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue))
  315. }
  316. case typeNVarChar, typeNText, typeNChar:
  317. switch val := val.(type) {
  318. case string:
  319. res.buffer = str2ucs2(val)
  320. case []byte:
  321. res.buffer = val
  322. default:
  323. err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val)
  324. return
  325. }
  326. res.ti.Size = len(res.buffer)
  327. case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar:
  328. switch val := val.(type) {
  329. case string:
  330. res.buffer = []byte(val)
  331. case []byte:
  332. res.buffer = val
  333. default:
  334. err = fmt.Errorf("mssql: invalid type for varchar column: %s", val)
  335. return
  336. }
  337. res.ti.Size = len(res.buffer)
  338. case typeBit, typeBitN:
  339. if reflect.TypeOf(val).Kind() != reflect.Bool {
  340. err = fmt.Errorf("mssql: invalid type for bit column: %s", val)
  341. return
  342. }
  343. res.ti.TypeId = typeBitN
  344. res.ti.Size = 1
  345. res.buffer = make([]byte, 1)
  346. if val.(bool) {
  347. res.buffer[0] = 1
  348. }
  349. case typeDateTime2N:
  350. switch val := val.(type) {
  351. case time.Time:
  352. res.buffer = encodeDateTime2(val, int(col.ti.Scale))
  353. res.ti.Size = len(res.buffer)
  354. default:
  355. err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val)
  356. return
  357. }
  358. case typeDateTimeOffsetN:
  359. switch val := val.(type) {
  360. case time.Time:
  361. res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale))
  362. res.ti.Size = len(res.buffer)
  363. default:
  364. err = fmt.Errorf("mssql: invalid type for datetimeoffset column: %s", val)
  365. return
  366. }
  367. case typeDateN:
  368. switch val := val.(type) {
  369. case time.Time:
  370. res.buffer = encodeDate(val)
  371. res.ti.Size = len(res.buffer)
  372. default:
  373. err = fmt.Errorf("mssql: invalid type for date column: %s", val)
  374. return
  375. }
  376. case typeDateTime, typeDateTimeN, typeDateTim4:
  377. switch val := val.(type) {
  378. case time.Time:
  379. if col.ti.Size == 4 {
  380. res.buffer = encodeDateTim4(val)
  381. res.ti.Size = len(res.buffer)
  382. } else if col.ti.Size == 8 {
  383. res.buffer = encodeDateTime(val)
  384. res.ti.Size = len(res.buffer)
  385. } else {
  386. err = fmt.Errorf("mssql: invalid size of column")
  387. }
  388. default:
  389. err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
  390. }
  391. // case typeMoney, typeMoney4, typeMoneyN:
  392. case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
  393. var value float64
  394. switch v := val.(type) {
  395. case int:
  396. value = float64(v)
  397. case int8:
  398. value = float64(v)
  399. case int16:
  400. value = float64(v)
  401. case int32:
  402. value = float64(v)
  403. case int64:
  404. value = float64(v)
  405. case float32:
  406. value = float64(v)
  407. case float64:
  408. value = v
  409. case string:
  410. if value, err = strconv.ParseFloat(v, 64); err != nil {
  411. return res, fmt.Errorf("bulk: unable to convert string to float: %v", err)
  412. }
  413. default:
  414. return res, fmt.Errorf("unknown value for decimal: %#v", v)
  415. }
  416. perc := col.ti.Prec
  417. scale := col.ti.Scale
  418. var dec Decimal
  419. dec, err = Float64ToDecimalScale(value, scale)
  420. if err != nil {
  421. return res, err
  422. }
  423. dec.prec = perc
  424. var length byte
  425. switch {
  426. case perc <= 9:
  427. length = 4
  428. case perc <= 19:
  429. length = 8
  430. case perc <= 28:
  431. length = 12
  432. default:
  433. length = 16
  434. }
  435. buf := make([]byte, length+1)
  436. // first byte length written by typeInfo.writer
  437. res.ti.Size = int(length) + 1
  438. // second byte sign
  439. if value < 0 {
  440. buf[0] = 0
  441. } else {
  442. buf[0] = 1
  443. }
  444. ub := dec.UnscaledBytes()
  445. l := len(ub)
  446. if l > int(length) {
  447. err = fmt.Errorf("decimal out of range: %s", dec)
  448. return res, err
  449. }
  450. // reverse the bytes
  451. for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
  452. buf[i] = ub[j]
  453. }
  454. res.buffer = buf
  455. case typeBigVarBin:
  456. switch val := val.(type) {
  457. case []byte:
  458. res.ti.Size = len(val)
  459. res.buffer = val
  460. default:
  461. err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
  462. return
  463. }
  464. case typeGuid:
  465. switch val := val.(type) {
  466. case []byte:
  467. res.ti.Size = len(val)
  468. res.buffer = val
  469. default:
  470. err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
  471. return
  472. }
  473. default:
  474. err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
  475. }
  476. return
  477. }
  478. func (b *Bulk) dlogf(format string, v ...interface{}) {
  479. if b.Debug {
  480. b.cn.sess.log.Printf(format, v...)
  481. }
  482. }