statement.go 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql/driver"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "reflect"
  11. "strings"
  12. "time"
  13. "github.com/xormplus/builder"
  14. "github.com/xormplus/core"
  15. )
  16. // Statement save all the sql info for executing SQL
  17. type Statement struct {
  18. RefTable *core.Table
  19. Engine *Engine
  20. Start int
  21. LimitN int
  22. idParam *core.PK
  23. OrderStr string
  24. JoinStr string
  25. joinArgs []interface{}
  26. GroupByStr string
  27. HavingStr string
  28. ColumnStr string
  29. selectStr string
  30. useAllCols bool
  31. OmitStr string
  32. AltTableName string
  33. tableName string
  34. RawSQL string
  35. RawParams []interface{}
  36. UseCascade bool
  37. UseAutoJoin bool
  38. StoreEngine string
  39. Charset string
  40. UseCache bool
  41. UseAutoTime bool
  42. noAutoCondition bool
  43. IsDistinct bool
  44. IsForUpdate bool
  45. TableAlias string
  46. allUseBool bool
  47. checkVersion bool
  48. unscoped bool
  49. columnMap columnMap
  50. omitColumnMap columnMap
  51. mustColumnMap map[string]bool
  52. nullableMap map[string]bool
  53. incrColumns map[string]incrParam
  54. decrColumns map[string]decrParam
  55. exprColumns map[string]exprParam
  56. cond builder.Cond
  57. bufferSize int
  58. context ContextCache
  59. }
  60. // Init reset all the statement's fields
  61. func (statement *Statement) Init() {
  62. statement.RefTable = nil
  63. statement.Start = 0
  64. statement.LimitN = 0
  65. statement.OrderStr = ""
  66. statement.UseCascade = true
  67. statement.JoinStr = ""
  68. statement.joinArgs = make([]interface{}, 0)
  69. statement.GroupByStr = ""
  70. statement.HavingStr = ""
  71. statement.ColumnStr = ""
  72. statement.OmitStr = ""
  73. statement.columnMap = columnMap{}
  74. statement.omitColumnMap = columnMap{}
  75. statement.AltTableName = ""
  76. statement.tableName = ""
  77. statement.idParam = nil
  78. statement.RawSQL = ""
  79. statement.RawParams = make([]interface{}, 0)
  80. statement.UseCache = true
  81. statement.UseAutoTime = true
  82. statement.noAutoCondition = false
  83. statement.IsDistinct = false
  84. statement.IsForUpdate = false
  85. statement.TableAlias = ""
  86. statement.selectStr = ""
  87. statement.allUseBool = false
  88. statement.useAllCols = false
  89. statement.mustColumnMap = make(map[string]bool)
  90. statement.nullableMap = make(map[string]bool)
  91. statement.checkVersion = true
  92. statement.unscoped = false
  93. statement.incrColumns = make(map[string]incrParam)
  94. statement.decrColumns = make(map[string]decrParam)
  95. statement.exprColumns = make(map[string]exprParam)
  96. statement.cond = builder.NewCond()
  97. statement.bufferSize = 0
  98. statement.context = nil
  99. }
  100. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  101. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  102. statement.noAutoCondition = true
  103. if len(no) > 0 {
  104. statement.noAutoCondition = no[0]
  105. }
  106. return statement
  107. }
  108. // Alias set the table alias
  109. func (statement *Statement) Alias(alias string) *Statement {
  110. statement.TableAlias = alias
  111. return statement
  112. }
  113. // SQL adds raw sql statement
  114. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  115. switch query.(type) {
  116. case (*builder.Builder):
  117. var err error
  118. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  119. if err != nil {
  120. statement.Engine.logger.Error(err)
  121. }
  122. case string:
  123. statement.RawSQL = query.(string)
  124. statement.RawParams = args
  125. default:
  126. statement.Engine.logger.Error("unsupported sql type")
  127. }
  128. return statement
  129. }
  130. // Where add Where statement
  131. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  132. return statement.And(query, args...)
  133. }
  134. // And add Where & and statement
  135. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  136. switch query.(type) {
  137. case string:
  138. isExpr := false
  139. var cargs []interface{}
  140. for i, _ := range args {
  141. if _, ok := args[i].(sqlExpr); ok {
  142. isExpr = true
  143. }
  144. cargs = append(cargs, args[i])
  145. }
  146. if isExpr {
  147. sqlStr, _ := ConvertToBoundSQL(query.(string), cargs)
  148. cond := builder.Expr(sqlStr)
  149. statement.cond = statement.cond.And(cond)
  150. } else {
  151. cond := builder.Expr(query.(string), args...)
  152. statement.cond = statement.cond.And(cond)
  153. }
  154. case map[string]interface{}:
  155. cond := builder.Eq(query.(map[string]interface{}))
  156. statement.cond = statement.cond.And(cond)
  157. case builder.Cond:
  158. cond := query.(builder.Cond)
  159. statement.cond = statement.cond.And(cond)
  160. for _, v := range args {
  161. if vv, ok := v.(builder.Cond); ok {
  162. statement.cond = statement.cond.And(vv)
  163. }
  164. }
  165. default:
  166. // TODO: not support condition type
  167. }
  168. return statement
  169. }
  170. // Or add Where & Or statement
  171. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  172. switch query.(type) {
  173. case string:
  174. isExpr := false
  175. var cargs []interface{}
  176. for i, _ := range args {
  177. if _, ok := args[i].(sqlExpr); ok {
  178. isExpr = true
  179. }
  180. cargs = append(cargs, args[i])
  181. }
  182. if isExpr {
  183. sqlStr, _ := ConvertToBoundSQL(query.(string), cargs)
  184. cond := builder.Expr(sqlStr)
  185. statement.cond = statement.cond.Or(cond)
  186. } else {
  187. cond := builder.Expr(query.(string), args...)
  188. statement.cond = statement.cond.Or(cond)
  189. }
  190. case map[string]interface{}:
  191. cond := builder.Eq(query.(map[string]interface{}))
  192. statement.cond = statement.cond.Or(cond)
  193. case builder.Cond:
  194. cond := query.(builder.Cond)
  195. statement.cond = statement.cond.Or(cond)
  196. for _, v := range args {
  197. if vv, ok := v.(builder.Cond); ok {
  198. statement.cond = statement.cond.Or(vv)
  199. }
  200. }
  201. default:
  202. // TODO: not support condition type
  203. }
  204. return statement
  205. }
  206. // In generate "Where column IN (?) " statement
  207. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  208. in := builder.In(statement.Engine.Quote(column), args...)
  209. statement.cond = statement.cond.And(in)
  210. return statement
  211. }
  212. // NotIn generate "Where column NOT IN (?) " statement
  213. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  214. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  215. statement.cond = statement.cond.And(notIn)
  216. return statement
  217. }
  218. func (statement *Statement) setRefValue(v reflect.Value) error {
  219. var err error
  220. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  221. if err != nil {
  222. return err
  223. }
  224. statement.tableName = statement.Engine.TableName(v, true)
  225. return nil
  226. }
  227. func (statement *Statement) setRefBean(bean interface{}) error {
  228. var err error
  229. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  230. if err != nil {
  231. return err
  232. }
  233. statement.tableName = statement.Engine.TableName(bean, true)
  234. return nil
  235. }
  236. // Auto generating update columnes and values according a struct
  237. func (statement *Statement) buildUpdates(bean interface{},
  238. includeVersion, includeUpdated, includeNil,
  239. includeAutoIncr, update bool) ([]string, []interface{}) {
  240. engine := statement.Engine
  241. table := statement.RefTable
  242. allUseBool := statement.allUseBool
  243. useAllCols := statement.useAllCols
  244. mustColumnMap := statement.mustColumnMap
  245. nullableMap := statement.nullableMap
  246. columnMap := statement.columnMap
  247. omitColumnMap := statement.omitColumnMap
  248. unscoped := statement.unscoped
  249. var colNames = make([]string, 0)
  250. var args = make([]interface{}, 0)
  251. for _, col := range table.Columns() {
  252. if !includeVersion && col.IsVersion {
  253. continue
  254. }
  255. if col.IsCreated {
  256. continue
  257. }
  258. if !includeUpdated && col.IsUpdated {
  259. continue
  260. }
  261. if !includeAutoIncr && col.IsAutoIncrement {
  262. continue
  263. }
  264. if col.IsDeleted && !unscoped {
  265. continue
  266. }
  267. if omitColumnMap.contain(col.Name) {
  268. continue
  269. }
  270. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  271. continue
  272. }
  273. if col.MapType == core.ONLYFROMDB {
  274. continue
  275. }
  276. fieldValuePtr, err := col.ValueOf(bean)
  277. if err != nil {
  278. engine.logger.Error(err)
  279. continue
  280. }
  281. fieldValue := *fieldValuePtr
  282. fieldType := reflect.TypeOf(fieldValue.Interface())
  283. if fieldType == nil {
  284. continue
  285. }
  286. requiredField := useAllCols
  287. includeNil := useAllCols
  288. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  289. if b {
  290. requiredField = true
  291. } else {
  292. continue
  293. }
  294. }
  295. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  296. if b, ok := getFlagForColumn(nullableMap, col); ok {
  297. if b && col.Nullable && isZero(fieldValue.Interface()) {
  298. var nilValue *int
  299. fieldValue = reflect.ValueOf(nilValue)
  300. fieldType = reflect.TypeOf(fieldValue.Interface())
  301. includeNil = true
  302. }
  303. }
  304. var val interface{}
  305. if fieldValue.CanAddr() {
  306. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  307. data, err := structConvert.ToDB()
  308. if err != nil {
  309. engine.logger.Error(err)
  310. } else {
  311. val = data
  312. }
  313. goto APPEND
  314. }
  315. }
  316. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  317. data, err := structConvert.ToDB()
  318. if err != nil {
  319. engine.logger.Error(err)
  320. } else {
  321. val = data
  322. }
  323. goto APPEND
  324. }
  325. if fieldType.Kind() == reflect.Ptr {
  326. if fieldValue.IsNil() {
  327. if includeNil {
  328. args = append(args, nil)
  329. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  330. }
  331. continue
  332. } else if !fieldValue.IsValid() {
  333. continue
  334. } else {
  335. // dereference ptr type to instance type
  336. fieldValue = fieldValue.Elem()
  337. fieldType = reflect.TypeOf(fieldValue.Interface())
  338. requiredField = true
  339. }
  340. }
  341. switch fieldType.Kind() {
  342. case reflect.Bool:
  343. if allUseBool || requiredField {
  344. val = fieldValue.Interface()
  345. } else {
  346. // if a bool in a struct, it will not be as a condition because it default is false,
  347. // please use Where() instead
  348. continue
  349. }
  350. case reflect.String:
  351. if !requiredField && fieldValue.String() == "" {
  352. continue
  353. }
  354. // for MyString, should convert to string or panic
  355. if fieldType.String() != reflect.String.String() {
  356. val = fieldValue.String()
  357. } else {
  358. val = fieldValue.Interface()
  359. }
  360. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  361. if !requiredField && fieldValue.Int() == 0 {
  362. continue
  363. }
  364. val = fieldValue.Interface()
  365. case reflect.Float32, reflect.Float64:
  366. if !requiredField && fieldValue.Float() == 0.0 {
  367. continue
  368. }
  369. val = fieldValue.Interface()
  370. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  371. if !requiredField && fieldValue.Uint() == 0 {
  372. continue
  373. }
  374. t := int64(fieldValue.Uint())
  375. val = reflect.ValueOf(&t).Interface()
  376. case reflect.Struct:
  377. if fieldType.ConvertibleTo(core.TimeType) {
  378. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  379. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  380. continue
  381. }
  382. val = engine.formatColTime(col, t)
  383. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  384. val, _ = nulType.Value()
  385. } else {
  386. if !col.SQLType.IsJson() {
  387. engine.autoMapType(fieldValue)
  388. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  389. if len(table.PrimaryKeys) == 1 {
  390. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  391. // fix non-int pk issues
  392. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  393. val = pkField.Interface()
  394. } else {
  395. continue
  396. }
  397. } else {
  398. //TODO: how to handler?
  399. panic("not supported")
  400. }
  401. } else {
  402. val = fieldValue.Interface()
  403. }
  404. } else {
  405. // Blank struct could not be as update data
  406. if requiredField || !isStructZero(fieldValue) {
  407. bytes, err := json.Marshal(fieldValue.Interface())
  408. if err != nil {
  409. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  410. }
  411. if col.SQLType.IsText() {
  412. val = string(bytes)
  413. } else if col.SQLType.IsBlob() {
  414. val = bytes
  415. }
  416. } else {
  417. continue
  418. }
  419. }
  420. }
  421. case reflect.Array, reflect.Slice, reflect.Map:
  422. if !requiredField {
  423. if fieldValue == reflect.Zero(fieldType) {
  424. continue
  425. }
  426. if fieldType.Kind() == reflect.Array {
  427. if isArrayValueZero(fieldValue) {
  428. continue
  429. }
  430. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  431. continue
  432. }
  433. }
  434. if col.SQLType.IsText() {
  435. bytes, err := json.Marshal(fieldValue.Interface())
  436. if err != nil {
  437. engine.logger.Error(err)
  438. continue
  439. }
  440. val = string(bytes)
  441. } else if col.SQLType.IsBlob() {
  442. var bytes []byte
  443. var err error
  444. if fieldType.Kind() == reflect.Slice &&
  445. fieldType.Elem().Kind() == reflect.Uint8 {
  446. if fieldValue.Len() > 0 {
  447. val = fieldValue.Bytes()
  448. } else {
  449. continue
  450. }
  451. } else if fieldType.Kind() == reflect.Array &&
  452. fieldType.Elem().Kind() == reflect.Uint8 {
  453. val = fieldValue.Slice(0, 0).Interface()
  454. } else {
  455. bytes, err = json.Marshal(fieldValue.Interface())
  456. if err != nil {
  457. engine.logger.Error(err)
  458. continue
  459. }
  460. val = bytes
  461. }
  462. } else {
  463. continue
  464. }
  465. default:
  466. val = fieldValue.Interface()
  467. }
  468. APPEND:
  469. args = append(args, val)
  470. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  471. continue
  472. }
  473. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  474. }
  475. return colNames, args
  476. }
  477. func (statement *Statement) needTableName() bool {
  478. return len(statement.JoinStr) > 0
  479. }
  480. func (statement *Statement) colName(col *core.Column, tableName string) string {
  481. if statement.needTableName() {
  482. var nm = tableName
  483. if len(statement.TableAlias) > 0 {
  484. nm = statement.TableAlias
  485. }
  486. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  487. }
  488. return statement.Engine.Quote(col.Name)
  489. }
  490. // TableName return current tableName
  491. func (statement *Statement) TableName() string {
  492. if statement.AltTableName != "" {
  493. return statement.AltTableName
  494. }
  495. return statement.tableName
  496. }
  497. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  498. func (statement *Statement) ID(id interface{}) *Statement {
  499. idValue := reflect.ValueOf(id)
  500. idType := reflect.TypeOf(idValue.Interface())
  501. switch idType {
  502. case ptrPkType:
  503. if pkPtr, ok := (id).(*core.PK); ok {
  504. statement.idParam = pkPtr
  505. return statement
  506. }
  507. case pkType:
  508. if pk, ok := (id).(core.PK); ok {
  509. statement.idParam = &pk
  510. return statement
  511. }
  512. }
  513. switch idType.Kind() {
  514. case reflect.String:
  515. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  516. return statement
  517. }
  518. statement.idParam = &core.PK{id}
  519. return statement
  520. }
  521. // Incr Generate "Update ... Set column = column + arg" statement
  522. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  523. k := strings.ToLower(column)
  524. if len(arg) > 0 {
  525. statement.incrColumns[k] = incrParam{column, arg[0]}
  526. } else {
  527. statement.incrColumns[k] = incrParam{column, 1}
  528. }
  529. return statement
  530. }
  531. // Decr Generate "Update ... Set column = column - arg" statement
  532. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  533. k := strings.ToLower(column)
  534. if len(arg) > 0 {
  535. statement.decrColumns[k] = decrParam{column, arg[0]}
  536. } else {
  537. statement.decrColumns[k] = decrParam{column, 1}
  538. }
  539. return statement
  540. }
  541. // SetExpr Generate "Update ... Set column = {expression}" statement
  542. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  543. k := strings.ToLower(column)
  544. statement.exprColumns[k] = exprParam{column, expression}
  545. return statement
  546. }
  547. // Generate "Update ... Set column = column + arg" statement
  548. func (statement *Statement) getInc() map[string]incrParam {
  549. return statement.incrColumns
  550. }
  551. // Generate "Update ... Set column = column - arg" statement
  552. func (statement *Statement) getDec() map[string]decrParam {
  553. return statement.decrColumns
  554. }
  555. // Generate "Update ... Set column = {expression}" statement
  556. func (statement *Statement) getExpr() map[string]exprParam {
  557. return statement.exprColumns
  558. }
  559. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  560. newColumns := make([]string, 0)
  561. for _, col := range columns {
  562. col = strings.Replace(col, "`", "", -1)
  563. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  564. ccols := strings.Split(col, ",")
  565. for _, c := range ccols {
  566. fields := strings.Split(strings.TrimSpace(c), ".")
  567. if len(fields) == 1 {
  568. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  569. } else if len(fields) == 2 {
  570. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  571. statement.Engine.quote(fields[1]))
  572. } else {
  573. panic(errors.New("unwanted colnames"))
  574. }
  575. }
  576. }
  577. return newColumns
  578. }
  579. func (statement *Statement) colmap2NewColsWithQuote() []string {
  580. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  581. copy(newColumns, statement.columnMap)
  582. for i := 0; i < len(statement.columnMap); i++ {
  583. newColumns[i] = statement.Engine.Quote(newColumns[i])
  584. }
  585. return newColumns
  586. }
  587. // Distinct generates "DISTINCT col1, col2 " statement
  588. func (statement *Statement) Distinct(columns ...string) *Statement {
  589. statement.IsDistinct = true
  590. statement.Cols(columns...)
  591. return statement
  592. }
  593. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  594. func (statement *Statement) ForUpdate() *Statement {
  595. statement.IsForUpdate = true
  596. return statement
  597. }
  598. // Select replace select
  599. func (statement *Statement) Select(str string) *Statement {
  600. statement.selectStr = str
  601. return statement
  602. }
  603. // Cols generate "col1, col2" statement
  604. func (statement *Statement) Cols(columns ...string) *Statement {
  605. cols := col2NewCols(columns...)
  606. for _, nc := range cols {
  607. statement.columnMap.add(nc)
  608. }
  609. newColumns := statement.colmap2NewColsWithQuote()
  610. statement.ColumnStr = strings.Join(newColumns, ", ")
  611. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  612. return statement
  613. }
  614. // AllCols update use only: update all columns
  615. func (statement *Statement) AllCols() *Statement {
  616. statement.useAllCols = true
  617. return statement
  618. }
  619. // MustCols update use only: must update columns
  620. func (statement *Statement) MustCols(columns ...string) *Statement {
  621. newColumns := col2NewCols(columns...)
  622. for _, nc := range newColumns {
  623. statement.mustColumnMap[strings.ToLower(nc)] = true
  624. }
  625. return statement
  626. }
  627. // UseBool indicates that use bool fields as update contents and query contiditions
  628. func (statement *Statement) UseBool(columns ...string) *Statement {
  629. if len(columns) > 0 {
  630. statement.MustCols(columns...)
  631. } else {
  632. statement.allUseBool = true
  633. }
  634. return statement
  635. }
  636. // Omit do not use the columns
  637. func (statement *Statement) Omit(columns ...string) {
  638. newColumns := col2NewCols(columns...)
  639. for _, nc := range newColumns {
  640. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  641. }
  642. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  643. }
  644. // Nullable Update use only: update columns to null when value is nullable and zero-value
  645. func (statement *Statement) Nullable(columns ...string) {
  646. newColumns := col2NewCols(columns...)
  647. for _, nc := range newColumns {
  648. statement.nullableMap[strings.ToLower(nc)] = true
  649. }
  650. }
  651. // Top generate LIMIT limit statement
  652. func (statement *Statement) Top(limit int) *Statement {
  653. statement.Limit(limit)
  654. return statement
  655. }
  656. // Limit generate LIMIT start, limit statement
  657. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  658. statement.LimitN = limit
  659. if len(start) > 0 {
  660. statement.Start = start[0]
  661. }
  662. return statement
  663. }
  664. // OrderBy generate "Order By order" statement
  665. func (statement *Statement) OrderBy(order string) *Statement {
  666. if len(statement.OrderStr) > 0 {
  667. statement.OrderStr += ", "
  668. }
  669. statement.OrderStr += order
  670. return statement
  671. }
  672. // Desc generate `ORDER BY xx DESC`
  673. func (statement *Statement) Desc(colNames ...string) *Statement {
  674. var buf builder.StringBuilder
  675. if len(statement.OrderStr) > 0 {
  676. fmt.Fprint(&buf, statement.OrderStr, ", ")
  677. }
  678. newColNames := statement.col2NewColsWithQuote(colNames...)
  679. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  680. statement.OrderStr = buf.String()
  681. return statement
  682. }
  683. // Asc provide asc order by query condition, the input parameters are columns.
  684. func (statement *Statement) Asc(colNames ...string) *Statement {
  685. var buf builder.StringBuilder
  686. if len(statement.OrderStr) > 0 {
  687. fmt.Fprint(&buf, statement.OrderStr, ", ")
  688. }
  689. newColNames := statement.col2NewColsWithQuote(colNames...)
  690. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  691. statement.OrderStr = buf.String()
  692. return statement
  693. }
  694. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  695. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  696. v := rValue(tableNameOrBean)
  697. t := v.Type()
  698. if t.Kind() == reflect.Struct {
  699. var err error
  700. statement.RefTable, err = statement.Engine.autoMapType(v)
  701. if err != nil {
  702. statement.Engine.logger.Error(err)
  703. return statement
  704. }
  705. }
  706. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  707. return statement
  708. }
  709. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  710. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  711. var buf builder.StringBuilder
  712. if len(statement.JoinStr) > 0 {
  713. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  714. } else {
  715. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  716. }
  717. tbName := statement.Engine.TableName(tablename, true)
  718. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  719. statement.JoinStr = buf.String()
  720. statement.joinArgs = append(statement.joinArgs, args...)
  721. return statement
  722. }
  723. // GroupBy generate "Group By keys" statement
  724. func (statement *Statement) GroupBy(keys string) *Statement {
  725. statement.GroupByStr = keys
  726. return statement
  727. }
  728. // Having generate "Having conditions" statement
  729. func (statement *Statement) Having(conditions string) *Statement {
  730. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  731. return statement
  732. }
  733. // Unscoped always disable struct tag "deleted"
  734. func (statement *Statement) Unscoped() *Statement {
  735. statement.unscoped = true
  736. return statement
  737. }
  738. func (statement *Statement) genColumnStr() string {
  739. if statement.RefTable == nil {
  740. return ""
  741. }
  742. var buf builder.StringBuilder
  743. columns := statement.RefTable.Columns()
  744. for _, col := range columns {
  745. if statement.omitColumnMap.contain(col.Name) {
  746. continue
  747. }
  748. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  749. continue
  750. }
  751. if col.MapType == core.ONLYTODB {
  752. continue
  753. }
  754. if buf.Len() != 0 {
  755. buf.WriteString(", ")
  756. }
  757. if statement.JoinStr != "" {
  758. if statement.TableAlias != "" {
  759. buf.WriteString(statement.TableAlias)
  760. } else {
  761. buf.WriteString(statement.TableName())
  762. }
  763. buf.WriteString(".")
  764. }
  765. statement.Engine.QuoteTo(&buf, col.Name)
  766. }
  767. return buf.String()
  768. }
  769. func (statement *Statement) genCreateTableSQL() string {
  770. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  771. statement.StoreEngine, statement.Charset)
  772. }
  773. func (statement *Statement) genIndexSQL() []string {
  774. var sqls []string
  775. tbName := statement.TableName()
  776. for _, index := range statement.RefTable.Indexes {
  777. if index.Type == core.IndexType {
  778. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  779. if sql != "" {
  780. sqls = append(sqls, sql)
  781. }
  782. }
  783. }
  784. return sqls
  785. }
  786. func uniqueName(tableName, uqeName string) string {
  787. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  788. }
  789. func (statement *Statement) genUniqueSQL() []string {
  790. var sqls []string
  791. tbName := statement.TableName()
  792. for _, index := range statement.RefTable.Indexes {
  793. if index.Type == core.UniqueType {
  794. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  795. sqls = append(sqls, sql)
  796. }
  797. }
  798. return sqls
  799. }
  800. func (statement *Statement) genDelIndexSQL() []string {
  801. var sqls []string
  802. tbName := statement.TableName()
  803. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  804. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  805. for idxName, index := range statement.RefTable.Indexes {
  806. var rIdxName string
  807. if index.Type == core.UniqueType {
  808. rIdxName = uniqueName(idxPrefixName, idxName)
  809. } else if index.Type == core.IndexType {
  810. rIdxName = indexName(idxPrefixName, idxName)
  811. }
  812. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  813. if statement.Engine.dialect.IndexOnTable() {
  814. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  815. }
  816. sqls = append(sqls, sql)
  817. }
  818. return sqls
  819. }
  820. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  821. quote := statement.Engine.Quote
  822. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  823. col.String(statement.Engine.dialect))
  824. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  825. sql += " COMMENT '" + col.Comment + "'"
  826. }
  827. sql += ";"
  828. return sql, []interface{}{}
  829. }
  830. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  831. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  832. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  833. }
  834. func (statement *Statement) mergeConds(bean interface{}) error {
  835. if !statement.noAutoCondition {
  836. var addedTableName = (len(statement.JoinStr) > 0)
  837. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  838. if err != nil {
  839. return err
  840. }
  841. statement.cond = statement.cond.And(autoCond)
  842. }
  843. if err := statement.processIDParam(); err != nil {
  844. return err
  845. }
  846. return nil
  847. }
  848. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  849. if err := statement.mergeConds(bean); err != nil {
  850. return "", nil, err
  851. }
  852. return builder.ToSQL(statement.cond)
  853. }
  854. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  855. v := rValue(bean)
  856. isStruct := v.Kind() == reflect.Struct
  857. if isStruct {
  858. statement.setRefBean(bean)
  859. }
  860. var columnStr = statement.ColumnStr
  861. if len(statement.selectStr) > 0 {
  862. columnStr = statement.selectStr
  863. } else {
  864. // TODO: always generate column names, not use * even if join
  865. if len(statement.JoinStr) == 0 {
  866. if len(columnStr) == 0 {
  867. if len(statement.GroupByStr) > 0 {
  868. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  869. } else {
  870. columnStr = statement.genColumnStr()
  871. }
  872. }
  873. } else {
  874. if len(columnStr) == 0 {
  875. if len(statement.GroupByStr) > 0 {
  876. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  877. }
  878. }
  879. }
  880. }
  881. if len(columnStr) == 0 {
  882. columnStr = "*"
  883. }
  884. if isStruct {
  885. if err := statement.mergeConds(bean); err != nil {
  886. return "", nil, err
  887. }
  888. } else {
  889. if err := statement.processIDParam(); err != nil {
  890. return "", nil, err
  891. }
  892. }
  893. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  894. if err != nil {
  895. return "", nil, err
  896. }
  897. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  898. if err != nil {
  899. return "", nil, err
  900. }
  901. return sqlStr, append(statement.joinArgs, condArgs...), nil
  902. }
  903. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  904. var condSQL string
  905. var condArgs []interface{}
  906. var err error
  907. if len(beans) > 0 {
  908. statement.setRefBean(beans[0])
  909. condSQL, condArgs, err = statement.genConds(beans[0])
  910. } else {
  911. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  912. }
  913. if err != nil {
  914. return "", nil, err
  915. }
  916. var selectSQL = statement.selectStr
  917. if len(selectSQL) <= 0 {
  918. if statement.IsDistinct {
  919. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  920. } else {
  921. selectSQL = "count(*)"
  922. }
  923. }
  924. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  925. if err != nil {
  926. return "", nil, err
  927. }
  928. return sqlStr, append(statement.joinArgs, condArgs...), nil
  929. }
  930. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  931. statement.setRefBean(bean)
  932. var sumStrs = make([]string, 0, len(columns))
  933. for _, colName := range columns {
  934. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  935. colName = statement.Engine.Quote(colName)
  936. }
  937. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  938. }
  939. sumSelect := strings.Join(sumStrs, ", ")
  940. condSQL, condArgs, err := statement.genConds(bean)
  941. if err != nil {
  942. return "", nil, err
  943. }
  944. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  945. if err != nil {
  946. return "", nil, err
  947. }
  948. return sqlStr, append(statement.joinArgs, condArgs...), nil
  949. }
  950. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
  951. var (
  952. distinct string
  953. dialect = statement.Engine.Dialect()
  954. quote = statement.Engine.Quote
  955. fromStr = " FROM "
  956. top, mssqlCondi, whereStr string
  957. )
  958. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  959. distinct = "DISTINCT "
  960. }
  961. if len(condSQL) > 0 {
  962. whereStr = " WHERE " + condSQL
  963. }
  964. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  965. fromStr += statement.TableName()
  966. } else {
  967. fromStr += quote(statement.TableName())
  968. }
  969. if statement.TableAlias != "" {
  970. if dialect.DBType() == core.ORACLE {
  971. fromStr += " " + quote(statement.TableAlias)
  972. } else {
  973. fromStr += " AS " + quote(statement.TableAlias)
  974. }
  975. }
  976. if statement.JoinStr != "" {
  977. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  978. }
  979. if dialect.DBType() == core.MSSQL {
  980. if statement.LimitN > 0 {
  981. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  982. }
  983. if statement.Start > 0 {
  984. var column string
  985. if len(statement.RefTable.PKColumns()) == 0 {
  986. for _, index := range statement.RefTable.Indexes {
  987. if len(index.Cols) == 1 {
  988. column = index.Cols[0]
  989. break
  990. }
  991. }
  992. if len(column) == 0 {
  993. column = statement.RefTable.ColumnsSeq()[0]
  994. }
  995. } else {
  996. column = statement.RefTable.PKColumns()[0].Name
  997. }
  998. if statement.needTableName() {
  999. if len(statement.TableAlias) > 0 {
  1000. column = statement.TableAlias + "." + column
  1001. } else {
  1002. column = statement.TableName() + "." + column
  1003. }
  1004. }
  1005. var orderStr string
  1006. if needOrderBy && len(statement.OrderStr) > 0 {
  1007. orderStr = " ORDER BY " + statement.OrderStr
  1008. }
  1009. var groupStr string
  1010. if len(statement.GroupByStr) > 0 {
  1011. groupStr = " GROUP BY " + statement.GroupByStr
  1012. }
  1013. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  1014. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  1015. }
  1016. }
  1017. var buf builder.StringBuilder
  1018. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  1019. if len(mssqlCondi) > 0 {
  1020. if len(whereStr) > 0 {
  1021. fmt.Fprint(&buf, " AND ", mssqlCondi)
  1022. } else {
  1023. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  1024. }
  1025. }
  1026. if statement.GroupByStr != "" {
  1027. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  1028. }
  1029. if statement.HavingStr != "" {
  1030. fmt.Fprint(&buf, " ", statement.HavingStr)
  1031. }
  1032. if needOrderBy && statement.OrderStr != "" {
  1033. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  1034. }
  1035. if needLimit {
  1036. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1037. if statement.Start > 0 {
  1038. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
  1039. } else if statement.LimitN > 0 {
  1040. fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
  1041. }
  1042. } else if dialect.DBType() == core.ORACLE {
  1043. if statement.Start != 0 || statement.LimitN != 0 {
  1044. oldString := buf.String()
  1045. buf.Reset()
  1046. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  1047. columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start)
  1048. }
  1049. }
  1050. }
  1051. if statement.IsForUpdate {
  1052. return dialect.ForUpdateSql(buf.String()), nil
  1053. }
  1054. return buf.String(), nil
  1055. }
  1056. func (statement *Statement) processIDParam() error {
  1057. if statement.idParam == nil || statement.RefTable == nil {
  1058. return nil
  1059. }
  1060. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1061. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1062. len(statement.RefTable.PrimaryKeys),
  1063. len(*statement.idParam),
  1064. )
  1065. }
  1066. for i, col := range statement.RefTable.PKColumns() {
  1067. var colName = statement.colName(col, statement.TableName())
  1068. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1069. }
  1070. return nil
  1071. }
  1072. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1073. var colnames = make([]string, len(cols))
  1074. for i, col := range cols {
  1075. if includeTableName {
  1076. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1077. "." + statement.Engine.Quote(col.Name)
  1078. } else {
  1079. colnames[i] = statement.Engine.Quote(col.Name)
  1080. }
  1081. }
  1082. return strings.Join(colnames, ", ")
  1083. }
  1084. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1085. if statement.RefTable != nil {
  1086. cols := statement.RefTable.PKColumns()
  1087. if len(cols) == 0 {
  1088. return ""
  1089. }
  1090. colstrs := statement.joinColumns(cols, false)
  1091. sqls := splitNNoCase(sqlStr, " from ", 2)
  1092. if len(sqls) != 2 {
  1093. return ""
  1094. }
  1095. var top string
  1096. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1097. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1098. }
  1099. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1100. return newsql
  1101. }
  1102. return ""
  1103. }
  1104. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1105. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1106. return "", ""
  1107. }
  1108. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1109. sqls := splitNNoCase(sqlStr, "where", 2)
  1110. if len(sqls) != 2 {
  1111. if len(sqls) == 1 {
  1112. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1113. colstrs, statement.Engine.Quote(statement.TableName()))
  1114. }
  1115. return "", ""
  1116. }
  1117. var whereStr = sqls[1]
  1118. //TODO: for postgres only, if any other database?
  1119. var paraStr string
  1120. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1121. paraStr = "$"
  1122. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1123. paraStr = ":"
  1124. }
  1125. if paraStr != "" {
  1126. if strings.Contains(sqls[1], paraStr) {
  1127. dollers := strings.Split(sqls[1], paraStr)
  1128. whereStr = dollers[0]
  1129. for i, c := range dollers[1:] {
  1130. ccs := strings.SplitN(c, " ", 2)
  1131. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1132. }
  1133. }
  1134. }
  1135. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1136. colstrs, statement.Engine.Quote(statement.TableName()),
  1137. whereStr)
  1138. }