statement.go 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql/driver"
  7. "errors"
  8. "fmt"
  9. "reflect"
  10. "strings"
  11. "time"
  12. "github.com/xormplus/builder"
  13. "github.com/xormplus/core"
  14. )
  15. // Statement save all the sql info for executing SQL
  16. type Statement struct {
  17. RefTable *core.Table
  18. Engine *Engine
  19. Start int
  20. LimitN int
  21. idParam *core.PK
  22. OrderStr string
  23. JoinStr string
  24. joinArgs []interface{}
  25. GroupByStr string
  26. HavingStr string
  27. ColumnStr string
  28. selectStr string
  29. useAllCols bool
  30. OmitStr string
  31. AltTableName string
  32. tableName string
  33. RawSQL string
  34. RawParams []interface{}
  35. UseCascade bool
  36. UseAutoJoin bool
  37. StoreEngine string
  38. Charset string
  39. UseCache bool
  40. UseAutoTime bool
  41. noAutoCondition bool
  42. IsDistinct bool
  43. IsForUpdate bool
  44. TableAlias string
  45. allUseBool bool
  46. checkVersion bool
  47. unscoped bool
  48. columnMap columnMap
  49. omitColumnMap columnMap
  50. mustColumnMap map[string]bool
  51. nullableMap map[string]bool
  52. incrColumns map[string]incrParam
  53. decrColumns map[string]decrParam
  54. exprColumns map[string]exprParam
  55. cond builder.Cond
  56. bufferSize int
  57. context ContextCache
  58. lastError error
  59. }
  60. // Init reset all the statement's fields
  61. func (statement *Statement) Init() {
  62. statement.RefTable = nil
  63. statement.Start = 0
  64. statement.LimitN = 0
  65. statement.OrderStr = ""
  66. statement.UseCascade = true
  67. statement.JoinStr = ""
  68. statement.joinArgs = make([]interface{}, 0)
  69. statement.GroupByStr = ""
  70. statement.HavingStr = ""
  71. statement.ColumnStr = ""
  72. statement.OmitStr = ""
  73. statement.columnMap = columnMap{}
  74. statement.omitColumnMap = columnMap{}
  75. statement.AltTableName = ""
  76. statement.tableName = ""
  77. statement.idParam = nil
  78. statement.RawSQL = ""
  79. statement.RawParams = make([]interface{}, 0)
  80. statement.UseCache = true
  81. statement.UseAutoTime = true
  82. statement.noAutoCondition = false
  83. statement.IsDistinct = false
  84. statement.IsForUpdate = false
  85. statement.TableAlias = ""
  86. statement.selectStr = ""
  87. statement.allUseBool = false
  88. statement.useAllCols = false
  89. statement.mustColumnMap = make(map[string]bool)
  90. statement.nullableMap = make(map[string]bool)
  91. statement.checkVersion = true
  92. statement.unscoped = false
  93. statement.incrColumns = make(map[string]incrParam)
  94. statement.decrColumns = make(map[string]decrParam)
  95. statement.exprColumns = make(map[string]exprParam)
  96. statement.cond = builder.NewCond()
  97. statement.bufferSize = 0
  98. statement.context = nil
  99. statement.lastError = nil
  100. }
  101. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  102. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  103. statement.noAutoCondition = true
  104. if len(no) > 0 {
  105. statement.noAutoCondition = no[0]
  106. }
  107. return statement
  108. }
  109. // Alias set the table alias
  110. func (statement *Statement) Alias(alias string) *Statement {
  111. statement.TableAlias = alias
  112. return statement
  113. }
  114. // SQL adds raw sql statement
  115. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  116. switch query.(type) {
  117. case (*builder.Builder):
  118. var err error
  119. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  120. if err != nil {
  121. statement.lastError = err
  122. }
  123. case string:
  124. statement.RawSQL = query.(string)
  125. statement.RawParams = args
  126. default:
  127. statement.lastError = ErrUnSupportedSQLType
  128. }
  129. return statement
  130. }
  131. // Where add Where statement
  132. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  133. return statement.And(query, args...)
  134. }
  135. // And add Where & and statement
  136. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  137. switch query.(type) {
  138. case string:
  139. isExpr := false
  140. var cargs []interface{}
  141. for i, _ := range args {
  142. if _, ok := args[i].(sqlExpr); ok {
  143. isExpr = true
  144. }
  145. cargs = append(cargs, args[i])
  146. }
  147. if isExpr {
  148. sqlStr, _ := ConvertToBoundSQL(query.(string), cargs)
  149. cond := builder.Expr(sqlStr)
  150. statement.cond = statement.cond.And(cond)
  151. } else {
  152. cond := builder.Expr(query.(string), args...)
  153. statement.cond = statement.cond.And(cond)
  154. }
  155. case map[string]interface{}:
  156. cond := builder.Eq(query.(map[string]interface{}))
  157. statement.cond = statement.cond.And(cond)
  158. case builder.Cond:
  159. cond := query.(builder.Cond)
  160. statement.cond = statement.cond.And(cond)
  161. for _, v := range args {
  162. if vv, ok := v.(builder.Cond); ok {
  163. statement.cond = statement.cond.And(vv)
  164. }
  165. }
  166. default:
  167. statement.lastError = ErrConditionType
  168. }
  169. return statement
  170. }
  171. // Or add Where & Or statement
  172. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  173. switch query.(type) {
  174. case string:
  175. isExpr := false
  176. var cargs []interface{}
  177. for i, _ := range args {
  178. if _, ok := args[i].(sqlExpr); ok {
  179. isExpr = true
  180. }
  181. cargs = append(cargs, args[i])
  182. }
  183. if isExpr {
  184. sqlStr, _ := ConvertToBoundSQL(query.(string), cargs)
  185. cond := builder.Expr(sqlStr)
  186. statement.cond = statement.cond.Or(cond)
  187. } else {
  188. cond := builder.Expr(query.(string), args...)
  189. statement.cond = statement.cond.Or(cond)
  190. }
  191. case map[string]interface{}:
  192. cond := builder.Eq(query.(map[string]interface{}))
  193. statement.cond = statement.cond.Or(cond)
  194. case builder.Cond:
  195. cond := query.(builder.Cond)
  196. statement.cond = statement.cond.Or(cond)
  197. for _, v := range args {
  198. if vv, ok := v.(builder.Cond); ok {
  199. statement.cond = statement.cond.Or(vv)
  200. }
  201. }
  202. default:
  203. // TODO: not support condition type
  204. }
  205. return statement
  206. }
  207. // In generate "Where column IN (?) " statement
  208. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  209. in := builder.In(statement.Engine.Quote(column), args...)
  210. statement.cond = statement.cond.And(in)
  211. return statement
  212. }
  213. // NotIn generate "Where column NOT IN (?) " statement
  214. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  215. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  216. statement.cond = statement.cond.And(notIn)
  217. return statement
  218. }
  219. func (statement *Statement) setRefValue(v reflect.Value) error {
  220. var err error
  221. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  222. if err != nil {
  223. return err
  224. }
  225. statement.tableName = statement.Engine.TableName(v, true)
  226. return nil
  227. }
  228. func (statement *Statement) setRefBean(bean interface{}) error {
  229. var err error
  230. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  231. if err != nil {
  232. return err
  233. }
  234. statement.tableName = statement.Engine.TableName(bean, true)
  235. return nil
  236. }
  237. // Auto generating update columnes and values according a struct
  238. func (statement *Statement) buildUpdates(bean interface{},
  239. includeVersion, includeUpdated, includeNil,
  240. includeAutoIncr, update bool) ([]string, []interface{}) {
  241. engine := statement.Engine
  242. table := statement.RefTable
  243. allUseBool := statement.allUseBool
  244. useAllCols := statement.useAllCols
  245. mustColumnMap := statement.mustColumnMap
  246. nullableMap := statement.nullableMap
  247. columnMap := statement.columnMap
  248. omitColumnMap := statement.omitColumnMap
  249. unscoped := statement.unscoped
  250. var colNames = make([]string, 0)
  251. var args = make([]interface{}, 0)
  252. for _, col := range table.Columns() {
  253. if !includeVersion && col.IsVersion {
  254. continue
  255. }
  256. if col.IsCreated {
  257. continue
  258. }
  259. if !includeUpdated && col.IsUpdated {
  260. continue
  261. }
  262. if !includeAutoIncr && col.IsAutoIncrement {
  263. continue
  264. }
  265. if col.IsDeleted && !unscoped {
  266. continue
  267. }
  268. if omitColumnMap.contain(col.Name) {
  269. continue
  270. }
  271. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  272. continue
  273. }
  274. if col.MapType == core.ONLYFROMDB {
  275. continue
  276. }
  277. fieldValuePtr, err := col.ValueOf(bean)
  278. if err != nil {
  279. engine.logger.Error(err)
  280. continue
  281. }
  282. fieldValue := *fieldValuePtr
  283. fieldType := reflect.TypeOf(fieldValue.Interface())
  284. if fieldType == nil {
  285. continue
  286. }
  287. requiredField := useAllCols
  288. includeNil := useAllCols
  289. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  290. if b {
  291. requiredField = true
  292. } else {
  293. continue
  294. }
  295. }
  296. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  297. if b, ok := getFlagForColumn(nullableMap, col); ok {
  298. if b && col.Nullable && isZero(fieldValue.Interface()) {
  299. var nilValue *int
  300. fieldValue = reflect.ValueOf(nilValue)
  301. fieldType = reflect.TypeOf(fieldValue.Interface())
  302. includeNil = true
  303. }
  304. }
  305. var val interface{}
  306. if fieldValue.CanAddr() {
  307. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  308. data, err := structConvert.ToDB()
  309. if err != nil {
  310. engine.logger.Error(err)
  311. } else {
  312. val = data
  313. }
  314. goto APPEND
  315. }
  316. }
  317. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  318. data, err := structConvert.ToDB()
  319. if err != nil {
  320. engine.logger.Error(err)
  321. } else {
  322. val = data
  323. }
  324. goto APPEND
  325. }
  326. if fieldType.Kind() == reflect.Ptr {
  327. if fieldValue.IsNil() {
  328. if includeNil {
  329. args = append(args, nil)
  330. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  331. }
  332. continue
  333. } else if !fieldValue.IsValid() {
  334. continue
  335. } else {
  336. // dereference ptr type to instance type
  337. fieldValue = fieldValue.Elem()
  338. fieldType = reflect.TypeOf(fieldValue.Interface())
  339. requiredField = true
  340. }
  341. }
  342. switch fieldType.Kind() {
  343. case reflect.Bool:
  344. if allUseBool || requiredField {
  345. val = fieldValue.Interface()
  346. } else {
  347. // if a bool in a struct, it will not be as a condition because it default is false,
  348. // please use Where() instead
  349. continue
  350. }
  351. case reflect.String:
  352. if !requiredField && fieldValue.String() == "" {
  353. continue
  354. }
  355. // for MyString, should convert to string or panic
  356. if fieldType.String() != reflect.String.String() {
  357. val = fieldValue.String()
  358. } else {
  359. val = fieldValue.Interface()
  360. }
  361. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  362. if !requiredField && fieldValue.Int() == 0 {
  363. continue
  364. }
  365. val = fieldValue.Interface()
  366. case reflect.Float32, reflect.Float64:
  367. if !requiredField && fieldValue.Float() == 0.0 {
  368. continue
  369. }
  370. val = fieldValue.Interface()
  371. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  372. if !requiredField && fieldValue.Uint() == 0 {
  373. continue
  374. }
  375. t := int64(fieldValue.Uint())
  376. val = reflect.ValueOf(&t).Interface()
  377. case reflect.Struct:
  378. if fieldType.ConvertibleTo(core.TimeType) {
  379. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  380. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  381. continue
  382. }
  383. val = engine.formatColTime(col, t)
  384. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  385. val, _ = nulType.Value()
  386. } else {
  387. if !col.SQLType.IsJson() {
  388. engine.autoMapType(fieldValue)
  389. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  390. if len(table.PrimaryKeys) == 1 {
  391. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  392. // fix non-int pk issues
  393. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  394. val = pkField.Interface()
  395. } else {
  396. continue
  397. }
  398. } else {
  399. //TODO: how to handler?
  400. panic("not supported")
  401. }
  402. } else {
  403. val = fieldValue.Interface()
  404. }
  405. } else {
  406. // Blank struct could not be as update data
  407. if requiredField || !isStructZero(fieldValue) {
  408. bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
  409. if err != nil {
  410. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  411. }
  412. if col.SQLType.IsText() {
  413. val = string(bytes)
  414. } else if col.SQLType.IsBlob() {
  415. val = bytes
  416. }
  417. } else {
  418. continue
  419. }
  420. }
  421. }
  422. case reflect.Array, reflect.Slice, reflect.Map:
  423. if !requiredField {
  424. if fieldValue == reflect.Zero(fieldType) {
  425. continue
  426. }
  427. if fieldType.Kind() == reflect.Array {
  428. if isArrayValueZero(fieldValue) {
  429. continue
  430. }
  431. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  432. continue
  433. }
  434. }
  435. if col.SQLType.IsText() {
  436. bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
  437. if err != nil {
  438. engine.logger.Error(err)
  439. continue
  440. }
  441. val = string(bytes)
  442. } else if col.SQLType.IsBlob() {
  443. var bytes []byte
  444. var err error
  445. if fieldType.Kind() == reflect.Slice &&
  446. fieldType.Elem().Kind() == reflect.Uint8 {
  447. if fieldValue.Len() > 0 {
  448. val = fieldValue.Bytes()
  449. } else {
  450. continue
  451. }
  452. } else if fieldType.Kind() == reflect.Array &&
  453. fieldType.Elem().Kind() == reflect.Uint8 {
  454. val = fieldValue.Slice(0, 0).Interface()
  455. } else {
  456. bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
  457. if err != nil {
  458. engine.logger.Error(err)
  459. continue
  460. }
  461. val = bytes
  462. }
  463. } else {
  464. continue
  465. }
  466. default:
  467. val = fieldValue.Interface()
  468. }
  469. APPEND:
  470. args = append(args, val)
  471. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  472. continue
  473. }
  474. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  475. }
  476. return colNames, args
  477. }
  478. func (statement *Statement) needTableName() bool {
  479. return len(statement.JoinStr) > 0
  480. }
  481. func (statement *Statement) colName(col *core.Column, tableName string) string {
  482. if statement.needTableName() {
  483. var nm = tableName
  484. if len(statement.TableAlias) > 0 {
  485. nm = statement.TableAlias
  486. }
  487. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  488. }
  489. return statement.Engine.Quote(col.Name)
  490. }
  491. // TableName return current tableName
  492. func (statement *Statement) TableName() string {
  493. if statement.AltTableName != "" {
  494. return statement.AltTableName
  495. }
  496. return statement.tableName
  497. }
  498. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  499. func (statement *Statement) ID(id interface{}) *Statement {
  500. idValue := reflect.ValueOf(id)
  501. idType := reflect.TypeOf(idValue.Interface())
  502. switch idType {
  503. case ptrPkType:
  504. if pkPtr, ok := (id).(*core.PK); ok {
  505. statement.idParam = pkPtr
  506. return statement
  507. }
  508. case pkType:
  509. if pk, ok := (id).(core.PK); ok {
  510. statement.idParam = &pk
  511. return statement
  512. }
  513. }
  514. switch idType.Kind() {
  515. case reflect.String:
  516. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  517. return statement
  518. }
  519. statement.idParam = &core.PK{id}
  520. return statement
  521. }
  522. // Incr Generate "Update ... Set column = column + arg" statement
  523. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  524. k := strings.ToLower(column)
  525. if len(arg) > 0 {
  526. statement.incrColumns[k] = incrParam{column, arg[0]}
  527. } else {
  528. statement.incrColumns[k] = incrParam{column, 1}
  529. }
  530. return statement
  531. }
  532. // Decr Generate "Update ... Set column = column - arg" statement
  533. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  534. k := strings.ToLower(column)
  535. if len(arg) > 0 {
  536. statement.decrColumns[k] = decrParam{column, arg[0]}
  537. } else {
  538. statement.decrColumns[k] = decrParam{column, 1}
  539. }
  540. return statement
  541. }
  542. // SetExpr Generate "Update ... Set column = {expression}" statement
  543. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  544. k := strings.ToLower(column)
  545. statement.exprColumns[k] = exprParam{column, expression}
  546. return statement
  547. }
  548. // Generate "Update ... Set column = column + arg" statement
  549. func (statement *Statement) getInc() map[string]incrParam {
  550. return statement.incrColumns
  551. }
  552. // Generate "Update ... Set column = column - arg" statement
  553. func (statement *Statement) getDec() map[string]decrParam {
  554. return statement.decrColumns
  555. }
  556. // Generate "Update ... Set column = {expression}" statement
  557. func (statement *Statement) getExpr() map[string]exprParam {
  558. return statement.exprColumns
  559. }
  560. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  561. newColumns := make([]string, 0)
  562. for _, col := range columns {
  563. col = strings.Replace(col, "`", "", -1)
  564. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  565. ccols := strings.Split(col, ",")
  566. for _, c := range ccols {
  567. fields := strings.Split(strings.TrimSpace(c), ".")
  568. if len(fields) == 1 {
  569. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  570. } else if len(fields) == 2 {
  571. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  572. statement.Engine.quote(fields[1]))
  573. } else {
  574. panic(errors.New("unwanted colnames"))
  575. }
  576. }
  577. }
  578. return newColumns
  579. }
  580. func (statement *Statement) colmap2NewColsWithQuote() []string {
  581. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  582. copy(newColumns, statement.columnMap)
  583. for i := 0; i < len(statement.columnMap); i++ {
  584. newColumns[i] = statement.Engine.Quote(newColumns[i])
  585. }
  586. return newColumns
  587. }
  588. // Distinct generates "DISTINCT col1, col2 " statement
  589. func (statement *Statement) Distinct(columns ...string) *Statement {
  590. statement.IsDistinct = true
  591. statement.Cols(columns...)
  592. return statement
  593. }
  594. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  595. func (statement *Statement) ForUpdate() *Statement {
  596. statement.IsForUpdate = true
  597. return statement
  598. }
  599. // Select replace select
  600. func (statement *Statement) Select(str string) *Statement {
  601. statement.selectStr = str
  602. return statement
  603. }
  604. // Cols generate "col1, col2" statement
  605. func (statement *Statement) Cols(columns ...string) *Statement {
  606. cols := col2NewCols(columns...)
  607. for _, nc := range cols {
  608. statement.columnMap.add(nc)
  609. }
  610. newColumns := statement.colmap2NewColsWithQuote()
  611. statement.ColumnStr = strings.Join(newColumns, ", ")
  612. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  613. return statement
  614. }
  615. // AllCols update use only: update all columns
  616. func (statement *Statement) AllCols() *Statement {
  617. statement.useAllCols = true
  618. return statement
  619. }
  620. // MustCols update use only: must update columns
  621. func (statement *Statement) MustCols(columns ...string) *Statement {
  622. newColumns := col2NewCols(columns...)
  623. for _, nc := range newColumns {
  624. statement.mustColumnMap[strings.ToLower(nc)] = true
  625. }
  626. return statement
  627. }
  628. // UseBool indicates that use bool fields as update contents and query contiditions
  629. func (statement *Statement) UseBool(columns ...string) *Statement {
  630. if len(columns) > 0 {
  631. statement.MustCols(columns...)
  632. } else {
  633. statement.allUseBool = true
  634. }
  635. return statement
  636. }
  637. // Omit do not use the columns
  638. func (statement *Statement) Omit(columns ...string) {
  639. newColumns := col2NewCols(columns...)
  640. for _, nc := range newColumns {
  641. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  642. }
  643. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  644. }
  645. // Nullable Update use only: update columns to null when value is nullable and zero-value
  646. func (statement *Statement) Nullable(columns ...string) {
  647. newColumns := col2NewCols(columns...)
  648. for _, nc := range newColumns {
  649. statement.nullableMap[strings.ToLower(nc)] = true
  650. }
  651. }
  652. // Top generate LIMIT limit statement
  653. func (statement *Statement) Top(limit int) *Statement {
  654. statement.Limit(limit)
  655. return statement
  656. }
  657. // Limit generate LIMIT start, limit statement
  658. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  659. statement.LimitN = limit
  660. if len(start) > 0 {
  661. statement.Start = start[0]
  662. }
  663. return statement
  664. }
  665. // OrderBy generate "Order By order" statement
  666. func (statement *Statement) OrderBy(order string) *Statement {
  667. if len(statement.OrderStr) > 0 {
  668. statement.OrderStr += ", "
  669. }
  670. statement.OrderStr += order
  671. return statement
  672. }
  673. // Desc generate `ORDER BY xx DESC`
  674. func (statement *Statement) Desc(colNames ...string) *Statement {
  675. var buf builder.StringBuilder
  676. if len(statement.OrderStr) > 0 {
  677. fmt.Fprint(&buf, statement.OrderStr, ", ")
  678. }
  679. newColNames := statement.col2NewColsWithQuote(colNames...)
  680. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  681. statement.OrderStr = buf.String()
  682. return statement
  683. }
  684. // Asc provide asc order by query condition, the input parameters are columns.
  685. func (statement *Statement) Asc(colNames ...string) *Statement {
  686. var buf builder.StringBuilder
  687. if len(statement.OrderStr) > 0 {
  688. fmt.Fprint(&buf, statement.OrderStr, ", ")
  689. }
  690. newColNames := statement.col2NewColsWithQuote(colNames...)
  691. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  692. statement.OrderStr = buf.String()
  693. return statement
  694. }
  695. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  696. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  697. v := rValue(tableNameOrBean)
  698. t := v.Type()
  699. if t.Kind() == reflect.Struct {
  700. var err error
  701. statement.RefTable, err = statement.Engine.autoMapType(v)
  702. if err != nil {
  703. statement.Engine.logger.Error(err)
  704. return statement
  705. }
  706. }
  707. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  708. return statement
  709. }
  710. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  711. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  712. var buf builder.StringBuilder
  713. if len(statement.JoinStr) > 0 {
  714. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  715. } else {
  716. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  717. }
  718. switch tp := tablename.(type) {
  719. case builder.Builder:
  720. subSQL, subQueryArgs, err := tp.ToSQL()
  721. if err != nil {
  722. statement.lastError = err
  723. return statement
  724. }
  725. tbs := strings.Split(tp.TableName(), ".")
  726. var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
  727. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  728. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  729. case *builder.Builder:
  730. subSQL, subQueryArgs, err := tp.ToSQL()
  731. if err != nil {
  732. statement.lastError = err
  733. return statement
  734. }
  735. tbs := strings.Split(tp.TableName(), ".")
  736. var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
  737. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  738. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  739. default:
  740. tbName := statement.Engine.TableName(tablename, true)
  741. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  742. }
  743. statement.JoinStr = buf.String()
  744. statement.joinArgs = append(statement.joinArgs, args...)
  745. return statement
  746. }
  747. // GroupBy generate "Group By keys" statement
  748. func (statement *Statement) GroupBy(keys string) *Statement {
  749. statement.GroupByStr = keys
  750. return statement
  751. }
  752. // Having generate "Having conditions" statement
  753. func (statement *Statement) Having(conditions string) *Statement {
  754. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  755. return statement
  756. }
  757. // Unscoped always disable struct tag "deleted"
  758. func (statement *Statement) Unscoped() *Statement {
  759. statement.unscoped = true
  760. return statement
  761. }
  762. func (statement *Statement) genColumnStr() string {
  763. if statement.RefTable == nil {
  764. return ""
  765. }
  766. var buf builder.StringBuilder
  767. columns := statement.RefTable.Columns()
  768. for _, col := range columns {
  769. if statement.omitColumnMap.contain(col.Name) {
  770. continue
  771. }
  772. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  773. continue
  774. }
  775. if col.MapType == core.ONLYTODB {
  776. continue
  777. }
  778. if buf.Len() != 0 {
  779. buf.WriteString(", ")
  780. }
  781. if statement.JoinStr != "" {
  782. if statement.TableAlias != "" {
  783. buf.WriteString(statement.TableAlias)
  784. } else {
  785. buf.WriteString(statement.TableName())
  786. }
  787. buf.WriteString(".")
  788. }
  789. statement.Engine.QuoteTo(&buf, col.Name)
  790. }
  791. return buf.String()
  792. }
  793. func (statement *Statement) genCreateTableSQL() string {
  794. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  795. statement.StoreEngine, statement.Charset)
  796. }
  797. func (statement *Statement) genIndexSQL() []string {
  798. var sqls []string
  799. tbName := statement.TableName()
  800. for _, index := range statement.RefTable.Indexes {
  801. if index.Type == core.IndexType {
  802. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  803. if sql != "" {
  804. sqls = append(sqls, sql)
  805. }
  806. }
  807. }
  808. return sqls
  809. }
  810. func uniqueName(tableName, uqeName string) string {
  811. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  812. }
  813. func (statement *Statement) genUniqueSQL() []string {
  814. var sqls []string
  815. tbName := statement.TableName()
  816. for _, index := range statement.RefTable.Indexes {
  817. if index.Type == core.UniqueType {
  818. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  819. sqls = append(sqls, sql)
  820. }
  821. }
  822. return sqls
  823. }
  824. func (statement *Statement) genDelIndexSQL() []string {
  825. var sqls []string
  826. tbName := statement.TableName()
  827. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  828. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  829. for idxName, index := range statement.RefTable.Indexes {
  830. var rIdxName string
  831. if index.Type == core.UniqueType {
  832. rIdxName = uniqueName(idxPrefixName, idxName)
  833. } else if index.Type == core.IndexType {
  834. rIdxName = indexName(idxPrefixName, idxName)
  835. }
  836. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  837. if statement.Engine.dialect.IndexOnTable() {
  838. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  839. }
  840. sqls = append(sqls, sql)
  841. }
  842. return sqls
  843. }
  844. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  845. quote := statement.Engine.Quote
  846. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  847. col.String(statement.Engine.dialect))
  848. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  849. sql += " COMMENT '" + col.Comment + "'"
  850. }
  851. sql += ";"
  852. return sql, []interface{}{}
  853. }
  854. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  855. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  856. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  857. }
  858. func (statement *Statement) mergeConds(bean interface{}) error {
  859. if !statement.noAutoCondition {
  860. var addedTableName = (len(statement.JoinStr) > 0)
  861. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  862. if err != nil {
  863. return err
  864. }
  865. statement.cond = statement.cond.And(autoCond)
  866. }
  867. if err := statement.processIDParam(); err != nil {
  868. return err
  869. }
  870. return nil
  871. }
  872. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  873. if err := statement.mergeConds(bean); err != nil {
  874. return "", nil, err
  875. }
  876. return builder.ToSQL(statement.cond)
  877. }
  878. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  879. v := rValue(bean)
  880. isStruct := v.Kind() == reflect.Struct
  881. if isStruct {
  882. statement.setRefBean(bean)
  883. }
  884. var columnStr = statement.ColumnStr
  885. if len(statement.selectStr) > 0 {
  886. columnStr = statement.selectStr
  887. } else {
  888. // TODO: always generate column names, not use * even if join
  889. if len(statement.JoinStr) == 0 {
  890. if len(columnStr) == 0 {
  891. if len(statement.GroupByStr) > 0 {
  892. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  893. } else {
  894. columnStr = statement.genColumnStr()
  895. }
  896. }
  897. } else {
  898. if len(columnStr) == 0 {
  899. if len(statement.GroupByStr) > 0 {
  900. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  901. }
  902. }
  903. }
  904. }
  905. if len(columnStr) == 0 {
  906. columnStr = "*"
  907. }
  908. if isStruct {
  909. if err := statement.mergeConds(bean); err != nil {
  910. return "", nil, err
  911. }
  912. } else {
  913. if err := statement.processIDParam(); err != nil {
  914. return "", nil, err
  915. }
  916. }
  917. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  918. if err != nil {
  919. return "", nil, err
  920. }
  921. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  922. if err != nil {
  923. return "", nil, err
  924. }
  925. return sqlStr, append(statement.joinArgs, condArgs...), nil
  926. }
  927. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  928. var condSQL string
  929. var condArgs []interface{}
  930. var err error
  931. if len(beans) > 0 {
  932. statement.setRefBean(beans[0])
  933. condSQL, condArgs, err = statement.genConds(beans[0])
  934. } else {
  935. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  936. }
  937. if err != nil {
  938. return "", nil, err
  939. }
  940. var selectSQL = statement.selectStr
  941. if len(selectSQL) <= 0 {
  942. if statement.IsDistinct {
  943. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  944. } else {
  945. selectSQL = "count(*)"
  946. }
  947. }
  948. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  949. if err != nil {
  950. return "", nil, err
  951. }
  952. return sqlStr, append(statement.joinArgs, condArgs...), nil
  953. }
  954. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  955. statement.setRefBean(bean)
  956. var sumStrs = make([]string, 0, len(columns))
  957. for _, colName := range columns {
  958. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  959. colName = statement.Engine.Quote(colName)
  960. }
  961. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  962. }
  963. sumSelect := strings.Join(sumStrs, ", ")
  964. condSQL, condArgs, err := statement.genConds(bean)
  965. if err != nil {
  966. return "", nil, err
  967. }
  968. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  969. if err != nil {
  970. return "", nil, err
  971. }
  972. return sqlStr, append(statement.joinArgs, condArgs...), nil
  973. }
  974. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
  975. var (
  976. distinct string
  977. dialect = statement.Engine.Dialect()
  978. quote = statement.Engine.Quote
  979. fromStr = " FROM "
  980. top, mssqlCondi, whereStr string
  981. )
  982. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  983. distinct = "DISTINCT "
  984. }
  985. if len(condSQL) > 0 {
  986. whereStr = " WHERE " + condSQL
  987. }
  988. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  989. fromStr += statement.TableName()
  990. } else {
  991. fromStr += quote(statement.TableName())
  992. }
  993. if statement.TableAlias != "" {
  994. if dialect.DBType() == core.ORACLE {
  995. fromStr += " " + quote(statement.TableAlias)
  996. } else {
  997. fromStr += " AS " + quote(statement.TableAlias)
  998. }
  999. }
  1000. if statement.JoinStr != "" {
  1001. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  1002. }
  1003. if dialect.DBType() == core.MSSQL {
  1004. if statement.LimitN > 0 {
  1005. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1006. }
  1007. if statement.Start > 0 {
  1008. var column string
  1009. if len(statement.RefTable.PKColumns()) == 0 {
  1010. for _, index := range statement.RefTable.Indexes {
  1011. if len(index.Cols) == 1 {
  1012. column = index.Cols[0]
  1013. break
  1014. }
  1015. }
  1016. if len(column) == 0 {
  1017. column = statement.RefTable.ColumnsSeq()[0]
  1018. }
  1019. } else {
  1020. column = statement.RefTable.PKColumns()[0].Name
  1021. }
  1022. if statement.needTableName() {
  1023. if len(statement.TableAlias) > 0 {
  1024. column = statement.TableAlias + "." + column
  1025. } else {
  1026. column = statement.TableName() + "." + column
  1027. }
  1028. }
  1029. var orderStr string
  1030. if needOrderBy && len(statement.OrderStr) > 0 {
  1031. orderStr = " ORDER BY " + statement.OrderStr
  1032. }
  1033. var groupStr string
  1034. if len(statement.GroupByStr) > 0 {
  1035. groupStr = " GROUP BY " + statement.GroupByStr
  1036. }
  1037. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  1038. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  1039. }
  1040. }
  1041. var buf builder.StringBuilder
  1042. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  1043. if len(mssqlCondi) > 0 {
  1044. if len(whereStr) > 0 {
  1045. fmt.Fprint(&buf, " AND ", mssqlCondi)
  1046. } else {
  1047. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  1048. }
  1049. }
  1050. if statement.GroupByStr != "" {
  1051. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  1052. }
  1053. if statement.HavingStr != "" {
  1054. fmt.Fprint(&buf, " ", statement.HavingStr)
  1055. }
  1056. if needOrderBy && statement.OrderStr != "" {
  1057. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  1058. }
  1059. if needLimit {
  1060. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1061. if statement.Start > 0 {
  1062. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
  1063. } else if statement.LimitN > 0 {
  1064. fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
  1065. }
  1066. } else if dialect.DBType() == core.ORACLE {
  1067. if statement.Start != 0 || statement.LimitN != 0 {
  1068. oldString := buf.String()
  1069. buf.Reset()
  1070. rawColStr := columnStr
  1071. if rawColStr == "*" {
  1072. rawColStr = "at.*"
  1073. }
  1074. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  1075. columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start)
  1076. }
  1077. }
  1078. }
  1079. if statement.IsForUpdate {
  1080. return dialect.ForUpdateSql(buf.String()), nil
  1081. }
  1082. return buf.String(), nil
  1083. }
  1084. func (statement *Statement) processIDParam() error {
  1085. if statement.idParam == nil || statement.RefTable == nil {
  1086. return nil
  1087. }
  1088. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1089. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1090. len(statement.RefTable.PrimaryKeys),
  1091. len(*statement.idParam),
  1092. )
  1093. }
  1094. for i, col := range statement.RefTable.PKColumns() {
  1095. var colName = statement.colName(col, statement.TableName())
  1096. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1097. }
  1098. return nil
  1099. }
  1100. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1101. var colnames = make([]string, len(cols))
  1102. for i, col := range cols {
  1103. if includeTableName {
  1104. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1105. "." + statement.Engine.Quote(col.Name)
  1106. } else {
  1107. colnames[i] = statement.Engine.Quote(col.Name)
  1108. }
  1109. }
  1110. return strings.Join(colnames, ", ")
  1111. }
  1112. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1113. if statement.RefTable != nil {
  1114. cols := statement.RefTable.PKColumns()
  1115. if len(cols) == 0 {
  1116. return ""
  1117. }
  1118. colstrs := statement.joinColumns(cols, false)
  1119. sqls := splitNNoCase(sqlStr, " from ", 2)
  1120. if len(sqls) != 2 {
  1121. return ""
  1122. }
  1123. var top string
  1124. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1125. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1126. }
  1127. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1128. return newsql
  1129. }
  1130. return ""
  1131. }
  1132. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1133. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1134. return "", ""
  1135. }
  1136. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1137. sqls := splitNNoCase(sqlStr, "where", 2)
  1138. if len(sqls) != 2 {
  1139. if len(sqls) == 1 {
  1140. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1141. colstrs, statement.Engine.Quote(statement.TableName()))
  1142. }
  1143. return "", ""
  1144. }
  1145. var whereStr = sqls[1]
  1146. //TODO: for postgres only, if any other database?
  1147. var paraStr string
  1148. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1149. paraStr = "$"
  1150. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1151. paraStr = ":"
  1152. }
  1153. if paraStr != "" {
  1154. if strings.Contains(sqls[1], paraStr) {
  1155. dollers := strings.Split(sqls[1], paraStr)
  1156. whereStr = dollers[0]
  1157. for i, c := range dollers[1:] {
  1158. ccs := strings.SplitN(c, " ", 2)
  1159. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1160. }
  1161. }
  1162. }
  1163. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1164. colstrs, statement.Engine.Quote(statement.TableName()),
  1165. whereStr)
  1166. }