dialect.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package core
  5. import (
  6. "fmt"
  7. "strings"
  8. "time"
  9. )
  10. type DbType string
  11. type Uri struct {
  12. DbType DbType
  13. Proto string
  14. Host string
  15. Port string
  16. DbName string
  17. User string
  18. Passwd string
  19. Charset string
  20. Laddr string
  21. Raddr string
  22. Timeout time.Duration
  23. Schema string
  24. }
  25. // a dialect is a driver's wrapper
  26. type Dialect interface {
  27. SetLogger(logger ILogger)
  28. Init(*DB, *Uri, string, string) error
  29. URI() *Uri
  30. DB() *DB
  31. DBType() DbType
  32. SqlType(*Column) string
  33. FormatBytes(b []byte) string
  34. DriverName() string
  35. DataSourceName() string
  36. IsReserved(string) bool
  37. Quote(string) string
  38. AndStr() string
  39. OrStr() string
  40. EqStr() string
  41. RollBackStr() string
  42. AutoIncrStr() string
  43. SupportInsertMany() bool
  44. SupportEngine() bool
  45. SupportCharset() bool
  46. SupportDropIfExists() bool
  47. IndexOnTable() bool
  48. ShowCreateNull() bool
  49. IndexCheckSql(tableName, idxName string) (string, []interface{})
  50. TableCheckSql(tableName string) (string, []interface{})
  51. IsColumnExist(tableName string, colName string) (bool, error)
  52. CreateTableSql(table *Table, tableName, storeEngine, charset string) string
  53. DropTableSql(tableName string) string
  54. CreateIndexSql(tableName string, index *Index) string
  55. DropIndexSql(tableName string, index *Index) string
  56. ModifyColumnSql(tableName string, col *Column) string
  57. ForUpdateSql(query string) string
  58. // CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
  59. // MustDropTable(tableName string) error
  60. GetColumns(tableName string) ([]string, map[string]*Column, error)
  61. GetTables() ([]*Table, error)
  62. GetIndexes(tableName string) (map[string]*Index, error)
  63. Filters() []Filter
  64. SetParams(params map[string]string)
  65. }
  66. func OpenDialect(dialect Dialect) (*DB, error) {
  67. return Open(dialect.DriverName(), dialect.DataSourceName())
  68. }
  69. // Base represents a basic dialect and all real dialects could embed this struct
  70. type Base struct {
  71. db *DB
  72. dialect Dialect
  73. driverName string
  74. dataSourceName string
  75. logger ILogger
  76. *Uri
  77. }
  78. func (b *Base) DB() *DB {
  79. return b.db
  80. }
  81. func (b *Base) SetLogger(logger ILogger) {
  82. b.logger = logger
  83. }
  84. func (b *Base) Init(db *DB, dialect Dialect, uri *Uri, drivername, dataSourceName string) error {
  85. b.db, b.dialect, b.Uri = db, dialect, uri
  86. b.driverName, b.dataSourceName = drivername, dataSourceName
  87. return nil
  88. }
  89. func (b *Base) URI() *Uri {
  90. return b.Uri
  91. }
  92. func (b *Base) DBType() DbType {
  93. return b.Uri.DbType
  94. }
  95. func (b *Base) FormatBytes(bs []byte) string {
  96. return fmt.Sprintf("0x%x", bs)
  97. }
  98. func (b *Base) DriverName() string {
  99. return b.driverName
  100. }
  101. func (b *Base) ShowCreateNull() bool {
  102. return true
  103. }
  104. func (b *Base) DataSourceName() string {
  105. return b.dataSourceName
  106. }
  107. func (b *Base) AndStr() string {
  108. return "AND"
  109. }
  110. func (b *Base) OrStr() string {
  111. return "OR"
  112. }
  113. func (b *Base) EqStr() string {
  114. return "="
  115. }
  116. func (db *Base) RollBackStr() string {
  117. return "ROLL BACK"
  118. }
  119. func (db *Base) SupportDropIfExists() bool {
  120. return true
  121. }
  122. func (db *Base) DropTableSql(tableName string) string {
  123. quote := db.dialect.Quote
  124. return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))
  125. }
  126. func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) {
  127. db.LogSQL(query, args)
  128. rows, err := db.DB().Query(query, args...)
  129. if err != nil {
  130. return false, err
  131. }
  132. defer rows.Close()
  133. if rows.Next() {
  134. return true, nil
  135. }
  136. return false, nil
  137. }
  138. func (db *Base) IsColumnExist(tableName, colName string) (bool, error) {
  139. query := fmt.Sprintf(
  140. "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
  141. db.dialect.Quote("COLUMN_NAME"),
  142. db.dialect.Quote("INFORMATION_SCHEMA"),
  143. db.dialect.Quote("COLUMNS"),
  144. db.dialect.Quote("TABLE_SCHEMA"),
  145. db.dialect.Quote("TABLE_NAME"),
  146. db.dialect.Quote("COLUMN_NAME"),
  147. )
  148. return db.HasRecords(query, db.DbName, tableName, colName)
  149. }
  150. /*
  151. func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error {
  152. sql, args := db.dialect.TableCheckSql(tableName)
  153. rows, err := db.DB().Query(sql, args...)
  154. if db.Logger != nil {
  155. db.Logger.Info("[sql]", sql, args)
  156. }
  157. if err != nil {
  158. return err
  159. }
  160. defer rows.Close()
  161. if rows.Next() {
  162. return nil
  163. }
  164. sql = db.dialect.CreateTableSql(table, tableName, storeEngine, charset)
  165. _, err = db.DB().Exec(sql)
  166. if db.Logger != nil {
  167. db.Logger.Info("[sql]", sql)
  168. }
  169. return err
  170. }*/
  171. func (db *Base) CreateIndexSql(tableName string, index *Index) string {
  172. quote := db.dialect.Quote
  173. var unique string
  174. var idxName string
  175. if index.Type == UniqueType {
  176. unique = " UNIQUE"
  177. }
  178. idxName = index.XName(tableName)
  179. return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
  180. quote(idxName), quote(tableName),
  181. quote(strings.Join(index.Cols, quote(","))))
  182. }
  183. func (db *Base) DropIndexSql(tableName string, index *Index) string {
  184. quote := db.dialect.Quote
  185. var name string
  186. if index.IsRegular {
  187. name = index.XName(tableName)
  188. } else {
  189. name = index.Name
  190. }
  191. return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
  192. }
  193. func (db *Base) ModifyColumnSql(tableName string, col *Column) string {
  194. return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, col.StringNoPk(db.dialect))
  195. }
  196. func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset string) string {
  197. var sql string
  198. sql = "CREATE TABLE IF NOT EXISTS "
  199. if tableName == "" {
  200. tableName = table.Name
  201. }
  202. sql += b.dialect.Quote(tableName)
  203. sql += " ("
  204. if len(table.ColumnsSeq()) > 0 {
  205. pkList := table.PrimaryKeys
  206. for _, colName := range table.ColumnsSeq() {
  207. col := table.GetColumn(colName)
  208. if col.IsPrimaryKey && len(pkList) == 1 {
  209. sql += col.String(b.dialect)
  210. } else {
  211. sql += col.StringNoPk(b.dialect)
  212. }
  213. sql = strings.TrimSpace(sql)
  214. if b.DriverName() == MYSQL && len(col.Comment) > 0 {
  215. sql += " COMMENT '" + col.Comment + "'"
  216. }
  217. sql += ", "
  218. }
  219. if len(pkList) > 1 {
  220. sql += "PRIMARY KEY ( "
  221. sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(",")))
  222. sql += " ), "
  223. }
  224. sql = sql[:len(sql)-2]
  225. }
  226. sql += ")"
  227. if b.dialect.SupportEngine() && storeEngine != "" {
  228. sql += " ENGINE=" + storeEngine
  229. }
  230. if b.dialect.SupportCharset() {
  231. if len(charset) == 0 {
  232. charset = b.dialect.URI().Charset
  233. }
  234. if len(charset) > 0 {
  235. sql += " DEFAULT CHARSET " + charset
  236. }
  237. }
  238. return sql
  239. }
  240. func (b *Base) ForUpdateSql(query string) string {
  241. return query + " FOR UPDATE"
  242. }
  243. func (b *Base) LogSQL(sql string, args []interface{}) {
  244. if b.logger != nil && b.logger.IsShowSQL() {
  245. if len(args) > 0 {
  246. b.logger.Infof("[SQL] %v %v", sql, args)
  247. } else {
  248. b.logger.Infof("[SQL] %v", sql)
  249. }
  250. }
  251. }
  252. func (b *Base) SetParams(params map[string]string) {
  253. }
  254. var (
  255. dialects = map[string]func() Dialect{}
  256. )
  257. // RegisterDialect register database dialect
  258. func RegisterDialect(dbName DbType, dialectFunc func() Dialect) {
  259. if dialectFunc == nil {
  260. panic("core: Register dialect is nil")
  261. }
  262. dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
  263. }
  264. // QueryDialect query if registered database dialect
  265. func QueryDialect(dbName DbType) Dialect {
  266. if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
  267. return d()
  268. }
  269. return nil
  270. }