Browse Source

windows: add LoadLibraryEx, add LazyDLL.System

Updates golang/go#14959

Change-Id: Ib91c359c3df919df0b30e584d38e56f79f3e3dc9
Reviewed-on: https://go-review.googlesource.com/21388
Reviewed-by: Russ Cox <rsc@golang.org>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Brad Fitzpatrick 9 years ago
parent
commit
3dff6e19a5

+ 108 - 16
windows/dll_windows.go

@@ -31,6 +31,10 @@ type DLL struct {
 }
 
 // LoadDLL loads DLL file into memory.
+//
+// Warning: using LoadDLL without an absolute path name is subject to
+// DLL preloading attacks. To safely load a system DLL, use LazyDLL
+// with System set to true, or use LoadLibraryEx directly.
 func LoadDLL(name string) (dll *DLL, err error) {
 	namep, err := UTF16PtrFromString(name)
 	if err != nil {
@@ -162,29 +166,48 @@ func (p *Proc) Call(a ...uintptr) (r1, r2 uintptr, lastErr error) {
 // call to its Handle method or to one of its
 // LazyProc's Addr method.
 type LazyDLL struct {
-	mu   sync.Mutex
-	dll  *DLL // non nil once DLL is loaded
 	Name string
+
+	// System determines whether the DLL must be loaded from the
+	// Windows System directory, bypassing the normal DLL search
+	// path.
+	System bool
+
+	mu  sync.Mutex
+	dll *DLL // non nil once DLL is loaded
 }
 
 // Load loads DLL file d.Name into memory. It returns an error if fails.
 // Load will not try to load DLL, if it is already loaded into memory.
 func (d *LazyDLL) Load() error {
 	// Non-racy version of:
-	// if d.dll == nil {
-	if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll))) == nil {
-		d.mu.Lock()
-		defer d.mu.Unlock()
-		if d.dll == nil {
-			dll, e := LoadDLL(d.Name)
-			if e != nil {
-				return e
-			}
-			// Non-racy version of:
-			// d.dll = dll
-			atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll)), unsafe.Pointer(dll))
-		}
+	// if d.dll != nil {
+	if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll))) != nil {
+		return nil
 	}
+	d.mu.Lock()
+	defer d.mu.Unlock()
+	if d.dll != nil {
+		return nil
+	}
+
+	// kernel32.dll is special, since it's where LoadLibraryEx comes from.
+	// The kernel already special-cases its name, so it's always
+	// loaded from system32.
+	var dll *DLL
+	var err error
+	if d.Name == "kernel32.dll" {
+		dll, err = LoadDLL(d.Name)
+	} else {
+		dll, err = loadLibraryEx(d.Name, d.System)
+	}
+	if err != nil {
+		return err
+	}
+
+	// Non-racy version of:
+	// d.dll = dll
+	atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll)), unsafe.Pointer(dll))
 	return nil
 }
 
@@ -215,8 +238,9 @@ func NewLazyDLL(name string) *LazyDLL {
 // A LazyProc implements access to a procedure inside a LazyDLL.
 // It delays the lookup until the Addr method is called.
 type LazyProc struct {
-	mu   sync.Mutex
 	Name string
+
+	mu   sync.Mutex
 	l    *LazyDLL
 	proc *Proc
 }
@@ -273,3 +297,71 @@ func (p *LazyProc) Call(a ...uintptr) (r1, r2 uintptr, lastErr error) {
 	p.mustFind()
 	return p.proc.Call(a...)
 }
+
+var canDoSearchSystem32Once struct {
+	sync.Once
+	v bool
+}
+
+func initCanDoSearchSystem32() {
+	// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
+	// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
+	// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
+	// systems that have KB2533623 installed. To determine whether the
+	// flags are available, use GetProcAddress to get the address of the
+	// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
+	// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
+	// flags can be used with LoadLibraryEx."
+	canDoSearchSystem32Once.v = (modkernel32.NewProc("AddDllDirectory").Find() == nil)
+}
+
+func canDoSearchSystem32() bool {
+	canDoSearchSystem32Once.Do(initCanDoSearchSystem32)
+	return canDoSearchSystem32Once.v
+}
+
+func isBaseName(name string) bool {
+	for _, c := range name {
+		if c == ':' || c == '/' || c == '\\' {
+			return false
+		}
+	}
+	return true
+}
+
+// loadLibraryEx wraps the Windows LoadLibraryEx function.
+//
+// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms684179(v=vs.85).aspx
+//
+// If name is not an absolute path, LoadLibraryEx searches for the DLL
+// in a variety of automatic locations unless constrained by flags.
+// See: https://msdn.microsoft.com/en-us/library/ff919712%28VS.85%29.aspx
+func loadLibraryEx(name string, system bool) (*DLL, error) {
+	loadDLL := name
+	var flags uintptr
+	if system {
+		if canDoSearchSystem32() {
+			const LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+			flags = LOAD_LIBRARY_SEARCH_SYSTEM32
+		} else if isBaseName(name) {
+			// WindowsXP or unpatched Windows machine
+			// trying to load "foo.dll" out of the system
+			// folder, but LoadLibraryEx doesn't support
+			// that yet on their system, so emulate it.
+			windir, _ := Getenv("WINDIR") // old var; apparently works on XP
+			if windir == "" {
+				return nil, errString("%WINDIR% not defined")
+			}
+			loadDLL = windir + "\\System32\\" + name
+		}
+	}
+	h, err := LoadLibraryEx(loadDLL, 0, flags)
+	if err != nil {
+		return nil, err
+	}
+	return &DLL{Name: name, Handle: h}, nil
+}
+
+type errString string
+
+func (s errString) Error() string { return string(s) }

+ 1 - 1
windows/registry/syscall.go

@@ -8,7 +8,7 @@ package registry
 
 import "syscall"
 
-//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go syscall.go
+//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -xsys -output zsyscall_windows.go syscall.go
 
 const (
 	_REG_OPTION_NON_VOLATILE = 0

+ 7 - 4
windows/registry/zsyscall_windows.go

@@ -2,14 +2,17 @@
 
 package registry
 
-import "unsafe"
-import "syscall"
+import (
+	"golang.org/x/sys/windows"
+	"syscall"
+	"unsafe"
+)
 
 var _ unsafe.Pointer
 
 var (
-	modadvapi32 = syscall.NewLazyDLL("advapi32.dll")
-	modkernel32 = syscall.NewLazyDLL("kernel32.dll")
+	modadvapi32 = &windows.LazyDLL{Name: "advapi32.dll", System: true}
+	modkernel32 = &windows.LazyDLL{Name: "kernel32.dll", System: true}
 
 	procRegCreateKeyExW           = modadvapi32.NewProc("RegCreateKeyExW")
 	procRegDeleteKeyW             = modadvapi32.NewProc("RegDeleteKeyW")

+ 2 - 1
windows/syscall_windows.go

@@ -14,7 +14,7 @@ import (
 	"unsafe"
 )
 
-//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go eventlog.go service.go syscall_windows.go security_windows.go
+//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -xsys -output zsyscall_windows.go eventlog.go service.go syscall_windows.go security_windows.go
 
 type Handle uintptr
 
@@ -84,6 +84,7 @@ func NewCallbackCDecl(fn interface{}) uintptr
 
 //sys	GetLastError() (lasterr error)
 //sys	LoadLibrary(libname string) (handle Handle, err error) = LoadLibraryW
+//sys	LoadLibraryEx(libname string, zero Handle, flags uintptr) (handle Handle, err error) = LoadLibraryExW
 //sys	FreeLibrary(handle Handle) (err error)
 //sys	GetProcAddress(module Handle, procname string) (proc uintptr, err error)
 //sys	GetVersion() (ver uint32, err error)

+ 38 - 13
windows/zsyscall_windows.go

@@ -2,23 +2,25 @@
 
 package windows
 
-import "unsafe"
-import "syscall"
+import (
+	"syscall"
+	"unsafe"
+)
 
 var _ unsafe.Pointer
 
 var (
-	modadvapi32 = syscall.NewLazyDLL("advapi32.dll")
-	modkernel32 = syscall.NewLazyDLL("kernel32.dll")
-	modshell32  = syscall.NewLazyDLL("shell32.dll")
-	modmswsock  = syscall.NewLazyDLL("mswsock.dll")
-	modcrypt32  = syscall.NewLazyDLL("crypt32.dll")
-	modws2_32   = syscall.NewLazyDLL("ws2_32.dll")
-	moddnsapi   = syscall.NewLazyDLL("dnsapi.dll")
-	modiphlpapi = syscall.NewLazyDLL("iphlpapi.dll")
-	modsecur32  = syscall.NewLazyDLL("secur32.dll")
-	modnetapi32 = syscall.NewLazyDLL("netapi32.dll")
-	moduserenv  = syscall.NewLazyDLL("userenv.dll")
+	modadvapi32 = &LazyDLL{Name: "advapi32.dll", System: true}
+	modkernel32 = &LazyDLL{Name: "kernel32.dll", System: true}
+	modshell32  = &LazyDLL{Name: "shell32.dll", System: true}
+	modmswsock  = &LazyDLL{Name: "mswsock.dll", System: true}
+	modcrypt32  = &LazyDLL{Name: "crypt32.dll", System: true}
+	modws2_32   = &LazyDLL{Name: "ws2_32.dll", System: true}
+	moddnsapi   = &LazyDLL{Name: "dnsapi.dll", System: true}
+	modiphlpapi = &LazyDLL{Name: "iphlpapi.dll", System: true}
+	modsecur32  = &LazyDLL{Name: "secur32.dll", System: true}
+	modnetapi32 = &LazyDLL{Name: "netapi32.dll", System: true}
+	moduserenv  = &LazyDLL{Name: "userenv.dll", System: true}
 
 	procRegisterEventSourceW               = modadvapi32.NewProc("RegisterEventSourceW")
 	procDeregisterEventSource              = modadvapi32.NewProc("DeregisterEventSource")
@@ -39,6 +41,7 @@ var (
 	procQueryServiceConfig2W               = modadvapi32.NewProc("QueryServiceConfig2W")
 	procGetLastError                       = modkernel32.NewProc("GetLastError")
 	procLoadLibraryW                       = modkernel32.NewProc("LoadLibraryW")
+	procLoadLibraryExW                     = modkernel32.NewProc("LoadLibraryExW")
 	procFreeLibrary                        = modkernel32.NewProc("FreeLibrary")
 	procGetProcAddress                     = modkernel32.NewProc("GetProcAddress")
 	procGetVersion                         = modkernel32.NewProc("GetVersion")
@@ -430,6 +433,28 @@ func _LoadLibrary(libname *uint16) (handle Handle, err error) {
 	return
 }
 
+func LoadLibraryEx(libname string, zero Handle, flags uintptr) (handle Handle, err error) {
+	var _p0 *uint16
+	_p0, err = syscall.UTF16PtrFromString(libname)
+	if err != nil {
+		return
+	}
+	return _LoadLibraryEx(_p0, zero, flags)
+}
+
+func _LoadLibraryEx(libname *uint16, zero Handle, flags uintptr) (handle Handle, err error) {
+	r0, _, e1 := syscall.Syscall(procLoadLibraryExW.Addr(), 3, uintptr(unsafe.Pointer(libname)), uintptr(zero), uintptr(flags))
+	handle = Handle(r0)
+	if handle == 0 {
+		if e1 != 0 {
+			err = error(e1)
+		} else {
+			err = syscall.EINVAL
+		}
+	}
+	return
+}
+
 func FreeLibrary(handle Handle) (err error) {
 	r1, _, e1 := syscall.Syscall(procFreeLibrary.Addr(), 1, uintptr(handle), 0, 0)
 	if r1 == 0 {