statement.go 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "bytes"
  7. "database/sql/driver"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "reflect"
  12. "strings"
  13. "time"
  14. "github.com/go-xorm/builder"
  15. "github.com/xormplus/core"
  16. )
  17. type incrParam struct {
  18. colName string
  19. arg interface{}
  20. }
  21. type decrParam struct {
  22. colName string
  23. arg interface{}
  24. }
  25. type exprParam struct {
  26. colName string
  27. expr string
  28. }
  29. // Statement save all the sql info for executing SQL
  30. type Statement struct {
  31. RefTable *core.Table
  32. Engine *Engine
  33. Start int
  34. LimitN int
  35. idParam *core.PK
  36. OrderStr string
  37. JoinStr string
  38. joinArgs []interface{}
  39. GroupByStr string
  40. HavingStr string
  41. ColumnStr string
  42. selectStr string
  43. columnMap map[string]bool
  44. useAllCols bool
  45. OmitStr string
  46. AltTableName string
  47. tableName string
  48. RawSQL string
  49. RawParams []interface{}
  50. UseCascade bool
  51. UseAutoJoin bool
  52. StoreEngine string
  53. Charset string
  54. UseCache bool
  55. UseAutoTime bool
  56. noAutoCondition bool
  57. IsDistinct bool
  58. IsForUpdate bool
  59. TableAlias string
  60. allUseBool bool
  61. checkVersion bool
  62. unscoped bool
  63. mustColumnMap map[string]bool
  64. nullableMap map[string]bool
  65. incrColumns map[string]incrParam
  66. decrColumns map[string]decrParam
  67. exprColumns map[string]exprParam
  68. cond builder.Cond
  69. bufferSize int
  70. }
  71. // Init reset all the statement's fields
  72. func (statement *Statement) Init() {
  73. statement.RefTable = nil
  74. statement.Start = 0
  75. statement.LimitN = 0
  76. statement.OrderStr = ""
  77. statement.UseCascade = true
  78. statement.JoinStr = ""
  79. statement.joinArgs = make([]interface{}, 0)
  80. statement.GroupByStr = ""
  81. statement.HavingStr = ""
  82. statement.ColumnStr = ""
  83. statement.OmitStr = ""
  84. statement.columnMap = make(map[string]bool)
  85. statement.AltTableName = ""
  86. statement.tableName = ""
  87. statement.idParam = nil
  88. statement.RawSQL = ""
  89. statement.RawParams = make([]interface{}, 0)
  90. statement.UseCache = true
  91. statement.UseAutoTime = true
  92. statement.noAutoCondition = false
  93. statement.IsDistinct = false
  94. statement.IsForUpdate = false
  95. statement.TableAlias = ""
  96. statement.selectStr = ""
  97. statement.allUseBool = false
  98. statement.useAllCols = false
  99. statement.mustColumnMap = make(map[string]bool)
  100. statement.nullableMap = make(map[string]bool)
  101. statement.checkVersion = true
  102. statement.unscoped = false
  103. statement.incrColumns = make(map[string]incrParam)
  104. statement.decrColumns = make(map[string]decrParam)
  105. statement.exprColumns = make(map[string]exprParam)
  106. statement.cond = builder.NewCond()
  107. statement.bufferSize = 0
  108. }
  109. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  110. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  111. statement.noAutoCondition = true
  112. if len(no) > 0 {
  113. statement.noAutoCondition = no[0]
  114. }
  115. return statement
  116. }
  117. // Alias set the table alias
  118. func (statement *Statement) Alias(alias string) *Statement {
  119. statement.TableAlias = alias
  120. return statement
  121. }
  122. // SQL adds raw sql statement
  123. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  124. switch query.(type) {
  125. case (*builder.Builder):
  126. var err error
  127. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  128. if err != nil {
  129. statement.Engine.logger.Error(err)
  130. }
  131. case string:
  132. statement.RawSQL = query.(string)
  133. statement.RawParams = args
  134. default:
  135. statement.Engine.logger.Error("unsupported sql type")
  136. }
  137. return statement
  138. }
  139. // Where add Where statement
  140. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  141. return statement.And(query, args...)
  142. }
  143. // And add Where & and statement
  144. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  145. switch query.(type) {
  146. case string:
  147. cond := builder.Expr(query.(string), args...)
  148. statement.cond = statement.cond.And(cond)
  149. case builder.Cond:
  150. cond := query.(builder.Cond)
  151. statement.cond = statement.cond.And(cond)
  152. for _, v := range args {
  153. if vv, ok := v.(builder.Cond); ok {
  154. statement.cond = statement.cond.And(vv)
  155. }
  156. }
  157. default:
  158. // TODO: not support condition type
  159. }
  160. return statement
  161. }
  162. // Or add Where & Or statement
  163. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  164. switch query.(type) {
  165. case string:
  166. cond := builder.Expr(query.(string), args...)
  167. statement.cond = statement.cond.Or(cond)
  168. case builder.Cond:
  169. cond := query.(builder.Cond)
  170. statement.cond = statement.cond.Or(cond)
  171. for _, v := range args {
  172. if vv, ok := v.(builder.Cond); ok {
  173. statement.cond = statement.cond.Or(vv)
  174. }
  175. }
  176. default:
  177. // TODO: not support condition type
  178. }
  179. return statement
  180. }
  181. // In generate "Where column IN (?) " statement
  182. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  183. in := builder.In(statement.Engine.Quote(column), args...)
  184. statement.cond = statement.cond.And(in)
  185. return statement
  186. }
  187. // NotIn generate "Where column NOT IN (?) " statement
  188. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  189. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  190. statement.cond = statement.cond.And(notIn)
  191. return statement
  192. }
  193. func (statement *Statement) setRefValue(v reflect.Value) error {
  194. var err error
  195. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  196. if err != nil {
  197. return err
  198. }
  199. statement.tableName = statement.Engine.tbName(v)
  200. return nil
  201. }
  202. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  203. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  204. v := rValue(tableNameOrBean)
  205. t := v.Type()
  206. if t.Kind() == reflect.String {
  207. statement.AltTableName = tableNameOrBean.(string)
  208. } else if t.Kind() == reflect.Struct {
  209. var err error
  210. statement.RefTable, err = statement.Engine.autoMapType(v)
  211. if err != nil {
  212. statement.Engine.logger.Error(err)
  213. return statement
  214. }
  215. statement.AltTableName = statement.Engine.tbName(v)
  216. }
  217. return statement
  218. }
  219. // Auto generating update columnes and values according a struct
  220. func buildUpdates(engine *Engine, table *core.Table, bean interface{},
  221. includeVersion bool, includeUpdated bool, includeNil bool,
  222. includeAutoIncr bool, allUseBool bool, useAllCols bool,
  223. mustColumnMap map[string]bool, nullableMap map[string]bool,
  224. columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) {
  225. var colNames = make([]string, 0)
  226. var args = make([]interface{}, 0)
  227. for _, col := range table.Columns() {
  228. if !includeVersion && col.IsVersion {
  229. continue
  230. }
  231. if col.IsCreated {
  232. continue
  233. }
  234. if !includeUpdated && col.IsUpdated {
  235. continue
  236. }
  237. if !includeAutoIncr && col.IsAutoIncrement {
  238. continue
  239. }
  240. if col.IsDeleted && !unscoped {
  241. continue
  242. }
  243. if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use {
  244. continue
  245. }
  246. fieldValuePtr, err := col.ValueOf(bean)
  247. if err != nil {
  248. engine.logger.Error(err)
  249. continue
  250. }
  251. fieldValue := *fieldValuePtr
  252. fieldType := reflect.TypeOf(fieldValue.Interface())
  253. if fieldType == nil {
  254. continue
  255. }
  256. requiredField := useAllCols
  257. includeNil := useAllCols
  258. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  259. if b {
  260. requiredField = true
  261. } else {
  262. continue
  263. }
  264. }
  265. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  266. if b, ok := getFlagForColumn(nullableMap, col); ok {
  267. if b && col.Nullable && isZero(fieldValue.Interface()) {
  268. var nilValue *int
  269. fieldValue = reflect.ValueOf(nilValue)
  270. fieldType = reflect.TypeOf(fieldValue.Interface())
  271. includeNil = true
  272. }
  273. }
  274. var val interface{}
  275. if fieldValue.CanAddr() {
  276. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  277. data, err := structConvert.ToDB()
  278. if err != nil {
  279. engine.logger.Error(err)
  280. } else {
  281. val = data
  282. }
  283. goto APPEND
  284. }
  285. }
  286. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  287. data, err := structConvert.ToDB()
  288. if err != nil {
  289. engine.logger.Error(err)
  290. } else {
  291. val = data
  292. }
  293. goto APPEND
  294. }
  295. if fieldType.Kind() == reflect.Ptr {
  296. if fieldValue.IsNil() {
  297. if includeNil {
  298. args = append(args, nil)
  299. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  300. }
  301. continue
  302. } else if !fieldValue.IsValid() {
  303. continue
  304. } else {
  305. // dereference ptr type to instance type
  306. fieldValue = fieldValue.Elem()
  307. fieldType = reflect.TypeOf(fieldValue.Interface())
  308. requiredField = true
  309. }
  310. }
  311. switch fieldType.Kind() {
  312. case reflect.Bool:
  313. if allUseBool || requiredField {
  314. val = fieldValue.Interface()
  315. } else {
  316. // if a bool in a struct, it will not be as a condition because it default is false,
  317. // please use Where() instead
  318. continue
  319. }
  320. case reflect.String:
  321. if !requiredField && fieldValue.String() == "" {
  322. continue
  323. }
  324. // for MyString, should convert to string or panic
  325. if fieldType.String() != reflect.String.String() {
  326. val = fieldValue.String()
  327. } else {
  328. val = fieldValue.Interface()
  329. }
  330. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  331. if !requiredField && fieldValue.Int() == 0 {
  332. continue
  333. }
  334. val = fieldValue.Interface()
  335. case reflect.Float32, reflect.Float64:
  336. if !requiredField && fieldValue.Float() == 0.0 {
  337. continue
  338. }
  339. val = fieldValue.Interface()
  340. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  341. if !requiredField && fieldValue.Uint() == 0 {
  342. continue
  343. }
  344. t := int64(fieldValue.Uint())
  345. val = reflect.ValueOf(&t).Interface()
  346. case reflect.Struct:
  347. if fieldType.ConvertibleTo(core.TimeType) {
  348. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  349. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  350. continue
  351. }
  352. val = engine.formatColTime(col, t)
  353. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  354. val, _ = nulType.Value()
  355. } else {
  356. if !col.SQLType.IsJson() {
  357. engine.autoMapType(fieldValue)
  358. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  359. if len(table.PrimaryKeys) == 1 {
  360. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  361. // fix non-int pk issues
  362. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  363. val = pkField.Interface()
  364. } else {
  365. continue
  366. }
  367. } else {
  368. //TODO: how to handler?
  369. panic("not supported")
  370. }
  371. } else {
  372. val = fieldValue.Interface()
  373. }
  374. } else {
  375. // Blank struct could not be as update data
  376. if requiredField || !isStructZero(fieldValue) {
  377. bytes, err := json.Marshal(fieldValue.Interface())
  378. if err != nil {
  379. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  380. }
  381. if col.SQLType.IsText() {
  382. val = string(bytes)
  383. } else if col.SQLType.IsBlob() {
  384. val = bytes
  385. }
  386. } else {
  387. continue
  388. }
  389. }
  390. }
  391. case reflect.Array, reflect.Slice, reflect.Map:
  392. if !requiredField {
  393. if fieldValue == reflect.Zero(fieldType) {
  394. continue
  395. }
  396. if fieldType.Kind() == reflect.Array {
  397. if isArrayValueZero(fieldValue) {
  398. continue
  399. }
  400. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  401. continue
  402. }
  403. }
  404. if col.SQLType.IsText() {
  405. bytes, err := json.Marshal(fieldValue.Interface())
  406. if err != nil {
  407. engine.logger.Error(err)
  408. continue
  409. }
  410. val = string(bytes)
  411. } else if col.SQLType.IsBlob() {
  412. var bytes []byte
  413. var err error
  414. if fieldType.Kind() == reflect.Slice &&
  415. fieldType.Elem().Kind() == reflect.Uint8 {
  416. if fieldValue.Len() > 0 {
  417. val = fieldValue.Bytes()
  418. } else {
  419. continue
  420. }
  421. } else if fieldType.Kind() == reflect.Array &&
  422. fieldType.Elem().Kind() == reflect.Uint8 {
  423. val = fieldValue.Slice(0, 0).Interface()
  424. } else {
  425. bytes, err = json.Marshal(fieldValue.Interface())
  426. if err != nil {
  427. engine.logger.Error(err)
  428. continue
  429. }
  430. val = bytes
  431. }
  432. } else {
  433. continue
  434. }
  435. default:
  436. val = fieldValue.Interface()
  437. }
  438. APPEND:
  439. args = append(args, val)
  440. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  441. continue
  442. }
  443. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  444. }
  445. return colNames, args
  446. }
  447. func (statement *Statement) needTableName() bool {
  448. return len(statement.JoinStr) > 0
  449. }
  450. func (statement *Statement) colName(col *core.Column, tableName string) string {
  451. if statement.needTableName() {
  452. var nm = tableName
  453. if len(statement.TableAlias) > 0 {
  454. nm = statement.TableAlias
  455. }
  456. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  457. }
  458. return statement.Engine.Quote(col.Name)
  459. }
  460. // TableName return current tableName
  461. func (statement *Statement) TableName() string {
  462. if statement.AltTableName != "" {
  463. return statement.AltTableName
  464. }
  465. return statement.tableName
  466. }
  467. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  468. func (statement *Statement) ID(id interface{}) *Statement {
  469. idValue := reflect.ValueOf(id)
  470. idType := reflect.TypeOf(idValue.Interface())
  471. switch idType {
  472. case ptrPkType:
  473. if pkPtr, ok := (id).(*core.PK); ok {
  474. statement.idParam = pkPtr
  475. return statement
  476. }
  477. case pkType:
  478. if pk, ok := (id).(core.PK); ok {
  479. statement.idParam = &pk
  480. return statement
  481. }
  482. }
  483. switch idType.Kind() {
  484. case reflect.String:
  485. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  486. return statement
  487. }
  488. statement.idParam = &core.PK{id}
  489. return statement
  490. }
  491. // Incr Generate "Update ... Set column = column + arg" statement
  492. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  493. k := strings.ToLower(column)
  494. if len(arg) > 0 {
  495. statement.incrColumns[k] = incrParam{column, arg[0]}
  496. } else {
  497. statement.incrColumns[k] = incrParam{column, 1}
  498. }
  499. return statement
  500. }
  501. // Decr Generate "Update ... Set column = column - arg" statement
  502. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  503. k := strings.ToLower(column)
  504. if len(arg) > 0 {
  505. statement.decrColumns[k] = decrParam{column, arg[0]}
  506. } else {
  507. statement.decrColumns[k] = decrParam{column, 1}
  508. }
  509. return statement
  510. }
  511. // SetExpr Generate "Update ... Set column = {expression}" statement
  512. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  513. k := strings.ToLower(column)
  514. statement.exprColumns[k] = exprParam{column, expression}
  515. return statement
  516. }
  517. // Generate "Update ... Set column = column + arg" statement
  518. func (statement *Statement) getInc() map[string]incrParam {
  519. return statement.incrColumns
  520. }
  521. // Generate "Update ... Set column = column - arg" statement
  522. func (statement *Statement) getDec() map[string]decrParam {
  523. return statement.decrColumns
  524. }
  525. // Generate "Update ... Set column = {expression}" statement
  526. func (statement *Statement) getExpr() map[string]exprParam {
  527. return statement.exprColumns
  528. }
  529. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  530. newColumns := make([]string, 0)
  531. for _, col := range columns {
  532. col = strings.Replace(col, "`", "", -1)
  533. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  534. ccols := strings.Split(col, ",")
  535. for _, c := range ccols {
  536. fields := strings.Split(strings.TrimSpace(c), ".")
  537. if len(fields) == 1 {
  538. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  539. } else if len(fields) == 2 {
  540. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  541. statement.Engine.quote(fields[1]))
  542. } else {
  543. panic(errors.New("unwanted colnames"))
  544. }
  545. }
  546. }
  547. return newColumns
  548. }
  549. func (statement *Statement) colmap2NewColsWithQuote() []string {
  550. newColumns := make([]string, 0, len(statement.columnMap))
  551. for col := range statement.columnMap {
  552. fields := strings.Split(strings.TrimSpace(col), ".")
  553. if len(fields) == 1 {
  554. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  555. } else if len(fields) == 2 {
  556. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  557. statement.Engine.quote(fields[1]))
  558. } else {
  559. panic(errors.New("unwanted colnames"))
  560. }
  561. }
  562. return newColumns
  563. }
  564. // Distinct generates "DISTINCT col1, col2 " statement
  565. func (statement *Statement) Distinct(columns ...string) *Statement {
  566. statement.IsDistinct = true
  567. statement.Cols(columns...)
  568. return statement
  569. }
  570. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  571. func (statement *Statement) ForUpdate() *Statement {
  572. statement.IsForUpdate = true
  573. return statement
  574. }
  575. // Select replace select
  576. func (statement *Statement) Select(str string) *Statement {
  577. statement.selectStr = str
  578. return statement
  579. }
  580. // Cols generate "col1, col2" statement
  581. func (statement *Statement) Cols(columns ...string) *Statement {
  582. cols := col2NewCols(columns...)
  583. for _, nc := range cols {
  584. statement.columnMap[strings.ToLower(nc)] = true
  585. }
  586. newColumns := statement.colmap2NewColsWithQuote()
  587. statement.ColumnStr = strings.Join(newColumns, ", ")
  588. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  589. return statement
  590. }
  591. // AllCols update use only: update all columns
  592. func (statement *Statement) AllCols() *Statement {
  593. statement.useAllCols = true
  594. return statement
  595. }
  596. // MustCols update use only: must update columns
  597. func (statement *Statement) MustCols(columns ...string) *Statement {
  598. newColumns := col2NewCols(columns...)
  599. for _, nc := range newColumns {
  600. statement.mustColumnMap[strings.ToLower(nc)] = true
  601. }
  602. return statement
  603. }
  604. // UseBool indicates that use bool fields as update contents and query contiditions
  605. func (statement *Statement) UseBool(columns ...string) *Statement {
  606. if len(columns) > 0 {
  607. statement.MustCols(columns...)
  608. } else {
  609. statement.allUseBool = true
  610. }
  611. return statement
  612. }
  613. // Omit do not use the columns
  614. func (statement *Statement) Omit(columns ...string) {
  615. newColumns := col2NewCols(columns...)
  616. for _, nc := range newColumns {
  617. statement.columnMap[strings.ToLower(nc)] = false
  618. }
  619. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  620. }
  621. // Nullable Update use only: update columns to null when value is nullable and zero-value
  622. func (statement *Statement) Nullable(columns ...string) {
  623. newColumns := col2NewCols(columns...)
  624. for _, nc := range newColumns {
  625. statement.nullableMap[strings.ToLower(nc)] = true
  626. }
  627. }
  628. // Top generate LIMIT limit statement
  629. func (statement *Statement) Top(limit int) *Statement {
  630. statement.Limit(limit)
  631. return statement
  632. }
  633. // Limit generate LIMIT start, limit statement
  634. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  635. statement.LimitN = limit
  636. if len(start) > 0 {
  637. statement.Start = start[0]
  638. }
  639. return statement
  640. }
  641. // OrderBy generate "Order By order" statement
  642. func (statement *Statement) OrderBy(order string) *Statement {
  643. if len(statement.OrderStr) > 0 {
  644. statement.OrderStr += ", "
  645. }
  646. statement.OrderStr += order
  647. return statement
  648. }
  649. // Desc generate `ORDER BY xx DESC`
  650. func (statement *Statement) Desc(colNames ...string) *Statement {
  651. var buf bytes.Buffer
  652. fmt.Fprintf(&buf, statement.OrderStr)
  653. if len(statement.OrderStr) > 0 {
  654. fmt.Fprint(&buf, ", ")
  655. }
  656. newColNames := statement.col2NewColsWithQuote(colNames...)
  657. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  658. statement.OrderStr = buf.String()
  659. return statement
  660. }
  661. // Asc provide asc order by query condition, the input parameters are columns.
  662. func (statement *Statement) Asc(colNames ...string) *Statement {
  663. var buf bytes.Buffer
  664. fmt.Fprintf(&buf, statement.OrderStr)
  665. if len(statement.OrderStr) > 0 {
  666. fmt.Fprint(&buf, ", ")
  667. }
  668. newColNames := statement.col2NewColsWithQuote(colNames...)
  669. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  670. statement.OrderStr = buf.String()
  671. return statement
  672. }
  673. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  674. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  675. var buf bytes.Buffer
  676. if len(statement.JoinStr) > 0 {
  677. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  678. } else {
  679. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  680. }
  681. switch tablename.(type) {
  682. case []string:
  683. t := tablename.([]string)
  684. if len(t) > 1 {
  685. fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
  686. } else if len(t) == 1 {
  687. fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
  688. }
  689. case []interface{}:
  690. t := tablename.([]interface{})
  691. l := len(t)
  692. var table string
  693. if l > 0 {
  694. f := t[0]
  695. v := rValue(f)
  696. t := v.Type()
  697. if t.Kind() == reflect.String {
  698. table = f.(string)
  699. } else if t.Kind() == reflect.Struct {
  700. table = statement.Engine.tbName(v)
  701. }
  702. }
  703. if l > 1 {
  704. fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
  705. statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
  706. } else if l == 1 {
  707. fmt.Fprintf(&buf, statement.Engine.Quote(table))
  708. }
  709. default:
  710. fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
  711. }
  712. fmt.Fprintf(&buf, " ON %v", condition)
  713. statement.JoinStr = buf.String()
  714. statement.joinArgs = append(statement.joinArgs, args...)
  715. return statement
  716. }
  717. // GroupBy generate "Group By keys" statement
  718. func (statement *Statement) GroupBy(keys string) *Statement {
  719. statement.GroupByStr = keys
  720. return statement
  721. }
  722. // Having generate "Having conditions" statement
  723. func (statement *Statement) Having(conditions string) *Statement {
  724. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  725. return statement
  726. }
  727. // Unscoped always disable struct tag "deleted"
  728. func (statement *Statement) Unscoped() *Statement {
  729. statement.unscoped = true
  730. return statement
  731. }
  732. func (statement *Statement) genColumnStr() string {
  733. var buf bytes.Buffer
  734. if statement.RefTable == nil {
  735. return ""
  736. }
  737. columns := statement.RefTable.Columns()
  738. for _, col := range columns {
  739. if statement.OmitStr != "" {
  740. if _, ok := getFlagForColumn(statement.columnMap, col); ok {
  741. continue
  742. }
  743. }
  744. if col.MapType == core.ONLYTODB {
  745. continue
  746. }
  747. if buf.Len() != 0 {
  748. buf.WriteString(", ")
  749. }
  750. if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
  751. buf.WriteString("id() AS ")
  752. }
  753. if statement.JoinStr != "" {
  754. if statement.TableAlias != "" {
  755. buf.WriteString(statement.TableAlias)
  756. } else {
  757. buf.WriteString(statement.TableName())
  758. }
  759. buf.WriteString(".")
  760. }
  761. statement.Engine.QuoteTo(&buf, col.Name)
  762. }
  763. return buf.String()
  764. }
  765. func (statement *Statement) genCreateTableSQL() string {
  766. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  767. statement.StoreEngine, statement.Charset)
  768. }
  769. func (statement *Statement) genIndexSQL() []string {
  770. var sqls []string
  771. tbName := statement.TableName()
  772. quote := statement.Engine.Quote
  773. for idxName, index := range statement.RefTable.Indexes {
  774. if index.Type == core.IndexType {
  775. sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
  776. quote(tbName), quote(strings.Join(index.Cols, quote(","))))
  777. sqls = append(sqls, sql)
  778. }
  779. }
  780. return sqls
  781. }
  782. func uniqueName(tableName, uqeName string) string {
  783. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  784. }
  785. func (statement *Statement) genUniqueSQL() []string {
  786. var sqls []string
  787. tbName := statement.TableName()
  788. for _, index := range statement.RefTable.Indexes {
  789. if index.Type == core.UniqueType {
  790. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  791. sqls = append(sqls, sql)
  792. }
  793. }
  794. return sqls
  795. }
  796. func (statement *Statement) genDelIndexSQL() []string {
  797. var sqls []string
  798. tbName := statement.TableName()
  799. for idxName, index := range statement.RefTable.Indexes {
  800. var rIdxName string
  801. if index.Type == core.UniqueType {
  802. rIdxName = uniqueName(tbName, idxName)
  803. } else if index.Type == core.IndexType {
  804. rIdxName = indexName(tbName, idxName)
  805. }
  806. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
  807. if statement.Engine.dialect.IndexOnTable() {
  808. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
  809. }
  810. sqls = append(sqls, sql)
  811. }
  812. return sqls
  813. }
  814. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  815. quote := statement.Engine.Quote
  816. sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()),
  817. col.String(statement.Engine.dialect))
  818. return sql, []interface{}{}
  819. }
  820. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  821. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  822. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  823. }
  824. func (statement *Statement) mergeConds(bean interface{}) error {
  825. if !statement.noAutoCondition {
  826. var addedTableName = (len(statement.JoinStr) > 0)
  827. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  828. if err != nil {
  829. return err
  830. }
  831. statement.cond = statement.cond.And(autoCond)
  832. }
  833. if err := statement.processIDParam(); err != nil {
  834. return err
  835. }
  836. return nil
  837. }
  838. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  839. if err := statement.mergeConds(bean); err != nil {
  840. return "", nil, err
  841. }
  842. return builder.ToSQL(statement.cond)
  843. }
  844. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  845. v := rValue(bean)
  846. isStruct := v.Kind() == reflect.Struct
  847. if isStruct {
  848. statement.setRefValue(v)
  849. }
  850. var columnStr = statement.ColumnStr
  851. if len(statement.selectStr) > 0 {
  852. columnStr = statement.selectStr
  853. } else {
  854. // TODO: always generate column names, not use * even if join
  855. if len(statement.JoinStr) == 0 {
  856. if len(columnStr) == 0 {
  857. if len(statement.GroupByStr) > 0 {
  858. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  859. } else {
  860. columnStr = statement.genColumnStr()
  861. }
  862. }
  863. } else {
  864. if len(columnStr) == 0 {
  865. if len(statement.GroupByStr) > 0 {
  866. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  867. }
  868. }
  869. }
  870. }
  871. if len(columnStr) == 0 {
  872. columnStr = "*"
  873. }
  874. if isStruct {
  875. if err := statement.mergeConds(bean); err != nil {
  876. return "", nil, err
  877. }
  878. }
  879. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  880. if err != nil {
  881. return "", nil, err
  882. }
  883. sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
  884. if err != nil {
  885. return "", nil, err
  886. }
  887. return sqlStr, append(statement.joinArgs, condArgs...), nil
  888. }
  889. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  890. var condSQL string
  891. var condArgs []interface{}
  892. var err error
  893. if len(beans) > 0 {
  894. statement.setRefValue(rValue(beans[0]))
  895. condSQL, condArgs, err = statement.genConds(beans[0])
  896. } else {
  897. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  898. }
  899. if err != nil {
  900. return "", nil, err
  901. }
  902. var selectSQL = statement.selectStr
  903. if len(selectSQL) <= 0 {
  904. if statement.IsDistinct {
  905. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  906. } else {
  907. selectSQL = "count(*)"
  908. }
  909. }
  910. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
  911. if err != nil {
  912. return "", nil, err
  913. }
  914. return sqlStr, append(statement.joinArgs, condArgs...), nil
  915. }
  916. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  917. statement.setRefValue(rValue(bean))
  918. var sumStrs = make([]string, 0, len(columns))
  919. for _, colName := range columns {
  920. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  921. colName = statement.Engine.Quote(colName)
  922. }
  923. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  924. }
  925. sumSelect := strings.Join(sumStrs, ", ")
  926. condSQL, condArgs, err := statement.genConds(bean)
  927. if err != nil {
  928. return "", nil, err
  929. }
  930. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
  931. if err != nil {
  932. return "", nil, err
  933. }
  934. return sqlStr, append(statement.joinArgs, condArgs...), nil
  935. }
  936. func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
  937. var distinct string
  938. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  939. distinct = "DISTINCT "
  940. }
  941. var dialect = statement.Engine.Dialect()
  942. var quote = statement.Engine.Quote
  943. var top string
  944. var mssqlCondi string
  945. if err := statement.processIDParam(); err != nil {
  946. return "", err
  947. }
  948. var buf bytes.Buffer
  949. if len(condSQL) > 0 {
  950. fmt.Fprintf(&buf, " WHERE %v", condSQL)
  951. }
  952. var whereStr = buf.String()
  953. var fromStr = " FROM "
  954. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  955. fromStr += statement.TableName()
  956. } else {
  957. fromStr += quote(statement.TableName())
  958. }
  959. if statement.TableAlias != "" {
  960. if dialect.DBType() == core.ORACLE {
  961. fromStr += " " + quote(statement.TableAlias)
  962. } else {
  963. fromStr += " AS " + quote(statement.TableAlias)
  964. }
  965. }
  966. if statement.JoinStr != "" {
  967. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  968. }
  969. if dialect.DBType() == core.MSSQL {
  970. if statement.LimitN > 0 {
  971. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  972. }
  973. if statement.Start > 0 {
  974. var column string
  975. if len(statement.RefTable.PKColumns()) == 0 {
  976. for _, index := range statement.RefTable.Indexes {
  977. if len(index.Cols) == 1 {
  978. column = index.Cols[0]
  979. break
  980. }
  981. }
  982. if len(column) == 0 {
  983. column = statement.RefTable.ColumnsSeq()[0]
  984. }
  985. } else {
  986. column = statement.RefTable.PKColumns()[0].Name
  987. }
  988. if statement.needTableName() {
  989. if len(statement.TableAlias) > 0 {
  990. column = statement.TableAlias + "." + column
  991. } else {
  992. column = statement.TableName() + "." + column
  993. }
  994. }
  995. var orderStr string
  996. if len(statement.OrderStr) > 0 {
  997. orderStr = " ORDER BY " + statement.OrderStr
  998. }
  999. var groupStr string
  1000. if len(statement.GroupByStr) > 0 {
  1001. groupStr = " GROUP BY " + statement.GroupByStr
  1002. }
  1003. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  1004. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  1005. }
  1006. }
  1007. // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
  1008. a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  1009. if len(mssqlCondi) > 0 {
  1010. if len(whereStr) > 0 {
  1011. a += " AND " + mssqlCondi
  1012. } else {
  1013. a += " WHERE " + mssqlCondi
  1014. }
  1015. }
  1016. if statement.GroupByStr != "" {
  1017. a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
  1018. }
  1019. if statement.HavingStr != "" {
  1020. a = fmt.Sprintf("%v %v", a, statement.HavingStr)
  1021. }
  1022. if statement.OrderStr != "" {
  1023. a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
  1024. }
  1025. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1026. if statement.Start > 0 {
  1027. a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
  1028. } else if statement.LimitN > 0 {
  1029. a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
  1030. }
  1031. } else if dialect.DBType() == core.ORACLE {
  1032. if statement.Start != 0 || statement.LimitN != 0 {
  1033. a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
  1034. }
  1035. }
  1036. if statement.IsForUpdate {
  1037. a = dialect.ForUpdateSql(a)
  1038. }
  1039. return
  1040. }
  1041. func (statement *Statement) processIDParam() error {
  1042. if statement.idParam == nil {
  1043. return nil
  1044. }
  1045. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1046. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1047. len(statement.RefTable.PrimaryKeys),
  1048. len(*statement.idParam),
  1049. )
  1050. }
  1051. for i, col := range statement.RefTable.PKColumns() {
  1052. var colName = statement.colName(col, statement.TableName())
  1053. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1054. }
  1055. return nil
  1056. }
  1057. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1058. var colnames = make([]string, len(cols))
  1059. for i, col := range cols {
  1060. if includeTableName {
  1061. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1062. "." + statement.Engine.Quote(col.Name)
  1063. } else {
  1064. colnames[i] = statement.Engine.Quote(col.Name)
  1065. }
  1066. }
  1067. return strings.Join(colnames, ", ")
  1068. }
  1069. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1070. if statement.RefTable != nil {
  1071. cols := statement.RefTable.PKColumns()
  1072. if len(cols) == 0 {
  1073. return ""
  1074. }
  1075. colstrs := statement.joinColumns(cols, false)
  1076. sqls := splitNNoCase(sqlStr, " from ", 2)
  1077. if len(sqls) != 2 {
  1078. return ""
  1079. }
  1080. var top string
  1081. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1082. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1083. }
  1084. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1085. return newsql
  1086. }
  1087. return ""
  1088. }
  1089. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1090. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1091. return "", ""
  1092. }
  1093. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1094. sqls := splitNNoCase(sqlStr, "where", 2)
  1095. if len(sqls) != 2 {
  1096. if len(sqls) == 1 {
  1097. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1098. colstrs, statement.Engine.Quote(statement.TableName()))
  1099. }
  1100. return "", ""
  1101. }
  1102. var whereStr = sqls[1]
  1103. //TODO: for postgres only, if any other database?
  1104. var paraStr string
  1105. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1106. paraStr = "$"
  1107. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1108. paraStr = ":"
  1109. }
  1110. if paraStr != "" {
  1111. if strings.Contains(sqls[1], paraStr) {
  1112. dollers := strings.Split(sqls[1], paraStr)
  1113. whereStr = dollers[0]
  1114. for i, c := range dollers[1:] {
  1115. ccs := strings.SplitN(c, " ", 2)
  1116. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1117. }
  1118. }
  1119. }
  1120. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1121. colstrs, statement.Engine.Quote(statement.TableName()),
  1122. whereStr)
  1123. }