statement.go 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql/driver"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "reflect"
  11. "strings"
  12. "time"
  13. "github.com/xormplus/builder"
  14. "github.com/xormplus/core"
  15. )
  16. // Statement save all the sql info for executing SQL
  17. type Statement struct {
  18. RefTable *core.Table
  19. Engine *Engine
  20. Start int
  21. LimitN int
  22. idParam *core.PK
  23. OrderStr string
  24. JoinStr string
  25. joinArgs []interface{}
  26. GroupByStr string
  27. HavingStr string
  28. ColumnStr string
  29. selectStr string
  30. useAllCols bool
  31. OmitStr string
  32. AltTableName string
  33. tableName string
  34. RawSQL string
  35. RawParams []interface{}
  36. UseCascade bool
  37. UseAutoJoin bool
  38. StoreEngine string
  39. Charset string
  40. UseCache bool
  41. UseAutoTime bool
  42. noAutoCondition bool
  43. IsDistinct bool
  44. IsForUpdate bool
  45. TableAlias string
  46. allUseBool bool
  47. checkVersion bool
  48. unscoped bool
  49. columnMap columnMap
  50. omitColumnMap columnMap
  51. mustColumnMap map[string]bool
  52. nullableMap map[string]bool
  53. incrColumns map[string]incrParam
  54. decrColumns map[string]decrParam
  55. exprColumns map[string]exprParam
  56. cond builder.Cond
  57. bufferSize int
  58. context ContextCache
  59. lastError error
  60. }
  61. // Init reset all the statement's fields
  62. func (statement *Statement) Init() {
  63. statement.RefTable = nil
  64. statement.Start = 0
  65. statement.LimitN = 0
  66. statement.OrderStr = ""
  67. statement.UseCascade = true
  68. statement.JoinStr = ""
  69. statement.joinArgs = make([]interface{}, 0)
  70. statement.GroupByStr = ""
  71. statement.HavingStr = ""
  72. statement.ColumnStr = ""
  73. statement.OmitStr = ""
  74. statement.columnMap = columnMap{}
  75. statement.omitColumnMap = columnMap{}
  76. statement.AltTableName = ""
  77. statement.tableName = ""
  78. statement.idParam = nil
  79. statement.RawSQL = ""
  80. statement.RawParams = make([]interface{}, 0)
  81. statement.UseCache = true
  82. statement.UseAutoTime = true
  83. statement.noAutoCondition = false
  84. statement.IsDistinct = false
  85. statement.IsForUpdate = false
  86. statement.TableAlias = ""
  87. statement.selectStr = ""
  88. statement.allUseBool = false
  89. statement.useAllCols = false
  90. statement.mustColumnMap = make(map[string]bool)
  91. statement.nullableMap = make(map[string]bool)
  92. statement.checkVersion = true
  93. statement.unscoped = false
  94. statement.incrColumns = make(map[string]incrParam)
  95. statement.decrColumns = make(map[string]decrParam)
  96. statement.exprColumns = make(map[string]exprParam)
  97. statement.cond = builder.NewCond()
  98. statement.bufferSize = 0
  99. statement.context = nil
  100. statement.lastError = nil
  101. }
  102. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  103. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  104. statement.noAutoCondition = true
  105. if len(no) > 0 {
  106. statement.noAutoCondition = no[0]
  107. }
  108. return statement
  109. }
  110. // Alias set the table alias
  111. func (statement *Statement) Alias(alias string) *Statement {
  112. statement.TableAlias = alias
  113. return statement
  114. }
  115. // SQL adds raw sql statement
  116. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  117. switch query.(type) {
  118. case (*builder.Builder):
  119. var err error
  120. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  121. if err != nil {
  122. statement.lastError = err
  123. }
  124. case string:
  125. statement.RawSQL = query.(string)
  126. statement.RawParams = args
  127. default:
  128. statement.lastError = ErrUnSupportedSQLType
  129. }
  130. return statement
  131. }
  132. // Where add Where statement
  133. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  134. return statement.And(query, args...)
  135. }
  136. // And add Where & and statement
  137. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  138. switch query.(type) {
  139. case string:
  140. cond := builder.Expr(query.(string), args...)
  141. statement.cond = statement.cond.And(cond)
  142. case map[string]interface{}:
  143. cond := builder.Eq(query.(map[string]interface{}))
  144. statement.cond = statement.cond.And(cond)
  145. case builder.Cond:
  146. cond := query.(builder.Cond)
  147. statement.cond = statement.cond.And(cond)
  148. for _, v := range args {
  149. if vv, ok := v.(builder.Cond); ok {
  150. statement.cond = statement.cond.And(vv)
  151. }
  152. }
  153. default:
  154. statement.lastError = ErrConditionType
  155. }
  156. return statement
  157. }
  158. // Or add Where & Or statement
  159. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  160. switch query.(type) {
  161. case string:
  162. cond := builder.Expr(query.(string), args...)
  163. statement.cond = statement.cond.Or(cond)
  164. case map[string]interface{}:
  165. cond := builder.Eq(query.(map[string]interface{}))
  166. statement.cond = statement.cond.Or(cond)
  167. case builder.Cond:
  168. cond := query.(builder.Cond)
  169. statement.cond = statement.cond.Or(cond)
  170. for _, v := range args {
  171. if vv, ok := v.(builder.Cond); ok {
  172. statement.cond = statement.cond.Or(vv)
  173. }
  174. }
  175. default:
  176. // TODO: not support condition type
  177. }
  178. return statement
  179. }
  180. // In generate "Where column IN (?) " statement
  181. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  182. in := builder.In(statement.Engine.Quote(column), args...)
  183. statement.cond = statement.cond.And(in)
  184. return statement
  185. }
  186. // NotIn generate "Where column NOT IN (?) " statement
  187. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  188. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  189. statement.cond = statement.cond.And(notIn)
  190. return statement
  191. }
  192. func (statement *Statement) setRefValue(v reflect.Value) error {
  193. var err error
  194. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  195. if err != nil {
  196. return err
  197. }
  198. statement.tableName = statement.Engine.TableName(v, true)
  199. return nil
  200. }
  201. func (statement *Statement) setRefBean(bean interface{}) error {
  202. var err error
  203. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  204. if err != nil {
  205. return err
  206. }
  207. statement.tableName = statement.Engine.TableName(bean, true)
  208. return nil
  209. }
  210. // Auto generating update columnes and values according a struct
  211. func (statement *Statement) buildUpdates(bean interface{},
  212. includeVersion, includeUpdated, includeNil,
  213. includeAutoIncr, update bool) ([]string, []interface{}) {
  214. engine := statement.Engine
  215. table := statement.RefTable
  216. allUseBool := statement.allUseBool
  217. useAllCols := statement.useAllCols
  218. mustColumnMap := statement.mustColumnMap
  219. nullableMap := statement.nullableMap
  220. columnMap := statement.columnMap
  221. omitColumnMap := statement.omitColumnMap
  222. unscoped := statement.unscoped
  223. var colNames = make([]string, 0)
  224. var args = make([]interface{}, 0)
  225. for _, col := range table.Columns() {
  226. if !includeVersion && col.IsVersion {
  227. continue
  228. }
  229. if col.IsCreated {
  230. continue
  231. }
  232. if !includeUpdated && col.IsUpdated {
  233. continue
  234. }
  235. if !includeAutoIncr && col.IsAutoIncrement {
  236. continue
  237. }
  238. if col.IsDeleted && !unscoped {
  239. continue
  240. }
  241. if omitColumnMap.contain(col.Name) {
  242. continue
  243. }
  244. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  245. continue
  246. }
  247. if col.MapType == core.ONLYFROMDB {
  248. continue
  249. }
  250. fieldValuePtr, err := col.ValueOf(bean)
  251. if err != nil {
  252. engine.logger.Error(err)
  253. continue
  254. }
  255. fieldValue := *fieldValuePtr
  256. fieldType := reflect.TypeOf(fieldValue.Interface())
  257. if fieldType == nil {
  258. continue
  259. }
  260. requiredField := useAllCols
  261. includeNil := useAllCols
  262. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  263. if b {
  264. requiredField = true
  265. } else {
  266. continue
  267. }
  268. }
  269. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  270. if b, ok := getFlagForColumn(nullableMap, col); ok {
  271. if b && col.Nullable && isZero(fieldValue.Interface()) {
  272. var nilValue *int
  273. fieldValue = reflect.ValueOf(nilValue)
  274. fieldType = reflect.TypeOf(fieldValue.Interface())
  275. includeNil = true
  276. }
  277. }
  278. var val interface{}
  279. if fieldValue.CanAddr() {
  280. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  281. data, err := structConvert.ToDB()
  282. if err != nil {
  283. engine.logger.Error(err)
  284. } else {
  285. val = data
  286. }
  287. goto APPEND
  288. }
  289. }
  290. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  291. data, err := structConvert.ToDB()
  292. if err != nil {
  293. engine.logger.Error(err)
  294. } else {
  295. val = data
  296. }
  297. goto APPEND
  298. }
  299. if fieldType.Kind() == reflect.Ptr {
  300. if fieldValue.IsNil() {
  301. if includeNil {
  302. args = append(args, nil)
  303. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  304. }
  305. continue
  306. } else if !fieldValue.IsValid() {
  307. continue
  308. } else {
  309. // dereference ptr type to instance type
  310. fieldValue = fieldValue.Elem()
  311. fieldType = reflect.TypeOf(fieldValue.Interface())
  312. requiredField = true
  313. }
  314. }
  315. switch fieldType.Kind() {
  316. case reflect.Bool:
  317. if allUseBool || requiredField {
  318. val = fieldValue.Interface()
  319. } else {
  320. // if a bool in a struct, it will not be as a condition because it default is false,
  321. // please use Where() instead
  322. continue
  323. }
  324. case reflect.String:
  325. if !requiredField && fieldValue.String() == "" {
  326. continue
  327. }
  328. // for MyString, should convert to string or panic
  329. if fieldType.String() != reflect.String.String() {
  330. val = fieldValue.String()
  331. } else {
  332. val = fieldValue.Interface()
  333. }
  334. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  335. if !requiredField && fieldValue.Int() == 0 {
  336. continue
  337. }
  338. val = fieldValue.Interface()
  339. case reflect.Float32, reflect.Float64:
  340. if !requiredField && fieldValue.Float() == 0.0 {
  341. continue
  342. }
  343. val = fieldValue.Interface()
  344. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  345. if !requiredField && fieldValue.Uint() == 0 {
  346. continue
  347. }
  348. t := int64(fieldValue.Uint())
  349. val = reflect.ValueOf(&t).Interface()
  350. case reflect.Struct:
  351. if fieldType.ConvertibleTo(core.TimeType) {
  352. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  353. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  354. continue
  355. }
  356. val = engine.formatColTime(col, t)
  357. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  358. val, _ = nulType.Value()
  359. } else {
  360. if !col.SQLType.IsJson() {
  361. engine.autoMapType(fieldValue)
  362. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  363. if len(table.PrimaryKeys) == 1 {
  364. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  365. // fix non-int pk issues
  366. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  367. val = pkField.Interface()
  368. } else {
  369. continue
  370. }
  371. } else {
  372. //TODO: how to handler?
  373. panic("not supported")
  374. }
  375. } else {
  376. val = fieldValue.Interface()
  377. }
  378. } else {
  379. // Blank struct could not be as update data
  380. if requiredField || !isStructZero(fieldValue) {
  381. bytes, err := json.Marshal(fieldValue.Interface())
  382. if err != nil {
  383. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  384. }
  385. if col.SQLType.IsText() {
  386. val = string(bytes)
  387. } else if col.SQLType.IsBlob() {
  388. val = bytes
  389. }
  390. } else {
  391. continue
  392. }
  393. }
  394. }
  395. case reflect.Array, reflect.Slice, reflect.Map:
  396. if !requiredField {
  397. if fieldValue == reflect.Zero(fieldType) {
  398. continue
  399. }
  400. if fieldType.Kind() == reflect.Array {
  401. if isArrayValueZero(fieldValue) {
  402. continue
  403. }
  404. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  405. continue
  406. }
  407. }
  408. if col.SQLType.IsText() {
  409. bytes, err := json.Marshal(fieldValue.Interface())
  410. if err != nil {
  411. engine.logger.Error(err)
  412. continue
  413. }
  414. val = string(bytes)
  415. } else if col.SQLType.IsBlob() {
  416. var bytes []byte
  417. var err error
  418. if fieldType.Kind() == reflect.Slice &&
  419. fieldType.Elem().Kind() == reflect.Uint8 {
  420. if fieldValue.Len() > 0 {
  421. val = fieldValue.Bytes()
  422. } else {
  423. continue
  424. }
  425. } else if fieldType.Kind() == reflect.Array &&
  426. fieldType.Elem().Kind() == reflect.Uint8 {
  427. val = fieldValue.Slice(0, 0).Interface()
  428. } else {
  429. bytes, err = json.Marshal(fieldValue.Interface())
  430. if err != nil {
  431. engine.logger.Error(err)
  432. continue
  433. }
  434. val = bytes
  435. }
  436. } else {
  437. continue
  438. }
  439. default:
  440. val = fieldValue.Interface()
  441. }
  442. APPEND:
  443. args = append(args, val)
  444. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  445. continue
  446. }
  447. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  448. }
  449. return colNames, args
  450. }
  451. func (statement *Statement) needTableName() bool {
  452. return len(statement.JoinStr) > 0
  453. }
  454. func (statement *Statement) colName(col *core.Column, tableName string) string {
  455. if statement.needTableName() {
  456. var nm = tableName
  457. if len(statement.TableAlias) > 0 {
  458. nm = statement.TableAlias
  459. }
  460. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  461. }
  462. return statement.Engine.Quote(col.Name)
  463. }
  464. // TableName return current tableName
  465. func (statement *Statement) TableName() string {
  466. if statement.AltTableName != "" {
  467. return statement.AltTableName
  468. }
  469. return statement.tableName
  470. }
  471. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  472. func (statement *Statement) ID(id interface{}) *Statement {
  473. idValue := reflect.ValueOf(id)
  474. idType := reflect.TypeOf(idValue.Interface())
  475. switch idType {
  476. case ptrPkType:
  477. if pkPtr, ok := (id).(*core.PK); ok {
  478. statement.idParam = pkPtr
  479. return statement
  480. }
  481. case pkType:
  482. if pk, ok := (id).(core.PK); ok {
  483. statement.idParam = &pk
  484. return statement
  485. }
  486. }
  487. switch idType.Kind() {
  488. case reflect.String:
  489. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  490. return statement
  491. }
  492. statement.idParam = &core.PK{id}
  493. return statement
  494. }
  495. // Incr Generate "Update ... Set column = column + arg" statement
  496. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  497. k := strings.ToLower(column)
  498. if len(arg) > 0 {
  499. statement.incrColumns[k] = incrParam{column, arg[0]}
  500. } else {
  501. statement.incrColumns[k] = incrParam{column, 1}
  502. }
  503. return statement
  504. }
  505. // Decr Generate "Update ... Set column = column - arg" statement
  506. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  507. k := strings.ToLower(column)
  508. if len(arg) > 0 {
  509. statement.decrColumns[k] = decrParam{column, arg[0]}
  510. } else {
  511. statement.decrColumns[k] = decrParam{column, 1}
  512. }
  513. return statement
  514. }
  515. // SetExpr Generate "Update ... Set column = {expression}" statement
  516. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  517. k := strings.ToLower(column)
  518. statement.exprColumns[k] = exprParam{column, expression}
  519. return statement
  520. }
  521. // Generate "Update ... Set column = column + arg" statement
  522. func (statement *Statement) getInc() map[string]incrParam {
  523. return statement.incrColumns
  524. }
  525. // Generate "Update ... Set column = column - arg" statement
  526. func (statement *Statement) getDec() map[string]decrParam {
  527. return statement.decrColumns
  528. }
  529. // Generate "Update ... Set column = {expression}" statement
  530. func (statement *Statement) getExpr() map[string]exprParam {
  531. return statement.exprColumns
  532. }
  533. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  534. newColumns := make([]string, 0)
  535. for _, col := range columns {
  536. col = strings.Replace(col, "`", "", -1)
  537. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  538. ccols := strings.Split(col, ",")
  539. for _, c := range ccols {
  540. fields := strings.Split(strings.TrimSpace(c), ".")
  541. if len(fields) == 1 {
  542. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  543. } else if len(fields) == 2 {
  544. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  545. statement.Engine.quote(fields[1]))
  546. } else {
  547. panic(errors.New("unwanted colnames"))
  548. }
  549. }
  550. }
  551. return newColumns
  552. }
  553. func (statement *Statement) colmap2NewColsWithQuote() []string {
  554. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  555. copy(newColumns, statement.columnMap)
  556. for i := 0; i < len(statement.columnMap); i++ {
  557. newColumns[i] = statement.Engine.Quote(newColumns[i])
  558. }
  559. return newColumns
  560. }
  561. // Distinct generates "DISTINCT col1, col2 " statement
  562. func (statement *Statement) Distinct(columns ...string) *Statement {
  563. statement.IsDistinct = true
  564. statement.Cols(columns...)
  565. return statement
  566. }
  567. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  568. func (statement *Statement) ForUpdate() *Statement {
  569. statement.IsForUpdate = true
  570. return statement
  571. }
  572. // Select replace select
  573. func (statement *Statement) Select(str string) *Statement {
  574. statement.selectStr = str
  575. return statement
  576. }
  577. // Cols generate "col1, col2" statement
  578. func (statement *Statement) Cols(columns ...string) *Statement {
  579. cols := col2NewCols(columns...)
  580. for _, nc := range cols {
  581. statement.columnMap.add(nc)
  582. }
  583. newColumns := statement.colmap2NewColsWithQuote()
  584. statement.ColumnStr = strings.Join(newColumns, ", ")
  585. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  586. return statement
  587. }
  588. // AllCols update use only: update all columns
  589. func (statement *Statement) AllCols() *Statement {
  590. statement.useAllCols = true
  591. return statement
  592. }
  593. // MustCols update use only: must update columns
  594. func (statement *Statement) MustCols(columns ...string) *Statement {
  595. newColumns := col2NewCols(columns...)
  596. for _, nc := range newColumns {
  597. statement.mustColumnMap[strings.ToLower(nc)] = true
  598. }
  599. return statement
  600. }
  601. // UseBool indicates that use bool fields as update contents and query contiditions
  602. func (statement *Statement) UseBool(columns ...string) *Statement {
  603. if len(columns) > 0 {
  604. statement.MustCols(columns...)
  605. } else {
  606. statement.allUseBool = true
  607. }
  608. return statement
  609. }
  610. // Omit do not use the columns
  611. func (statement *Statement) Omit(columns ...string) {
  612. newColumns := col2NewCols(columns...)
  613. for _, nc := range newColumns {
  614. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  615. }
  616. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  617. }
  618. // Nullable Update use only: update columns to null when value is nullable and zero-value
  619. func (statement *Statement) Nullable(columns ...string) {
  620. newColumns := col2NewCols(columns...)
  621. for _, nc := range newColumns {
  622. statement.nullableMap[strings.ToLower(nc)] = true
  623. }
  624. }
  625. // Top generate LIMIT limit statement
  626. func (statement *Statement) Top(limit int) *Statement {
  627. statement.Limit(limit)
  628. return statement
  629. }
  630. // Limit generate LIMIT start, limit statement
  631. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  632. statement.LimitN = limit
  633. if len(start) > 0 {
  634. statement.Start = start[0]
  635. }
  636. return statement
  637. }
  638. // OrderBy generate "Order By order" statement
  639. func (statement *Statement) OrderBy(order string) *Statement {
  640. if len(statement.OrderStr) > 0 {
  641. statement.OrderStr += ", "
  642. }
  643. statement.OrderStr += order
  644. return statement
  645. }
  646. // Desc generate `ORDER BY xx DESC`
  647. func (statement *Statement) Desc(colNames ...string) *Statement {
  648. var buf builder.StringBuilder
  649. if len(statement.OrderStr) > 0 {
  650. fmt.Fprint(&buf, statement.OrderStr, ", ")
  651. }
  652. newColNames := statement.col2NewColsWithQuote(colNames...)
  653. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  654. statement.OrderStr = buf.String()
  655. return statement
  656. }
  657. // Asc provide asc order by query condition, the input parameters are columns.
  658. func (statement *Statement) Asc(colNames ...string) *Statement {
  659. var buf builder.StringBuilder
  660. if len(statement.OrderStr) > 0 {
  661. fmt.Fprint(&buf, statement.OrderStr, ", ")
  662. }
  663. newColNames := statement.col2NewColsWithQuote(colNames...)
  664. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  665. statement.OrderStr = buf.String()
  666. return statement
  667. }
  668. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  669. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  670. v := rValue(tableNameOrBean)
  671. t := v.Type()
  672. if t.Kind() == reflect.Struct {
  673. var err error
  674. statement.RefTable, err = statement.Engine.autoMapType(v)
  675. if err != nil {
  676. statement.Engine.logger.Error(err)
  677. return statement
  678. }
  679. }
  680. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  681. return statement
  682. }
  683. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  684. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  685. var buf builder.StringBuilder
  686. if len(statement.JoinStr) > 0 {
  687. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  688. } else {
  689. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  690. }
  691. switch tp := tablename.(type) {
  692. case builder.Builder:
  693. subSQL, subQueryArgs, err := tp.ToSQL()
  694. if err != nil {
  695. statement.lastError = err
  696. return statement
  697. }
  698. tbs := strings.Split(tp.TableName(), ".")
  699. var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
  700. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  701. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  702. case *builder.Builder:
  703. subSQL, subQueryArgs, err := tp.ToSQL()
  704. if err != nil {
  705. statement.lastError = err
  706. return statement
  707. }
  708. tbs := strings.Split(tp.TableName(), ".")
  709. var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
  710. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  711. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  712. default:
  713. tbName := statement.Engine.TableName(tablename, true)
  714. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  715. }
  716. statement.JoinStr = buf.String()
  717. statement.joinArgs = append(statement.joinArgs, args...)
  718. return statement
  719. }
  720. // GroupBy generate "Group By keys" statement
  721. func (statement *Statement) GroupBy(keys string) *Statement {
  722. statement.GroupByStr = keys
  723. return statement
  724. }
  725. // Having generate "Having conditions" statement
  726. func (statement *Statement) Having(conditions string) *Statement {
  727. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  728. return statement
  729. }
  730. // Unscoped always disable struct tag "deleted"
  731. func (statement *Statement) Unscoped() *Statement {
  732. statement.unscoped = true
  733. return statement
  734. }
  735. func (statement *Statement) genColumnStr() string {
  736. if statement.RefTable == nil {
  737. return ""
  738. }
  739. var buf builder.StringBuilder
  740. columns := statement.RefTable.Columns()
  741. for _, col := range columns {
  742. if statement.omitColumnMap.contain(col.Name) {
  743. continue
  744. }
  745. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  746. continue
  747. }
  748. if col.MapType == core.ONLYTODB {
  749. continue
  750. }
  751. if buf.Len() != 0 {
  752. buf.WriteString(", ")
  753. }
  754. if statement.JoinStr != "" {
  755. if statement.TableAlias != "" {
  756. buf.WriteString(statement.TableAlias)
  757. } else {
  758. buf.WriteString(statement.TableName())
  759. }
  760. buf.WriteString(".")
  761. }
  762. statement.Engine.QuoteTo(&buf, col.Name)
  763. }
  764. return buf.String()
  765. }
  766. func (statement *Statement) genCreateTableSQL() string {
  767. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  768. statement.StoreEngine, statement.Charset)
  769. }
  770. func (statement *Statement) genIndexSQL() []string {
  771. var sqls []string
  772. tbName := statement.TableName()
  773. for _, index := range statement.RefTable.Indexes {
  774. if index.Type == core.IndexType {
  775. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  776. if sql != "" {
  777. sqls = append(sqls, sql)
  778. }
  779. }
  780. }
  781. return sqls
  782. }
  783. func uniqueName(tableName, uqeName string) string {
  784. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  785. }
  786. func (statement *Statement) genUniqueSQL() []string {
  787. var sqls []string
  788. tbName := statement.TableName()
  789. for _, index := range statement.RefTable.Indexes {
  790. if index.Type == core.UniqueType {
  791. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  792. sqls = append(sqls, sql)
  793. }
  794. }
  795. return sqls
  796. }
  797. func (statement *Statement) genDelIndexSQL() []string {
  798. var sqls []string
  799. tbName := statement.TableName()
  800. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  801. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  802. for idxName, index := range statement.RefTable.Indexes {
  803. var rIdxName string
  804. if index.Type == core.UniqueType {
  805. rIdxName = uniqueName(idxPrefixName, idxName)
  806. } else if index.Type == core.IndexType {
  807. rIdxName = indexName(idxPrefixName, idxName)
  808. }
  809. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  810. if statement.Engine.dialect.IndexOnTable() {
  811. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  812. }
  813. sqls = append(sqls, sql)
  814. }
  815. return sqls
  816. }
  817. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  818. quote := statement.Engine.Quote
  819. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  820. col.String(statement.Engine.dialect))
  821. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  822. sql += " COMMENT '" + col.Comment + "'"
  823. }
  824. sql += ";"
  825. return sql, []interface{}{}
  826. }
  827. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  828. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  829. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  830. }
  831. func (statement *Statement) mergeConds(bean interface{}) error {
  832. if !statement.noAutoCondition {
  833. var addedTableName = (len(statement.JoinStr) > 0)
  834. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  835. if err != nil {
  836. return err
  837. }
  838. statement.cond = statement.cond.And(autoCond)
  839. }
  840. if err := statement.processIDParam(); err != nil {
  841. return err
  842. }
  843. return nil
  844. }
  845. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  846. if err := statement.mergeConds(bean); err != nil {
  847. return "", nil, err
  848. }
  849. return builder.ToSQL(statement.cond)
  850. }
  851. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  852. v := rValue(bean)
  853. isStruct := v.Kind() == reflect.Struct
  854. if isStruct {
  855. statement.setRefBean(bean)
  856. }
  857. var columnStr = statement.ColumnStr
  858. if len(statement.selectStr) > 0 {
  859. columnStr = statement.selectStr
  860. } else {
  861. // TODO: always generate column names, not use * even if join
  862. if len(statement.JoinStr) == 0 {
  863. if len(columnStr) == 0 {
  864. if len(statement.GroupByStr) > 0 {
  865. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  866. } else {
  867. columnStr = statement.genColumnStr()
  868. }
  869. }
  870. } else {
  871. if len(columnStr) == 0 {
  872. if len(statement.GroupByStr) > 0 {
  873. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  874. }
  875. }
  876. }
  877. }
  878. if len(columnStr) == 0 {
  879. columnStr = "*"
  880. }
  881. if isStruct {
  882. if err := statement.mergeConds(bean); err != nil {
  883. return "", nil, err
  884. }
  885. } else {
  886. if err := statement.processIDParam(); err != nil {
  887. return "", nil, err
  888. }
  889. }
  890. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  891. if err != nil {
  892. return "", nil, err
  893. }
  894. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  895. if err != nil {
  896. return "", nil, err
  897. }
  898. return sqlStr, append(statement.joinArgs, condArgs...), nil
  899. }
  900. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  901. var condSQL string
  902. var condArgs []interface{}
  903. var err error
  904. if len(beans) > 0 {
  905. statement.setRefBean(beans[0])
  906. condSQL, condArgs, err = statement.genConds(beans[0])
  907. } else {
  908. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  909. }
  910. if err != nil {
  911. return "", nil, err
  912. }
  913. var selectSQL = statement.selectStr
  914. if len(selectSQL) <= 0 {
  915. if statement.IsDistinct {
  916. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  917. } else {
  918. selectSQL = "count(*)"
  919. }
  920. }
  921. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  922. if err != nil {
  923. return "", nil, err
  924. }
  925. return sqlStr, append(statement.joinArgs, condArgs...), nil
  926. }
  927. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  928. statement.setRefBean(bean)
  929. var sumStrs = make([]string, 0, len(columns))
  930. for _, colName := range columns {
  931. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  932. colName = statement.Engine.Quote(colName)
  933. }
  934. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  935. }
  936. sumSelect := strings.Join(sumStrs, ", ")
  937. condSQL, condArgs, err := statement.genConds(bean)
  938. if err != nil {
  939. return "", nil, err
  940. }
  941. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  942. if err != nil {
  943. return "", nil, err
  944. }
  945. return sqlStr, append(statement.joinArgs, condArgs...), nil
  946. }
  947. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
  948. var (
  949. distinct string
  950. dialect = statement.Engine.Dialect()
  951. quote = statement.Engine.Quote
  952. fromStr = " FROM "
  953. top, mssqlCondi, whereStr string
  954. )
  955. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  956. distinct = "DISTINCT "
  957. }
  958. if len(condSQL) > 0 {
  959. whereStr = " WHERE " + condSQL
  960. }
  961. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  962. fromStr += statement.TableName()
  963. } else {
  964. fromStr += quote(statement.TableName())
  965. }
  966. if statement.TableAlias != "" {
  967. if dialect.DBType() == core.ORACLE {
  968. fromStr += " " + quote(statement.TableAlias)
  969. } else {
  970. fromStr += " AS " + quote(statement.TableAlias)
  971. }
  972. }
  973. if statement.JoinStr != "" {
  974. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  975. }
  976. if dialect.DBType() == core.MSSQL {
  977. if statement.LimitN > 0 {
  978. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  979. }
  980. if statement.Start > 0 {
  981. var column string
  982. if len(statement.RefTable.PKColumns()) == 0 {
  983. for _, index := range statement.RefTable.Indexes {
  984. if len(index.Cols) == 1 {
  985. column = index.Cols[0]
  986. break
  987. }
  988. }
  989. if len(column) == 0 {
  990. column = statement.RefTable.ColumnsSeq()[0]
  991. }
  992. } else {
  993. column = statement.RefTable.PKColumns()[0].Name
  994. }
  995. if statement.needTableName() {
  996. if len(statement.TableAlias) > 0 {
  997. column = statement.TableAlias + "." + column
  998. } else {
  999. column = statement.TableName() + "." + column
  1000. }
  1001. }
  1002. var orderStr string
  1003. if needOrderBy && len(statement.OrderStr) > 0 {
  1004. orderStr = " ORDER BY " + statement.OrderStr
  1005. }
  1006. var groupStr string
  1007. if len(statement.GroupByStr) > 0 {
  1008. groupStr = " GROUP BY " + statement.GroupByStr
  1009. }
  1010. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  1011. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  1012. }
  1013. }
  1014. var buf builder.StringBuilder
  1015. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  1016. if len(mssqlCondi) > 0 {
  1017. if len(whereStr) > 0 {
  1018. fmt.Fprint(&buf, " AND ", mssqlCondi)
  1019. } else {
  1020. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  1021. }
  1022. }
  1023. if statement.GroupByStr != "" {
  1024. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  1025. }
  1026. if statement.HavingStr != "" {
  1027. fmt.Fprint(&buf, " ", statement.HavingStr)
  1028. }
  1029. if needOrderBy && statement.OrderStr != "" {
  1030. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  1031. }
  1032. if needLimit {
  1033. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1034. if statement.Start > 0 {
  1035. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
  1036. } else if statement.LimitN > 0 {
  1037. fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
  1038. }
  1039. } else if dialect.DBType() == core.ORACLE {
  1040. if statement.Start != 0 || statement.LimitN != 0 {
  1041. oldString := buf.String()
  1042. buf.Reset()
  1043. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  1044. columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start)
  1045. }
  1046. }
  1047. }
  1048. if statement.IsForUpdate {
  1049. return dialect.ForUpdateSql(buf.String()), nil
  1050. }
  1051. return buf.String(), nil
  1052. }
  1053. func (statement *Statement) processIDParam() error {
  1054. if statement.idParam == nil || statement.RefTable == nil {
  1055. return nil
  1056. }
  1057. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1058. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1059. len(statement.RefTable.PrimaryKeys),
  1060. len(*statement.idParam),
  1061. )
  1062. }
  1063. for i, col := range statement.RefTable.PKColumns() {
  1064. var colName = statement.colName(col, statement.TableName())
  1065. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1066. }
  1067. return nil
  1068. }
  1069. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1070. var colnames = make([]string, len(cols))
  1071. for i, col := range cols {
  1072. if includeTableName {
  1073. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1074. "." + statement.Engine.Quote(col.Name)
  1075. } else {
  1076. colnames[i] = statement.Engine.Quote(col.Name)
  1077. }
  1078. }
  1079. return strings.Join(colnames, ", ")
  1080. }
  1081. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1082. if statement.RefTable != nil {
  1083. cols := statement.RefTable.PKColumns()
  1084. if len(cols) == 0 {
  1085. return ""
  1086. }
  1087. colstrs := statement.joinColumns(cols, false)
  1088. sqls := splitNNoCase(sqlStr, " from ", 2)
  1089. if len(sqls) != 2 {
  1090. return ""
  1091. }
  1092. var top string
  1093. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1094. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1095. }
  1096. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1097. return newsql
  1098. }
  1099. return ""
  1100. }
  1101. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1102. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1103. return "", ""
  1104. }
  1105. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1106. sqls := splitNNoCase(sqlStr, "where", 2)
  1107. if len(sqls) != 2 {
  1108. if len(sqls) == 1 {
  1109. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1110. colstrs, statement.Engine.Quote(statement.TableName()))
  1111. }
  1112. return "", ""
  1113. }
  1114. var whereStr = sqls[1]
  1115. //TODO: for postgres only, if any other database?
  1116. var paraStr string
  1117. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1118. paraStr = "$"
  1119. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1120. paraStr = ":"
  1121. }
  1122. if paraStr != "" {
  1123. if strings.Contains(sqls[1], paraStr) {
  1124. dollers := strings.Split(sqls[1], paraStr)
  1125. whereStr = dollers[0]
  1126. for i, c := range dollers[1:] {
  1127. ccs := strings.SplitN(c, " ", 2)
  1128. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1129. }
  1130. }
  1131. }
  1132. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1133. colstrs, statement.Engine.Quote(statement.TableName()),
  1134. whereStr)
  1135. }