diff --git a/adapter/inbound.go b/adapter/inbound.go index 7ee00f1c..c8d77dcc 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -4,6 +4,7 @@ import ( "context" "net/netip" + "github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-dns" M "github.com/sagernet/sing/common/metadata" ) @@ -34,6 +35,7 @@ type InboundContext struct { SourceGeoIPCode string GeoIPCode string + ProcessInfo *process.Info } type inboundContextKey struct{} diff --git a/common/process/searcher.go b/common/process/searcher.go new file mode 100644 index 00000000..cdecd333 --- /dev/null +++ b/common/process/searcher.go @@ -0,0 +1,21 @@ +package process + +import ( + "context" + "net/netip" + + E "github.com/sagernet/sing/common/exceptions" +) + +type Searcher interface { + FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) +} + +var ErrNotFound = E.New("process not found") + +type Info struct { + ProcessPath string + PackageName string + User string + UserId int32 +} diff --git a/common/process/searcher_android.go b/common/process/searcher_android.go new file mode 100644 index 00000000..2da6413c --- /dev/null +++ b/common/process/searcher_android.go @@ -0,0 +1,169 @@ +package process + +import ( + "context" + "encoding/xml" + "io" + "net/netip" + "os" + "strconv" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/fsnotify/fsnotify" +) + +var _ Searcher = (*androidSearcher)(nil) + +type androidSearcher struct { + logger log.ContextLogger + watcher *fsnotify.Watcher + userMap map[string]int32 + packageMap map[int32]string + sharedUserMap map[int32]string +} + +func NewSearcher(logger log.ContextLogger) (Searcher, error) { + return &androidSearcher{logger: logger}, nil +} + +func (s *androidSearcher) Start() error { + err := s.updatePackages() + if err != nil { + return E.Cause(err, "read packages list") + } + err = s.startWatcher() + if err != nil { + s.logger.Debug("create fsnotify watcher: ", err) + } + return nil +} + +func (s *androidSearcher) startWatcher() error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + err = watcher.Add("/data/system/packages.xml") + if err != nil { + return err + } + s.watcher = watcher + go s.loopUpdate() + return nil +} + +func (s *androidSearcher) loopUpdate() { + select { + case _, ok := <-s.watcher.Events: + if !ok { + return + } + err := s.updatePackages() + if err != nil { + s.logger.Error(E.Cause(err, "update packages list")) + } + case err, ok := <-s.watcher.Errors: + if !ok { + return + } + s.logger.Error(E.Cause(err, "fsnotify error")) + } +} + +func (s *androidSearcher) Close() error { + return common.Close(common.PtrOrNil(s.watcher)) +} + +func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) { + _, uid, err := resolveSocketByNetlink(network, srcIP, srcPort) + if err != nil { + return nil, err + } + if sharedUser, loaded := s.sharedUserMap[uid]; loaded { + return &Info{ + UserId: uid, + PackageName: sharedUser, + }, nil + } + if packageName, loaded := s.packageMap[uid]; loaded { + return &Info{ + UserId: uid, + PackageName: packageName, + }, nil + } + return &Info{UserId: uid}, nil +} + +func (s *androidSearcher) updatePackages() error { + userMap := make(map[string]int32) + packageMap := make(map[int32]string) + sharedUserMap := make(map[int32]string) + packagesData, err := os.Open("/data/system/packages.xml") + if err != nil { + return err + } + decoder := xml.NewDecoder(packagesData) + var token xml.Token + for { + token, err = decoder.Token() + if err == io.EOF { + break + } else if err != nil { + return err + } + + element, isStart := token.(xml.StartElement) + if !isStart { + continue + } + + switch element.Name.Local { + case "package": + var name string + var userID int64 + for _, attr := range element.Attr { + switch attr.Name.Local { + case "name": + name = attr.Value + case "userId", "sharedUserId": + userID, err = strconv.ParseInt(attr.Value, 10, 32) + if err != nil { + return err + } + } + } + if userID == 0 && name == "" { + continue + } + userMap[name] = int32(userID) + packageMap[int32(userID)] = name + case "shared-user": + var name string + var userID int64 + for _, attr := range element.Attr { + switch attr.Name.Local { + case "name": + name = attr.Value + case "userId": + userID, err = strconv.ParseInt(attr.Value, 10, 32) + if err != nil { + return err + } + packageMap[int32(userID)] = name + } + } + if userID == 0 && name == "" { + continue + } + sharedUserMap[int32(userID)] = name + } + } + s.logger.Info("updated packages list: ", len(packageMap), " packages, ", len(sharedUserMap), " shared users") + s.userMap = userMap + s.packageMap = packageMap + s.sharedUserMap = sharedUserMap + return nil +} diff --git a/common/process/searcher_darwin.go b/common/process/searcher_darwin.go new file mode 100644 index 00000000..2988debf --- /dev/null +++ b/common/process/searcher_darwin.go @@ -0,0 +1,123 @@ +package process + +import ( + "context" + "encoding/binary" + "net/netip" + "os" + "syscall" + "unsafe" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + + "golang.org/x/sys/unix" +) + +var _ Searcher = (*darwinSearcher)(nil) + +type darwinSearcher struct{} + +func NewSearcher(logger log.ContextLogger) (Searcher, error) { + return &darwinSearcher{}, nil +} + +func (d *darwinSearcher) FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) { + processName, err := findProcessName(network, srcIP, srcPort) + if err != nil { + return nil, err + } + return &Info{ProcessPath: processName, UserId: -1}, nil +} + +func findProcessName(network string, ip netip.Addr, port int) (string, error) { + var spath string + switch network { + case C.NetworkTCP: + spath = "net.inet.tcp.pcblist_n" + case C.NetworkUDP: + spath = "net.inet.udp.pcblist_n" + default: + return "", os.ErrInvalid + } + + isIPv4 := ip.Is4() + + value, err := syscall.Sysctl(spath) + if err != nil { + return "", err + } + + buf := []byte(value) + + // from darwin-xnu/bsd/netinet/in_pcblist.c:get_pcblist_n + // size/offset are round up (aligned) to 8 bytes in darwin + // rup8(sizeof(xinpcb_n)) + rup8(sizeof(xsocket_n)) + + // 2 * rup8(sizeof(xsockbuf_n)) + rup8(sizeof(xsockstat_n)) + itemSize := 384 + if network == C.NetworkTCP { + // rup8(sizeof(xtcpcb_n)) + itemSize += 208 + } + // skip the first xinpgen(24 bytes) block + for i := 24; i+itemSize <= len(buf); i += itemSize { + // offset of xinpcb_n and xsocket_n + inp, so := i, i+104 + + srcPort := binary.BigEndian.Uint16(buf[inp+18 : inp+20]) + if uint16(port) != srcPort { + continue + } + + // xinpcb_n.inp_vflag + flag := buf[inp+44] + + var srcIP netip.Addr + switch { + case flag&0x1 > 0 && isIPv4: + // ipv4 + srcIP = netip.AddrFrom4(*(*[4]byte)(buf[inp+76 : inp+80])) + case flag&0x2 > 0 && !isIPv4: + // ipv6 + srcIP = netip.AddrFrom16(*(*[16]byte)(buf[inp+64 : inp+80])) + default: + continue + } + + if ip != srcIP { + continue + } + + // xsocket_n.so_last_pid + pid := readNativeUint32(buf[so+68 : so+72]) + return getExecPathFromPID(pid) + } + + return "", ErrNotFound +} + +func getExecPathFromPID(pid uint32) (string, error) { + const ( + procpidpathinfo = 0xb + procpidpathinfosize = 1024 + proccallnumpidinfo = 0x2 + ) + buf := make([]byte, procpidpathinfosize) + _, _, errno := syscall.Syscall6( + syscall.SYS_PROC_INFO, + proccallnumpidinfo, + uintptr(pid), + procpidpathinfo, + 0, + uintptr(unsafe.Pointer(&buf[0])), + procpidpathinfosize) + if errno != 0 { + return "", errno + } + + return unix.ByteSliceToString(buf), nil +} + +func readNativeUint32(b []byte) uint32 { + return *(*uint32)(unsafe.Pointer(&b[0])) +} diff --git a/common/process/searcher_linux.go b/common/process/searcher_linux.go new file mode 100644 index 00000000..64295cb4 --- /dev/null +++ b/common/process/searcher_linux.go @@ -0,0 +1,35 @@ +//go:build linux && !android + +package process + +import ( + "context" + "net/netip" + + "github.com/sagernet/sing-box/log" +) + +var _ Searcher = (*linuxSearcher)(nil) + +type linuxSearcher struct { + logger log.ContextLogger +} + +func NewSearcher(logger log.ContextLogger) (Searcher, error) { + return &linuxSearcher{logger}, nil +} + +func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) { + inode, uid, err := resolveSocketByNetlink(network, srcIP, srcPort) + if err != nil { + return nil, err + } + processPath, err := resolveProcessNameByProcSearch(inode, uid) + if err != nil { + s.logger.DebugContext(ctx, "find process path: ", err) + } + return &Info{ + UserId: uid, + ProcessPath: processPath, + }, nil +} diff --git a/common/process/searcher_linux_shared.go b/common/process/searcher_linux_shared.go new file mode 100644 index 00000000..7aea0e40 --- /dev/null +++ b/common/process/searcher_linux_shared.go @@ -0,0 +1,206 @@ +//go:build linux + +package process + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" + "net/netip" + "os" + "path" + "strings" + "syscall" + "unicode" + "unsafe" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" +) + +// from https://github.com/vishvananda/netlink/blob/bca67dfc8220b44ef582c9da4e9172bf1c9ec973/nl/nl_linux.go#L52-L62 +var nativeEndian = func() binary.ByteOrder { + var x uint32 = 0x01020304 + if *(*byte)(unsafe.Pointer(&x)) == 0x01 { + return binary.BigEndian + } + + return binary.LittleEndian +}() + +const ( + sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + 8 + 48 + socketDiagByFamily = 20 + pathProc = "/proc" +) + +func resolveSocketByNetlink(network string, ip netip.Addr, srcPort int) (inode int32, uid int32, err error) { + var family byte + var protocol byte + + switch network { + case C.NetworkTCP: + protocol = syscall.IPPROTO_TCP + case C.NetworkUDP: + protocol = syscall.IPPROTO_UDP + default: + return 0, 0, os.ErrInvalid + } + + if ip.Is4() { + family = syscall.AF_INET + } else { + family = syscall.AF_INET6 + } + + req := packSocketDiagRequest(family, protocol, ip, uint16(srcPort)) + + socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG) + if err != nil { + return 0, 0, E.Cause(err, "dial netlink") + } + defer syscall.Close(socket) + + syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, &syscall.Timeval{Usec: 100}) + syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &syscall.Timeval{Usec: 100}) + + if err = syscall.Connect(socket, &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pad: 0, + Pid: 0, + Groups: 0, + }); err != nil { + return 0, 0, err + } + + if _, err = syscall.Write(socket, req); err != nil { + return 0, 0, E.Cause(err, "write netlink request") + } + + _buffer := buf.StackNew() + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + + n, err := syscall.Read(socket, buffer.FreeBytes()) + if err != nil { + return 0, 0, E.Cause(err, "read netlink response") + } + + buffer.Truncate(n) + + messages, err := syscall.ParseNetlinkMessage(buffer.Bytes()) + if err != nil { + return 0, 0, E.Cause(err, "parse netlink message") + } else if len(messages) == 0 { + return 0, 0, E.New("unexcepted netlink response") + } + + message := messages[0] + if message.Header.Type&syscall.NLMSG_ERROR != 0 { + return 0, 0, E.New("netlink message: NLMSG_ERROR") + } + + inode, uid = unpackSocketDiagResponse(&messages[0]) + if inode < 0 || uid < 0 { + return 0, 0, E.New("invalid inode(", inode, ") or uid(", uid, ")") + } + return +} + +func packSocketDiagRequest(family, protocol byte, source netip.Addr, sourcePort uint16) []byte { + s := make([]byte, 16) + copy(s, source.AsSlice()) + + buf := make([]byte, sizeOfSocketDiagRequest) + + nativeEndian.PutUint32(buf[0:4], sizeOfSocketDiagRequest) + nativeEndian.PutUint16(buf[4:6], socketDiagByFamily) + nativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST|syscall.NLM_F_DUMP) + nativeEndian.PutUint32(buf[8:12], 0) + nativeEndian.PutUint32(buf[12:16], 0) + + buf[16] = family + buf[17] = protocol + buf[18] = 0 + buf[19] = 0 + nativeEndian.PutUint32(buf[20:24], 0xFFFFFFFF) + + binary.BigEndian.PutUint16(buf[24:26], sourcePort) + binary.BigEndian.PutUint16(buf[26:28], 0) + + copy(buf[28:44], s) + copy(buf[44:60], net.IPv6zero) + + nativeEndian.PutUint32(buf[60:64], 0) + nativeEndian.PutUint64(buf[64:72], 0xFFFFFFFFFFFFFFFF) + + return buf +} + +func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid int32) { + if len(msg.Data) < 72 { + return 0, 0 + } + + data := msg.Data + + uid = int32(nativeEndian.Uint32(data[64:68])) + inode = int32(nativeEndian.Uint32(data[68:72])) + + return +} + +func resolveProcessNameByProcSearch(inode, uid int32) (string, error) { + files, err := os.ReadDir(pathProc) + if err != nil { + return "", err + } + + buffer := make([]byte, syscall.PathMax) + socket := []byte(fmt.Sprintf("socket:[%d]", inode)) + + for _, f := range files { + if !f.IsDir() || !isPid(f.Name()) { + continue + } + + info, err := f.Info() + if err != nil { + return "", err + } + if info.Sys().(*syscall.Stat_t).Uid != uint32(uid) { + continue + } + + processPath := path.Join(pathProc, f.Name()) + fdPath := path.Join(processPath, "fd") + + fds, err := os.ReadDir(fdPath) + if err != nil { + continue + } + + for _, fd := range fds { + n, err := syscall.Readlink(path.Join(fdPath, fd.Name()), buffer) + if err != nil { + continue + } + + if bytes.Equal(buffer[:n], socket) { + return os.Readlink(path.Join(processPath, "exe")) + } + } + } + + return "", fmt.Errorf("process of uid(%d),inode(%d) not found", uid, inode) +} + +func isPid(s string) bool { + return strings.IndexFunc(s, func(r rune) bool { + return !unicode.IsDigit(r) + }) == -1 +} diff --git a/common/process/searcher_stub.go b/common/process/searcher_stub.go new file mode 100644 index 00000000..ff128517 --- /dev/null +++ b/common/process/searcher_stub.go @@ -0,0 +1,13 @@ +//go:build !linux && !windows && !darwin + +package process + +import ( + "os" + + "github.com/sagernet/sing-box/log" +) + +func NewSearcher(logger log.ContextLogger) (Searcher, error) { + return nil, os.ErrInvalid +} diff --git a/common/process/searcher_windows.go b/common/process/searcher_windows.go new file mode 100644 index 00000000..ae3c3a7f --- /dev/null +++ b/common/process/searcher_windows.go @@ -0,0 +1,235 @@ +package process + +import ( + "context" + "fmt" + "net/netip" + "os" + "syscall" + "unsafe" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/sys/windows" +) + +var _ Searcher = (*windowsSearcher)(nil) + +type windowsSearcher struct{} + +func NewSearcher(logger log.ContextLogger) (Searcher, error) { + err := initWin32API() + if err != nil { + return nil, E.Cause(err, "init win32 api") + } + return &windowsSearcher{}, nil +} + +var ( + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable") + procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW") +) + +func initWin32API() error { + err := modiphlpapi.Load() + if err != nil { + return E.Cause(err, "load iphlpapi.dll") + } + + err = procGetExtendedTcpTable.Find() + if err != nil { + return E.Cause(err, "load iphlpapi::GetExtendedTcpTable") + } + + err = procGetExtendedUdpTable.Find() + if err != nil { + return E.Cause(err, "load iphlpapi::GetExtendedUdpTable") + } + + err = modkernel32.Load() + if err != nil { + return E.Cause(err, "load kernel32.dll") + } + + err = procQueryFullProcessImageNameW.Find() + if err != nil { + return E.Cause(err, "load kernel32::QueryFullProcessImageNameW") + } + + return nil +} + +func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) { + processName, err := findProcessName(network, srcIP, srcPort) + if err != nil { + return nil, err + } + return &Info{ProcessPath: processName, UserId: -1}, nil +} + +func findProcessName(network string, ip netip.Addr, srcPort int) (string, error) { + family := windows.AF_INET + if ip.Is6() { + family = windows.AF_INET6 + } + + const ( + tcpTablePidConn = 4 + udpTablePid = 1 + ) + + var class int + var fn uintptr + switch network { + case C.NetworkTCP: + fn = procGetExtendedTcpTable.Addr() + class = tcpTablePidConn + case C.NetworkUDP: + fn = procGetExtendedUdpTable.Addr() + class = udpTablePid + default: + return "", os.ErrInvalid + } + + buf, err := getTransportTable(fn, family, class) + if err != nil { + return "", err + } + + s := newSearcher(family == windows.AF_INET, network == C.NetworkTCP) + + pid, err := s.Search(buf, ip, uint16(srcPort)) + if err != nil { + return "", err + } + return getExecPathFromPID(pid) +} + +type searcher struct { + itemSize int + port int + ip int + ipSize int + pid int + tcpState int +} + +func (s *searcher) Search(b []byte, ip netip.Addr, port uint16) (uint32, error) { + n := int(readNativeUint32(b[:4])) + itemSize := s.itemSize + for i := 0; i < n; i++ { + row := b[4+itemSize*i : 4+itemSize*(i+1)] + + if s.tcpState >= 0 { + tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4]) + // MIB_TCP_STATE_ESTAB, only check established connections for TCP + if tcpState != 5 { + continue + } + } + + // according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian. + // this field can be illustrated as follows depends on different machine endianess: + // little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB) + // big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB) + // so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32 + srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4]))) + if srcPort != port { + continue + } + + srcIP, _ := netip.AddrFromSlice(row[s.ip : s.ip+s.ipSize]) + // windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto + if ip != srcIP && (!srcIP.IsUnspecified() || s.tcpState != -1) { + continue + } + + pid := readNativeUint32(row[s.pid : s.pid+4]) + return pid, nil + } + return 0, ErrNotFound +} + +func newSearcher(isV4, isTCP bool) *searcher { + var itemSize, port, ip, ipSize, pid int + tcpState := -1 + switch { + case isV4 && isTCP: + // struct MIB_TCPROW_OWNER_PID + itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0 + case isV4 && !isTCP: + // struct MIB_UDPROW_OWNER_PID + itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8 + case !isV4 && isTCP: + // struct MIB_TCP6ROW_OWNER_PID + itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48 + case !isV4 && !isTCP: + // struct MIB_UDP6ROW_OWNER_PID + itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24 + } + + return &searcher{ + itemSize: itemSize, + port: port, + ip: ip, + ipSize: ipSize, + pid: pid, + tcpState: tcpState, + } +} + +func getTransportTable(fn uintptr, family int, class int) ([]byte, error) { + for size, buf := uint32(8), make([]byte, 8); ; { + ptr := unsafe.Pointer(&buf[0]) + err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0) + + switch err { + case 0: + return buf, nil + case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER): + buf = make([]byte, size) + default: + return nil, fmt.Errorf("syscall error: %d", err) + } + } +} + +func readNativeUint32(b []byte) uint32 { + return *(*uint32)(unsafe.Pointer(&b[0])) +} + +func getExecPathFromPID(pid uint32) (string, error) { + // kernel process starts with a colon in order to distinguish with normal processes + switch pid { + case 0: + // reserved pid for system idle process + return ":System Idle Process", nil + case 4: + // reserved pid for windows kernel image + return ":System", nil + } + h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid) + if err != nil { + return "", err + } + defer windows.CloseHandle(h) + + buf := make([]uint16, syscall.MAX_LONG_PATH) + size := uint32(len(buf)) + r1, _, err := syscall.SyscallN( + procQueryFullProcessImageNameW.Addr(), + uintptr(h), + uintptr(1), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&size)), + ) + if r1 == 0 { + return "", err + } + return syscall.UTF16ToString(buf[:size]), nil +} diff --git a/common/process/searcher_with_name.go b/common/process/searcher_with_name.go new file mode 100644 index 00000000..ed7ffffb --- /dev/null +++ b/common/process/searcher_with_name.go @@ -0,0 +1,25 @@ +//go:build cgo && linux && !android + +package process + +import ( + "context" + "net/netip" + "os/user" + + F "github.com/sagernet/sing/common/format" +) + +func FindProcessInfo(searcher Searcher, ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) { + info, err := searcher.FindProcessInfo(ctx, network, srcIP, srcPort) + if err != nil { + return nil, err + } + if info.UserId != -1 { + osUser, _ := user.LookupId(F.ToString(info.UserId)) + if osUser != nil { + info.User = osUser.Username + } + } + return info, nil +} diff --git a/common/process/searcher_without_name.go b/common/process/searcher_without_name.go new file mode 100644 index 00000000..6b22d0c7 --- /dev/null +++ b/common/process/searcher_without_name.go @@ -0,0 +1,12 @@ +//go:build !cgo || !linux || android + +package process + +import ( + "context" + "net/netip" +) + +func FindProcessInfo(searcher Searcher, ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) { + return searcher.FindProcessInfo(ctx, network, srcIP, srcPort) +} diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index fc926588..f14bff5d 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -122,15 +122,33 @@ func castMetadata(metadata adapter.InboundContext) trafficontrol.Metadata { } else { domain = metadata.Destination.Fqdn } + var processPath string + if metadata.ProcessInfo != nil { + if metadata.ProcessInfo.ProcessPath != "" { + processPath = metadata.ProcessInfo.ProcessPath + } else if metadata.ProcessInfo.PackageName != "" { + processPath = metadata.ProcessInfo.PackageName + } + if processPath == "" { + if metadata.ProcessInfo.UserId != -1 { + processPath = F.ToString(metadata.ProcessInfo.UserId) + } + } else if metadata.ProcessInfo.User != "" { + processPath = F.ToString(processPath, " (", metadata.ProcessInfo.User, ")") + } else if metadata.ProcessInfo.UserId != -1 { + processPath = F.ToString(processPath, " (", metadata.ProcessInfo.UserId, ")") + } + } return trafficontrol.Metadata{ - NetWork: metadata.Network, - Type: inbound, - SrcIP: metadata.Source.Addr, - DstIP: metadata.Destination.Addr, - SrcPort: F.ToString(metadata.Source.Port), - DstPort: F.ToString(metadata.Destination.Port), - Host: domain, - DNSMode: "normal", + NetWork: metadata.Network, + Type: inbound, + SrcIP: metadata.Source.Addr, + DstIP: metadata.Destination.Addr, + SrcPort: F.ToString(metadata.Source.Port), + DstPort: F.ToString(metadata.Destination.Port), + Host: domain, + DNSMode: "normal", + ProcessPath: processPath, } } diff --git a/go.mod b/go.mod index 3748136e..94f8ebef 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( github.com/database64128/tfo-go v1.1.0 + github.com/fsnotify/fsnotify v1.5.4 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/cors v1.2.1 github.com/go-chi/render v1.0.1 diff --git a/go.sum b/go.sum index ea3d7e34..00956b08 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/database64128/tfo-go v1.1.0/go.mod h1:95pOT8bnV3P2Lmu9upHNWFHz6dYGJ9c github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= @@ -68,6 +70,7 @@ golang.org/x/net v0.0.0-20220708220712-1185a9018129 h1:vucSRfWwTsoXro7P+3Cjlr6fl golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= diff --git a/option/dns.go b/option/dns.go index f7ef5738..27559af5 100644 --- a/option/dns.go +++ b/option/dns.go @@ -91,7 +91,7 @@ func (r *DNSRule) UnmarshalJSON(bytes []byte) error { type DefaultDNSRule struct { Inbound Listable[string] `json:"inbound,omitempty"` Network string `json:"network,omitempty"` - User Listable[string] `json:"user,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` Protocol Listable[string] `json:"protocol,omitempty"` Domain Listable[string] `json:"domain,omitempty"` DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` @@ -102,6 +102,10 @@ type DefaultDNSRule struct { SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` SourcePort Listable[uint16] `json:"source_port,omitempty"` Port Listable[uint16] `json:"port,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + User Listable[string] `json:"user,omitempty"` + UserID Listable[int32] `json:"user_id,omitempty"` Outbound Listable[string] `json:"outbound,omitempty"` Server string `json:"server,omitempty"` } @@ -126,6 +130,9 @@ func (r DefaultDNSRule) Equals(other DefaultDNSRule) bool { common.ComparableSliceEquals(r.SourceIPCIDR, other.SourceIPCIDR) && common.ComparableSliceEquals(r.SourcePort, other.SourcePort) && common.ComparableSliceEquals(r.Port, other.Port) && + common.ComparableSliceEquals(r.ProcessName, other.ProcessName) && + common.ComparableSliceEquals(r.UserID, other.UserID) && + common.ComparableSliceEquals(r.PackageName, other.PackageName) && common.ComparableSliceEquals(r.Outbound, other.Outbound) && r.Server == other.Server } diff --git a/option/route.go b/option/route.go index badaf3e8..50f5b4b3 100644 --- a/option/route.go +++ b/option/route.go @@ -13,6 +13,7 @@ type RouteOptions struct { Geosite *GeositeOptions `json:"geosite,omitempty"` Rules []Rule `json:"rules,omitempty"` Final string `json:"final,omitempty"` + FindProcess bool `json:"find_process,omitempty"` AutoDetectInterface bool `json:"auto_detect_interface,omitempty"` DefaultInterface string `json:"default_interface,omitempty"` } @@ -89,7 +90,7 @@ type DefaultRule struct { Inbound Listable[string] `json:"inbound,omitempty"` IPVersion int `json:"ip_version,omitempty"` Network string `json:"network,omitempty"` - User Listable[string] `json:"user,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` Protocol Listable[string] `json:"protocol,omitempty"` Domain Listable[string] `json:"domain,omitempty"` DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` @@ -102,6 +103,10 @@ type DefaultRule struct { IPCIDR Listable[string] `json:"ip_cidr,omitempty"` SourcePort Listable[uint16] `json:"source_port,omitempty"` Port Listable[uint16] `json:"port,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + User Listable[string] `json:"user,omitempty"` + UserID Listable[int32] `json:"user_id,omitempty"` Outbound string `json:"outbound,omitempty"` } @@ -128,6 +133,10 @@ func (r DefaultRule) Equals(other DefaultRule) bool { common.ComparableSliceEquals(r.IPCIDR, other.IPCIDR) && common.ComparableSliceEquals(r.SourcePort, other.SourcePort) && common.ComparableSliceEquals(r.Port, other.Port) && + common.ComparableSliceEquals(r.ProcessName, other.ProcessName) && + common.ComparableSliceEquals(r.PackageName, other.PackageName) && + common.ComparableSliceEquals(r.User, other.User) && + common.ComparableSliceEquals(r.UserID, other.UserID) && r.Outbound == other.Outbound } diff --git a/route/router.go b/route/router.go index 96931ca2..827188cf 100644 --- a/route/router.go +++ b/route/router.go @@ -8,6 +8,7 @@ import ( "net/netip" "net/url" "os" + "os/user" "path/filepath" "reflect" "strings" @@ -17,6 +18,7 @@ import ( "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/geoip" "github.com/sagernet/sing-box/common/geosite" + "github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/common/sniff" "github.com/sagernet/sing-box/common/urltest" C "github.com/sagernet/sing-box/constant" @@ -39,41 +41,36 @@ import ( var _ adapter.Router = (*Router)(nil) type Router struct { - ctx context.Context - logger log.ContextLogger - dnsLogger log.ContextLogger - - outbounds []adapter.Outbound - outboundByTag map[string]adapter.Outbound - rules []adapter.Rule - + ctx context.Context + logger log.ContextLogger + dnsLogger log.ContextLogger + outbounds []adapter.Outbound + outboundByTag map[string]adapter.Outbound + rules []adapter.Rule defaultDetour string defaultOutboundForConnection adapter.Outbound defaultOutboundForPacketConnection adapter.Outbound - - needGeoIPDatabase bool - needGeositeDatabase bool - geoIPOptions option.GeoIPOptions - geositeOptions option.GeositeOptions - geoIPReader *geoip.Reader - geositeReader *geosite.Reader - geositeCache map[string]adapter.Rule - - dnsClient *dns.Client - defaultDomainStrategy dns.DomainStrategy - dnsRules []adapter.Rule - defaultTransport dns.Transport - transports []dns.Transport - transportMap map[string]dns.Transport - - interfaceBindManager control.BindManager - networkMonitor NetworkUpdateMonitor - autoDetectInterface bool - defaultInterface string - interfaceMonitor DefaultInterfaceMonitor - - trafficController adapter.TrafficController - urlTestHistoryStorage *urltest.HistoryStorage + needGeoIPDatabase bool + needGeositeDatabase bool + geoIPOptions option.GeoIPOptions + geositeOptions option.GeositeOptions + geoIPReader *geoip.Reader + geositeReader *geosite.Reader + geositeCache map[string]adapter.Rule + dnsClient *dns.Client + defaultDomainStrategy dns.DomainStrategy + dnsRules []adapter.Rule + defaultTransport dns.Transport + transports []dns.Transport + transportMap map[string]dns.Transport + interfaceBindManager control.BindManager + networkMonitor NetworkUpdateMonitor + autoDetectInterface bool + defaultInterface string + interfaceMonitor DefaultInterfaceMonitor + trafficController adapter.TrafficController + urlTestHistoryStorage *urltest.HistoryStorage + processSearcher process.Searcher } func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.ContextLogger, options option.RouteOptions, dnsOptions option.DNSOptions) (*Router, error) { @@ -84,8 +81,8 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont outboundByTag: make(map[string]adapter.Outbound), rules: make([]adapter.Rule, 0, len(options.Rules)), dnsRules: make([]adapter.Rule, 0, len(dnsOptions.Rules)), - needGeoIPDatabase: hasGeoRule(options.Rules, isGeoIPRule) || hasGeoDNSRule(dnsOptions.Rules, isGeoIPDNSRule), - needGeositeDatabase: hasGeoRule(options.Rules, isGeositeRule) || hasGeoDNSRule(dnsOptions.Rules, isGeositeDNSRule), + needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule), + needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule), geoIPOptions: common.PtrValueOrDefault(options.GeoIP), geositeOptions: common.PtrValueOrDefault(options.Geosite), geositeCache: make(map[string]adapter.Rule), @@ -221,6 +218,13 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont } router.interfaceMonitor = interfaceMonitor } + if hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess { + searcher, err := process.NewSearcher(logger) + if err != nil { + return nil, E.Cause(err, "create process searcher") + } + router.processSearcher = searcher + } return router, nil } @@ -376,6 +380,14 @@ func (r *Router) Start() error { return err } } + if r.processSearcher != nil { + if starter, isStarter := r.processSearcher.(common.Starter); isStarter { + err := starter.Start() + if err != nil { + return E.Cause(err, "initialize process searcher") + } + } + } return nil } @@ -396,6 +408,7 @@ func (r *Router) Close() error { common.PtrOrNil(r.geoIPReader), r.interfaceMonitor, r.networkMonitor, + r.processSearcher, ) } @@ -464,7 +477,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad metadata.DestinationAddresses = addresses r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") } - matchedRule, detour := r.match(ctx, metadata, r.defaultOutboundForConnection) + matchedRule, detour := r.match(ctx, &metadata, r.defaultOutboundForConnection) if !common.Contains(detour.Network(), C.NetworkTCP) { conn.Close() return E.New("missing supported outbound, closing connection") @@ -509,7 +522,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m metadata.DestinationAddresses = addresses r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") } - matchedRule, detour := r.match(ctx, metadata, r.defaultOutboundForPacketConnection) + matchedRule, detour := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection) if !common.Contains(detour.Network(), C.NetworkUDP) { conn.Close() return E.New("missing supported outbound, closing packet connection") @@ -532,9 +545,34 @@ func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, r.defaultDomainStrategy) } -func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) { +func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) { + if r.processSearcher != nil { + processInfo, err := process.FindProcessInfo(r.processSearcher, ctx, metadata.Network, metadata.Source.Addr, int(metadata.Source.Port)) + if err != nil { + r.logger.DebugContext(ctx, "failed to search process: ", err) + } else { + if processInfo.ProcessPath != "" { + r.logger.DebugContext(ctx, "found process path: ", processInfo.ProcessPath) + } else if processInfo.PackageName != "" { + r.logger.DebugContext(ctx, "found package name: ", processInfo.PackageName) + } else if processInfo.UserId != -1 { + if /*needUserName &&*/ true { + osUser, _ := user.LookupId(F.ToString(processInfo.UserId)) + if osUser != nil { + processInfo.User = osUser.Username + } + } + if processInfo.User != "" { + r.logger.DebugContext(ctx, "found user: ", processInfo.User) + } else { + r.logger.DebugContext(ctx, "found user id: ", processInfo.UserId) + } + } + metadata.ProcessInfo = processInfo + } + } for i, rule := range r.rules { - if rule.Match(&metadata) { + if rule.Match(metadata) { detour := rule.Outbound() r.logger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) if outbound, loaded := r.Outbound(detour); loaded { @@ -606,7 +644,7 @@ func (r *Router) URLTestHistoryStorage(create bool) *urltest.HistoryStorage { return r.urlTestHistoryStorage } -func hasGeoRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { +func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { for _, rule := range rules { switch rule.Type { case C.RuleTypeDefault: @@ -624,7 +662,7 @@ func hasGeoRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bo return false } -func hasGeoDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool { +func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool { for _, rule := range rules { switch rule.Type { case C.RuleTypeDefault: @@ -658,6 +696,14 @@ func isGeositeDNSRule(rule option.DefaultDNSRule) bool { return len(rule.Geosite) > 0 } +func isProcessRule(rule option.DefaultRule) bool { + return len(rule.ProcessName) > 0 +} + +func isProcessDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.ProcessName) > 0 +} + func notPrivateNode(code string) bool { return code != "private" } diff --git a/route/rule.go b/route/rule.go index 55970b9c..45a28a75 100644 --- a/route/rule.go +++ b/route/rule.go @@ -86,8 +86,8 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt return nil, E.New("invalid network: ", options.Network) } } - if len(options.User) > 0 { - item := NewUserItem(options.User) + if len(options.AuthUser) > 0 { + item := NewAuthUserItem(options.AuthUser) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } @@ -155,6 +155,26 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.ProcessName) > 0 { + item := NewProcessItem(options.ProcessName) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.PackageName) > 0 { + item := NewPackageNameItem(options.PackageName) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.User) > 0 { + item := NewUserItem(options.User) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.UserID) > 0 { + item := NewUserIDItem(options.UserID) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } return rule, nil } diff --git a/route/rule_auth_user.go b/route/rule_auth_user.go new file mode 100644 index 00000000..fbe053e6 --- /dev/null +++ b/route/rule_auth_user.go @@ -0,0 +1,37 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" + F "github.com/sagernet/sing/common/format" +) + +var _ RuleItem = (*AuthUserItem)(nil) + +type AuthUserItem struct { + users []string + userMap map[string]bool +} + +func NewAuthUserItem(users []string) *AuthUserItem { + userMap := make(map[string]bool) + for _, protocol := range users { + userMap[protocol] = true + } + return &AuthUserItem{ + users: users, + userMap: userMap, + } +} + +func (r *AuthUserItem) Match(metadata *adapter.InboundContext) bool { + return r.userMap[metadata.User] +} + +func (r *AuthUserItem) String() string { + if len(r.users) == 1 { + return F.ToString("auth_user=", r.users[0]) + } + return F.ToString("auth_user=[", strings.Join(r.users, " "), "]") +} diff --git a/route/rule_dns.go b/route/rule_dns.go index 6a7e5f3d..f1975c5e 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -70,8 +70,8 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options return nil, E.New("invalid network: ", options.Network) } } - if len(options.User) > 0 { - item := NewUserItem(options.User) + if len(options.AuthUser) > 0 { + item := NewAuthUserItem(options.AuthUser) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } @@ -126,6 +126,26 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.ProcessName) > 0 { + item := NewProcessItem(options.ProcessName) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.PackageName) > 0 { + item := NewPackageNameItem(options.PackageName) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.User) > 0 { + item := NewUserItem(options.User) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.UserID) > 0 { + item := NewUserIDItem(options.UserID) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.Outbound) > 0 { item := NewOutboundRule(options.Outbound) rule.items = append(rule.items, item) diff --git a/route/rule_package_name.go b/route/rule_package_name.go new file mode 100644 index 00000000..d1ca09eb --- /dev/null +++ b/route/rule_package_name.go @@ -0,0 +1,43 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" +) + +var _ RuleItem = (*PackageNameItem)(nil) + +type PackageNameItem struct { + packageNames []string + packageMap map[string]bool +} + +func NewPackageNameItem(packageNameList []string) *PackageNameItem { + rule := &PackageNameItem{ + packageNames: packageNameList, + packageMap: make(map[string]bool), + } + for _, packageName := range packageNameList { + rule.packageMap[packageName] = true + } + return rule +} + +func (r *PackageNameItem) Match(metadata *adapter.InboundContext) bool { + if metadata.ProcessInfo == nil || metadata.ProcessInfo.PackageName == "" { + return false + } + return r.packageMap[metadata.ProcessInfo.PackageName] +} + +func (r *PackageNameItem) String() string { + var description string + pLen := len(r.packageNames) + if pLen == 1 { + description = "package_name=" + r.packageNames[0] + } else { + description = "package_name=[" + strings.Join(r.packageNames, " ") + "]" + } + return description +} diff --git a/route/rule_process.go b/route/rule_process.go new file mode 100644 index 00000000..ec874a4f --- /dev/null +++ b/route/rule_process.go @@ -0,0 +1,44 @@ +package route + +import ( + "path/filepath" + "strings" + + "github.com/sagernet/sing-box/adapter" +) + +var _ RuleItem = (*ProcessItem)(nil) + +type ProcessItem struct { + processes []string + processMap map[string]bool +} + +func NewProcessItem(processNameList []string) *ProcessItem { + rule := &ProcessItem{ + processes: processNameList, + processMap: make(map[string]bool), + } + for _, processName := range processNameList { + rule.processMap[strings.ToLower(processName)] = true + } + return rule +} + +func (r *ProcessItem) Match(metadata *adapter.InboundContext) bool { + if metadata.ProcessInfo == nil || metadata.ProcessInfo.ProcessPath == "" { + return false + } + return r.processMap[strings.ToLower(filepath.Base(metadata.ProcessInfo.ProcessPath))] +} + +func (r *ProcessItem) String() string { + var description string + pLen := len(r.processes) + if pLen == 1 { + description = "process_name=" + r.processes[0] + } else { + description = "process_name=[" + strings.Join(r.processes, " ") + "]" + } + return description +} diff --git a/route/rule_user.go b/route/rule_user.go index 43d9f48d..bed97fba 100644 --- a/route/rule_user.go +++ b/route/rule_user.go @@ -26,7 +26,10 @@ func NewUserItem(users []string) *UserItem { } func (r *UserItem) Match(metadata *adapter.InboundContext) bool { - return r.userMap[metadata.User] + if metadata.ProcessInfo == nil || metadata.ProcessInfo.User == "" { + return false + } + return r.userMap[metadata.ProcessInfo.User] } func (r *UserItem) String() string { diff --git a/route/rule_user_id.go b/route/rule_user_id.go new file mode 100644 index 00000000..43ab704e --- /dev/null +++ b/route/rule_user_id.go @@ -0,0 +1,44 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" + F "github.com/sagernet/sing/common/format" +) + +var _ RuleItem = (*UserIdItem)(nil) + +type UserIdItem struct { + userIds []int32 + userIdMap map[int32]bool +} + +func NewUserIDItem(userIdList []int32) *UserIdItem { + rule := &UserIdItem{ + userIds: userIdList, + userIdMap: make(map[int32]bool), + } + for _, userId := range userIdList { + rule.userIdMap[userId] = true + } + return rule +} + +func (r *UserIdItem) Match(metadata *adapter.InboundContext) bool { + if metadata.ProcessInfo == nil || metadata.ProcessInfo.UserId == -1 { + return false + } + return r.userIdMap[metadata.ProcessInfo.UserId] +} + +func (r *UserIdItem) String() string { + var description string + pLen := len(r.userIds) + if pLen == 1 { + description = "user_id=" + F.ToString(r.userIds[0]) + } else { + description = "user_id=[" + strings.Join(F.MapToString(r.userIds), " ") + "]" + } + return description +}