Explorar el Código

fix #293 copy extensions

Tao Wen hace 7 años
padre
commit
10a568c511
Se han modificado 5 ficheros con 50 adiciones y 25 borrados
  1. 4 4
      adapter.go
  2. 13 10
      config.go
  3. 4 2
      reflect.go
  4. 19 7
      reflect_extension.go
  5. 10 2
      reflect_map.go

+ 4 - 4
adapter.go

@@ -98,7 +98,7 @@ func (adapter *Decoder) Buffered() io.Reader {
 func (adapter *Decoder) UseNumber() {
 	cfg := adapter.iter.cfg.configBeforeFrozen
 	cfg.UseNumber = true
-	adapter.iter.cfg = cfg.frozeWithCacheReuse()
+	adapter.iter.cfg = cfg.frozeWithCacheReuse(adapter.iter.cfg.extraExtensions)
 }
 
 // DisallowUnknownFields causes the Decoder to return an error when the destination
@@ -107,7 +107,7 @@ func (adapter *Decoder) UseNumber() {
 func (adapter *Decoder) DisallowUnknownFields() {
 	cfg := adapter.iter.cfg.configBeforeFrozen
 	cfg.DisallowUnknownFields = true
-	adapter.iter.cfg = cfg.frozeWithCacheReuse()
+	adapter.iter.cfg = cfg.frozeWithCacheReuse(adapter.iter.cfg.extraExtensions)
 }
 
 // NewEncoder same as json.NewEncoder
@@ -132,14 +132,14 @@ func (adapter *Encoder) Encode(val interface{}) error {
 func (adapter *Encoder) SetIndent(prefix, indent string) {
 	config := adapter.stream.cfg.configBeforeFrozen
 	config.IndentionStep = len(indent)
-	adapter.stream.cfg = config.frozeWithCacheReuse()
+	adapter.stream.cfg = config.frozeWithCacheReuse(adapter.stream.cfg.extraExtensions)
 }
 
 // SetEscapeHTML escape html by default, set to false to disable
 func (adapter *Encoder) SetEscapeHTML(escapeHTML bool) {
 	config := adapter.stream.cfg.configBeforeFrozen
 	config.EscapeHTML = escapeHTML
-	adapter.stream.cfg = config.frozeWithCacheReuse()
+	adapter.stream.cfg = config.frozeWithCacheReuse(adapter.stream.cfg.extraExtensions)
 }
 
 // Valid reports whether data is a valid JSON encoding.

+ 13 - 10
config.go

@@ -74,7 +74,9 @@ type frozenConfig struct {
 	disallowUnknownFields         bool
 	decoderCache                  *concurrent.Map
 	encoderCache                  *concurrent.Map
-	extensions                    []Extension
+	encoderExtension              Extension
+	decoderExtension              Extension
+	extraExtensions               []Extension
 	streamPool                    *sync.Pool
 	iteratorPool                  *sync.Pool
 	caseSensitive                 bool
@@ -158,22 +160,21 @@ func (cfg Config) Froze() API {
 	if cfg.ValidateJsonRawMessage {
 		api.validateJsonRawMessage(encoderExtension)
 	}
-	if len(encoderExtension) > 0 {
-		api.extensions = append(api.extensions, encoderExtension)
-	}
-	if len(decoderExtension) > 0 {
-		api.extensions = append(api.extensions, decoderExtension)
-	}
+	api.encoderExtension = encoderExtension
+	api.decoderExtension = decoderExtension
 	api.configBeforeFrozen = cfg
 	return api
 }
 
-func (cfg Config) frozeWithCacheReuse() *frozenConfig {
+func (cfg Config) frozeWithCacheReuse(extraExtensions []Extension) *frozenConfig {
 	api := getFrozenConfigFromCache(cfg)
 	if api != nil {
 		return api
 	}
 	api = cfg.Froze().(*frozenConfig)
+	for _, extension := range extraExtensions {
+		api.RegisterExtension(extension)
+	}
 	addFrozenConfigToCache(cfg, api)
 	return api
 }
@@ -219,7 +220,9 @@ func (cfg *frozenConfig) getTagKey() string {
 }
 
 func (cfg *frozenConfig) RegisterExtension(extension Extension) {
-	cfg.extensions = append(cfg.extensions, extension)
+	cfg.extraExtensions = append(cfg.extraExtensions, extension)
+	copied := cfg.configBeforeFrozen
+	cfg.configBeforeFrozen = copied
 }
 
 type lossyFloat32Encoder struct {
@@ -314,7 +317,7 @@ func (cfg *frozenConfig) MarshalIndent(v interface{}, prefix, indent string) ([]
 	}
 	newCfg := cfg.configBeforeFrozen
 	newCfg.IndentionStep = len(indent)
-	return newCfg.frozeWithCacheReuse().Marshal(v)
+	return newCfg.frozeWithCacheReuse(cfg.extraExtensions).Marshal(v)
 }
 
 func (cfg *frozenConfig) UnmarshalFromString(str string, v interface{}) error {

+ 4 - 2
reflect.go

@@ -120,7 +120,8 @@ func decoderOfType(ctx *ctx, typ reflect2.Type) ValDecoder {
 	for _, extension := range extensions {
 		decoder = extension.DecorateDecoder(typ, decoder)
 	}
-	for _, extension := range ctx.extensions {
+	decoder = ctx.decoderExtension.DecorateDecoder(typ, decoder)
+	for _, extension := range ctx.extraExtensions {
 		decoder = extension.DecorateDecoder(typ, decoder)
 	}
 	return decoder
@@ -222,7 +223,8 @@ func encoderOfType(ctx *ctx, typ reflect2.Type) ValEncoder {
 	for _, extension := range extensions {
 		encoder = extension.DecorateEncoder(typ, encoder)
 	}
-	for _, extension := range ctx.extensions {
+	encoder = ctx.encoderExtension.DecorateEncoder(typ, encoder)
+	for _, extension := range ctx.extraExtensions {
 		encoder = extension.DecorateEncoder(typ, encoder)
 	}
 	return encoder

+ 19 - 7
reflect_extension.go

@@ -246,7 +246,8 @@ func getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder {
 		for _, extension := range extensions {
 			decoder = extension.DecorateDecoder(typ, decoder)
 		}
-		for _, extension := range ctx.extensions {
+		decoder = ctx.decoderExtension.DecorateDecoder(typ, decoder)
+		for _, extension := range ctx.extraExtensions {
 			decoder = extension.DecorateDecoder(typ, decoder)
 		}
 	}
@@ -259,14 +260,18 @@ func _getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder {
 			return decoder
 		}
 	}
-	for _, extension := range ctx.extensions {
+	decoder := ctx.decoderExtension.CreateDecoder(typ)
+	if decoder != nil {
+		return decoder
+	}
+	for _, extension := range ctx.extraExtensions {
 		decoder := extension.CreateDecoder(typ)
 		if decoder != nil {
 			return decoder
 		}
 	}
 	typeName := typ.String()
-	decoder := typeDecoders[typeName]
+	decoder = typeDecoders[typeName]
 	if decoder != nil {
 		return decoder
 	}
@@ -286,7 +291,8 @@ func getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
 		for _, extension := range extensions {
 			encoder = extension.DecorateEncoder(typ, encoder)
 		}
-		for _, extension := range ctx.extensions {
+		encoder = ctx.encoderExtension.DecorateEncoder(typ, encoder)
+		for _, extension := range ctx.extraExtensions {
 			encoder = extension.DecorateEncoder(typ, encoder)
 		}
 	}
@@ -300,14 +306,18 @@ func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
 			return encoder
 		}
 	}
-	for _, extension := range ctx.extensions {
+	encoder := ctx.encoderExtension.CreateEncoder(typ)
+	if encoder != nil {
+		return encoder
+	}
+	for _, extension := range ctx.extraExtensions {
 		encoder := extension.CreateEncoder(typ)
 		if encoder != nil {
 			return encoder
 		}
 	}
 	typeName := typ.String()
-	encoder := typeEncoders[typeName]
+	encoder = typeEncoders[typeName]
 	if encoder != nil {
 		return encoder
 	}
@@ -393,7 +403,9 @@ func createStructDescriptor(ctx *ctx, typ reflect2.Type, bindings []*Binding, em
 	for _, extension := range extensions {
 		extension.UpdateStructDescriptor(structDescriptor)
 	}
-	for _, extension := range ctx.extensions {
+	ctx.encoderExtension.UpdateStructDescriptor(structDescriptor)
+	ctx.decoderExtension.UpdateStructDescriptor(structDescriptor)
+	for _, extension := range ctx.extraExtensions {
 		extension.UpdateStructDescriptor(structDescriptor)
 	}
 	processTags(structDescriptor, ctx.frozenConfig)

+ 10 - 2
reflect_map.go

@@ -39,7 +39,11 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder {
 }
 
 func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
-	for _, extension := range ctx.extensions {
+	decoder := ctx.decoderExtension.CreateMapKeyDecoder(typ)
+	if decoder != nil {
+		return decoder
+	}
+	for _, extension := range ctx.extraExtensions {
 		decoder := extension.CreateMapKeyDecoder(typ)
 		if decoder != nil {
 			return decoder
@@ -77,7 +81,11 @@ func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
 }
 
 func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder {
-	for _, extension := range ctx.extensions {
+	encoder := ctx.encoderExtension.CreateMapKeyEncoder(typ)
+	if encoder != nil {
+		return encoder
+	}
+	for _, extension := range ctx.extraExtensions {
 		encoder := extension.CreateMapKeyEncoder(typ)
 		if encoder != nil {
 			return encoder