session_insert.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "github.com/xormplus/core"
  12. )
  13. // Insert insert one or more beans
  14. func (session *Session) Insert(beans ...interface{}) (int64, error) {
  15. var affected int64
  16. var err error
  17. if session.isAutoClose {
  18. defer session.Close()
  19. }
  20. for _, bean := range beans {
  21. sliceValue := reflect.Indirect(reflect.ValueOf(bean))
  22. if sliceValue.Kind() == reflect.Slice {
  23. size := sliceValue.Len()
  24. if size > 0 {
  25. if session.engine.SupportInsertMany() {
  26. cnt, err := session.innerInsertMulti(bean)
  27. if err != nil {
  28. return affected, err
  29. }
  30. affected += cnt
  31. } else {
  32. for i := 0; i < size; i++ {
  33. cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
  34. if err != nil {
  35. return affected, err
  36. }
  37. affected += cnt
  38. }
  39. }
  40. }
  41. } else {
  42. cnt, err := session.innerInsert(bean)
  43. if err != nil {
  44. return affected, err
  45. }
  46. affected += cnt
  47. }
  48. }
  49. return affected, err
  50. }
  51. func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
  52. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  53. if sliceValue.Kind() != reflect.Slice {
  54. return 0, errors.New("needs a pointer to a slice")
  55. }
  56. if sliceValue.Len() <= 0 {
  57. return 0, errors.New("could not insert a empty slice")
  58. }
  59. if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
  60. return 0, err
  61. }
  62. if len(session.statement.TableName()) <= 0 {
  63. return 0, ErrTableNotFound
  64. }
  65. table := session.statement.RefTable
  66. size := sliceValue.Len()
  67. var colNames []string
  68. var colMultiPlaces []string
  69. var args []interface{}
  70. var cols []*core.Column
  71. for i := 0; i < size; i++ {
  72. v := sliceValue.Index(i)
  73. vv := reflect.Indirect(v)
  74. elemValue := v.Interface()
  75. var colPlaces []string
  76. // handle BeforeInsertProcessor
  77. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  78. for _, closure := range session.beforeClosures {
  79. closure(elemValue)
  80. }
  81. if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
  82. processor.BeforeInsert()
  83. }
  84. // --
  85. if i == 0 {
  86. for _, col := range table.Columns() {
  87. ptrFieldValue, err := col.ValueOfV(&vv)
  88. if err != nil {
  89. return 0, err
  90. }
  91. fieldValue := *ptrFieldValue
  92. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  93. continue
  94. }
  95. if col.MapType == core.ONLYFROMDB {
  96. continue
  97. }
  98. if col.IsDeleted {
  99. continue
  100. }
  101. if session.statement.ColumnStr != "" {
  102. if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
  103. continue
  104. }
  105. }
  106. if session.statement.OmitStr != "" {
  107. if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
  108. continue
  109. }
  110. }
  111. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
  112. val, t := session.engine.NowTime2(col.SQLType.Name)
  113. args = append(args, val)
  114. var colName = col.Name
  115. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  116. col := table.GetColumn(colName)
  117. setColumnTime(bean, col, t)
  118. })
  119. } else if col.IsVersion && session.statement.checkVersion {
  120. args = append(args, 1)
  121. var colName = col.Name
  122. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  123. col := table.GetColumn(colName)
  124. setColumnInt(bean, col, 1)
  125. })
  126. } else {
  127. arg, err := session.value2Interface(col, fieldValue)
  128. if err != nil {
  129. return 0, err
  130. }
  131. args = append(args, arg)
  132. }
  133. colNames = append(colNames, col.Name)
  134. cols = append(cols, col)
  135. colPlaces = append(colPlaces, "?")
  136. }
  137. } else {
  138. for _, col := range cols {
  139. ptrFieldValue, err := col.ValueOfV(&vv)
  140. if err != nil {
  141. return 0, err
  142. }
  143. fieldValue := *ptrFieldValue
  144. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  145. continue
  146. }
  147. if col.MapType == core.ONLYFROMDB {
  148. continue
  149. }
  150. if col.IsDeleted {
  151. continue
  152. }
  153. if session.statement.ColumnStr != "" {
  154. if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
  155. continue
  156. }
  157. }
  158. if session.statement.OmitStr != "" {
  159. if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
  160. continue
  161. }
  162. }
  163. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
  164. val, t := session.engine.NowTime2(col.SQLType.Name)
  165. args = append(args, val)
  166. var colName = col.Name
  167. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  168. col := table.GetColumn(colName)
  169. setColumnTime(bean, col, t)
  170. })
  171. } else if col.IsVersion && session.statement.checkVersion {
  172. args = append(args, 1)
  173. var colName = col.Name
  174. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  175. col := table.GetColumn(colName)
  176. setColumnInt(bean, col, 1)
  177. })
  178. } else {
  179. arg, err := session.value2Interface(col, fieldValue)
  180. if err != nil {
  181. return 0, err
  182. }
  183. args = append(args, arg)
  184. }
  185. colPlaces = append(colPlaces, "?")
  186. }
  187. }
  188. colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
  189. }
  190. cleanupProcessorsClosures(&session.beforeClosures)
  191. var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
  192. var statement string
  193. if session.engine.dialect.DBType() == core.ORACLE {
  194. sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
  195. temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
  196. session.engine.Quote(session.statement.TableName()),
  197. session.engine.QuoteStr(),
  198. strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
  199. session.engine.QuoteStr())
  200. statement = fmt.Sprintf(sql,
  201. session.engine.Quote(session.statement.TableName()),
  202. session.engine.QuoteStr(),
  203. strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
  204. session.engine.QuoteStr(),
  205. strings.Join(colMultiPlaces, temp))
  206. } else {
  207. statement = fmt.Sprintf(sql,
  208. session.engine.Quote(session.statement.TableName()),
  209. session.engine.QuoteStr(),
  210. strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
  211. session.engine.QuoteStr(),
  212. strings.Join(colMultiPlaces, "),("))
  213. }
  214. res, err := session.exec(statement, args...)
  215. if err != nil {
  216. return 0, err
  217. }
  218. if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
  219. session.cacheInsert(session.statement.TableName())
  220. }
  221. lenAfterClosures := len(session.afterClosures)
  222. for i := 0; i < size; i++ {
  223. elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
  224. // handle AfterInsertProcessor
  225. if session.isAutoCommit {
  226. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  227. for _, closure := range session.afterClosures {
  228. closure(elemValue)
  229. }
  230. if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  231. processor.AfterInsert()
  232. }
  233. } else {
  234. if lenAfterClosures > 0 {
  235. if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
  236. *value = append(*value, session.afterClosures...)
  237. } else {
  238. afterClosures := make([]func(interface{}), lenAfterClosures)
  239. copy(afterClosures, session.afterClosures)
  240. session.afterInsertBeans[elemValue] = &afterClosures
  241. }
  242. } else {
  243. if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  244. session.afterInsertBeans[elemValue] = nil
  245. }
  246. }
  247. }
  248. }
  249. cleanupProcessorsClosures(&session.afterClosures)
  250. return res.RowsAffected()
  251. }
  252. // InsertMulti insert multiple records
  253. func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
  254. if session.isAutoClose {
  255. defer session.Close()
  256. }
  257. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  258. if sliceValue.Kind() != reflect.Slice {
  259. return 0, ErrParamsType
  260. }
  261. if sliceValue.Len() <= 0 {
  262. return 0, nil
  263. }
  264. return session.innerInsertMulti(rowsSlicePtr)
  265. }
  266. func (session *Session) innerInsert(bean interface{}) (int64, error) {
  267. if err := session.statement.setRefValue(rValue(bean)); err != nil {
  268. return 0, err
  269. }
  270. if len(session.statement.TableName()) <= 0 {
  271. return 0, ErrTableNotFound
  272. }
  273. table := session.statement.RefTable
  274. // handle BeforeInsertProcessor
  275. for _, closure := range session.beforeClosures {
  276. closure(bean)
  277. }
  278. cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
  279. if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
  280. processor.BeforeInsert()
  281. }
  282. // --
  283. colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false)
  284. if err != nil {
  285. return 0, err
  286. }
  287. // insert expr columns, override if exists
  288. exprColumns := session.statement.getExpr()
  289. exprColVals := make([]string, 0, len(exprColumns))
  290. for _, v := range exprColumns {
  291. // remove the expr columns
  292. for i, colName := range colNames {
  293. if colName == v.colName {
  294. colNames = append(colNames[:i], colNames[i+1:]...)
  295. args = append(args[:i], args[i+1:]...)
  296. }
  297. }
  298. // append expr column to the end
  299. colNames = append(colNames, v.colName)
  300. exprColVals = append(exprColVals, v.expr)
  301. }
  302. colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns))
  303. if len(exprColVals) > 0 {
  304. colPlaces = colPlaces + strings.Join(exprColVals, ", ")
  305. } else {
  306. if len(colPlaces) > 0 {
  307. colPlaces = colPlaces[0 : len(colPlaces)-2]
  308. }
  309. }
  310. var sqlStr string
  311. if len(colPlaces) > 0 {
  312. sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
  313. session.engine.Quote(session.statement.TableName()),
  314. session.engine.QuoteStr(),
  315. strings.Join(colNames, session.engine.Quote(", ")),
  316. session.engine.QuoteStr(),
  317. colPlaces)
  318. } else {
  319. if session.engine.dialect.DBType() == core.MYSQL {
  320. sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(session.statement.TableName()))
  321. } else {
  322. sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.engine.Quote(session.statement.TableName()))
  323. }
  324. }
  325. handleAfterInsertProcessorFunc := func(bean interface{}) {
  326. if session.isAutoCommit {
  327. for _, closure := range session.afterClosures {
  328. closure(bean)
  329. }
  330. if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
  331. processor.AfterInsert()
  332. }
  333. } else {
  334. lenAfterClosures := len(session.afterClosures)
  335. if lenAfterClosures > 0 {
  336. if value, has := session.afterInsertBeans[bean]; has && value != nil {
  337. *value = append(*value, session.afterClosures...)
  338. } else {
  339. afterClosures := make([]func(interface{}), lenAfterClosures)
  340. copy(afterClosures, session.afterClosures)
  341. session.afterInsertBeans[bean] = &afterClosures
  342. }
  343. } else {
  344. if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
  345. session.afterInsertBeans[bean] = nil
  346. }
  347. }
  348. }
  349. cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
  350. }
  351. // for postgres, many of them didn't implement lastInsertId, so we should
  352. // implemented it ourself.
  353. if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
  354. res, err := session.queryBytes("select seq_atable.currval from dual", args...)
  355. if err != nil {
  356. return 0, err
  357. }
  358. handleAfterInsertProcessorFunc(bean)
  359. if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
  360. session.cacheInsert(session.statement.TableName())
  361. }
  362. if table.Version != "" && session.statement.checkVersion {
  363. verValue, err := table.VersionColumn().ValueOf(bean)
  364. if err != nil {
  365. session.engine.logger.Error(err)
  366. } else if verValue.IsValid() && verValue.CanSet() {
  367. verValue.SetInt(1)
  368. }
  369. }
  370. if len(res) < 1 {
  371. return 0, errors.New("insert no error but not returned id")
  372. }
  373. idByte := res[0][table.AutoIncrement]
  374. id, err := strconv.ParseInt(string(idByte), 10, 64)
  375. if err != nil || id <= 0 {
  376. return 1, err
  377. }
  378. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  379. if err != nil {
  380. session.engine.logger.Error(err)
  381. }
  382. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  383. return 1, nil
  384. }
  385. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  386. return 1, nil
  387. } else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
  388. //assert table.AutoIncrement != ""
  389. sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
  390. res, err := session.queryBytes(sqlStr, args...)
  391. if err != nil {
  392. return 0, err
  393. }
  394. handleAfterInsertProcessorFunc(bean)
  395. if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
  396. session.cacheInsert(session.statement.TableName())
  397. }
  398. if table.Version != "" && session.statement.checkVersion {
  399. verValue, err := table.VersionColumn().ValueOf(bean)
  400. if err != nil {
  401. session.engine.logger.Error(err)
  402. } else if verValue.IsValid() && verValue.CanSet() {
  403. verValue.SetInt(1)
  404. }
  405. }
  406. if len(res) < 1 {
  407. return 0, errors.New("insert no error but not returned id")
  408. }
  409. idByte := res[0][table.AutoIncrement]
  410. id, err := strconv.ParseInt(string(idByte), 10, 64)
  411. if err != nil || id <= 0 {
  412. return 1, err
  413. }
  414. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  415. if err != nil {
  416. session.engine.logger.Error(err)
  417. }
  418. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  419. return 1, nil
  420. }
  421. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  422. return 1, nil
  423. } else {
  424. res, err := session.exec(sqlStr, args...)
  425. if err != nil {
  426. return 0, err
  427. }
  428. defer handleAfterInsertProcessorFunc(bean)
  429. if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
  430. session.cacheInsert(session.statement.TableName())
  431. }
  432. if table.Version != "" && session.statement.checkVersion {
  433. verValue, err := table.VersionColumn().ValueOf(bean)
  434. if err != nil {
  435. session.engine.logger.Error(err)
  436. } else if verValue.IsValid() && verValue.CanSet() {
  437. verValue.SetInt(1)
  438. }
  439. }
  440. if table.AutoIncrement == "" {
  441. return res.RowsAffected()
  442. }
  443. var id int64
  444. id, err = res.LastInsertId()
  445. if err != nil || id <= 0 {
  446. return res.RowsAffected()
  447. }
  448. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  449. if err != nil {
  450. session.engine.logger.Error(err)
  451. }
  452. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  453. return res.RowsAffected()
  454. }
  455. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  456. return res.RowsAffected()
  457. }
  458. }
  459. // InsertOne insert only one struct into database as a record.
  460. // The in parameter bean must a struct or a point to struct. The return
  461. // parameter is inserted and error
  462. func (session *Session) InsertOne(bean interface{}) (int64, error) {
  463. if session.isAutoClose {
  464. defer session.Close()
  465. }
  466. return session.innerInsert(bean)
  467. }
  468. func (session *Session) cacheInsert(tables ...string) error {
  469. if session.statement.RefTable == nil {
  470. return ErrCacheFailed
  471. }
  472. table := session.statement.RefTable
  473. cacher := session.engine.getCacher2(table)
  474. for _, t := range tables {
  475. session.engine.logger.Debug("[cache] clear sql:", t)
  476. cacher.ClearIds(t)
  477. }
  478. return nil
  479. }