瀏覽代碼

unix: drop dummy byte for oob in unixgram SendmsgN

This commit relaxes SendmsgN behavior of introducing a dummy 1-byte
payload when sending ancillary-only messages.
The fake payload is not needed for SOCK_DGRAM type sockets, and actually
breaks interoperability with other fd-passing software (journald is one
known example). This introduces an additional check to avoid injecting
dummy payload in such case.

Backport of https://go.googlesource.com/go/+/93da0b6e66f24c4c307e0df37ceb102a33306174
Full reference at https:/golang.org/issue/6476#issue-51285243

Change-Id: I7cf00a1c7cde75fe62e00b98ccba6ac8469b0493
Reviewed-on: https://go-review.googlesource.com/60190
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Luca Bruno 8 年之前
父節點
當前提交
bb24a47a89
共有 2 個文件被更改,包括 111 次插入86 次删除
  1. 99 84
      unix/creds_test.go
  2. 12 2
      unix/syscall_linux.go

+ 99 - 84
unix/creds_test.go

@@ -21,101 +21,116 @@ import (
 // sockets. The SO_PASSCRED socket option is enabled on the sending
 // socket for this to work.
 func TestSCMCredentials(t *testing.T) {
-	fds, err := unix.Socketpair(unix.AF_LOCAL, unix.SOCK_STREAM, 0)
-	if err != nil {
-		t.Fatalf("Socketpair: %v", err)
+	socketTypeTests := []struct {
+		socketType int
+		dataLen    int
+	}{
+		{
+			unix.SOCK_STREAM,
+			1,
+		}, {
+			unix.SOCK_DGRAM,
+			0,
+		},
 	}
-	defer unix.Close(fds[0])
-	defer unix.Close(fds[1])
 
-	err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1)
-	if err != nil {
-		t.Fatalf("SetsockoptInt: %v", err)
-	}
+	for _, tt := range socketTypeTests {
+		fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0)
+		if err != nil {
+			t.Fatalf("Socketpair: %v", err)
+		}
+		defer unix.Close(fds[0])
+		defer unix.Close(fds[1])
 
-	srvFile := os.NewFile(uintptr(fds[0]), "server")
-	defer srvFile.Close()
-	srv, err := net.FileConn(srvFile)
-	if err != nil {
-		t.Errorf("FileConn: %v", err)
-		return
-	}
-	defer srv.Close()
+		err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1)
+		if err != nil {
+			t.Fatalf("SetsockoptInt: %v", err)
+		}
 
-	cliFile := os.NewFile(uintptr(fds[1]), "client")
-	defer cliFile.Close()
-	cli, err := net.FileConn(cliFile)
-	if err != nil {
-		t.Errorf("FileConn: %v", err)
-		return
-	}
-	defer cli.Close()
+		srvFile := os.NewFile(uintptr(fds[0]), "server")
+		defer srvFile.Close()
+		srv, err := net.FileConn(srvFile)
+		if err != nil {
+			t.Errorf("FileConn: %v", err)
+			return
+		}
+		defer srv.Close()
+
+		cliFile := os.NewFile(uintptr(fds[1]), "client")
+		defer cliFile.Close()
+		cli, err := net.FileConn(cliFile)
+		if err != nil {
+			t.Errorf("FileConn: %v", err)
+			return
+		}
+		defer cli.Close()
+
+		var ucred unix.Ucred
+		if os.Getuid() != 0 {
+			ucred.Pid = int32(os.Getpid())
+			ucred.Uid = 0
+			ucred.Gid = 0
+			oob := unix.UnixCredentials(&ucred)
+			_, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
+			if op, ok := err.(*net.OpError); ok {
+				err = op.Err
+			}
+			if sys, ok := err.(*os.SyscallError); ok {
+				err = sys.Err
+			}
+			if err != syscall.EPERM {
+				t.Fatalf("WriteMsgUnix failed with %v, want EPERM", err)
+			}
+		}
 
-	var ucred unix.Ucred
-	if os.Getuid() != 0 {
 		ucred.Pid = int32(os.Getpid())
-		ucred.Uid = 0
-		ucred.Gid = 0
+		ucred.Uid = uint32(os.Getuid())
+		ucred.Gid = uint32(os.Getgid())
 		oob := unix.UnixCredentials(&ucred)
-		_, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
-		if op, ok := err.(*net.OpError); ok {
-			err = op.Err
+
+		// On SOCK_STREAM, this is internally going to send a dummy byte
+		n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
+		if err != nil {
+			t.Fatalf("WriteMsgUnix: %v", err)
 		}
-		if sys, ok := err.(*os.SyscallError); ok {
-			err = sys.Err
+		if n != 0 {
+			t.Fatalf("WriteMsgUnix n = %d, want 0", n)
 		}
-		if err != syscall.EPERM {
-			t.Fatalf("WriteMsgUnix failed with %v, want EPERM", err)
+		if oobn != len(oob) {
+			t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
 		}
-	}
-
-	ucred.Pid = int32(os.Getpid())
-	ucred.Uid = uint32(os.Getuid())
-	ucred.Gid = uint32(os.Getgid())
-	oob := unix.UnixCredentials(&ucred)
-
-	// this is going to send a dummy byte
-	n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
-	if err != nil {
-		t.Fatalf("WriteMsgUnix: %v", err)
-	}
-	if n != 0 {
-		t.Fatalf("WriteMsgUnix n = %d, want 0", n)
-	}
-	if oobn != len(oob) {
-		t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
-	}
 
-	oob2 := make([]byte, 10*len(oob))
-	n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
-	if err != nil {
-		t.Fatalf("ReadMsgUnix: %v", err)
-	}
-	if flags != 0 {
-		t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags)
-	}
-	if n != 1 {
-		t.Fatalf("ReadMsgUnix n = %d, want 1 (dummy byte)", n)
-	}
-	if oobn2 != oobn {
-		// without SO_PASSCRED set on the socket, ReadMsgUnix will
-		// return zero oob bytes
-		t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
-	}
-	oob2 = oob2[:oobn2]
-	if !bytes.Equal(oob, oob2) {
-		t.Fatal("ReadMsgUnix oob bytes don't match")
-	}
+		oob2 := make([]byte, 10*len(oob))
+		n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
+		if err != nil {
+			t.Fatalf("ReadMsgUnix: %v", err)
+		}
+		if flags != 0 {
+			t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags)
+		}
+		if n != tt.dataLen {
+			t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)
+		}
+		if oobn2 != oobn {
+			// without SO_PASSCRED set on the socket, ReadMsgUnix will
+			// return zero oob bytes
+			t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
+		}
+		oob2 = oob2[:oobn2]
+		if !bytes.Equal(oob, oob2) {
+			t.Fatal("ReadMsgUnix oob bytes don't match")
+		}
 
-	scm, err := unix.ParseSocketControlMessage(oob2)
-	if err != nil {
-		t.Fatalf("ParseSocketControlMessage: %v", err)
-	}
-	newUcred, err := unix.ParseUnixCredentials(&scm[0])
-	if err != nil {
-		t.Fatalf("ParseUnixCredentials: %v", err)
-	}
-	if *newUcred != ucred {
-		t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
+		scm, err := unix.ParseSocketControlMessage(oob2)
+		if err != nil {
+			t.Fatalf("ParseSocketControlMessage: %v", err)
+		}
+		newUcred, err := unix.ParseUnixCredentials(&scm[0])
+		if err != nil {
+			t.Fatalf("ParseUnixCredentials: %v", err)
+		}
+		if *newUcred != ucred {
+			t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
+		}
 	}
 }

+ 12 - 2
unix/syscall_linux.go

@@ -931,8 +931,13 @@ func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from
 	}
 	var dummy byte
 	if len(oob) > 0 {
+		var sockType int
+		sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
+		if err != nil {
+			return
+		}
 		// receive at least one normal byte
-		if len(p) == 0 {
+		if sockType != SOCK_DGRAM && len(p) == 0 {
 			iov.Base = &dummy
 			iov.SetLen(1)
 		}
@@ -978,8 +983,13 @@ func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error)
 	}
 	var dummy byte
 	if len(oob) > 0 {
+		var sockType int
+		sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
+		if err != nil {
+			return 0, err
+		}
 		// send at least one normal byte
-		if len(p) == 0 {
+		if sockType != SOCK_DGRAM && len(p) == 0 {
 			iov.Base = &dummy
 			iov.SetLen(1)
 		}