statement.go 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql/driver"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "reflect"
  11. "strings"
  12. "time"
  13. "github.com/go-xorm/builder"
  14. "github.com/xormplus/core"
  15. )
  16. // Statement save all the sql info for executing SQL
  17. type Statement struct {
  18. RefTable *core.Table
  19. Engine *Engine
  20. Start int
  21. LimitN int
  22. idParam *core.PK
  23. OrderStr string
  24. JoinStr string
  25. joinArgs []interface{}
  26. GroupByStr string
  27. HavingStr string
  28. ColumnStr string
  29. selectStr string
  30. useAllCols bool
  31. OmitStr string
  32. AltTableName string
  33. tableName string
  34. RawSQL string
  35. RawParams []interface{}
  36. UseCascade bool
  37. UseAutoJoin bool
  38. StoreEngine string
  39. Charset string
  40. UseCache bool
  41. UseAutoTime bool
  42. noAutoCondition bool
  43. IsDistinct bool
  44. IsForUpdate bool
  45. TableAlias string
  46. allUseBool bool
  47. checkVersion bool
  48. unscoped bool
  49. columnMap columnMap
  50. omitColumnMap columnMap
  51. mustColumnMap map[string]bool
  52. nullableMap map[string]bool
  53. incrColumns map[string]incrParam
  54. decrColumns map[string]decrParam
  55. exprColumns map[string]exprParam
  56. cond builder.Cond
  57. bufferSize int
  58. context ContextCache
  59. }
  60. // Init reset all the statement's fields
  61. func (statement *Statement) Init() {
  62. statement.RefTable = nil
  63. statement.Start = 0
  64. statement.LimitN = 0
  65. statement.OrderStr = ""
  66. statement.UseCascade = true
  67. statement.JoinStr = ""
  68. statement.joinArgs = make([]interface{}, 0)
  69. statement.GroupByStr = ""
  70. statement.HavingStr = ""
  71. statement.ColumnStr = ""
  72. statement.OmitStr = ""
  73. statement.columnMap = columnMap{}
  74. statement.omitColumnMap = columnMap{}
  75. statement.AltTableName = ""
  76. statement.tableName = ""
  77. statement.idParam = nil
  78. statement.RawSQL = ""
  79. statement.RawParams = make([]interface{}, 0)
  80. statement.UseCache = true
  81. statement.UseAutoTime = true
  82. statement.noAutoCondition = false
  83. statement.IsDistinct = false
  84. statement.IsForUpdate = false
  85. statement.TableAlias = ""
  86. statement.selectStr = ""
  87. statement.allUseBool = false
  88. statement.useAllCols = false
  89. statement.mustColumnMap = make(map[string]bool)
  90. statement.nullableMap = make(map[string]bool)
  91. statement.checkVersion = true
  92. statement.unscoped = false
  93. statement.incrColumns = make(map[string]incrParam)
  94. statement.decrColumns = make(map[string]decrParam)
  95. statement.exprColumns = make(map[string]exprParam)
  96. statement.cond = builder.NewCond()
  97. statement.bufferSize = 0
  98. statement.context = nil
  99. }
  100. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  101. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  102. statement.noAutoCondition = true
  103. if len(no) > 0 {
  104. statement.noAutoCondition = no[0]
  105. }
  106. return statement
  107. }
  108. // Alias set the table alias
  109. func (statement *Statement) Alias(alias string) *Statement {
  110. statement.TableAlias = alias
  111. return statement
  112. }
  113. // SQL adds raw sql statement
  114. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  115. switch query.(type) {
  116. case (*builder.Builder):
  117. var err error
  118. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  119. if err != nil {
  120. statement.Engine.logger.Error(err)
  121. }
  122. case string:
  123. statement.RawSQL = query.(string)
  124. statement.RawParams = args
  125. default:
  126. statement.Engine.logger.Error("unsupported sql type")
  127. }
  128. return statement
  129. }
  130. // Where add Where statement
  131. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  132. return statement.And(query, args...)
  133. }
  134. // And add Where & and statement
  135. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  136. switch query.(type) {
  137. case string:
  138. cond := builder.Expr(query.(string), args...)
  139. statement.cond = statement.cond.And(cond)
  140. case map[string]interface{}:
  141. cond := builder.Eq(query.(map[string]interface{}))
  142. statement.cond = statement.cond.And(cond)
  143. case builder.Cond:
  144. cond := query.(builder.Cond)
  145. statement.cond = statement.cond.And(cond)
  146. for _, v := range args {
  147. if vv, ok := v.(builder.Cond); ok {
  148. statement.cond = statement.cond.And(vv)
  149. }
  150. }
  151. default:
  152. // TODO: not support condition type
  153. }
  154. return statement
  155. }
  156. // Or add Where & Or statement
  157. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  158. switch query.(type) {
  159. case string:
  160. cond := builder.Expr(query.(string), args...)
  161. statement.cond = statement.cond.Or(cond)
  162. case map[string]interface{}:
  163. cond := builder.Eq(query.(map[string]interface{}))
  164. statement.cond = statement.cond.Or(cond)
  165. case builder.Cond:
  166. cond := query.(builder.Cond)
  167. statement.cond = statement.cond.Or(cond)
  168. for _, v := range args {
  169. if vv, ok := v.(builder.Cond); ok {
  170. statement.cond = statement.cond.Or(vv)
  171. }
  172. }
  173. default:
  174. // TODO: not support condition type
  175. }
  176. return statement
  177. }
  178. // In generate "Where column IN (?) " statement
  179. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  180. in := builder.In(statement.Engine.Quote(column), args...)
  181. statement.cond = statement.cond.And(in)
  182. return statement
  183. }
  184. // NotIn generate "Where column NOT IN (?) " statement
  185. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  186. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  187. statement.cond = statement.cond.And(notIn)
  188. return statement
  189. }
  190. func (statement *Statement) setRefValue(v reflect.Value) error {
  191. var err error
  192. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  193. if err != nil {
  194. return err
  195. }
  196. statement.tableName = statement.Engine.TableName(v, true)
  197. return nil
  198. }
  199. func (statement *Statement) setRefBean(bean interface{}) error {
  200. var err error
  201. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  202. if err != nil {
  203. return err
  204. }
  205. statement.tableName = statement.Engine.TableName(bean, true)
  206. return nil
  207. }
  208. // Auto generating update columnes and values according a struct
  209. func (statement *Statement) buildUpdates(bean interface{},
  210. includeVersion, includeUpdated, includeNil,
  211. includeAutoIncr, update bool) ([]string, []interface{}) {
  212. engine := statement.Engine
  213. table := statement.RefTable
  214. allUseBool := statement.allUseBool
  215. useAllCols := statement.useAllCols
  216. mustColumnMap := statement.mustColumnMap
  217. nullableMap := statement.nullableMap
  218. columnMap := statement.columnMap
  219. omitColumnMap := statement.omitColumnMap
  220. unscoped := statement.unscoped
  221. var colNames = make([]string, 0)
  222. var args = make([]interface{}, 0)
  223. for _, col := range table.Columns() {
  224. if !includeVersion && col.IsVersion {
  225. continue
  226. }
  227. if col.IsCreated {
  228. continue
  229. }
  230. if !includeUpdated && col.IsUpdated {
  231. continue
  232. }
  233. if !includeAutoIncr && col.IsAutoIncrement {
  234. continue
  235. }
  236. if col.IsDeleted && !unscoped {
  237. continue
  238. }
  239. if omitColumnMap.contain(col.Name) {
  240. continue
  241. }
  242. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  243. continue
  244. }
  245. if col.MapType == core.ONLYFROMDB {
  246. continue
  247. }
  248. fieldValuePtr, err := col.ValueOf(bean)
  249. if err != nil {
  250. engine.logger.Error(err)
  251. continue
  252. }
  253. fieldValue := *fieldValuePtr
  254. fieldType := reflect.TypeOf(fieldValue.Interface())
  255. if fieldType == nil {
  256. continue
  257. }
  258. requiredField := useAllCols
  259. includeNil := useAllCols
  260. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  261. if b {
  262. requiredField = true
  263. } else {
  264. continue
  265. }
  266. }
  267. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  268. if b, ok := getFlagForColumn(nullableMap, col); ok {
  269. if b && col.Nullable && isZero(fieldValue.Interface()) {
  270. var nilValue *int
  271. fieldValue = reflect.ValueOf(nilValue)
  272. fieldType = reflect.TypeOf(fieldValue.Interface())
  273. includeNil = true
  274. }
  275. }
  276. var val interface{}
  277. if fieldValue.CanAddr() {
  278. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  279. data, err := structConvert.ToDB()
  280. if err != nil {
  281. engine.logger.Error(err)
  282. } else {
  283. val = data
  284. }
  285. goto APPEND
  286. }
  287. }
  288. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  289. data, err := structConvert.ToDB()
  290. if err != nil {
  291. engine.logger.Error(err)
  292. } else {
  293. val = data
  294. }
  295. goto APPEND
  296. }
  297. if fieldType.Kind() == reflect.Ptr {
  298. if fieldValue.IsNil() {
  299. if includeNil {
  300. args = append(args, nil)
  301. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  302. }
  303. continue
  304. } else if !fieldValue.IsValid() {
  305. continue
  306. } else {
  307. // dereference ptr type to instance type
  308. fieldValue = fieldValue.Elem()
  309. fieldType = reflect.TypeOf(fieldValue.Interface())
  310. requiredField = true
  311. }
  312. }
  313. switch fieldType.Kind() {
  314. case reflect.Bool:
  315. if allUseBool || requiredField {
  316. val = fieldValue.Interface()
  317. } else {
  318. // if a bool in a struct, it will not be as a condition because it default is false,
  319. // please use Where() instead
  320. continue
  321. }
  322. case reflect.String:
  323. if !requiredField && fieldValue.String() == "" {
  324. continue
  325. }
  326. // for MyString, should convert to string or panic
  327. if fieldType.String() != reflect.String.String() {
  328. val = fieldValue.String()
  329. } else {
  330. val = fieldValue.Interface()
  331. }
  332. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  333. if !requiredField && fieldValue.Int() == 0 {
  334. continue
  335. }
  336. val = fieldValue.Interface()
  337. case reflect.Float32, reflect.Float64:
  338. if !requiredField && fieldValue.Float() == 0.0 {
  339. continue
  340. }
  341. val = fieldValue.Interface()
  342. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  343. if !requiredField && fieldValue.Uint() == 0 {
  344. continue
  345. }
  346. t := int64(fieldValue.Uint())
  347. val = reflect.ValueOf(&t).Interface()
  348. case reflect.Struct:
  349. if fieldType.ConvertibleTo(core.TimeType) {
  350. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  351. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  352. continue
  353. }
  354. val = engine.formatColTime(col, t)
  355. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  356. val, _ = nulType.Value()
  357. } else {
  358. if !col.SQLType.IsJson() {
  359. engine.autoMapType(fieldValue)
  360. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  361. if len(table.PrimaryKeys) == 1 {
  362. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  363. // fix non-int pk issues
  364. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  365. val = pkField.Interface()
  366. } else {
  367. continue
  368. }
  369. } else {
  370. //TODO: how to handler?
  371. panic("not supported")
  372. }
  373. } else {
  374. val = fieldValue.Interface()
  375. }
  376. } else {
  377. // Blank struct could not be as update data
  378. if requiredField || !isStructZero(fieldValue) {
  379. bytes, err := json.Marshal(fieldValue.Interface())
  380. if err != nil {
  381. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  382. }
  383. if col.SQLType.IsText() {
  384. val = string(bytes)
  385. } else if col.SQLType.IsBlob() {
  386. val = bytes
  387. }
  388. } else {
  389. continue
  390. }
  391. }
  392. }
  393. case reflect.Array, reflect.Slice, reflect.Map:
  394. if !requiredField {
  395. if fieldValue == reflect.Zero(fieldType) {
  396. continue
  397. }
  398. if fieldType.Kind() == reflect.Array {
  399. if isArrayValueZero(fieldValue) {
  400. continue
  401. }
  402. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  403. continue
  404. }
  405. }
  406. if col.SQLType.IsText() {
  407. bytes, err := json.Marshal(fieldValue.Interface())
  408. if err != nil {
  409. engine.logger.Error(err)
  410. continue
  411. }
  412. val = string(bytes)
  413. } else if col.SQLType.IsBlob() {
  414. var bytes []byte
  415. var err error
  416. if fieldType.Kind() == reflect.Slice &&
  417. fieldType.Elem().Kind() == reflect.Uint8 {
  418. if fieldValue.Len() > 0 {
  419. val = fieldValue.Bytes()
  420. } else {
  421. continue
  422. }
  423. } else if fieldType.Kind() == reflect.Array &&
  424. fieldType.Elem().Kind() == reflect.Uint8 {
  425. val = fieldValue.Slice(0, 0).Interface()
  426. } else {
  427. bytes, err = json.Marshal(fieldValue.Interface())
  428. if err != nil {
  429. engine.logger.Error(err)
  430. continue
  431. }
  432. val = bytes
  433. }
  434. } else {
  435. continue
  436. }
  437. default:
  438. val = fieldValue.Interface()
  439. }
  440. APPEND:
  441. args = append(args, val)
  442. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  443. continue
  444. }
  445. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  446. }
  447. return colNames, args
  448. }
  449. func (statement *Statement) needTableName() bool {
  450. return len(statement.JoinStr) > 0
  451. }
  452. func (statement *Statement) colName(col *core.Column, tableName string) string {
  453. if statement.needTableName() {
  454. var nm = tableName
  455. if len(statement.TableAlias) > 0 {
  456. nm = statement.TableAlias
  457. }
  458. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  459. }
  460. return statement.Engine.Quote(col.Name)
  461. }
  462. // TableName return current tableName
  463. func (statement *Statement) TableName() string {
  464. if statement.AltTableName != "" {
  465. return statement.AltTableName
  466. }
  467. return statement.tableName
  468. }
  469. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  470. func (statement *Statement) ID(id interface{}) *Statement {
  471. idValue := reflect.ValueOf(id)
  472. idType := reflect.TypeOf(idValue.Interface())
  473. switch idType {
  474. case ptrPkType:
  475. if pkPtr, ok := (id).(*core.PK); ok {
  476. statement.idParam = pkPtr
  477. return statement
  478. }
  479. case pkType:
  480. if pk, ok := (id).(core.PK); ok {
  481. statement.idParam = &pk
  482. return statement
  483. }
  484. }
  485. switch idType.Kind() {
  486. case reflect.String:
  487. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  488. return statement
  489. }
  490. statement.idParam = &core.PK{id}
  491. return statement
  492. }
  493. // Incr Generate "Update ... Set column = column + arg" statement
  494. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  495. k := strings.ToLower(column)
  496. if len(arg) > 0 {
  497. statement.incrColumns[k] = incrParam{column, arg[0]}
  498. } else {
  499. statement.incrColumns[k] = incrParam{column, 1}
  500. }
  501. return statement
  502. }
  503. // Decr Generate "Update ... Set column = column - arg" statement
  504. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  505. k := strings.ToLower(column)
  506. if len(arg) > 0 {
  507. statement.decrColumns[k] = decrParam{column, arg[0]}
  508. } else {
  509. statement.decrColumns[k] = decrParam{column, 1}
  510. }
  511. return statement
  512. }
  513. // SetExpr Generate "Update ... Set column = {expression}" statement
  514. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  515. k := strings.ToLower(column)
  516. statement.exprColumns[k] = exprParam{column, expression}
  517. return statement
  518. }
  519. // Generate "Update ... Set column = column + arg" statement
  520. func (statement *Statement) getInc() map[string]incrParam {
  521. return statement.incrColumns
  522. }
  523. // Generate "Update ... Set column = column - arg" statement
  524. func (statement *Statement) getDec() map[string]decrParam {
  525. return statement.decrColumns
  526. }
  527. // Generate "Update ... Set column = {expression}" statement
  528. func (statement *Statement) getExpr() map[string]exprParam {
  529. return statement.exprColumns
  530. }
  531. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  532. newColumns := make([]string, 0)
  533. for _, col := range columns {
  534. col = strings.Replace(col, "`", "", -1)
  535. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  536. ccols := strings.Split(col, ",")
  537. for _, c := range ccols {
  538. fields := strings.Split(strings.TrimSpace(c), ".")
  539. if len(fields) == 1 {
  540. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  541. } else if len(fields) == 2 {
  542. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  543. statement.Engine.quote(fields[1]))
  544. } else {
  545. panic(errors.New("unwanted colnames"))
  546. }
  547. }
  548. }
  549. return newColumns
  550. }
  551. func (statement *Statement) colmap2NewColsWithQuote() []string {
  552. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  553. copy(newColumns, statement.columnMap)
  554. for i := 0; i < len(statement.columnMap); i++ {
  555. newColumns[i] = statement.Engine.Quote(newColumns[i])
  556. }
  557. return newColumns
  558. }
  559. // Distinct generates "DISTINCT col1, col2 " statement
  560. func (statement *Statement) Distinct(columns ...string) *Statement {
  561. statement.IsDistinct = true
  562. statement.Cols(columns...)
  563. return statement
  564. }
  565. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  566. func (statement *Statement) ForUpdate() *Statement {
  567. statement.IsForUpdate = true
  568. return statement
  569. }
  570. // Select replace select
  571. func (statement *Statement) Select(str string) *Statement {
  572. statement.selectStr = str
  573. return statement
  574. }
  575. // Cols generate "col1, col2" statement
  576. func (statement *Statement) Cols(columns ...string) *Statement {
  577. cols := col2NewCols(columns...)
  578. for _, nc := range cols {
  579. statement.columnMap.add(nc)
  580. }
  581. newColumns := statement.colmap2NewColsWithQuote()
  582. statement.ColumnStr = strings.Join(newColumns, ", ")
  583. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  584. return statement
  585. }
  586. // AllCols update use only: update all columns
  587. func (statement *Statement) AllCols() *Statement {
  588. statement.useAllCols = true
  589. return statement
  590. }
  591. // MustCols update use only: must update columns
  592. func (statement *Statement) MustCols(columns ...string) *Statement {
  593. newColumns := col2NewCols(columns...)
  594. for _, nc := range newColumns {
  595. statement.mustColumnMap[strings.ToLower(nc)] = true
  596. }
  597. return statement
  598. }
  599. // UseBool indicates that use bool fields as update contents and query contiditions
  600. func (statement *Statement) UseBool(columns ...string) *Statement {
  601. if len(columns) > 0 {
  602. statement.MustCols(columns...)
  603. } else {
  604. statement.allUseBool = true
  605. }
  606. return statement
  607. }
  608. // Omit do not use the columns
  609. func (statement *Statement) Omit(columns ...string) {
  610. newColumns := col2NewCols(columns...)
  611. for _, nc := range newColumns {
  612. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  613. }
  614. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  615. }
  616. // Nullable Update use only: update columns to null when value is nullable and zero-value
  617. func (statement *Statement) Nullable(columns ...string) {
  618. newColumns := col2NewCols(columns...)
  619. for _, nc := range newColumns {
  620. statement.nullableMap[strings.ToLower(nc)] = true
  621. }
  622. }
  623. // Top generate LIMIT limit statement
  624. func (statement *Statement) Top(limit int) *Statement {
  625. statement.Limit(limit)
  626. return statement
  627. }
  628. // Limit generate LIMIT start, limit statement
  629. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  630. statement.LimitN = limit
  631. if len(start) > 0 {
  632. statement.Start = start[0]
  633. }
  634. return statement
  635. }
  636. // OrderBy generate "Order By order" statement
  637. func (statement *Statement) OrderBy(order string) *Statement {
  638. if len(statement.OrderStr) > 0 {
  639. statement.OrderStr += ", "
  640. }
  641. statement.OrderStr += order
  642. return statement
  643. }
  644. // Desc generate `ORDER BY xx DESC`
  645. func (statement *Statement) Desc(colNames ...string) *Statement {
  646. var buf builder.StringBuilder
  647. if len(statement.OrderStr) > 0 {
  648. fmt.Fprint(&buf, statement.OrderStr, ", ")
  649. }
  650. newColNames := statement.col2NewColsWithQuote(colNames...)
  651. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  652. statement.OrderStr = buf.String()
  653. return statement
  654. }
  655. // Asc provide asc order by query condition, the input parameters are columns.
  656. func (statement *Statement) Asc(colNames ...string) *Statement {
  657. var buf builder.StringBuilder
  658. if len(statement.OrderStr) > 0 {
  659. fmt.Fprint(&buf, statement.OrderStr, ", ")
  660. }
  661. newColNames := statement.col2NewColsWithQuote(colNames...)
  662. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  663. statement.OrderStr = buf.String()
  664. return statement
  665. }
  666. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  667. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  668. v := rValue(tableNameOrBean)
  669. t := v.Type()
  670. if t.Kind() == reflect.Struct {
  671. var err error
  672. statement.RefTable, err = statement.Engine.autoMapType(v)
  673. if err != nil {
  674. statement.Engine.logger.Error(err)
  675. return statement
  676. }
  677. }
  678. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  679. return statement
  680. }
  681. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  682. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  683. var buf builder.StringBuilder
  684. if len(statement.JoinStr) > 0 {
  685. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  686. } else {
  687. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  688. }
  689. tbName := statement.Engine.TableName(tablename, true)
  690. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  691. statement.JoinStr = buf.String()
  692. statement.joinArgs = append(statement.joinArgs, args...)
  693. return statement
  694. }
  695. // GroupBy generate "Group By keys" statement
  696. func (statement *Statement) GroupBy(keys string) *Statement {
  697. statement.GroupByStr = keys
  698. return statement
  699. }
  700. // Having generate "Having conditions" statement
  701. func (statement *Statement) Having(conditions string) *Statement {
  702. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  703. return statement
  704. }
  705. // Unscoped always disable struct tag "deleted"
  706. func (statement *Statement) Unscoped() *Statement {
  707. statement.unscoped = true
  708. return statement
  709. }
  710. func (statement *Statement) genColumnStr() string {
  711. if statement.RefTable == nil {
  712. return ""
  713. }
  714. var buf builder.StringBuilder
  715. columns := statement.RefTable.Columns()
  716. for _, col := range columns {
  717. if statement.omitColumnMap.contain(col.Name) {
  718. continue
  719. }
  720. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  721. continue
  722. }
  723. if col.MapType == core.ONLYTODB {
  724. continue
  725. }
  726. if buf.Len() != 0 {
  727. buf.WriteString(", ")
  728. }
  729. if statement.JoinStr != "" {
  730. if statement.TableAlias != "" {
  731. buf.WriteString(statement.TableAlias)
  732. } else {
  733. buf.WriteString(statement.TableName())
  734. }
  735. buf.WriteString(".")
  736. }
  737. statement.Engine.QuoteTo(&buf, col.Name)
  738. }
  739. return buf.String()
  740. }
  741. func (statement *Statement) genCreateTableSQL() string {
  742. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  743. statement.StoreEngine, statement.Charset)
  744. }
  745. func (statement *Statement) genIndexSQL() []string {
  746. var sqls []string
  747. tbName := statement.TableName()
  748. for _, index := range statement.RefTable.Indexes {
  749. if index.Type == core.IndexType {
  750. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  751. if sql != "" {
  752. sqls = append(sqls, sql)
  753. }
  754. }
  755. }
  756. return sqls
  757. }
  758. func uniqueName(tableName, uqeName string) string {
  759. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  760. }
  761. func (statement *Statement) genUniqueSQL() []string {
  762. var sqls []string
  763. tbName := statement.TableName()
  764. for _, index := range statement.RefTable.Indexes {
  765. if index.Type == core.UniqueType {
  766. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  767. sqls = append(sqls, sql)
  768. }
  769. }
  770. return sqls
  771. }
  772. func (statement *Statement) genDelIndexSQL() []string {
  773. var sqls []string
  774. tbName := statement.TableName()
  775. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  776. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  777. for idxName, index := range statement.RefTable.Indexes {
  778. var rIdxName string
  779. if index.Type == core.UniqueType {
  780. rIdxName = uniqueName(idxPrefixName, idxName)
  781. } else if index.Type == core.IndexType {
  782. rIdxName = indexName(idxPrefixName, idxName)
  783. }
  784. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  785. if statement.Engine.dialect.IndexOnTable() {
  786. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  787. }
  788. sqls = append(sqls, sql)
  789. }
  790. return sqls
  791. }
  792. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  793. quote := statement.Engine.Quote
  794. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  795. col.String(statement.Engine.dialect))
  796. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  797. sql += " COMMENT '" + col.Comment + "'"
  798. }
  799. sql += ";"
  800. return sql, []interface{}{}
  801. }
  802. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  803. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  804. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  805. }
  806. func (statement *Statement) mergeConds(bean interface{}) error {
  807. if !statement.noAutoCondition {
  808. var addedTableName = (len(statement.JoinStr) > 0)
  809. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  810. if err != nil {
  811. return err
  812. }
  813. statement.cond = statement.cond.And(autoCond)
  814. }
  815. if err := statement.processIDParam(); err != nil {
  816. return err
  817. }
  818. return nil
  819. }
  820. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  821. if err := statement.mergeConds(bean); err != nil {
  822. return "", nil, err
  823. }
  824. return builder.ToSQL(statement.cond)
  825. }
  826. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  827. v := rValue(bean)
  828. isStruct := v.Kind() == reflect.Struct
  829. if isStruct {
  830. statement.setRefBean(bean)
  831. }
  832. var columnStr = statement.ColumnStr
  833. if len(statement.selectStr) > 0 {
  834. columnStr = statement.selectStr
  835. } else {
  836. // TODO: always generate column names, not use * even if join
  837. if len(statement.JoinStr) == 0 {
  838. if len(columnStr) == 0 {
  839. if len(statement.GroupByStr) > 0 {
  840. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  841. } else {
  842. columnStr = statement.genColumnStr()
  843. }
  844. }
  845. } else {
  846. if len(columnStr) == 0 {
  847. if len(statement.GroupByStr) > 0 {
  848. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  849. }
  850. }
  851. }
  852. }
  853. if len(columnStr) == 0 {
  854. columnStr = "*"
  855. }
  856. if isStruct {
  857. if err := statement.mergeConds(bean); err != nil {
  858. return "", nil, err
  859. }
  860. } else {
  861. if err := statement.processIDParam(); err != nil {
  862. return "", nil, err
  863. }
  864. }
  865. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  866. if err != nil {
  867. return "", nil, err
  868. }
  869. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  870. if err != nil {
  871. return "", nil, err
  872. }
  873. return sqlStr, append(statement.joinArgs, condArgs...), nil
  874. }
  875. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  876. var condSQL string
  877. var condArgs []interface{}
  878. var err error
  879. if len(beans) > 0 {
  880. statement.setRefBean(beans[0])
  881. condSQL, condArgs, err = statement.genConds(beans[0])
  882. } else {
  883. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  884. }
  885. if err != nil {
  886. return "", nil, err
  887. }
  888. var selectSQL = statement.selectStr
  889. if len(selectSQL) <= 0 {
  890. if statement.IsDistinct {
  891. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  892. } else {
  893. selectSQL = "count(*)"
  894. }
  895. }
  896. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  897. if err != nil {
  898. return "", nil, err
  899. }
  900. return sqlStr, append(statement.joinArgs, condArgs...), nil
  901. }
  902. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  903. statement.setRefBean(bean)
  904. var sumStrs = make([]string, 0, len(columns))
  905. for _, colName := range columns {
  906. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  907. colName = statement.Engine.Quote(colName)
  908. }
  909. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  910. }
  911. sumSelect := strings.Join(sumStrs, ", ")
  912. condSQL, condArgs, err := statement.genConds(bean)
  913. if err != nil {
  914. return "", nil, err
  915. }
  916. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  917. if err != nil {
  918. return "", nil, err
  919. }
  920. return sqlStr, append(statement.joinArgs, condArgs...), nil
  921. }
  922. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
  923. var (
  924. distinct string
  925. dialect = statement.Engine.Dialect()
  926. quote = statement.Engine.Quote
  927. fromStr = " FROM "
  928. top, mssqlCondi, whereStr string
  929. )
  930. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  931. distinct = "DISTINCT "
  932. }
  933. if len(condSQL) > 0 {
  934. whereStr = " WHERE " + condSQL
  935. }
  936. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  937. fromStr += statement.TableName()
  938. } else {
  939. fromStr += quote(statement.TableName())
  940. }
  941. if statement.TableAlias != "" {
  942. if dialect.DBType() == core.ORACLE {
  943. fromStr += " " + quote(statement.TableAlias)
  944. } else {
  945. fromStr += " AS " + quote(statement.TableAlias)
  946. }
  947. }
  948. if statement.JoinStr != "" {
  949. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  950. }
  951. if dialect.DBType() == core.MSSQL {
  952. if statement.LimitN > 0 {
  953. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  954. }
  955. if statement.Start > 0 {
  956. var column string
  957. if len(statement.RefTable.PKColumns()) == 0 {
  958. for _, index := range statement.RefTable.Indexes {
  959. if len(index.Cols) == 1 {
  960. column = index.Cols[0]
  961. break
  962. }
  963. }
  964. if len(column) == 0 {
  965. column = statement.RefTable.ColumnsSeq()[0]
  966. }
  967. } else {
  968. column = statement.RefTable.PKColumns()[0].Name
  969. }
  970. if statement.needTableName() {
  971. if len(statement.TableAlias) > 0 {
  972. column = statement.TableAlias + "." + column
  973. } else {
  974. column = statement.TableName() + "." + column
  975. }
  976. }
  977. var orderStr string
  978. if needOrderBy && len(statement.OrderStr) > 0 {
  979. orderStr = " ORDER BY " + statement.OrderStr
  980. }
  981. var groupStr string
  982. if len(statement.GroupByStr) > 0 {
  983. groupStr = " GROUP BY " + statement.GroupByStr
  984. }
  985. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  986. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  987. }
  988. }
  989. var buf builder.StringBuilder
  990. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  991. if len(mssqlCondi) > 0 {
  992. if len(whereStr) > 0 {
  993. fmt.Fprint(&buf, " AND ", mssqlCondi)
  994. } else {
  995. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  996. }
  997. }
  998. if statement.GroupByStr != "" {
  999. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  1000. }
  1001. if statement.HavingStr != "" {
  1002. fmt.Fprint(&buf, " ", statement.HavingStr)
  1003. }
  1004. if needOrderBy && statement.OrderStr != "" {
  1005. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  1006. }
  1007. if needLimit {
  1008. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1009. if statement.Start > 0 {
  1010. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
  1011. } else if statement.LimitN > 0 {
  1012. fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
  1013. }
  1014. } else if dialect.DBType() == core.ORACLE {
  1015. if statement.Start != 0 || statement.LimitN != 0 {
  1016. oldString := buf.String()
  1017. buf.Reset()
  1018. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  1019. columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start)
  1020. }
  1021. }
  1022. }
  1023. if statement.IsForUpdate {
  1024. return dialect.ForUpdateSql(buf.String()), nil
  1025. }
  1026. return buf.String(), nil
  1027. }
  1028. func (statement *Statement) processIDParam() error {
  1029. if statement.idParam == nil || statement.RefTable == nil {
  1030. return nil
  1031. }
  1032. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1033. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1034. len(statement.RefTable.PrimaryKeys),
  1035. len(*statement.idParam),
  1036. )
  1037. }
  1038. for i, col := range statement.RefTable.PKColumns() {
  1039. var colName = statement.colName(col, statement.TableName())
  1040. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1041. }
  1042. return nil
  1043. }
  1044. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1045. var colnames = make([]string, len(cols))
  1046. for i, col := range cols {
  1047. if includeTableName {
  1048. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1049. "." + statement.Engine.Quote(col.Name)
  1050. } else {
  1051. colnames[i] = statement.Engine.Quote(col.Name)
  1052. }
  1053. }
  1054. return strings.Join(colnames, ", ")
  1055. }
  1056. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1057. if statement.RefTable != nil {
  1058. cols := statement.RefTable.PKColumns()
  1059. if len(cols) == 0 {
  1060. return ""
  1061. }
  1062. colstrs := statement.joinColumns(cols, false)
  1063. sqls := splitNNoCase(sqlStr, " from ", 2)
  1064. if len(sqls) != 2 {
  1065. return ""
  1066. }
  1067. var top string
  1068. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1069. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1070. }
  1071. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1072. return newsql
  1073. }
  1074. return ""
  1075. }
  1076. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1077. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1078. return "", ""
  1079. }
  1080. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1081. sqls := splitNNoCase(sqlStr, "where", 2)
  1082. if len(sqls) != 2 {
  1083. if len(sqls) == 1 {
  1084. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1085. colstrs, statement.Engine.Quote(statement.TableName()))
  1086. }
  1087. return "", ""
  1088. }
  1089. var whereStr = sqls[1]
  1090. //TODO: for postgres only, if any other database?
  1091. var paraStr string
  1092. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1093. paraStr = "$"
  1094. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1095. paraStr = ":"
  1096. }
  1097. if paraStr != "" {
  1098. if strings.Contains(sqls[1], paraStr) {
  1099. dollers := strings.Split(sqls[1], paraStr)
  1100. whereStr = dollers[0]
  1101. for i, c := range dollers[1:] {
  1102. ccs := strings.SplitN(c, " ", 2)
  1103. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1104. }
  1105. }
  1106. }
  1107. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1108. colstrs, statement.Engine.Quote(statement.TableName()),
  1109. whereStr)
  1110. }