Browse Source

icmp: add simple multipart message validation

This change adds simple validation for multipart messages to avoid
generating incorrect messages and introduces RawBody and RawExtension
to control message validation. RawBody and RawExtension are excluded
from normal message processing and can be used to construct crafted
messages for applications such as wire format testing.

Fixes golang/go#28686.

Change-Id: I56f51d6566142f5e1dcc75cfce5250801e583d6d
Reviewed-on: https://go-review.googlesource.com/c/net/+/155859
Run-TryBot: Mikio Hara <mikioh.public.networking@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Mikio Hara 7 years ago
parent
commit
56fb01167e
10 changed files with 385 additions and 66 deletions
  1. 19 1
      icmp/dstunreach.go
  2. 26 10
      icmp/echo.go
  3. 60 0
      icmp/extension.go
  4. 2 1
      icmp/message.go
  5. 201 13
      icmp/message_test.go
  6. 16 7
      icmp/messagebody.go
  7. 22 14
      icmp/multipart.go
  8. 3 11
      icmp/multipart_test.go
  9. 17 8
      icmp/paramprob.go
  10. 19 1
      icmp/timeexceeded.go

+ 19 - 1
icmp/dstunreach.go

@@ -4,6 +4,12 @@
 
 
 package icmp
 package icmp
 
 
+import (
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
+
 // A DstUnreach represents an ICMP destination unreachable message
 // A DstUnreach represents an ICMP destination unreachable message
 // body.
 // body.
 type DstUnreach struct {
 type DstUnreach struct {
@@ -17,11 +23,23 @@ func (p *DstUnreach) Len(proto int) int {
 		return 0
 		return 0
 	}
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 }
 
 
 // Marshal implements the Marshal method of MessageBody interface.
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *DstUnreach) Marshal(proto int) ([]byte, error) {
 func (p *DstUnreach) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeDestinationUnreachable
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeDestinationUnreachable
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
 	return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
 }
 }
 
 

+ 26 - 10
icmp/echo.go

@@ -4,7 +4,13 @@
 
 
 package icmp
 package icmp
 
 
-import "encoding/binary"
+import (
+	"encoding/binary"
+
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
 
 
 // An Echo represents an ICMP echo request or reply message body.
 // An Echo represents an ICMP echo request or reply message body.
 type Echo struct {
 type Echo struct {
@@ -59,29 +65,39 @@ func (p *ExtendedEchoRequest) Len(proto int) int {
 		return 0
 		return 0
 	}
 	}
 	l, _ := multipartMessageBodyDataLen(proto, false, nil, p.Extensions)
 	l, _ := multipartMessageBodyDataLen(proto, false, nil, p.Extensions)
-	return 4 + l
+	return l
 }
 }
 
 
 // Marshal implements the Marshal method of MessageBody interface.
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *ExtendedEchoRequest) Marshal(proto int) ([]byte, error) {
 func (p *ExtendedEchoRequest) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeExtendedEchoRequest
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeExtendedEchoRequest
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	b, err := marshalMultipartMessageBody(proto, false, nil, p.Extensions)
 	b, err := marshalMultipartMessageBody(proto, false, nil, p.Extensions)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	bb := make([]byte, 4)
-	binary.BigEndian.PutUint16(bb[:2], uint16(p.ID))
-	bb[2] = byte(p.Seq)
+	binary.BigEndian.PutUint16(b[:2], uint16(p.ID))
+	b[2] = byte(p.Seq)
 	if p.Local {
 	if p.Local {
-		bb[3] |= 0x01
+		b[3] |= 0x01
 	}
 	}
-	bb = append(bb, b...)
-	return bb, nil
+	return b, nil
 }
 }
 
 
 // parseExtendedEchoRequest parses b as an ICMP extended echo request
 // parseExtendedEchoRequest parses b as an ICMP extended echo request
 // message body.
 // message body.
 func parseExtendedEchoRequest(proto int, typ Type, b []byte) (MessageBody, error) {
 func parseExtendedEchoRequest(proto int, typ Type, b []byte) (MessageBody, error) {
-	if len(b) < 4+4 {
+	if len(b) < 4 {
 		return nil, errMessageTooShort
 		return nil, errMessageTooShort
 	}
 	}
 	p := &ExtendedEchoRequest{ID: int(binary.BigEndian.Uint16(b[:2])), Seq: int(b[2])}
 	p := &ExtendedEchoRequest{ID: int(binary.BigEndian.Uint16(b[:2])), Seq: int(b[2])}
@@ -89,7 +105,7 @@ func parseExtendedEchoRequest(proto int, typ Type, b []byte) (MessageBody, error
 		p.Local = true
 		p.Local = true
 	}
 	}
 	var err error
 	var err error
-	_, p.Extensions, err = parseMultipartMessageBody(proto, typ, b[4:])
+	_, p.Extensions, err = parseMultipartMessageBody(proto, typ, b)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 60 - 0
icmp/extension.go

@@ -103,8 +103,68 @@ func parseExtensions(typ Type, b []byte, l int) ([]Extension, int, error) {
 				return nil, -1, err
 				return nil, -1, err
 			}
 			}
 			exts = append(exts, ext)
 			exts = append(exts, ext)
+		default:
+			ext := &RawExtension{Data: make([]byte, ol)}
+			copy(ext.Data, b[:ol])
+			exts = append(exts, ext)
 		}
 		}
 		b = b[ol:]
 		b = b[ol:]
 	}
 	}
 	return exts, l, nil
 	return exts, l, nil
 }
 }
+
+func validExtensions(typ Type, exts []Extension) bool {
+	switch typ {
+	case ipv4.ICMPTypeDestinationUnreachable, ipv4.ICMPTypeTimeExceeded, ipv4.ICMPTypeParameterProblem,
+		ipv6.ICMPTypeDestinationUnreachable, ipv6.ICMPTypeTimeExceeded:
+		for i := range exts {
+			switch exts[i].(type) {
+			case *MPLSLabelStack, *InterfaceInfo, *RawExtension:
+			default:
+				return false
+			}
+		}
+		return true
+	case ipv4.ICMPTypeExtendedEchoRequest, ipv6.ICMPTypeExtendedEchoRequest:
+		var n int
+		for i := range exts {
+			switch exts[i].(type) {
+			case *InterfaceIdent:
+				n++
+			case *RawExtension:
+			default:
+				return false
+			}
+		}
+		// Not a single InterfaceIdent object or a combo of
+		// RawExtension and InterfaceIdent objects is not
+		// allowed.
+		if n == 1 && len(exts) > 1 {
+			return false
+		}
+		return true
+	default:
+		return false
+	}
+}
+
+// A RawExtension represents a raw extension.
+//
+// A raw extension is excluded from message processing and can be used
+// to construct applications such as protocol conformance testing.
+type RawExtension struct {
+	Data []byte // data
+}
+
+// Len implements the Len method of Extension interface.
+func (p *RawExtension) Len(proto int) int {
+	if p == nil {
+		return 0
+	}
+	return len(p.Data)
+}
+
+// Marshal implements the Marshal method of Extension interface.
+func (p *RawExtension) Marshal(proto int) ([]byte, error) {
+	return p.Data, nil
+}

+ 2 - 1
icmp/message.go

@@ -34,6 +34,7 @@ var (
 	errHeaderTooShort   = errors.New("header too short")
 	errHeaderTooShort   = errors.New("header too short")
 	errBufferTooShort   = errors.New("buffer too short")
 	errBufferTooShort   = errors.New("buffer too short")
 	errOpNoSupport      = errors.New("operation not supported")
 	errOpNoSupport      = errors.New("operation not supported")
+	errInvalidBody      = errors.New("invalid body")
 	errNoExtension      = errors.New("no extension")
 	errNoExtension      = errors.New("no extension")
 	errInvalidExtension = errors.New("invalid extension")
 	errInvalidExtension = errors.New("invalid extension")
 )
 )
@@ -150,7 +151,7 @@ func ParseMessage(proto int, b []byte) (*Message, error) {
 		return nil, errInvalidProtocol
 		return nil, errInvalidProtocol
 	}
 	}
 	if fn, ok := parseFns[m.Type]; !ok {
 	if fn, ok := parseFns[m.Type]; !ok {
-		m.Body, err = parseDefaultMessageBody(proto, b[4:])
+		m.Body, err = parseRawBody(proto, b[4:])
 	} else {
 	} else {
 		m.Body, err = fn(proto, m.Type, b[4:])
 		m.Body, err = fn(proto, m.Type, b[4:])
 	}
 	}

+ 201 - 13
icmp/message_test.go

@@ -5,6 +5,7 @@
 package icmp_test
 package icmp_test
 
 
 import (
 import (
+	"bytes"
 	"net"
 	"net"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
@@ -31,17 +32,19 @@ func TestMarshalAndParseMessage(t *testing.T) {
 			for _, psh := range pshs {
 			for _, psh := range pshs {
 				b, err := tm.Marshal(psh)
 				b, err := tm.Marshal(psh)
 				if err != nil {
 				if err != nil {
-					t.Fatal(err)
+					t.Fatalf("#%d: %v", i, err)
 				}
 				}
 				m, err := icmp.ParseMessage(proto, b)
 				m, err := icmp.ParseMessage(proto, b)
 				if err != nil {
 				if err != nil {
-					t.Fatal(err)
+					t.Fatalf("#%d: %v", i, err)
 				}
 				}
 				if m.Type != tm.Type || m.Code != tm.Code {
 				if m.Type != tm.Type || m.Code != tm.Code {
 					t.Errorf("#%d: got %#v; want %#v", i, m, &tm)
 					t.Errorf("#%d: got %#v; want %#v", i, m, &tm)
+					continue
 				}
 				}
 				if !reflect.DeepEqual(m.Body, tm.Body) {
 				if !reflect.DeepEqual(m.Body, tm.Body) {
 					t.Errorf("#%d: got %#v; want %#v", i, m.Body, tm.Body)
 					t.Errorf("#%d: got %#v; want %#v", i, m.Body, tm.Body)
+					continue
 				}
 				}
 			}
 			}
 		}
 		}
@@ -80,6 +83,13 @@ func TestMarshalAndParseMessage(t *testing.T) {
 					Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
 					Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
 					Body: &icmp.ExtendedEchoRequest{
 					Body: &icmp.ExtendedEchoRequest{
 						ID: 1, Seq: 2,
 						ID: 1, Seq: 2,
+						Extensions: []icmp.Extension{
+							&icmp.InterfaceIdent{
+								Class: 3,
+								Type:  1,
+								Name:  "en101",
+							},
+						},
 					},
 					},
 				},
 				},
 				{
 				{
@@ -88,12 +98,6 @@ func TestMarshalAndParseMessage(t *testing.T) {
 						State: 4 /* Delay */, Active: true, IPv4: true,
 						State: 4 /* Delay */, Active: true, IPv4: true,
 					},
 					},
 				},
 				},
-				{
-					Type: ipv4.ICMPTypePhoturis,
-					Body: &icmp.DefaultMessageBody{
-						Data: []byte{0x80, 0x40, 0x20, 0x10},
-					},
-				},
 			})
 			})
 	})
 	})
 	t.Run("IPv6", func(t *testing.T) {
 	t.Run("IPv6", func(t *testing.T) {
@@ -136,6 +140,13 @@ func TestMarshalAndParseMessage(t *testing.T) {
 					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
 					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
 					Body: &icmp.ExtendedEchoRequest{
 					Body: &icmp.ExtendedEchoRequest{
 						ID: 1, Seq: 2,
 						ID: 1, Seq: 2,
+						Extensions: []icmp.Extension{
+							&icmp.InterfaceIdent{
+								Class: 3,
+								Type:  2,
+								Index: 911,
+							},
+						},
 					},
 					},
 				},
 				},
 				{
 				{
@@ -144,12 +155,189 @@ func TestMarshalAndParseMessage(t *testing.T) {
 						State: 5 /* Probe */, Active: true, IPv6: true,
 						State: 5 /* Probe */, Active: true, IPv6: true,
 					},
 					},
 				},
 				},
-				{
-					Type: ipv6.ICMPTypeDuplicateAddressConfirmation,
-					Body: &icmp.DefaultMessageBody{
-						Data: []byte{0x80, 0x40, 0x20, 0x10},
+			})
+	})
+}
+
+func TestMarshalAndParseRawMessage(t *testing.T) {
+	t.Run("RawBody", func(t *testing.T) {
+		for i, tt := range []struct {
+			m               icmp.Message
+			wire            []byte
+			parseShouldFail bool
+		}{
+			{ // Nil body
+				m: icmp.Message{
+					Type: ipv4.ICMPTypeDestinationUnreachable, Code: 127,
+				},
+				wire: []byte{
+					0x03, 0x7f, 0xfc, 0x80,
+				},
+				parseShouldFail: true,
+			},
+			{ // Empty body
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 128,
+					Body: &icmp.RawBody{},
+				},
+				wire: []byte{
+					0x01, 0x80, 0x00, 0x00,
+				},
+				parseShouldFail: true,
+			},
+			{ // Crafted body
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDuplicateAddressConfirmation, Code: 129,
+					Body: &icmp.RawBody{
+						Data: []byte{0xca, 0xfe},
 					},
 					},
 				},
 				},
-			})
+				wire: []byte{
+					0x9e, 0x81, 0x00, 0x00,
+					0xca, 0xfe,
+				},
+				parseShouldFail: false,
+			},
+		} {
+			b, err := tt.m.Marshal(nil)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if !bytes.Equal(b, tt.wire) {
+				t.Errorf("#%d: got %#v; want %#v", i, b, tt.wire)
+				continue
+			}
+			m, err := icmp.ParseMessage(tt.m.Type.Protocol(), b)
+			if err != nil != tt.parseShouldFail {
+				t.Errorf("#%d: got %v, %v", i, m, err)
+				continue
+			}
+			if tt.parseShouldFail {
+				continue
+			}
+			if m.Type != tt.m.Type || m.Code != tt.m.Code {
+				t.Errorf("#%d: got %v; want %v", i, m, tt.m)
+				continue
+			}
+			if !bytes.Equal(m.Body.(*icmp.RawBody).Data, tt.m.Body.(*icmp.RawBody).Data) {
+				t.Errorf("#%d: got %#v; want %#v", i, m.Body, tt.m.Body)
+				continue
+			}
+		}
+	})
+	t.Run("RawExtension", func(t *testing.T) {
+		for i, tt := range []struct {
+			m    icmp.Message
+			wire []byte
+		}{
+			{ // Unaligned data and nil extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 130,
+					Body: &icmp.DstUnreach{
+						Data: []byte("ERROR-INVOKING-PACKET"),
+					},
+				},
+				wire: []byte{
+					0x01, 0x82, 0x00, 0x00,
+					0x00, 0x00, 0x00, 0x00,
+					'E', 'R', 'R', 'O',
+					'R', '-', 'I', 'N',
+					'V', 'O', 'K', 'I',
+					'N', 'G', '-', 'P',
+					'A', 'C', 'K', 'E',
+					'T',
+				},
+			},
+			{ // Unaligned data and empty extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeDestinationUnreachable, Code: 131,
+					Body: &icmp.DstUnreach{
+						Data: []byte("ERROR-INVOKING-PACKET"),
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{},
+						},
+					},
+				},
+				wire: []byte{
+					0x01, 0x83, 0x00, 0x00,
+					0x02, 0x00, 0x00, 0x00,
+					'E', 'R', 'R', 'O',
+					'R', '-', 'I', 'N',
+					'V', 'O', 'K', 'I',
+					'N', 'G', '-', 'P',
+					'A', 'C', 'K', 'E',
+					'T',
+					0x20, 0x00, 0xdf, 0xff,
+				},
+			},
+			{ // Nil extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 132,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+					},
+				},
+				wire: []byte{
+					0xa0, 0x84, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+				},
+			},
+			{ // Empty extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 133,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{},
+						},
+					},
+				},
+				wire: []byte{
+					0xa0, 0x85, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+					0x20, 0x00, 0xdf, 0xff,
+				},
+			},
+			{ // Crafted extension
+				m: icmp.Message{
+					Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 134,
+					Body: &icmp.ExtendedEchoRequest{
+						ID: 1, Seq: 2, Local: true,
+						Extensions: []icmp.Extension{
+							&icmp.RawExtension{
+								Data: []byte("CRAFTED"),
+							},
+						},
+					},
+				},
+				wire: []byte{
+					0xa0, 0x86, 0x00, 0x00,
+					0x00, 0x01, 0x02, 0x01,
+					0x20, 0x00, 0xc3, 0x21,
+					'C', 'R', 'A', 'F',
+					'T', 'E', 'D',
+				},
+			},
+		} {
+			b, err := tt.m.Marshal(nil)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if !bytes.Equal(b, tt.wire) {
+				t.Errorf("#%d: got %#v; want %#v", i, b, tt.wire)
+				continue
+			}
+			m, err := icmp.ParseMessage(tt.m.Type.Protocol(), b)
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
+			}
+			if m.Type != tt.m.Type || m.Code != tt.m.Code {
+				t.Errorf("#%d: got %v; want %v", i, m, tt.m)
+				continue
+			}
+		}
 	})
 	})
 }
 }

+ 16 - 7
icmp/messagebody.go

@@ -17,13 +17,17 @@ type MessageBody interface {
 	Marshal(proto int) ([]byte, error)
 	Marshal(proto int) ([]byte, error)
 }
 }
 
 
-// A DefaultMessageBody represents the default message body.
-type DefaultMessageBody struct {
+// A RawBody represents a raw message body.
+//
+// A raw message body is excluded from message processing and can be
+// used to construct applications such as protocol conformance
+// testing.
+type RawBody struct {
 	Data []byte // data
 	Data []byte // data
 }
 }
 
 
 // Len implements the Len method of MessageBody interface.
 // Len implements the Len method of MessageBody interface.
-func (p *DefaultMessageBody) Len(proto int) int {
+func (p *RawBody) Len(proto int) int {
 	if p == nil {
 	if p == nil {
 		return 0
 		return 0
 	}
 	}
@@ -31,13 +35,18 @@ func (p *DefaultMessageBody) Len(proto int) int {
 }
 }
 
 
 // Marshal implements the Marshal method of MessageBody interface.
 // Marshal implements the Marshal method of MessageBody interface.
-func (p *DefaultMessageBody) Marshal(proto int) ([]byte, error) {
+func (p *RawBody) Marshal(proto int) ([]byte, error) {
 	return p.Data, nil
 	return p.Data, nil
 }
 }
 
 
-// parseDefaultMessageBody parses b as an ICMP message body.
-func parseDefaultMessageBody(proto int, b []byte) (MessageBody, error) {
-	p := &DefaultMessageBody{Data: make([]byte, len(b))}
+// parseRawBody parses b as an ICMP message body.
+func parseRawBody(proto int, b []byte) (MessageBody, error) {
+	p := &RawBody{Data: make([]byte, len(b))}
 	copy(p.Data, b)
 	copy(p.Data, b)
 	return p, nil
 	return p, nil
 }
 }
+
+// A DefaultMessageBody represents the default message body.
+//
+// Deprecated: Use RawBody instead.
+type DefaultMessageBody = RawBody

+ 22 - 14
icmp/multipart.go

@@ -11,18 +11,24 @@ import "golang.org/x/net/internal/iana"
 // and a required length for a padded original datagram in wire
 // and a required length for a padded original datagram in wire
 // format.
 // format.
 func multipartMessageBodyDataLen(proto int, withOrigDgram bool, b []byte, exts []Extension) (bodyLen, dataLen int) {
 func multipartMessageBodyDataLen(proto int, withOrigDgram bool, b []byte, exts []Extension) (bodyLen, dataLen int) {
+	bodyLen = 4 // length of leading octets
+	var extLen int
+	var rawExt bool // raw extension may contain an empty object
 	for _, ext := range exts {
 	for _, ext := range exts {
-		bodyLen += ext.Len(proto)
-	}
-	if bodyLen > 0 {
-		if withOrigDgram {
-			dataLen = multipartMessageOrigDatagramLen(proto, b)
+		extLen += ext.Len(proto)
+		if _, ok := ext.(*RawExtension); ok {
+			rawExt = true
 		}
 		}
-		bodyLen += 4 // length of extension header
+	}
+	if extLen > 0 && withOrigDgram {
+		dataLen = multipartMessageOrigDatagramLen(proto, b)
 	} else {
 	} else {
 		dataLen = len(b)
 		dataLen = len(b)
 	}
 	}
-	bodyLen += dataLen
+	if extLen > 0 || rawExt {
+		bodyLen += 4 // length of extension header
+	}
+	bodyLen += dataLen + extLen
 	return bodyLen, dataLen
 	return bodyLen, dataLen
 }
 }
 
 
@@ -54,12 +60,11 @@ func multipartMessageOrigDatagramLen(proto int, b []byte) int {
 // It can be used for non-multipart message bodies when exts is nil.
 // It can be used for non-multipart message bodies when exts is nil.
 func marshalMultipartMessageBody(proto int, withOrigDgram bool, data []byte, exts []Extension) ([]byte, error) {
 func marshalMultipartMessageBody(proto int, withOrigDgram bool, data []byte, exts []Extension) ([]byte, error) {
 	bodyLen, dataLen := multipartMessageBodyDataLen(proto, withOrigDgram, data, exts)
 	bodyLen, dataLen := multipartMessageBodyDataLen(proto, withOrigDgram, data, exts)
-	b := make([]byte, 4+bodyLen)
+	b := make([]byte, bodyLen)
 	copy(b[4:], data)
 	copy(b[4:], data)
-	off := dataLen + 4
 	if len(exts) > 0 {
 	if len(exts) > 0 {
-		b[dataLen+4] = byte(extensionVersion << 4)
-		off += 4 // length of object header
+		b[4+dataLen] = byte(extensionVersion << 4)
+		off := 4 + dataLen + 4 // leading octets, data, extension header
 		for _, ext := range exts {
 		for _, ext := range exts {
 			switch ext := ext.(type) {
 			switch ext := ext.(type) {
 			case *MPLSLabelStack:
 			case *MPLSLabelStack:
@@ -78,11 +83,14 @@ func marshalMultipartMessageBody(proto int, withOrigDgram bool, data []byte, ext
 					return nil, err
 					return nil, err
 				}
 				}
 				off += ext.Len(proto)
 				off += ext.Len(proto)
+			case *RawExtension:
+				copy(b[off:], ext.Data)
+				off += ext.Len(proto)
 			}
 			}
 		}
 		}
-		s := checksum(b[dataLen+4:])
-		b[dataLen+4+2] ^= byte(s)
-		b[dataLen+4+3] ^= byte(s >> 8)
+		s := checksum(b[4+dataLen:])
+		b[4+dataLen+2] ^= byte(s)
+		b[4+dataLen+3] ^= byte(s >> 8)
 		if withOrigDgram {
 		if withOrigDgram {
 			switch proto {
 			switch proto {
 			case iana.ProtocolICMP:
 			case iana.ProtocolICMP:

+ 3 - 11
icmp/multipart_test.go

@@ -232,11 +232,6 @@ func TestMarshalAndParseMultipartMessage(t *testing.T) {
 							Type:  2,
 							Type:  2,
 							Index: 911,
 							Index: 911,
 						},
 						},
-						&icmp.InterfaceIdent{
-							Class: 3,
-							Type:  1,
-							Name:  "en101",
-						},
 					},
 					},
 				},
 				},
 			},
 			},
@@ -359,11 +354,6 @@ func TestMarshalAndParseMultipartMessage(t *testing.T) {
 				Body: &icmp.ExtendedEchoRequest{
 				Body: &icmp.ExtendedEchoRequest{
 					ID: 1, Seq: 2, Local: true,
 					ID: 1, Seq: 2, Local: true,
 					Extensions: []icmp.Extension{
 					Extensions: []icmp.Extension{
-						&icmp.InterfaceIdent{
-							Class: 3,
-							Type:  1,
-							Name:  "en101",
-						},
 						&icmp.InterfaceIdent{
 						&icmp.InterfaceIdent{
 							Class: 3,
 							Class: 3,
 							Type:  2,
 							Type:  2,
@@ -413,10 +403,12 @@ func dumpExtensions(gotExts, wantExts []icmp.Extension) string {
 			if !reflect.DeepEqual(got, want) {
 			if !reflect.DeepEqual(got, want) {
 				s += fmt.Sprintf("#%d: got %#v; want %#v\n", i, got, want)
 				s += fmt.Sprintf("#%d: got %#v; want %#v\n", i, got, want)
 			}
 			}
+		case *icmp.RawExtension:
+			s += fmt.Sprintf("#%d: raw extension\n", i)
 		}
 		}
 	}
 	}
 	if len(s) == 0 {
 	if len(s) == 0 {
-		return "<nil>"
+		s += "empty extension"
 	}
 	}
 	return s[:len(s)-1]
 	return s[:len(s)-1]
 }
 }

+ 17 - 8
icmp/paramprob.go

@@ -6,7 +6,9 @@ package icmp
 
 
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
+
 	"golang.org/x/net/internal/iana"
 	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
 )
 )
 
 
 // A ParamProb represents an ICMP parameter problem message body.
 // A ParamProb represents an ICMP parameter problem message body.
@@ -22,23 +24,30 @@ func (p *ParamProb) Len(proto int) int {
 		return 0
 		return 0
 	}
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 }
 
 
 // Marshal implements the Marshal method of MessageBody interface.
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *ParamProb) Marshal(proto int) ([]byte, error) {
 func (p *ParamProb) Marshal(proto int) ([]byte, error) {
-	if proto == iana.ProtocolIPv6ICMP {
+	switch proto {
+	case iana.ProtocolICMP:
+		if !validExtensions(ipv4.ICMPTypeParameterProblem, p.Extensions) {
+			return nil, errInvalidExtension
+		}
+		b, err := marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
+		if err != nil {
+			return nil, err
+		}
+		b[0] = byte(p.Pointer)
+		return b, nil
+	case iana.ProtocolIPv6ICMP:
 		b := make([]byte, p.Len(proto))
 		b := make([]byte, p.Len(proto))
 		binary.BigEndian.PutUint32(b[:4], uint32(p.Pointer))
 		binary.BigEndian.PutUint32(b[:4], uint32(p.Pointer))
 		copy(b[4:], p.Data)
 		copy(b[4:], p.Data)
 		return b, nil
 		return b, nil
+	default:
+		return nil, errInvalidProtocol
 	}
 	}
-	b, err := marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
-	if err != nil {
-		return nil, err
-	}
-	b[0] = byte(p.Pointer)
-	return b, nil
 }
 }
 
 
 // parseParamProb parses b as an ICMP parameter problem message body.
 // parseParamProb parses b as an ICMP parameter problem message body.

+ 19 - 1
icmp/timeexceeded.go

@@ -4,6 +4,12 @@
 
 
 package icmp
 package icmp
 
 
+import (
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
+
 // A TimeExceeded represents an ICMP time exceeded message body.
 // A TimeExceeded represents an ICMP time exceeded message body.
 type TimeExceeded struct {
 type TimeExceeded struct {
 	Data       []byte      // data, known as original datagram field
 	Data       []byte      // data, known as original datagram field
@@ -16,11 +22,23 @@ func (p *TimeExceeded) Len(proto int) int {
 		return 0
 		return 0
 	}
 	}
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
 	l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
-	return 4 + l
+	return l
 }
 }
 
 
 // Marshal implements the Marshal method of MessageBody interface.
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *TimeExceeded) Marshal(proto int) ([]byte, error) {
 func (p *TimeExceeded) Marshal(proto int) ([]byte, error) {
+	var typ Type
+	switch proto {
+	case iana.ProtocolICMP:
+		typ = ipv4.ICMPTypeTimeExceeded
+	case iana.ProtocolIPv6ICMP:
+		typ = ipv6.ICMPTypeTimeExceeded
+	default:
+		return nil, errInvalidProtocol
+	}
+	if !validExtensions(typ, p.Extensions) {
+		return nil, errInvalidExtension
+	}
 	return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
 	return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
 }
 }