statement.go 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql/driver"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "reflect"
  11. "strings"
  12. "time"
  13. "github.com/go-xorm/builder"
  14. "github.com/xormplus/core"
  15. )
  16. // Statement save all the sql info for executing SQL
  17. type Statement struct {
  18. RefTable *core.Table
  19. Engine *Engine
  20. Start int
  21. LimitN int
  22. idParam *core.PK
  23. OrderStr string
  24. JoinStr string
  25. joinArgs []interface{}
  26. GroupByStr string
  27. HavingStr string
  28. ColumnStr string
  29. selectStr string
  30. useAllCols bool
  31. OmitStr string
  32. AltTableName string
  33. tableName string
  34. RawSQL string
  35. RawParams []interface{}
  36. UseCascade bool
  37. UseAutoJoin bool
  38. StoreEngine string
  39. Charset string
  40. UseCache bool
  41. UseAutoTime bool
  42. noAutoCondition bool
  43. IsDistinct bool
  44. IsForUpdate bool
  45. TableAlias string
  46. allUseBool bool
  47. checkVersion bool
  48. unscoped bool
  49. columnMap columnMap
  50. omitColumnMap columnMap
  51. mustColumnMap map[string]bool
  52. nullableMap map[string]bool
  53. incrColumns map[string]incrParam
  54. decrColumns map[string]decrParam
  55. exprColumns map[string]exprParam
  56. cond builder.Cond
  57. bufferSize int
  58. }
  59. // Init reset all the statement's fields
  60. func (statement *Statement) Init() {
  61. statement.RefTable = nil
  62. statement.Start = 0
  63. statement.LimitN = 0
  64. statement.OrderStr = ""
  65. statement.UseCascade = true
  66. statement.JoinStr = ""
  67. statement.joinArgs = make([]interface{}, 0)
  68. statement.GroupByStr = ""
  69. statement.HavingStr = ""
  70. statement.ColumnStr = ""
  71. statement.OmitStr = ""
  72. statement.columnMap = columnMap{}
  73. statement.omitColumnMap = columnMap{}
  74. statement.AltTableName = ""
  75. statement.tableName = ""
  76. statement.idParam = nil
  77. statement.RawSQL = ""
  78. statement.RawParams = make([]interface{}, 0)
  79. statement.UseCache = true
  80. statement.UseAutoTime = true
  81. statement.noAutoCondition = false
  82. statement.IsDistinct = false
  83. statement.IsForUpdate = false
  84. statement.TableAlias = ""
  85. statement.selectStr = ""
  86. statement.allUseBool = false
  87. statement.useAllCols = false
  88. statement.mustColumnMap = make(map[string]bool)
  89. statement.nullableMap = make(map[string]bool)
  90. statement.checkVersion = true
  91. statement.unscoped = false
  92. statement.incrColumns = make(map[string]incrParam)
  93. statement.decrColumns = make(map[string]decrParam)
  94. statement.exprColumns = make(map[string]exprParam)
  95. statement.cond = builder.NewCond()
  96. statement.bufferSize = 0
  97. }
  98. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  99. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  100. statement.noAutoCondition = true
  101. if len(no) > 0 {
  102. statement.noAutoCondition = no[0]
  103. }
  104. return statement
  105. }
  106. // Alias set the table alias
  107. func (statement *Statement) Alias(alias string) *Statement {
  108. statement.TableAlias = alias
  109. return statement
  110. }
  111. // SQL adds raw sql statement
  112. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  113. switch query.(type) {
  114. case (*builder.Builder):
  115. var err error
  116. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  117. if err != nil {
  118. statement.Engine.logger.Error(err)
  119. }
  120. case string:
  121. statement.RawSQL = query.(string)
  122. statement.RawParams = args
  123. default:
  124. statement.Engine.logger.Error("unsupported sql type")
  125. }
  126. return statement
  127. }
  128. // Where add Where statement
  129. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  130. return statement.And(query, args...)
  131. }
  132. // And add Where & and statement
  133. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  134. switch query.(type) {
  135. case string:
  136. cond := builder.Expr(query.(string), args...)
  137. statement.cond = statement.cond.And(cond)
  138. case map[string]interface{}:
  139. cond := builder.Eq(query.(map[string]interface{}))
  140. statement.cond = statement.cond.And(cond)
  141. case builder.Cond:
  142. cond := query.(builder.Cond)
  143. statement.cond = statement.cond.And(cond)
  144. for _, v := range args {
  145. if vv, ok := v.(builder.Cond); ok {
  146. statement.cond = statement.cond.And(vv)
  147. }
  148. }
  149. default:
  150. // TODO: not support condition type
  151. }
  152. return statement
  153. }
  154. // Or add Where & Or statement
  155. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  156. switch query.(type) {
  157. case string:
  158. cond := builder.Expr(query.(string), args...)
  159. statement.cond = statement.cond.Or(cond)
  160. case map[string]interface{}:
  161. cond := builder.Eq(query.(map[string]interface{}))
  162. statement.cond = statement.cond.Or(cond)
  163. case builder.Cond:
  164. cond := query.(builder.Cond)
  165. statement.cond = statement.cond.Or(cond)
  166. for _, v := range args {
  167. if vv, ok := v.(builder.Cond); ok {
  168. statement.cond = statement.cond.Or(vv)
  169. }
  170. }
  171. default:
  172. // TODO: not support condition type
  173. }
  174. return statement
  175. }
  176. // In generate "Where column IN (?) " statement
  177. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  178. in := builder.In(statement.Engine.Quote(column), args...)
  179. statement.cond = statement.cond.And(in)
  180. return statement
  181. }
  182. // NotIn generate "Where column NOT IN (?) " statement
  183. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  184. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  185. statement.cond = statement.cond.And(notIn)
  186. return statement
  187. }
  188. func (statement *Statement) setRefValue(v reflect.Value) error {
  189. var err error
  190. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  191. if err != nil {
  192. return err
  193. }
  194. statement.tableName = statement.Engine.TableName(v, true)
  195. return nil
  196. }
  197. func (statement *Statement) setRefBean(bean interface{}) error {
  198. var err error
  199. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  200. if err != nil {
  201. return err
  202. }
  203. statement.tableName = statement.Engine.TableName(bean, true)
  204. return nil
  205. }
  206. // Auto generating update columnes and values according a struct
  207. func (statement *Statement) buildUpdates(bean interface{},
  208. includeVersion, includeUpdated, includeNil,
  209. includeAutoIncr, update bool) ([]string, []interface{}) {
  210. engine := statement.Engine
  211. table := statement.RefTable
  212. allUseBool := statement.allUseBool
  213. useAllCols := statement.useAllCols
  214. mustColumnMap := statement.mustColumnMap
  215. nullableMap := statement.nullableMap
  216. columnMap := statement.columnMap
  217. omitColumnMap := statement.omitColumnMap
  218. unscoped := statement.unscoped
  219. var colNames = make([]string, 0)
  220. var args = make([]interface{}, 0)
  221. for _, col := range table.Columns() {
  222. if !includeVersion && col.IsVersion {
  223. continue
  224. }
  225. if col.IsCreated {
  226. continue
  227. }
  228. if !includeUpdated && col.IsUpdated {
  229. continue
  230. }
  231. if !includeAutoIncr && col.IsAutoIncrement {
  232. continue
  233. }
  234. if col.IsDeleted && !unscoped {
  235. continue
  236. }
  237. if omitColumnMap.contain(col.Name) {
  238. continue
  239. }
  240. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  241. continue
  242. }
  243. if col.MapType == core.ONLYFROMDB {
  244. continue
  245. }
  246. fieldValuePtr, err := col.ValueOf(bean)
  247. if err != nil {
  248. engine.logger.Error(err)
  249. continue
  250. }
  251. fieldValue := *fieldValuePtr
  252. fieldType := reflect.TypeOf(fieldValue.Interface())
  253. if fieldType == nil {
  254. continue
  255. }
  256. requiredField := useAllCols
  257. includeNil := useAllCols
  258. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  259. if b {
  260. requiredField = true
  261. } else {
  262. continue
  263. }
  264. }
  265. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  266. if b, ok := getFlagForColumn(nullableMap, col); ok {
  267. if b && col.Nullable && isZero(fieldValue.Interface()) {
  268. var nilValue *int
  269. fieldValue = reflect.ValueOf(nilValue)
  270. fieldType = reflect.TypeOf(fieldValue.Interface())
  271. includeNil = true
  272. }
  273. }
  274. var val interface{}
  275. if fieldValue.CanAddr() {
  276. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  277. data, err := structConvert.ToDB()
  278. if err != nil {
  279. engine.logger.Error(err)
  280. } else {
  281. val = data
  282. }
  283. goto APPEND
  284. }
  285. }
  286. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  287. data, err := structConvert.ToDB()
  288. if err != nil {
  289. engine.logger.Error(err)
  290. } else {
  291. val = data
  292. }
  293. goto APPEND
  294. }
  295. if fieldType.Kind() == reflect.Ptr {
  296. if fieldValue.IsNil() {
  297. if includeNil {
  298. args = append(args, nil)
  299. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  300. }
  301. continue
  302. } else if !fieldValue.IsValid() {
  303. continue
  304. } else {
  305. // dereference ptr type to instance type
  306. fieldValue = fieldValue.Elem()
  307. fieldType = reflect.TypeOf(fieldValue.Interface())
  308. requiredField = true
  309. }
  310. }
  311. switch fieldType.Kind() {
  312. case reflect.Bool:
  313. if allUseBool || requiredField {
  314. val = fieldValue.Interface()
  315. } else {
  316. // if a bool in a struct, it will not be as a condition because it default is false,
  317. // please use Where() instead
  318. continue
  319. }
  320. case reflect.String:
  321. if !requiredField && fieldValue.String() == "" {
  322. continue
  323. }
  324. // for MyString, should convert to string or panic
  325. if fieldType.String() != reflect.String.String() {
  326. val = fieldValue.String()
  327. } else {
  328. val = fieldValue.Interface()
  329. }
  330. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  331. if !requiredField && fieldValue.Int() == 0 {
  332. continue
  333. }
  334. val = fieldValue.Interface()
  335. case reflect.Float32, reflect.Float64:
  336. if !requiredField && fieldValue.Float() == 0.0 {
  337. continue
  338. }
  339. val = fieldValue.Interface()
  340. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  341. if !requiredField && fieldValue.Uint() == 0 {
  342. continue
  343. }
  344. t := int64(fieldValue.Uint())
  345. val = reflect.ValueOf(&t).Interface()
  346. case reflect.Struct:
  347. if fieldType.ConvertibleTo(core.TimeType) {
  348. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  349. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  350. continue
  351. }
  352. val = engine.formatColTime(col, t)
  353. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  354. val, _ = nulType.Value()
  355. } else {
  356. if !col.SQLType.IsJson() {
  357. engine.autoMapType(fieldValue)
  358. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  359. if len(table.PrimaryKeys) == 1 {
  360. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  361. // fix non-int pk issues
  362. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  363. val = pkField.Interface()
  364. } else {
  365. continue
  366. }
  367. } else {
  368. //TODO: how to handler?
  369. panic("not supported")
  370. }
  371. } else {
  372. val = fieldValue.Interface()
  373. }
  374. } else {
  375. // Blank struct could not be as update data
  376. if requiredField || !isStructZero(fieldValue) {
  377. bytes, err := json.Marshal(fieldValue.Interface())
  378. if err != nil {
  379. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  380. }
  381. if col.SQLType.IsText() {
  382. val = string(bytes)
  383. } else if col.SQLType.IsBlob() {
  384. val = bytes
  385. }
  386. } else {
  387. continue
  388. }
  389. }
  390. }
  391. case reflect.Array, reflect.Slice, reflect.Map:
  392. if !requiredField {
  393. if fieldValue == reflect.Zero(fieldType) {
  394. continue
  395. }
  396. if fieldType.Kind() == reflect.Array {
  397. if isArrayValueZero(fieldValue) {
  398. continue
  399. }
  400. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  401. continue
  402. }
  403. }
  404. if col.SQLType.IsText() {
  405. bytes, err := json.Marshal(fieldValue.Interface())
  406. if err != nil {
  407. engine.logger.Error(err)
  408. continue
  409. }
  410. val = string(bytes)
  411. } else if col.SQLType.IsBlob() {
  412. var bytes []byte
  413. var err error
  414. if fieldType.Kind() == reflect.Slice &&
  415. fieldType.Elem().Kind() == reflect.Uint8 {
  416. if fieldValue.Len() > 0 {
  417. val = fieldValue.Bytes()
  418. } else {
  419. continue
  420. }
  421. } else if fieldType.Kind() == reflect.Array &&
  422. fieldType.Elem().Kind() == reflect.Uint8 {
  423. val = fieldValue.Slice(0, 0).Interface()
  424. } else {
  425. bytes, err = json.Marshal(fieldValue.Interface())
  426. if err != nil {
  427. engine.logger.Error(err)
  428. continue
  429. }
  430. val = bytes
  431. }
  432. } else {
  433. continue
  434. }
  435. default:
  436. val = fieldValue.Interface()
  437. }
  438. APPEND:
  439. args = append(args, val)
  440. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  441. continue
  442. }
  443. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  444. }
  445. return colNames, args
  446. }
  447. func (statement *Statement) needTableName() bool {
  448. return len(statement.JoinStr) > 0
  449. }
  450. func (statement *Statement) colName(col *core.Column, tableName string) string {
  451. if statement.needTableName() {
  452. var nm = tableName
  453. if len(statement.TableAlias) > 0 {
  454. nm = statement.TableAlias
  455. }
  456. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  457. }
  458. return statement.Engine.Quote(col.Name)
  459. }
  460. // TableName return current tableName
  461. func (statement *Statement) TableName() string {
  462. if statement.AltTableName != "" {
  463. return statement.AltTableName
  464. }
  465. return statement.tableName
  466. }
  467. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  468. func (statement *Statement) ID(id interface{}) *Statement {
  469. idValue := reflect.ValueOf(id)
  470. idType := reflect.TypeOf(idValue.Interface())
  471. switch idType {
  472. case ptrPkType:
  473. if pkPtr, ok := (id).(*core.PK); ok {
  474. statement.idParam = pkPtr
  475. return statement
  476. }
  477. case pkType:
  478. if pk, ok := (id).(core.PK); ok {
  479. statement.idParam = &pk
  480. return statement
  481. }
  482. }
  483. switch idType.Kind() {
  484. case reflect.String:
  485. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  486. return statement
  487. }
  488. statement.idParam = &core.PK{id}
  489. return statement
  490. }
  491. // Incr Generate "Update ... Set column = column + arg" statement
  492. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  493. k := strings.ToLower(column)
  494. if len(arg) > 0 {
  495. statement.incrColumns[k] = incrParam{column, arg[0]}
  496. } else {
  497. statement.incrColumns[k] = incrParam{column, 1}
  498. }
  499. return statement
  500. }
  501. // Decr Generate "Update ... Set column = column - arg" statement
  502. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  503. k := strings.ToLower(column)
  504. if len(arg) > 0 {
  505. statement.decrColumns[k] = decrParam{column, arg[0]}
  506. } else {
  507. statement.decrColumns[k] = decrParam{column, 1}
  508. }
  509. return statement
  510. }
  511. // SetExpr Generate "Update ... Set column = {expression}" statement
  512. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  513. k := strings.ToLower(column)
  514. statement.exprColumns[k] = exprParam{column, expression}
  515. return statement
  516. }
  517. // Generate "Update ... Set column = column + arg" statement
  518. func (statement *Statement) getInc() map[string]incrParam {
  519. return statement.incrColumns
  520. }
  521. // Generate "Update ... Set column = column - arg" statement
  522. func (statement *Statement) getDec() map[string]decrParam {
  523. return statement.decrColumns
  524. }
  525. // Generate "Update ... Set column = {expression}" statement
  526. func (statement *Statement) getExpr() map[string]exprParam {
  527. return statement.exprColumns
  528. }
  529. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  530. newColumns := make([]string, 0)
  531. for _, col := range columns {
  532. col = strings.Replace(col, "`", "", -1)
  533. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  534. ccols := strings.Split(col, ",")
  535. for _, c := range ccols {
  536. fields := strings.Split(strings.TrimSpace(c), ".")
  537. if len(fields) == 1 {
  538. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  539. } else if len(fields) == 2 {
  540. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  541. statement.Engine.quote(fields[1]))
  542. } else {
  543. panic(errors.New("unwanted colnames"))
  544. }
  545. }
  546. }
  547. return newColumns
  548. }
  549. func (statement *Statement) colmap2NewColsWithQuote() []string {
  550. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  551. copy(newColumns, statement.columnMap)
  552. for i := 0; i < len(statement.columnMap); i++ {
  553. newColumns[i] = statement.Engine.Quote(newColumns[i])
  554. }
  555. return newColumns
  556. }
  557. // Distinct generates "DISTINCT col1, col2 " statement
  558. func (statement *Statement) Distinct(columns ...string) *Statement {
  559. statement.IsDistinct = true
  560. statement.Cols(columns...)
  561. return statement
  562. }
  563. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  564. func (statement *Statement) ForUpdate() *Statement {
  565. statement.IsForUpdate = true
  566. return statement
  567. }
  568. // Select replace select
  569. func (statement *Statement) Select(str string) *Statement {
  570. statement.selectStr = str
  571. return statement
  572. }
  573. // Cols generate "col1, col2" statement
  574. func (statement *Statement) Cols(columns ...string) *Statement {
  575. cols := col2NewCols(columns...)
  576. for _, nc := range cols {
  577. statement.columnMap.add(nc)
  578. }
  579. newColumns := statement.colmap2NewColsWithQuote()
  580. statement.ColumnStr = strings.Join(newColumns, ", ")
  581. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  582. return statement
  583. }
  584. // AllCols update use only: update all columns
  585. func (statement *Statement) AllCols() *Statement {
  586. statement.useAllCols = true
  587. return statement
  588. }
  589. // MustCols update use only: must update columns
  590. func (statement *Statement) MustCols(columns ...string) *Statement {
  591. newColumns := col2NewCols(columns...)
  592. for _, nc := range newColumns {
  593. statement.mustColumnMap[strings.ToLower(nc)] = true
  594. }
  595. return statement
  596. }
  597. // UseBool indicates that use bool fields as update contents and query contiditions
  598. func (statement *Statement) UseBool(columns ...string) *Statement {
  599. if len(columns) > 0 {
  600. statement.MustCols(columns...)
  601. } else {
  602. statement.allUseBool = true
  603. }
  604. return statement
  605. }
  606. // Omit do not use the columns
  607. func (statement *Statement) Omit(columns ...string) {
  608. newColumns := col2NewCols(columns...)
  609. for _, nc := range newColumns {
  610. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  611. }
  612. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  613. }
  614. // Nullable Update use only: update columns to null when value is nullable and zero-value
  615. func (statement *Statement) Nullable(columns ...string) {
  616. newColumns := col2NewCols(columns...)
  617. for _, nc := range newColumns {
  618. statement.nullableMap[strings.ToLower(nc)] = true
  619. }
  620. }
  621. // Top generate LIMIT limit statement
  622. func (statement *Statement) Top(limit int) *Statement {
  623. statement.Limit(limit)
  624. return statement
  625. }
  626. // Limit generate LIMIT start, limit statement
  627. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  628. statement.LimitN = limit
  629. if len(start) > 0 {
  630. statement.Start = start[0]
  631. }
  632. return statement
  633. }
  634. // OrderBy generate "Order By order" statement
  635. func (statement *Statement) OrderBy(order string) *Statement {
  636. if len(statement.OrderStr) > 0 {
  637. statement.OrderStr += ", "
  638. }
  639. statement.OrderStr += order
  640. return statement
  641. }
  642. // Desc generate `ORDER BY xx DESC`
  643. func (statement *Statement) Desc(colNames ...string) *Statement {
  644. var buf builder.StringBuilder
  645. if len(statement.OrderStr) > 0 {
  646. fmt.Fprint(&buf, statement.OrderStr, ", ")
  647. }
  648. newColNames := statement.col2NewColsWithQuote(colNames...)
  649. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  650. statement.OrderStr = buf.String()
  651. return statement
  652. }
  653. // Asc provide asc order by query condition, the input parameters are columns.
  654. func (statement *Statement) Asc(colNames ...string) *Statement {
  655. var buf builder.StringBuilder
  656. if len(statement.OrderStr) > 0 {
  657. fmt.Fprint(&buf, statement.OrderStr, ", ")
  658. }
  659. newColNames := statement.col2NewColsWithQuote(colNames...)
  660. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  661. statement.OrderStr = buf.String()
  662. return statement
  663. }
  664. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  665. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  666. v := rValue(tableNameOrBean)
  667. t := v.Type()
  668. if t.Kind() == reflect.Struct {
  669. var err error
  670. statement.RefTable, err = statement.Engine.autoMapType(v)
  671. if err != nil {
  672. statement.Engine.logger.Error(err)
  673. return statement
  674. }
  675. }
  676. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  677. return statement
  678. }
  679. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  680. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  681. var buf builder.StringBuilder
  682. if len(statement.JoinStr) > 0 {
  683. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  684. } else {
  685. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  686. }
  687. tbName := statement.Engine.TableName(tablename, true)
  688. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  689. statement.JoinStr = buf.String()
  690. statement.joinArgs = append(statement.joinArgs, args...)
  691. return statement
  692. }
  693. // GroupBy generate "Group By keys" statement
  694. func (statement *Statement) GroupBy(keys string) *Statement {
  695. statement.GroupByStr = keys
  696. return statement
  697. }
  698. // Having generate "Having conditions" statement
  699. func (statement *Statement) Having(conditions string) *Statement {
  700. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  701. return statement
  702. }
  703. // Unscoped always disable struct tag "deleted"
  704. func (statement *Statement) Unscoped() *Statement {
  705. statement.unscoped = true
  706. return statement
  707. }
  708. func (statement *Statement) genColumnStr() string {
  709. if statement.RefTable == nil {
  710. return ""
  711. }
  712. var buf builder.StringBuilder
  713. columns := statement.RefTable.Columns()
  714. for _, col := range columns {
  715. if statement.omitColumnMap.contain(col.Name) {
  716. continue
  717. }
  718. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  719. continue
  720. }
  721. if col.MapType == core.ONLYTODB {
  722. continue
  723. }
  724. if buf.Len() != 0 {
  725. buf.WriteString(", ")
  726. }
  727. if statement.JoinStr != "" {
  728. if statement.TableAlias != "" {
  729. buf.WriteString(statement.TableAlias)
  730. } else {
  731. buf.WriteString(statement.TableName())
  732. }
  733. buf.WriteString(".")
  734. }
  735. statement.Engine.QuoteTo(&buf, col.Name)
  736. }
  737. return buf.String()
  738. }
  739. func (statement *Statement) genCreateTableSQL() string {
  740. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  741. statement.StoreEngine, statement.Charset)
  742. }
  743. func (statement *Statement) genIndexSQL() []string {
  744. var sqls []string
  745. tbName := statement.TableName()
  746. for _, index := range statement.RefTable.Indexes {
  747. if index.Type == core.IndexType {
  748. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  749. if sql != "" {
  750. sqls = append(sqls, sql)
  751. }
  752. }
  753. }
  754. return sqls
  755. }
  756. func uniqueName(tableName, uqeName string) string {
  757. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  758. }
  759. func (statement *Statement) genUniqueSQL() []string {
  760. var sqls []string
  761. tbName := statement.TableName()
  762. for _, index := range statement.RefTable.Indexes {
  763. if index.Type == core.UniqueType {
  764. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  765. sqls = append(sqls, sql)
  766. }
  767. }
  768. return sqls
  769. }
  770. func (statement *Statement) genDelIndexSQL() []string {
  771. var sqls []string
  772. tbName := statement.TableName()
  773. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  774. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  775. for idxName, index := range statement.RefTable.Indexes {
  776. var rIdxName string
  777. if index.Type == core.UniqueType {
  778. rIdxName = uniqueName(idxPrefixName, idxName)
  779. } else if index.Type == core.IndexType {
  780. rIdxName = indexName(idxPrefixName, idxName)
  781. }
  782. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  783. if statement.Engine.dialect.IndexOnTable() {
  784. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  785. }
  786. sqls = append(sqls, sql)
  787. }
  788. return sqls
  789. }
  790. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  791. quote := statement.Engine.Quote
  792. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  793. col.String(statement.Engine.dialect))
  794. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  795. sql += " COMMENT '" + col.Comment + "'"
  796. }
  797. sql += ";"
  798. return sql, []interface{}{}
  799. }
  800. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  801. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  802. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  803. }
  804. func (statement *Statement) mergeConds(bean interface{}) error {
  805. if !statement.noAutoCondition {
  806. var addedTableName = (len(statement.JoinStr) > 0)
  807. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  808. if err != nil {
  809. return err
  810. }
  811. statement.cond = statement.cond.And(autoCond)
  812. }
  813. if err := statement.processIDParam(); err != nil {
  814. return err
  815. }
  816. return nil
  817. }
  818. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  819. if err := statement.mergeConds(bean); err != nil {
  820. return "", nil, err
  821. }
  822. return builder.ToSQL(statement.cond)
  823. }
  824. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  825. v := rValue(bean)
  826. isStruct := v.Kind() == reflect.Struct
  827. if isStruct {
  828. statement.setRefBean(bean)
  829. }
  830. var columnStr = statement.ColumnStr
  831. if len(statement.selectStr) > 0 {
  832. columnStr = statement.selectStr
  833. } else {
  834. // TODO: always generate column names, not use * even if join
  835. if len(statement.JoinStr) == 0 {
  836. if len(columnStr) == 0 {
  837. if len(statement.GroupByStr) > 0 {
  838. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  839. } else {
  840. columnStr = statement.genColumnStr()
  841. }
  842. }
  843. } else {
  844. if len(columnStr) == 0 {
  845. if len(statement.GroupByStr) > 0 {
  846. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  847. }
  848. }
  849. }
  850. }
  851. if len(columnStr) == 0 {
  852. columnStr = "*"
  853. }
  854. if isStruct {
  855. if err := statement.mergeConds(bean); err != nil {
  856. return "", nil, err
  857. }
  858. } else {
  859. if err := statement.processIDParam(); err != nil {
  860. return "", nil, err
  861. }
  862. }
  863. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  864. if err != nil {
  865. return "", nil, err
  866. }
  867. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  868. if err != nil {
  869. return "", nil, err
  870. }
  871. return sqlStr, append(statement.joinArgs, condArgs...), nil
  872. }
  873. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  874. var condSQL string
  875. var condArgs []interface{}
  876. var err error
  877. if len(beans) > 0 {
  878. statement.setRefBean(beans[0])
  879. condSQL, condArgs, err = statement.genConds(beans[0])
  880. } else {
  881. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  882. }
  883. if err != nil {
  884. return "", nil, err
  885. }
  886. var selectSQL = statement.selectStr
  887. if len(selectSQL) <= 0 {
  888. if statement.IsDistinct {
  889. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  890. } else {
  891. selectSQL = "count(*)"
  892. }
  893. }
  894. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  895. if err != nil {
  896. return "", nil, err
  897. }
  898. return sqlStr, append(statement.joinArgs, condArgs...), nil
  899. }
  900. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  901. statement.setRefBean(bean)
  902. var sumStrs = make([]string, 0, len(columns))
  903. for _, colName := range columns {
  904. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  905. colName = statement.Engine.Quote(colName)
  906. }
  907. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  908. }
  909. sumSelect := strings.Join(sumStrs, ", ")
  910. condSQL, condArgs, err := statement.genConds(bean)
  911. if err != nil {
  912. return "", nil, err
  913. }
  914. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  915. if err != nil {
  916. return "", nil, err
  917. }
  918. return sqlStr, append(statement.joinArgs, condArgs...), nil
  919. }
  920. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
  921. var (
  922. distinct string
  923. dialect = statement.Engine.Dialect()
  924. quote = statement.Engine.Quote
  925. fromStr = " FROM "
  926. top, mssqlCondi, whereStr string
  927. )
  928. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  929. distinct = "DISTINCT "
  930. }
  931. if len(condSQL) > 0 {
  932. whereStr = " WHERE " + condSQL
  933. }
  934. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  935. fromStr += statement.TableName()
  936. } else {
  937. fromStr += quote(statement.TableName())
  938. }
  939. if statement.TableAlias != "" {
  940. if dialect.DBType() == core.ORACLE {
  941. fromStr += " " + quote(statement.TableAlias)
  942. } else {
  943. fromStr += " AS " + quote(statement.TableAlias)
  944. }
  945. }
  946. if statement.JoinStr != "" {
  947. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  948. }
  949. if dialect.DBType() == core.MSSQL {
  950. if statement.LimitN > 0 {
  951. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  952. }
  953. if statement.Start > 0 {
  954. var column string
  955. if len(statement.RefTable.PKColumns()) == 0 {
  956. for _, index := range statement.RefTable.Indexes {
  957. if len(index.Cols) == 1 {
  958. column = index.Cols[0]
  959. break
  960. }
  961. }
  962. if len(column) == 0 {
  963. column = statement.RefTable.ColumnsSeq()[0]
  964. }
  965. } else {
  966. column = statement.RefTable.PKColumns()[0].Name
  967. }
  968. if statement.needTableName() {
  969. if len(statement.TableAlias) > 0 {
  970. column = statement.TableAlias + "." + column
  971. } else {
  972. column = statement.TableName() + "." + column
  973. }
  974. }
  975. var orderStr string
  976. if needOrderBy && len(statement.OrderStr) > 0 {
  977. orderStr = " ORDER BY " + statement.OrderStr
  978. }
  979. var groupStr string
  980. if len(statement.GroupByStr) > 0 {
  981. groupStr = " GROUP BY " + statement.GroupByStr
  982. }
  983. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  984. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  985. }
  986. }
  987. var buf builder.StringBuilder
  988. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  989. if len(mssqlCondi) > 0 {
  990. if len(whereStr) > 0 {
  991. fmt.Fprint(&buf, " AND ", mssqlCondi)
  992. } else {
  993. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  994. }
  995. }
  996. if statement.GroupByStr != "" {
  997. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  998. }
  999. if statement.HavingStr != "" {
  1000. fmt.Fprint(&buf, " ", statement.HavingStr)
  1001. }
  1002. if needOrderBy && statement.OrderStr != "" {
  1003. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  1004. }
  1005. if needLimit {
  1006. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1007. if statement.Start > 0 {
  1008. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
  1009. } else if statement.LimitN > 0 {
  1010. fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
  1011. }
  1012. } else if dialect.DBType() == core.ORACLE {
  1013. if statement.Start != 0 || statement.LimitN != 0 {
  1014. oldString := buf.String()
  1015. buf.Reset()
  1016. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  1017. columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start)
  1018. }
  1019. }
  1020. }
  1021. if statement.IsForUpdate {
  1022. return dialect.ForUpdateSql(buf.String()), nil
  1023. }
  1024. return buf.String(), nil
  1025. }
  1026. func (statement *Statement) processIDParam() error {
  1027. if statement.idParam == nil || statement.RefTable == nil {
  1028. return nil
  1029. }
  1030. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1031. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1032. len(statement.RefTable.PrimaryKeys),
  1033. len(*statement.idParam),
  1034. )
  1035. }
  1036. for i, col := range statement.RefTable.PKColumns() {
  1037. var colName = statement.colName(col, statement.TableName())
  1038. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1039. }
  1040. return nil
  1041. }
  1042. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1043. var colnames = make([]string, len(cols))
  1044. for i, col := range cols {
  1045. if includeTableName {
  1046. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1047. "." + statement.Engine.Quote(col.Name)
  1048. } else {
  1049. colnames[i] = statement.Engine.Quote(col.Name)
  1050. }
  1051. }
  1052. return strings.Join(colnames, ", ")
  1053. }
  1054. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1055. if statement.RefTable != nil {
  1056. cols := statement.RefTable.PKColumns()
  1057. if len(cols) == 0 {
  1058. return ""
  1059. }
  1060. colstrs := statement.joinColumns(cols, false)
  1061. sqls := splitNNoCase(sqlStr, " from ", 2)
  1062. if len(sqls) != 2 {
  1063. return ""
  1064. }
  1065. var top string
  1066. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1067. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1068. }
  1069. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1070. return newsql
  1071. }
  1072. return ""
  1073. }
  1074. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1075. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1076. return "", ""
  1077. }
  1078. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1079. sqls := splitNNoCase(sqlStr, "where", 2)
  1080. if len(sqls) != 2 {
  1081. if len(sqls) == 1 {
  1082. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1083. colstrs, statement.Engine.Quote(statement.TableName()))
  1084. }
  1085. return "", ""
  1086. }
  1087. var whereStr = sqls[1]
  1088. //TODO: for postgres only, if any other database?
  1089. var paraStr string
  1090. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1091. paraStr = "$"
  1092. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1093. paraStr = ":"
  1094. }
  1095. if paraStr != "" {
  1096. if strings.Contains(sqls[1], paraStr) {
  1097. dollers := strings.Split(sqls[1], paraStr)
  1098. whereStr = dollers[0]
  1099. for i, c := range dollers[1:] {
  1100. ccs := strings.SplitN(c, " ", 2)
  1101. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1102. }
  1103. }
  1104. }
  1105. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1106. colstrs, statement.Engine.Quote(statement.TableName()),
  1107. whereStr)
  1108. }