table_merge.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. // Copyright 2016 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "fmt"
  7. "reflect"
  8. "strings"
  9. "sync"
  10. "sync/atomic"
  11. )
  12. // Merge merges the src message into dst.
  13. // This assumes that dst and src of the same type and are non-nil.
  14. func (a *InternalMessageInfo) Merge(dst, src Message) {
  15. mi := atomicLoadMergeInfo(&a.merge)
  16. if mi == nil {
  17. mi = getMergeInfo(reflect.TypeOf(dst).Elem())
  18. atomicStoreMergeInfo(&a.merge, mi)
  19. }
  20. mi.merge(toPointer(&dst), toPointer(&src))
  21. }
  22. type mergeInfo struct {
  23. typ reflect.Type
  24. initialized int32 // 0: only typ is valid, 1: everything is valid
  25. lock sync.Mutex
  26. fields []mergeFieldInfo
  27. unrecognized field // Offset of XXX_unrecognized
  28. }
  29. type mergeFieldInfo struct {
  30. field field // Offset of field, guaranteed to be valid
  31. // isPointer reports whether the value in the field is a pointer.
  32. // This is true for the following situations:
  33. // * Pointer to struct
  34. // * Pointer to basic type (proto2 only)
  35. // * Slice (first value in slice header is a pointer)
  36. // * String (first value in string header is a pointer)
  37. isPointer bool
  38. // basicWidth reports the width of the field assuming that it is directly
  39. // embedded in the struct (as is the case for basic types in proto3).
  40. // The possible values are:
  41. // 0: invalid
  42. // 1: bool
  43. // 4: int32, uint32, float32
  44. // 8: int64, uint64, float64
  45. basicWidth int
  46. // Where dst and src are pointers to the types being merged.
  47. merge func(dst, src pointer)
  48. }
  49. var (
  50. mergeInfoMap = map[reflect.Type]*mergeInfo{}
  51. mergeInfoLock sync.Mutex
  52. )
  53. func getMergeInfo(t reflect.Type) *mergeInfo {
  54. mergeInfoLock.Lock()
  55. defer mergeInfoLock.Unlock()
  56. mi := mergeInfoMap[t]
  57. if mi == nil {
  58. mi = &mergeInfo{typ: t}
  59. mergeInfoMap[t] = mi
  60. }
  61. return mi
  62. }
  63. // merge merges src into dst assuming they are both of type *mi.typ.
  64. func (mi *mergeInfo) merge(dst, src pointer) {
  65. if dst.isNil() {
  66. panic("proto: nil destination")
  67. }
  68. if src.isNil() {
  69. return // Nothing to do.
  70. }
  71. if atomic.LoadInt32(&mi.initialized) == 0 {
  72. mi.computeMergeInfo()
  73. }
  74. for _, fi := range mi.fields {
  75. sfp := src.offset(fi.field)
  76. // As an optimization, we can avoid the merge function call cost
  77. // if we know for sure that the source will have no effect
  78. // by checking if it is the zero value.
  79. if unsafeAllowed {
  80. if fi.isPointer && sfp.getPointer().isNil() { // Could be slice or string
  81. continue
  82. }
  83. if fi.basicWidth > 0 {
  84. switch {
  85. case fi.basicWidth == 1 && !*sfp.toBool():
  86. continue
  87. case fi.basicWidth == 4 && *sfp.toUint32() == 0:
  88. continue
  89. case fi.basicWidth == 8 && *sfp.toUint64() == 0:
  90. continue
  91. }
  92. }
  93. }
  94. dfp := dst.offset(fi.field)
  95. fi.merge(dfp, sfp)
  96. }
  97. // TODO: Make this faster?
  98. out := dst.asPointerTo(mi.typ).Elem()
  99. in := src.asPointerTo(mi.typ).Elem()
  100. if emIn, err := extendable(in.Addr().Interface()); err == nil {
  101. emOut, _ := extendable(out.Addr().Interface())
  102. if emIn != nil {
  103. mergeExtension(emOut, emIn)
  104. }
  105. }
  106. if mi.unrecognized.IsValid() {
  107. if b := *src.offset(mi.unrecognized).toBytes(); len(b) > 0 {
  108. *dst.offset(mi.unrecognized).toBytes() = append([]byte(nil), b...)
  109. }
  110. }
  111. }
  112. func (mi *mergeInfo) computeMergeInfo() {
  113. mi.lock.Lock()
  114. defer mi.lock.Unlock()
  115. if mi.initialized != 0 {
  116. return
  117. }
  118. t := mi.typ
  119. n := t.NumField()
  120. props := GetProperties(t)
  121. for i := 0; i < n; i++ {
  122. f := t.Field(i)
  123. if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" {
  124. continue
  125. }
  126. mfi := mergeFieldInfo{field: toField(&f, nil)}
  127. tf := f.Type
  128. // As an optimization, we can avoid the merge function call cost
  129. // if we know for sure that the source will have no effect
  130. // by checking if it is the zero value.
  131. if unsafeAllowed {
  132. switch tf.Kind() {
  133. case reflect.Ptr, reflect.Slice, reflect.String:
  134. // As a special case, we assume slices and strings are pointers
  135. // since we know that the first field in the SliceSlice or
  136. // StringHeader is a data pointer.
  137. mfi.isPointer = true
  138. case reflect.Bool:
  139. mfi.basicWidth = 1
  140. case reflect.Int32, reflect.Uint32, reflect.Float32:
  141. mfi.basicWidth = 4
  142. case reflect.Int64, reflect.Uint64, reflect.Float64:
  143. mfi.basicWidth = 8
  144. }
  145. }
  146. // Unwrap tf to get at its most basic type.
  147. var isPointer, isSlice bool
  148. if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
  149. isSlice = true
  150. tf = tf.Elem()
  151. }
  152. if tf.Kind() == reflect.Ptr {
  153. isPointer = true
  154. tf = tf.Elem()
  155. }
  156. if isPointer && isSlice && tf.Kind() != reflect.Struct {
  157. panic("both pointer and slice for basic type in " + tf.Name())
  158. }
  159. switch tf.Kind() {
  160. case reflect.Int32:
  161. switch {
  162. case isSlice: // E.g., []int32
  163. mfi.merge = func(dst, src pointer) {
  164. // NOTE: toInt32Slice is not defined (see pointer_reflect.go).
  165. /*
  166. sfsp := src.toInt32Slice()
  167. if *sfsp != nil {
  168. dfsp := dst.toInt32Slice()
  169. *dfsp = append(*dfsp, *sfsp...)
  170. if *dfsp == nil {
  171. *dfsp = []int64{}
  172. }
  173. }
  174. */
  175. sfs := src.getInt32Slice()
  176. if sfs != nil {
  177. dfs := dst.getInt32Slice()
  178. dfs = append(dfs, sfs...)
  179. if dfs == nil {
  180. dfs = []int32{}
  181. }
  182. dst.setInt32Slice(dfs)
  183. }
  184. }
  185. case isPointer: // E.g., *int32
  186. mfi.merge = func(dst, src pointer) {
  187. // NOTE: toInt32Ptr is not defined (see pointer_reflect.go).
  188. /*
  189. sfpp := src.toInt32Ptr()
  190. if *sfpp != nil {
  191. dfpp := dst.toInt32Ptr()
  192. if *dfpp == nil {
  193. *dfpp = Int32(**sfpp)
  194. } else {
  195. **dfpp = **sfpp
  196. }
  197. }
  198. */
  199. sfp := src.getInt32Ptr()
  200. if sfp != nil {
  201. dfp := dst.getInt32Ptr()
  202. if dfp == nil {
  203. dst.setInt32Ptr(*sfp)
  204. } else {
  205. *dfp = *sfp
  206. }
  207. }
  208. }
  209. default: // E.g., int32
  210. mfi.merge = func(dst, src pointer) {
  211. if v := *src.toInt32(); v != 0 {
  212. *dst.toInt32() = v
  213. }
  214. }
  215. }
  216. case reflect.Int64:
  217. switch {
  218. case isSlice: // E.g., []int64
  219. mfi.merge = func(dst, src pointer) {
  220. sfsp := src.toInt64Slice()
  221. if *sfsp != nil {
  222. dfsp := dst.toInt64Slice()
  223. *dfsp = append(*dfsp, *sfsp...)
  224. if *dfsp == nil {
  225. *dfsp = []int64{}
  226. }
  227. }
  228. }
  229. case isPointer: // E.g., *int64
  230. mfi.merge = func(dst, src pointer) {
  231. sfpp := src.toInt64Ptr()
  232. if *sfpp != nil {
  233. dfpp := dst.toInt64Ptr()
  234. if *dfpp == nil {
  235. *dfpp = Int64(**sfpp)
  236. } else {
  237. **dfpp = **sfpp
  238. }
  239. }
  240. }
  241. default: // E.g., int64
  242. mfi.merge = func(dst, src pointer) {
  243. if v := *src.toInt64(); v != 0 {
  244. *dst.toInt64() = v
  245. }
  246. }
  247. }
  248. case reflect.Uint32:
  249. switch {
  250. case isSlice: // E.g., []uint32
  251. mfi.merge = func(dst, src pointer) {
  252. sfsp := src.toUint32Slice()
  253. if *sfsp != nil {
  254. dfsp := dst.toUint32Slice()
  255. *dfsp = append(*dfsp, *sfsp...)
  256. if *dfsp == nil {
  257. *dfsp = []uint32{}
  258. }
  259. }
  260. }
  261. case isPointer: // E.g., *uint32
  262. mfi.merge = func(dst, src pointer) {
  263. sfpp := src.toUint32Ptr()
  264. if *sfpp != nil {
  265. dfpp := dst.toUint32Ptr()
  266. if *dfpp == nil {
  267. *dfpp = Uint32(**sfpp)
  268. } else {
  269. **dfpp = **sfpp
  270. }
  271. }
  272. }
  273. default: // E.g., uint32
  274. mfi.merge = func(dst, src pointer) {
  275. if v := *src.toUint32(); v != 0 {
  276. *dst.toUint32() = v
  277. }
  278. }
  279. }
  280. case reflect.Uint64:
  281. switch {
  282. case isSlice: // E.g., []uint64
  283. mfi.merge = func(dst, src pointer) {
  284. sfsp := src.toUint64Slice()
  285. if *sfsp != nil {
  286. dfsp := dst.toUint64Slice()
  287. *dfsp = append(*dfsp, *sfsp...)
  288. if *dfsp == nil {
  289. *dfsp = []uint64{}
  290. }
  291. }
  292. }
  293. case isPointer: // E.g., *uint64
  294. mfi.merge = func(dst, src pointer) {
  295. sfpp := src.toUint64Ptr()
  296. if *sfpp != nil {
  297. dfpp := dst.toUint64Ptr()
  298. if *dfpp == nil {
  299. *dfpp = Uint64(**sfpp)
  300. } else {
  301. **dfpp = **sfpp
  302. }
  303. }
  304. }
  305. default: // E.g., uint64
  306. mfi.merge = func(dst, src pointer) {
  307. if v := *src.toUint64(); v != 0 {
  308. *dst.toUint64() = v
  309. }
  310. }
  311. }
  312. case reflect.Float32:
  313. switch {
  314. case isSlice: // E.g., []float32
  315. mfi.merge = func(dst, src pointer) {
  316. sfsp := src.toFloat32Slice()
  317. if *sfsp != nil {
  318. dfsp := dst.toFloat32Slice()
  319. *dfsp = append(*dfsp, *sfsp...)
  320. if *dfsp == nil {
  321. *dfsp = []float32{}
  322. }
  323. }
  324. }
  325. case isPointer: // E.g., *float32
  326. mfi.merge = func(dst, src pointer) {
  327. sfpp := src.toFloat32Ptr()
  328. if *sfpp != nil {
  329. dfpp := dst.toFloat32Ptr()
  330. if *dfpp == nil {
  331. *dfpp = Float32(**sfpp)
  332. } else {
  333. **dfpp = **sfpp
  334. }
  335. }
  336. }
  337. default: // E.g., float32
  338. mfi.merge = func(dst, src pointer) {
  339. if v := *src.toFloat32(); v != 0 {
  340. *dst.toFloat32() = v
  341. }
  342. }
  343. }
  344. case reflect.Float64:
  345. switch {
  346. case isSlice: // E.g., []float64
  347. mfi.merge = func(dst, src pointer) {
  348. sfsp := src.toFloat64Slice()
  349. if *sfsp != nil {
  350. dfsp := dst.toFloat64Slice()
  351. *dfsp = append(*dfsp, *sfsp...)
  352. if *dfsp == nil {
  353. *dfsp = []float64{}
  354. }
  355. }
  356. }
  357. case isPointer: // E.g., *float64
  358. mfi.merge = func(dst, src pointer) {
  359. sfpp := src.toFloat64Ptr()
  360. if *sfpp != nil {
  361. dfpp := dst.toFloat64Ptr()
  362. if *dfpp == nil {
  363. *dfpp = Float64(**sfpp)
  364. } else {
  365. **dfpp = **sfpp
  366. }
  367. }
  368. }
  369. default: // E.g., float64
  370. mfi.merge = func(dst, src pointer) {
  371. if v := *src.toFloat64(); v != 0 {
  372. *dst.toFloat64() = v
  373. }
  374. }
  375. }
  376. case reflect.Bool:
  377. switch {
  378. case isSlice: // E.g., []bool
  379. mfi.merge = func(dst, src pointer) {
  380. sfsp := src.toBoolSlice()
  381. if *sfsp != nil {
  382. dfsp := dst.toBoolSlice()
  383. *dfsp = append(*dfsp, *sfsp...)
  384. if *dfsp == nil {
  385. *dfsp = []bool{}
  386. }
  387. }
  388. }
  389. case isPointer: // E.g., *bool
  390. mfi.merge = func(dst, src pointer) {
  391. sfpp := src.toBoolPtr()
  392. if *sfpp != nil {
  393. dfpp := dst.toBoolPtr()
  394. if *dfpp == nil {
  395. *dfpp = Bool(**sfpp)
  396. } else {
  397. **dfpp = **sfpp
  398. }
  399. }
  400. }
  401. default: // E.g., bool
  402. mfi.merge = func(dst, src pointer) {
  403. if v := *src.toBool(); v {
  404. *dst.toBool() = v
  405. }
  406. }
  407. }
  408. case reflect.String:
  409. switch {
  410. case isSlice: // E.g., []string
  411. mfi.merge = func(dst, src pointer) {
  412. sfsp := src.toStringSlice()
  413. if *sfsp != nil {
  414. dfsp := dst.toStringSlice()
  415. *dfsp = append(*dfsp, *sfsp...)
  416. if *dfsp == nil {
  417. *dfsp = []string{}
  418. }
  419. }
  420. }
  421. case isPointer: // E.g., *string
  422. mfi.merge = func(dst, src pointer) {
  423. sfpp := src.toStringPtr()
  424. if *sfpp != nil {
  425. dfpp := dst.toStringPtr()
  426. if *dfpp == nil {
  427. *dfpp = String(**sfpp)
  428. } else {
  429. **dfpp = **sfpp
  430. }
  431. }
  432. }
  433. default: // E.g., string
  434. mfi.merge = func(dst, src pointer) {
  435. if v := *src.toString(); v != "" {
  436. *dst.toString() = v
  437. }
  438. }
  439. }
  440. case reflect.Slice:
  441. isProto3 := props.Prop[i].Proto3
  442. switch {
  443. case isPointer:
  444. panic("bad pointer in byte slice case in " + tf.Name())
  445. case tf.Elem().Kind() != reflect.Uint8:
  446. panic("bad element kind in byte slice case in " + tf.Name())
  447. case isSlice: // E.g., [][]byte
  448. mfi.merge = func(dst, src pointer) {
  449. sbsp := src.toBytesSlice()
  450. if *sbsp != nil {
  451. dbsp := dst.toBytesSlice()
  452. for _, sb := range *sbsp {
  453. if sb == nil {
  454. *dbsp = append(*dbsp, nil)
  455. } else {
  456. *dbsp = append(*dbsp, append([]byte{}, sb...))
  457. }
  458. }
  459. if *dbsp == nil {
  460. *dbsp = [][]byte{}
  461. }
  462. }
  463. }
  464. default: // E.g., []byte
  465. mfi.merge = func(dst, src pointer) {
  466. sbp := src.toBytes()
  467. if *sbp != nil {
  468. dbp := dst.toBytes()
  469. if !isProto3 || len(*sbp) > 0 {
  470. *dbp = append([]byte{}, *sbp...)
  471. }
  472. }
  473. }
  474. }
  475. case reflect.Struct:
  476. switch {
  477. case !isPointer:
  478. panic(fmt.Sprintf("message field %s without pointer", tf))
  479. case isSlice: // E.g., []*pb.T
  480. mi := getMergeInfo(tf)
  481. mfi.merge = func(dst, src pointer) {
  482. sps := src.getPointerSlice()
  483. if sps != nil {
  484. dps := dst.getPointerSlice()
  485. for _, sp := range sps {
  486. var dp pointer
  487. if !sp.isNil() {
  488. dp = valToPointer(reflect.New(tf))
  489. mi.merge(dp, sp)
  490. }
  491. dps = append(dps, dp)
  492. }
  493. if dps == nil {
  494. dps = []pointer{}
  495. }
  496. dst.setPointerSlice(dps)
  497. }
  498. }
  499. default: // E.g., *pb.T
  500. mi := getMergeInfo(tf)
  501. mfi.merge = func(dst, src pointer) {
  502. sp := src.getPointer()
  503. if !sp.isNil() {
  504. dp := dst.getPointer()
  505. if dp.isNil() {
  506. dp = valToPointer(reflect.New(tf))
  507. dst.setPointer(dp)
  508. }
  509. mi.merge(dp, sp)
  510. }
  511. }
  512. }
  513. case reflect.Map:
  514. switch {
  515. case isPointer || isSlice:
  516. panic("bad pointer or slice in map case in " + tf.Name())
  517. default: // E.g., map[K]V
  518. mfi.merge = func(dst, src pointer) {
  519. sm := src.asPointerTo(tf).Elem()
  520. if sm.Len() == 0 {
  521. return
  522. }
  523. dm := dst.asPointerTo(tf).Elem()
  524. if dm.IsNil() {
  525. dm.Set(reflect.MakeMap(tf))
  526. }
  527. switch tf.Elem().Kind() {
  528. case reflect.Ptr: // Proto struct (e.g., *T)
  529. for _, key := range sm.MapKeys() {
  530. val := sm.MapIndex(key)
  531. val = reflect.ValueOf(Clone(val.Interface().(Message)))
  532. dm.SetMapIndex(key, val)
  533. }
  534. case reflect.Slice: // E.g. Bytes type (e.g., []byte)
  535. for _, key := range sm.MapKeys() {
  536. val := sm.MapIndex(key)
  537. val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
  538. dm.SetMapIndex(key, val)
  539. }
  540. default: // Basic type (e.g., string)
  541. for _, key := range sm.MapKeys() {
  542. val := sm.MapIndex(key)
  543. dm.SetMapIndex(key, val)
  544. }
  545. }
  546. }
  547. }
  548. case reflect.Interface:
  549. // Must be oneof field.
  550. switch {
  551. case isPointer || isSlice:
  552. panic("bad pointer or slice in interface case in " + tf.Name())
  553. default: // E.g., interface{}
  554. // TODO: Make this faster?
  555. mfi.merge = func(dst, src pointer) {
  556. su := src.asPointerTo(tf).Elem()
  557. if !su.IsNil() {
  558. du := dst.asPointerTo(tf).Elem()
  559. typ := su.Elem().Type()
  560. if du.IsNil() || du.Elem().Type() != typ {
  561. du.Set(reflect.New(typ.Elem())) // Initialize interface if empty
  562. }
  563. sv := su.Elem().Elem().Field(0)
  564. if sv.Kind() == reflect.Ptr && sv.IsNil() {
  565. return
  566. }
  567. dv := du.Elem().Elem().Field(0)
  568. if dv.Kind() == reflect.Ptr && dv.IsNil() {
  569. dv.Set(reflect.New(sv.Type().Elem())) // Initialize proto message if empty
  570. }
  571. switch sv.Type().Kind() {
  572. case reflect.Ptr: // Proto struct (e.g., *T)
  573. Merge(dv.Interface().(Message), sv.Interface().(Message))
  574. case reflect.Slice: // E.g. Bytes type (e.g., []byte)
  575. dv.Set(reflect.ValueOf(append([]byte{}, sv.Bytes()...)))
  576. default: // Basic type (e.g., string)
  577. dv.Set(sv)
  578. }
  579. }
  580. }
  581. }
  582. default:
  583. panic(fmt.Sprintf("merger not found for type:%s", tf))
  584. }
  585. mi.fields = append(mi.fields, mfi)
  586. }
  587. expFunc := exporterFunc(t)
  588. mi.unrecognized = invalidField
  589. if f, ok := t.FieldByName("XXX_unrecognized"); ok {
  590. if f.Type != reflect.TypeOf([]byte{}) {
  591. panic("expected XXX_unrecognized to be of type []byte")
  592. }
  593. mi.unrecognized = toField(&f, nil)
  594. }
  595. if f, ok := t.FieldByName("unknownFields"); ok {
  596. if f.Type != reflect.TypeOf([]byte{}) {
  597. panic("expected unknownFields to be of type []byte")
  598. }
  599. mi.unrecognized = toField(&f, expFunc)
  600. }
  601. atomic.StoreInt32(&mi.initialized, 1)
  602. }