scan.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. // Copyright 2012 Gary Burd
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License"): you may
  4. // not use this file except in compliance with the License. You may obtain
  5. // a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  11. // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  12. // License for the specific language governing permissions and limitations
  13. // under the License.
  14. package redis
  15. import (
  16. "errors"
  17. "fmt"
  18. "reflect"
  19. "strconv"
  20. "strings"
  21. "sync"
  22. )
  23. func ensureLen(d reflect.Value, n int) {
  24. if n > d.Cap() {
  25. d.Set(reflect.MakeSlice(d.Type(), n, n))
  26. } else {
  27. d.SetLen(n)
  28. }
  29. }
  30. func cannotConvert(d reflect.Value, s interface{}) error {
  31. return fmt.Errorf("redigo: Scan cannot convert from %s to %s",
  32. reflect.TypeOf(s), d.Type())
  33. }
  34. func convertAssignBytes(d reflect.Value, s []byte) (err error) {
  35. switch d.Type().Kind() {
  36. case reflect.Float32, reflect.Float64:
  37. var x float64
  38. x, err = strconv.ParseFloat(string(s), d.Type().Bits())
  39. d.SetFloat(x)
  40. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  41. var x int64
  42. x, err = strconv.ParseInt(string(s), 10, d.Type().Bits())
  43. d.SetInt(x)
  44. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  45. var x uint64
  46. x, err = strconv.ParseUint(string(s), 10, d.Type().Bits())
  47. d.SetUint(x)
  48. case reflect.Bool:
  49. var x bool
  50. x, err = strconv.ParseBool(string(s))
  51. d.SetBool(x)
  52. case reflect.String:
  53. d.SetString(string(s))
  54. case reflect.Slice:
  55. if d.Type().Elem().Kind() != reflect.Uint8 {
  56. err = cannotConvert(d, s)
  57. } else {
  58. d.SetBytes(s)
  59. }
  60. default:
  61. err = cannotConvert(d, s)
  62. }
  63. return
  64. }
  65. func convertAssignInt(d reflect.Value, s int64) (err error) {
  66. switch d.Type().Kind() {
  67. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  68. d.SetInt(s)
  69. if d.Int() != s {
  70. err = strconv.ErrRange
  71. d.SetInt(0)
  72. }
  73. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  74. if s < 0 {
  75. err = strconv.ErrRange
  76. } else {
  77. x := uint64(s)
  78. d.SetUint(x)
  79. if d.Uint() != x {
  80. err = strconv.ErrRange
  81. d.SetUint(0)
  82. }
  83. }
  84. case reflect.Bool:
  85. d.SetBool(s != 0)
  86. default:
  87. err = cannotConvert(d, s)
  88. }
  89. return
  90. }
  91. func convertAssignValues(d reflect.Value, s []interface{}) (err error) {
  92. if d.Type().Kind() != reflect.Slice {
  93. return cannotConvert(d, s)
  94. }
  95. ensureLen(d, len(s))
  96. for i := 0; i < len(s); i++ {
  97. switch s := s[i].(type) {
  98. case []byte:
  99. err = convertAssignBytes(d.Index(i), s)
  100. case int64:
  101. err = convertAssignInt(d.Index(i), s)
  102. default:
  103. err = cannotConvert(d, s)
  104. }
  105. if err != nil {
  106. break
  107. }
  108. }
  109. return
  110. }
  111. func convertAssign(d interface{}, s interface{}) (err error) {
  112. // Handle the most common destination types using type switches and
  113. // fall back to reflection for all other types.
  114. switch s := s.(type) {
  115. case nil:
  116. // ingore
  117. case []byte:
  118. switch d := d.(type) {
  119. case *string:
  120. *d = string(s)
  121. case *int:
  122. *d, err = strconv.Atoi(string(s))
  123. case *bool:
  124. *d, err = strconv.ParseBool(string(s))
  125. case *[]byte:
  126. *d = s
  127. case *interface{}:
  128. *d = s
  129. case nil:
  130. // skip value
  131. default:
  132. if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
  133. err = cannotConvert(d, s)
  134. } else {
  135. err = convertAssignBytes(d.Elem(), s)
  136. }
  137. }
  138. case int64:
  139. switch d := d.(type) {
  140. case *int:
  141. x := int(s)
  142. if int64(x) != s {
  143. err = strconv.ErrRange
  144. x = 0
  145. }
  146. *d = x
  147. case *bool:
  148. *d = s != 0
  149. case *interface{}:
  150. *d = s
  151. case nil:
  152. // skip value
  153. default:
  154. if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
  155. err = cannotConvert(d, s)
  156. } else {
  157. err = convertAssignInt(d.Elem(), s)
  158. }
  159. }
  160. case []interface{}:
  161. switch d := d.(type) {
  162. case *[]interface{}:
  163. *d = s
  164. case *interface{}:
  165. *d = s
  166. case nil:
  167. // skip value
  168. default:
  169. if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
  170. err = cannotConvert(d, s)
  171. } else {
  172. err = convertAssignValues(d.Elem(), s)
  173. }
  174. }
  175. case Error:
  176. err = s
  177. default:
  178. err = cannotConvert(reflect.ValueOf(d), s)
  179. }
  180. return
  181. }
  182. // Scan copies from the multi-bulk src to the values pointed at by dest.
  183. //
  184. // The values pointed at by dest must be an integer, float, boolean, string, or
  185. // []byte. Scan uses the standard strconv package to convert bulk values to
  186. // numeric and boolean types.
  187. //
  188. // If a dest value is nil, then the corresponding src value is skipped.
  189. //
  190. // If the multi-bulk value is nil, then the corresponding dest value is not
  191. // modified.
  192. //
  193. // To enable easy use of Scan in a loop, Scan returns the slice of src
  194. // following the copied values.
  195. func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) {
  196. if len(src) < len(dest) {
  197. return nil, errors.New("redigo: Scan multibulk short")
  198. }
  199. var err error
  200. for i, d := range dest {
  201. err = convertAssign(d, src[i])
  202. if err != nil {
  203. break
  204. }
  205. }
  206. return src[len(dest):], err
  207. }
  208. type fieldSpec struct {
  209. name string
  210. index []int
  211. //omitEmpty bool
  212. }
  213. type structSpec struct {
  214. m map[string]*fieldSpec
  215. l []*fieldSpec
  216. }
  217. func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
  218. return ss.m[string(name)]
  219. }
  220. func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
  221. for i := 0; i < t.NumField(); i++ {
  222. f := t.Field(i)
  223. switch {
  224. case f.PkgPath != "":
  225. // Ignore unexported fields.
  226. case f.Anonymous:
  227. // TODO: Handle pointers. Requires change to decoder and
  228. // protection against infinite recursion.
  229. if f.Type.Kind() == reflect.Struct {
  230. compileStructSpec(f.Type, depth, append(index, i), ss)
  231. }
  232. default:
  233. fs := &fieldSpec{name: f.Name}
  234. tag := f.Tag.Get("redis")
  235. p := strings.Split(tag, ",")
  236. if len(p) > 0 {
  237. if p[0] == "-" {
  238. continue
  239. }
  240. if len(p[0]) > 0 {
  241. fs.name = p[0]
  242. }
  243. for _, s := range p[1:] {
  244. switch s {
  245. //case "omitempty":
  246. // fs.omitempty = true
  247. default:
  248. panic(errors.New("redigo: unknown field flag " + s + " for type " + t.Name()))
  249. }
  250. }
  251. }
  252. d, found := depth[fs.name]
  253. if !found {
  254. d = 1 << 30
  255. }
  256. switch {
  257. case len(index) == d:
  258. // At same depth, remove from result.
  259. delete(ss.m, fs.name)
  260. j := 0
  261. for i := 0; i < len(ss.l); i++ {
  262. if fs.name != ss.l[i].name {
  263. ss.l[j] = ss.l[i]
  264. j += 1
  265. }
  266. }
  267. ss.l = ss.l[:j]
  268. case len(index) < d:
  269. fs.index = make([]int, len(index)+1)
  270. copy(fs.index, index)
  271. fs.index[len(index)] = i
  272. depth[fs.name] = len(index)
  273. ss.m[fs.name] = fs
  274. ss.l = append(ss.l, fs)
  275. }
  276. }
  277. }
  278. }
  279. var (
  280. structSpecMutex sync.RWMutex
  281. structSpecCache = make(map[reflect.Type]*structSpec)
  282. defaultFieldSpec = &fieldSpec{}
  283. )
  284. func structSpecForType(t reflect.Type) *structSpec {
  285. structSpecMutex.RLock()
  286. ss, found := structSpecCache[t]
  287. structSpecMutex.RUnlock()
  288. if found {
  289. return ss
  290. }
  291. structSpecMutex.Lock()
  292. defer structSpecMutex.Unlock()
  293. ss, found = structSpecCache[t]
  294. if found {
  295. return ss
  296. }
  297. ss = &structSpec{m: make(map[string]*fieldSpec)}
  298. compileStructSpec(t, make(map[string]int), nil, ss)
  299. structSpecCache[t] = ss
  300. return ss
  301. }
  302. var scanStructValueError = errors.New("redigo: ScanStruct value must be non-nil pointer to a struct.")
  303. // ScanStruct scans a multi-bulk src containing alternating names and values to
  304. // a struct. The HGETALL and CONFIG GET commands return replies in this format.
  305. //
  306. // ScanStruct uses the struct field name to match values in the response. Use
  307. // 'redis' field tag to override the name:
  308. //
  309. // Field int `redis:"myName"`
  310. //
  311. // Fields with the tag redis:"-" are ignored.
  312. //
  313. // Integer, float boolean string and []byte fields are supported. Scan uses
  314. // the standard strconv package to convert bulk values to numeric and boolean
  315. // types.
  316. //
  317. // If the multi-bulk value is nil, then the corresponding field is not
  318. // modified.
  319. func ScanStruct(src []interface{}, dest interface{}) error {
  320. d := reflect.ValueOf(dest)
  321. if d.Kind() != reflect.Ptr || d.IsNil() {
  322. return scanStructValueError
  323. }
  324. d = d.Elem()
  325. if d.Kind() != reflect.Struct {
  326. return scanStructValueError
  327. }
  328. ss := structSpecForType(d.Type())
  329. if len(src)%2 != 0 {
  330. return errors.New("redigo: ScanStruct expects even number of values in values")
  331. }
  332. for i := 0; i < len(src); i += 2 {
  333. name, ok := src[i].([]byte)
  334. if !ok {
  335. return errors.New("redigo: ScanStruct key not a bulk value")
  336. }
  337. fs := ss.fieldSpec(name)
  338. if fs == nil {
  339. continue
  340. }
  341. f := d.FieldByIndex(fs.index)
  342. var err error
  343. switch s := src[i+1].(type) {
  344. case nil:
  345. // ignore
  346. case []byte:
  347. err = convertAssignBytes(f, s)
  348. case int64:
  349. err = convertAssignInt(f, s)
  350. default:
  351. err = cannotConvert(f, s)
  352. }
  353. if err != nil {
  354. return err
  355. }
  356. }
  357. return nil
  358. }
  359. var (
  360. scanSliceValueError = errors.New("redigo: ScanSlice dest must be non-nil pointer to a struct.")
  361. scanSliceSrcError = errors.New("redigo: ScanSlice src element must be bulk or nil.")
  362. )
  363. // ScanSlice scans multi-bulk src to the slice pointed to by dest. The elements
  364. // the dest slice must be integer, float, boolean, string, struct or pointer to
  365. // struct values.
  366. //
  367. // Struct fields must be integer, float, boolean or string values. All struct
  368. // fields are used unless a subset is specified using fieldNames.
  369. func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error {
  370. d := reflect.ValueOf(dest)
  371. if d.Kind() != reflect.Ptr || d.IsNil() {
  372. return scanSliceValueError
  373. }
  374. d = d.Elem()
  375. if d.Kind() != reflect.Slice {
  376. return scanSliceValueError
  377. }
  378. isPtr := false
  379. t := d.Type().Elem()
  380. if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
  381. isPtr = true
  382. t = t.Elem()
  383. }
  384. if t.Kind() != reflect.Struct {
  385. ensureLen(d, len(src))
  386. for i, s := range src {
  387. if s == nil {
  388. continue
  389. }
  390. s, ok := s.([]byte)
  391. if !ok {
  392. return scanSliceSrcError
  393. }
  394. if err := convertAssignBytes(d.Index(i), s); err != nil {
  395. return err
  396. }
  397. }
  398. return nil
  399. }
  400. ss := structSpecForType(t)
  401. fss := ss.l
  402. if len(fieldNames) > 0 {
  403. fss = make([]*fieldSpec, len(fieldNames))
  404. for i, name := range fieldNames {
  405. fss[i] = ss.m[name]
  406. if fss[i] == nil {
  407. return errors.New("redigo: bad field name " + name)
  408. }
  409. }
  410. }
  411. n := len(src) / len(fss)
  412. if n*len(fss) != len(src) {
  413. return errors.New("redigo: length of ScanSlice not a multiple of struct field count.")
  414. }
  415. ensureLen(d, n)
  416. for i := 0; i < n; i++ {
  417. d := d.Index(i)
  418. if isPtr {
  419. if d.IsNil() {
  420. d.Set(reflect.New(t))
  421. }
  422. d = d.Elem()
  423. }
  424. for j, fs := range fss {
  425. s := src[i*len(fss)+j]
  426. if s == nil {
  427. continue
  428. }
  429. sb, ok := s.([]byte)
  430. if !ok {
  431. return scanSliceSrcError
  432. }
  433. d := d.FieldByIndex(fs.index)
  434. if err := convertAssignBytes(d, sb); err != nil {
  435. return err
  436. }
  437. }
  438. }
  439. return nil
  440. }
  441. // Args is a helper for constructing command arguments from structured values.
  442. type Args []interface{}
  443. // Add returns the result of appending value to args.
  444. func (args Args) Add(value interface{}) Args {
  445. return append(args, value)
  446. }
  447. // AddFlat returns the result of appending the flattened value of v to args.
  448. //
  449. // Maps are flattened by appending the alternating keys and map values to args.
  450. //
  451. // Slices are flattened by appending the slice elements to args.
  452. //
  453. // Structs are flattened by appending the alternating field names and field
  454. // values to args. If v is a nil struct pointer, then nothing is appended. The
  455. // 'redis' field tag overrides struct field names. See ScanStruct for more
  456. // information on the use of the 'redis' field tag.
  457. //
  458. // Other types are appended to args as is.
  459. func (args Args) AddFlat(v interface{}) Args {
  460. rv := reflect.ValueOf(v)
  461. switch rv.Kind() {
  462. case reflect.Struct:
  463. args = flattenStruct(args, rv)
  464. case reflect.Slice:
  465. for i := 0; i < rv.Len(); i++ {
  466. args = append(args, rv.Index(i).Interface())
  467. }
  468. case reflect.Map:
  469. for _, k := range rv.MapKeys() {
  470. args = append(args, k.Interface(), rv.MapIndex(k).Interface())
  471. }
  472. case reflect.Ptr:
  473. if rv.Type().Elem().Kind() == reflect.Struct {
  474. if !rv.IsNil() {
  475. args = flattenStruct(args, rv.Elem())
  476. }
  477. } else {
  478. args = append(args, v)
  479. }
  480. default:
  481. args = append(args, v)
  482. }
  483. return args
  484. }
  485. func flattenStruct(args Args, v reflect.Value) Args {
  486. ss := structSpecForType(v.Type())
  487. for _, fs := range ss.l {
  488. fv := v.FieldByIndex(fs.index)
  489. args = append(args, fs.name, fv.Interface())
  490. }
  491. return args
  492. }