statement.go 35 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql/driver"
  7. "fmt"
  8. "reflect"
  9. "strings"
  10. "time"
  11. "github.com/xormplus/builder"
  12. "github.com/xormplus/core"
  13. )
  14. // Statement save all the sql info for executing SQL
  15. type Statement struct {
  16. RefTable *core.Table
  17. Engine *Engine
  18. Start int
  19. LimitN int
  20. idParam *core.PK
  21. OrderStr string
  22. JoinStr string
  23. joinArgs []interface{}
  24. GroupByStr string
  25. HavingStr string
  26. ColumnStr string
  27. selectStr string
  28. useAllCols bool
  29. OmitStr string
  30. AltTableName string
  31. tableName string
  32. RawSQL string
  33. RawParams []interface{}
  34. UseCascade bool
  35. UseAutoJoin bool
  36. StoreEngine string
  37. Charset string
  38. UseCache bool
  39. UseAutoTime bool
  40. noAutoCondition bool
  41. IsDistinct bool
  42. IsForUpdate bool
  43. TableAlias string
  44. allUseBool bool
  45. checkVersion bool
  46. unscoped bool
  47. columnMap columnMap
  48. omitColumnMap columnMap
  49. mustColumnMap map[string]bool
  50. nullableMap map[string]bool
  51. incrColumns map[string]incrParam
  52. decrColumns map[string]decrParam
  53. exprColumns map[string]exprParam
  54. cond builder.Cond
  55. bufferSize int
  56. context ContextCache
  57. lastError error
  58. }
  59. // Init reset all the statement's fields
  60. func (statement *Statement) Init() {
  61. statement.RefTable = nil
  62. statement.Start = 0
  63. statement.LimitN = 0
  64. statement.OrderStr = ""
  65. statement.UseCascade = true
  66. statement.JoinStr = ""
  67. statement.joinArgs = make([]interface{}, 0)
  68. statement.GroupByStr = ""
  69. statement.HavingStr = ""
  70. statement.ColumnStr = ""
  71. statement.OmitStr = ""
  72. statement.columnMap = columnMap{}
  73. statement.omitColumnMap = columnMap{}
  74. statement.AltTableName = ""
  75. statement.tableName = ""
  76. statement.idParam = nil
  77. statement.RawSQL = ""
  78. statement.RawParams = make([]interface{}, 0)
  79. statement.UseCache = true
  80. statement.UseAutoTime = true
  81. statement.noAutoCondition = false
  82. statement.IsDistinct = false
  83. statement.IsForUpdate = false
  84. statement.TableAlias = ""
  85. statement.selectStr = ""
  86. statement.allUseBool = false
  87. statement.useAllCols = false
  88. statement.mustColumnMap = make(map[string]bool)
  89. statement.nullableMap = make(map[string]bool)
  90. statement.checkVersion = true
  91. statement.unscoped = false
  92. statement.incrColumns = make(map[string]incrParam)
  93. statement.decrColumns = make(map[string]decrParam)
  94. statement.exprColumns = make(map[string]exprParam)
  95. statement.cond = builder.NewCond()
  96. statement.bufferSize = 0
  97. statement.context = nil
  98. statement.lastError = nil
  99. }
  100. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  101. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  102. statement.noAutoCondition = true
  103. if len(no) > 0 {
  104. statement.noAutoCondition = no[0]
  105. }
  106. return statement
  107. }
  108. // Alias set the table alias
  109. func (statement *Statement) Alias(alias string) *Statement {
  110. statement.TableAlias = alias
  111. return statement
  112. }
  113. // SQL adds raw sql statement
  114. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  115. switch query.(type) {
  116. case (*builder.Builder):
  117. var err error
  118. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  119. if err != nil {
  120. statement.lastError = err
  121. }
  122. case string:
  123. statement.RawSQL = query.(string)
  124. statement.RawParams = args
  125. default:
  126. statement.lastError = ErrUnSupportedSQLType
  127. }
  128. return statement
  129. }
  130. // Where add Where statement
  131. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  132. return statement.And(query, args...)
  133. }
  134. // And add Where & and statement
  135. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  136. switch query.(type) {
  137. case string:
  138. isExpr := false
  139. var cargs []interface{}
  140. for i, _ := range args {
  141. if _, ok := args[i].(sqlExpr); ok {
  142. isExpr = true
  143. }
  144. cargs = append(cargs, args[i])
  145. }
  146. if isExpr {
  147. sqlStr, _ := ConvertToBoundSQL(query.(string), cargs)
  148. cond := builder.Expr(sqlStr)
  149. statement.cond = statement.cond.And(cond)
  150. } else {
  151. cond := builder.Expr(query.(string), args...)
  152. statement.cond = statement.cond.And(cond)
  153. }
  154. case map[string]interface{}:
  155. cond := builder.Eq(query.(map[string]interface{}))
  156. statement.cond = statement.cond.And(cond)
  157. case builder.Cond:
  158. cond := query.(builder.Cond)
  159. statement.cond = statement.cond.And(cond)
  160. for _, v := range args {
  161. if vv, ok := v.(builder.Cond); ok {
  162. statement.cond = statement.cond.And(vv)
  163. }
  164. }
  165. default:
  166. statement.lastError = ErrConditionType
  167. }
  168. return statement
  169. }
  170. // Or add Where & Or statement
  171. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  172. switch query.(type) {
  173. case string:
  174. isExpr := false
  175. var cargs []interface{}
  176. for i, _ := range args {
  177. if _, ok := args[i].(sqlExpr); ok {
  178. isExpr = true
  179. }
  180. cargs = append(cargs, args[i])
  181. }
  182. if isExpr {
  183. sqlStr, _ := ConvertToBoundSQL(query.(string), cargs)
  184. cond := builder.Expr(sqlStr)
  185. statement.cond = statement.cond.Or(cond)
  186. } else {
  187. cond := builder.Expr(query.(string), args...)
  188. statement.cond = statement.cond.Or(cond)
  189. }
  190. case map[string]interface{}:
  191. cond := builder.Eq(query.(map[string]interface{}))
  192. statement.cond = statement.cond.Or(cond)
  193. case builder.Cond:
  194. cond := query.(builder.Cond)
  195. statement.cond = statement.cond.Or(cond)
  196. for _, v := range args {
  197. if vv, ok := v.(builder.Cond); ok {
  198. statement.cond = statement.cond.Or(vv)
  199. }
  200. }
  201. default:
  202. // TODO: not support condition type
  203. }
  204. return statement
  205. }
  206. // In generate "Where column IN (?) " statement
  207. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  208. in := builder.In(statement.Engine.Quote(column), args...)
  209. statement.cond = statement.cond.And(in)
  210. return statement
  211. }
  212. // NotIn generate "Where column NOT IN (?) " statement
  213. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  214. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  215. statement.cond = statement.cond.And(notIn)
  216. return statement
  217. }
  218. func (statement *Statement) setRefValue(v reflect.Value) error {
  219. var err error
  220. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  221. if err != nil {
  222. return err
  223. }
  224. statement.tableName = statement.Engine.TableName(v, true)
  225. return nil
  226. }
  227. func (statement *Statement) setRefBean(bean interface{}) error {
  228. var err error
  229. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  230. if err != nil {
  231. return err
  232. }
  233. statement.tableName = statement.Engine.TableName(bean, true)
  234. return nil
  235. }
  236. // Auto generating update columnes and values according a struct
  237. func (statement *Statement) buildUpdates(bean interface{},
  238. includeVersion, includeUpdated, includeNil,
  239. includeAutoIncr, update bool) ([]string, []interface{}) {
  240. engine := statement.Engine
  241. table := statement.RefTable
  242. allUseBool := statement.allUseBool
  243. useAllCols := statement.useAllCols
  244. mustColumnMap := statement.mustColumnMap
  245. nullableMap := statement.nullableMap
  246. columnMap := statement.columnMap
  247. omitColumnMap := statement.omitColumnMap
  248. unscoped := statement.unscoped
  249. var colNames = make([]string, 0)
  250. var args = make([]interface{}, 0)
  251. for _, col := range table.Columns() {
  252. if !includeVersion && col.IsVersion {
  253. continue
  254. }
  255. if col.IsCreated {
  256. continue
  257. }
  258. if !includeUpdated && col.IsUpdated {
  259. continue
  260. }
  261. if !includeAutoIncr && col.IsAutoIncrement {
  262. continue
  263. }
  264. if col.IsDeleted && !unscoped {
  265. continue
  266. }
  267. if omitColumnMap.contain(col.Name) {
  268. continue
  269. }
  270. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  271. continue
  272. }
  273. if col.MapType == core.ONLYFROMDB {
  274. continue
  275. }
  276. fieldValuePtr, err := col.ValueOf(bean)
  277. if err != nil {
  278. engine.logger.Error(err)
  279. continue
  280. }
  281. fieldValue := *fieldValuePtr
  282. fieldType := reflect.TypeOf(fieldValue.Interface())
  283. if fieldType == nil {
  284. continue
  285. }
  286. requiredField := useAllCols
  287. includeNil := useAllCols
  288. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  289. if b {
  290. requiredField = true
  291. } else {
  292. continue
  293. }
  294. }
  295. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  296. if b, ok := getFlagForColumn(nullableMap, col); ok {
  297. if b && col.Nullable && isZero(fieldValue.Interface()) {
  298. var nilValue *int
  299. fieldValue = reflect.ValueOf(nilValue)
  300. fieldType = reflect.TypeOf(fieldValue.Interface())
  301. includeNil = true
  302. }
  303. }
  304. var val interface{}
  305. if fieldValue.CanAddr() {
  306. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  307. data, err := structConvert.ToDB()
  308. if err != nil {
  309. engine.logger.Error(err)
  310. } else {
  311. val = data
  312. }
  313. goto APPEND
  314. }
  315. }
  316. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  317. data, err := structConvert.ToDB()
  318. if err != nil {
  319. engine.logger.Error(err)
  320. } else {
  321. val = data
  322. }
  323. goto APPEND
  324. }
  325. if fieldType.Kind() == reflect.Ptr {
  326. if fieldValue.IsNil() {
  327. if includeNil {
  328. args = append(args, nil)
  329. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  330. }
  331. continue
  332. } else if !fieldValue.IsValid() {
  333. continue
  334. } else {
  335. // dereference ptr type to instance type
  336. fieldValue = fieldValue.Elem()
  337. fieldType = reflect.TypeOf(fieldValue.Interface())
  338. requiredField = true
  339. }
  340. }
  341. switch fieldType.Kind() {
  342. case reflect.Bool:
  343. if allUseBool || requiredField {
  344. val = fieldValue.Interface()
  345. } else {
  346. // if a bool in a struct, it will not be as a condition because it default is false,
  347. // please use Where() instead
  348. continue
  349. }
  350. case reflect.String:
  351. if !requiredField && fieldValue.String() == "" {
  352. continue
  353. }
  354. // for MyString, should convert to string or panic
  355. if fieldType.String() != reflect.String.String() {
  356. val = fieldValue.String()
  357. } else {
  358. val = fieldValue.Interface()
  359. }
  360. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  361. if !requiredField && fieldValue.Int() == 0 {
  362. continue
  363. }
  364. val = fieldValue.Interface()
  365. case reflect.Float32, reflect.Float64:
  366. if !requiredField && fieldValue.Float() == 0.0 {
  367. continue
  368. }
  369. val = fieldValue.Interface()
  370. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  371. if !requiredField && fieldValue.Uint() == 0 {
  372. continue
  373. }
  374. t := int64(fieldValue.Uint())
  375. val = reflect.ValueOf(&t).Interface()
  376. case reflect.Struct:
  377. if fieldType.ConvertibleTo(core.TimeType) {
  378. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  379. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  380. continue
  381. }
  382. val = engine.formatColTime(col, t)
  383. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  384. val, _ = nulType.Value()
  385. } else {
  386. if !col.SQLType.IsJson() {
  387. engine.autoMapType(fieldValue)
  388. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  389. if len(table.PrimaryKeys) == 1 {
  390. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  391. // fix non-int pk issues
  392. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  393. val = pkField.Interface()
  394. } else {
  395. continue
  396. }
  397. } else {
  398. // TODO: how to handler?
  399. panic("not supported")
  400. }
  401. } else {
  402. val = fieldValue.Interface()
  403. }
  404. } else {
  405. // Blank struct could not be as update data
  406. if requiredField || !isStructZero(fieldValue) {
  407. bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
  408. if err != nil {
  409. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  410. }
  411. if col.SQLType.IsText() {
  412. val = string(bytes)
  413. } else if col.SQLType.IsBlob() {
  414. val = bytes
  415. }
  416. } else {
  417. continue
  418. }
  419. }
  420. }
  421. case reflect.Array, reflect.Slice, reflect.Map:
  422. if !requiredField {
  423. if fieldValue == reflect.Zero(fieldType) {
  424. continue
  425. }
  426. if fieldType.Kind() == reflect.Array {
  427. if isArrayValueZero(fieldValue) {
  428. continue
  429. }
  430. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  431. continue
  432. }
  433. }
  434. if col.SQLType.IsText() {
  435. bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
  436. if err != nil {
  437. engine.logger.Error(err)
  438. continue
  439. }
  440. val = string(bytes)
  441. } else if col.SQLType.IsBlob() {
  442. var bytes []byte
  443. var err error
  444. if fieldType.Kind() == reflect.Slice &&
  445. fieldType.Elem().Kind() == reflect.Uint8 {
  446. if fieldValue.Len() > 0 {
  447. val = fieldValue.Bytes()
  448. } else {
  449. continue
  450. }
  451. } else if fieldType.Kind() == reflect.Array &&
  452. fieldType.Elem().Kind() == reflect.Uint8 {
  453. val = fieldValue.Slice(0, 0).Interface()
  454. } else {
  455. bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
  456. if err != nil {
  457. engine.logger.Error(err)
  458. continue
  459. }
  460. val = bytes
  461. }
  462. } else {
  463. continue
  464. }
  465. default:
  466. val = fieldValue.Interface()
  467. }
  468. APPEND:
  469. args = append(args, val)
  470. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  471. continue
  472. }
  473. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  474. }
  475. return colNames, args
  476. }
  477. func (statement *Statement) needTableName() bool {
  478. return len(statement.JoinStr) > 0
  479. }
  480. func (statement *Statement) colName(col *core.Column, tableName string) string {
  481. if statement.needTableName() {
  482. var nm = tableName
  483. if len(statement.TableAlias) > 0 {
  484. nm = statement.TableAlias
  485. }
  486. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  487. }
  488. return statement.Engine.Quote(col.Name)
  489. }
  490. // TableName return current tableName
  491. func (statement *Statement) TableName() string {
  492. if statement.AltTableName != "" {
  493. return statement.AltTableName
  494. }
  495. return statement.tableName
  496. }
  497. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  498. func (statement *Statement) ID(id interface{}) *Statement {
  499. idValue := reflect.ValueOf(id)
  500. idType := reflect.TypeOf(idValue.Interface())
  501. switch idType {
  502. case ptrPkType:
  503. if pkPtr, ok := (id).(*core.PK); ok {
  504. statement.idParam = pkPtr
  505. return statement
  506. }
  507. case pkType:
  508. if pk, ok := (id).(core.PK); ok {
  509. statement.idParam = &pk
  510. return statement
  511. }
  512. }
  513. switch idType.Kind() {
  514. case reflect.String:
  515. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  516. return statement
  517. }
  518. statement.idParam = &core.PK{id}
  519. return statement
  520. }
  521. // Incr Generate "Update ... Set column = column + arg" statement
  522. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  523. k := strings.ToLower(column)
  524. if len(arg) > 0 {
  525. statement.incrColumns[k] = incrParam{column, arg[0]}
  526. } else {
  527. statement.incrColumns[k] = incrParam{column, 1}
  528. }
  529. return statement
  530. }
  531. // Decr Generate "Update ... Set column = column - arg" statement
  532. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  533. k := strings.ToLower(column)
  534. if len(arg) > 0 {
  535. statement.decrColumns[k] = decrParam{column, arg[0]}
  536. } else {
  537. statement.decrColumns[k] = decrParam{column, 1}
  538. }
  539. return statement
  540. }
  541. // SetExpr Generate "Update ... Set column = {expression}" statement
  542. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  543. k := strings.ToLower(column)
  544. statement.exprColumns[k] = exprParam{column, expression}
  545. return statement
  546. }
  547. // Generate "Update ... Set column = column + arg" statement
  548. func (statement *Statement) getInc() map[string]incrParam {
  549. return statement.incrColumns
  550. }
  551. // Generate "Update ... Set column = column - arg" statement
  552. func (statement *Statement) getDec() map[string]decrParam {
  553. return statement.decrColumns
  554. }
  555. // Generate "Update ... Set column = {expression}" statement
  556. func (statement *Statement) getExpr() map[string]exprParam {
  557. return statement.exprColumns
  558. }
  559. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  560. newColumns := make([]string, 0)
  561. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  562. for _, col := range columns {
  563. newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
  564. }
  565. return newColumns
  566. }
  567. func (statement *Statement) colmap2NewColsWithQuote() []string {
  568. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  569. copy(newColumns, statement.columnMap)
  570. for i := 0; i < len(statement.columnMap); i++ {
  571. newColumns[i] = statement.Engine.Quote(newColumns[i])
  572. }
  573. return newColumns
  574. }
  575. // Distinct generates "DISTINCT col1, col2 " statement
  576. func (statement *Statement) Distinct(columns ...string) *Statement {
  577. statement.IsDistinct = true
  578. statement.Cols(columns...)
  579. return statement
  580. }
  581. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  582. func (statement *Statement) ForUpdate() *Statement {
  583. statement.IsForUpdate = true
  584. return statement
  585. }
  586. // Select replace select
  587. func (statement *Statement) Select(str string) *Statement {
  588. statement.selectStr = str
  589. return statement
  590. }
  591. // Cols generate "col1, col2" statement
  592. func (statement *Statement) Cols(columns ...string) *Statement {
  593. cols := col2NewCols(columns...)
  594. for _, nc := range cols {
  595. statement.columnMap.add(nc)
  596. }
  597. newColumns := statement.colmap2NewColsWithQuote()
  598. statement.ColumnStr = strings.Join(newColumns, ", ")
  599. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  600. return statement
  601. }
  602. // AllCols update use only: update all columns
  603. func (statement *Statement) AllCols() *Statement {
  604. statement.useAllCols = true
  605. return statement
  606. }
  607. // MustCols update use only: must update columns
  608. func (statement *Statement) MustCols(columns ...string) *Statement {
  609. newColumns := col2NewCols(columns...)
  610. for _, nc := range newColumns {
  611. statement.mustColumnMap[strings.ToLower(nc)] = true
  612. }
  613. return statement
  614. }
  615. // UseBool indicates that use bool fields as update contents and query contiditions
  616. func (statement *Statement) UseBool(columns ...string) *Statement {
  617. if len(columns) > 0 {
  618. statement.MustCols(columns...)
  619. } else {
  620. statement.allUseBool = true
  621. }
  622. return statement
  623. }
  624. // Omit do not use the columns
  625. func (statement *Statement) Omit(columns ...string) {
  626. newColumns := col2NewCols(columns...)
  627. for _, nc := range newColumns {
  628. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  629. }
  630. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  631. }
  632. // Nullable Update use only: update columns to null when value is nullable and zero-value
  633. func (statement *Statement) Nullable(columns ...string) {
  634. newColumns := col2NewCols(columns...)
  635. for _, nc := range newColumns {
  636. statement.nullableMap[strings.ToLower(nc)] = true
  637. }
  638. }
  639. // Top generate LIMIT limit statement
  640. func (statement *Statement) Top(limit int) *Statement {
  641. statement.Limit(limit)
  642. return statement
  643. }
  644. // Limit generate LIMIT start, limit statement
  645. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  646. statement.LimitN = limit
  647. if len(start) > 0 {
  648. statement.Start = start[0]
  649. }
  650. return statement
  651. }
  652. // OrderBy generate "Order By order" statement
  653. func (statement *Statement) OrderBy(order string) *Statement {
  654. if len(statement.OrderStr) > 0 {
  655. statement.OrderStr += ", "
  656. }
  657. statement.OrderStr += order
  658. return statement
  659. }
  660. // Desc generate `ORDER BY xx DESC`
  661. func (statement *Statement) Desc(colNames ...string) *Statement {
  662. var buf builder.StringBuilder
  663. if len(statement.OrderStr) > 0 {
  664. fmt.Fprint(&buf, statement.OrderStr, ", ")
  665. }
  666. newColNames := statement.col2NewColsWithQuote(colNames...)
  667. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  668. statement.OrderStr = buf.String()
  669. return statement
  670. }
  671. // Asc provide asc order by query condition, the input parameters are columns.
  672. func (statement *Statement) Asc(colNames ...string) *Statement {
  673. var buf builder.StringBuilder
  674. if len(statement.OrderStr) > 0 {
  675. fmt.Fprint(&buf, statement.OrderStr, ", ")
  676. }
  677. newColNames := statement.col2NewColsWithQuote(colNames...)
  678. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  679. statement.OrderStr = buf.String()
  680. return statement
  681. }
  682. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  683. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  684. v := rValue(tableNameOrBean)
  685. t := v.Type()
  686. if t.Kind() == reflect.Struct {
  687. var err error
  688. statement.RefTable, err = statement.Engine.autoMapType(v)
  689. if err != nil {
  690. statement.Engine.logger.Error(err)
  691. return statement
  692. }
  693. }
  694. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  695. return statement
  696. }
  697. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  698. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  699. var buf builder.StringBuilder
  700. if len(statement.JoinStr) > 0 {
  701. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  702. } else {
  703. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  704. }
  705. switch tp := tablename.(type) {
  706. case builder.Builder:
  707. subSQL, subQueryArgs, err := tp.ToSQL()
  708. if err != nil {
  709. statement.lastError = err
  710. return statement
  711. }
  712. tbs := strings.Split(tp.TableName(), ".")
  713. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  714. var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
  715. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  716. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  717. case *builder.Builder:
  718. subSQL, subQueryArgs, err := tp.ToSQL()
  719. if err != nil {
  720. statement.lastError = err
  721. return statement
  722. }
  723. tbs := strings.Split(tp.TableName(), ".")
  724. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  725. var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
  726. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  727. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  728. default:
  729. tbName := statement.Engine.TableName(tablename, true)
  730. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  731. }
  732. statement.JoinStr = buf.String()
  733. statement.joinArgs = append(statement.joinArgs, args...)
  734. return statement
  735. }
  736. // GroupBy generate "Group By keys" statement
  737. func (statement *Statement) GroupBy(keys string) *Statement {
  738. statement.GroupByStr = keys
  739. return statement
  740. }
  741. // Having generate "Having conditions" statement
  742. func (statement *Statement) Having(conditions string) *Statement {
  743. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  744. return statement
  745. }
  746. // Unscoped always disable struct tag "deleted"
  747. func (statement *Statement) Unscoped() *Statement {
  748. statement.unscoped = true
  749. return statement
  750. }
  751. func (statement *Statement) genColumnStr() string {
  752. if statement.RefTable == nil {
  753. return ""
  754. }
  755. var buf builder.StringBuilder
  756. columns := statement.RefTable.Columns()
  757. for _, col := range columns {
  758. if statement.omitColumnMap.contain(col.Name) {
  759. continue
  760. }
  761. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  762. continue
  763. }
  764. if col.MapType == core.ONLYTODB {
  765. continue
  766. }
  767. if buf.Len() != 0 {
  768. buf.WriteString(", ")
  769. }
  770. if statement.JoinStr != "" {
  771. if statement.TableAlias != "" {
  772. buf.WriteString(statement.TableAlias)
  773. } else {
  774. buf.WriteString(statement.TableName())
  775. }
  776. buf.WriteString(".")
  777. }
  778. statement.Engine.QuoteTo(&buf, col.Name)
  779. }
  780. return buf.String()
  781. }
  782. func (statement *Statement) genCreateTableSQL() string {
  783. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  784. statement.StoreEngine, statement.Charset)
  785. }
  786. func (statement *Statement) genIndexSQL() []string {
  787. var sqls []string
  788. tbName := statement.TableName()
  789. for _, index := range statement.RefTable.Indexes {
  790. if index.Type == core.IndexType {
  791. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  792. if sql != "" {
  793. sqls = append(sqls, sql)
  794. }
  795. }
  796. }
  797. return sqls
  798. }
  799. func uniqueName(tableName, uqeName string) string {
  800. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  801. }
  802. func (statement *Statement) genUniqueSQL() []string {
  803. var sqls []string
  804. tbName := statement.TableName()
  805. for _, index := range statement.RefTable.Indexes {
  806. if index.Type == core.UniqueType {
  807. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  808. sqls = append(sqls, sql)
  809. }
  810. }
  811. return sqls
  812. }
  813. func (statement *Statement) genDelIndexSQL() []string {
  814. var sqls []string
  815. tbName := statement.TableName()
  816. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  817. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  818. for idxName, index := range statement.RefTable.Indexes {
  819. var rIdxName string
  820. if index.Type == core.UniqueType {
  821. rIdxName = uniqueName(idxPrefixName, idxName)
  822. } else if index.Type == core.IndexType {
  823. rIdxName = indexName(idxPrefixName, idxName)
  824. }
  825. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  826. if statement.Engine.dialect.IndexOnTable() {
  827. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  828. }
  829. sqls = append(sqls, sql)
  830. }
  831. return sqls
  832. }
  833. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  834. quote := statement.Engine.Quote
  835. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  836. col.String(statement.Engine.dialect))
  837. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  838. sql += " COMMENT '" + col.Comment + "'"
  839. }
  840. sql += ";"
  841. return sql, []interface{}{}
  842. }
  843. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  844. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  845. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  846. }
  847. func (statement *Statement) mergeConds(bean interface{}) error {
  848. if !statement.noAutoCondition {
  849. var addedTableName = (len(statement.JoinStr) > 0)
  850. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  851. if err != nil {
  852. return err
  853. }
  854. statement.cond = statement.cond.And(autoCond)
  855. }
  856. if err := statement.processIDParam(); err != nil {
  857. return err
  858. }
  859. return nil
  860. }
  861. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  862. if err := statement.mergeConds(bean); err != nil {
  863. return "", nil, err
  864. }
  865. return builder.ToSQL(statement.cond)
  866. }
  867. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  868. v := rValue(bean)
  869. isStruct := v.Kind() == reflect.Struct
  870. if isStruct {
  871. statement.setRefBean(bean)
  872. }
  873. var columnStr = statement.ColumnStr
  874. if len(statement.selectStr) > 0 {
  875. columnStr = statement.selectStr
  876. } else {
  877. // TODO: always generate column names, not use * even if join
  878. if len(statement.JoinStr) == 0 {
  879. if len(columnStr) == 0 {
  880. if len(statement.GroupByStr) > 0 {
  881. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  882. } else {
  883. columnStr = statement.genColumnStr()
  884. }
  885. }
  886. } else {
  887. if len(columnStr) == 0 {
  888. if len(statement.GroupByStr) > 0 {
  889. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  890. }
  891. }
  892. }
  893. }
  894. if len(columnStr) == 0 {
  895. columnStr = "*"
  896. }
  897. if isStruct {
  898. if err := statement.mergeConds(bean); err != nil {
  899. return "", nil, err
  900. }
  901. } else {
  902. if err := statement.processIDParam(); err != nil {
  903. return "", nil, err
  904. }
  905. }
  906. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  907. if err != nil {
  908. return "", nil, err
  909. }
  910. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  911. if err != nil {
  912. return "", nil, err
  913. }
  914. return sqlStr, append(statement.joinArgs, condArgs...), nil
  915. }
  916. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  917. var condSQL string
  918. var condArgs []interface{}
  919. var err error
  920. if len(beans) > 0 {
  921. statement.setRefBean(beans[0])
  922. condSQL, condArgs, err = statement.genConds(beans[0])
  923. } else {
  924. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  925. }
  926. if err != nil {
  927. return "", nil, err
  928. }
  929. var selectSQL = statement.selectStr
  930. if len(selectSQL) <= 0 {
  931. if statement.IsDistinct {
  932. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  933. } else {
  934. selectSQL = "count(*)"
  935. }
  936. }
  937. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  938. if err != nil {
  939. return "", nil, err
  940. }
  941. return sqlStr, append(statement.joinArgs, condArgs...), nil
  942. }
  943. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  944. statement.setRefBean(bean)
  945. var sumStrs = make([]string, 0, len(columns))
  946. for _, colName := range columns {
  947. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  948. colName = statement.Engine.Quote(colName)
  949. }
  950. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  951. }
  952. sumSelect := strings.Join(sumStrs, ", ")
  953. condSQL, condArgs, err := statement.genConds(bean)
  954. if err != nil {
  955. return "", nil, err
  956. }
  957. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  958. if err != nil {
  959. return "", nil, err
  960. }
  961. return sqlStr, append(statement.joinArgs, condArgs...), nil
  962. }
  963. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
  964. var (
  965. distinct string
  966. dialect = statement.Engine.Dialect()
  967. quote = statement.Engine.Quote
  968. fromStr = " FROM "
  969. top, mssqlCondi, whereStr string
  970. )
  971. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  972. distinct = "DISTINCT "
  973. }
  974. if len(condSQL) > 0 {
  975. whereStr = " WHERE " + condSQL
  976. }
  977. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  978. fromStr += statement.TableName()
  979. } else {
  980. fromStr += quote(statement.TableName())
  981. }
  982. if statement.TableAlias != "" {
  983. if dialect.DBType() == core.ORACLE {
  984. fromStr += " " + quote(statement.TableAlias)
  985. } else {
  986. fromStr += " AS " + quote(statement.TableAlias)
  987. }
  988. }
  989. if statement.JoinStr != "" {
  990. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  991. }
  992. if dialect.DBType() == core.MSSQL {
  993. if statement.LimitN > 0 {
  994. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  995. }
  996. if statement.Start > 0 {
  997. var column string
  998. if len(statement.RefTable.PKColumns()) == 0 {
  999. for _, index := range statement.RefTable.Indexes {
  1000. if len(index.Cols) == 1 {
  1001. column = index.Cols[0]
  1002. break
  1003. }
  1004. }
  1005. if len(column) == 0 {
  1006. column = statement.RefTable.ColumnsSeq()[0]
  1007. }
  1008. } else {
  1009. column = statement.RefTable.PKColumns()[0].Name
  1010. }
  1011. if statement.needTableName() {
  1012. if len(statement.TableAlias) > 0 {
  1013. column = statement.TableAlias + "." + column
  1014. } else {
  1015. column = statement.TableName() + "." + column
  1016. }
  1017. }
  1018. var orderStr string
  1019. if needOrderBy && len(statement.OrderStr) > 0 {
  1020. orderStr = " ORDER BY " + statement.OrderStr
  1021. }
  1022. var groupStr string
  1023. if len(statement.GroupByStr) > 0 {
  1024. groupStr = " GROUP BY " + statement.GroupByStr
  1025. }
  1026. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  1027. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  1028. }
  1029. }
  1030. var buf builder.StringBuilder
  1031. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  1032. if len(mssqlCondi) > 0 {
  1033. if len(whereStr) > 0 {
  1034. fmt.Fprint(&buf, " AND ", mssqlCondi)
  1035. } else {
  1036. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  1037. }
  1038. }
  1039. if statement.GroupByStr != "" {
  1040. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  1041. }
  1042. if statement.HavingStr != "" {
  1043. fmt.Fprint(&buf, " ", statement.HavingStr)
  1044. }
  1045. if needOrderBy && statement.OrderStr != "" {
  1046. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  1047. }
  1048. if needLimit {
  1049. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1050. if statement.Start > 0 {
  1051. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
  1052. } else if statement.LimitN > 0 {
  1053. fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
  1054. }
  1055. } else if dialect.DBType() == core.ORACLE {
  1056. if statement.Start != 0 || statement.LimitN != 0 {
  1057. oldString := buf.String()
  1058. buf.Reset()
  1059. rawColStr := columnStr
  1060. if rawColStr == "*" {
  1061. rawColStr = "at.*"
  1062. }
  1063. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  1064. columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start)
  1065. }
  1066. }
  1067. }
  1068. if statement.IsForUpdate {
  1069. return dialect.ForUpdateSql(buf.String()), nil
  1070. }
  1071. return buf.String(), nil
  1072. }
  1073. func (statement *Statement) processIDParam() error {
  1074. if statement.idParam != nil {
  1075. if statement.RefTable == nil {
  1076. return fmt.Errorf("ID condition is error, cant find ref table object")
  1077. } else if len(statement.RefTable.PKColumns()) == 0 {
  1078. return fmt.Errorf("ID condition is error, cant find table pkcolumns")
  1079. }
  1080. }
  1081. if statement.idParam == nil || statement.RefTable == nil {
  1082. return nil
  1083. }
  1084. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1085. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1086. len(statement.RefTable.PrimaryKeys),
  1087. len(*statement.idParam),
  1088. )
  1089. }
  1090. for i, col := range statement.RefTable.PKColumns() {
  1091. var colName = statement.colName(col, statement.TableName())
  1092. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1093. }
  1094. return nil
  1095. }
  1096. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1097. var colnames = make([]string, len(cols))
  1098. for i, col := range cols {
  1099. if includeTableName {
  1100. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1101. "." + statement.Engine.Quote(col.Name)
  1102. } else {
  1103. colnames[i] = statement.Engine.Quote(col.Name)
  1104. }
  1105. }
  1106. return strings.Join(colnames, ", ")
  1107. }
  1108. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1109. if statement.RefTable != nil {
  1110. cols := statement.RefTable.PKColumns()
  1111. if len(cols) == 0 {
  1112. return ""
  1113. }
  1114. colstrs := statement.joinColumns(cols, false)
  1115. sqls := splitNNoCase(sqlStr, " from ", 2)
  1116. if len(sqls) != 2 {
  1117. return ""
  1118. }
  1119. var top string
  1120. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1121. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1122. }
  1123. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1124. return newsql
  1125. }
  1126. return ""
  1127. }
  1128. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1129. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1130. return "", ""
  1131. }
  1132. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1133. sqls := splitNNoCase(sqlStr, "where", 2)
  1134. if len(sqls) != 2 {
  1135. if len(sqls) == 1 {
  1136. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1137. colstrs, statement.Engine.Quote(statement.TableName()))
  1138. }
  1139. return "", ""
  1140. }
  1141. var whereStr = sqls[1]
  1142. // TODO: for postgres only, if any other database?
  1143. var paraStr string
  1144. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1145. paraStr = "$"
  1146. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1147. paraStr = ":"
  1148. }
  1149. if paraStr != "" {
  1150. if strings.Contains(sqls[1], paraStr) {
  1151. dollers := strings.Split(sqls[1], paraStr)
  1152. whereStr = dollers[0]
  1153. for i, c := range dollers[1:] {
  1154. ccs := strings.SplitN(c, " ", 2)
  1155. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1156. }
  1157. }
  1158. }
  1159. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1160. colstrs, statement.Engine.Quote(statement.TableName()),
  1161. whereStr)
  1162. }