db.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. package core
  2. import (
  3. "database/sql"
  4. "errors"
  5. "reflect"
  6. "regexp"
  7. "sync"
  8. )
  9. func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
  10. vv := reflect.ValueOf(mp)
  11. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  12. return "", []interface{}{}, ErrNoMapPointer
  13. }
  14. args := make([]interface{}, 0)
  15. query = re.ReplaceAllStringFunc(query, func(src string) string {
  16. args = append(args, vv.Elem().MapIndex(reflect.ValueOf(src[1:])).Interface())
  17. return "?"
  18. })
  19. return query, args, nil
  20. }
  21. func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
  22. vv := reflect.ValueOf(st)
  23. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  24. return "", []interface{}{}, ErrNoStructPointer
  25. }
  26. args := make([]interface{}, 0)
  27. query = re.ReplaceAllStringFunc(query, func(src string) string {
  28. args = append(args, vv.Elem().FieldByName(src[1:]).Interface())
  29. return "?"
  30. })
  31. return query, args, nil
  32. }
  33. type DB struct {
  34. *sql.DB
  35. Mapper IMapper
  36. }
  37. func Open(driverName, dataSourceName string) (*DB, error) {
  38. db, err := sql.Open(driverName, dataSourceName)
  39. return &DB{db, NewCacheMapper(&SnakeMapper{})}, err
  40. }
  41. func FromDB(db *sql.DB) *DB {
  42. return &DB{db, NewCacheMapper(&SnakeMapper{})}
  43. }
  44. func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
  45. rows, err := db.DB.Query(query, args...)
  46. return &Rows{rows, db.Mapper}, err
  47. }
  48. func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
  49. query, args, err := MapToSlice(query, mp)
  50. if err != nil {
  51. return nil, err
  52. }
  53. return db.Query(query, args...)
  54. }
  55. func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
  56. query, args, err := StructToSlice(query, st)
  57. if err != nil {
  58. return nil, err
  59. }
  60. return db.Query(query, args...)
  61. }
  62. type Row struct {
  63. *sql.Row
  64. // One of these two will be non-nil:
  65. err error // deferred error for easy chaining
  66. Mapper IMapper
  67. }
  68. func (row *Row) Scan(dest ...interface{}) error {
  69. if row.err != nil {
  70. return row.err
  71. }
  72. return row.Row.Scan(dest...)
  73. }
  74. func (db *DB) QueryRow(query string, args ...interface{}) *Row {
  75. row := db.DB.QueryRow(query, args...)
  76. return &Row{row, nil, db.Mapper}
  77. }
  78. func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
  79. query, args, err := MapToSlice(query, mp)
  80. if err != nil {
  81. return &Row{nil, err, db.Mapper}
  82. }
  83. return db.QueryRow(query, args...)
  84. }
  85. func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
  86. query, args, err := StructToSlice(query, st)
  87. if err != nil {
  88. return &Row{nil, err, db.Mapper}
  89. }
  90. return db.QueryRow(query, args...)
  91. }
  92. type Stmt struct {
  93. *sql.Stmt
  94. Mapper IMapper
  95. names map[string]int
  96. }
  97. func (db *DB) Prepare(query string) (*Stmt, error) {
  98. names := make(map[string]int)
  99. var i int
  100. query = re.ReplaceAllStringFunc(query, func(src string) string {
  101. names[src[1:]] = i
  102. i += 1
  103. return "?"
  104. })
  105. stmt, err := db.DB.Prepare(query)
  106. if err != nil {
  107. return nil, err
  108. }
  109. return &Stmt{stmt, db.Mapper, names}, nil
  110. }
  111. func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
  112. vv := reflect.ValueOf(mp)
  113. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  114. return nil, errors.New("mp should be a map's pointer")
  115. }
  116. args := make([]interface{}, len(s.names))
  117. for k, i := range s.names {
  118. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  119. }
  120. return s.Stmt.Exec(args...)
  121. }
  122. func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
  123. vv := reflect.ValueOf(st)
  124. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  125. return nil, errors.New("mp should be a map's pointer")
  126. }
  127. args := make([]interface{}, len(s.names))
  128. for k, i := range s.names {
  129. args[i] = vv.Elem().FieldByName(k).Interface()
  130. }
  131. return s.Stmt.Exec(args...)
  132. }
  133. func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
  134. rows, err := s.Stmt.Query(args...)
  135. if err != nil {
  136. return nil, err
  137. }
  138. return &Rows{rows, s.Mapper}, nil
  139. }
  140. func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
  141. vv := reflect.ValueOf(mp)
  142. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  143. return nil, errors.New("mp should be a map's pointer")
  144. }
  145. args := make([]interface{}, len(s.names))
  146. for k, i := range s.names {
  147. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  148. }
  149. return s.Query(args...)
  150. }
  151. func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
  152. vv := reflect.ValueOf(st)
  153. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  154. return nil, errors.New("mp should be a map's pointer")
  155. }
  156. args := make([]interface{}, len(s.names))
  157. for k, i := range s.names {
  158. args[i] = vv.Elem().FieldByName(k).Interface()
  159. }
  160. return s.Query(args...)
  161. }
  162. func (s *Stmt) QueryRow(args ...interface{}) *Row {
  163. row := s.Stmt.QueryRow(args...)
  164. return &Row{row, nil, s.Mapper}
  165. }
  166. func (s *Stmt) QueryRowMap(mp interface{}) *Row {
  167. vv := reflect.ValueOf(mp)
  168. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  169. return &Row{nil, errors.New("mp should be a map's pointer"), s.Mapper}
  170. }
  171. args := make([]interface{}, len(s.names))
  172. for k, i := range s.names {
  173. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  174. }
  175. return s.QueryRow(args...)
  176. }
  177. func (s *Stmt) QueryRowStruct(st interface{}) *Row {
  178. vv := reflect.ValueOf(st)
  179. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  180. return &Row{nil, errors.New("st should be a struct's pointer"), s.Mapper}
  181. }
  182. args := make([]interface{}, len(s.names))
  183. for k, i := range s.names {
  184. args[i] = vv.Elem().FieldByName(k).Interface()
  185. }
  186. return s.QueryRow(args...)
  187. }
  188. var (
  189. re = regexp.MustCompile(`[?](\w+)`)
  190. )
  191. // insert into (name) values (?)
  192. // insert into (name) values (?name)
  193. func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
  194. query, args, err := MapToSlice(query, mp)
  195. if err != nil {
  196. return nil, err
  197. }
  198. return db.DB.Exec(query, args...)
  199. }
  200. func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
  201. query, args, err := StructToSlice(query, st)
  202. if err != nil {
  203. return nil, err
  204. }
  205. return db.DB.Exec(query, args...)
  206. }
  207. type Rows struct {
  208. *sql.Rows
  209. Mapper IMapper
  210. }
  211. // scan data to a struct's pointer according field index
  212. func (rs *Rows) ScanStructByIndex(dest ...interface{}) error {
  213. if len(dest) == 0 {
  214. return errors.New("at least one struct")
  215. }
  216. vvvs := make([]reflect.Value, len(dest))
  217. for i, s := range dest {
  218. vv := reflect.ValueOf(s)
  219. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  220. return errors.New("dest should be a struct's pointer")
  221. }
  222. vvvs[i] = vv.Elem()
  223. }
  224. cols, err := rs.Columns()
  225. if err != nil {
  226. return err
  227. }
  228. newDest := make([]interface{}, len(cols))
  229. var i = 0
  230. for _, vvv := range vvvs {
  231. for j := 0; j < vvv.NumField(); j++ {
  232. newDest[i] = vvv.Field(j).Addr().Interface()
  233. i = i + 1
  234. }
  235. }
  236. return rs.Rows.Scan(newDest...)
  237. }
  238. type EmptyScanner struct {
  239. }
  240. func (EmptyScanner) Scan(src interface{}) error {
  241. return nil
  242. }
  243. var (
  244. fieldCache = make(map[reflect.Type]map[string]int)
  245. fieldCacheMutex sync.RWMutex
  246. )
  247. func fieldByName(v reflect.Value, name string) reflect.Value {
  248. t := v.Type()
  249. fieldCacheMutex.RLock()
  250. cache, ok := fieldCache[t]
  251. fieldCacheMutex.RUnlock()
  252. if !ok {
  253. cache = make(map[string]int)
  254. for i := 0; i < v.NumField(); i++ {
  255. cache[t.Field(i).Name] = i
  256. }
  257. fieldCacheMutex.Lock()
  258. fieldCache[t] = cache
  259. fieldCacheMutex.Unlock()
  260. }
  261. if i, ok := cache[name]; ok {
  262. return v.Field(i)
  263. }
  264. return reflect.Zero(t)
  265. }
  266. // scan data to a struct's pointer according field name
  267. func (rs *Rows) ScanStructByName(dest interface{}) error {
  268. vv := reflect.ValueOf(dest)
  269. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  270. return errors.New("dest should be a struct's pointer")
  271. }
  272. cols, err := rs.Columns()
  273. if err != nil {
  274. return err
  275. }
  276. newDest := make([]interface{}, len(cols))
  277. var v EmptyScanner
  278. for j, name := range cols {
  279. f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name))
  280. if f.IsValid() {
  281. newDest[j] = f.Addr().Interface()
  282. } else {
  283. newDest[j] = &v
  284. }
  285. }
  286. return rs.Rows.Scan(newDest...)
  287. }
  288. type cacheStruct struct {
  289. value reflect.Value
  290. idx int
  291. }
  292. var (
  293. reflectCache = make(map[reflect.Type]*cacheStruct)
  294. reflectCacheMutex sync.RWMutex
  295. )
  296. func ReflectNew(typ reflect.Type) reflect.Value {
  297. reflectCacheMutex.RLock()
  298. cs, ok := reflectCache[typ]
  299. reflectCacheMutex.RUnlock()
  300. const newSize = 200
  301. if !ok || cs.idx+1 > newSize-1 {
  302. cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0}
  303. reflectCacheMutex.Lock()
  304. reflectCache[typ] = cs
  305. reflectCacheMutex.Unlock()
  306. } else {
  307. reflectCacheMutex.Lock()
  308. cs.idx = cs.idx + 1
  309. reflectCacheMutex.Unlock()
  310. }
  311. return cs.value.Index(cs.idx).Addr()
  312. }
  313. // scan data to a slice's pointer, slice's length should equal to columns' number
  314. func (rs *Rows) ScanSlice(dest interface{}) error {
  315. vv := reflect.ValueOf(dest)
  316. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
  317. return errors.New("dest should be a slice's pointer")
  318. }
  319. vvv := vv.Elem()
  320. cols, err := rs.Columns()
  321. if err != nil {
  322. return err
  323. }
  324. newDest := make([]interface{}, len(cols))
  325. for j := 0; j < len(cols); j++ {
  326. if j >= vvv.Len() {
  327. newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
  328. } else {
  329. newDest[j] = vvv.Index(j).Addr().Interface()
  330. }
  331. }
  332. err = rs.Rows.Scan(newDest...)
  333. if err != nil {
  334. return err
  335. }
  336. srcLen := vvv.Len()
  337. for i := srcLen; i < len(cols); i++ {
  338. vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
  339. }
  340. return nil
  341. }
  342. // scan data to a map's pointer
  343. func (rs *Rows) ScanMap(dest interface{}) error {
  344. vv := reflect.ValueOf(dest)
  345. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  346. return errors.New("dest should be a map's pointer")
  347. }
  348. cols, err := rs.Columns()
  349. if err != nil {
  350. return err
  351. }
  352. newDest := make([]interface{}, len(cols))
  353. vvv := vv.Elem()
  354. for i, _ := range cols {
  355. newDest[i] = ReflectNew(vvv.Type().Elem()).Interface()
  356. }
  357. err = rs.Rows.Scan(newDest...)
  358. if err != nil {
  359. return err
  360. }
  361. for i, name := range cols {
  362. vname := reflect.ValueOf(name)
  363. vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
  364. }
  365. return nil
  366. }
  367. /*func (rs *Rows) ScanMap(dest interface{}) error {
  368. vv := reflect.ValueOf(dest)
  369. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  370. return errors.New("dest should be a map's pointer")
  371. }
  372. cols, err := rs.Columns()
  373. if err != nil {
  374. return err
  375. }
  376. newDest := make([]interface{}, len(cols))
  377. err = rs.ScanSlice(newDest)
  378. if err != nil {
  379. return err
  380. }
  381. vvv := vv.Elem()
  382. for i, name := range cols {
  383. vname := reflect.ValueOf(name)
  384. vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
  385. }
  386. return nil
  387. }*/
  388. type Tx struct {
  389. *sql.Tx
  390. Mapper IMapper
  391. }
  392. func (db *DB) Begin() (*Tx, error) {
  393. tx, err := db.DB.Begin()
  394. if err != nil {
  395. return nil, err
  396. }
  397. return &Tx{tx, db.Mapper}, nil
  398. }
  399. func (tx *Tx) Prepare(query string) (*Stmt, error) {
  400. names := make(map[string]int)
  401. var i int
  402. query = re.ReplaceAllStringFunc(query, func(src string) string {
  403. names[src[1:]] = i
  404. i += 1
  405. return "?"
  406. })
  407. stmt, err := tx.Tx.Prepare(query)
  408. if err != nil {
  409. return nil, err
  410. }
  411. return &Stmt{stmt, tx.Mapper, names}, nil
  412. }
  413. func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
  414. // TODO:
  415. return stmt
  416. }
  417. func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
  418. query, args, err := MapToSlice(query, mp)
  419. if err != nil {
  420. return nil, err
  421. }
  422. return tx.Tx.Exec(query, args...)
  423. }
  424. func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
  425. query, args, err := StructToSlice(query, st)
  426. if err != nil {
  427. return nil, err
  428. }
  429. return tx.Tx.Exec(query, args...)
  430. }
  431. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
  432. rows, err := tx.Tx.Query(query, args...)
  433. if err != nil {
  434. return nil, err
  435. }
  436. return &Rows{rows, tx.Mapper}, nil
  437. }
  438. func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
  439. query, args, err := MapToSlice(query, mp)
  440. if err != nil {
  441. return nil, err
  442. }
  443. return tx.Query(query, args...)
  444. }
  445. func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
  446. query, args, err := StructToSlice(query, st)
  447. if err != nil {
  448. return nil, err
  449. }
  450. return tx.Query(query, args...)
  451. }
  452. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
  453. row := tx.Tx.QueryRow(query, args...)
  454. return &Row{row, nil, tx.Mapper}
  455. }
  456. func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
  457. query, args, err := MapToSlice(query, mp)
  458. if err != nil {
  459. return &Row{nil, err, tx.Mapper}
  460. }
  461. return tx.QueryRow(query, args...)
  462. }
  463. func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
  464. query, args, err := StructToSlice(query, st)
  465. if err != nil {
  466. return &Row{nil, err, tx.Mapper}
  467. }
  468. return tx.QueryRow(query, args...)
  469. }