tablib_sql.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package tablib
  2. import (
  3. "bytes"
  4. "regexp"
  5. "strconv"
  6. "strings"
  7. "time"
  8. )
  9. var (
  10. typePostgres = "postgres"
  11. typeMySQL = "mysql"
  12. defaults = map[string]string{"various." + typePostgres: "TEXT",
  13. "various." + typeMySQL: "VARCHAR(100)", "numeric." + typePostgres: "NUMERIC",
  14. "numeric." + typeMySQL: "DOUBLE"}
  15. )
  16. // columnSQLType determines the type of a column
  17. // if throughout the whole column values have the same type then this type is
  18. // returned, otherwise the VARCHAR/TEXT type is returned.
  19. // numeric types are coerced into DOUBLE/NUMERIC
  20. func (d *Dataset) columnSQLType(header, dbType string) (string, []interface{}) {
  21. types := 0
  22. currentType := ""
  23. maxString := 0
  24. values := d.Column(header)
  25. for _, c := range values {
  26. switch c.(type) {
  27. case uint, uint8, uint16, uint32, uint64,
  28. int, int8, int16, int32, int64,
  29. float32, float64:
  30. if currentType != "numeric" {
  31. currentType = "numeric"
  32. types++
  33. }
  34. case time.Time:
  35. if currentType != "time" {
  36. currentType = "time"
  37. types++
  38. }
  39. case string:
  40. if currentType != "string" {
  41. currentType = "string"
  42. types++
  43. }
  44. if len(c.(string)) > maxString {
  45. maxString = len(c.(string))
  46. }
  47. }
  48. }
  49. if types > 1 {
  50. return defaults["various."+dbType], values
  51. }
  52. switch currentType {
  53. case "numeric":
  54. return defaults["numeric."+dbType], values
  55. case "time":
  56. return "TIMESTAMP", values
  57. default:
  58. if dbType == typePostgres {
  59. return "TEXT", values
  60. }
  61. return "VARCHAR(" + strconv.Itoa(maxString) + ")", values
  62. }
  63. }
  64. // isStringColumn returns whether a column is VARCHAR/TEXT
  65. func isStringColumn(c string) bool {
  66. return strings.HasPrefix(c, "VARCHAR") || strings.HasPrefix(c, "TEXT")
  67. }
  68. // MySQL returns a string representing a suite of MySQL commands
  69. // recreating the Dataset into a table.
  70. func (d *Dataset) MySQL(table string) *Exportable {
  71. return d.sql(table, typeMySQL)
  72. }
  73. // Postgres returns a string representing a suite of Postgres commands
  74. // recreating the Dataset into a table.
  75. func (d *Dataset) Postgres(table string) *Exportable {
  76. return d.sql(table, typePostgres)
  77. }
  78. // sql returns a string representing a suite of SQL commands
  79. // recreating the Dataset into a table.
  80. func (d *Dataset) sql(table, dbType string) *Exportable {
  81. b := newBuffer()
  82. tableSQL, columnTypes, columnValues := d.createTable(table, dbType)
  83. b.WriteString(tableSQL)
  84. reg, _ := regexp.Compile("[']")
  85. // inserts
  86. for i := range d.data {
  87. b.WriteString("INSERT INTO " + table + " VALUES(" + strconv.Itoa(i+1) + ", ")
  88. for j, col := range d.headers {
  89. asStr := d.asString(columnValues[col][i])
  90. if isStringColumn(columnTypes[col]) {
  91. b.WriteString("'" + reg.ReplaceAllString(asStr, "''") + "'")
  92. } else if strings.HasPrefix(columnTypes[col], "TIMESTAMP") {
  93. if dbType == typeMySQL {
  94. b.WriteString("CONVERT_TZ('" + asStr[:10] + " " + asStr[11:19] + "', '" + asStr[len(asStr)-6:] + "', 'SYSTEM')")
  95. } else {
  96. b.WriteString("'" + asStr + "'") // simpler with Postgres
  97. }
  98. } else {
  99. b.WriteString(asStr)
  100. }
  101. if j < len(d.headers)-1 {
  102. b.WriteString(", ")
  103. }
  104. }
  105. b.WriteString(");\n")
  106. }
  107. b.WriteString("\nCOMMIT;\n")
  108. return newExportable(b)
  109. }
  110. func (d *Dataset) createTable(table, dbType string) (string, map[string]string, map[string][]interface{}) {
  111. var b bytes.Buffer
  112. columnValues := make(map[string][]interface{})
  113. columnTypes := make(map[string]string)
  114. // create table
  115. b.WriteString("CREATE TABLE IF NOT EXISTS " + table)
  116. if dbType == typePostgres {
  117. b.WriteString("\n(\n\tid SERIAL PRIMARY KEY,\n")
  118. } else {
  119. b.WriteString("\n(\n\tid INT NOT NULL AUTO_INCREMENT PRIMARY KEY,\n")
  120. }
  121. for i, h := range d.headers {
  122. b.WriteString("\t" + h)
  123. t, v := d.columnSQLType(h, dbType)
  124. columnValues[h] = v
  125. columnTypes[h] = t
  126. b.WriteString(" " + t)
  127. if i < len(d.headers)-1 {
  128. b.WriteString(",")
  129. }
  130. b.WriteString("\n")
  131. }
  132. b.WriteString(");\n\n")
  133. return b.String(), columnTypes, columnValues
  134. }