Browse Source

icmp: add simple multipart message validation

This change adds simple validation for multipart messages to avoid
generating incorrect messages and introduces RawBody and RawExtension
to control message validation. RawBody and RawExtension are excluded
from normal message processing and can be used to construct crafted
messages for applications such as wire format testing.

Fixes golang/go#28686.

Change-Id: I56f51d6566142f5e1dcc75cfce5250801e583d6d
Reviewed-on: https://go-review.googlesource.com/c/net/+/155859
Run-TryBot: Mikio Hara <mikioh.public.networking@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Mikio Hara 7 years ago
parent
commit
56fb01167e
10 changed files with 385 additions and 66 deletions
  1. 19 1
      icmp/dstunreach.go
  2. 26 10
      icmp/echo.go
  3. 60 0
      icmp/extension.go
  4. 2 1
      icmp/message.go
  5. 201 13
      icmp/message_test.go
  6. 16 7
      icmp/messagebody.go
  7. 22 14
      icmp/multipart.go
  8. 3 11
      icmp/multipart_test.go
  9. 17 8
      icmp/paramprob.go
  10. 19 1
      icmp/timeexceeded.go

+ 19 - 1
icmp/dstunreach.go

@@ -4,6 +4,12 @@
 
 package icmp
 
+import (
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
+
 // A DstUnreach represents an ICMP destination unreachable message
 // body.
 type DstUnreach struct {
@@ -17,11 +23,23 @@ func (p *DstUnreach) Len(proto int) int {
 		return 0
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *DstUnreach) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeDestinationUnreachable
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeDestinationUnreachable
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
 }
 

+ 26 - 10
icmp/echo.go

@@ -4,7 +4,13 @@
 
 package icmp
 
-import "encoding/binary"
+import (
+	"encoding/binary"
+
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
 
 // An Echo represents an ICMP echo request or reply message body.
 type Echo struct {
@@ -59,29 +65,39 @@ func (p *ExtendedEchoRequest) Len(proto int) int {
 		return 0
 	}
 	l, _ := multipartMessageBodyDataLen(proto, false, nil, p.Extensions)
-	return 4 + l
+	return l
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *ExtendedEchoRequest) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeExtendedEchoRequest
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeExtendedEchoRequest
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	b, err := marshalMultipartMessageBody(proto, false, nil, p.Extensions)
 	if err != nil {
 		return nil, err
 	}
-	bb := make([]byte, 4)
-	binary.BigEndian.PutUint16(bb[:2], uint16(p.ID))
-	bb[2] = byte(p.Seq)
+	binary.BigEndian.PutUint16(b[:2], uint16(p.ID))
+	b[2] = byte(p.Seq)
 	if p.Local {
-		bb[3] |= 0x01
+		b[3] |= 0x01
 	}
-	bb = append(bb, b...)
-	return bb, nil
+	return b, nil
 }
 
 // parseExtendedEchoRequest parses b as an ICMP extended echo request
 // message body.
 func parseExtendedEchoRequest(proto int, typ Type, b []byte) (MessageBody, error) {
-	if len(b) < 4+4 {
+	if len(b) < 4 {
 		return nil, errMessageTooShort
 	}
 	p := &ExtendedEchoRequest{ID: int(binary.BigEndian.Uint16(b[:2])), Seq: int(b[2])}
@@ -89,7 +105,7 @@ func parseExtendedEchoRequest(proto int, typ Type, b []byte) (MessageBody, error
 		p.Local = true
 	}
 	var err error
-	_, p.Extensions, err = parseMultipartMessageBody(proto, typ, b[4:])
+	_, p.Extensions, err = parseMultipartMessageBody(proto, typ, b)
 	if err != nil {
 		return nil, err
 	}

+ 60 - 0
icmp/extension.go

@@ -103,8 +103,68 @@ func parseExtensions(typ Type, b []byte, l int) ([]Extension, int, error) {
 				return nil, -1, err
 			}
 			exts = append(exts, ext)
+		default:
+			ext := &RawExtension{Data: make([]byte, ol)}
+			copy(ext.Data, b[:ol])
+			exts = append(exts, ext)
 		}
 		b = b[ol:]
 	}
 	return exts, l, nil
 }
+
+func validExtensions(typ Type, exts []Extension) bool {
+	switch typ {
+	case ipv4.ICMPTypeDestinationUnreachable, ipv4.ICMPTypeTimeExceeded, ipv4.ICMPTypeParameterProblem,
+		ipv6.ICMPTypeDestinationUnreachable, ipv6.ICMPTypeTimeExceeded:
+		for i := range exts {
+			switch exts[i].(type) {
+			case *MPLSLabelStack, *InterfaceInfo, *RawExtension:
+			default:
+				return false
+			}
+		}
+		return true
+	case ipv4.ICMPTypeExtendedEchoRequest, ipv6.ICMPTypeExtendedEchoRequest:
+		var n int
+		for i := range exts {
+			switch exts[i].(type) {
+			case *InterfaceIdent:
+				n++
+			case *RawExtension:
+			default:
+				return false
+			}
+		}
+		// Not a single InterfaceIdent object or a combo of
+		// RawExtension and InterfaceIdent objects is not
+		// allowed.
+		if n == 1 && len(exts) > 1 {
+			return false
+		}
+		return true
+	default:
+		return false
+	}
+}
+
+// A RawExtension represents a raw extension.
+//
+// A raw extension is excluded from message processing and can be used
+// to construct applications such as protocol conformance testing.
+type RawExtension struct {
+	Data []byte // data
+}
+
+// Len implements the Len method of Extension interface.
+func (p *RawExtension) Len(proto int) int {
+	if p == nil {
+		return 0
+	}
+	return len(p.Data)
+}
+
+// Marshal implements the Marshal method of Extension interface.
+func (p *RawExtension) Marshal(proto int) ([]byte, error) {
+	return p.Data, nil
+}

+ 2 - 1
icmp/message.go

@@ -34,6 +34,7 @@ var (
 	errHeaderTooShort   = errors.New("header too short")
 	errBufferTooShort   = errors.New("buffer too short")
 	errOpNoSupport      = errors.New("operation not supported")
+	errInvalidBody      = errors.New("invalid body")
 	errNoExtension      = errors.New("no extension")
 	errInvalidExtension = errors.New("invalid extension")
 )
@@ -150,7 +151,7 @@ func ParseMessage(proto int, b []byte) (*Message, error) {
 		return nil, errInvalidProtocol
 	}
 	if fn, ok := parseFns[m.Type]; !ok {
-		m.Body, err = parseDefaultMessageBody(proto, b[4:])
+		m.Body, err = parseRawBody(proto, b[4:])
 	} else {
 		m.Body, err = fn(proto, m.Type, b[4:])
 	}

+ 201 - 13
icmp/message_test.go

@@ -5,6 +5,7 @@
 package icmp_test
 
 import (
+	"bytes"
 	"net"
 	"reflect"
 	"testing"
@@ -31,17 +32,19 @@ func TestMarshalAndParseMessage(t *testing.T) {
 			for _, psh := range pshs {
 				b, err := tm.Marshal(psh)
 				if err != nil {
-					t.Fatal(err)
+					t.Fatalf("#%d: %v", i, err)
 				}
 				m, err := icmp.ParseMessage(proto, b)
 				if err != nil {
-					t.Fatal(err)
+					t.Fatalf("#%d: %v", i, err)
 				}
 				if m.Type != tm.Type || m.Code != tm.Code {
 					t.Errorf("#%d: got %#v; want %#v", i, m, &tm)
+					continue
 				}
 				if !reflect.DeepEqual(m.Body, tm.Body) {
 					t.Errorf("#%d: got %#v; want %#v", i, m.Body, tm.Body)
+					continue
 				}
 			}
 		}
@@ -80,6 +83,13 @@ func TestMarshalAndParseMessage(t *testing.T) {
 					Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
 					Body: &icmp.ExtendedEchoRequest{
 						ID: 1, Seq: 2,
+						Extensions: []icmp.Extension{
+							&icmp.InterfaceIdent{
+								Class: 3,
+								Type:  1,
+								Name:  "en101",
+							},
+						},
 					},
 				},
 				{
@@ -88,12 +98,6 @@ func TestMarshalAndParseMessage(t *testing.T) {
 						State: 4 /* Delay */, Active: true, IPv4: true,
 					},
 				},
-				{
-					Type: ipv4.ICMPTypePhoturis,
-					Body: &icmp.DefaultMessageBody{
-						Data: []byte{0x80, 0x40, 0x20, 0x10},
-					},
-				},
 			})
 	})
 	t.Run("IPv6", func(t *testing.T) {
@@ -136,6 +140,13 @@ func TestMarshalAndParseMessage(t *testing.T) {
 					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
 					Body: &icmp.ExtendedEchoRequest{
 						ID: 1, Seq: 2,
+						Extensions: []icmp.Extension{
+							&icmp.InterfaceIdent{
+								Class: 3,
+								Type:  2,
+								Index: 911,
+							},
+						},
 					},
 				},
 				{
@@ -144,12 +155,189 @@ func TestMarshalAndParseMessage(t *testing.T) {
 						State: 5 /* Probe */, Active: true, IPv6: true,
 					},
 				},
-				{
-					Type: ipv6.ICMPTypeDuplicateAddressConfirmation,
-					Body: &icmp.DefaultMessageBody{
-						Data: []byte{0x80, 0x40, 0x20, 0x10},
+			})
+	})
+}
+
+func TestMarshalAndParseRawMessage(t *testing.T) {
+	t.Run("RawBody", func(t *testing.T) {
+		for i, tt := range []struct {
+			m               icmp.Message
+			wire            []byte
+			parseShouldFail bool
+		}{
+			{ // Nil body
+				m: icmp.Message{
+					Type: ipv4.ICMPTypeDestinationUnreachable, Code: 127,
+				},
+				wire: []byte{
+					0x03, 0x7f, 0xfc, 0x80,
+				},
+				parseShouldFail: true,
+			},
+			{ // Empty body
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 128,
+					Body: &icmp.RawBody{},
+				},
+				wire: []byte{
+					0x01, 0x80, 0x00, 0x00,
+				},
+				parseShouldFail: true,
+			},
+			{ // Crafted body
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDuplicateAddressConfirmation, Code: 129,
+					Body: &icmp.RawBody{
+						Data: []byte{0xca, 0xfe},
 					},
 				},
-			})
+				wire: []byte{
+					0x9e, 0x81, 0x00, 0x00,
+					0xca, 0xfe,
+				},
+				parseShouldFail: false,
+			},
+		} {
+			b, err := tt.m.Marshal(nil)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if !bytes.Equal(b, tt.wire) {
+				t.Errorf("#%d: got %#v; want %#v", i, b, tt.wire)
+				continue
+			}
+			m, err := icmp.ParseMessage(tt.m.Type.Protocol(), b)
+			if err != nil != tt.parseShouldFail {
+				t.Errorf("#%d: got %v, %v", i, m, err)
+				continue
+			}
+			if tt.parseShouldFail {
+				continue
+			}
+			if m.Type != tt.m.Type || m.Code != tt.m.Code {
+				t.Errorf("#%d: got %v; want %v", i, m, tt.m)
+				continue
+			}
+			if !bytes.Equal(m.Body.(*icmp.RawBody).Data, tt.m.Body.(*icmp.RawBody).Data) {
+				t.Errorf("#%d: got %#v; want %#v", i, m.Body, tt.m.Body)
+				continue
+			}
+		}
+	})
+	t.Run("RawExtension", func(t *testing.T) {
+		for i, tt := range []struct {
+			m    icmp.Message
+			wire []byte
+		}{
+			{ // Unaligned data and nil extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 130,
+					Body: &icmp.DstUnreach{
+						Data: []byte("ERROR-INVOKING-PACKET"),
+					},
+				},
+				wire: []byte{
+					0x01, 0x82, 0x00, 0x00,
+					0x00, 0x00, 0x00, 0x00,
+					'E', 'R', 'R', 'O',
+					'R', '-', 'I', 'N',
+					'V', 'O', 'K', 'I',
+					'N', 'G', '-', 'P',
+					'A', 'C', 'K', 'E',
+					'T',
+				},
+			},
+			{ // Unaligned data and empty extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 131,
+					Body: &icmp.DstUnreach{
+						Data: []byte("ERROR-INVOKING-PACKET"),
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{},
+						},
+					},
+				},
+				wire: []byte{
+					0x01, 0x83, 0x00, 0x00,
+					0x02, 0x00, 0x00, 0x00,
+					'E', 'R', 'R', 'O',
+					'R', '-', 'I', 'N',
+					'V', 'O', 'K', 'I',
+					'N', 'G', '-', 'P',
+					'A', 'C', 'K', 'E',
+					'T',
+					0x20, 0x00, 0xdf, 0xff,
+				},
+			},
+			{ // Nil extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 132,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+					},
+				},
+				wire: []byte{
+					0xa0, 0x84, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+				},
+			},
+			{ // Empty extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 133,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{},
+						},
+					},
+				},
+				wire: []byte{
+					0xa0, 0x85, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+					0x20, 0x00, 0xdf, 0xff,
+				},
+			},
+			{ // Crafted extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 134,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{
+								Data: []byte("CRAFTED"),
+							},
+						},
+					},
+				},
+				wire: []byte{
+					0xa0, 0x86, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+					0x20, 0x00, 0xc3, 0x21,
+					'C', 'R', 'A', 'F',
+					'T', 'E', 'D',
+				},
+			},
+		} {
+			b, err := tt.m.Marshal(nil)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if !bytes.Equal(b, tt.wire) {
+				t.Errorf("#%d: got %#v; want %#v", i, b, tt.wire)
+				continue
+			}
+			m, err := icmp.ParseMessage(tt.m.Type.Protocol(), b)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if m.Type != tt.m.Type || m.Code != tt.m.Code {
+				t.Errorf("#%d: got %v; want %v", i, m, tt.m)
+				continue
+			}
+		}
 	})
 }

+ 16 - 7
icmp/messagebody.go

@@ -17,13 +17,17 @@ type MessageBody interface {
 	Marshal(proto int) ([]byte, error)
 }
 
-// A DefaultMessageBody represents the default message body.
-type DefaultMessageBody struct {
+// A RawBody represents a raw message body.
+//
+// A raw message body is excluded from message processing and can be
+// used to construct applications such as protocol conformance
+// testing.
+type RawBody struct {
 	Data []byte // data
 }
 
 // Len implements the Len method of MessageBody interface.
-func (p *DefaultMessageBody) Len(proto int) int {
+func (p *RawBody) Len(proto int) int {
 	if p == nil {
 		return 0
 	}
@@ -31,13 +35,18 @@ func (p *DefaultMessageBody) Len(proto int) int {
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
-func (p *DefaultMessageBody) Marshal(proto int) ([]byte, error) {
+func (p *RawBody) Marshal(proto int) ([]byte, error) {
 	return p.Data, nil
 }
 
-// parseDefaultMessageBody parses b as an ICMP message body.
-func parseDefaultMessageBody(proto int, b []byte) (MessageBody, error) {
-	p := &DefaultMessageBody{Data: make([]byte, len(b))}
+// parseRawBody parses b as an ICMP message body.
+func parseRawBody(proto int, b []byte) (MessageBody, error) {
+	p := &RawBody{Data: make([]byte, len(b))}
 	copy(p.Data, b)
 	return p, nil
 }
+
+// A DefaultMessageBody represents the default message body.
+//
+// Deprecated: Use RawBody instead.
+type DefaultMessageBody = RawBody

+ 22 - 14
icmp/multipart.go

@@ -11,18 +11,24 @@ import "golang.org/x/net/internal/iana"
 // and a required length for a padded original datagram in wire
 // format.
 func multipartMessageBodyDataLen(proto int, withOrigDgram bool, b []byte, exts []Extension) (bodyLen, dataLen int) {
+	bodyLen = 4 // length of leading octets
+	var extLen int
+	var rawExt bool // raw extension may contain an empty object
 	for _, ext := range exts {
-		bodyLen += ext.Len(proto)
-	}
-	if bodyLen > 0 {
-		if withOrigDgram {
-			dataLen = multipartMessageOrigDatagramLen(proto, b)
+		extLen += ext.Len(proto)
+		if _, ok := ext.(*RawExtension); ok {
+			rawExt = true
 		}
-		bodyLen += 4 // length of extension header
+	}
+	if extLen > 0 && withOrigDgram {
+		dataLen = multipartMessageOrigDatagramLen(proto, b)
 	} else {
 		dataLen = len(b)
 	}
-	bodyLen += dataLen
+	if extLen > 0 || rawExt {
+		bodyLen += 4 // length of extension header
+	}
+	bodyLen += dataLen + extLen
 	return bodyLen, dataLen
 }
 
@@ -54,12 +60,11 @@ func multipartMessageOrigDatagramLen(proto int, b []byte) int {
 // It can be used for non-multipart message bodies when exts is nil.
 func marshalMultipartMessageBody(proto int, withOrigDgram bool, data []byte, exts []Extension) ([]byte, error) {
 	bodyLen, dataLen := multipartMessageBodyDataLen(proto, withOrigDgram, data, exts)
-	b := make([]byte, 4+bodyLen)
+	b := make([]byte, bodyLen)
 	copy(b[4:], data)
-	off := dataLen + 4
 	if len(exts) > 0 {
-		b[dataLen+4] = byte(extensionVersion << 4)
-		off += 4 // length of object header
+		b[4+dataLen] = byte(extensionVersion << 4)
+		off := 4 + dataLen + 4 // leading octets, data, extension header
 		for _, ext := range exts {
 			switch ext := ext.(type) {
 			case *MPLSLabelStack:
@@ -78,11 +83,14 @@ func marshalMultipartMessageBody(proto int, withOrigDgram bool, data []byte, ext
 					return nil, err
 				}
 				off += ext.Len(proto)
+			case *RawExtension:
+				copy(b[off:], ext.Data)
+				off += ext.Len(proto)
 			}
 		}
-		s := checksum(b[dataLen+4:])
-		b[dataLen+4+2] ^= byte(s)
-		b[dataLen+4+3] ^= byte(s >> 8)
+		s := checksum(b[4+dataLen:])
+		b[4+dataLen+2] ^= byte(s)
+		b[4+dataLen+3] ^= byte(s >> 8)
 		if withOrigDgram {
 			switch proto {
 			case iana.ProtocolICMP:

+ 3 - 11
icmp/multipart_test.go

@@ -232,11 +232,6 @@ func TestMarshalAndParseMultipartMessage(t *testing.T) {
 							Type:  2,
 							Index: 911,
 						},
-						&icmp.InterfaceIdent{
-							Class: 3,
-							Type:  1,
-							Name:  "en101",
-						},
 					},
 				},
 			},
@@ -359,11 +354,6 @@ func TestMarshalAndParseMultipartMessage(t *testing.T) {
 				Body: &icmp.ExtendedEchoRequest{
 					ID: 1, Seq: 2, Local: true,
 					Extensions: []icmp.Extension{
-						&icmp.InterfaceIdent{
-							Class: 3,
-							Type:  1,
-							Name:  "en101",
-						},
 						&icmp.InterfaceIdent{
 							Class: 3,
 							Type:  2,
@@ -413,10 +403,12 @@ func dumpExtensions(gotExts, wantExts []icmp.Extension) string {
 			if !reflect.DeepEqual(got, want) {
 				s += fmt.Sprintf("#%d: got %#v; want %#v\n", i, got, want)
 			}
+		case *icmp.RawExtension:
+			s += fmt.Sprintf("#%d: raw extension\n", i)
 		}
 	}
 	if len(s) == 0 {
-		return "<nil>"
+		s += "empty extension"
 	}
 	return s[:len(s)-1]
 }

+ 17 - 8
icmp/paramprob.go

@@ -6,7 +6,9 @@ package icmp
 
 import (
 	"encoding/binary"
+
 	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
 )
 
 // A ParamProb represents an ICMP parameter problem message body.
@@ -22,23 +24,30 @@ func (p *ParamProb) Len(proto int) int {
 		return 0
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *ParamProb) Marshal(proto int) ([]byte, error) {
-	if proto == iana.ProtocolIPv6ICMP {
+	switch proto {
+	case iana.ProtocolICMP:
+		if !validExtensions(ipv4.ICMPTypeParameterProblem, p.Extensions) {
+			return nil, errInvalidExtension
+		}
+		b, err := marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
+		if err != nil {
+			return nil, err
+		}
+		b[0] = byte(p.Pointer)
+		return b, nil
+	case iana.ProtocolIPv6ICMP:
 		b := make([]byte, p.Len(proto))
 		binary.BigEndian.PutUint32(b[:4], uint32(p.Pointer))
 		copy(b[4:], p.Data)
 		return b, nil
+	default:
+		return nil, errInvalidProtocol
 	}
-	b, err := marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
-	if err != nil {
-		return nil, err
-	}
-	b[0] = byte(p.Pointer)
-	return b, nil
 }
 
 // parseParamProb parses b as an ICMP parameter problem message body.

+ 19 - 1
icmp/timeexceeded.go

@@ -4,6 +4,12 @@
 
 package icmp
 
+import (
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
+
 // A TimeExceeded represents an ICMP time exceeded message body.
 type TimeExceeded struct {
 	Data       []byte      // data, known as original datagram field
@@ -16,11 +22,23 @@ func (p *TimeExceeded) Len(proto int) int {
 		return 0
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *TimeExceeded) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeTimeExceeded
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeTimeExceeded
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
 }