statement.go 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "bytes"
  7. "database/sql/driver"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "reflect"
  12. "strings"
  13. "time"
  14. "github.com/go-xorm/builder"
  15. "github.com/xormplus/core"
  16. )
  17. // Statement save all the sql info for executing SQL
  18. type Statement struct {
  19. RefTable *core.Table
  20. Engine *Engine
  21. Start int
  22. LimitN int
  23. idParam *core.PK
  24. OrderStr string
  25. JoinStr string
  26. joinArgs []interface{}
  27. GroupByStr string
  28. HavingStr string
  29. ColumnStr string
  30. selectStr string
  31. useAllCols bool
  32. OmitStr string
  33. AltTableName string
  34. tableName string
  35. RawSQL string
  36. RawParams []interface{}
  37. UseCascade bool
  38. UseAutoJoin bool
  39. StoreEngine string
  40. Charset string
  41. UseCache bool
  42. UseAutoTime bool
  43. noAutoCondition bool
  44. IsDistinct bool
  45. IsForUpdate bool
  46. TableAlias string
  47. allUseBool bool
  48. checkVersion bool
  49. unscoped bool
  50. columnMap columnMap
  51. omitColumnMap columnMap
  52. mustColumnMap map[string]bool
  53. nullableMap map[string]bool
  54. incrColumns map[string]incrParam
  55. decrColumns map[string]decrParam
  56. exprColumns map[string]exprParam
  57. cond builder.Cond
  58. bufferSize int
  59. }
  60. // Init reset all the statement's fields
  61. func (statement *Statement) Init() {
  62. statement.RefTable = nil
  63. statement.Start = 0
  64. statement.LimitN = 0
  65. statement.OrderStr = ""
  66. statement.UseCascade = true
  67. statement.JoinStr = ""
  68. statement.joinArgs = make([]interface{}, 0)
  69. statement.GroupByStr = ""
  70. statement.HavingStr = ""
  71. statement.ColumnStr = ""
  72. statement.OmitStr = ""
  73. statement.columnMap = columnMap{}
  74. statement.omitColumnMap = columnMap{}
  75. statement.AltTableName = ""
  76. statement.tableName = ""
  77. statement.idParam = nil
  78. statement.RawSQL = ""
  79. statement.RawParams = make([]interface{}, 0)
  80. statement.UseCache = true
  81. statement.UseAutoTime = true
  82. statement.noAutoCondition = false
  83. statement.IsDistinct = false
  84. statement.IsForUpdate = false
  85. statement.TableAlias = ""
  86. statement.selectStr = ""
  87. statement.allUseBool = false
  88. statement.useAllCols = false
  89. statement.mustColumnMap = make(map[string]bool)
  90. statement.nullableMap = make(map[string]bool)
  91. statement.checkVersion = true
  92. statement.unscoped = false
  93. statement.incrColumns = make(map[string]incrParam)
  94. statement.decrColumns = make(map[string]decrParam)
  95. statement.exprColumns = make(map[string]exprParam)
  96. statement.cond = builder.NewCond()
  97. statement.bufferSize = 0
  98. }
  99. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  100. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  101. statement.noAutoCondition = true
  102. if len(no) > 0 {
  103. statement.noAutoCondition = no[0]
  104. }
  105. return statement
  106. }
  107. // Alias set the table alias
  108. func (statement *Statement) Alias(alias string) *Statement {
  109. statement.TableAlias = alias
  110. return statement
  111. }
  112. // SQL adds raw sql statement
  113. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  114. switch query.(type) {
  115. case (*builder.Builder):
  116. var err error
  117. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  118. if err != nil {
  119. statement.Engine.logger.Error(err)
  120. }
  121. case string:
  122. statement.RawSQL = query.(string)
  123. statement.RawParams = args
  124. default:
  125. statement.Engine.logger.Error("unsupported sql type")
  126. }
  127. return statement
  128. }
  129. // Where add Where statement
  130. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  131. return statement.And(query, args...)
  132. }
  133. // And add Where & and statement
  134. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  135. switch query.(type) {
  136. case string:
  137. cond := builder.Expr(query.(string), args...)
  138. statement.cond = statement.cond.And(cond)
  139. case map[string]interface{}:
  140. cond := builder.Eq(query.(map[string]interface{}))
  141. statement.cond = statement.cond.And(cond)
  142. case builder.Cond:
  143. cond := query.(builder.Cond)
  144. statement.cond = statement.cond.And(cond)
  145. for _, v := range args {
  146. if vv, ok := v.(builder.Cond); ok {
  147. statement.cond = statement.cond.And(vv)
  148. }
  149. }
  150. default:
  151. // TODO: not support condition type
  152. }
  153. return statement
  154. }
  155. // Or add Where & Or statement
  156. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  157. switch query.(type) {
  158. case string:
  159. cond := builder.Expr(query.(string), args...)
  160. statement.cond = statement.cond.Or(cond)
  161. case map[string]interface{}:
  162. cond := builder.Eq(query.(map[string]interface{}))
  163. statement.cond = statement.cond.Or(cond)
  164. case builder.Cond:
  165. cond := query.(builder.Cond)
  166. statement.cond = statement.cond.Or(cond)
  167. for _, v := range args {
  168. if vv, ok := v.(builder.Cond); ok {
  169. statement.cond = statement.cond.Or(vv)
  170. }
  171. }
  172. default:
  173. // TODO: not support condition type
  174. }
  175. return statement
  176. }
  177. // In generate "Where column IN (?) " statement
  178. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  179. in := builder.In(statement.Engine.Quote(column), args...)
  180. statement.cond = statement.cond.And(in)
  181. return statement
  182. }
  183. // NotIn generate "Where column NOT IN (?) " statement
  184. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  185. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  186. statement.cond = statement.cond.And(notIn)
  187. return statement
  188. }
  189. func (statement *Statement) setRefValue(v reflect.Value) error {
  190. var err error
  191. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  192. if err != nil {
  193. return err
  194. }
  195. statement.tableName = statement.Engine.TableName(v, true)
  196. return nil
  197. }
  198. func (statement *Statement) setRefBean(bean interface{}) error {
  199. var err error
  200. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  201. if err != nil {
  202. return err
  203. }
  204. statement.tableName = statement.Engine.TableName(bean, true)
  205. return nil
  206. }
  207. // Auto generating update columnes and values according a struct
  208. func (statement *Statement) buildUpdates(bean interface{},
  209. includeVersion, includeUpdated, includeNil,
  210. includeAutoIncr, update bool) ([]string, []interface{}) {
  211. engine := statement.Engine
  212. table := statement.RefTable
  213. allUseBool := statement.allUseBool
  214. useAllCols := statement.useAllCols
  215. mustColumnMap := statement.mustColumnMap
  216. nullableMap := statement.nullableMap
  217. columnMap := statement.columnMap
  218. omitColumnMap := statement.omitColumnMap
  219. unscoped := statement.unscoped
  220. var colNames = make([]string, 0)
  221. var args = make([]interface{}, 0)
  222. for _, col := range table.Columns() {
  223. if !includeVersion && col.IsVersion {
  224. continue
  225. }
  226. if col.IsCreated {
  227. continue
  228. }
  229. if !includeUpdated && col.IsUpdated {
  230. continue
  231. }
  232. if !includeAutoIncr && col.IsAutoIncrement {
  233. continue
  234. }
  235. if col.IsDeleted && !unscoped {
  236. continue
  237. }
  238. if omitColumnMap.contain(col.Name) {
  239. continue
  240. }
  241. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  242. continue
  243. }
  244. if col.MapType == core.ONLYFROMDB {
  245. continue
  246. }
  247. fieldValuePtr, err := col.ValueOf(bean)
  248. if err != nil {
  249. engine.logger.Error(err)
  250. continue
  251. }
  252. fieldValue := *fieldValuePtr
  253. fieldType := reflect.TypeOf(fieldValue.Interface())
  254. if fieldType == nil {
  255. continue
  256. }
  257. requiredField := useAllCols
  258. includeNil := useAllCols
  259. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  260. if b {
  261. requiredField = true
  262. } else {
  263. continue
  264. }
  265. }
  266. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  267. if b, ok := getFlagForColumn(nullableMap, col); ok {
  268. if b && col.Nullable && isZero(fieldValue.Interface()) {
  269. var nilValue *int
  270. fieldValue = reflect.ValueOf(nilValue)
  271. fieldType = reflect.TypeOf(fieldValue.Interface())
  272. includeNil = true
  273. }
  274. }
  275. var val interface{}
  276. if fieldValue.CanAddr() {
  277. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  278. data, err := structConvert.ToDB()
  279. if err != nil {
  280. engine.logger.Error(err)
  281. } else {
  282. val = data
  283. }
  284. goto APPEND
  285. }
  286. }
  287. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  288. data, err := structConvert.ToDB()
  289. if err != nil {
  290. engine.logger.Error(err)
  291. } else {
  292. val = data
  293. }
  294. goto APPEND
  295. }
  296. if fieldType.Kind() == reflect.Ptr {
  297. if fieldValue.IsNil() {
  298. if includeNil {
  299. args = append(args, nil)
  300. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  301. }
  302. continue
  303. } else if !fieldValue.IsValid() {
  304. continue
  305. } else {
  306. // dereference ptr type to instance type
  307. fieldValue = fieldValue.Elem()
  308. fieldType = reflect.TypeOf(fieldValue.Interface())
  309. requiredField = true
  310. }
  311. }
  312. switch fieldType.Kind() {
  313. case reflect.Bool:
  314. if allUseBool || requiredField {
  315. val = fieldValue.Interface()
  316. } else {
  317. // if a bool in a struct, it will not be as a condition because it default is false,
  318. // please use Where() instead
  319. continue
  320. }
  321. case reflect.String:
  322. if !requiredField && fieldValue.String() == "" {
  323. continue
  324. }
  325. // for MyString, should convert to string or panic
  326. if fieldType.String() != reflect.String.String() {
  327. val = fieldValue.String()
  328. } else {
  329. val = fieldValue.Interface()
  330. }
  331. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  332. if !requiredField && fieldValue.Int() == 0 {
  333. continue
  334. }
  335. val = fieldValue.Interface()
  336. case reflect.Float32, reflect.Float64:
  337. if !requiredField && fieldValue.Float() == 0.0 {
  338. continue
  339. }
  340. val = fieldValue.Interface()
  341. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  342. if !requiredField && fieldValue.Uint() == 0 {
  343. continue
  344. }
  345. t := int64(fieldValue.Uint())
  346. val = reflect.ValueOf(&t).Interface()
  347. case reflect.Struct:
  348. if fieldType.ConvertibleTo(core.TimeType) {
  349. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  350. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  351. continue
  352. }
  353. val = engine.formatColTime(col, t)
  354. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  355. val, _ = nulType.Value()
  356. } else {
  357. if !col.SQLType.IsJson() {
  358. engine.autoMapType(fieldValue)
  359. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  360. if len(table.PrimaryKeys) == 1 {
  361. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  362. // fix non-int pk issues
  363. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  364. val = pkField.Interface()
  365. } else {
  366. continue
  367. }
  368. } else {
  369. //TODO: how to handler?
  370. panic("not supported")
  371. }
  372. } else {
  373. val = fieldValue.Interface()
  374. }
  375. } else {
  376. // Blank struct could not be as update data
  377. if requiredField || !isStructZero(fieldValue) {
  378. bytes, err := json.Marshal(fieldValue.Interface())
  379. if err != nil {
  380. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  381. }
  382. if col.SQLType.IsText() {
  383. val = string(bytes)
  384. } else if col.SQLType.IsBlob() {
  385. val = bytes
  386. }
  387. } else {
  388. continue
  389. }
  390. }
  391. }
  392. case reflect.Array, reflect.Slice, reflect.Map:
  393. if !requiredField {
  394. if fieldValue == reflect.Zero(fieldType) {
  395. continue
  396. }
  397. if fieldType.Kind() == reflect.Array {
  398. if isArrayValueZero(fieldValue) {
  399. continue
  400. }
  401. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  402. continue
  403. }
  404. }
  405. if col.SQLType.IsText() {
  406. bytes, err := json.Marshal(fieldValue.Interface())
  407. if err != nil {
  408. engine.logger.Error(err)
  409. continue
  410. }
  411. val = string(bytes)
  412. } else if col.SQLType.IsBlob() {
  413. var bytes []byte
  414. var err error
  415. if fieldType.Kind() == reflect.Slice &&
  416. fieldType.Elem().Kind() == reflect.Uint8 {
  417. if fieldValue.Len() > 0 {
  418. val = fieldValue.Bytes()
  419. } else {
  420. continue
  421. }
  422. } else if fieldType.Kind() == reflect.Array &&
  423. fieldType.Elem().Kind() == reflect.Uint8 {
  424. val = fieldValue.Slice(0, 0).Interface()
  425. } else {
  426. bytes, err = json.Marshal(fieldValue.Interface())
  427. if err != nil {
  428. engine.logger.Error(err)
  429. continue
  430. }
  431. val = bytes
  432. }
  433. } else {
  434. continue
  435. }
  436. default:
  437. val = fieldValue.Interface()
  438. }
  439. APPEND:
  440. args = append(args, val)
  441. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  442. continue
  443. }
  444. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  445. }
  446. return colNames, args
  447. }
  448. func (statement *Statement) needTableName() bool {
  449. return len(statement.JoinStr) > 0
  450. }
  451. func (statement *Statement) colName(col *core.Column, tableName string) string {
  452. if statement.needTableName() {
  453. var nm = tableName
  454. if len(statement.TableAlias) > 0 {
  455. nm = statement.TableAlias
  456. }
  457. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  458. }
  459. return statement.Engine.Quote(col.Name)
  460. }
  461. // TableName return current tableName
  462. func (statement *Statement) TableName() string {
  463. if statement.AltTableName != "" {
  464. return statement.AltTableName
  465. }
  466. return statement.tableName
  467. }
  468. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  469. func (statement *Statement) ID(id interface{}) *Statement {
  470. idValue := reflect.ValueOf(id)
  471. idType := reflect.TypeOf(idValue.Interface())
  472. switch idType {
  473. case ptrPkType:
  474. if pkPtr, ok := (id).(*core.PK); ok {
  475. statement.idParam = pkPtr
  476. return statement
  477. }
  478. case pkType:
  479. if pk, ok := (id).(core.PK); ok {
  480. statement.idParam = &pk
  481. return statement
  482. }
  483. }
  484. switch idType.Kind() {
  485. case reflect.String:
  486. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  487. return statement
  488. }
  489. statement.idParam = &core.PK{id}
  490. return statement
  491. }
  492. // Incr Generate "Update ... Set column = column + arg" statement
  493. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  494. k := strings.ToLower(column)
  495. if len(arg) > 0 {
  496. statement.incrColumns[k] = incrParam{column, arg[0]}
  497. } else {
  498. statement.incrColumns[k] = incrParam{column, 1}
  499. }
  500. return statement
  501. }
  502. // Decr Generate "Update ... Set column = column - arg" statement
  503. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  504. k := strings.ToLower(column)
  505. if len(arg) > 0 {
  506. statement.decrColumns[k] = decrParam{column, arg[0]}
  507. } else {
  508. statement.decrColumns[k] = decrParam{column, 1}
  509. }
  510. return statement
  511. }
  512. // SetExpr Generate "Update ... Set column = {expression}" statement
  513. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  514. k := strings.ToLower(column)
  515. statement.exprColumns[k] = exprParam{column, expression}
  516. return statement
  517. }
  518. // Generate "Update ... Set column = column + arg" statement
  519. func (statement *Statement) getInc() map[string]incrParam {
  520. return statement.incrColumns
  521. }
  522. // Generate "Update ... Set column = column - arg" statement
  523. func (statement *Statement) getDec() map[string]decrParam {
  524. return statement.decrColumns
  525. }
  526. // Generate "Update ... Set column = {expression}" statement
  527. func (statement *Statement) getExpr() map[string]exprParam {
  528. return statement.exprColumns
  529. }
  530. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  531. newColumns := make([]string, 0)
  532. for _, col := range columns {
  533. col = strings.Replace(col, "`", "", -1)
  534. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  535. ccols := strings.Split(col, ",")
  536. for _, c := range ccols {
  537. fields := strings.Split(strings.TrimSpace(c), ".")
  538. if len(fields) == 1 {
  539. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  540. } else if len(fields) == 2 {
  541. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  542. statement.Engine.quote(fields[1]))
  543. } else {
  544. panic(errors.New("unwanted colnames"))
  545. }
  546. }
  547. }
  548. return newColumns
  549. }
  550. func (statement *Statement) colmap2NewColsWithQuote() []string {
  551. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  552. copy(newColumns, statement.columnMap)
  553. for i := 0; i < len(statement.columnMap); i++ {
  554. newColumns[i] = statement.Engine.Quote(newColumns[i])
  555. }
  556. return newColumns
  557. }
  558. // Distinct generates "DISTINCT col1, col2 " statement
  559. func (statement *Statement) Distinct(columns ...string) *Statement {
  560. statement.IsDistinct = true
  561. statement.Cols(columns...)
  562. return statement
  563. }
  564. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  565. func (statement *Statement) ForUpdate() *Statement {
  566. statement.IsForUpdate = true
  567. return statement
  568. }
  569. // Select replace select
  570. func (statement *Statement) Select(str string) *Statement {
  571. statement.selectStr = str
  572. return statement
  573. }
  574. // Cols generate "col1, col2" statement
  575. func (statement *Statement) Cols(columns ...string) *Statement {
  576. cols := col2NewCols(columns...)
  577. for _, nc := range cols {
  578. statement.columnMap.add(nc)
  579. }
  580. newColumns := statement.colmap2NewColsWithQuote()
  581. statement.ColumnStr = strings.Join(newColumns, ", ")
  582. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  583. return statement
  584. }
  585. // AllCols update use only: update all columns
  586. func (statement *Statement) AllCols() *Statement {
  587. statement.useAllCols = true
  588. return statement
  589. }
  590. // MustCols update use only: must update columns
  591. func (statement *Statement) MustCols(columns ...string) *Statement {
  592. newColumns := col2NewCols(columns...)
  593. for _, nc := range newColumns {
  594. statement.mustColumnMap[strings.ToLower(nc)] = true
  595. }
  596. return statement
  597. }
  598. // UseBool indicates that use bool fields as update contents and query contiditions
  599. func (statement *Statement) UseBool(columns ...string) *Statement {
  600. if len(columns) > 0 {
  601. statement.MustCols(columns...)
  602. } else {
  603. statement.allUseBool = true
  604. }
  605. return statement
  606. }
  607. // Omit do not use the columns
  608. func (statement *Statement) Omit(columns ...string) {
  609. newColumns := col2NewCols(columns...)
  610. for _, nc := range newColumns {
  611. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  612. }
  613. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  614. }
  615. // Nullable Update use only: update columns to null when value is nullable and zero-value
  616. func (statement *Statement) Nullable(columns ...string) {
  617. newColumns := col2NewCols(columns...)
  618. for _, nc := range newColumns {
  619. statement.nullableMap[strings.ToLower(nc)] = true
  620. }
  621. }
  622. // Top generate LIMIT limit statement
  623. func (statement *Statement) Top(limit int) *Statement {
  624. statement.Limit(limit)
  625. return statement
  626. }
  627. // Limit generate LIMIT start, limit statement
  628. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  629. statement.LimitN = limit
  630. if len(start) > 0 {
  631. statement.Start = start[0]
  632. }
  633. return statement
  634. }
  635. // OrderBy generate "Order By order" statement
  636. func (statement *Statement) OrderBy(order string) *Statement {
  637. if len(statement.OrderStr) > 0 {
  638. statement.OrderStr += ", "
  639. }
  640. statement.OrderStr += order
  641. return statement
  642. }
  643. // Desc generate `ORDER BY xx DESC`
  644. func (statement *Statement) Desc(colNames ...string) *Statement {
  645. var buf bytes.Buffer
  646. fmt.Fprintf(&buf, statement.OrderStr)
  647. if len(statement.OrderStr) > 0 {
  648. fmt.Fprint(&buf, ", ")
  649. }
  650. newColNames := statement.col2NewColsWithQuote(colNames...)
  651. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  652. statement.OrderStr = buf.String()
  653. return statement
  654. }
  655. // Asc provide asc order by query condition, the input parameters are columns.
  656. func (statement *Statement) Asc(colNames ...string) *Statement {
  657. var buf bytes.Buffer
  658. fmt.Fprintf(&buf, statement.OrderStr)
  659. if len(statement.OrderStr) > 0 {
  660. fmt.Fprint(&buf, ", ")
  661. }
  662. newColNames := statement.col2NewColsWithQuote(colNames...)
  663. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  664. statement.OrderStr = buf.String()
  665. return statement
  666. }
  667. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  668. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  669. v := rValue(tableNameOrBean)
  670. t := v.Type()
  671. if t.Kind() == reflect.Struct {
  672. var err error
  673. statement.RefTable, err = statement.Engine.autoMapType(v)
  674. if err != nil {
  675. statement.Engine.logger.Error(err)
  676. return statement
  677. }
  678. }
  679. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  680. return statement
  681. }
  682. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  683. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  684. var buf bytes.Buffer
  685. if len(statement.JoinStr) > 0 {
  686. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  687. } else {
  688. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  689. }
  690. tbName := statement.Engine.TableName(tablename, true)
  691. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  692. statement.JoinStr = buf.String()
  693. statement.joinArgs = append(statement.joinArgs, args...)
  694. return statement
  695. }
  696. // GroupBy generate "Group By keys" statement
  697. func (statement *Statement) GroupBy(keys string) *Statement {
  698. statement.GroupByStr = keys
  699. return statement
  700. }
  701. // Having generate "Having conditions" statement
  702. func (statement *Statement) Having(conditions string) *Statement {
  703. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  704. return statement
  705. }
  706. // Unscoped always disable struct tag "deleted"
  707. func (statement *Statement) Unscoped() *Statement {
  708. statement.unscoped = true
  709. return statement
  710. }
  711. func (statement *Statement) genColumnStr() string {
  712. var buf bytes.Buffer
  713. if statement.RefTable == nil {
  714. return ""
  715. }
  716. columns := statement.RefTable.Columns()
  717. for _, col := range columns {
  718. if statement.omitColumnMap.contain(col.Name) {
  719. continue
  720. }
  721. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  722. continue
  723. }
  724. if col.MapType == core.ONLYTODB {
  725. continue
  726. }
  727. if buf.Len() != 0 {
  728. buf.WriteString(", ")
  729. }
  730. if statement.JoinStr != "" {
  731. if statement.TableAlias != "" {
  732. buf.WriteString(statement.TableAlias)
  733. } else {
  734. buf.WriteString(statement.TableName())
  735. }
  736. buf.WriteString(".")
  737. }
  738. statement.Engine.QuoteTo(&buf, col.Name)
  739. }
  740. return buf.String()
  741. }
  742. func (statement *Statement) genCreateTableSQL() string {
  743. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  744. statement.StoreEngine, statement.Charset)
  745. }
  746. func (statement *Statement) genIndexSQL() []string {
  747. var sqls []string
  748. tbName := statement.TableName()
  749. for _, index := range statement.RefTable.Indexes {
  750. if index.Type == core.IndexType {
  751. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  752. if sql != "" {
  753. sqls = append(sqls, sql)
  754. }
  755. }
  756. }
  757. return sqls
  758. }
  759. func uniqueName(tableName, uqeName string) string {
  760. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  761. }
  762. func (statement *Statement) genUniqueSQL() []string {
  763. var sqls []string
  764. tbName := statement.TableName()
  765. for _, index := range statement.RefTable.Indexes {
  766. if index.Type == core.UniqueType {
  767. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  768. sqls = append(sqls, sql)
  769. }
  770. }
  771. return sqls
  772. }
  773. func (statement *Statement) genDelIndexSQL() []string {
  774. var sqls []string
  775. tbName := statement.TableName()
  776. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  777. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  778. for idxName, index := range statement.RefTable.Indexes {
  779. var rIdxName string
  780. if index.Type == core.UniqueType {
  781. rIdxName = uniqueName(idxPrefixName, idxName)
  782. } else if index.Type == core.IndexType {
  783. rIdxName = indexName(idxPrefixName, idxName)
  784. }
  785. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  786. if statement.Engine.dialect.IndexOnTable() {
  787. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  788. }
  789. sqls = append(sqls, sql)
  790. }
  791. return sqls
  792. }
  793. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  794. quote := statement.Engine.Quote
  795. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  796. col.String(statement.Engine.dialect))
  797. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  798. sql += " COMMENT '" + col.Comment + "'"
  799. }
  800. sql += ";"
  801. return sql, []interface{}{}
  802. }
  803. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  804. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  805. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  806. }
  807. func (statement *Statement) mergeConds(bean interface{}) error {
  808. if !statement.noAutoCondition {
  809. var addedTableName = (len(statement.JoinStr) > 0)
  810. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  811. if err != nil {
  812. return err
  813. }
  814. statement.cond = statement.cond.And(autoCond)
  815. }
  816. if err := statement.processIDParam(); err != nil {
  817. return err
  818. }
  819. return nil
  820. }
  821. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  822. if err := statement.mergeConds(bean); err != nil {
  823. return "", nil, err
  824. }
  825. return builder.ToSQL(statement.cond)
  826. }
  827. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  828. v := rValue(bean)
  829. isStruct := v.Kind() == reflect.Struct
  830. if isStruct {
  831. statement.setRefBean(bean)
  832. }
  833. var columnStr = statement.ColumnStr
  834. if len(statement.selectStr) > 0 {
  835. columnStr = statement.selectStr
  836. } else {
  837. // TODO: always generate column names, not use * even if join
  838. if len(statement.JoinStr) == 0 {
  839. if len(columnStr) == 0 {
  840. if len(statement.GroupByStr) > 0 {
  841. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  842. } else {
  843. columnStr = statement.genColumnStr()
  844. }
  845. }
  846. } else {
  847. if len(columnStr) == 0 {
  848. if len(statement.GroupByStr) > 0 {
  849. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  850. }
  851. }
  852. }
  853. }
  854. if len(columnStr) == 0 {
  855. columnStr = "*"
  856. }
  857. if isStruct {
  858. if err := statement.mergeConds(bean); err != nil {
  859. return "", nil, err
  860. }
  861. } else {
  862. if err := statement.processIDParam(); err != nil {
  863. return "", nil, err
  864. }
  865. }
  866. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  867. if err != nil {
  868. return "", nil, err
  869. }
  870. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  871. if err != nil {
  872. return "", nil, err
  873. }
  874. return sqlStr, append(statement.joinArgs, condArgs...), nil
  875. }
  876. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  877. var condSQL string
  878. var condArgs []interface{}
  879. var err error
  880. if len(beans) > 0 {
  881. statement.setRefBean(beans[0])
  882. condSQL, condArgs, err = statement.genConds(beans[0])
  883. } else {
  884. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  885. }
  886. if err != nil {
  887. return "", nil, err
  888. }
  889. var selectSQL = statement.selectStr
  890. if len(selectSQL) <= 0 {
  891. if statement.IsDistinct {
  892. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  893. } else {
  894. selectSQL = "count(*)"
  895. }
  896. }
  897. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  898. if err != nil {
  899. return "", nil, err
  900. }
  901. return sqlStr, append(statement.joinArgs, condArgs...), nil
  902. }
  903. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  904. statement.setRefBean(bean)
  905. var sumStrs = make([]string, 0, len(columns))
  906. for _, colName := range columns {
  907. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  908. colName = statement.Engine.Quote(colName)
  909. }
  910. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  911. }
  912. sumSelect := strings.Join(sumStrs, ", ")
  913. condSQL, condArgs, err := statement.genConds(bean)
  914. if err != nil {
  915. return "", nil, err
  916. }
  917. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  918. if err != nil {
  919. return "", nil, err
  920. }
  921. return sqlStr, append(statement.joinArgs, condArgs...), nil
  922. }
  923. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (a string, err error) {
  924. var distinct string
  925. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  926. distinct = "DISTINCT "
  927. }
  928. var dialect = statement.Engine.Dialect()
  929. var quote = statement.Engine.Quote
  930. var top string
  931. var mssqlCondi string
  932. var buf bytes.Buffer
  933. if len(condSQL) > 0 {
  934. fmt.Fprintf(&buf, " WHERE %v", condSQL)
  935. }
  936. var whereStr = buf.String()
  937. var fromStr = " FROM "
  938. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  939. fromStr += statement.TableName()
  940. } else {
  941. fromStr += quote(statement.TableName())
  942. }
  943. if statement.TableAlias != "" {
  944. if dialect.DBType() == core.ORACLE {
  945. fromStr += " " + quote(statement.TableAlias)
  946. } else {
  947. fromStr += " AS " + quote(statement.TableAlias)
  948. }
  949. }
  950. if statement.JoinStr != "" {
  951. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  952. }
  953. if dialect.DBType() == core.MSSQL {
  954. if statement.LimitN > 0 {
  955. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  956. }
  957. if statement.Start > 0 {
  958. var column string
  959. if len(statement.RefTable.PKColumns()) == 0 {
  960. for _, index := range statement.RefTable.Indexes {
  961. if len(index.Cols) == 1 {
  962. column = index.Cols[0]
  963. break
  964. }
  965. }
  966. if len(column) == 0 {
  967. column = statement.RefTable.ColumnsSeq()[0]
  968. }
  969. } else {
  970. column = statement.RefTable.PKColumns()[0].Name
  971. }
  972. if statement.needTableName() {
  973. if len(statement.TableAlias) > 0 {
  974. column = statement.TableAlias + "." + column
  975. } else {
  976. column = statement.TableName() + "." + column
  977. }
  978. }
  979. var orderStr string
  980. if needOrderBy && len(statement.OrderStr) > 0 {
  981. orderStr = " ORDER BY " + statement.OrderStr
  982. }
  983. var groupStr string
  984. if len(statement.GroupByStr) > 0 {
  985. groupStr = " GROUP BY " + statement.GroupByStr
  986. }
  987. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  988. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  989. }
  990. }
  991. // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
  992. a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  993. if len(mssqlCondi) > 0 {
  994. if len(whereStr) > 0 {
  995. a += " AND " + mssqlCondi
  996. } else {
  997. a += " WHERE " + mssqlCondi
  998. }
  999. }
  1000. if statement.GroupByStr != "" {
  1001. a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
  1002. }
  1003. if statement.HavingStr != "" {
  1004. a = fmt.Sprintf("%v %v", a, statement.HavingStr)
  1005. }
  1006. if needOrderBy && statement.OrderStr != "" {
  1007. a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
  1008. }
  1009. if needLimit {
  1010. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1011. if statement.Start > 0 {
  1012. a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
  1013. } else if statement.LimitN > 0 {
  1014. a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
  1015. }
  1016. } else if dialect.DBType() == core.ORACLE {
  1017. if statement.Start != 0 || statement.LimitN != 0 {
  1018. a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
  1019. }
  1020. }
  1021. }
  1022. if statement.IsForUpdate {
  1023. a = dialect.ForUpdateSql(a)
  1024. }
  1025. return
  1026. }
  1027. func (statement *Statement) processIDParam() error {
  1028. if statement.idParam == nil || statement.RefTable == nil {
  1029. return nil
  1030. }
  1031. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1032. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1033. len(statement.RefTable.PrimaryKeys),
  1034. len(*statement.idParam),
  1035. )
  1036. }
  1037. for i, col := range statement.RefTable.PKColumns() {
  1038. var colName = statement.colName(col, statement.TableName())
  1039. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1040. }
  1041. return nil
  1042. }
  1043. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1044. var colnames = make([]string, len(cols))
  1045. for i, col := range cols {
  1046. if includeTableName {
  1047. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1048. "." + statement.Engine.Quote(col.Name)
  1049. } else {
  1050. colnames[i] = statement.Engine.Quote(col.Name)
  1051. }
  1052. }
  1053. return strings.Join(colnames, ", ")
  1054. }
  1055. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1056. if statement.RefTable != nil {
  1057. cols := statement.RefTable.PKColumns()
  1058. if len(cols) == 0 {
  1059. return ""
  1060. }
  1061. colstrs := statement.joinColumns(cols, false)
  1062. sqls := splitNNoCase(sqlStr, " from ", 2)
  1063. if len(sqls) != 2 {
  1064. return ""
  1065. }
  1066. var top string
  1067. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1068. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1069. }
  1070. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1071. return newsql
  1072. }
  1073. return ""
  1074. }
  1075. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1076. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1077. return "", ""
  1078. }
  1079. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1080. sqls := splitNNoCase(sqlStr, "where", 2)
  1081. if len(sqls) != 2 {
  1082. if len(sqls) == 1 {
  1083. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1084. colstrs, statement.Engine.Quote(statement.TableName()))
  1085. }
  1086. return "", ""
  1087. }
  1088. var whereStr = sqls[1]
  1089. //TODO: for postgres only, if any other database?
  1090. var paraStr string
  1091. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1092. paraStr = "$"
  1093. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1094. paraStr = ":"
  1095. }
  1096. if paraStr != "" {
  1097. if strings.Contains(sqls[1], paraStr) {
  1098. dollers := strings.Split(sqls[1], paraStr)
  1099. whereStr = dollers[0]
  1100. for i, c := range dollers[1:] {
  1101. ccs := strings.SplitN(c, " ", 2)
  1102. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1103. }
  1104. }
  1105. }
  1106. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1107. colstrs, statement.Engine.Quote(statement.TableName()),
  1108. whereStr)
  1109. }