Browse Source

icmp: add extensions for MPLS

This change implements ICMP multi-part message marshaler, parser and
extensions for MPLS which are used for route trace applications as
described in RFC 4950.

API breaking changes:

type MessageBody interface, Len() int
type Extension interface, Len() int
type Extension interface, Marshal() ([]byte, error)

are replaced with

type MessageBody interface, Len(int) int
type Extension interface, Len(int) int
type Extension interface, Marshal(int) ([]byte, error)

Change-Id: Iee1f2e03916d49b8dfe3a89fe682c702d40ecc85
Reviewed-on: https://go-review.googlesource.com/2794
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Mikio Hara 11 years ago
parent
commit
71586c3cf9
13 changed files with 698 additions and 63 deletions
  1. 9 10
      icmp/dstunreach.go
  2. 1 1
      icmp/echo.go
  3. 67 2
      icmp/extension.go
  4. 158 0
      icmp/extension_test.go
  5. 20 13
      icmp/message.go
  6. 6 6
      icmp/message_test.go
  7. 3 2
      icmp/messagebody.go
  8. 75 0
      icmp/mpls.go
  9. 103 0
      icmp/multipart.go
  10. 223 0
      icmp/multipart_test.go
  11. 2 2
      icmp/packettoobig.go
  12. 22 17
      icmp/paramprob.go
  13. 9 10
      icmp/timeexceeded.go

+ 9 - 10
icmp/dstunreach.go

@@ -12,31 +12,30 @@ type DstUnreach struct {
 }
 }
 
 
 // Len implements the Len method of MessageBody interface.
 // Len implements the Len method of MessageBody interface.
-func (p *DstUnreach) Len() int {
+func (p *DstUnreach) Len(proto int) int {
 	if p == nil {
 	if p == nil {
 		return 0
 		return 0
 	}
 	}
-	return 4 + len(p.Data)
+	l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions)
+	return l
 }
 }
 
 
 // Marshal implements the Marshal method of MessageBody interface.
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *DstUnreach) Marshal(proto int) ([]byte, error) {
 func (p *DstUnreach) Marshal(proto int) ([]byte, error) {
-	b := make([]byte, 4+len(p.Data))
-	copy(b[4:], p.Data)
-	return b, nil
+	return marshalMultipartMessageBody(proto, p.Data, p.Extensions)
 }
 }
 
 
 // parseDstUnreach parses b as an ICMP destination unreachable message
 // parseDstUnreach parses b as an ICMP destination unreachable message
 // body.
 // body.
 func parseDstUnreach(proto int, b []byte) (MessageBody, error) {
 func parseDstUnreach(proto int, b []byte) (MessageBody, error) {
-	bodyLen := len(b)
-	if bodyLen < 4 {
+	if len(b) < 4 {
 		return nil, errMessageTooShort
 		return nil, errMessageTooShort
 	}
 	}
 	p := &DstUnreach{}
 	p := &DstUnreach{}
-	if bodyLen > 4 {
-		p.Data = make([]byte, bodyLen-4)
-		copy(p.Data, b[4:])
+	var err error
+	p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b)
+	if err != nil {
+		return nil, err
 	}
 	}
 	return p, nil
 	return p, nil
 }
 }

+ 1 - 1
icmp/echo.go

@@ -12,7 +12,7 @@ type Echo struct {
 }
 }
 
 
 // Len implements the Len method of MessageBody interface.
 // Len implements the Len method of MessageBody interface.
-func (p *Echo) Len() int {
+func (p *Echo) Len(proto int) int {
 	if p == nil {
 	if p == nil {
 		return 0
 		return 0
 	}
 	}

+ 67 - 2
icmp/extension.go

@@ -7,10 +7,75 @@ package icmp
 // An Extension represents an ICMP extension.
 // An Extension represents an ICMP extension.
 type Extension interface {
 type Extension interface {
 	// Len returns the length of ICMP extension.
 	// Len returns the length of ICMP extension.
-	Len() int
+	// Proto must be either the ICMPv4 or ICMPv6 protocol number.
+	Len(proto int) int
 
 
 	// Marshal returns the binary enconding of ICMP extension.
 	// Marshal returns the binary enconding of ICMP extension.
-	Marshal() ([]byte, error)
+	// Proto must be either the ICMPv4 or ICMPv6 protocol number.
+	Marshal(proto int) ([]byte, error)
 }
 }
 
 
 const extensionVersion = 2
 const extensionVersion = 2
+
+func validExtensionHeader(b []byte) bool {
+	v := int(b[0]&0xf0) >> 4
+	s := uint16(b[2])<<8 | uint16(b[3])
+	if s != 0 {
+		s = checksum(b)
+	}
+	if v != extensionVersion || s != 0 {
+		return false
+	}
+	return true
+}
+
+// parseExtensions parses b as a list of ICMP extensions.
+// The length attribute l must be the length attribute field in
+// received icmp messages.
+//
+// It will return a list of ICMP extensions and an adjusted length
+// attribute that represents the length of the padded original
+// datagram field. Otherwise, it returns an error.
+func parseExtensions(b []byte, l int) ([]Extension, int, error) {
+	// Still a lot of non-RFC 4884 compliant implementations are
+	// out there. Set the length attribute l to 128 when it looks
+	// inappropriate for backwards compatibility.
+	//
+	// A minimal extension at least requires 8 octets; 4 octets
+	// for an extension header, and 4 octets for a single object
+	// header.
+	//
+	// See RFC 4884 for further information.
+	if 128 > l || l+8 > len(b) {
+		l = 128
+	}
+	if l+8 > len(b) {
+		return nil, -1, errNoExtension
+	}
+	if !validExtensionHeader(b[l:]) {
+		if l == 128 {
+			return nil, -1, errNoExtension
+		}
+		l = 128
+		if !validExtensionHeader(b[l:]) {
+			return nil, -1, errNoExtension
+		}
+	}
+	var exts []Extension
+	for b = b[l+4:]; len(b) >= 4; {
+		ol := int(b[0])<<8 | int(b[1])
+		if 4 > ol || ol > len(b) {
+			break
+		}
+		switch b[2] {
+		case classMPLSLabelStack:
+			ext, err := parseMPLSLabelStack(b[:ol])
+			if err != nil {
+				return nil, -1, err
+			}
+			exts = append(exts, ext)
+		}
+		b = b[ol:]
+	}
+	return exts, l, nil
+}

+ 158 - 0
icmp/extension_test.go

@@ -0,0 +1,158 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package icmp
+
+import (
+	"reflect"
+	"testing"
+
+	"golang.org/x/net/internal/iana"
+)
+
+var marshalAndParseExtensionTests = []struct {
+	proto int
+	hdr   []byte
+	obj   []byte
+	exts  []Extension
+}{
+	// MPLS label stack with no label
+	{
+		proto: iana.ProtocolICMP,
+		hdr: []byte{
+			0x20, 0x00, 0x00, 0x00,
+		},
+		obj: []byte{
+			0x00, 0x04, 0x01, 0x01,
+		},
+		exts: []Extension{
+			&MPLSLabelStack{
+				Class: classMPLSLabelStack,
+				Type:  typeIncomingMPLSLabelStack,
+			},
+		},
+	},
+	// MPLS label stack with a single label
+	{
+		proto: iana.ProtocolIPv6ICMP,
+		hdr: []byte{
+			0x20, 0x00, 0x00, 0x00,
+		},
+		obj: []byte{
+			0x00, 0x08, 0x01, 0x01,
+			0x03, 0xe8, 0xe9, 0xff,
+		},
+		exts: []Extension{
+			&MPLSLabelStack{
+				Class: classMPLSLabelStack,
+				Type:  typeIncomingMPLSLabelStack,
+				Labels: []MPLSLabel{
+					{
+						Label: 16014,
+						TC:    0x4,
+						S:     true,
+						TTL:   255,
+					},
+				},
+			},
+		},
+	},
+	// MPLS label stack with multiple labels
+	{
+		proto: iana.ProtocolICMP,
+		hdr: []byte{
+			0x20, 0x00, 0x00, 0x00,
+		},
+		obj: []byte{
+			0x00, 0x0c, 0x01, 0x01,
+			0x03, 0xe8, 0xde, 0xfe,
+			0x03, 0xe8, 0xe1, 0xff,
+		},
+		exts: []Extension{
+			&MPLSLabelStack{
+				Class: classMPLSLabelStack,
+				Type:  typeIncomingMPLSLabelStack,
+				Labels: []MPLSLabel{
+					{
+						Label: 16013,
+						TC:    0x7,
+						S:     false,
+						TTL:   254,
+					},
+					{
+						Label: 16014,
+						TC:    0,
+						S:     true,
+						TTL:   255,
+					},
+				},
+			},
+		},
+	},
+}
+
+func TestMarshalAndParseExtension(t *testing.T) {
+	for i, tt := range marshalAndParseExtensionTests {
+		for j, ext := range tt.exts {
+			var err error
+			var b []byte
+			switch ext := ext.(type) {
+			case *MPLSLabelStack:
+				b, err = ext.Marshal(tt.proto)
+				if err != nil {
+					t.Errorf("#%v/%v: %v", i, j, err)
+					continue
+				}
+			}
+			if !reflect.DeepEqual(b, tt.obj) {
+				t.Errorf("#%v/%v: got %#v; want %#v", i, j, b, tt.obj)
+				continue
+			}
+		}
+
+		for j, wire := range []struct {
+			data     []byte // original datagram
+			inlattr  int    // length of padded original datagram, a hint
+			outlattr int    // length of padded original datagram, a want
+			err      error
+		}{
+			{nil, 0, -1, errNoExtension},
+			{make([]byte, 127), 128, -1, errNoExtension},
+
+			{make([]byte, 128), 127, -1, errNoExtension},
+			{make([]byte, 128), 128, -1, errNoExtension},
+			{make([]byte, 128), 129, -1, errNoExtension},
+
+			{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 127, 128, nil},
+			{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 128, 128, nil},
+			{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 129, 128, nil},
+
+			{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 511, -1, errNoExtension},
+			{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 512, 512, nil},
+			{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 513, -1, errNoExtension},
+		} {
+			exts, l, err := parseExtensions(wire.data, wire.inlattr)
+			if err != wire.err {
+				t.Errorf("#%v/%v: got %v; want %v", i, j, err, wire.err)
+				continue
+			}
+			if wire.err != nil {
+				continue
+			}
+			if l != wire.outlattr {
+				t.Errorf("#%v/%v: got %v; want %v", i, j, l, wire.outlattr)
+			}
+			if !reflect.DeepEqual(exts, tt.exts) {
+				for j, ext := range exts {
+					switch ext := ext.(type) {
+					case *MPLSLabelStack:
+						want := tt.exts[j].(*MPLSLabelStack)
+						t.Errorf("#%v/%v: got %#v; want %#v", i, j, ext, want)
+					}
+				}
+				continue
+			}
+		}
+	}
+}

+ 20 - 13
icmp/message.go

@@ -8,6 +8,7 @@
 //
 //
 // ICMPv4 and ICMPv6 are defined in RFC 792 and RFC 4443.
 // ICMPv4 and ICMPv6 are defined in RFC 792 and RFC 4443.
 // Multi-part message support for ICMP is defined in RFC 4884.
 // Multi-part message support for ICMP is defined in RFC 4884.
+// ICMP extensions for MPLS are defined in RFC 4950.
 package icmp // import "golang.org/x/net/icmp"
 package icmp // import "golang.org/x/net/icmp"
 
 
 import (
 import (
@@ -25,8 +26,23 @@ var (
 	errHeaderTooShort  = errors.New("header too short")
 	errHeaderTooShort  = errors.New("header too short")
 	errBufferTooShort  = errors.New("buffer too short")
 	errBufferTooShort  = errors.New("buffer too short")
 	errOpNoSupport     = errors.New("operation not supported")
 	errOpNoSupport     = errors.New("operation not supported")
+	errNoExtension     = errors.New("no extension")
 )
 )
 
 
+func checksum(b []byte) uint16 {
+	csumcv := len(b) - 1 // checksum coverage
+	s := uint32(0)
+	for i := 0; i < csumcv; i += 2 {
+		s += uint32(b[i+1])<<8 | uint32(b[i])
+	}
+	if csumcv&1 == 0 {
+		s += uint32(b[csumcv])
+	}
+	s = s>>16 + s&0xffff
+	s = s + s>>16
+	return ^uint16(s)
+}
+
 // A Type represents an ICMP message type.
 // A Type represents an ICMP message type.
 type Type interface {
 type Type interface {
 	Protocol() int
 	Protocol() int
@@ -63,7 +79,7 @@ func (m *Message) Marshal(psh []byte) ([]byte, error) {
 	if m.Type.Protocol() == iana.ProtocolIPv6ICMP && psh != nil {
 	if m.Type.Protocol() == iana.ProtocolIPv6ICMP && psh != nil {
 		b = append(psh, b...)
 		b = append(psh, b...)
 	}
 	}
-	if m.Body != nil && m.Body.Len() != 0 {
+	if m.Body != nil && m.Body.Len(m.Type.Protocol()) != 0 {
 		mb, err := m.Body.Marshal(m.Type.Protocol())
 		mb, err := m.Body.Marshal(m.Type.Protocol())
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
@@ -77,20 +93,11 @@ func (m *Message) Marshal(psh []byte) ([]byte, error) {
 		off, l := 2*net.IPv6len, len(b)-len(psh)
 		off, l := 2*net.IPv6len, len(b)-len(psh)
 		b[off], b[off+1], b[off+2], b[off+3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l)
 		b[off], b[off+1], b[off+2], b[off+3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l)
 	}
 	}
-	csumcv := len(b) - 1 // checksum coverage
-	s := uint32(0)
-	for i := 0; i < csumcv; i += 2 {
-		s += uint32(b[i+1])<<8 | uint32(b[i])
-	}
-	if csumcv&1 == 0 {
-		s += uint32(b[csumcv])
-	}
-	s = s>>16 + s&0xffff
-	s = s + s>>16
+	s := checksum(b)
 	// Place checksum back in header; using ^= avoids the
 	// Place checksum back in header; using ^= avoids the
 	// assumption the checksum bytes are zero.
 	// assumption the checksum bytes are zero.
-	b[len(psh)+2] ^= byte(^s)
-	b[len(psh)+3] ^= byte(^s >> 8)
+	b[len(psh)+2] ^= byte(s)
+	b[len(psh)+3] ^= byte(s >> 8)
 	return b[len(psh):], nil
 	return b[len(psh):], nil
 }
 }
 
 

+ 6 - 6
icmp/message_test.go

@@ -51,7 +51,7 @@ var marshalAndParseMessageForIPv4Tests = []icmp.Message{
 }
 }
 
 
 func TestMarshalAndParseMessageForIPv4(t *testing.T) {
 func TestMarshalAndParseMessageForIPv4(t *testing.T) {
-	for _, tt := range marshalAndParseMessageForIPv4Tests {
+	for i, tt := range marshalAndParseMessageForIPv4Tests {
 		b, err := tt.Marshal(nil)
 		b, err := tt.Marshal(nil)
 		if err != nil {
 		if err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
@@ -61,10 +61,10 @@ func TestMarshalAndParseMessageForIPv4(t *testing.T) {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
 		if m.Type != tt.Type || m.Code != tt.Code {
 		if m.Type != tt.Type || m.Code != tt.Code {
-			t.Errorf("got %v; want %v", m, &tt)
+			t.Errorf("#%v: got %v; want %v", i, m, &tt)
 		}
 		}
 		if !reflect.DeepEqual(m.Body, tt.Body) {
 		if !reflect.DeepEqual(m.Body, tt.Body) {
-			t.Errorf("got %v; want %v", m.Body, tt.Body)
+			t.Errorf("#%v: got %v; want %v", i, m.Body, tt.Body)
 		}
 		}
 	}
 	}
 }
 }
@@ -113,7 +113,7 @@ var marshalAndParseMessageForIPv6Tests = []icmp.Message{
 
 
 func TestMarshalAndParseMessageForIPv6(t *testing.T) {
 func TestMarshalAndParseMessageForIPv6(t *testing.T) {
 	pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1"))
 	pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1"))
-	for _, tt := range marshalAndParseMessageForIPv6Tests {
+	for i, tt := range marshalAndParseMessageForIPv6Tests {
 		for _, psh := range [][]byte{pshicmp, nil} {
 		for _, psh := range [][]byte{pshicmp, nil} {
 			b, err := tt.Marshal(psh)
 			b, err := tt.Marshal(psh)
 			if err != nil {
 			if err != nil {
@@ -124,10 +124,10 @@ func TestMarshalAndParseMessageForIPv6(t *testing.T) {
 				t.Fatal(err)
 				t.Fatal(err)
 			}
 			}
 			if m.Type != tt.Type || m.Code != tt.Code {
 			if m.Type != tt.Type || m.Code != tt.Code {
-				t.Errorf("got %v; want %v", m, &tt)
+				t.Errorf("#%v: got %v; want %v", i, m, &tt)
 			}
 			}
 			if !reflect.DeepEqual(m.Body, tt.Body) {
 			if !reflect.DeepEqual(m.Body, tt.Body) {
-				t.Errorf("got %v; want %v", m.Body, tt.Body)
+				t.Errorf("#%v: got %v; want %v", i, m.Body, tt.Body)
 			}
 			}
 		}
 		}
 	}
 	}

+ 3 - 2
icmp/messagebody.go

@@ -7,7 +7,8 @@ package icmp
 // A MessageBody represents an ICMP message body.
 // A MessageBody represents an ICMP message body.
 type MessageBody interface {
 type MessageBody interface {
 	// Len returns the length of ICMP message body.
 	// Len returns the length of ICMP message body.
-	Len() int
+	// Proto must be either the ICMPv4 or ICMPv6 protocol number.
+	Len(proto int) int
 
 
 	// Marshal returns the binary enconding of ICMP message body.
 	// Marshal returns the binary enconding of ICMP message body.
 	// Proto must be either the ICMPv4 or ICMPv6 protocol number.
 	// Proto must be either the ICMPv4 or ICMPv6 protocol number.
@@ -20,7 +21,7 @@ type DefaultMessageBody struct {
 }
 }
 
 
 // Len implements the Len method of MessageBody interface.
 // Len implements the Len method of MessageBody interface.
-func (p *DefaultMessageBody) Len() int {
+func (p *DefaultMessageBody) Len(proto int) int {
 	if p == nil {
 	if p == nil {
 		return 0
 		return 0
 	}
 	}

+ 75 - 0
icmp/mpls.go

@@ -0,0 +1,75 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package icmp
+
+// A MPLSLabel represents a MPLS label stack entry.
+type MPLSLabel struct {
+	Label int  // label value
+	TC    int  // traffic class; formerly experimental use
+	S     bool // bottom of stack
+	TTL   int  // time to live
+}
+
+const (
+	classMPLSLabelStack        = 1
+	typeIncomingMPLSLabelStack = 1
+)
+
+// A MPLSLabelStack represents a MPLS label stack.
+type MPLSLabelStack struct {
+	Class  int // extension object class number
+	Type   int // extension object sub-type
+	Labels []MPLSLabel
+}
+
+// Len implements the Len method of Extension interface.
+func (ls *MPLSLabelStack) Len(proto int) int {
+	return 4 + (4 * len(ls.Labels))
+}
+
+// Marshal implements the Marshal method of Extension interface.
+func (ls *MPLSLabelStack) Marshal(proto int) ([]byte, error) {
+	b := make([]byte, ls.Len(proto))
+	if err := ls.marshal(proto, b); err != nil {
+		return nil, err
+	}
+	return b, nil
+}
+
+func (ls *MPLSLabelStack) marshal(proto int, b []byte) error {
+	l := ls.Len(proto)
+	b[0], b[1] = byte(l>>8), byte(l)
+	b[2], b[3] = classMPLSLabelStack, typeIncomingMPLSLabelStack
+	off := 4
+	for _, ll := range ls.Labels {
+		b[off], b[off+1], b[off+2] = byte(ll.Label>>12), byte(ll.Label>>4&0xff), byte(ll.Label<<4&0xf0)
+		b[off+2] |= byte(ll.TC << 1 & 0x0e)
+		if ll.S {
+			b[off+2] |= 0x1
+		}
+		b[off+3] = byte(ll.TTL)
+		off += 4
+	}
+	return nil
+}
+
+func parseMPLSLabelStack(b []byte) (Extension, error) {
+	ls := &MPLSLabelStack{
+		Class: int(b[2]),
+		Type:  int(b[3]),
+	}
+	for b = b[4:]; len(b) >= 4; b = b[4:] {
+		ll := MPLSLabel{
+			Label: int(b[0])<<12 | int(b[1])<<4 | int(b[2])>>4,
+			TC:    int(b[2]&0x0e) >> 1,
+			TTL:   int(b[3]),
+		}
+		if b[2]&0x1 != 0 {
+			ll.S = true
+		}
+		ls.Labels = append(ls.Labels, ll)
+	}
+	return ls, nil
+}

+ 103 - 0
icmp/multipart.go

@@ -0,0 +1,103 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package icmp
+
+import "golang.org/x/net/internal/iana"
+
+// multipartMessageBodyDataLen takes b as an original datagram and
+// exts as extensions, and returns a required length for message body
+// and a required length for a padded original datagram in wire
+// format.
+func multipartMessageBodyDataLen(proto int, b []byte, exts []Extension) (bodyLen, dataLen int) {
+	for _, ext := range exts {
+		bodyLen += ext.Len(proto)
+	}
+	if bodyLen > 0 {
+		dataLen = multipartMessageOrigDatagramLen(proto, b)
+		bodyLen += 4 // length of extension header
+	} else {
+		dataLen = len(b)
+	}
+	bodyLen += dataLen
+	return bodyLen, dataLen
+}
+
+// multipartMessageOrigDatagramLen takes b as an original datagram,
+// and returns a required length for a padded orignal datagram in wire
+// format.
+func multipartMessageOrigDatagramLen(proto int, b []byte) int {
+	roundup := func(b []byte, align int) int {
+		// According to RFC 4884, the padded original datagram
+		// field must contain at least 128 octets.
+		if len(b) < 128 {
+			return 128
+		}
+		r := len(b)
+		return (r + align) &^ (align - 1)
+	}
+	switch proto {
+	case iana.ProtocolICMP:
+		return roundup(b, 4)
+	case iana.ProtocolIPv6ICMP:
+		return roundup(b, 8)
+	default:
+		return len(b)
+	}
+}
+
+// marshalMultipartMessageBody takes data as an original datagram and
+// exts as extesnsions, and returns a binary encoding of message body.
+// It can be used for non-multipart message bodies when exts is nil.
+func marshalMultipartMessageBody(proto int, data []byte, exts []Extension) ([]byte, error) {
+	bodyLen, dataLen := multipartMessageBodyDataLen(proto, data, exts)
+	b := make([]byte, 4+bodyLen)
+	copy(b[4:], data)
+	off := dataLen + 4
+	if len(exts) > 0 {
+		b[dataLen+4] = byte(extensionVersion << 4)
+		off += 4 // length of object header
+		for _, ext := range exts {
+			switch ext := ext.(type) {
+			case *MPLSLabelStack:
+				if err := ext.marshal(proto, b[off:]); err != nil {
+					return nil, err
+				}
+				off += ext.Len(proto)
+			}
+		}
+		s := checksum(b[dataLen+4:])
+		b[dataLen+4+2] ^= byte(s)
+		b[dataLen+4+3] ^= byte(s >> 8)
+		switch proto {
+		case iana.ProtocolICMP:
+			b[1] = byte(dataLen / 4)
+		case iana.ProtocolIPv6ICMP:
+			b[0] = byte(dataLen / 8)
+		}
+	}
+	return b, nil
+}
+
+// parseMultipartMessageBody parses b as either a non-multipart
+// message body or a multipart message body.
+func parseMultipartMessageBody(proto int, b []byte) ([]byte, []Extension, error) {
+	var l int
+	switch proto {
+	case iana.ProtocolICMP:
+		l = 4 * int(b[1])
+	case iana.ProtocolIPv6ICMP:
+		l = 8 * int(b[0])
+	}
+	if len(b) == 4 {
+		return nil, nil, nil
+	}
+	exts, l, err := parseExtensions(b[4:], l)
+	if err != nil {
+		l = len(b) - 4
+	}
+	data := make([]byte, l)
+	copy(data, b[4:])
+	return data, exts, nil
+}

+ 223 - 0
icmp/multipart_test.go

@@ -0,0 +1,223 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package icmp_test
+
+import (
+	"fmt"
+	"net"
+	"reflect"
+	"testing"
+
+	"golang.org/x/net/icmp"
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+)
+
+var marshalAndParseMultipartMessageForIPv4Tests = []icmp.Message{
+	{
+		Type: ipv4.ICMPTypeDestinationUnreachable, Code: 15,
+		Body: &icmp.DstUnreach{
+			Data: []byte("ERROR-INVOKING-PACKET"),
+			Extensions: []icmp.Extension{
+				&icmp.MPLSLabelStack{
+					Class: 1,
+					Type:  1,
+					Labels: []icmp.MPLSLabel{
+						{
+							Label: 16014,
+							TC:    0x4,
+							S:     true,
+							TTL:   255,
+						},
+					},
+				},
+			},
+		},
+	},
+	{
+		Type: ipv4.ICMPTypeTimeExceeded, Code: 1,
+		Body: &icmp.TimeExceeded{
+			Data: []byte("ERROR-INVOKING-PACKET"),
+			Extensions: []icmp.Extension{
+				&icmp.MPLSLabelStack{
+					Class: 1,
+					Type:  1,
+					Labels: []icmp.MPLSLabel{
+						{
+							Label: 16014,
+							TC:    0x4,
+							S:     true,
+							TTL:   255,
+						},
+					},
+				},
+			},
+		},
+	},
+	{
+		Type: ipv4.ICMPTypeParameterProblem, Code: 2,
+		Body: &icmp.ParamProb{
+			Pointer: 8,
+			Data:    []byte("ERROR-INVOKING-PACKET"),
+			Extensions: []icmp.Extension{
+				&icmp.MPLSLabelStack{
+					Class: 1,
+					Type:  1,
+					Labels: []icmp.MPLSLabel{
+						{
+							Label: 16014,
+							TC:    0x4,
+							S:     true,
+							TTL:   255,
+						},
+					},
+				},
+			},
+		},
+	},
+}
+
+func TestMarshalAndParseMultipartMessageForIPv4(t *testing.T) {
+	for i, tt := range marshalAndParseMultipartMessageForIPv4Tests {
+		b, err := tt.Marshal(nil)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if b[5] != 32 {
+			t.Errorf("#%v: got %v; want 32", i, b[5])
+		}
+		m, err := icmp.ParseMessage(iana.ProtocolICMP, b)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if m.Type != tt.Type || m.Code != tt.Code {
+			t.Errorf("#%v: got %v; want %v", i, m, &tt)
+		}
+		switch m.Type {
+		case ipv4.ICMPTypeDestinationUnreachable:
+			got, want := m.Body.(*icmp.DstUnreach), tt.Body.(*icmp.DstUnreach)
+			if !reflect.DeepEqual(got.Extensions, want.Extensions) {
+				t.Errorf("#%v: got %#v; want %#v", i, got.Extensions, want.Extensions)
+			}
+			if len(got.Data) != 128 {
+				t.Errorf("#%v: got %v; want 128", i, len(got.Data))
+			}
+		case ipv4.ICMPTypeTimeExceeded:
+			got, want := m.Body.(*icmp.TimeExceeded), tt.Body.(*icmp.TimeExceeded)
+			if !reflect.DeepEqual(got.Extensions, want.Extensions) {
+				t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
+			}
+			if len(got.Data) != 128 {
+				t.Errorf("#%v: got %v; want 128", i, len(got.Data))
+			}
+		case ipv4.ICMPTypeParameterProblem:
+			got, want := m.Body.(*icmp.ParamProb), tt.Body.(*icmp.ParamProb)
+			if !reflect.DeepEqual(got.Extensions, want.Extensions) {
+				t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
+			}
+			if len(got.Data) != 128 {
+				t.Errorf("#%v: got %v; want 128", i, len(got.Data))
+			}
+		}
+	}
+}
+
+var marshalAndParseMultipartMessageForIPv6Tests = []icmp.Message{
+	{
+		Type: ipv6.ICMPTypeDestinationUnreachable, Code: 6,
+		Body: &icmp.DstUnreach{
+			Data: []byte("ERROR-INVOKING-PACKET"),
+			Extensions: []icmp.Extension{
+				&icmp.MPLSLabelStack{
+					Class: 1,
+					Type:  1,
+					Labels: []icmp.MPLSLabel{
+						{
+							Label: 16014,
+							TC:    0x4,
+							S:     true,
+							TTL:   255,
+						},
+					},
+				},
+			},
+		},
+	},
+	{
+		Type: ipv6.ICMPTypeTimeExceeded, Code: 1,
+		Body: &icmp.TimeExceeded{
+			Data: []byte("ERROR-INVOKING-PACKET"),
+			Extensions: []icmp.Extension{
+				&icmp.MPLSLabelStack{
+					Class: 1,
+					Type:  1,
+					Labels: []icmp.MPLSLabel{
+						{
+							Label: 16014,
+							TC:    0x4,
+							S:     true,
+							TTL:   255,
+						},
+					},
+				},
+			},
+		},
+	},
+}
+
+func TestMarshalAndParseMultipartMessageForIPv6(t *testing.T) {
+	pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1"))
+	for i, tt := range marshalAndParseMultipartMessageForIPv6Tests {
+		for _, psh := range [][]byte{pshicmp, nil} {
+			b, err := tt.Marshal(psh)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if b[4] != 16 {
+				t.Errorf("#%v: got %v; want 16", i, b[4])
+			}
+			m, err := icmp.ParseMessage(iana.ProtocolIPv6ICMP, b)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if m.Type != tt.Type || m.Code != tt.Code {
+				t.Errorf("#%v: got %v; want %v", i, m, &tt)
+			}
+			switch m.Type {
+			case ipv6.ICMPTypeDestinationUnreachable:
+				got, want := m.Body.(*icmp.DstUnreach), tt.Body.(*icmp.DstUnreach)
+				if !reflect.DeepEqual(got.Extensions, want.Extensions) {
+					t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
+				}
+				if len(got.Data) != 128 {
+					t.Errorf("#%v: got %v; want 128", i, len(got.Data))
+				}
+			case ipv6.ICMPTypeTimeExceeded:
+				got, want := m.Body.(*icmp.TimeExceeded), tt.Body.(*icmp.TimeExceeded)
+				if !reflect.DeepEqual(got.Extensions, want.Extensions) {
+					t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
+				}
+				if len(got.Data) != 128 {
+					t.Errorf("#%v: got %v; want 128", i, len(got.Data))
+				}
+			}
+		}
+	}
+}
+
+func dumpExtensions(i int, gotExts, wantExts []icmp.Extension) string {
+	var s string
+	for j, got := range gotExts {
+		switch got := got.(type) {
+		case *icmp.MPLSLabelStack:
+			want := wantExts[j].(*icmp.MPLSLabelStack)
+			if !reflect.DeepEqual(got, want) {
+				s += fmt.Sprintf("#%v/%v: got %#v; want %#v\n", i, j, got, want)
+			}
+		}
+	}
+	return s[:len(s)-1]
+}

+ 2 - 2
icmp/packettoobig.go

@@ -7,11 +7,11 @@ package icmp
 // A PacketTooBig represents an ICMP packet too big message body.
 // A PacketTooBig represents an ICMP packet too big message body.
 type PacketTooBig struct {
 type PacketTooBig struct {
 	MTU  int    // maximum transmission unit of the nexthop link
 	MTU  int    // maximum transmission unit of the nexthop link
-	Data []byte // data
+	Data []byte // data, known as original datagram field
 }
 }
 
 
 // Len implements the Len method of MessageBody interface.
 // Len implements the Len method of MessageBody interface.
-func (p *PacketTooBig) Len() int {
+func (p *PacketTooBig) Len(proto int) int {
 	if p == nil {
 	if p == nil {
 		return 0
 		return 0
 	}
 	}

+ 22 - 17
icmp/paramprob.go

@@ -14,42 +14,47 @@ type ParamProb struct {
 }
 }
 
 
 // Len implements the Len method of MessageBody interface.
 // Len implements the Len method of MessageBody interface.
-func (p *ParamProb) Len() int {
+func (p *ParamProb) Len(proto int) int {
 	if p == nil {
 	if p == nil {
 		return 0
 		return 0
 	}
 	}
-	return 4 + len(p.Data)
+	l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions)
+	return l
 }
 }
 
 
 // Marshal implements the Marshal method of MessageBody interface.
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *ParamProb) Marshal(proto int) ([]byte, error) {
 func (p *ParamProb) Marshal(proto int) ([]byte, error) {
-	b := make([]byte, 4+len(p.Data))
-	switch proto {
-	case iana.ProtocolICMP:
-		b[0] = byte(p.Pointer)
-	case iana.ProtocolIPv6ICMP:
+	if proto == iana.ProtocolIPv6ICMP {
+		b := make([]byte, 4+p.Len(proto))
 		b[0], b[1], b[2], b[3] = byte(p.Pointer>>24), byte(p.Pointer>>16), byte(p.Pointer>>8), byte(p.Pointer)
 		b[0], b[1], b[2], b[3] = byte(p.Pointer>>24), byte(p.Pointer>>16), byte(p.Pointer>>8), byte(p.Pointer)
+		copy(b[4:], p.Data)
+		return b, nil
 	}
 	}
-	copy(b[4:], p.Data)
+	b, err := marshalMultipartMessageBody(proto, p.Data, p.Extensions)
+	if err != nil {
+		return nil, err
+	}
+	b[0] = byte(p.Pointer)
 	return b, nil
 	return b, nil
 }
 }
 
 
 // parseParamProb parses b as an ICMP parameter problem message body.
 // parseParamProb parses b as an ICMP parameter problem message body.
 func parseParamProb(proto int, b []byte) (MessageBody, error) {
 func parseParamProb(proto int, b []byte) (MessageBody, error) {
-	bodyLen := len(b)
-	if bodyLen < 4 {
+	if len(b) < 4 {
 		return nil, errMessageTooShort
 		return nil, errMessageTooShort
 	}
 	}
 	p := &ParamProb{}
 	p := &ParamProb{}
-	switch proto {
-	case iana.ProtocolICMP:
-		p.Pointer = uintptr(b[0])
-	case iana.ProtocolIPv6ICMP:
+	if proto == iana.ProtocolIPv6ICMP {
 		p.Pointer = uintptr(b[0])<<24 | uintptr(b[1])<<16 | uintptr(b[2])<<8 | uintptr(b[3])
 		p.Pointer = uintptr(b[0])<<24 | uintptr(b[1])<<16 | uintptr(b[2])<<8 | uintptr(b[3])
-	}
-	if bodyLen > 4 {
-		p.Data = make([]byte, bodyLen-4)
+		p.Data = make([]byte, len(b)-4)
 		copy(p.Data, b[4:])
 		copy(p.Data, b[4:])
+		return p, nil
+	}
+	p.Pointer = uintptr(b[0])
+	var err error
+	p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b)
+	if err != nil {
+		return nil, err
 	}
 	}
 	return p, nil
 	return p, nil
 }
 }

+ 9 - 10
icmp/timeexceeded.go

@@ -11,30 +11,29 @@ type TimeExceeded struct {
 }
 }
 
 
 // Len implements the Len method of MessageBody interface.
 // Len implements the Len method of MessageBody interface.
-func (p *TimeExceeded) Len() int {
+func (p *TimeExceeded) Len(proto int) int {
 	if p == nil {
 	if p == nil {
 		return 0
 		return 0
 	}
 	}
-	return 4 + len(p.Data)
+	l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions)
+	return l
 }
 }
 
 
 // Marshal implements the Marshal method of MessageBody interface.
 // Marshal implements the Marshal method of MessageBody interface.
 func (p *TimeExceeded) Marshal(proto int) ([]byte, error) {
 func (p *TimeExceeded) Marshal(proto int) ([]byte, error) {
-	b := make([]byte, 4+len(p.Data))
-	copy(b[4:], p.Data)
-	return b, nil
+	return marshalMultipartMessageBody(proto, p.Data, p.Extensions)
 }
 }
 
 
 // parseTimeExceeded parses b as an ICMP time exceeded message body.
 // parseTimeExceeded parses b as an ICMP time exceeded message body.
 func parseTimeExceeded(proto int, b []byte) (MessageBody, error) {
 func parseTimeExceeded(proto int, b []byte) (MessageBody, error) {
-	bodyLen := len(b)
-	if bodyLen < 4 {
+	if len(b) < 4 {
 		return nil, errMessageTooShort
 		return nil, errMessageTooShort
 	}
 	}
 	p := &TimeExceeded{}
 	p := &TimeExceeded{}
-	if bodyLen > 4 {
-		p.Data = make([]byte, bodyLen-4)
-		copy(p.Data, b[4:])
+	var err error
+	p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b)
+	if err != nil {
+		return nil, err
 	}
 	}
 	return p, nil
 	return p, nil
 }
 }