Add process_name/package_name/user/user_id rule item

This commit is contained in:
世界 2022-07-23 19:01:41 +08:00
parent 4abf669d09
commit 5f6f33c464
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
23 changed files with 1191 additions and 55 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"net/netip"
"github.com/sagernet/sing-box/common/process"
"github.com/sagernet/sing-dns"
M "github.com/sagernet/sing/common/metadata"
)
@ -34,6 +35,7 @@ type InboundContext struct {
SourceGeoIPCode string
GeoIPCode string
ProcessInfo *process.Info
}
type inboundContextKey struct{}

View file

@ -0,0 +1,21 @@
package process
import (
"context"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
)
type Searcher interface {
FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error)
}
var ErrNotFound = E.New("process not found")
type Info struct {
ProcessPath string
PackageName string
User string
UserId int32
}

View file

@ -0,0 +1,169 @@
package process
import (
"context"
"encoding/xml"
"io"
"net/netip"
"os"
"strconv"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/fsnotify/fsnotify"
)
var _ Searcher = (*androidSearcher)(nil)
type androidSearcher struct {
logger log.ContextLogger
watcher *fsnotify.Watcher
userMap map[string]int32
packageMap map[int32]string
sharedUserMap map[int32]string
}
func NewSearcher(logger log.ContextLogger) (Searcher, error) {
return &androidSearcher{logger: logger}, nil
}
func (s *androidSearcher) Start() error {
err := s.updatePackages()
if err != nil {
return E.Cause(err, "read packages list")
}
err = s.startWatcher()
if err != nil {
s.logger.Debug("create fsnotify watcher: ", err)
}
return nil
}
func (s *androidSearcher) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
err = watcher.Add("/data/system/packages.xml")
if err != nil {
return err
}
s.watcher = watcher
go s.loopUpdate()
return nil
}
func (s *androidSearcher) loopUpdate() {
select {
case _, ok := <-s.watcher.Events:
if !ok {
return
}
err := s.updatePackages()
if err != nil {
s.logger.Error(E.Cause(err, "update packages list"))
}
case err, ok := <-s.watcher.Errors:
if !ok {
return
}
s.logger.Error(E.Cause(err, "fsnotify error"))
}
}
func (s *androidSearcher) Close() error {
return common.Close(common.PtrOrNil(s.watcher))
}
func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) {
_, uid, err := resolveSocketByNetlink(network, srcIP, srcPort)
if err != nil {
return nil, err
}
if sharedUser, loaded := s.sharedUserMap[uid]; loaded {
return &Info{
UserId: uid,
PackageName: sharedUser,
}, nil
}
if packageName, loaded := s.packageMap[uid]; loaded {
return &Info{
UserId: uid,
PackageName: packageName,
}, nil
}
return &Info{UserId: uid}, nil
}
func (s *androidSearcher) updatePackages() error {
userMap := make(map[string]int32)
packageMap := make(map[int32]string)
sharedUserMap := make(map[int32]string)
packagesData, err := os.Open("/data/system/packages.xml")
if err != nil {
return err
}
decoder := xml.NewDecoder(packagesData)
var token xml.Token
for {
token, err = decoder.Token()
if err == io.EOF {
break
} else if err != nil {
return err
}
element, isStart := token.(xml.StartElement)
if !isStart {
continue
}
switch element.Name.Local {
case "package":
var name string
var userID int64
for _, attr := range element.Attr {
switch attr.Name.Local {
case "name":
name = attr.Value
case "userId", "sharedUserId":
userID, err = strconv.ParseInt(attr.Value, 10, 32)
if err != nil {
return err
}
}
}
if userID == 0 && name == "" {
continue
}
userMap[name] = int32(userID)
packageMap[int32(userID)] = name
case "shared-user":
var name string
var userID int64
for _, attr := range element.Attr {
switch attr.Name.Local {
case "name":
name = attr.Value
case "userId":
userID, err = strconv.ParseInt(attr.Value, 10, 32)
if err != nil {
return err
}
packageMap[int32(userID)] = name
}
}
if userID == 0 && name == "" {
continue
}
sharedUserMap[int32(userID)] = name
}
}
s.logger.Info("updated packages list: ", len(packageMap), " packages, ", len(sharedUserMap), " shared users")
s.userMap = userMap
s.packageMap = packageMap
s.sharedUserMap = sharedUserMap
return nil
}

View file

@ -0,0 +1,123 @@
package process
import (
"context"
"encoding/binary"
"net/netip"
"os"
"syscall"
"unsafe"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"golang.org/x/sys/unix"
)
var _ Searcher = (*darwinSearcher)(nil)
type darwinSearcher struct{}
func NewSearcher(logger log.ContextLogger) (Searcher, error) {
return &darwinSearcher{}, nil
}
func (d *darwinSearcher) FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) {
processName, err := findProcessName(network, srcIP, srcPort)
if err != nil {
return nil, err
}
return &Info{ProcessPath: processName, UserId: -1}, nil
}
func findProcessName(network string, ip netip.Addr, port int) (string, error) {
var spath string
switch network {
case C.NetworkTCP:
spath = "net.inet.tcp.pcblist_n"
case C.NetworkUDP:
spath = "net.inet.udp.pcblist_n"
default:
return "", os.ErrInvalid
}
isIPv4 := ip.Is4()
value, err := syscall.Sysctl(spath)
if err != nil {
return "", err
}
buf := []byte(value)
// from darwin-xnu/bsd/netinet/in_pcblist.c:get_pcblist_n
// size/offset are round up (aligned) to 8 bytes in darwin
// rup8(sizeof(xinpcb_n)) + rup8(sizeof(xsocket_n)) +
// 2 * rup8(sizeof(xsockbuf_n)) + rup8(sizeof(xsockstat_n))
itemSize := 384
if network == C.NetworkTCP {
// rup8(sizeof(xtcpcb_n))
itemSize += 208
}
// skip the first xinpgen(24 bytes) block
for i := 24; i+itemSize <= len(buf); i += itemSize {
// offset of xinpcb_n and xsocket_n
inp, so := i, i+104
srcPort := binary.BigEndian.Uint16(buf[inp+18 : inp+20])
if uint16(port) != srcPort {
continue
}
// xinpcb_n.inp_vflag
flag := buf[inp+44]
var srcIP netip.Addr
switch {
case flag&0x1 > 0 && isIPv4:
// ipv4
srcIP = netip.AddrFrom4(*(*[4]byte)(buf[inp+76 : inp+80]))
case flag&0x2 > 0 && !isIPv4:
// ipv6
srcIP = netip.AddrFrom16(*(*[16]byte)(buf[inp+64 : inp+80]))
default:
continue
}
if ip != srcIP {
continue
}
// xsocket_n.so_last_pid
pid := readNativeUint32(buf[so+68 : so+72])
return getExecPathFromPID(pid)
}
return "", ErrNotFound
}
func getExecPathFromPID(pid uint32) (string, error) {
const (
procpidpathinfo = 0xb
procpidpathinfosize = 1024
proccallnumpidinfo = 0x2
)
buf := make([]byte, procpidpathinfosize)
_, _, errno := syscall.Syscall6(
syscall.SYS_PROC_INFO,
proccallnumpidinfo,
uintptr(pid),
procpidpathinfo,
0,
uintptr(unsafe.Pointer(&buf[0])),
procpidpathinfosize)
if errno != 0 {
return "", errno
}
return unix.ByteSliceToString(buf), nil
}
func readNativeUint32(b []byte) uint32 {
return *(*uint32)(unsafe.Pointer(&b[0]))
}

View file

@ -0,0 +1,35 @@
//go:build linux && !android
package process
import (
"context"
"net/netip"
"github.com/sagernet/sing-box/log"
)
var _ Searcher = (*linuxSearcher)(nil)
type linuxSearcher struct {
logger log.ContextLogger
}
func NewSearcher(logger log.ContextLogger) (Searcher, error) {
return &linuxSearcher{logger}, nil
}
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) {
inode, uid, err := resolveSocketByNetlink(network, srcIP, srcPort)
if err != nil {
return nil, err
}
processPath, err := resolveProcessNameByProcSearch(inode, uid)
if err != nil {
s.logger.DebugContext(ctx, "find process path: ", err)
}
return &Info{
UserId: uid,
ProcessPath: processPath,
}, nil
}

View file

@ -0,0 +1,206 @@
//go:build linux
package process
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"net/netip"
"os"
"path"
"strings"
"syscall"
"unicode"
"unsafe"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
)
// from https://github.com/vishvananda/netlink/blob/bca67dfc8220b44ef582c9da4e9172bf1c9ec973/nl/nl_linux.go#L52-L62
var nativeEndian = func() binary.ByteOrder {
var x uint32 = 0x01020304
if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
return binary.BigEndian
}
return binary.LittleEndian
}()
const (
sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + 8 + 48
socketDiagByFamily = 20
pathProc = "/proc"
)
func resolveSocketByNetlink(network string, ip netip.Addr, srcPort int) (inode int32, uid int32, err error) {
var family byte
var protocol byte
switch network {
case C.NetworkTCP:
protocol = syscall.IPPROTO_TCP
case C.NetworkUDP:
protocol = syscall.IPPROTO_UDP
default:
return 0, 0, os.ErrInvalid
}
if ip.Is4() {
family = syscall.AF_INET
} else {
family = syscall.AF_INET6
}
req := packSocketDiagRequest(family, protocol, ip, uint16(srcPort))
socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG)
if err != nil {
return 0, 0, E.Cause(err, "dial netlink")
}
defer syscall.Close(socket)
syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, &syscall.Timeval{Usec: 100})
syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &syscall.Timeval{Usec: 100})
if err = syscall.Connect(socket, &syscall.SockaddrNetlink{
Family: syscall.AF_NETLINK,
Pad: 0,
Pid: 0,
Groups: 0,
}); err != nil {
return 0, 0, err
}
if _, err = syscall.Write(socket, req); err != nil {
return 0, 0, E.Cause(err, "write netlink request")
}
_buffer := buf.StackNew()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
n, err := syscall.Read(socket, buffer.FreeBytes())
if err != nil {
return 0, 0, E.Cause(err, "read netlink response")
}
buffer.Truncate(n)
messages, err := syscall.ParseNetlinkMessage(buffer.Bytes())
if err != nil {
return 0, 0, E.Cause(err, "parse netlink message")
} else if len(messages) == 0 {
return 0, 0, E.New("unexcepted netlink response")
}
message := messages[0]
if message.Header.Type&syscall.NLMSG_ERROR != 0 {
return 0, 0, E.New("netlink message: NLMSG_ERROR")
}
inode, uid = unpackSocketDiagResponse(&messages[0])
if inode < 0 || uid < 0 {
return 0, 0, E.New("invalid inode(", inode, ") or uid(", uid, ")")
}
return
}
func packSocketDiagRequest(family, protocol byte, source netip.Addr, sourcePort uint16) []byte {
s := make([]byte, 16)
copy(s, source.AsSlice())
buf := make([]byte, sizeOfSocketDiagRequest)
nativeEndian.PutUint32(buf[0:4], sizeOfSocketDiagRequest)
nativeEndian.PutUint16(buf[4:6], socketDiagByFamily)
nativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST|syscall.NLM_F_DUMP)
nativeEndian.PutUint32(buf[8:12], 0)
nativeEndian.PutUint32(buf[12:16], 0)
buf[16] = family
buf[17] = protocol
buf[18] = 0
buf[19] = 0
nativeEndian.PutUint32(buf[20:24], 0xFFFFFFFF)
binary.BigEndian.PutUint16(buf[24:26], sourcePort)
binary.BigEndian.PutUint16(buf[26:28], 0)
copy(buf[28:44], s)
copy(buf[44:60], net.IPv6zero)
nativeEndian.PutUint32(buf[60:64], 0)
nativeEndian.PutUint64(buf[64:72], 0xFFFFFFFFFFFFFFFF)
return buf
}
func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid int32) {
if len(msg.Data) < 72 {
return 0, 0
}
data := msg.Data
uid = int32(nativeEndian.Uint32(data[64:68]))
inode = int32(nativeEndian.Uint32(data[68:72]))
return
}
func resolveProcessNameByProcSearch(inode, uid int32) (string, error) {
files, err := os.ReadDir(pathProc)
if err != nil {
return "", err
}
buffer := make([]byte, syscall.PathMax)
socket := []byte(fmt.Sprintf("socket:[%d]", inode))
for _, f := range files {
if !f.IsDir() || !isPid(f.Name()) {
continue
}
info, err := f.Info()
if err != nil {
return "", err
}
if info.Sys().(*syscall.Stat_t).Uid != uint32(uid) {
continue
}
processPath := path.Join(pathProc, f.Name())
fdPath := path.Join(processPath, "fd")
fds, err := os.ReadDir(fdPath)
if err != nil {
continue
}
for _, fd := range fds {
n, err := syscall.Readlink(path.Join(fdPath, fd.Name()), buffer)
if err != nil {
continue
}
if bytes.Equal(buffer[:n], socket) {
return os.Readlink(path.Join(processPath, "exe"))
}
}
}
return "", fmt.Errorf("process of uid(%d),inode(%d) not found", uid, inode)
}
func isPid(s string) bool {
return strings.IndexFunc(s, func(r rune) bool {
return !unicode.IsDigit(r)
}) == -1
}

View file

@ -0,0 +1,13 @@
//go:build !linux && !windows && !darwin
package process
import (
"os"
"github.com/sagernet/sing-box/log"
)
func NewSearcher(logger log.ContextLogger) (Searcher, error) {
return nil, os.ErrInvalid
}

View file

@ -0,0 +1,235 @@
package process
import (
"context"
"fmt"
"net/netip"
"os"
"syscall"
"unsafe"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/windows"
)
var _ Searcher = (*windowsSearcher)(nil)
type windowsSearcher struct{}
func NewSearcher(logger log.ContextLogger) (Searcher, error) {
err := initWin32API()
if err != nil {
return nil, E.Cause(err, "init win32 api")
}
return &windowsSearcher{}, nil
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable")
procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW")
)
func initWin32API() error {
err := modiphlpapi.Load()
if err != nil {
return E.Cause(err, "load iphlpapi.dll")
}
err = procGetExtendedTcpTable.Find()
if err != nil {
return E.Cause(err, "load iphlpapi::GetExtendedTcpTable")
}
err = procGetExtendedUdpTable.Find()
if err != nil {
return E.Cause(err, "load iphlpapi::GetExtendedUdpTable")
}
err = modkernel32.Load()
if err != nil {
return E.Cause(err, "load kernel32.dll")
}
err = procQueryFullProcessImageNameW.Find()
if err != nil {
return E.Cause(err, "load kernel32::QueryFullProcessImageNameW")
}
return nil
}
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) {
processName, err := findProcessName(network, srcIP, srcPort)
if err != nil {
return nil, err
}
return &Info{ProcessPath: processName, UserId: -1}, nil
}
func findProcessName(network string, ip netip.Addr, srcPort int) (string, error) {
family := windows.AF_INET
if ip.Is6() {
family = windows.AF_INET6
}
const (
tcpTablePidConn = 4
udpTablePid = 1
)
var class int
var fn uintptr
switch network {
case C.NetworkTCP:
fn = procGetExtendedTcpTable.Addr()
class = tcpTablePidConn
case C.NetworkUDP:
fn = procGetExtendedUdpTable.Addr()
class = udpTablePid
default:
return "", os.ErrInvalid
}
buf, err := getTransportTable(fn, family, class)
if err != nil {
return "", err
}
s := newSearcher(family == windows.AF_INET, network == C.NetworkTCP)
pid, err := s.Search(buf, ip, uint16(srcPort))
if err != nil {
return "", err
}
return getExecPathFromPID(pid)
}
type searcher struct {
itemSize int
port int
ip int
ipSize int
pid int
tcpState int
}
func (s *searcher) Search(b []byte, ip netip.Addr, port uint16) (uint32, error) {
n := int(readNativeUint32(b[:4]))
itemSize := s.itemSize
for i := 0; i < n; i++ {
row := b[4+itemSize*i : 4+itemSize*(i+1)]
if s.tcpState >= 0 {
tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4])
// MIB_TCP_STATE_ESTAB, only check established connections for TCP
if tcpState != 5 {
continue
}
}
// according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian.
// this field can be illustrated as follows depends on different machine endianess:
// little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB)
// big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB)
// so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32
srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4])))
if srcPort != port {
continue
}
srcIP, _ := netip.AddrFromSlice(row[s.ip : s.ip+s.ipSize])
// windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
if ip != srcIP && (!srcIP.IsUnspecified() || s.tcpState != -1) {
continue
}
pid := readNativeUint32(row[s.pid : s.pid+4])
return pid, nil
}
return 0, ErrNotFound
}
func newSearcher(isV4, isTCP bool) *searcher {
var itemSize, port, ip, ipSize, pid int
tcpState := -1
switch {
case isV4 && isTCP:
// struct MIB_TCPROW_OWNER_PID
itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0
case isV4 && !isTCP:
// struct MIB_UDPROW_OWNER_PID
itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8
case !isV4 && isTCP:
// struct MIB_TCP6ROW_OWNER_PID
itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48
case !isV4 && !isTCP:
// struct MIB_UDP6ROW_OWNER_PID
itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24
}
return &searcher{
itemSize: itemSize,
port: port,
ip: ip,
ipSize: ipSize,
pid: pid,
tcpState: tcpState,
}
}
func getTransportTable(fn uintptr, family int, class int) ([]byte, error) {
for size, buf := uint32(8), make([]byte, 8); ; {
ptr := unsafe.Pointer(&buf[0])
err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0)
switch err {
case 0:
return buf, nil
case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER):
buf = make([]byte, size)
default:
return nil, fmt.Errorf("syscall error: %d", err)
}
}
}
func readNativeUint32(b []byte) uint32 {
return *(*uint32)(unsafe.Pointer(&b[0]))
}
func getExecPathFromPID(pid uint32) (string, error) {
// kernel process starts with a colon in order to distinguish with normal processes
switch pid {
case 0:
// reserved pid for system idle process
return ":System Idle Process", nil
case 4:
// reserved pid for windows kernel image
return ":System", nil
}
h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
if err != nil {
return "", err
}
defer windows.CloseHandle(h)
buf := make([]uint16, syscall.MAX_LONG_PATH)
size := uint32(len(buf))
r1, _, err := syscall.SyscallN(
procQueryFullProcessImageNameW.Addr(),
uintptr(h),
uintptr(1),
uintptr(unsafe.Pointer(&buf[0])),
uintptr(unsafe.Pointer(&size)),
)
if r1 == 0 {
return "", err
}
return syscall.UTF16ToString(buf[:size]), nil
}

View file

@ -0,0 +1,25 @@
//go:build cgo && linux && !android
package process
import (
"context"
"net/netip"
"os/user"
F "github.com/sagernet/sing/common/format"
)
func FindProcessInfo(searcher Searcher, ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) {
info, err := searcher.FindProcessInfo(ctx, network, srcIP, srcPort)
if err != nil {
return nil, err
}
if info.UserId != -1 {
osUser, _ := user.LookupId(F.ToString(info.UserId))
if osUser != nil {
info.User = osUser.Username
}
}
return info, nil
}

View file

@ -0,0 +1,12 @@
//go:build !cgo || !linux || android
package process
import (
"context"
"net/netip"
)
func FindProcessInfo(searcher Searcher, ctx context.Context, network string, srcIP netip.Addr, srcPort int) (*Info, error) {
return searcher.FindProcessInfo(ctx, network, srcIP, srcPort)
}

View file

@ -122,6 +122,23 @@ func castMetadata(metadata adapter.InboundContext) trafficontrol.Metadata {
} else {
domain = metadata.Destination.Fqdn
}
var processPath string
if metadata.ProcessInfo != nil {
if metadata.ProcessInfo.ProcessPath != "" {
processPath = metadata.ProcessInfo.ProcessPath
} else if metadata.ProcessInfo.PackageName != "" {
processPath = metadata.ProcessInfo.PackageName
}
if processPath == "" {
if metadata.ProcessInfo.UserId != -1 {
processPath = F.ToString(metadata.ProcessInfo.UserId)
}
} else if metadata.ProcessInfo.User != "" {
processPath = F.ToString(processPath, " (", metadata.ProcessInfo.User, ")")
} else if metadata.ProcessInfo.UserId != -1 {
processPath = F.ToString(processPath, " (", metadata.ProcessInfo.UserId, ")")
}
}
return trafficontrol.Metadata{
NetWork: metadata.Network,
Type: inbound,
@ -131,6 +148,7 @@ func castMetadata(metadata adapter.InboundContext) trafficontrol.Metadata {
DstPort: F.ToString(metadata.Destination.Port),
Host: domain,
DNSMode: "normal",
ProcessPath: processPath,
}
}

1
go.mod
View file

@ -4,6 +4,7 @@ go 1.18
require (
github.com/database64128/tfo-go v1.1.0
github.com/fsnotify/fsnotify v1.5.4
github.com/go-chi/chi/v5 v5.0.7
github.com/go-chi/cors v1.2.1
github.com/go-chi/render v1.0.1

3
go.sum
View file

@ -4,6 +4,8 @@ github.com/database64128/tfo-go v1.1.0/go.mod h1:95pOT8bnV3P2Lmu9upHNWFHz6dYGJ9c
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI=
github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU=
github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8=
github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
@ -68,6 +70,7 @@ golang.org/x/net v0.0.0-20220708220712-1185a9018129 h1:vucSRfWwTsoXro7P+3Cjlr6fl
golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=

View file

@ -91,7 +91,7 @@ func (r *DNSRule) UnmarshalJSON(bytes []byte) error {
type DefaultDNSRule struct {
Inbound Listable[string] `json:"inbound,omitempty"`
Network string `json:"network,omitempty"`
User Listable[string] `json:"user,omitempty"`
AuthUser Listable[string] `json:"auth_user,omitempty"`
Protocol Listable[string] `json:"protocol,omitempty"`
Domain Listable[string] `json:"domain,omitempty"`
DomainSuffix Listable[string] `json:"domain_suffix,omitempty"`
@ -102,6 +102,10 @@ type DefaultDNSRule struct {
SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"`
SourcePort Listable[uint16] `json:"source_port,omitempty"`
Port Listable[uint16] `json:"port,omitempty"`
ProcessName Listable[string] `json:"process_name,omitempty"`
PackageName Listable[string] `json:"package_name,omitempty"`
User Listable[string] `json:"user,omitempty"`
UserID Listable[int32] `json:"user_id,omitempty"`
Outbound Listable[string] `json:"outbound,omitempty"`
Server string `json:"server,omitempty"`
}
@ -126,6 +130,9 @@ func (r DefaultDNSRule) Equals(other DefaultDNSRule) bool {
common.ComparableSliceEquals(r.SourceIPCIDR, other.SourceIPCIDR) &&
common.ComparableSliceEquals(r.SourcePort, other.SourcePort) &&
common.ComparableSliceEquals(r.Port, other.Port) &&
common.ComparableSliceEquals(r.ProcessName, other.ProcessName) &&
common.ComparableSliceEquals(r.UserID, other.UserID) &&
common.ComparableSliceEquals(r.PackageName, other.PackageName) &&
common.ComparableSliceEquals(r.Outbound, other.Outbound) &&
r.Server == other.Server
}

View file

@ -13,6 +13,7 @@ type RouteOptions struct {
Geosite *GeositeOptions `json:"geosite,omitempty"`
Rules []Rule `json:"rules,omitempty"`
Final string `json:"final,omitempty"`
FindProcess bool `json:"find_process,omitempty"`
AutoDetectInterface bool `json:"auto_detect_interface,omitempty"`
DefaultInterface string `json:"default_interface,omitempty"`
}
@ -89,7 +90,7 @@ type DefaultRule struct {
Inbound Listable[string] `json:"inbound,omitempty"`
IPVersion int `json:"ip_version,omitempty"`
Network string `json:"network,omitempty"`
User Listable[string] `json:"user,omitempty"`
AuthUser Listable[string] `json:"auth_user,omitempty"`
Protocol Listable[string] `json:"protocol,omitempty"`
Domain Listable[string] `json:"domain,omitempty"`
DomainSuffix Listable[string] `json:"domain_suffix,omitempty"`
@ -102,6 +103,10 @@ type DefaultRule struct {
IPCIDR Listable[string] `json:"ip_cidr,omitempty"`
SourcePort Listable[uint16] `json:"source_port,omitempty"`
Port Listable[uint16] `json:"port,omitempty"`
ProcessName Listable[string] `json:"process_name,omitempty"`
PackageName Listable[string] `json:"package_name,omitempty"`
User Listable[string] `json:"user,omitempty"`
UserID Listable[int32] `json:"user_id,omitempty"`
Outbound string `json:"outbound,omitempty"`
}
@ -128,6 +133,10 @@ func (r DefaultRule) Equals(other DefaultRule) bool {
common.ComparableSliceEquals(r.IPCIDR, other.IPCIDR) &&
common.ComparableSliceEquals(r.SourcePort, other.SourcePort) &&
common.ComparableSliceEquals(r.Port, other.Port) &&
common.ComparableSliceEquals(r.ProcessName, other.ProcessName) &&
common.ComparableSliceEquals(r.PackageName, other.PackageName) &&
common.ComparableSliceEquals(r.User, other.User) &&
common.ComparableSliceEquals(r.UserID, other.UserID) &&
r.Outbound == other.Outbound
}

View file

@ -8,6 +8,7 @@ import (
"net/netip"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"strings"
@ -17,6 +18,7 @@ import (
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/geoip"
"github.com/sagernet/sing-box/common/geosite"
"github.com/sagernet/sing-box/common/process"
"github.com/sagernet/sing-box/common/sniff"
"github.com/sagernet/sing-box/common/urltest"
C "github.com/sagernet/sing-box/constant"
@ -42,15 +44,12 @@ type Router struct {
ctx context.Context
logger log.ContextLogger
dnsLogger log.ContextLogger
outbounds []adapter.Outbound
outboundByTag map[string]adapter.Outbound
rules []adapter.Rule
defaultDetour string
defaultOutboundForConnection adapter.Outbound
defaultOutboundForPacketConnection adapter.Outbound
needGeoIPDatabase bool
needGeositeDatabase bool
geoIPOptions option.GeoIPOptions
@ -58,22 +57,20 @@ type Router struct {
geoIPReader *geoip.Reader
geositeReader *geosite.Reader
geositeCache map[string]adapter.Rule
dnsClient *dns.Client
defaultDomainStrategy dns.DomainStrategy
dnsRules []adapter.Rule
defaultTransport dns.Transport
transports []dns.Transport
transportMap map[string]dns.Transport
interfaceBindManager control.BindManager
networkMonitor NetworkUpdateMonitor
autoDetectInterface bool
defaultInterface string
interfaceMonitor DefaultInterfaceMonitor
trafficController adapter.TrafficController
urlTestHistoryStorage *urltest.HistoryStorage
processSearcher process.Searcher
}
func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.ContextLogger, options option.RouteOptions, dnsOptions option.DNSOptions) (*Router, error) {
@ -84,8 +81,8 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont
outboundByTag: make(map[string]adapter.Outbound),
rules: make([]adapter.Rule, 0, len(options.Rules)),
dnsRules: make([]adapter.Rule, 0, len(dnsOptions.Rules)),
needGeoIPDatabase: hasGeoRule(options.Rules, isGeoIPRule) || hasGeoDNSRule(dnsOptions.Rules, isGeoIPDNSRule),
needGeositeDatabase: hasGeoRule(options.Rules, isGeositeRule) || hasGeoDNSRule(dnsOptions.Rules, isGeositeDNSRule),
needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule),
needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule),
geoIPOptions: common.PtrValueOrDefault(options.GeoIP),
geositeOptions: common.PtrValueOrDefault(options.Geosite),
geositeCache: make(map[string]adapter.Rule),
@ -221,6 +218,13 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont
}
router.interfaceMonitor = interfaceMonitor
}
if hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess {
searcher, err := process.NewSearcher(logger)
if err != nil {
return nil, E.Cause(err, "create process searcher")
}
router.processSearcher = searcher
}
return router, nil
}
@ -376,6 +380,14 @@ func (r *Router) Start() error {
return err
}
}
if r.processSearcher != nil {
if starter, isStarter := r.processSearcher.(common.Starter); isStarter {
err := starter.Start()
if err != nil {
return E.Cause(err, "initialize process searcher")
}
}
}
return nil
}
@ -396,6 +408,7 @@ func (r *Router) Close() error {
common.PtrOrNil(r.geoIPReader),
r.interfaceMonitor,
r.networkMonitor,
r.processSearcher,
)
}
@ -464,7 +477,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
metadata.DestinationAddresses = addresses
r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
}
matchedRule, detour := r.match(ctx, metadata, r.defaultOutboundForConnection)
matchedRule, detour := r.match(ctx, &metadata, r.defaultOutboundForConnection)
if !common.Contains(detour.Network(), C.NetworkTCP) {
conn.Close()
return E.New("missing supported outbound, closing connection")
@ -509,7 +522,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
metadata.DestinationAddresses = addresses
r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
}
matchedRule, detour := r.match(ctx, metadata, r.defaultOutboundForPacketConnection)
matchedRule, detour := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection)
if !common.Contains(detour.Network(), C.NetworkUDP) {
conn.Close()
return E.New("missing supported outbound, closing packet connection")
@ -532,9 +545,34 @@ func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr
return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, r.defaultDomainStrategy)
}
func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
if r.processSearcher != nil {
processInfo, err := process.FindProcessInfo(r.processSearcher, ctx, metadata.Network, metadata.Source.Addr, int(metadata.Source.Port))
if err != nil {
r.logger.DebugContext(ctx, "failed to search process: ", err)
} else {
if processInfo.ProcessPath != "" {
r.logger.DebugContext(ctx, "found process path: ", processInfo.ProcessPath)
} else if processInfo.PackageName != "" {
r.logger.DebugContext(ctx, "found package name: ", processInfo.PackageName)
} else if processInfo.UserId != -1 {
if /*needUserName &&*/ true {
osUser, _ := user.LookupId(F.ToString(processInfo.UserId))
if osUser != nil {
processInfo.User = osUser.Username
}
}
if processInfo.User != "" {
r.logger.DebugContext(ctx, "found user: ", processInfo.User)
} else {
r.logger.DebugContext(ctx, "found user id: ", processInfo.UserId)
}
}
metadata.ProcessInfo = processInfo
}
}
for i, rule := range r.rules {
if rule.Match(&metadata) {
if rule.Match(metadata) {
detour := rule.Outbound()
r.logger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
if outbound, loaded := r.Outbound(detour); loaded {
@ -606,7 +644,7 @@ func (r *Router) URLTestHistoryStorage(create bool) *urltest.HistoryStorage {
return r.urlTestHistoryStorage
}
func hasGeoRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool {
func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool {
for _, rule := range rules {
switch rule.Type {
case C.RuleTypeDefault:
@ -624,7 +662,7 @@ func hasGeoRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bo
return false
}
func hasGeoDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool {
func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool {
for _, rule := range rules {
switch rule.Type {
case C.RuleTypeDefault:
@ -658,6 +696,14 @@ func isGeositeDNSRule(rule option.DefaultDNSRule) bool {
return len(rule.Geosite) > 0
}
func isProcessRule(rule option.DefaultRule) bool {
return len(rule.ProcessName) > 0
}
func isProcessDNSRule(rule option.DefaultDNSRule) bool {
return len(rule.ProcessName) > 0
}
func notPrivateNode(code string) bool {
return code != "private"
}

View file

@ -86,8 +86,8 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
return nil, E.New("invalid network: ", options.Network)
}
}
if len(options.User) > 0 {
item := NewUserItem(options.User)
if len(options.AuthUser) > 0 {
item := NewAuthUserItem(options.AuthUser)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
@ -155,6 +155,26 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.ProcessName) > 0 {
item := NewProcessItem(options.ProcessName)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.PackageName) > 0 {
item := NewPackageNameItem(options.PackageName)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.User) > 0 {
item := NewUserItem(options.User)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.UserID) > 0 {
item := NewUserIDItem(options.UserID)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
return rule, nil
}

37
route/rule_auth_user.go Normal file
View file

@ -0,0 +1,37 @@
package route
import (
"strings"
"github.com/sagernet/sing-box/adapter"
F "github.com/sagernet/sing/common/format"
)
var _ RuleItem = (*AuthUserItem)(nil)
type AuthUserItem struct {
users []string
userMap map[string]bool
}
func NewAuthUserItem(users []string) *AuthUserItem {
userMap := make(map[string]bool)
for _, protocol := range users {
userMap[protocol] = true
}
return &AuthUserItem{
users: users,
userMap: userMap,
}
}
func (r *AuthUserItem) Match(metadata *adapter.InboundContext) bool {
return r.userMap[metadata.User]
}
func (r *AuthUserItem) String() string {
if len(r.users) == 1 {
return F.ToString("auth_user=", r.users[0])
}
return F.ToString("auth_user=[", strings.Join(r.users, " "), "]")
}

View file

@ -70,8 +70,8 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
return nil, E.New("invalid network: ", options.Network)
}
}
if len(options.User) > 0 {
item := NewUserItem(options.User)
if len(options.AuthUser) > 0 {
item := NewAuthUserItem(options.AuthUser)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
@ -126,6 +126,26 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.ProcessName) > 0 {
item := NewProcessItem(options.ProcessName)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.PackageName) > 0 {
item := NewPackageNameItem(options.PackageName)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.User) > 0 {
item := NewUserItem(options.User)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.UserID) > 0 {
item := NewUserIDItem(options.UserID)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.Outbound) > 0 {
item := NewOutboundRule(options.Outbound)
rule.items = append(rule.items, item)

View file

@ -0,0 +1,43 @@
package route
import (
"strings"
"github.com/sagernet/sing-box/adapter"
)
var _ RuleItem = (*PackageNameItem)(nil)
type PackageNameItem struct {
packageNames []string
packageMap map[string]bool
}
func NewPackageNameItem(packageNameList []string) *PackageNameItem {
rule := &PackageNameItem{
packageNames: packageNameList,
packageMap: make(map[string]bool),
}
for _, packageName := range packageNameList {
rule.packageMap[packageName] = true
}
return rule
}
func (r *PackageNameItem) Match(metadata *adapter.InboundContext) bool {
if metadata.ProcessInfo == nil || metadata.ProcessInfo.PackageName == "" {
return false
}
return r.packageMap[metadata.ProcessInfo.PackageName]
}
func (r *PackageNameItem) String() string {
var description string
pLen := len(r.packageNames)
if pLen == 1 {
description = "package_name=" + r.packageNames[0]
} else {
description = "package_name=[" + strings.Join(r.packageNames, " ") + "]"
}
return description
}

44
route/rule_process.go Normal file
View file

@ -0,0 +1,44 @@
package route
import (
"path/filepath"
"strings"
"github.com/sagernet/sing-box/adapter"
)
var _ RuleItem = (*ProcessItem)(nil)
type ProcessItem struct {
processes []string
processMap map[string]bool
}
func NewProcessItem(processNameList []string) *ProcessItem {
rule := &ProcessItem{
processes: processNameList,
processMap: make(map[string]bool),
}
for _, processName := range processNameList {
rule.processMap[strings.ToLower(processName)] = true
}
return rule
}
func (r *ProcessItem) Match(metadata *adapter.InboundContext) bool {
if metadata.ProcessInfo == nil || metadata.ProcessInfo.ProcessPath == "" {
return false
}
return r.processMap[strings.ToLower(filepath.Base(metadata.ProcessInfo.ProcessPath))]
}
func (r *ProcessItem) String() string {
var description string
pLen := len(r.processes)
if pLen == 1 {
description = "process_name=" + r.processes[0]
} else {
description = "process_name=[" + strings.Join(r.processes, " ") + "]"
}
return description
}

View file

@ -26,7 +26,10 @@ func NewUserItem(users []string) *UserItem {
}
func (r *UserItem) Match(metadata *adapter.InboundContext) bool {
return r.userMap[metadata.User]
if metadata.ProcessInfo == nil || metadata.ProcessInfo.User == "" {
return false
}
return r.userMap[metadata.ProcessInfo.User]
}
func (r *UserItem) String() string {

44
route/rule_user_id.go Normal file
View file

@ -0,0 +1,44 @@
package route
import (
"strings"
"github.com/sagernet/sing-box/adapter"
F "github.com/sagernet/sing/common/format"
)
var _ RuleItem = (*UserIdItem)(nil)
type UserIdItem struct {
userIds []int32
userIdMap map[int32]bool
}
func NewUserIDItem(userIdList []int32) *UserIdItem {
rule := &UserIdItem{
userIds: userIdList,
userIdMap: make(map[int32]bool),
}
for _, userId := range userIdList {
rule.userIdMap[userId] = true
}
return rule
}
func (r *UserIdItem) Match(metadata *adapter.InboundContext) bool {
if metadata.ProcessInfo == nil || metadata.ProcessInfo.UserId == -1 {
return false
}
return r.userIdMap[metadata.ProcessInfo.UserId]
}
func (r *UserIdItem) String() string {
var description string
pLen := len(r.userIds)
if pLen == 1 {
description = "user_id=" + F.ToString(r.userIds[0])
} else {
description = "user_id=[" + strings.Join(F.MapToString(r.userIds), " ") + "]"
}
return description
}