statement.go 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "bytes"
  7. "database/sql/driver"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "reflect"
  12. "strings"
  13. "time"
  14. "github.com/go-xorm/builder"
  15. "github.com/xormplus/core"
  16. )
  17. type incrParam struct {
  18. colName string
  19. arg interface{}
  20. }
  21. type decrParam struct {
  22. colName string
  23. arg interface{}
  24. }
  25. type exprParam struct {
  26. colName string
  27. expr string
  28. }
  29. // Statement save all the sql info for executing SQL
  30. type Statement struct {
  31. RefTable *core.Table
  32. Engine *Engine
  33. Start int
  34. LimitN int
  35. idParam *core.PK
  36. OrderStr string
  37. JoinStr string
  38. joinArgs []interface{}
  39. GroupByStr string
  40. HavingStr string
  41. ColumnStr string
  42. selectStr string
  43. columnMap map[string]bool
  44. useAllCols bool
  45. OmitStr string
  46. AltTableName string
  47. tableName string
  48. RawSQL string
  49. RawParams []interface{}
  50. UseCascade bool
  51. UseAutoJoin bool
  52. StoreEngine string
  53. Charset string
  54. UseCache bool
  55. UseAutoTime bool
  56. noAutoCondition bool
  57. IsDistinct bool
  58. IsForUpdate bool
  59. TableAlias string
  60. allUseBool bool
  61. checkVersion bool
  62. unscoped bool
  63. mustColumnMap map[string]bool
  64. nullableMap map[string]bool
  65. incrColumns map[string]incrParam
  66. decrColumns map[string]decrParam
  67. exprColumns map[string]exprParam
  68. cond builder.Cond
  69. }
  70. // Init reset all the statement's fields
  71. func (statement *Statement) Init() {
  72. statement.RefTable = nil
  73. statement.Start = 0
  74. statement.LimitN = 0
  75. statement.OrderStr = ""
  76. statement.UseCascade = true
  77. statement.JoinStr = ""
  78. statement.joinArgs = make([]interface{}, 0)
  79. statement.GroupByStr = ""
  80. statement.HavingStr = ""
  81. statement.ColumnStr = ""
  82. statement.OmitStr = ""
  83. statement.columnMap = make(map[string]bool)
  84. statement.AltTableName = ""
  85. statement.tableName = ""
  86. statement.idParam = nil
  87. statement.RawSQL = ""
  88. statement.RawParams = make([]interface{}, 0)
  89. statement.UseCache = true
  90. statement.UseAutoTime = true
  91. statement.noAutoCondition = false
  92. statement.IsDistinct = false
  93. statement.IsForUpdate = false
  94. statement.TableAlias = ""
  95. statement.selectStr = ""
  96. statement.allUseBool = false
  97. statement.useAllCols = false
  98. statement.mustColumnMap = make(map[string]bool)
  99. statement.nullableMap = make(map[string]bool)
  100. statement.checkVersion = true
  101. statement.unscoped = false
  102. statement.incrColumns = make(map[string]incrParam)
  103. statement.decrColumns = make(map[string]decrParam)
  104. statement.exprColumns = make(map[string]exprParam)
  105. statement.cond = builder.NewCond()
  106. }
  107. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  108. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  109. statement.noAutoCondition = true
  110. if len(no) > 0 {
  111. statement.noAutoCondition = no[0]
  112. }
  113. return statement
  114. }
  115. // Alias set the table alias
  116. func (statement *Statement) Alias(alias string) *Statement {
  117. statement.TableAlias = alias
  118. return statement
  119. }
  120. // SQL adds raw sql statement
  121. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  122. switch query.(type) {
  123. case (*builder.Builder):
  124. var err error
  125. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  126. if err != nil {
  127. statement.Engine.logger.Error(err)
  128. }
  129. case string:
  130. statement.RawSQL = query.(string)
  131. statement.RawParams = args
  132. default:
  133. statement.Engine.logger.Error("unsupported sql type")
  134. }
  135. return statement
  136. }
  137. // Where add Where statement
  138. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  139. return statement.And(query, args...)
  140. }
  141. // And add Where & and statement
  142. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  143. switch query.(type) {
  144. case string:
  145. cond := builder.Expr(query.(string), args...)
  146. statement.cond = statement.cond.And(cond)
  147. case builder.Cond:
  148. cond := query.(builder.Cond)
  149. statement.cond = statement.cond.And(cond)
  150. for _, v := range args {
  151. if vv, ok := v.(builder.Cond); ok {
  152. statement.cond = statement.cond.And(vv)
  153. }
  154. }
  155. default:
  156. // TODO: not support condition type
  157. }
  158. return statement
  159. }
  160. // Or add Where & Or statement
  161. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  162. switch query.(type) {
  163. case string:
  164. cond := builder.Expr(query.(string), args...)
  165. statement.cond = statement.cond.Or(cond)
  166. case builder.Cond:
  167. cond := query.(builder.Cond)
  168. statement.cond = statement.cond.Or(cond)
  169. for _, v := range args {
  170. if vv, ok := v.(builder.Cond); ok {
  171. statement.cond = statement.cond.Or(vv)
  172. }
  173. }
  174. default:
  175. // TODO: not support condition type
  176. }
  177. return statement
  178. }
  179. // In generate "Where column IN (?) " statement
  180. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  181. in := builder.In(statement.Engine.Quote(column), args...)
  182. statement.cond = statement.cond.And(in)
  183. return statement
  184. }
  185. // NotIn generate "Where column NOT IN (?) " statement
  186. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  187. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  188. statement.cond = statement.cond.And(notIn)
  189. return statement
  190. }
  191. func (statement *Statement) setRefValue(v reflect.Value) error {
  192. var err error
  193. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  194. if err != nil {
  195. return err
  196. }
  197. statement.tableName = statement.Engine.tbName(v)
  198. return nil
  199. }
  200. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  201. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  202. v := rValue(tableNameOrBean)
  203. t := v.Type()
  204. if t.Kind() == reflect.String {
  205. statement.AltTableName = tableNameOrBean.(string)
  206. } else if t.Kind() == reflect.Struct {
  207. var err error
  208. statement.RefTable, err = statement.Engine.autoMapType(v)
  209. if err != nil {
  210. statement.Engine.logger.Error(err)
  211. return statement
  212. }
  213. statement.AltTableName = statement.Engine.tbName(v)
  214. }
  215. return statement
  216. }
  217. // Auto generating update columnes and values according a struct
  218. func buildUpdates(engine *Engine, table *core.Table, bean interface{},
  219. includeVersion bool, includeUpdated bool, includeNil bool,
  220. includeAutoIncr bool, allUseBool bool, useAllCols bool,
  221. mustColumnMap map[string]bool, nullableMap map[string]bool,
  222. columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) {
  223. var colNames = make([]string, 0)
  224. var args = make([]interface{}, 0)
  225. for _, col := range table.Columns() {
  226. if !includeVersion && col.IsVersion {
  227. continue
  228. }
  229. if col.IsCreated {
  230. continue
  231. }
  232. if !includeUpdated && col.IsUpdated {
  233. continue
  234. }
  235. if !includeAutoIncr && col.IsAutoIncrement {
  236. continue
  237. }
  238. if col.IsDeleted && !unscoped {
  239. continue
  240. }
  241. if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use {
  242. continue
  243. }
  244. fieldValuePtr, err := col.ValueOf(bean)
  245. if err != nil {
  246. engine.logger.Error(err)
  247. continue
  248. }
  249. fieldValue := *fieldValuePtr
  250. fieldType := reflect.TypeOf(fieldValue.Interface())
  251. if fieldType == nil {
  252. continue
  253. }
  254. requiredField := useAllCols
  255. includeNil := useAllCols
  256. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  257. if b {
  258. requiredField = true
  259. } else {
  260. continue
  261. }
  262. }
  263. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  264. if b, ok := getFlagForColumn(nullableMap, col); ok {
  265. if b && col.Nullable && isZero(fieldValue.Interface()) {
  266. var nilValue *int
  267. fieldValue = reflect.ValueOf(nilValue)
  268. fieldType = reflect.TypeOf(fieldValue.Interface())
  269. includeNil = true
  270. }
  271. }
  272. var val interface{}
  273. if fieldValue.CanAddr() {
  274. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  275. data, err := structConvert.ToDB()
  276. if err != nil {
  277. engine.logger.Error(err)
  278. } else {
  279. val = data
  280. }
  281. goto APPEND
  282. }
  283. }
  284. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  285. data, err := structConvert.ToDB()
  286. if err != nil {
  287. engine.logger.Error(err)
  288. } else {
  289. val = data
  290. }
  291. goto APPEND
  292. }
  293. if fieldType.Kind() == reflect.Ptr {
  294. if fieldValue.IsNil() {
  295. if includeNil {
  296. args = append(args, nil)
  297. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  298. }
  299. continue
  300. } else if !fieldValue.IsValid() {
  301. continue
  302. } else {
  303. // dereference ptr type to instance type
  304. fieldValue = fieldValue.Elem()
  305. fieldType = reflect.TypeOf(fieldValue.Interface())
  306. requiredField = true
  307. }
  308. }
  309. switch fieldType.Kind() {
  310. case reflect.Bool:
  311. if allUseBool || requiredField {
  312. val = fieldValue.Interface()
  313. } else {
  314. // if a bool in a struct, it will not be as a condition because it default is false,
  315. // please use Where() instead
  316. continue
  317. }
  318. case reflect.String:
  319. if !requiredField && fieldValue.String() == "" {
  320. continue
  321. }
  322. // for MyString, should convert to string or panic
  323. if fieldType.String() != reflect.String.String() {
  324. val = fieldValue.String()
  325. } else {
  326. val = fieldValue.Interface()
  327. }
  328. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  329. if !requiredField && fieldValue.Int() == 0 {
  330. continue
  331. }
  332. val = fieldValue.Interface()
  333. case reflect.Float32, reflect.Float64:
  334. if !requiredField && fieldValue.Float() == 0.0 {
  335. continue
  336. }
  337. val = fieldValue.Interface()
  338. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  339. if !requiredField && fieldValue.Uint() == 0 {
  340. continue
  341. }
  342. t := int64(fieldValue.Uint())
  343. val = reflect.ValueOf(&t).Interface()
  344. case reflect.Struct:
  345. if fieldType.ConvertibleTo(core.TimeType) {
  346. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  347. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  348. continue
  349. }
  350. val = engine.formatColTime(col, t)
  351. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  352. val, _ = nulType.Value()
  353. } else {
  354. if !col.SQLType.IsJson() {
  355. engine.autoMapType(fieldValue)
  356. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  357. if len(table.PrimaryKeys) == 1 {
  358. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  359. // fix non-int pk issues
  360. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  361. val = pkField.Interface()
  362. } else {
  363. continue
  364. }
  365. } else {
  366. //TODO: how to handler?
  367. panic("not supported")
  368. }
  369. } else {
  370. val = fieldValue.Interface()
  371. }
  372. } else {
  373. // Blank struct could not be as update data
  374. if requiredField || !isStructZero(fieldValue) {
  375. bytes, err := json.Marshal(fieldValue.Interface())
  376. if err != nil {
  377. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  378. }
  379. if col.SQLType.IsText() {
  380. val = string(bytes)
  381. } else if col.SQLType.IsBlob() {
  382. val = bytes
  383. }
  384. } else {
  385. continue
  386. }
  387. }
  388. }
  389. case reflect.Array, reflect.Slice, reflect.Map:
  390. if !requiredField {
  391. if fieldValue == reflect.Zero(fieldType) {
  392. continue
  393. }
  394. if fieldType.Kind() == reflect.Array {
  395. if isArrayValueZero(fieldValue) {
  396. continue
  397. }
  398. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  399. continue
  400. }
  401. }
  402. if col.SQLType.IsText() {
  403. bytes, err := json.Marshal(fieldValue.Interface())
  404. if err != nil {
  405. engine.logger.Error(err)
  406. continue
  407. }
  408. val = string(bytes)
  409. } else if col.SQLType.IsBlob() {
  410. var bytes []byte
  411. var err error
  412. if fieldType.Kind() == reflect.Slice &&
  413. fieldType.Elem().Kind() == reflect.Uint8 {
  414. if fieldValue.Len() > 0 {
  415. val = fieldValue.Bytes()
  416. } else {
  417. continue
  418. }
  419. } else if fieldType.Kind() == reflect.Array &&
  420. fieldType.Elem().Kind() == reflect.Uint8 {
  421. val = fieldValue.Slice(0, 0).Interface()
  422. } else {
  423. bytes, err = json.Marshal(fieldValue.Interface())
  424. if err != nil {
  425. engine.logger.Error(err)
  426. continue
  427. }
  428. val = bytes
  429. }
  430. } else {
  431. continue
  432. }
  433. default:
  434. val = fieldValue.Interface()
  435. }
  436. APPEND:
  437. args = append(args, val)
  438. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  439. continue
  440. }
  441. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  442. }
  443. return colNames, args
  444. }
  445. func (statement *Statement) needTableName() bool {
  446. return len(statement.JoinStr) > 0
  447. }
  448. func (statement *Statement) colName(col *core.Column, tableName string) string {
  449. if statement.needTableName() {
  450. var nm = tableName
  451. if len(statement.TableAlias) > 0 {
  452. nm = statement.TableAlias
  453. }
  454. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  455. }
  456. return statement.Engine.Quote(col.Name)
  457. }
  458. // TableName return current tableName
  459. func (statement *Statement) TableName() string {
  460. if statement.AltTableName != "" {
  461. return statement.AltTableName
  462. }
  463. return statement.tableName
  464. }
  465. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  466. func (statement *Statement) ID(id interface{}) *Statement {
  467. idValue := reflect.ValueOf(id)
  468. idType := reflect.TypeOf(idValue.Interface())
  469. switch idType {
  470. case ptrPkType:
  471. if pkPtr, ok := (id).(*core.PK); ok {
  472. statement.idParam = pkPtr
  473. return statement
  474. }
  475. case pkType:
  476. if pk, ok := (id).(core.PK); ok {
  477. statement.idParam = &pk
  478. return statement
  479. }
  480. }
  481. switch idType.Kind() {
  482. case reflect.String:
  483. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  484. return statement
  485. }
  486. statement.idParam = &core.PK{id}
  487. return statement
  488. }
  489. // Incr Generate "Update ... Set column = column + arg" statement
  490. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  491. k := strings.ToLower(column)
  492. if len(arg) > 0 {
  493. statement.incrColumns[k] = incrParam{column, arg[0]}
  494. } else {
  495. statement.incrColumns[k] = incrParam{column, 1}
  496. }
  497. return statement
  498. }
  499. // Decr Generate "Update ... Set column = column - arg" statement
  500. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  501. k := strings.ToLower(column)
  502. if len(arg) > 0 {
  503. statement.decrColumns[k] = decrParam{column, arg[0]}
  504. } else {
  505. statement.decrColumns[k] = decrParam{column, 1}
  506. }
  507. return statement
  508. }
  509. // SetExpr Generate "Update ... Set column = {expression}" statement
  510. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  511. k := strings.ToLower(column)
  512. statement.exprColumns[k] = exprParam{column, expression}
  513. return statement
  514. }
  515. // Generate "Update ... Set column = column + arg" statement
  516. func (statement *Statement) getInc() map[string]incrParam {
  517. return statement.incrColumns
  518. }
  519. // Generate "Update ... Set column = column - arg" statement
  520. func (statement *Statement) getDec() map[string]decrParam {
  521. return statement.decrColumns
  522. }
  523. // Generate "Update ... Set column = {expression}" statement
  524. func (statement *Statement) getExpr() map[string]exprParam {
  525. return statement.exprColumns
  526. }
  527. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  528. newColumns := make([]string, 0)
  529. for _, col := range columns {
  530. col = strings.Replace(col, "`", "", -1)
  531. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  532. ccols := strings.Split(col, ",")
  533. for _, c := range ccols {
  534. fields := strings.Split(strings.TrimSpace(c), ".")
  535. if len(fields) == 1 {
  536. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  537. } else if len(fields) == 2 {
  538. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  539. statement.Engine.quote(fields[1]))
  540. } else {
  541. panic(errors.New("unwanted colnames"))
  542. }
  543. }
  544. }
  545. return newColumns
  546. }
  547. func (statement *Statement) colmap2NewColsWithQuote() []string {
  548. newColumns := make([]string, 0, len(statement.columnMap))
  549. for col := range statement.columnMap {
  550. fields := strings.Split(strings.TrimSpace(col), ".")
  551. if len(fields) == 1 {
  552. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  553. } else if len(fields) == 2 {
  554. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  555. statement.Engine.quote(fields[1]))
  556. } else {
  557. panic(errors.New("unwanted colnames"))
  558. }
  559. }
  560. return newColumns
  561. }
  562. // Distinct generates "DISTINCT col1, col2 " statement
  563. func (statement *Statement) Distinct(columns ...string) *Statement {
  564. statement.IsDistinct = true
  565. statement.Cols(columns...)
  566. return statement
  567. }
  568. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  569. func (statement *Statement) ForUpdate() *Statement {
  570. statement.IsForUpdate = true
  571. return statement
  572. }
  573. // Select replace select
  574. func (statement *Statement) Select(str string) *Statement {
  575. statement.selectStr = str
  576. return statement
  577. }
  578. // Cols generate "col1, col2" statement
  579. func (statement *Statement) Cols(columns ...string) *Statement {
  580. cols := col2NewCols(columns...)
  581. for _, nc := range cols {
  582. statement.columnMap[strings.ToLower(nc)] = true
  583. }
  584. newColumns := statement.colmap2NewColsWithQuote()
  585. statement.ColumnStr = strings.Join(newColumns, ", ")
  586. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  587. return statement
  588. }
  589. // AllCols update use only: update all columns
  590. func (statement *Statement) AllCols() *Statement {
  591. statement.useAllCols = true
  592. return statement
  593. }
  594. // MustCols update use only: must update columns
  595. func (statement *Statement) MustCols(columns ...string) *Statement {
  596. newColumns := col2NewCols(columns...)
  597. for _, nc := range newColumns {
  598. statement.mustColumnMap[strings.ToLower(nc)] = true
  599. }
  600. return statement
  601. }
  602. // UseBool indicates that use bool fields as update contents and query contiditions
  603. func (statement *Statement) UseBool(columns ...string) *Statement {
  604. if len(columns) > 0 {
  605. statement.MustCols(columns...)
  606. } else {
  607. statement.allUseBool = true
  608. }
  609. return statement
  610. }
  611. // Omit do not use the columns
  612. func (statement *Statement) Omit(columns ...string) {
  613. newColumns := col2NewCols(columns...)
  614. for _, nc := range newColumns {
  615. statement.columnMap[strings.ToLower(nc)] = false
  616. }
  617. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  618. }
  619. // Nullable Update use only: update columns to null when value is nullable and zero-value
  620. func (statement *Statement) Nullable(columns ...string) {
  621. newColumns := col2NewCols(columns...)
  622. for _, nc := range newColumns {
  623. statement.nullableMap[strings.ToLower(nc)] = true
  624. }
  625. }
  626. // Top generate LIMIT limit statement
  627. func (statement *Statement) Top(limit int) *Statement {
  628. statement.Limit(limit)
  629. return statement
  630. }
  631. // Limit generate LIMIT start, limit statement
  632. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  633. statement.LimitN = limit
  634. if len(start) > 0 {
  635. statement.Start = start[0]
  636. }
  637. return statement
  638. }
  639. // OrderBy generate "Order By order" statement
  640. func (statement *Statement) OrderBy(order string) *Statement {
  641. if len(statement.OrderStr) > 0 {
  642. statement.OrderStr += ", "
  643. }
  644. statement.OrderStr += order
  645. return statement
  646. }
  647. // Desc generate `ORDER BY xx DESC`
  648. func (statement *Statement) Desc(colNames ...string) *Statement {
  649. var buf bytes.Buffer
  650. fmt.Fprintf(&buf, statement.OrderStr)
  651. if len(statement.OrderStr) > 0 {
  652. fmt.Fprint(&buf, ", ")
  653. }
  654. newColNames := statement.col2NewColsWithQuote(colNames...)
  655. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  656. statement.OrderStr = buf.String()
  657. return statement
  658. }
  659. // Asc provide asc order by query condition, the input parameters are columns.
  660. func (statement *Statement) Asc(colNames ...string) *Statement {
  661. var buf bytes.Buffer
  662. fmt.Fprintf(&buf, statement.OrderStr)
  663. if len(statement.OrderStr) > 0 {
  664. fmt.Fprint(&buf, ", ")
  665. }
  666. newColNames := statement.col2NewColsWithQuote(colNames...)
  667. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  668. statement.OrderStr = buf.String()
  669. return statement
  670. }
  671. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  672. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  673. var buf bytes.Buffer
  674. if len(statement.JoinStr) > 0 {
  675. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  676. } else {
  677. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  678. }
  679. switch tablename.(type) {
  680. case []string:
  681. t := tablename.([]string)
  682. if len(t) > 1 {
  683. fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
  684. } else if len(t) == 1 {
  685. fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
  686. }
  687. case []interface{}:
  688. t := tablename.([]interface{})
  689. l := len(t)
  690. var table string
  691. if l > 0 {
  692. f := t[0]
  693. v := rValue(f)
  694. t := v.Type()
  695. if t.Kind() == reflect.String {
  696. table = f.(string)
  697. } else if t.Kind() == reflect.Struct {
  698. table = statement.Engine.tbName(v)
  699. }
  700. }
  701. if l > 1 {
  702. fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
  703. statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
  704. } else if l == 1 {
  705. fmt.Fprintf(&buf, statement.Engine.Quote(table))
  706. }
  707. default:
  708. fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
  709. }
  710. fmt.Fprintf(&buf, " ON %v", condition)
  711. statement.JoinStr = buf.String()
  712. statement.joinArgs = append(statement.joinArgs, args...)
  713. return statement
  714. }
  715. // GroupBy generate "Group By keys" statement
  716. func (statement *Statement) GroupBy(keys string) *Statement {
  717. statement.GroupByStr = keys
  718. return statement
  719. }
  720. // Having generate "Having conditions" statement
  721. func (statement *Statement) Having(conditions string) *Statement {
  722. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  723. return statement
  724. }
  725. // Unscoped always disable struct tag "deleted"
  726. func (statement *Statement) Unscoped() *Statement {
  727. statement.unscoped = true
  728. return statement
  729. }
  730. func (statement *Statement) genColumnStr() string {
  731. var buf bytes.Buffer
  732. if statement.RefTable == nil {
  733. return ""
  734. }
  735. columns := statement.RefTable.Columns()
  736. for _, col := range columns {
  737. if statement.OmitStr != "" {
  738. if _, ok := getFlagForColumn(statement.columnMap, col); ok {
  739. continue
  740. }
  741. }
  742. if col.MapType == core.ONLYTODB {
  743. continue
  744. }
  745. if buf.Len() != 0 {
  746. buf.WriteString(", ")
  747. }
  748. if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
  749. buf.WriteString("id() AS ")
  750. }
  751. if statement.JoinStr != "" {
  752. if statement.TableAlias != "" {
  753. buf.WriteString(statement.TableAlias)
  754. } else {
  755. buf.WriteString(statement.TableName())
  756. }
  757. buf.WriteString(".")
  758. }
  759. statement.Engine.QuoteTo(&buf, col.Name)
  760. }
  761. return buf.String()
  762. }
  763. func (statement *Statement) genCreateTableSQL() string {
  764. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  765. statement.StoreEngine, statement.Charset)
  766. }
  767. func (statement *Statement) genIndexSQL() []string {
  768. var sqls []string
  769. tbName := statement.TableName()
  770. quote := statement.Engine.Quote
  771. for idxName, index := range statement.RefTable.Indexes {
  772. if index.Type == core.IndexType {
  773. sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
  774. quote(tbName), quote(strings.Join(index.Cols, quote(","))))
  775. sqls = append(sqls, sql)
  776. }
  777. }
  778. return sqls
  779. }
  780. func uniqueName(tableName, uqeName string) string {
  781. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  782. }
  783. func (statement *Statement) genUniqueSQL() []string {
  784. var sqls []string
  785. tbName := statement.TableName()
  786. for _, index := range statement.RefTable.Indexes {
  787. if index.Type == core.UniqueType {
  788. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  789. sqls = append(sqls, sql)
  790. }
  791. }
  792. return sqls
  793. }
  794. func (statement *Statement) genDelIndexSQL() []string {
  795. var sqls []string
  796. tbName := statement.TableName()
  797. for idxName, index := range statement.RefTable.Indexes {
  798. var rIdxName string
  799. if index.Type == core.UniqueType {
  800. rIdxName = uniqueName(tbName, idxName)
  801. } else if index.Type == core.IndexType {
  802. rIdxName = indexName(tbName, idxName)
  803. }
  804. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
  805. if statement.Engine.dialect.IndexOnTable() {
  806. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
  807. }
  808. sqls = append(sqls, sql)
  809. }
  810. return sqls
  811. }
  812. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  813. quote := statement.Engine.Quote
  814. sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()),
  815. col.String(statement.Engine.dialect))
  816. return sql, []interface{}{}
  817. }
  818. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  819. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  820. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  821. }
  822. func (statement *Statement) mergeConds(bean interface{}) error {
  823. if !statement.noAutoCondition {
  824. var addedTableName = (len(statement.JoinStr) > 0)
  825. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  826. if err != nil {
  827. return err
  828. }
  829. statement.cond = statement.cond.And(autoCond)
  830. }
  831. if err := statement.processIDParam(); err != nil {
  832. return err
  833. }
  834. return nil
  835. }
  836. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  837. if err := statement.mergeConds(bean); err != nil {
  838. return "", nil, err
  839. }
  840. return builder.ToSQL(statement.cond)
  841. }
  842. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  843. v := rValue(bean)
  844. isStruct := v.Kind() == reflect.Struct
  845. if isStruct {
  846. statement.setRefValue(v)
  847. }
  848. var columnStr = statement.ColumnStr
  849. if len(statement.selectStr) > 0 {
  850. columnStr = statement.selectStr
  851. } else {
  852. // TODO: always generate column names, not use * even if join
  853. if len(statement.JoinStr) == 0 {
  854. if len(columnStr) == 0 {
  855. if len(statement.GroupByStr) > 0 {
  856. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  857. } else {
  858. columnStr = statement.genColumnStr()
  859. }
  860. }
  861. } else {
  862. if len(columnStr) == 0 {
  863. if len(statement.GroupByStr) > 0 {
  864. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  865. }
  866. }
  867. }
  868. }
  869. if len(columnStr) == 0 {
  870. columnStr = "*"
  871. }
  872. if isStruct {
  873. if err := statement.mergeConds(bean); err != nil {
  874. return "", nil, err
  875. }
  876. }
  877. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  878. if err != nil {
  879. return "", nil, err
  880. }
  881. sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
  882. if err != nil {
  883. return "", nil, err
  884. }
  885. return sqlStr, append(statement.joinArgs, condArgs...), nil
  886. }
  887. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  888. var condSQL string
  889. var condArgs []interface{}
  890. var err error
  891. if len(beans) > 0 {
  892. statement.setRefValue(rValue(beans[0]))
  893. condSQL, condArgs, err = statement.genConds(beans[0])
  894. } else {
  895. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  896. }
  897. if err != nil {
  898. return "", nil, err
  899. }
  900. var selectSQL = statement.selectStr
  901. if len(selectSQL) <= 0 {
  902. if statement.IsDistinct {
  903. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  904. } else {
  905. selectSQL = "count(*)"
  906. }
  907. }
  908. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
  909. if err != nil {
  910. return "", nil, err
  911. }
  912. return sqlStr, append(statement.joinArgs, condArgs...), nil
  913. }
  914. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  915. statement.setRefValue(rValue(bean))
  916. var sumStrs = make([]string, 0, len(columns))
  917. for _, colName := range columns {
  918. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  919. colName = statement.Engine.Quote(colName)
  920. }
  921. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  922. }
  923. sumSelect := strings.Join(sumStrs, ", ")
  924. condSQL, condArgs, err := statement.genConds(bean)
  925. if err != nil {
  926. return "", nil, err
  927. }
  928. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
  929. if err != nil {
  930. return "", nil, err
  931. }
  932. return sqlStr, append(statement.joinArgs, condArgs...), nil
  933. }
  934. func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
  935. var distinct string
  936. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  937. distinct = "DISTINCT "
  938. }
  939. var dialect = statement.Engine.Dialect()
  940. var quote = statement.Engine.Quote
  941. var top string
  942. var mssqlCondi string
  943. if err := statement.processIDParam(); err != nil {
  944. return "", err
  945. }
  946. var buf bytes.Buffer
  947. if len(condSQL) > 0 {
  948. fmt.Fprintf(&buf, " WHERE %v", condSQL)
  949. }
  950. var whereStr = buf.String()
  951. var fromStr = " FROM "
  952. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  953. fromStr += statement.TableName()
  954. } else {
  955. fromStr += quote(statement.TableName())
  956. }
  957. if statement.TableAlias != "" {
  958. if dialect.DBType() == core.ORACLE {
  959. fromStr += " " + quote(statement.TableAlias)
  960. } else {
  961. fromStr += " AS " + quote(statement.TableAlias)
  962. }
  963. }
  964. if statement.JoinStr != "" {
  965. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  966. }
  967. if dialect.DBType() == core.MSSQL {
  968. if statement.LimitN > 0 {
  969. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  970. }
  971. if statement.Start > 0 {
  972. var column string
  973. if len(statement.RefTable.PKColumns()) == 0 {
  974. for _, index := range statement.RefTable.Indexes {
  975. if len(index.Cols) == 1 {
  976. column = index.Cols[0]
  977. break
  978. }
  979. }
  980. if len(column) == 0 {
  981. column = statement.RefTable.ColumnsSeq()[0]
  982. }
  983. } else {
  984. column = statement.RefTable.PKColumns()[0].Name
  985. }
  986. if statement.needTableName() {
  987. if len(statement.TableAlias) > 0 {
  988. column = statement.TableAlias + "." + column
  989. } else {
  990. column = statement.TableName() + "." + column
  991. }
  992. }
  993. var orderStr string
  994. if len(statement.OrderStr) > 0 {
  995. orderStr = " ORDER BY " + statement.OrderStr
  996. }
  997. var groupStr string
  998. if len(statement.GroupByStr) > 0 {
  999. groupStr = " GROUP BY " + statement.GroupByStr
  1000. }
  1001. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  1002. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  1003. }
  1004. }
  1005. // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
  1006. a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  1007. if len(mssqlCondi) > 0 {
  1008. if len(whereStr) > 0 {
  1009. a += " AND " + mssqlCondi
  1010. } else {
  1011. a += " WHERE " + mssqlCondi
  1012. }
  1013. }
  1014. if statement.GroupByStr != "" {
  1015. a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
  1016. }
  1017. if statement.HavingStr != "" {
  1018. a = fmt.Sprintf("%v %v", a, statement.HavingStr)
  1019. }
  1020. if statement.OrderStr != "" {
  1021. a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
  1022. }
  1023. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1024. if statement.Start > 0 {
  1025. a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
  1026. } else if statement.LimitN > 0 {
  1027. a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
  1028. }
  1029. } else if dialect.DBType() == core.ORACLE {
  1030. if statement.Start != 0 || statement.LimitN != 0 {
  1031. a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
  1032. }
  1033. }
  1034. if statement.IsForUpdate {
  1035. a = dialect.ForUpdateSql(a)
  1036. }
  1037. return
  1038. }
  1039. func (statement *Statement) processIDParam() error {
  1040. if statement.idParam == nil {
  1041. return nil
  1042. }
  1043. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1044. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1045. len(statement.RefTable.PrimaryKeys),
  1046. len(*statement.idParam),
  1047. )
  1048. }
  1049. for i, col := range statement.RefTable.PKColumns() {
  1050. var colName = statement.colName(col, statement.TableName())
  1051. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1052. }
  1053. return nil
  1054. }
  1055. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1056. var colnames = make([]string, len(cols))
  1057. for i, col := range cols {
  1058. if includeTableName {
  1059. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1060. "." + statement.Engine.Quote(col.Name)
  1061. } else {
  1062. colnames[i] = statement.Engine.Quote(col.Name)
  1063. }
  1064. }
  1065. return strings.Join(colnames, ", ")
  1066. }
  1067. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1068. if statement.RefTable != nil {
  1069. cols := statement.RefTable.PKColumns()
  1070. if len(cols) == 0 {
  1071. return ""
  1072. }
  1073. colstrs := statement.joinColumns(cols, false)
  1074. sqls := splitNNoCase(sqlStr, " from ", 2)
  1075. if len(sqls) != 2 {
  1076. return ""
  1077. }
  1078. var top string
  1079. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1080. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1081. }
  1082. return fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1083. }
  1084. return ""
  1085. }
  1086. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1087. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1088. return "", ""
  1089. }
  1090. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1091. sqls := splitNNoCase(sqlStr, "where", 2)
  1092. if len(sqls) != 2 {
  1093. if len(sqls) == 1 {
  1094. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1095. colstrs, statement.Engine.Quote(statement.TableName()))
  1096. }
  1097. return "", ""
  1098. }
  1099. var whereStr = sqls[1]
  1100. //TODO: for postgres only, if any other database?
  1101. var paraStr string
  1102. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1103. paraStr = "$"
  1104. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1105. paraStr = ":"
  1106. }
  1107. if paraStr != "" {
  1108. if strings.Contains(sqls[1], paraStr) {
  1109. dollers := strings.Split(sqls[1], paraStr)
  1110. whereStr = dollers[0]
  1111. for i, c := range dollers[1:] {
  1112. ccs := strings.SplitN(c, " ", 2)
  1113. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1114. }
  1115. }
  1116. }
  1117. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1118. colstrs, statement.Engine.Quote(statement.TableName()),
  1119. whereStr)
  1120. }