Explorar o código

codec: added mapSet and mapDelete helpers (safe/unsafe variants)

Also, renamed mapIndex to mapGet.
Ugorji Nwoke %!s(int64=6) %!d(string=hai) anos
pai
achega
19cb75f56d
Modificáronse 6 ficheiros con 89 adicións e 43 borrados
  1. 2 2
      codec/codec_test.go
  2. 13 13
      codec/decode.go
  3. 23 8
      codec/encode.go
  4. 1 0
      codec/fast-path.go.tmpl
  5. 11 7
      codec/helper_not_unsafe.go
  6. 39 13
      codec/helper_unsafe.go

+ 2 - 2
codec/codec_test.go

@@ -2856,14 +2856,14 @@ func TestMapRangeIndex(t *testing.T) {
 	}
 	testDeepEqualErr(len(m2c), 0, t, "all-keys-not-consumed")
 
-	// ---- test mapIndex
+	// ---- test mapGet
 
 	fnTestMapIndex := func(mi ...interface{}) {
 		for _, m0 := range mi {
 			m := reflect.ValueOf(m0)
 			rvv := mapAddressableRV(m.Type().Elem())
 			for _, k := range m.MapKeys() {
-				testDeepEqualErr(m.MapIndex(k).Interface(), mapIndex(m, k, rvv).Interface(), t, "map-index-eq")
+				testDeepEqualErr(m.MapIndex(k).Interface(), mapGet(m, k, rvv).Interface(), t, "map-index-eq")
 			}
 		}
 	}

+ 13 - 13
codec/decode.go

@@ -832,16 +832,16 @@ func (d *Decoder) kMap(f *codecFnInfo, rv reflect.Value) {
 
 	rvvMut := !isImmutableKind(vtypeKind)
 
-	// we do a mapGet if kind is mutable, and InterfaceReset=true if interface
-	var mapGet, mapSet bool
+	// we do a doMapGet if kind is mutable, and InterfaceReset=true if interface
+	var doMapGet, doMapSet bool
 	if !d.h.MapValueReset {
 		if rvvMut {
 			if vtypeKind == reflect.Interface {
 				if !d.h.InterfaceReset {
-					mapGet = true
+					doMapGet = true
 				}
 			} else {
-				mapGet = true
+				doMapGet = true
 			}
 		}
 	}
@@ -898,7 +898,7 @@ func (d *Decoder) kMap(f *codecFnInfo, rv reflect.Value) {
 		// i.e. TryDecodeAsNil never shares slices with other decDriver procedures
 		if dd.TryDecodeAsNil() {
 			if d.h.DeleteOnNilMapValue {
-				rv.SetMapIndex(rvk, reflect.Value{})
+				mapDelete(rv, rvk)
 			} else {
 				if ktypeIsString { // set to a real string (not string view)
 					rvk.SetString(d.string(kstrbs))
@@ -906,21 +906,21 @@ func (d *Decoder) kMap(f *codecFnInfo, rv reflect.Value) {
 				if !rvvz.IsValid() {
 					rvvz = reflect.Zero(vtype)
 				}
-				rv.SetMapIndex(rvk, rvvz)
+				mapSet(rv, rvk, rvvz)
 			}
 			continue
 		}
 
-		mapSet = true // set to false if u do a get, and its a non-nil pointer
-		if mapGet {
+		doMapSet = true // set to false if u do a get, and its a non-nil pointer
+		if doMapGet {
 			if !rvvaSet {
 				rvva = mapAddressableRV(vtype)
 				rvvaSet = true
 			}
-			rvv = mapIndex(rv, rvk, rvva) // reflect.Value{})
+			rvv = mapGet(rv, rvk, rvva) // reflect.Value{})
 			if vtypeKind == reflect.Ptr {
 				if rvv.IsValid() && !rvisnil(rvv) {
-					mapSet = false
+					doMapSet = false
 				} else {
 					rvv = reflect.New(vtype.Elem())
 				}
@@ -945,12 +945,12 @@ func (d *Decoder) kMap(f *codecFnInfo, rv reflect.Value) {
 
 		// We MUST be done with the stringview of the key, BEFORE decoding the value (rvv)
 		// so that we don't unknowingly reuse the rvk backing buffer during rvv decode.
-		if mapSet && ktypeIsString { // set to a real string (not string view)
+		if doMapSet && ktypeIsString { // set to a real string (not string view)
 			rvk.SetString(d.string(kstrbs))
 		}
 		d.decodeValue(rvv, valFn)
-		if mapSet {
-			rv.SetMapIndex(rvk, rvv)
+		if doMapSet {
+			mapSet(rv, rvk, rvv)
 		}
 		// if ktypeIsString {
 		// 	// keepAlive4StringView(kstrbs) // not needed, as reference is outside loop

+ 23 - 8
codec/encode.go

@@ -619,7 +619,7 @@ func (e *Encoder) kMapCanonical(rtkey, rtval reflect.Type, rv reflect.Value, val
 			e.mapElemKey()
 			e.e.EncodeBool(mksv[i].v)
 			e.mapElemValue()
-			e.encodeValue(mapIndex(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
+			e.encodeValue(mapGet(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
 		}
 	case reflect.String:
 		mksv := make([]stringRv, len(mks))
@@ -637,7 +637,7 @@ func (e *Encoder) kMapCanonical(rtkey, rtval reflect.Type, rv reflect.Value, val
 				e.e.EncodeStringEnc(cUTF8, mksv[i].v)
 			}
 			e.mapElemValue()
-			e.encodeValue(mapIndex(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
+			e.encodeValue(mapGet(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
 		}
 	case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uintptr:
 		mksv := make([]uint64Rv, len(mks))
@@ -651,7 +651,7 @@ func (e *Encoder) kMapCanonical(rtkey, rtval reflect.Type, rv reflect.Value, val
 			e.mapElemKey()
 			e.e.EncodeUint(mksv[i].v)
 			e.mapElemValue()
-			e.encodeValue(mapIndex(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
+			e.encodeValue(mapGet(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
 		}
 	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
 		mksv := make([]int64Rv, len(mks))
@@ -665,7 +665,7 @@ func (e *Encoder) kMapCanonical(rtkey, rtval reflect.Type, rv reflect.Value, val
 			e.mapElemKey()
 			e.e.EncodeInt(mksv[i].v)
 			e.mapElemValue()
-			e.encodeValue(mapIndex(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
+			e.encodeValue(mapGet(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
 		}
 	case reflect.Float32:
 		mksv := make([]float64Rv, len(mks))
@@ -679,7 +679,7 @@ func (e *Encoder) kMapCanonical(rtkey, rtval reflect.Type, rv reflect.Value, val
 			e.mapElemKey()
 			e.e.EncodeFloat32(float32(mksv[i].v))
 			e.mapElemValue()
-			e.encodeValue(mapIndex(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
+			e.encodeValue(mapGet(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
 		}
 	case reflect.Float64:
 		mksv := make([]float64Rv, len(mks))
@@ -693,7 +693,7 @@ func (e *Encoder) kMapCanonical(rtkey, rtval reflect.Type, rv reflect.Value, val
 			e.mapElemKey()
 			e.e.EncodeFloat64(mksv[i].v)
 			e.mapElemValue()
-			e.encodeValue(mapIndex(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
+			e.encodeValue(mapGet(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
 		}
 	case reflect.Struct:
 		if rv.Type() == timeTyp {
@@ -708,7 +708,7 @@ func (e *Encoder) kMapCanonical(rtkey, rtval reflect.Type, rv reflect.Value, val
 				e.mapElemKey()
 				e.e.EncodeTime(mksv[i].v)
 				e.mapElemValue()
-				e.encodeValue(mapIndex(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
+				e.encodeValue(mapGet(rv, mksv[i].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksv[i].r), valFn)
 			}
 			break
 		}
@@ -732,7 +732,7 @@ func (e *Encoder) kMapCanonical(rtkey, rtval reflect.Type, rv reflect.Value, val
 			e.mapElemKey()
 			e.asis(mksbv[j].v)
 			e.mapElemValue()
-			e.encodeValue(mapIndex(rv, mksbv[j].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksbv[j].r), valFn)
+			e.encodeValue(mapGet(rv, mksbv[j].r, rvv), valFn) // e.encodeValue(rv.MapIndex(mksbv[j].r), valFn)
 		}
 		bufp.end()
 	}
@@ -1339,6 +1339,21 @@ func (e *Encoder) mapElemValue() {
 	e.c = containerMapValue
 }
 
+// // Note: This is harder to inline, as there are 2 function calls inside.
+// func (e *Encoder) mapElemKeyOrValue(j uint8) {
+// 	if j == 0 {
+// 		if e.js {
+// 			e.jenc.WriteMapElemKey()
+// 		}
+// 		e.c = containerMapKey
+// 	} else {
+// 		if e.js {
+// 			e.jenc.WriteMapElemValue()
+// 		}
+// 		e.c = containerMapValue
+// 	}
+// }
+
 func (e *Encoder) mapEnd() {
 	e.e.WriteMapEnd()
 	e.c = containerMapEnd

+ 1 - 0
codec/fast-path.go.tmpl

@@ -161,6 +161,7 @@ func (fastpathT) {{ .MethodNamePfx "EncAsMap" false }}V(v []{{ .Elem }}, e *Enco
 	} else {
 		e.mapStart(len(v) / 2)
 		for j := range v {
+			{{/* e.mapElemKeyOrValue(uint8(j)%2) */ -}}
 			if j%2 == 0 {
 				e.mapElemKey()
 			} else {

+ 11 - 7
codec/helper_not_unsafe.go

@@ -323,17 +323,21 @@ func (e *Encoder) kUintptr(f *codecFnInfo, rv reflect.Value) {
 
 // ------------ map range and map indexing ----------
 
-func mapIndex(m, k, v reflect.Value) (vv reflect.Value) {
+func mapGet(m, k, v reflect.Value) (vv reflect.Value) {
 	return m.MapIndex(k)
-	// if vv.IsValid() && v.CanSet() {
-	// 	v.Set(vv)
-	// }
-	// return
 }
 
-// return an addressable reflect value that can be used in mapRange and mapIndex operations.
+func mapSet(m, k, v reflect.Value) {
+	m.SetMapIndex(k, v)
+}
+
+func mapDelete(m, k reflect.Value) {
+	m.SetMapIndex(k, reflect.Value{})
+}
+
+// return an addressable reflect value that can be used in mapRange and mapGet operations.
 //
-// all calls to mapIndex or mapRange will call here to get an addressable reflect.Value.
+// all calls to mapGet or mapRange will call here to get an addressable reflect.Value.
 func mapAddressableRV(t reflect.Type) (r reflect.Value) {
 	return // reflect.New(t).Elem()
 }

+ 39 - 13
codec/helper_unsafe.go

@@ -568,9 +568,9 @@ func (t *unsafeMapIter) Next() (r bool) {
 	if t.done {
 		return
 	}
-	unsafeMapSet(t.kptr, t.ktyp, t.it.key, t.kisref)
+	unsafeSet(t.kptr, t.ktyp, t.it.key, t.kisref)
 	if t.mapvalues {
-		unsafeMapSet(t.vptr, t.vtyp, t.it.value, t.visref)
+		unsafeSet(t.vptr, t.vtyp, t.it.value, t.visref)
 	}
 	return true
 }
@@ -586,7 +586,7 @@ func (t *unsafeMapIter) Value() (r reflect.Value) {
 	return
 }
 
-func unsafeMapSet(p, ptyp, p2 unsafe.Pointer, isref bool) {
+func unsafeSet(p, ptyp, p2 unsafe.Pointer, isref bool) {
 	if isref {
 		*(*unsafe.Pointer)(p) = *(*unsafe.Pointer)(p2) // p2
 	} else {
@@ -624,14 +624,16 @@ func mapRange(m, k, v reflect.Value, mapvalues bool) *unsafeMapIter {
 	return t
 }
 
-func mapIndex(m, k, v reflect.Value) (vv reflect.Value) {
-	var urv = (*unsafeReflectValue)(unsafe.Pointer(&k))
-	var kptr unsafe.Pointer
+func unsafeMapKVPtr(urv *unsafeReflectValue) unsafe.Pointer {
 	if urv.flag&unsafeFlagIndir == 0 {
-		kptr = unsafe.Pointer(&urv.ptr)
-	} else {
-		kptr = urv.ptr
+		return unsafe.Pointer(&urv.ptr)
 	}
+	return urv.ptr
+}
+
+func mapGet(m, k, v reflect.Value) (vv reflect.Value) {
+	var urv = (*unsafeReflectValue)(unsafe.Pointer(&k))
+	var kptr = unsafeMapKVPtr(urv)
 
 	urv = (*unsafeReflectValue)(unsafe.Pointer(&m))
 
@@ -643,13 +645,29 @@ func mapIndex(m, k, v reflect.Value) (vv reflect.Value) {
 
 	urv = (*unsafeReflectValue)(unsafe.Pointer(&v))
 
-	unsafeMapSet(urv.ptr, urv.typ, vvptr, refBitset.isset(byte(v.Kind())))
+	unsafeSet(urv.ptr, urv.typ, vvptr, refBitset.isset(byte(v.Kind())))
 	return v
 }
 
-// return an addressable reflect value that can be used in mapRange and mapIndex operations.
+func mapSet(m, k, v reflect.Value) {
+	var urv = (*unsafeReflectValue)(unsafe.Pointer(&k))
+	var kptr = unsafeMapKVPtr(urv)
+	urv = (*unsafeReflectValue)(unsafe.Pointer(&v))
+	var vptr = unsafeMapKVPtr(urv)
+	urv = (*unsafeReflectValue)(unsafe.Pointer(&m))
+	mapassign(urv.typ, rv2ptr(urv), kptr, vptr)
+}
+
+func mapDelete(m, k reflect.Value) {
+	var urv = (*unsafeReflectValue)(unsafe.Pointer(&k))
+	var kptr = unsafeMapKVPtr(urv)
+	urv = (*unsafeReflectValue)(unsafe.Pointer(&m))
+	mapdelete(urv.typ, rv2ptr(urv), kptr)
+}
+
+// return an addressable reflect value that can be used in mapRange and mapGet operations.
 //
-// all calls to mapIndex or mapRange will call here to get an addressable reflect.Value.
+// all calls to mapGet or mapRange will call here to get an addressable reflect.Value.
 func mapAddressableRV(t reflect.Type) (r reflect.Value) {
 	return reflect.New(t).Elem()
 }
@@ -664,7 +682,15 @@ func mapiternext(it unsafe.Pointer) (key unsafe.Pointer)
 
 //go:linkname mapaccess reflect.mapaccess
 //go:noescape
-func mapaccess(rtype unsafe.Pointer, m unsafe.Pointer, key unsafe.Pointer) (val unsafe.Pointer)
+func mapaccess(typ unsafe.Pointer, m unsafe.Pointer, key unsafe.Pointer) (val unsafe.Pointer)
+
+//go:linkname mapassign reflect.mapassign
+//go:noescape
+func mapassign(typ unsafe.Pointer, m unsafe.Pointer, key, val unsafe.Pointer)
+
+//go:linkname mapdelete reflect.mapdelete
+//go:noescape
+func mapdelete(typ unsafe.Pointer, m unsafe.Pointer, key unsafe.Pointer)
 
 //go:linkname typedmemmove reflect.typedmemmove
 //go:noescape