db.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. package core
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "errors"
  6. "fmt"
  7. "reflect"
  8. "regexp"
  9. "sync"
  10. )
  11. func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
  12. vv := reflect.ValueOf(mp)
  13. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  14. return "", []interface{}{}, ErrNoMapPointer
  15. }
  16. args := make([]interface{}, 0, len(vv.Elem().MapKeys()))
  17. var err error
  18. query = re.ReplaceAllStringFunc(query, func(src string) string {
  19. v := vv.Elem().MapIndex(reflect.ValueOf(src[1:]))
  20. if !v.IsValid() {
  21. err = fmt.Errorf("map key %s is missing", src[1:])
  22. } else {
  23. args = append(args, v.Interface())
  24. }
  25. return "?"
  26. })
  27. return query, args, err
  28. }
  29. func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
  30. vv := reflect.ValueOf(st)
  31. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  32. return "", []interface{}{}, ErrNoStructPointer
  33. }
  34. args := make([]interface{}, 0)
  35. var err error
  36. query = re.ReplaceAllStringFunc(query, func(src string) string {
  37. fv := vv.Elem().FieldByName(src[1:]).Interface()
  38. if v, ok := fv.(driver.Valuer); ok {
  39. var value driver.Value
  40. value, err = v.Value()
  41. if err != nil {
  42. return "?"
  43. }
  44. args = append(args, value)
  45. } else {
  46. args = append(args, fv)
  47. }
  48. return "?"
  49. })
  50. if err != nil {
  51. return "", []interface{}{}, err
  52. }
  53. return query, args, nil
  54. }
  55. type DB struct {
  56. *sql.DB
  57. Mapper IMapper
  58. }
  59. func Open(driverName, dataSourceName string) (*DB, error) {
  60. db, err := sql.Open(driverName, dataSourceName)
  61. if err != nil {
  62. return nil, err
  63. }
  64. return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil
  65. }
  66. func FromDB(db *sql.DB) *DB {
  67. return &DB{db, NewCacheMapper(&SnakeMapper{})}
  68. }
  69. func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
  70. rows, err := db.DB.Query(query, args...)
  71. if err != nil {
  72. if rows != nil {
  73. rows.Close()
  74. }
  75. return nil, err
  76. }
  77. return &Rows{rows, db.Mapper}, nil
  78. }
  79. func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
  80. query, args, err := MapToSlice(query, mp)
  81. if err != nil {
  82. return nil, err
  83. }
  84. return db.Query(query, args...)
  85. }
  86. func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
  87. query, args, err := StructToSlice(query, st)
  88. if err != nil {
  89. return nil, err
  90. }
  91. return db.Query(query, args...)
  92. }
  93. type Row struct {
  94. rows *Rows
  95. // One of these two will be non-nil:
  96. err error // deferred error for easy chaining
  97. }
  98. func (row *Row) Columns() ([]string, error) {
  99. if row.err != nil {
  100. return nil, row.err
  101. }
  102. return row.rows.Columns()
  103. }
  104. func (row *Row) Scan(dest ...interface{}) error {
  105. if row.err != nil {
  106. return row.err
  107. }
  108. defer row.rows.Close()
  109. for _, dp := range dest {
  110. if _, ok := dp.(*sql.RawBytes); ok {
  111. return errors.New("sql: RawBytes isn't allowed on Row.Scan")
  112. }
  113. }
  114. if !row.rows.Next() {
  115. if err := row.rows.Err(); err != nil {
  116. return err
  117. }
  118. return sql.ErrNoRows
  119. }
  120. err := row.rows.Scan(dest...)
  121. if err != nil {
  122. return err
  123. }
  124. // Make sure the query can be processed to completion with no errors.
  125. if err := row.rows.Close(); err != nil {
  126. return err
  127. }
  128. return nil
  129. }
  130. func (row *Row) ScanStructByName(dest interface{}) error {
  131. if row.err != nil {
  132. return row.err
  133. }
  134. if !row.rows.Next() {
  135. if err := row.rows.Err(); err != nil {
  136. return err
  137. }
  138. return sql.ErrNoRows
  139. }
  140. return row.rows.ScanStructByName(dest)
  141. }
  142. func (row *Row) ScanStructByIndex(dest interface{}) error {
  143. if row.err != nil {
  144. return row.err
  145. }
  146. if !row.rows.Next() {
  147. if err := row.rows.Err(); err != nil {
  148. return err
  149. }
  150. return sql.ErrNoRows
  151. }
  152. return row.rows.ScanStructByIndex(dest)
  153. }
  154. // scan data to a slice's pointer, slice's length should equal to columns' number
  155. func (row *Row) ScanSlice(dest interface{}) error {
  156. if row.err != nil {
  157. return row.err
  158. }
  159. if !row.rows.Next() {
  160. if err := row.rows.Err(); err != nil {
  161. return err
  162. }
  163. return sql.ErrNoRows
  164. }
  165. return row.rows.ScanSlice(dest)
  166. }
  167. // scan data to a map's pointer
  168. func (row *Row) ScanMap(dest interface{}) error {
  169. if row.err != nil {
  170. return row.err
  171. }
  172. if !row.rows.Next() {
  173. if err := row.rows.Err(); err != nil {
  174. return err
  175. }
  176. return sql.ErrNoRows
  177. }
  178. return row.rows.ScanMap(dest)
  179. }
  180. func (db *DB) QueryRow(query string, args ...interface{}) *Row {
  181. rows, err := db.Query(query, args...)
  182. if err != nil {
  183. return &Row{nil, err}
  184. }
  185. return &Row{rows, nil}
  186. }
  187. func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
  188. query, args, err := MapToSlice(query, mp)
  189. if err != nil {
  190. return &Row{nil, err}
  191. }
  192. return db.QueryRow(query, args...)
  193. }
  194. func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
  195. query, args, err := StructToSlice(query, st)
  196. if err != nil {
  197. return &Row{nil, err}
  198. }
  199. return db.QueryRow(query, args...)
  200. }
  201. type Stmt struct {
  202. *sql.Stmt
  203. Mapper IMapper
  204. names map[string]int
  205. }
  206. func (db *DB) Prepare(query string) (*Stmt, error) {
  207. names := make(map[string]int)
  208. var i int
  209. query = re.ReplaceAllStringFunc(query, func(src string) string {
  210. names[src[1:]] = i
  211. i += 1
  212. return "?"
  213. })
  214. stmt, err := db.DB.Prepare(query)
  215. if err != nil {
  216. return nil, err
  217. }
  218. return &Stmt{stmt, db.Mapper, names}, nil
  219. }
  220. func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
  221. vv := reflect.ValueOf(mp)
  222. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  223. return nil, errors.New("mp should be a map's pointer")
  224. }
  225. args := make([]interface{}, len(s.names))
  226. for k, i := range s.names {
  227. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  228. }
  229. return s.Stmt.Exec(args...)
  230. }
  231. func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
  232. vv := reflect.ValueOf(st)
  233. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  234. return nil, errors.New("mp should be a map's pointer")
  235. }
  236. args := make([]interface{}, len(s.names))
  237. for k, i := range s.names {
  238. args[i] = vv.Elem().FieldByName(k).Interface()
  239. }
  240. return s.Stmt.Exec(args...)
  241. }
  242. func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
  243. rows, err := s.Stmt.Query(args...)
  244. if err != nil {
  245. return nil, err
  246. }
  247. return &Rows{rows, s.Mapper}, nil
  248. }
  249. func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
  250. vv := reflect.ValueOf(mp)
  251. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  252. return nil, errors.New("mp should be a map's pointer")
  253. }
  254. args := make([]interface{}, len(s.names))
  255. for k, i := range s.names {
  256. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  257. }
  258. return s.Query(args...)
  259. }
  260. func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
  261. vv := reflect.ValueOf(st)
  262. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  263. return nil, errors.New("mp should be a map's pointer")
  264. }
  265. args := make([]interface{}, len(s.names))
  266. for k, i := range s.names {
  267. args[i] = vv.Elem().FieldByName(k).Interface()
  268. }
  269. return s.Query(args...)
  270. }
  271. func (s *Stmt) QueryRow(args ...interface{}) *Row {
  272. rows, err := s.Query(args...)
  273. return &Row{rows, err}
  274. }
  275. func (s *Stmt) QueryRowMap(mp interface{}) *Row {
  276. vv := reflect.ValueOf(mp)
  277. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  278. return &Row{nil, errors.New("mp should be a map's pointer")}
  279. }
  280. args := make([]interface{}, len(s.names))
  281. for k, i := range s.names {
  282. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  283. }
  284. return s.QueryRow(args...)
  285. }
  286. func (s *Stmt) QueryRowStruct(st interface{}) *Row {
  287. vv := reflect.ValueOf(st)
  288. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  289. return &Row{nil, errors.New("st should be a struct's pointer")}
  290. }
  291. args := make([]interface{}, len(s.names))
  292. for k, i := range s.names {
  293. args[i] = vv.Elem().FieldByName(k).Interface()
  294. }
  295. return s.QueryRow(args...)
  296. }
  297. var (
  298. re = regexp.MustCompile(`[?](\w+)`)
  299. )
  300. // insert into (name) values (?)
  301. // insert into (name) values (?name)
  302. func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
  303. query, args, err := MapToSlice(query, mp)
  304. if err != nil {
  305. return nil, err
  306. }
  307. return db.DB.Exec(query, args...)
  308. }
  309. func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
  310. query, args, err := StructToSlice(query, st)
  311. if err != nil {
  312. return nil, err
  313. }
  314. return db.DB.Exec(query, args...)
  315. }
  316. type Rows struct {
  317. *sql.Rows
  318. Mapper IMapper
  319. }
  320. // scan data to a struct's pointer according field index
  321. func (rs *Rows) ScanStructByIndex(dest ...interface{}) error {
  322. if len(dest) == 0 {
  323. return errors.New("at least one struct")
  324. }
  325. vvvs := make([]reflect.Value, len(dest))
  326. for i, s := range dest {
  327. vv := reflect.ValueOf(s)
  328. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  329. return errors.New("dest should be a struct's pointer")
  330. }
  331. vvvs[i] = vv.Elem()
  332. }
  333. cols, err := rs.Columns()
  334. if err != nil {
  335. return err
  336. }
  337. newDest := make([]interface{}, len(cols))
  338. var i = 0
  339. for _, vvv := range vvvs {
  340. for j := 0; j < vvv.NumField(); j++ {
  341. newDest[i] = vvv.Field(j).Addr().Interface()
  342. i = i + 1
  343. }
  344. }
  345. return rs.Rows.Scan(newDest...)
  346. }
  347. type EmptyScanner struct {
  348. }
  349. func (EmptyScanner) Scan(src interface{}) error {
  350. return nil
  351. }
  352. var (
  353. fieldCache = make(map[reflect.Type]map[string]int)
  354. fieldCacheMutex sync.RWMutex
  355. )
  356. func fieldByName(v reflect.Value, name string) reflect.Value {
  357. t := v.Type()
  358. fieldCacheMutex.RLock()
  359. cache, ok := fieldCache[t]
  360. fieldCacheMutex.RUnlock()
  361. if !ok {
  362. cache = make(map[string]int)
  363. for i := 0; i < v.NumField(); i++ {
  364. cache[t.Field(i).Name] = i
  365. }
  366. fieldCacheMutex.Lock()
  367. fieldCache[t] = cache
  368. fieldCacheMutex.Unlock()
  369. }
  370. if i, ok := cache[name]; ok {
  371. return v.Field(i)
  372. }
  373. return reflect.Zero(t)
  374. }
  375. // scan data to a struct's pointer according field name
  376. func (rs *Rows) ScanStructByName(dest interface{}) error {
  377. vv := reflect.ValueOf(dest)
  378. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  379. return errors.New("dest should be a struct's pointer")
  380. }
  381. cols, err := rs.Columns()
  382. if err != nil {
  383. return err
  384. }
  385. newDest := make([]interface{}, len(cols))
  386. var v EmptyScanner
  387. for j, name := range cols {
  388. f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name))
  389. if f.IsValid() {
  390. newDest[j] = f.Addr().Interface()
  391. } else {
  392. newDest[j] = &v
  393. }
  394. }
  395. return rs.Rows.Scan(newDest...)
  396. }
  397. type cacheStruct struct {
  398. value reflect.Value
  399. idx int
  400. }
  401. var (
  402. reflectCache = make(map[reflect.Type]*cacheStruct)
  403. reflectCacheMutex sync.RWMutex
  404. )
  405. func ReflectNew(typ reflect.Type) reflect.Value {
  406. reflectCacheMutex.RLock()
  407. cs, ok := reflectCache[typ]
  408. reflectCacheMutex.RUnlock()
  409. const newSize = 200
  410. if !ok || cs.idx+1 > newSize-1 {
  411. cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0}
  412. reflectCacheMutex.Lock()
  413. reflectCache[typ] = cs
  414. reflectCacheMutex.Unlock()
  415. } else {
  416. reflectCacheMutex.Lock()
  417. cs.idx = cs.idx + 1
  418. reflectCacheMutex.Unlock()
  419. }
  420. return cs.value.Index(cs.idx).Addr()
  421. }
  422. // scan data to a slice's pointer, slice's length should equal to columns' number
  423. func (rs *Rows) ScanSlice(dest interface{}) error {
  424. vv := reflect.ValueOf(dest)
  425. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
  426. return errors.New("dest should be a slice's pointer")
  427. }
  428. vvv := vv.Elem()
  429. cols, err := rs.Columns()
  430. if err != nil {
  431. return err
  432. }
  433. newDest := make([]interface{}, len(cols))
  434. for j := 0; j < len(cols); j++ {
  435. if j >= vvv.Len() {
  436. newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
  437. } else {
  438. newDest[j] = vvv.Index(j).Addr().Interface()
  439. }
  440. }
  441. err = rs.Rows.Scan(newDest...)
  442. if err != nil {
  443. return err
  444. }
  445. srcLen := vvv.Len()
  446. for i := srcLen; i < len(cols); i++ {
  447. vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
  448. }
  449. return nil
  450. }
  451. // scan data to a map's pointer
  452. func (rs *Rows) ScanMap(dest interface{}) error {
  453. vv := reflect.ValueOf(dest)
  454. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  455. return errors.New("dest should be a map's pointer")
  456. }
  457. cols, err := rs.Columns()
  458. if err != nil {
  459. return err
  460. }
  461. newDest := make([]interface{}, len(cols))
  462. vvv := vv.Elem()
  463. for i, _ := range cols {
  464. newDest[i] = ReflectNew(vvv.Type().Elem()).Interface()
  465. //v := reflect.New(vvv.Type().Elem())
  466. //newDest[i] = v.Interface()
  467. }
  468. err = rs.Rows.Scan(newDest...)
  469. if err != nil {
  470. return err
  471. }
  472. for i, name := range cols {
  473. vname := reflect.ValueOf(name)
  474. vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
  475. }
  476. return nil
  477. }
  478. /*func (rs *Rows) ScanMap(dest interface{}) error {
  479. vv := reflect.ValueOf(dest)
  480. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  481. return errors.New("dest should be a map's pointer")
  482. }
  483. cols, err := rs.Columns()
  484. if err != nil {
  485. return err
  486. }
  487. newDest := make([]interface{}, len(cols))
  488. err = rs.ScanSlice(newDest)
  489. if err != nil {
  490. return err
  491. }
  492. vvv := vv.Elem()
  493. for i, name := range cols {
  494. vname := reflect.ValueOf(name)
  495. vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
  496. }
  497. return nil
  498. }*/
  499. type Tx struct {
  500. *sql.Tx
  501. Mapper IMapper
  502. }
  503. func (db *DB) Begin() (*Tx, error) {
  504. tx, err := db.DB.Begin()
  505. if err != nil {
  506. return nil, err
  507. }
  508. return &Tx{tx, db.Mapper}, nil
  509. }
  510. func (tx *Tx) Prepare(query string) (*Stmt, error) {
  511. names := make(map[string]int)
  512. var i int
  513. query = re.ReplaceAllStringFunc(query, func(src string) string {
  514. names[src[1:]] = i
  515. i += 1
  516. return "?"
  517. })
  518. stmt, err := tx.Tx.Prepare(query)
  519. if err != nil {
  520. return nil, err
  521. }
  522. return &Stmt{stmt, tx.Mapper, names}, nil
  523. }
  524. func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
  525. // TODO:
  526. return stmt
  527. }
  528. func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
  529. query, args, err := MapToSlice(query, mp)
  530. if err != nil {
  531. return nil, err
  532. }
  533. return tx.Tx.Exec(query, args...)
  534. }
  535. func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
  536. query, args, err := StructToSlice(query, st)
  537. if err != nil {
  538. return nil, err
  539. }
  540. return tx.Tx.Exec(query, args...)
  541. }
  542. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
  543. rows, err := tx.Tx.Query(query, args...)
  544. if err != nil {
  545. return nil, err
  546. }
  547. return &Rows{rows, tx.Mapper}, nil
  548. }
  549. func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
  550. query, args, err := MapToSlice(query, mp)
  551. if err != nil {
  552. return nil, err
  553. }
  554. return tx.Query(query, args...)
  555. }
  556. func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
  557. query, args, err := StructToSlice(query, st)
  558. if err != nil {
  559. return nil, err
  560. }
  561. return tx.Query(query, args...)
  562. }
  563. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
  564. rows, err := tx.Query(query, args...)
  565. return &Row{rows, err}
  566. }
  567. func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
  568. query, args, err := MapToSlice(query, mp)
  569. if err != nil {
  570. return &Row{nil, err}
  571. }
  572. return tx.QueryRow(query, args...)
  573. }
  574. func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
  575. query, args, err := StructToSlice(query, st)
  576. if err != nil {
  577. return &Row{nil, err}
  578. }
  579. return tx.QueryRow(query, args...)
  580. }