From a0577540353e9c8acd42774d8e7a72f5676d19e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 26 Aug 2022 17:25:36 +0800 Subject: [PATCH] Revert linux process searcher --- common/process/searcher_android.go | 12 +-- common/process/searcher_linux.go | 6 +- common/process/searcher_linux_shared.go | 115 ++++++++++++++++++++---- go.mod | 2 +- 4 files changed, 110 insertions(+), 25 deletions(-) diff --git a/common/process/searcher_android.go b/common/process/searcher_android.go index b8ef817a..638c00c5 100644 --- a/common/process/searcher_android.go +++ b/common/process/searcher_android.go @@ -18,21 +18,21 @@ func NewSearcher(config Config) (Searcher, error) { } func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) { - socket, err := resolveSocketByNetlink(network, source, destination) + _, uid, err := resolveSocketByNetlink(network, source, destination) if err != nil { return nil, err } - if sharedPackage, loaded := s.packageManager.SharedPackageByID(socket.UID); loaded { + if sharedPackage, loaded := s.packageManager.SharedPackageByID(uid); loaded { return &Info{ - UserId: int32(socket.UID), + UserId: int32(uid), PackageName: sharedPackage, }, nil } - if packageName, loaded := s.packageManager.PackageByID(socket.UID); loaded { + if packageName, loaded := s.packageManager.PackageByID(uid); loaded { return &Info{ - UserId: int32(socket.UID), + UserId: int32(uid), PackageName: packageName, }, nil } - return &Info{UserId: int32(socket.UID)}, nil + return &Info{UserId: int32(uid)}, nil } diff --git a/common/process/searcher_linux.go b/common/process/searcher_linux.go index 3462740e..39470205 100644 --- a/common/process/searcher_linux.go +++ b/common/process/searcher_linux.go @@ -20,16 +20,16 @@ func NewSearcher(config Config) (Searcher, error) { } func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) { - socket, err := resolveSocketByNetlink(network, source, destination) + inode, uid, err := resolveSocketByNetlink(network, source, destination) if err != nil { return nil, err } - processPath, err := resolveProcessNameByProcSearch(socket.INode, socket.UID) + processPath, err := resolveProcessNameByProcSearch(inode, uid) if err != nil { s.logger.DebugContext(ctx, "find process path: ", err) } return &Info{ - UserId: int32(socket.UID), + UserId: int32(uid), ProcessPath: processPath, }, nil } diff --git a/common/process/searcher_linux_shared.go b/common/process/searcher_linux_shared.go index 1114c07d..67e24a5f 100644 --- a/common/process/searcher_linux_shared.go +++ b/common/process/searcher_linux_shared.go @@ -6,6 +6,7 @@ import ( "bytes" "encoding/binary" "fmt" + "net" "net/netip" "os" "path" @@ -14,7 +15,9 @@ import ( "unicode" "unsafe" - "github.com/sagernet/netlink" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" ) @@ -34,7 +37,7 @@ const ( pathProc = "/proc" ) -func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (*netlink.Socket, error) { +func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) { var family uint8 var protocol uint8 @@ -44,28 +47,110 @@ func resolveSocketByNetlink(network string, source netip.AddrPort, destination n case N.NetworkUDP: protocol = syscall.IPPROTO_UDP default: - return nil, os.ErrInvalid + return 0, 0, os.ErrInvalid } + if source.Addr().Is4() { family = syscall.AF_INET } else { family = syscall.AF_INET6 } - sockets, err := netlink.SocketGet(family, protocol, source, netip.AddrPortFrom(netip.IPv6Unspecified(), 0)) - if err == nil { - sockets, err = netlink.SocketGet(family, protocol, source, destination) - } + + req := packSocketDiagRequest(family, protocol, source) + + socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG) if err != nil { - return nil, err + return 0, 0, E.Cause(err, "dial netlink") } - if len(sockets) > 1 { - for _, socket := range sockets { - if socket.ID.DestinationPort == destination.Port() { - return socket, nil - } - } + defer syscall.Close(socket) + + syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, &syscall.Timeval{Usec: 100}) + syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &syscall.Timeval{Usec: 100}) + + err = syscall.Connect(socket, &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pad: 0, + Pid: 0, + Groups: 0, + }) + if err != nil { + return } - return sockets[0], nil + + _, err = syscall.Write(socket, req) + if err != nil { + return 0, 0, E.Cause(err, "write netlink request") + } + + _buffer := buf.StackNew() + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + + n, err := syscall.Read(socket, buffer.FreeBytes()) + if err != nil { + return 0, 0, E.Cause(err, "read netlink response") + } + + buffer.Truncate(n) + + messages, err := syscall.ParseNetlinkMessage(buffer.Bytes()) + if err != nil { + return 0, 0, E.Cause(err, "parse netlink message") + } else if len(messages) == 0 { + return 0, 0, E.New("unexcepted netlink response") + } + + message := messages[0] + if message.Header.Type&syscall.NLMSG_ERROR != 0 { + return 0, 0, E.New("netlink message: NLMSG_ERROR") + } + + inode, uid = unpackSocketDiagResponse(&messages[0]) + return +} + +func packSocketDiagRequest(family, protocol byte, source netip.AddrPort) []byte { + s := make([]byte, 16) + copy(s, source.Addr().AsSlice()) + + buf := make([]byte, sizeOfSocketDiagRequest) + + nativeEndian.PutUint32(buf[0:4], sizeOfSocketDiagRequest) + nativeEndian.PutUint16(buf[4:6], socketDiagByFamily) + nativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST|syscall.NLM_F_DUMP) + nativeEndian.PutUint32(buf[8:12], 0) + nativeEndian.PutUint32(buf[12:16], 0) + + buf[16] = family + buf[17] = protocol + buf[18] = 0 + buf[19] = 0 + nativeEndian.PutUint32(buf[20:24], 0xFFFFFFFF) + + binary.BigEndian.PutUint16(buf[24:26], source.Port()) + binary.BigEndian.PutUint16(buf[26:28], 0) + + copy(buf[28:44], s) + copy(buf[44:60], net.IPv6zero) + + nativeEndian.PutUint32(buf[60:64], 0) + nativeEndian.PutUint64(buf[64:72], 0xFFFFFFFFFFFFFFFF) + + return buf +} + +func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid uint32) { + if len(msg.Data) < 72 { + return 0, 0 + } + + data := msg.Data + + uid = nativeEndian.Uint32(data[64:68]) + inode = nativeEndian.Uint32(data[68:72]) + + return } func resolveProcessNameByProcSearch(inode, uid uint32) (string, error) { diff --git a/go.mod b/go.mod index 480dd7bc..3b39b066 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/oschwald/maxminddb-golang v1.10.0 github.com/pires/go-proxyproto v0.6.2 github.com/sagernet/certmagic v0.0.0-20220819042630-4a57f8b6853a - github.com/sagernet/netlink v0.0.0-20220820041223-3cd8365d17ac github.com/sagernet/quic-go v0.0.0-20220818150011-de611ab3e2bb github.com/sagernet/sing v0.0.0-20220825093630-185d87918290 github.com/sagernet/sing-dns v0.0.0-20220822023312-3e086b06d666 @@ -58,6 +57,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sagernet/abx-go v0.0.0-20220819185957-dba1257d738e // indirect github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 // indirect + github.com/sagernet/netlink v0.0.0-20220820041223-3cd8365d17ac // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect go.uber.org/multierr v1.6.0 // indirect