Selaa lähdekoodia

Address a few more OOB slice read issues

address oob slice read issues
Becca Petrin 6 vuotta sitten
vanhempi
commit
aa384b608e
4 muutettua tiedostoa jossa 147 lisäystä ja 17 poistoa
  1. 93 17
      keytab/keytab.go
  2. 35 0
      keytab/keytab_test.go
  3. 4 0
      spnego/spnego.go
  4. 15 0
      spnego/spnego_test.go

+ 93 - 17
keytab/keytab.go

@@ -141,6 +141,10 @@ func (kt *Keytab) Write(w io.Writer) (int, error) {
 
 
 // Unmarshal byte slice of Keytab data into Keytab type.
 // Unmarshal byte slice of Keytab data into Keytab type.
 func (kt *Keytab) Unmarshal(b []byte) error {
 func (kt *Keytab) Unmarshal(b []byte) error {
+	if len(b) < 2 {
+		return fmt.Errorf("byte array is less than 2 bytes: %d", len(b))
+	}
+
 	//The first byte of the file always has the value 5
 	//The first byte of the file always has the value 5
 	if b[0] != keytabFirstByte {
 	if b[0] != keytabFirstByte {
 		return errors.New("invalid keytab data. First byte does not equal 5")
 		return errors.New("invalid keytab data. First byte does not equal 5")
@@ -165,7 +169,10 @@ func (kt *Keytab) Unmarshal(b []byte) error {
 	*/
 	*/
 	// n tracks position in the byte array
 	// n tracks position in the byte array
 	n := 2
 	n := 2
-	l := readInt32(b, &n, &endian)
+	l, err := readInt32(b, &n, &endian)
+	if err != nil {
+		return err
+	}
 	for l != 0 {
 	for l != 0 {
 		if l < 0 {
 		if l < 0 {
 			//Zero padded so skip over
 			//Zero padded so skip over
@@ -173,6 +180,12 @@ func (kt *Keytab) Unmarshal(b []byte) error {
 			n = n + int(l)
 			n = n + int(l)
 		} else {
 		} else {
 			//fmt.Printf("Bytes for entry: %v\n", b[n:n+int(l)])
 			//fmt.Printf("Bytes for entry: %v\n", b[n:n+int(l)])
+			if n < 0 {
+				return fmt.Errorf("%d can't be less than zero", n)
+			}
+			if n+int(l) > len(b) {
+				return fmt.Errorf("%s's length is less than %d", b, n+int(l))
+			}
 			eb := b[n : n+int(l)]
 			eb := b[n : n+int(l)]
 			n = n + int(l)
 			n = n + int(l)
 			ke := newKeytabEntry()
 			ke := newKeytabEntry()
@@ -180,10 +193,25 @@ func (kt *Keytab) Unmarshal(b []byte) error {
 			var p int
 			var p int
 			var err error
 			var err error
 			parsePrincipal(eb, &p, kt, &ke, &endian)
 			parsePrincipal(eb, &p, kt, &ke, &endian)
-			ke.Timestamp = readTimestamp(eb, &p, &endian)
-			ke.KVNO8 = uint8(readInt8(eb, &p, &endian))
-			ke.Key.KeyType = int32(readInt16(eb, &p, &endian))
-			kl := int(readInt16(eb, &p, &endian))
+			ke.Timestamp, err = readTimestamp(eb, &p, &endian)
+			if err != nil {
+				return err
+			}
+			rei8, err := readInt8(eb, &p, &endian)
+			if err != nil {
+				return err
+			}
+			ke.KVNO8 = uint8(rei8)
+			rei16, err := readInt16(eb, &p, &endian)
+			if err != nil {
+				return err
+			}
+			ke.Key.KeyType = int32(rei16)
+			rei16, err = readInt16(eb, &p, &endian)
+			if err != nil {
+				return err
+			}
+			kl := int(rei16)
 			ke.Key.KeyValue, err = readBytes(eb, &p, kl, &endian)
 			ke.Key.KeyValue, err = readBytes(eb, &p, kl, &endian)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
@@ -193,7 +221,11 @@ func (kt *Keytab) Unmarshal(b []byte) error {
 			// and that the value of the 32-bit integer contained in those bytes is non-zero.
 			// and that the value of the 32-bit integer contained in those bytes is non-zero.
 			if len(eb)-p >= 4 {
 			if len(eb)-p >= 4 {
 				// The 32-bit key may be present
 				// The 32-bit key may be present
-				ke.KVNO = uint32(readInt32(eb, &p, &endian))
+				ri32, err := readInt32(eb, &p, &endian)
+				if err != nil {
+					return err
+				}
+				ke.KVNO = uint32(ri32)
 			}
 			}
 			if ke.KVNO == 0 {
 			if ke.KVNO == 0 {
 				// Handles if the value from the last 4 bytes was zero and also if there are not the 4 bytes present. Makes sense to put the same value here as KVNO8
 				// Handles if the value from the last 4 bytes was zero and also if there are not the 4 bytes present. Makes sense to put the same value here as KVNO8
@@ -203,11 +235,15 @@ func (kt *Keytab) Unmarshal(b []byte) error {
 			kt.Entries = append(kt.Entries, ke)
 			kt.Entries = append(kt.Entries, ke)
 		}
 		}
 		// Check if there are still 4 bytes left to read
 		// Check if there are still 4 bytes left to read
-		if n > len(b) || len(b[n:]) < 4 {
+		// Also check that n is greater than zero
+		if n < 0 || n > len(b) || len(b[n:]) < 4 {
 			break
 			break
 		}
 		}
 		// Read the size of the next entry
 		// Read the size of the next entry
-		l = readInt32(b, &n, &endian)
+		l, err = readInt32(b, &n, &endian)
+		if err != nil {
+			return err
+		}
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -253,19 +289,29 @@ func (e entry) marshal(v int) ([]byte, error) {
 
 
 // Parse the Keytab bytes of a principal into a Keytab entry's principal.
 // Parse the Keytab bytes of a principal into a Keytab entry's principal.
 func parsePrincipal(b []byte, p *int, kt *Keytab, ke *entry, e *binary.ByteOrder) error {
 func parsePrincipal(b []byte, p *int, kt *Keytab, ke *entry, e *binary.ByteOrder) error {
-	ke.Principal.NumComponents = readInt16(b, p, e)
+	var err error
+	ke.Principal.NumComponents, err = readInt16(b, p, e)
+	if err != nil {
+		return err
+	}
 	if kt.version == 1 {
 	if kt.version == 1 {
 		//In version 1 the number of components includes the realm. Minus 1 to make consistent with version 2
 		//In version 1 the number of components includes the realm. Minus 1 to make consistent with version 2
 		ke.Principal.NumComponents--
 		ke.Principal.NumComponents--
 	}
 	}
-	lenRealm := readInt16(b, p, e)
+	lenRealm, err := readInt16(b, p, e)
+	if err != nil {
+		return err
+	}
 	realmB, err := readBytes(b, p, int(lenRealm), e)
 	realmB, err := readBytes(b, p, int(lenRealm), e)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	ke.Principal.Realm = string(realmB)
 	ke.Principal.Realm = string(realmB)
 	for i := 0; i < int(ke.Principal.NumComponents); i++ {
 	for i := 0; i < int(ke.Principal.NumComponents); i++ {
-		l := readInt16(b, p, e)
+		l, err := readInt16(b, p, e)
+		if err != nil {
+			return err
+		}
 		compB, err := readBytes(b, p, int(l), e)
 		compB, err := readBytes(b, p, int(l), e)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -274,7 +320,10 @@ func parsePrincipal(b []byte, p *int, kt *Keytab, ke *entry, e *binary.ByteOrder
 	}
 	}
 	if kt.version != 1 {
 	if kt.version != 1 {
 		//Name Type is omitted in version 1
 		//Name Type is omitted in version 1
-		ke.Principal.NameType = readInt32(b, p, e)
+		ke.Principal.NameType, err = readInt32(b, p, e)
+		if err != nil {
+			return err
+		}
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -327,12 +376,23 @@ func marshalString(s string, v int) ([]byte, error) {
 }
 }
 
 
 // Read bytes representing a timestamp.
 // Read bytes representing a timestamp.
-func readTimestamp(b []byte, p *int, e *binary.ByteOrder) time.Time {
-	return time.Unix(int64(readInt32(b, p, e)), 0)
+func readTimestamp(b []byte, p *int, e *binary.ByteOrder) (time.Time, error) {
+	i32, err := readInt32(b, p, e)
+	if err != nil {
+		return time.Time{}, err
+	}
+	return time.Unix(int64(i32), 0), nil
 }
 }
 
 
 // Read bytes representing an eight bit integer.
 // Read bytes representing an eight bit integer.
-func readInt8(b []byte, p *int, e *binary.ByteOrder) (i int8) {
+func readInt8(b []byte, p *int, e *binary.ByteOrder) (i int8, err error) {
+	if *p < 0 {
+		return 0, fmt.Errorf("%d cannot be less than zero", *p)
+	}
+
+	if (*p + 1) > len(b) {
+		return 0, fmt.Errorf("%s's length is less than %d", b, *p+1)
+	}
 	buf := bytes.NewBuffer(b[*p : *p+1])
 	buf := bytes.NewBuffer(b[*p : *p+1])
 	binary.Read(buf, *e, &i)
 	binary.Read(buf, *e, &i)
 	*p++
 	*p++
@@ -340,7 +400,15 @@ func readInt8(b []byte, p *int, e *binary.ByteOrder) (i int8) {
 }
 }
 
 
 // Read bytes representing a sixteen bit integer.
 // Read bytes representing a sixteen bit integer.
-func readInt16(b []byte, p *int, e *binary.ByteOrder) (i int16) {
+func readInt16(b []byte, p *int, e *binary.ByteOrder) (i int16, err error) {
+	if *p < 0 {
+		return 0, fmt.Errorf("%d cannot be less than zero", *p)
+	}
+
+	if (*p + 2) > len(b) {
+		return 0, fmt.Errorf("%s's length is less than %d", b, *p+2)
+	}
+
 	buf := bytes.NewBuffer(b[*p : *p+2])
 	buf := bytes.NewBuffer(b[*p : *p+2])
 	binary.Read(buf, *e, &i)
 	binary.Read(buf, *e, &i)
 	*p += 2
 	*p += 2
@@ -348,7 +416,15 @@ func readInt16(b []byte, p *int, e *binary.ByteOrder) (i int16) {
 }
 }
 
 
 // Read bytes representing a thirty two bit integer.
 // Read bytes representing a thirty two bit integer.
-func readInt32(b []byte, p *int, e *binary.ByteOrder) (i int32) {
+func readInt32(b []byte, p *int, e *binary.ByteOrder) (i int32, err error) {
+	if *p < 0 {
+		return 0, fmt.Errorf("%d cannot be less than zero", *p)
+	}
+
+	if (*p + 4) > len(b) {
+		return 0, fmt.Errorf("%s's length is less than %d", b, *p+4)
+	}
+
 	buf := bytes.NewBuffer(b[*p : *p+4])
 	buf := bytes.NewBuffer(b[*p : *p+4])
 	binary.Read(buf, *e, &i)
 	binary.Read(buf, *e, &i)
 	*p += 4
 	*p += 4

+ 35 - 0
keytab/keytab_test.go

@@ -1,6 +1,7 @@
 package keytab
 package keytab
 
 
 import (
 import (
+	"encoding/base64"
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/hex"
 	"os"
 	"os"
@@ -109,3 +110,37 @@ func TestReadBytes(t *testing.T) {
 		t.Fatal("err should be given because negative s was given")
 		t.Fatal("err should be given because negative s was given")
 	}
 	}
 }
 }
+
+func TestUnmarshalPotentialPanics(t *testing.T) {
+	kt := New()
+
+	// Test a good keytab with bad bytes to unmarshal. These should
+	// return errors, but not panic.
+	if err := kt.Unmarshal(nil); err == nil {
+		t.Fatal("should have errored, input is absent")
+	}
+	if err := kt.Unmarshal([]byte{}); err == nil {
+		t.Fatal("should have errored, input is empty")
+	}
+	// Incorrect first byte.
+	if err := kt.Unmarshal([]byte{4}); err == nil {
+		t.Fatal("should have errored, input isn't long enough")
+	}
+	// First byte, but no further content.
+	if err := kt.Unmarshal([]byte{5}); err == nil {
+		t.Fatal("should have errored, input isn't long enough")
+	}
+}
+
+// cxf testing stuff
+func TestBadKeytabs(t *testing.T) {
+	badPayloads := make([]string, 3)
+	badPayloads = append(badPayloads, "BQIwMDAwMDA=")
+	badPayloads = append(badPayloads, "BQIAAAAwAAEACjAwMDAwMDAwMDAAIDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAw")
+	badPayloads = append(badPayloads, "BQKAAAAA")
+	for _, v := range badPayloads {
+		decodedKt, _ := base64.StdEncoding.DecodeString(v)
+		parsedKt := new(Keytab)
+		parsedKt.Unmarshal(decodedKt)
+	}
+}

+ 4 - 0
spnego/spnego.go

@@ -132,6 +132,10 @@ func (s *SPNEGOToken) Marshal() ([]byte, error) {
 func (s *SPNEGOToken) Unmarshal(b []byte) error {
 func (s *SPNEGOToken) Unmarshal(b []byte) error {
 	var r []byte
 	var r []byte
 	var err error
 	var err error
+	// We need some data in the array
+	if len(b) < 1 {
+		return fmt.Errorf("provided byte array is empty")
+	}
 	if b[0] != byte(161) {
 	if b[0] != byte(161) {
 		// Not a NegTokenResp/Targ could be a NegTokenInit
 		// Not a NegTokenResp/Targ could be a NegTokenInit
 		var oid asn1.ObjectIdentifier
 		var oid asn1.ObjectIdentifier

+ 15 - 0
spnego/spnego_test.go

@@ -37,6 +37,21 @@ func TestUnmarshal_SPNEGO_Init(t *testing.T) {
 	assert.NotZero(t, len(s.NegTokenInit.MechTokenBytes), "MechToken is zero in length")
 	assert.NotZero(t, len(s.NegTokenInit.MechTokenBytes), "MechToken is zero in length")
 }
 }
 
 
+func TestUnMarshal_SPNEGO_Empty(t *testing.T) {
+	sp := new(SPNEGOToken)
+
+	// The following tests are intended to ensure we don't panic.
+	if err := sp.Unmarshal(nil); err == nil {
+		t.Fatal("should have errored, input is absent")
+	}
+	if err := sp.Unmarshal([]byte{}); err == nil {
+		t.Fatal("should have errored, input is empty")
+	}
+	if err := sp.Unmarshal([]byte{1}); err == nil {
+		t.Fatal("should have errored, input is too low")
+	}
+}
+
 func TestUnmarshal_SPNEGO_RespTarg(t *testing.T) {
 func TestUnmarshal_SPNEGO_RespTarg(t *testing.T) {
 	t.Parallel()
 	t.Parallel()
 	b, err := hex.DecodeString(testGSSAPIResp)
 	b, err := hex.DecodeString(testGSSAPIResp)