statement.go 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "bytes"
  7. "database/sql/driver"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "reflect"
  12. "strings"
  13. "time"
  14. "github.com/go-xorm/builder"
  15. "github.com/xormplus/core"
  16. )
  17. type incrParam struct {
  18. colName string
  19. arg interface{}
  20. }
  21. type decrParam struct {
  22. colName string
  23. arg interface{}
  24. }
  25. type exprParam struct {
  26. colName string
  27. expr string
  28. }
  29. // Statement save all the sql info for executing SQL
  30. type Statement struct {
  31. RefTable *core.Table
  32. Engine *Engine
  33. Start int
  34. LimitN int
  35. IdParam *core.PK
  36. OrderStr string
  37. JoinStr string
  38. joinArgs []interface{}
  39. GroupByStr string
  40. HavingStr string
  41. ColumnStr string
  42. selectStr string
  43. columnMap map[string]bool
  44. useAllCols bool
  45. OmitStr string
  46. AltTableName string
  47. tableName string
  48. RawSQL string
  49. RawParams []interface{}
  50. UseCascade bool
  51. UseAutoJoin bool
  52. StoreEngine string
  53. Charset string
  54. UseCache bool
  55. UseAutoTime bool
  56. noAutoCondition bool
  57. IsDistinct bool
  58. IsForUpdate bool
  59. TableAlias string
  60. allUseBool bool
  61. checkVersion bool
  62. unscoped bool
  63. mustColumnMap map[string]bool
  64. nullableMap map[string]bool
  65. incrColumns map[string]incrParam
  66. decrColumns map[string]decrParam
  67. exprColumns map[string]exprParam
  68. cond builder.Cond
  69. }
  70. // Init reset all the statment's fields
  71. func (statement *Statement) Init() {
  72. statement.RefTable = nil
  73. statement.Start = 0
  74. statement.LimitN = 0
  75. statement.OrderStr = ""
  76. statement.UseCascade = true
  77. statement.JoinStr = ""
  78. statement.joinArgs = make([]interface{}, 0)
  79. statement.GroupByStr = ""
  80. statement.HavingStr = ""
  81. statement.ColumnStr = ""
  82. statement.OmitStr = ""
  83. statement.columnMap = make(map[string]bool)
  84. statement.AltTableName = ""
  85. statement.tableName = ""
  86. statement.IdParam = nil
  87. statement.RawSQL = ""
  88. statement.RawParams = make([]interface{}, 0)
  89. statement.UseCache = true
  90. statement.UseAutoTime = true
  91. statement.noAutoCondition = false
  92. statement.IsDistinct = false
  93. statement.IsForUpdate = false
  94. statement.TableAlias = ""
  95. statement.selectStr = ""
  96. statement.allUseBool = false
  97. statement.useAllCols = false
  98. statement.mustColumnMap = make(map[string]bool)
  99. statement.nullableMap = make(map[string]bool)
  100. statement.checkVersion = true
  101. statement.unscoped = false
  102. statement.incrColumns = make(map[string]incrParam)
  103. statement.decrColumns = make(map[string]decrParam)
  104. statement.exprColumns = make(map[string]exprParam)
  105. statement.cond = builder.NewCond()
  106. }
  107. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  108. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  109. statement.noAutoCondition = true
  110. if len(no) > 0 {
  111. statement.noAutoCondition = no[0]
  112. }
  113. return statement
  114. }
  115. // Alias set the table alias
  116. func (statement *Statement) Alias(alias string) *Statement {
  117. statement.TableAlias = alias
  118. return statement
  119. }
  120. // Sql add the raw sql statement
  121. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  122. switch query.(type) {
  123. case (*builder.Builder):
  124. var err error
  125. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  126. if err != nil {
  127. statement.Engine.logger.Error(err)
  128. }
  129. case string:
  130. statement.RawSQL = query.(string)
  131. statement.RawParams = args
  132. default:
  133. statement.Engine.logger.Error("unsupported sql type")
  134. }
  135. return statement
  136. }
  137. // Where add Where statment
  138. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  139. return statement.And(query, args...)
  140. }
  141. // And add Where & and statment
  142. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  143. switch query.(type) {
  144. case string:
  145. cond := builder.Expr(query.(string), args...)
  146. statement.cond = statement.cond.And(cond)
  147. case builder.Cond:
  148. cond := query.(builder.Cond)
  149. statement.cond = statement.cond.And(cond)
  150. for _, v := range args {
  151. if vv, ok := v.(builder.Cond); ok {
  152. statement.cond = statement.cond.And(vv)
  153. }
  154. }
  155. default:
  156. // TODO: not support condition type
  157. }
  158. return statement
  159. }
  160. // Or add Where & Or statment
  161. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  162. switch query.(type) {
  163. case string:
  164. cond := builder.Expr(query.(string), args...)
  165. statement.cond = statement.cond.Or(cond)
  166. case builder.Cond:
  167. cond := query.(builder.Cond)
  168. statement.cond = statement.cond.Or(cond)
  169. for _, v := range args {
  170. if vv, ok := v.(builder.Cond); ok {
  171. statement.cond = statement.cond.Or(vv)
  172. }
  173. }
  174. default:
  175. // TODO: not support condition type
  176. }
  177. return statement
  178. }
  179. // In generate "Where column IN (?) " statment
  180. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  181. if len(args) == 0 {
  182. return statement
  183. }
  184. in := builder.In(column, args...)
  185. statement.cond = statement.cond.And(in)
  186. return statement
  187. }
  188. // NotIn generate "Where column IN (?) " statment
  189. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  190. if len(args) == 0 {
  191. return statement
  192. }
  193. in := builder.NotIn(column, args...)
  194. statement.cond = statement.cond.And(in)
  195. return statement
  196. }
  197. func (statement *Statement) setRefValue(v reflect.Value) {
  198. statement.RefTable = statement.Engine.autoMapType(reflect.Indirect(v))
  199. statement.tableName = statement.Engine.tbName(v)
  200. }
  201. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  202. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  203. v := rValue(tableNameOrBean)
  204. t := v.Type()
  205. if t.Kind() == reflect.String {
  206. statement.AltTableName = tableNameOrBean.(string)
  207. } else if t.Kind() == reflect.Struct {
  208. statement.RefTable = statement.Engine.autoMapType(v)
  209. statement.AltTableName = statement.Engine.tbName(v)
  210. }
  211. return statement
  212. }
  213. // Auto generating update columnes and values according a struct
  214. func buildUpdates(engine *Engine, table *core.Table, bean interface{},
  215. includeVersion bool, includeUpdated bool, includeNil bool,
  216. includeAutoIncr bool, allUseBool bool, useAllCols bool,
  217. mustColumnMap map[string]bool, nullableMap map[string]bool,
  218. columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) {
  219. var colNames = make([]string, 0)
  220. var args = make([]interface{}, 0)
  221. for _, col := range table.Columns() {
  222. if !includeVersion && col.IsVersion {
  223. continue
  224. }
  225. if col.IsCreated {
  226. continue
  227. }
  228. if !includeUpdated && col.IsUpdated {
  229. continue
  230. }
  231. if !includeAutoIncr && col.IsAutoIncrement {
  232. continue
  233. }
  234. if col.IsDeleted && !unscoped {
  235. continue
  236. }
  237. if use, ok := columnMap[col.Name]; ok && !use {
  238. continue
  239. }
  240. fieldValuePtr, err := col.ValueOf(bean)
  241. if err != nil {
  242. engine.logger.Error(err)
  243. continue
  244. }
  245. fieldValue := *fieldValuePtr
  246. fieldType := reflect.TypeOf(fieldValue.Interface())
  247. requiredField := useAllCols
  248. includeNil := useAllCols
  249. lColName := strings.ToLower(col.Name)
  250. if b, ok := mustColumnMap[lColName]; ok {
  251. if b {
  252. requiredField = true
  253. } else {
  254. continue
  255. }
  256. }
  257. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  258. if b, ok := nullableMap[lColName]; ok {
  259. if b && col.Nullable && isZero(fieldValue.Interface()) {
  260. var nilValue *int
  261. fieldValue = reflect.ValueOf(nilValue)
  262. fieldType = reflect.TypeOf(fieldValue.Interface())
  263. includeNil = true
  264. }
  265. }
  266. var val interface{}
  267. if fieldValue.CanAddr() {
  268. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  269. data, err := structConvert.ToDB()
  270. if err != nil {
  271. engine.logger.Error(err)
  272. } else {
  273. val = data
  274. }
  275. goto APPEND
  276. }
  277. }
  278. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  279. data, err := structConvert.ToDB()
  280. if err != nil {
  281. engine.logger.Error(err)
  282. } else {
  283. val = data
  284. }
  285. goto APPEND
  286. }
  287. if fieldType.Kind() == reflect.Ptr {
  288. if fieldValue.IsNil() {
  289. if includeNil {
  290. args = append(args, nil)
  291. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  292. }
  293. continue
  294. } else if !fieldValue.IsValid() {
  295. continue
  296. } else {
  297. // dereference ptr type to instance type
  298. fieldValue = fieldValue.Elem()
  299. fieldType = reflect.TypeOf(fieldValue.Interface())
  300. requiredField = true
  301. }
  302. }
  303. switch fieldType.Kind() {
  304. case reflect.Bool:
  305. if allUseBool || requiredField {
  306. val = fieldValue.Interface()
  307. } else {
  308. // if a bool in a struct, it will not be as a condition because it default is false,
  309. // please use Where() instead
  310. continue
  311. }
  312. case reflect.String:
  313. if !requiredField && fieldValue.String() == "" {
  314. continue
  315. }
  316. // for MyString, should convert to string or panic
  317. if fieldType.String() != reflect.String.String() {
  318. val = fieldValue.String()
  319. } else {
  320. val = fieldValue.Interface()
  321. }
  322. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  323. if !requiredField && fieldValue.Int() == 0 {
  324. continue
  325. }
  326. val = fieldValue.Interface()
  327. case reflect.Float32, reflect.Float64:
  328. if !requiredField && fieldValue.Float() == 0.0 {
  329. continue
  330. }
  331. val = fieldValue.Interface()
  332. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  333. if !requiredField && fieldValue.Uint() == 0 {
  334. continue
  335. }
  336. t := int64(fieldValue.Uint())
  337. val = reflect.ValueOf(&t).Interface()
  338. case reflect.Struct:
  339. if fieldType.ConvertibleTo(core.TimeType) {
  340. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  341. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  342. continue
  343. }
  344. val = engine.FormatTime(col.SQLType.Name, t)
  345. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  346. val, _ = nulType.Value()
  347. } else {
  348. if !col.SQLType.IsJson() {
  349. engine.autoMapType(fieldValue)
  350. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  351. if len(table.PrimaryKeys) == 1 {
  352. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  353. // fix non-int pk issues
  354. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  355. val = pkField.Interface()
  356. } else {
  357. continue
  358. }
  359. } else {
  360. //TODO: how to handler?
  361. panic("not supported")
  362. }
  363. } else {
  364. val = fieldValue.Interface()
  365. }
  366. } else {
  367. // Blank struct could not be as update data
  368. if requiredField || !isStructZero(fieldValue) {
  369. bytes, err := json.Marshal(fieldValue.Interface())
  370. if err != nil {
  371. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  372. }
  373. if col.SQLType.IsText() {
  374. val = string(bytes)
  375. } else if col.SQLType.IsBlob() {
  376. val = bytes
  377. }
  378. } else {
  379. continue
  380. }
  381. }
  382. }
  383. case reflect.Array, reflect.Slice, reflect.Map:
  384. if !requiredField {
  385. if fieldValue == reflect.Zero(fieldType) {
  386. continue
  387. }
  388. if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  389. continue
  390. }
  391. }
  392. if col.SQLType.IsText() {
  393. bytes, err := json.Marshal(fieldValue.Interface())
  394. if err != nil {
  395. engine.logger.Error(err)
  396. continue
  397. }
  398. val = string(bytes)
  399. } else if col.SQLType.IsBlob() {
  400. var bytes []byte
  401. var err error
  402. if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
  403. fieldType.Elem().Kind() == reflect.Uint8 {
  404. if fieldValue.Len() > 0 {
  405. val = fieldValue.Bytes()
  406. } else {
  407. continue
  408. }
  409. } else {
  410. bytes, err = json.Marshal(fieldValue.Interface())
  411. if err != nil {
  412. engine.logger.Error(err)
  413. continue
  414. }
  415. val = bytes
  416. }
  417. } else {
  418. continue
  419. }
  420. default:
  421. val = fieldValue.Interface()
  422. }
  423. APPEND:
  424. args = append(args, val)
  425. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  426. continue
  427. }
  428. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  429. }
  430. return colNames, args
  431. }
  432. func (statement *Statement) needTableName() bool {
  433. return len(statement.JoinStr) > 0
  434. }
  435. func (statement *Statement) colName(col *core.Column, tableName string) string {
  436. if statement.needTableName() {
  437. var nm = tableName
  438. if len(statement.TableAlias) > 0 {
  439. nm = statement.TableAlias
  440. }
  441. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  442. }
  443. return statement.Engine.Quote(col.Name)
  444. }
  445. func buildConds(engine *Engine, table *core.Table, bean interface{},
  446. includeVersion bool, includeUpdated bool, includeNil bool,
  447. includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
  448. mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
  449. var conds []builder.Cond
  450. for _, col := range table.Columns() {
  451. if !includeVersion && col.IsVersion {
  452. continue
  453. }
  454. if !includeUpdated && col.IsUpdated {
  455. continue
  456. }
  457. if !includeAutoIncr && col.IsAutoIncrement {
  458. continue
  459. }
  460. if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
  461. continue
  462. }
  463. if col.SQLType.IsJson() {
  464. continue
  465. }
  466. var colName string
  467. if addedTableName {
  468. var nm = tableName
  469. if len(aliasName) > 0 {
  470. nm = aliasName
  471. }
  472. colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
  473. } else {
  474. colName = engine.Quote(col.Name)
  475. }
  476. fieldValuePtr, err := col.ValueOf(bean)
  477. if err != nil {
  478. engine.logger.Error(err)
  479. continue
  480. }
  481. if col.IsDeleted && !unscoped { // tag "deleted" is enabled
  482. conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}))
  483. }
  484. fieldValue := *fieldValuePtr
  485. if fieldValue.Interface() == nil {
  486. continue
  487. }
  488. fieldType := reflect.TypeOf(fieldValue.Interface())
  489. requiredField := useAllCols
  490. if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok {
  491. if b {
  492. requiredField = true
  493. } else {
  494. continue
  495. }
  496. }
  497. if fieldType.Kind() == reflect.Ptr {
  498. if fieldValue.IsNil() {
  499. if includeNil {
  500. conds = append(conds, builder.Eq{colName: nil})
  501. }
  502. continue
  503. } else if !fieldValue.IsValid() {
  504. continue
  505. } else {
  506. // dereference ptr type to instance type
  507. fieldValue = fieldValue.Elem()
  508. fieldType = reflect.TypeOf(fieldValue.Interface())
  509. requiredField = true
  510. }
  511. }
  512. var val interface{}
  513. switch fieldType.Kind() {
  514. case reflect.Bool:
  515. if allUseBool || requiredField {
  516. val = fieldValue.Interface()
  517. } else {
  518. // if a bool in a struct, it will not be as a condition because it default is false,
  519. // please use Where() instead
  520. continue
  521. }
  522. case reflect.String:
  523. if !requiredField && fieldValue.String() == "" {
  524. continue
  525. }
  526. // for MyString, should convert to string or panic
  527. if fieldType.String() != reflect.String.String() {
  528. val = fieldValue.String()
  529. } else {
  530. val = fieldValue.Interface()
  531. }
  532. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  533. if !requiredField && fieldValue.Int() == 0 {
  534. continue
  535. }
  536. val = fieldValue.Interface()
  537. case reflect.Float32, reflect.Float64:
  538. if !requiredField && fieldValue.Float() == 0.0 {
  539. continue
  540. }
  541. val = fieldValue.Interface()
  542. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  543. if !requiredField && fieldValue.Uint() == 0 {
  544. continue
  545. }
  546. t := int64(fieldValue.Uint())
  547. val = reflect.ValueOf(&t).Interface()
  548. case reflect.Struct:
  549. if fieldType.ConvertibleTo(core.TimeType) {
  550. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  551. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  552. continue
  553. }
  554. val = engine.FormatTime(col.SQLType.Name, t)
  555. } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
  556. continue
  557. } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
  558. val, _ = valNul.Value()
  559. if val == nil {
  560. continue
  561. }
  562. } else {
  563. if col.SQLType.IsJson() {
  564. if col.SQLType.IsText() {
  565. bytes, err := json.Marshal(fieldValue.Interface())
  566. if err != nil {
  567. engine.logger.Error(err)
  568. continue
  569. }
  570. val = string(bytes)
  571. } else if col.SQLType.IsBlob() {
  572. var bytes []byte
  573. var err error
  574. bytes, err = json.Marshal(fieldValue.Interface())
  575. if err != nil {
  576. engine.logger.Error(err)
  577. continue
  578. }
  579. val = bytes
  580. }
  581. } else {
  582. engine.autoMapType(fieldValue)
  583. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  584. if len(table.PrimaryKeys) == 1 {
  585. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  586. // fix non-int pk issues
  587. //if pkField.Int() != 0 {
  588. if pkField.IsValid() && !isZero(pkField.Interface()) {
  589. val = pkField.Interface()
  590. } else {
  591. continue
  592. }
  593. } else {
  594. //TODO: how to handler?
  595. panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys))
  596. }
  597. } else {
  598. val = fieldValue.Interface()
  599. }
  600. }
  601. }
  602. case reflect.Array, reflect.Slice, reflect.Map:
  603. if fieldValue == reflect.Zero(fieldType) {
  604. continue
  605. }
  606. if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  607. continue
  608. }
  609. if col.SQLType.IsText() {
  610. bytes, err := json.Marshal(fieldValue.Interface())
  611. if err != nil {
  612. engine.logger.Error(err)
  613. continue
  614. }
  615. val = string(bytes)
  616. } else if col.SQLType.IsBlob() {
  617. var bytes []byte
  618. var err error
  619. if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
  620. fieldType.Elem().Kind() == reflect.Uint8 {
  621. if fieldValue.Len() > 0 {
  622. val = fieldValue.Bytes()
  623. } else {
  624. continue
  625. }
  626. } else {
  627. bytes, err = json.Marshal(fieldValue.Interface())
  628. if err != nil {
  629. engine.logger.Error(err)
  630. continue
  631. }
  632. val = bytes
  633. }
  634. } else {
  635. continue
  636. }
  637. default:
  638. val = fieldValue.Interface()
  639. }
  640. conds = append(conds, builder.Eq{colName: val})
  641. }
  642. return builder.And(conds...), nil
  643. }
  644. // TableName return current tableName
  645. func (statement *Statement) TableName() string {
  646. if statement.AltTableName != "" {
  647. return statement.AltTableName
  648. }
  649. return statement.tableName
  650. }
  651. // Id generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?"
  652. func (statement *Statement) Id(id interface{}) *Statement {
  653. idValue := reflect.ValueOf(id)
  654. idType := reflect.TypeOf(idValue.Interface())
  655. switch idType {
  656. case ptrPkType:
  657. if pkPtr, ok := (id).(*core.PK); ok {
  658. statement.IdParam = pkPtr
  659. return statement
  660. }
  661. case pkType:
  662. if pk, ok := (id).(core.PK); ok {
  663. statement.IdParam = &pk
  664. return statement
  665. }
  666. }
  667. switch idType.Kind() {
  668. case reflect.String:
  669. statement.IdParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  670. return statement
  671. }
  672. statement.IdParam = &core.PK{id}
  673. return statement
  674. }
  675. // Incr Generate "Update ... Set column = column + arg" statment
  676. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  677. k := strings.ToLower(column)
  678. if len(arg) > 0 {
  679. statement.incrColumns[k] = incrParam{column, arg[0]}
  680. } else {
  681. statement.incrColumns[k] = incrParam{column, 1}
  682. }
  683. return statement
  684. }
  685. // Decr Generate "Update ... Set column = column - arg" statment
  686. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  687. k := strings.ToLower(column)
  688. if len(arg) > 0 {
  689. statement.decrColumns[k] = decrParam{column, arg[0]}
  690. } else {
  691. statement.decrColumns[k] = decrParam{column, 1}
  692. }
  693. return statement
  694. }
  695. // SetExpr Generate "Update ... Set column = {expression}" statment
  696. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  697. k := strings.ToLower(column)
  698. statement.exprColumns[k] = exprParam{column, expression}
  699. return statement
  700. }
  701. // Generate "Update ... Set column = column + arg" statment
  702. func (statement *Statement) getInc() map[string]incrParam {
  703. return statement.incrColumns
  704. }
  705. // Generate "Update ... Set column = column - arg" statment
  706. func (statement *Statement) getDec() map[string]decrParam {
  707. return statement.decrColumns
  708. }
  709. // Generate "Update ... Set column = {expression}" statment
  710. func (statement *Statement) getExpr() map[string]exprParam {
  711. return statement.exprColumns
  712. }
  713. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  714. newColumns := make([]string, 0)
  715. for _, col := range columns {
  716. col = strings.Replace(col, "`", "", -1)
  717. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  718. ccols := strings.Split(col, ",")
  719. for _, c := range ccols {
  720. fields := strings.Split(strings.TrimSpace(c), ".")
  721. if len(fields) == 1 {
  722. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  723. } else if len(fields) == 2 {
  724. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  725. statement.Engine.quote(fields[1]))
  726. } else {
  727. panic(errors.New("unwanted colnames"))
  728. }
  729. }
  730. }
  731. return newColumns
  732. }
  733. // Generate "Distince col1, col2 " statment
  734. func (statement *Statement) Distinct(columns ...string) *Statement {
  735. statement.IsDistinct = true
  736. statement.Cols(columns...)
  737. return statement
  738. }
  739. // Generate "SELECT ... FOR UPDATE" statment
  740. func (statement *Statement) ForUpdate() *Statement {
  741. statement.IsForUpdate = true
  742. return statement
  743. }
  744. // Select replace select
  745. func (s *Statement) Select(str string) *Statement {
  746. s.selectStr = str
  747. return s
  748. }
  749. // Cols generate "col1, col2" statement
  750. func (statement *Statement) Cols(columns ...string) *Statement {
  751. cols := col2NewCols(columns...)
  752. for _, nc := range cols {
  753. statement.columnMap[strings.ToLower(nc)] = true
  754. }
  755. newColumns := statement.col2NewColsWithQuote(columns...)
  756. statement.ColumnStr = strings.Join(newColumns, ", ")
  757. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  758. return statement
  759. }
  760. // AllCols update use only: update all columns
  761. func (statement *Statement) AllCols() *Statement {
  762. statement.useAllCols = true
  763. return statement
  764. }
  765. // MustCols update use only: must update columns
  766. func (statement *Statement) MustCols(columns ...string) *Statement {
  767. newColumns := col2NewCols(columns...)
  768. for _, nc := range newColumns {
  769. statement.mustColumnMap[strings.ToLower(nc)] = true
  770. }
  771. return statement
  772. }
  773. // UseBool indicates that use bool fields as update contents and query contiditions
  774. func (statement *Statement) UseBool(columns ...string) *Statement {
  775. if len(columns) > 0 {
  776. statement.MustCols(columns...)
  777. } else {
  778. statement.allUseBool = true
  779. }
  780. return statement
  781. }
  782. // Omit do not use the columns
  783. func (statement *Statement) Omit(columns ...string) {
  784. newColumns := col2NewCols(columns...)
  785. for _, nc := range newColumns {
  786. statement.columnMap[strings.ToLower(nc)] = false
  787. }
  788. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  789. }
  790. // Nullable Update use only: update columns to null when value is nullable and zero-value
  791. func (statement *Statement) Nullable(columns ...string) {
  792. newColumns := col2NewCols(columns...)
  793. for _, nc := range newColumns {
  794. statement.nullableMap[strings.ToLower(nc)] = true
  795. }
  796. }
  797. // Top generate LIMIT limit statement
  798. func (statement *Statement) Top(limit int) *Statement {
  799. statement.Limit(limit)
  800. return statement
  801. }
  802. // Limit generate LIMIT start, limit statement
  803. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  804. statement.LimitN = limit
  805. if len(start) > 0 {
  806. statement.Start = start[0]
  807. }
  808. return statement
  809. }
  810. // OrderBy generate "Order By order" statement
  811. func (statement *Statement) OrderBy(order string) *Statement {
  812. if len(statement.OrderStr) > 0 {
  813. statement.OrderStr += ", "
  814. }
  815. statement.OrderStr += order
  816. return statement
  817. }
  818. // Desc generate `ORDER BY xx DESC`
  819. func (statement *Statement) Desc(colNames ...string) *Statement {
  820. var buf bytes.Buffer
  821. fmt.Fprintf(&buf, statement.OrderStr)
  822. if len(statement.OrderStr) > 0 {
  823. fmt.Fprint(&buf, ", ")
  824. }
  825. newColNames := statement.col2NewColsWithQuote(colNames...)
  826. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  827. statement.OrderStr = buf.String()
  828. return statement
  829. }
  830. // Asc provide asc order by query condition, the input parameters are columns.
  831. func (statement *Statement) Asc(colNames ...string) *Statement {
  832. var buf bytes.Buffer
  833. fmt.Fprintf(&buf, statement.OrderStr)
  834. if len(statement.OrderStr) > 0 {
  835. fmt.Fprint(&buf, ", ")
  836. }
  837. newColNames := statement.col2NewColsWithQuote(colNames...)
  838. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  839. statement.OrderStr = buf.String()
  840. return statement
  841. }
  842. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  843. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  844. var buf bytes.Buffer
  845. if len(statement.JoinStr) > 0 {
  846. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  847. } else {
  848. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  849. }
  850. switch tablename.(type) {
  851. case []string:
  852. t := tablename.([]string)
  853. if len(t) > 1 {
  854. fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
  855. } else if len(t) == 1 {
  856. fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
  857. }
  858. case []interface{}:
  859. t := tablename.([]interface{})
  860. l := len(t)
  861. var table string
  862. if l > 0 {
  863. f := t[0]
  864. v := rValue(f)
  865. t := v.Type()
  866. if t.Kind() == reflect.String {
  867. table = f.(string)
  868. } else if t.Kind() == reflect.Struct {
  869. table = statement.Engine.tbName(v)
  870. }
  871. }
  872. if l > 1 {
  873. fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
  874. statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
  875. } else if l == 1 {
  876. fmt.Fprintf(&buf, statement.Engine.Quote(table))
  877. }
  878. default:
  879. fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
  880. }
  881. fmt.Fprintf(&buf, " ON %v", condition)
  882. statement.JoinStr = buf.String()
  883. statement.joinArgs = append(statement.joinArgs, args...)
  884. return statement
  885. }
  886. // GroupBy generate "Group By keys" statement
  887. func (statement *Statement) GroupBy(keys string) *Statement {
  888. statement.GroupByStr = keys
  889. return statement
  890. }
  891. // Having generate "Having conditions" statement
  892. func (statement *Statement) Having(conditions string) *Statement {
  893. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  894. return statement
  895. }
  896. // Unscoped always disable struct tag "deleted"
  897. func (statement *Statement) Unscoped() *Statement {
  898. statement.unscoped = true
  899. return statement
  900. }
  901. func (statement *Statement) genColumnStr() string {
  902. table := statement.RefTable
  903. var colNames []string
  904. for _, col := range table.Columns() {
  905. if statement.OmitStr != "" {
  906. if _, ok := statement.columnMap[strings.ToLower(col.Name)]; ok {
  907. continue
  908. }
  909. }
  910. if col.MapType == core.ONLYTODB {
  911. continue
  912. }
  913. if statement.JoinStr != "" {
  914. var name string
  915. if statement.TableAlias != "" {
  916. name = statement.Engine.Quote(statement.TableAlias)
  917. } else {
  918. name = statement.Engine.Quote(statement.TableName())
  919. }
  920. name += "." + statement.Engine.Quote(col.Name)
  921. if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
  922. colNames = append(colNames, "id() AS "+name)
  923. } else {
  924. colNames = append(colNames, name)
  925. }
  926. } else {
  927. name := statement.Engine.Quote(col.Name)
  928. if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
  929. colNames = append(colNames, "id() AS "+name)
  930. } else {
  931. colNames = append(colNames, name)
  932. }
  933. }
  934. }
  935. return strings.Join(colNames, ", ")
  936. }
  937. func (statement *Statement) genCreateTableSQL() string {
  938. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  939. statement.StoreEngine, statement.Charset)
  940. }
  941. func (s *Statement) genIndexSQL() []string {
  942. var sqls []string
  943. tbName := s.TableName()
  944. quote := s.Engine.Quote
  945. for idxName, index := range s.RefTable.Indexes {
  946. if index.Type == core.IndexType {
  947. sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
  948. quote(tbName), quote(strings.Join(index.Cols, quote(","))))
  949. sqls = append(sqls, sql)
  950. }
  951. }
  952. return sqls
  953. }
  954. func uniqueName(tableName, uqeName string) string {
  955. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  956. }
  957. func (s *Statement) genUniqueSQL() []string {
  958. var sqls []string
  959. tbName := s.TableName()
  960. for _, index := range s.RefTable.Indexes {
  961. if index.Type == core.UniqueType {
  962. sql := s.Engine.dialect.CreateIndexSql(tbName, index)
  963. sqls = append(sqls, sql)
  964. }
  965. }
  966. return sqls
  967. }
  968. func (s *Statement) genDelIndexSQL() []string {
  969. var sqls []string
  970. tbName := s.TableName()
  971. for idxName, index := range s.RefTable.Indexes {
  972. var rIdxName string
  973. if index.Type == core.UniqueType {
  974. rIdxName = uniqueName(tbName, idxName)
  975. } else if index.Type == core.IndexType {
  976. rIdxName = indexName(tbName, idxName)
  977. }
  978. sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName))
  979. if s.Engine.dialect.IndexOnTable() {
  980. sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName()))
  981. }
  982. sqls = append(sqls, sql)
  983. }
  984. return sqls
  985. }
  986. func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  987. quote := s.Engine.Quote
  988. sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(s.TableName()),
  989. col.String(s.Engine.dialect))
  990. return sql, []interface{}{}
  991. }
  992. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  993. return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  994. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  995. }
  996. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  997. if !statement.noAutoCondition {
  998. var addedTableName = (len(statement.JoinStr) > 0)
  999. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  1000. if err != nil {
  1001. return "", nil, err
  1002. }
  1003. statement.cond = statement.cond.And(autoCond)
  1004. }
  1005. statement.processIdParam()
  1006. return builder.ToSQL(statement.cond)
  1007. }
  1008. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
  1009. statement.setRefValue(rValue(bean))
  1010. var columnStr = statement.ColumnStr
  1011. if len(statement.selectStr) > 0 {
  1012. columnStr = statement.selectStr
  1013. } else {
  1014. // TODO: always generate column names, not use * even if join
  1015. if len(statement.JoinStr) == 0 {
  1016. if len(columnStr) == 0 {
  1017. if len(statement.GroupByStr) > 0 {
  1018. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  1019. } else {
  1020. columnStr = statement.genColumnStr()
  1021. }
  1022. }
  1023. } else {
  1024. if len(columnStr) == 0 {
  1025. if len(statement.GroupByStr) > 0 {
  1026. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  1027. } else {
  1028. columnStr = "*"
  1029. }
  1030. }
  1031. }
  1032. }
  1033. condSQL, condArgs, _ := statement.genConds(bean)
  1034. return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
  1035. }
  1036. func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) {
  1037. statement.setRefValue(rValue(bean))
  1038. condSQL, condArgs, _ := statement.genConds(bean)
  1039. return statement.genSelectSQL("count(*)", condSQL), append(statement.joinArgs, condArgs...)
  1040. }
  1041. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) {
  1042. statement.setRefValue(rValue(bean))
  1043. var sumStrs = make([]string, 0, len(columns))
  1044. for _, colName := range columns {
  1045. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  1046. }
  1047. condSQL, condArgs, _ := statement.genConds(bean)
  1048. return statement.genSelectSQL(strings.Join(sumStrs, ", "), condSQL), append(statement.joinArgs, condArgs...)
  1049. }
  1050. func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
  1051. var distinct string
  1052. if statement.IsDistinct {
  1053. distinct = "DISTINCT "
  1054. }
  1055. var dialect = statement.Engine.Dialect()
  1056. var quote = statement.Engine.Quote
  1057. var top string
  1058. var mssqlCondi string
  1059. statement.processIdParam()
  1060. var buf bytes.Buffer
  1061. if len(condSQL) > 0 {
  1062. fmt.Fprintf(&buf, " WHERE %v", condSQL)
  1063. }
  1064. var whereStr = buf.String()
  1065. var fromStr = " FROM " + quote(statement.TableName())
  1066. if statement.TableAlias != "" {
  1067. if dialect.DBType() == core.ORACLE {
  1068. fromStr += " " + quote(statement.TableAlias)
  1069. } else {
  1070. fromStr += " AS " + quote(statement.TableAlias)
  1071. }
  1072. }
  1073. if statement.JoinStr != "" {
  1074. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  1075. }
  1076. if dialect.DBType() == core.MSSQL {
  1077. if statement.LimitN > 0 {
  1078. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  1079. }
  1080. if statement.Start > 0 {
  1081. var column = "(id)"
  1082. if len(statement.RefTable.PKColumns()) == 0 {
  1083. for _, index := range statement.RefTable.Indexes {
  1084. if len(index.Cols) == 1 {
  1085. column = index.Cols[0]
  1086. break
  1087. }
  1088. }
  1089. if len(column) == 0 {
  1090. column = statement.RefTable.ColumnsSeq()[0]
  1091. }
  1092. }
  1093. var orderStr string
  1094. if len(statement.OrderStr) > 0 {
  1095. orderStr = " ORDER BY " + statement.OrderStr
  1096. }
  1097. var groupStr string
  1098. if len(statement.GroupByStr) > 0 {
  1099. groupStr = " GROUP BY " + statement.GroupByStr
  1100. }
  1101. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  1102. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  1103. }
  1104. }
  1105. // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
  1106. a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, fromStr, whereStr)
  1107. if len(mssqlCondi) > 0 {
  1108. if len(whereStr) > 0 {
  1109. a += " AND " + mssqlCondi
  1110. } else {
  1111. a += " WHERE " + mssqlCondi
  1112. }
  1113. }
  1114. if statement.GroupByStr != "" {
  1115. a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
  1116. }
  1117. if statement.HavingStr != "" {
  1118. a = fmt.Sprintf("%v %v", a, statement.HavingStr)
  1119. }
  1120. if statement.OrderStr != "" {
  1121. a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
  1122. }
  1123. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1124. if statement.Start > 0 {
  1125. a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
  1126. } else if statement.LimitN > 0 {
  1127. a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
  1128. }
  1129. } else if dialect.DBType() == core.ORACLE {
  1130. if statement.Start != 0 || statement.LimitN != 0 {
  1131. a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
  1132. }
  1133. }
  1134. if statement.IsForUpdate {
  1135. a = dialect.ForUpdateSql(a)
  1136. }
  1137. return
  1138. }
  1139. func (statement *Statement) processIdParam() {
  1140. if statement.IdParam == nil {
  1141. return
  1142. }
  1143. for i, col := range statement.RefTable.PKColumns() {
  1144. var colName = statement.colName(col, statement.TableName())
  1145. if i < len(*(statement.IdParam)) {
  1146. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.IdParam))[i]})
  1147. } else {
  1148. statement.cond = statement.cond.And(builder.Eq{colName: ""})
  1149. }
  1150. }
  1151. }
  1152. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1153. var colnames = make([]string, len(cols))
  1154. for i, col := range cols {
  1155. if includeTableName {
  1156. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1157. "." + statement.Engine.Quote(col.Name)
  1158. } else {
  1159. colnames[i] = statement.Engine.Quote(col.Name)
  1160. }
  1161. }
  1162. return strings.Join(colnames, ", ")
  1163. }
  1164. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1165. if statement.RefTable != nil {
  1166. cols := statement.RefTable.PKColumns()
  1167. if len(cols) == 0 {
  1168. return ""
  1169. }
  1170. colstrs := statement.joinColumns(cols, false)
  1171. sqls := splitNNoCase(sqlStr, " from ", 2)
  1172. if len(sqls) != 2 {
  1173. return ""
  1174. }
  1175. return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1])
  1176. }
  1177. return ""
  1178. }
  1179. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1180. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1181. return "", ""
  1182. }
  1183. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1184. sqls := splitNNoCase(sqlStr, "where", 2)
  1185. if len(sqls) != 2 {
  1186. if len(sqls) == 1 {
  1187. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1188. colstrs, statement.Engine.Quote(statement.TableName()))
  1189. }
  1190. return "", ""
  1191. }
  1192. var whereStr = sqls[1]
  1193. //TODO: for postgres only, if any other database?
  1194. var paraStr string
  1195. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1196. paraStr = "$"
  1197. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1198. paraStr = ":"
  1199. }
  1200. if paraStr != "" {
  1201. if strings.Contains(sqls[1], paraStr) {
  1202. dollers := strings.Split(sqls[1], paraStr)
  1203. whereStr = dollers[0]
  1204. for i, c := range dollers[1:] {
  1205. ccs := strings.SplitN(c, " ", 2)
  1206. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1207. }
  1208. }
  1209. }
  1210. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1211. colstrs, statement.Engine.Quote(statement.TableName()),
  1212. whereStr)
  1213. }