浏览代码

windows/registry: copy latest changes from internal/syscall/registry

This CL includes changes from:

http://golang.org/cl/9805
http://golang.org/cl/9806
http://golang.org/cl/9901

Change-Id: I1f41a8215f9f760c0d3b84596e37bf48bf4c9bc2
Reviewed-on: https://go-review.googlesource.com/10132
Reviewed-by: Rob Pike <r@golang.org>
Alex Brainman 10 年之前
父节点
当前提交
87f732a730
共有 3 个文件被更改,包括 93 次插入3 次删除
  1. 11 0
      windows/registry/export_test.go
  2. 65 0
      windows/registry/registry_test.go
  3. 17 3
      windows/registry/value.go

+ 11 - 0
windows/registry/export_test.go

@@ -0,0 +1,11 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build windows
+
+package registry
+
+func (k Key) SetValue(name string, valtype uint32, data []byte) error {
+	return k.setValue(name, valtype, data)
+}

+ 65 - 0
windows/registry/registry_test.go

@@ -615,3 +615,68 @@ func TestExpandString(t *testing.T) {
 		t.Errorf("want %q string expanded, got %q", want, got)
 		t.Errorf("want %q string expanded, got %q", want, got)
 	}
 	}
 }
 }
+
+func TestInvalidValues(t *testing.T) {
+	softwareK, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer softwareK.Close()
+
+	testKName := randKeyName("TestInvalidValues_")
+
+	k, exist, err := registry.CreateKey(softwareK, testKName, registry.CREATE_SUB_KEY|registry.QUERY_VALUE|registry.SET_VALUE)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer k.Close()
+
+	if exist {
+		t.Fatalf("key %q already exists", testKName)
+	}
+
+	defer registry.DeleteKey(softwareK, testKName)
+
+	var tests = []struct {
+		Type uint32
+		Name string
+		Data []byte
+	}{
+		{registry.DWORD, "Dword1", nil},
+		{registry.DWORD, "Dword2", []byte{1, 2, 3}},
+		{registry.QWORD, "Qword1", nil},
+		{registry.QWORD, "Qword2", []byte{1, 2, 3}},
+		{registry.QWORD, "Qword3", []byte{1, 2, 3, 4, 5, 6, 7}},
+		{registry.MULTI_SZ, "MultiString1", nil},
+		{registry.MULTI_SZ, "MultiString2", []byte{0}},
+		{registry.MULTI_SZ, "MultiString3", []byte{'a', 'b', 0}},
+		{registry.MULTI_SZ, "MultiString4", []byte{'a', 0, 0, 'b', 0}},
+		{registry.MULTI_SZ, "MultiString5", []byte{'a', 0, 0}},
+	}
+
+	for _, test := range tests {
+		err := k.SetValue(test.Name, test.Type, test.Data)
+		if err != nil {
+			t.Fatalf("SetValue for %q failed: %v", test.Name, err)
+		}
+	}
+
+	for _, test := range tests {
+		switch test.Type {
+		case registry.DWORD, registry.QWORD:
+			value, valType, err := k.GetIntegerValue(test.Name)
+			if err == nil {
+				t.Errorf("GetIntegerValue(%q) succeeded. Returns type=%d value=%v", test.Name, valType, value)
+			}
+		case registry.MULTI_SZ:
+			value, valType, err := k.GetStringsValue(test.Name)
+			if err == nil {
+				if len(value) != 0 {
+					t.Errorf("GetStringsValue(%q) succeeded. Returns type=%d value=%v", test.Name, valType, value)
+				}
+			}
+		default:
+			t.Errorf("unsupported type %d for %s value", test.Type, test.Name)
+		}
+	}
+}

+ 17 - 3
windows/registry/value.go

@@ -130,7 +130,7 @@ func ExpandString(value string) (string, error) {
 			return "", err
 			return "", err
 		}
 		}
 		if n <= uint32(len(r)) {
 		if n <= uint32(len(r)) {
-			u := (*[1 << 10]uint16)(unsafe.Pointer(&r[0]))[:]
+			u := (*[1 << 15]uint16)(unsafe.Pointer(&r[0]))[:]
 			return syscall.UTF16ToString(u), nil
 			return syscall.UTF16ToString(u), nil
 		}
 		}
 		r = make([]uint16, n)
 		r = make([]uint16, n)
@@ -150,9 +150,17 @@ func (k Key) GetStringsValue(name string) (val []string, valtype uint32, err err
 	if typ != MULTI_SZ {
 	if typ != MULTI_SZ {
 		return nil, typ, ErrUnexpectedType
 		return nil, typ, ErrUnexpectedType
 	}
 	}
-	val = make([]string, 0, 5)
+	if len(data) == 0 {
+		return nil, typ, nil
+	}
 	p := (*[1 << 24]uint16)(unsafe.Pointer(&data[0]))[:len(data)/2]
 	p := (*[1 << 24]uint16)(unsafe.Pointer(&data[0]))[:len(data)/2]
-	p = p[:len(p)-1] // remove terminating nil
+	if len(p) == 0 {
+		return nil, typ, nil
+	}
+	if p[len(p)-1] == 0 {
+		p = p[:len(p)-1] // remove terminating null
+	}
+	val = make([]string, 0, 5)
 	from := 0
 	from := 0
 	for i, c := range p {
 	for i, c := range p {
 		if c == 0 {
 		if c == 0 {
@@ -175,8 +183,14 @@ func (k Key) GetIntegerValue(name string) (val uint64, valtype uint32, err error
 	}
 	}
 	switch typ {
 	switch typ {
 	case DWORD:
 	case DWORD:
+		if len(data) != 4 {
+			return 0, typ, errors.New("DWORD value is not 4 bytes long")
+		}
 		return uint64(*(*uint32)(unsafe.Pointer(&data[0]))), DWORD, nil
 		return uint64(*(*uint32)(unsafe.Pointer(&data[0]))), DWORD, nil
 	case QWORD:
 	case QWORD:
+		if len(data) != 8 {
+			return 0, typ, errors.New("QWORD value is not 8 bytes long")
+		}
 		return uint64(*(*uint64)(unsafe.Pointer(&data[0]))), QWORD, nil
 		return uint64(*(*uint64)(unsafe.Pointer(&data[0]))), QWORD, nil
 	default:
 	default:
 		return 0, typ, ErrUnexpectedType
 		return 0, typ, ErrUnexpectedType