decoder.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. // Package ndr provides the ability to unmarshal NDR encoded byte steams into Go data structures
  2. package ndr
  3. import (
  4. "bufio"
  5. "fmt"
  6. "io"
  7. "reflect"
  8. "strings"
  9. )
  10. // Struct tag values
  11. const (
  12. TagConformant = "conformant"
  13. TagVarying = "varying"
  14. TagPointer = "pointer"
  15. TagPipe = "pipe"
  16. )
  17. // Decoder unmarshals NDR byte stream data into a Go struct representation
  18. type Decoder struct {
  19. r *bufio.Reader // source of the data
  20. size int // initial size of bytes in buffer
  21. ch CommonHeader // NDR common header
  22. ph PrivateHeader // NDR private header
  23. conformantMax []uint32 // conformant max values that were moved to the beginning of the structure
  24. s interface{} // pointer to the structure being populated
  25. current []string // keeps track of the current field being populated
  26. }
  27. type deferedPtr struct {
  28. v reflect.Value
  29. tag reflect.StructTag
  30. }
  31. // NewDecoder creates a new instance of a NDR Decoder.
  32. func NewDecoder(r io.Reader) *Decoder {
  33. dec := new(Decoder)
  34. dec.r = bufio.NewReader(r)
  35. dec.r.Peek(int(commonHeaderBytes)) // For some reason an operation is needed on the buffer to initialise it so Buffered() != 0
  36. dec.size = dec.r.Buffered()
  37. return dec
  38. }
  39. // Decode unmarshals the NDR encoded bytes into the pointer of a struct provided.
  40. func (dec *Decoder) Decode(s interface{}) error {
  41. dec.s = s
  42. err := dec.readCommonHeader()
  43. if err != nil {
  44. return err
  45. }
  46. err = dec.readPrivateHeader()
  47. if err != nil {
  48. return err
  49. }
  50. _, err = dec.r.Discard(4) //The next 4 bytes are an RPC unique pointer referent. We just skip these.
  51. if err != nil {
  52. return Errorf("unable to process byte stream: %v", err)
  53. }
  54. return dec.process(s, reflect.StructTag(""))
  55. }
  56. func (dec *Decoder) process(s interface{}, tag reflect.StructTag) error {
  57. // Scan for conformant fields as their max counts are moved to the beginning
  58. // http://pubs.opengroup.org/onlinepubs/9629399/chap14.htm#tagfcjh_37
  59. err := dec.scanConformantArrays(s, tag)
  60. if err != nil {
  61. return err
  62. }
  63. // Recursively fill the struct fields
  64. var localDef []deferedPtr
  65. err = dec.fill(s, tag, &localDef)
  66. if err != nil {
  67. return Errorf("could not decode: %v", err)
  68. }
  69. // Read any deferred referents associated with pointers
  70. for _, p := range localDef {
  71. err = dec.process(p.v, p.tag)
  72. if err != nil {
  73. return fmt.Errorf("could not decode deferred referent: %v", err)
  74. }
  75. }
  76. return nil
  77. }
  78. // scanConformantArrays scans the structure for embedded conformant fields and captures the maximum element counts for
  79. // dimensions of the array that are moved to the beginning of the structure.
  80. func (dec *Decoder) scanConformantArrays(s interface{}, tag reflect.StructTag) error {
  81. err := dec.conformantScan(s, tag)
  82. if err != nil {
  83. return fmt.Errorf("failed to scan for embedded conformant arrays: %v", err)
  84. }
  85. for i := range dec.conformantMax {
  86. dec.conformantMax[i], err = dec.readUint32()
  87. if err != nil {
  88. return fmt.Errorf("could not read preceding conformant max count index %d: %v", i, err)
  89. }
  90. }
  91. return nil
  92. }
  93. // conformantScan inspects the structure's fields for whether they are conformant.
  94. func (dec *Decoder) conformantScan(s interface{}, tag reflect.StructTag) error {
  95. ndrTag := parseTags(tag)
  96. if ndrTag.HasValue(TagPointer) {
  97. return nil
  98. }
  99. v := getReflectValue(s)
  100. switch v.Kind() {
  101. case reflect.Struct:
  102. for i := 0; i < v.NumField(); i++ {
  103. err := dec.conformantScan(v.Field(i), v.Type().Field(i).Tag)
  104. if err != nil {
  105. return err
  106. }
  107. }
  108. case reflect.String:
  109. if !ndrTag.HasValue(TagConformant) {
  110. break
  111. }
  112. dec.conformantMax = append(dec.conformantMax, uint32(0))
  113. case reflect.Slice:
  114. if !ndrTag.HasValue(TagConformant) {
  115. break
  116. }
  117. d, t := sliceDimensions(v.Type())
  118. for i := 0; i < d; i++ {
  119. dec.conformantMax = append(dec.conformantMax, uint32(0))
  120. }
  121. // For string arrays there is a common max for the strings within the array.
  122. if t.Kind() == reflect.String {
  123. dec.conformantMax = append(dec.conformantMax, uint32(0))
  124. }
  125. }
  126. return nil
  127. }
  128. func (dec *Decoder) isPointer(v reflect.Value, tag reflect.StructTag, def *[]deferedPtr) (bool, error) {
  129. // Pointer so defer filling the referent
  130. ndrTag := parseTags(tag)
  131. if ndrTag.HasValue(TagPointer) {
  132. p, err := dec.readUint32()
  133. if err != nil {
  134. return true, fmt.Errorf("could not read pointer: %v", err)
  135. }
  136. ndrTag.delete(TagPointer)
  137. if p != 0 {
  138. // if pointer is not zero add to the deferred items at end of stream
  139. *def = append(*def, deferedPtr{v, ndrTag.StructTag()})
  140. }
  141. return true, nil
  142. }
  143. return false, nil
  144. }
  145. func getReflectValue(s interface{}) (v reflect.Value) {
  146. if r, ok := s.(reflect.Value); ok {
  147. v = r
  148. } else {
  149. if reflect.ValueOf(s).Kind() == reflect.Ptr {
  150. v = reflect.ValueOf(s).Elem()
  151. }
  152. }
  153. return
  154. }
  155. // fill populates fields with values from the NDR byte stream.
  156. func (dec *Decoder) fill(s interface{}, tag reflect.StructTag, localDef *[]deferedPtr) error {
  157. v := getReflectValue(s)
  158. //// Pointer so defer filling the referent
  159. ptr, err := dec.isPointer(v, tag, localDef)
  160. if err != nil {
  161. return fmt.Errorf("could not process struct field(%s): %v", strings.Join(dec.current, "/"), err)
  162. }
  163. if ptr {
  164. return nil
  165. }
  166. // Populate the value from the byte stream
  167. switch v.Kind() {
  168. case reflect.Struct:
  169. dec.current = append(dec.current, v.Type().Name()) //Track the current field being filled
  170. // in case struct is a union, track this and the selected union field for efficiency
  171. var unionTag reflect.Value
  172. var unionField string // field to fill if struct is a union
  173. // Go through each field in the struct and recursively fill
  174. for i := 0; i < v.NumField(); i++ {
  175. fieldName := v.Type().Field(i).Name
  176. dec.current = append(dec.current, fieldName) //Track the current field being filled
  177. //fmt.Fprintf(os.Stderr, "DEBUG Decoding: %s\n", strings.Join(dec.current, "/"))
  178. structTag := v.Type().Field(i).Tag
  179. ndrTag := parseTags(structTag)
  180. // Union handling
  181. if !unionTag.IsValid() {
  182. // Is this field a union tag?
  183. unionTag = dec.isUnion(v.Field(i), structTag)
  184. } else {
  185. // What is the selected field value of the union if we don't already know
  186. if unionField == "" {
  187. unionField, err = unionSelectedField(v, unionTag)
  188. if err != nil {
  189. return fmt.Errorf("could not determine selected union value field for %s with discriminat"+
  190. " tag %s: %v", v.Type().Name(), unionTag, err)
  191. }
  192. }
  193. if ndrTag.HasValue(TagUnionField) && fieldName != unionField {
  194. // is a union and this field has not been selected so will skip it.
  195. dec.current = dec.current[:len(dec.current)-1] //This field has been skipped so remove it from the current field tracker
  196. continue
  197. }
  198. }
  199. // Check if field is a pointer
  200. if v.Field(i).Type().Implements(reflect.TypeOf(new(RawBytes)).Elem()) &&
  201. v.Field(i).Type().Kind() == reflect.Slice && v.Field(i).Type().Elem().Kind() == reflect.Uint8 {
  202. //field is for rawbytes
  203. structTag, err = addSizeToTag(v, v.Field(i), structTag)
  204. if err != nil {
  205. return fmt.Errorf("could not get rawbytes field(%s) size: %v", strings.Join(dec.current, "/"), err)
  206. }
  207. ptr, err := dec.isPointer(v.Field(i), structTag, localDef)
  208. if err != nil {
  209. return fmt.Errorf("could not process struct field(%s): %v", strings.Join(dec.current, "/"), err)
  210. }
  211. if !ptr {
  212. err := dec.readRawBytes(v.Field(i), structTag)
  213. if err != nil {
  214. return fmt.Errorf("could not fill raw bytes struct field(%s): %v", strings.Join(dec.current, "/"), err)
  215. }
  216. }
  217. } else {
  218. err := dec.fill(v.Field(i), structTag, localDef)
  219. if err != nil {
  220. return fmt.Errorf("could not fill struct field(%s): %v", strings.Join(dec.current, "/"), err)
  221. }
  222. }
  223. dec.current = dec.current[:len(dec.current)-1] //This field has been filled so remove it from the current field tracker
  224. }
  225. dec.current = dec.current[:len(dec.current)-1] //This field has been filled so remove it from the current field tracker
  226. case reflect.Bool:
  227. i, err := dec.readBool()
  228. if err != nil {
  229. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  230. }
  231. v.Set(reflect.ValueOf(i))
  232. case reflect.Uint8:
  233. i, err := dec.readUint8()
  234. if err != nil {
  235. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  236. }
  237. v.Set(reflect.ValueOf(i))
  238. case reflect.Uint16:
  239. i, err := dec.readUint16()
  240. if err != nil {
  241. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  242. }
  243. v.Set(reflect.ValueOf(i))
  244. case reflect.Uint32:
  245. i, err := dec.readUint32()
  246. if err != nil {
  247. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  248. }
  249. v.Set(reflect.ValueOf(i))
  250. case reflect.Uint64:
  251. i, err := dec.readUint64()
  252. if err != nil {
  253. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  254. }
  255. v.Set(reflect.ValueOf(i))
  256. case reflect.Int8:
  257. i, err := dec.readInt8()
  258. if err != nil {
  259. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  260. }
  261. v.Set(reflect.ValueOf(i))
  262. case reflect.Int16:
  263. i, err := dec.readInt16()
  264. if err != nil {
  265. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  266. }
  267. v.Set(reflect.ValueOf(i))
  268. case reflect.Int32:
  269. i, err := dec.readInt32()
  270. if err != nil {
  271. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  272. }
  273. v.Set(reflect.ValueOf(i))
  274. case reflect.Int64:
  275. i, err := dec.readInt64()
  276. if err != nil {
  277. return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
  278. }
  279. v.Set(reflect.ValueOf(i))
  280. case reflect.String:
  281. ndrTag := parseTags(tag)
  282. conformant := ndrTag.HasValue(TagConformant)
  283. // strings are always varying so this is assumed without an explicit tag
  284. var s string
  285. var err error
  286. if conformant {
  287. s, err = dec.readConformantVaryingString(localDef)
  288. if err != nil {
  289. return fmt.Errorf("could not fill with conformant varying string: %v", err)
  290. }
  291. } else {
  292. s, err = dec.readVaryingString(localDef)
  293. if err != nil {
  294. return fmt.Errorf("could not fill with varying string: %v", err)
  295. }
  296. }
  297. v.Set(reflect.ValueOf(s))
  298. case reflect.Float32:
  299. i, err := dec.readFloat32()
  300. if err != nil {
  301. return fmt.Errorf("could not fill %v: %v", v.Type().Name(), err)
  302. }
  303. v.Set(reflect.ValueOf(i))
  304. case reflect.Float64:
  305. i, err := dec.readFloat64()
  306. if err != nil {
  307. return fmt.Errorf("could not fill %v: %v", v.Type().Name(), err)
  308. }
  309. v.Set(reflect.ValueOf(i))
  310. case reflect.Array:
  311. err := dec.fillFixedArray(v, tag, localDef)
  312. if err != nil {
  313. return err
  314. }
  315. case reflect.Slice:
  316. if v.Type().Implements(reflect.TypeOf(new(RawBytes)).Elem()) && v.Type().Elem().Kind() == reflect.Uint8 {
  317. //field is for rawbytes
  318. err := dec.readRawBytes(v, tag)
  319. if err != nil {
  320. return fmt.Errorf("could not fill raw bytes struct field(%s): %v", strings.Join(dec.current, "/"), err)
  321. }
  322. break
  323. }
  324. ndrTag := parseTags(tag)
  325. conformant := ndrTag.HasValue(TagConformant)
  326. varying := ndrTag.HasValue(TagVarying)
  327. if ndrTag.HasValue(TagPipe) {
  328. err := dec.fillPipe(v, tag)
  329. if err != nil {
  330. return err
  331. }
  332. break
  333. }
  334. _, t := sliceDimensions(v.Type())
  335. if t.Kind() == reflect.String && !ndrTag.HasValue(subStringArrayValue) {
  336. // String array
  337. err := dec.readStringsArray(v, tag, localDef)
  338. if err != nil {
  339. return err
  340. }
  341. break
  342. }
  343. // varying is assumed as fixed arrays use the Go array type rather than slice
  344. if conformant && varying {
  345. err := dec.fillConformantVaryingArray(v, tag, localDef)
  346. if err != nil {
  347. return err
  348. }
  349. } else if !conformant && varying {
  350. err := dec.fillVaryingArray(v, tag, localDef)
  351. if err != nil {
  352. return err
  353. }
  354. } else {
  355. //default to conformant and not varying
  356. err := dec.fillConformantArray(v, tag, localDef)
  357. if err != nil {
  358. return err
  359. }
  360. }
  361. default:
  362. return fmt.Errorf("unsupported type")
  363. }
  364. return nil
  365. }
  366. // readBytes returns a number of bytes from the NDR byte stream.
  367. func (dec *Decoder) readBytes(n int) ([]byte, error) {
  368. //TODO make this take an int64 as input to allow for larger values on all systems?
  369. b := make([]byte, n, n)
  370. m, err := dec.r.Read(b)
  371. if err != nil || m != n {
  372. return b, fmt.Errorf("error reading bytes from stream: %v", err)
  373. }
  374. return b, nil
  375. }