diff --git a/adapter/router.go b/adapter/router.go index e1807747..cdc03961 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -32,6 +32,7 @@ type Router interface { LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) InterfaceFinder() control.InterfaceFinder + UpdateInterfaces() error DefaultInterface() string AutoDetectInterface() bool AutoDetectInterfaceFunc() control.Func diff --git a/box.go b/box.go index 434f0655..74307bb2 100644 --- a/box.go +++ b/box.go @@ -133,6 +133,12 @@ func New(options Options) (*Box, error) { if err != nil { return nil, err } + if options.PlatformInterface != nil { + err = options.PlatformInterface.Initialize(ctx, router) + if err != nil { + return nil, E.Cause(err, "initialize platform interface") + } + } preServices := make(map[string]adapter.Service) postServices := make(map[string]adapter.Service) if needClashAPI { diff --git a/experimental/libbox/iterator.go b/experimental/libbox/iterator.go index e23e2a7f..db64a259 100644 --- a/experimental/libbox/iterator.go +++ b/experimental/libbox/iterator.go @@ -29,3 +29,19 @@ func (i *iterator[T]) Next() T { func (i *iterator[T]) HasNext() bool { return len(i.values) > 0 } + +type abstractIterator[T any] interface { + Next() T + HasNext() bool +} + +func iteratorToArray[T any](iterator abstractIterator[T]) []T { + if iterator == nil { + return nil + } + var values []T + for iterator.HasNext() { + values = append(values, iterator.Next()) + } + return values +} diff --git a/experimental/libbox/monitor.go b/experimental/libbox/monitor.go new file mode 100644 index 00000000..685d6ccb --- /dev/null +++ b/experimental/libbox/monitor.go @@ -0,0 +1,183 @@ +package libbox + +import ( + "context" + "net" + "net/netip" + "sync" + + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/x/list" +) + +var ( + _ tun.DefaultInterfaceMonitor = (*platformDefaultInterfaceMonitor)(nil) + _ InterfaceUpdateListener = (*platformDefaultInterfaceMonitor)(nil) +) + +type platformDefaultInterfaceMonitor struct { + *platformInterfaceWrapper + errorHandler E.Handler + networkAddresses []networkAddress + defaultInterfaceName string + defaultInterfaceIndex int + element *list.Element[tun.NetworkUpdateCallback] + access sync.Mutex + callbacks list.List[tun.DefaultInterfaceUpdateCallback] +} + +type networkAddress struct { + interfaceName string + interfaceIndex int + addresses []netip.Prefix +} + +func (m *platformDefaultInterfaceMonitor) Start() error { + return m.iif.StartDefaultInterfaceMonitor(m) +} + +func (m *platformDefaultInterfaceMonitor) Close() error { + return m.iif.CloseDefaultInterfaceMonitor(m) +} + +func (m *platformDefaultInterfaceMonitor) DefaultInterfaceName(destination netip.Addr) string { + for _, address := range m.networkAddresses { + for _, prefix := range address.addresses { + if prefix.Contains(destination) { + return address.interfaceName + } + } + } + return m.defaultInterfaceName +} + +func (m *platformDefaultInterfaceMonitor) DefaultInterfaceIndex(destination netip.Addr) int { + for _, address := range m.networkAddresses { + for _, prefix := range address.addresses { + if prefix.Contains(destination) { + return address.interfaceIndex + } + } + } + return m.defaultInterfaceIndex +} + +func (m *platformDefaultInterfaceMonitor) OverrideAndroidVPN() bool { + return false +} + +func (m *platformDefaultInterfaceMonitor) AndroidVPNEnabled() bool { + return false +} + +func (m *platformDefaultInterfaceMonitor) RegisterCallback(callback tun.DefaultInterfaceUpdateCallback) *list.Element[tun.DefaultInterfaceUpdateCallback] { + m.access.Lock() + defer m.access.Unlock() + return m.callbacks.PushBack(callback) +} + +func (m *platformDefaultInterfaceMonitor) UnregisterCallback(element *list.Element[tun.DefaultInterfaceUpdateCallback]) { + m.access.Lock() + defer m.access.Unlock() + m.callbacks.Remove(element) +} + +func (m *platformDefaultInterfaceMonitor) UpdateDefaultInterface(interfaceName string, interfaceIndex32 int32) { + var err error + if m.iif.UsePlatformInterfaceGetter() { + err = m.updateInterfacesPlatform() + } else { + err = m.updateInterfaces() + } + if err == nil { + err = m.router.UpdateInterfaces() + } + if err != nil { + m.errorHandler.NewError(context.Background(), E.Cause(err, "update interfaces")) + } + interfaceIndex := int(interfaceIndex32) + if interfaceName == "" { + for _, netIf := range m.networkAddresses { + if netIf.interfaceIndex == interfaceIndex { + interfaceName = netIf.interfaceName + break + } + } + } else if interfaceIndex == -1 { + for _, netIf := range m.networkAddresses { + if netIf.interfaceName == interfaceName { + interfaceIndex = netIf.interfaceIndex + break + } + } + } + if interfaceName == "" { + m.errorHandler.NewError(context.Background(), E.New("invalid interface name for ", interfaceIndex)) + return + } else if interfaceIndex == -1 { + m.errorHandler.NewError(context.Background(), E.New("invalid interface index for ", interfaceName)) + return + } + if m.defaultInterfaceName == interfaceName && m.defaultInterfaceIndex == interfaceIndex { + return + } + m.defaultInterfaceName = interfaceName + m.defaultInterfaceIndex = interfaceIndex + m.access.Lock() + callbacks := m.callbacks.Array() + m.access.Unlock() + for _, callback := range callbacks { + err = callback(tun.EventInterfaceUpdate) + if err != nil { + m.errorHandler.NewError(context.Background(), err) + } + } +} + +func (m *platformDefaultInterfaceMonitor) updateInterfaces() error { + interfaces, err := net.Interfaces() + if err != nil { + return err + } + var addresses []networkAddress + for _, iif := range interfaces { + var netAddresses []net.Addr + netAddresses, err = iif.Addrs() + if err != nil { + return err + } + var address networkAddress + address.interfaceName = iif.Name + address.interfaceIndex = iif.Index + address.addresses = common.Map(common.FilterIsInstance(netAddresses, func(it net.Addr) (*net.IPNet, bool) { + value, loaded := it.(*net.IPNet) + return value, loaded + }), func(it *net.IPNet) netip.Prefix { + bits, _ := it.Mask.Size() + return netip.PrefixFrom(M.AddrFromIP(it.IP), bits) + }) + addresses = append(addresses, address) + } + m.networkAddresses = addresses + return nil +} + +func (m *platformDefaultInterfaceMonitor) updateInterfacesPlatform() error { + interfaces, err := m.Interfaces() + if err != nil { + return err + } + var addresses []networkAddress + for _, iif := range interfaces { + var address networkAddress + address.interfaceName = iif.Name + address.interfaceIndex = iif.Index + // address.addresses = common.Map(iif.Addresses, netip.MustParsePrefix) + addresses = append(addresses, address) + } + m.networkAddresses = addresses + return nil +} diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index 09195525..60198b22 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -1,6 +1,8 @@ package libbox -import "github.com/sagernet/sing-box/option" +import ( + "github.com/sagernet/sing-box/option" +) type PlatformInterface interface { AutoDetectInterfaceControl(fd int32) error @@ -10,6 +12,11 @@ type PlatformInterface interface { FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (int32, error) PackageNameByUid(uid int32) (string, error) UIDByPackageName(packageName string) (int32, error) + UsePlatformDefaultInterfaceMonitor() bool + StartDefaultInterfaceMonitor(listener InterfaceUpdateListener) error + CloseDefaultInterfaceMonitor(listener InterfaceUpdateListener) error + UsePlatformInterfaceGetter() bool + GetInterfaces() (NetworkInterfaceIterator, error) } type TunInterface interface { @@ -17,8 +24,19 @@ type TunInterface interface { Close() error } -type OnDemandRuleIterator interface { - Next() OnDemandRule +type InterfaceUpdateListener interface { + UpdateDefaultInterface(interfaceName string, interfaceIndex int32) +} + +type NetworkInterface struct { + Index int32 + MTU int32 + Name string + Addresses StringIterator +} + +type NetworkInterfaceIterator interface { + Next() *NetworkInterface HasNext() bool } @@ -31,6 +49,11 @@ type OnDemandRule interface { ProbeURL() string } +type OnDemandRuleIterator interface { + Next() OnDemandRule + HasNext() bool +} + type onDemandRule struct { option.OnDemandRule } diff --git a/experimental/libbox/platform/interface.go b/experimental/libbox/platform/interface.go index a16297f8..0f3680ca 100644 --- a/experimental/libbox/platform/interface.go +++ b/experimental/libbox/platform/interface.go @@ -1,17 +1,33 @@ package platform import ( + "context" "io" + "net/netip" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" ) type Interface interface { + Initialize(ctx context.Context, router adapter.Router) error AutoDetectInterfaceControl() control.Func OpenTun(options *tun.Options, platformOptions option.TunPlatformOptions) (tun.Tun, error) + UsePlatformDefaultInterfaceMonitor() bool + CreateDefaultInterfaceMonitor(errorHandler E.Handler) tun.DefaultInterfaceMonitor + UsePlatformInterfaceGetter() bool + Interfaces() ([]NetworkInterface, error) process.Searcher io.Writer } + +type NetworkInterface struct { + Index int + MTU int + Name string + Addresses []netip.Prefix +} diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index e6510170..b247a631 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -6,11 +6,13 @@ import ( "syscall" "github.com/sagernet/sing-box" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/experimental/libbox/internal/procfs" "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" @@ -31,7 +33,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box instance, err := box.New(box.Options{ Context: ctx, Options: options, - PlatformInterface: &platformInterfaceWrapper{platformInterface, platformInterface.UseProcFS()}, + PlatformInterface: &platformInterfaceWrapper{iif: platformInterface, useProcFS: platformInterface.UseProcFS()}, }) if err != nil { cancel() @@ -58,6 +60,12 @@ var _ platform.Interface = (*platformInterfaceWrapper)(nil) type platformInterfaceWrapper struct { iif PlatformInterface useProcFS bool + router adapter.Router +} + +func (w *platformInterfaceWrapper) Initialize(ctx context.Context, router adapter.Router) error { + w.router = router + return nil } func (w *platformInterfaceWrapper) AutoDetectInterfaceControl() control.Func { @@ -122,3 +130,36 @@ func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network packageName, _ := w.iif.PackageNameByUid(uid) return &process.Info{UserId: uid, PackageName: packageName}, nil } + +func (w *platformInterfaceWrapper) UsePlatformDefaultInterfaceMonitor() bool { + return w.iif.UsePlatformDefaultInterfaceMonitor() +} + +func (w *platformInterfaceWrapper) CreateDefaultInterfaceMonitor(errorHandler E.Handler) tun.DefaultInterfaceMonitor { + return &platformDefaultInterfaceMonitor{ + platformInterfaceWrapper: w, + errorHandler: errorHandler, + defaultInterfaceIndex: -1, + } +} + +func (w *platformInterfaceWrapper) UsePlatformInterfaceGetter() bool { + return w.iif.UsePlatformInterfaceGetter() +} + +func (w *platformInterfaceWrapper) Interfaces() ([]platform.NetworkInterface, error) { + interfaceIterator, err := w.iif.GetInterfaces() + if err != nil { + return nil, err + } + var interfaces []platform.NetworkInterface + for _, netInterface := range iteratorToArray[*NetworkInterface](interfaceIterator) { + interfaces = append(interfaces, platform.NetworkInterface{ + Index: int(netInterface.Index), + MTU: int(netInterface.MTU), + Name: netInterface.Name, + Addresses: common.Map(iteratorToArray[string](netInterface.Addresses), netip.MustParsePrefix), + }) + } + return interfaces, nil +} diff --git a/route/interface_finder.go b/route/interface_finder.go index 20688c9d..850f091f 100644 --- a/route/interface_finder.go +++ b/route/interface_finder.go @@ -9,7 +9,7 @@ import ( var _ control.InterfaceFinder = (*myInterfaceFinder)(nil) type myInterfaceFinder struct { - ifs []net.Interface + interfaces []net.Interface } func (f *myInterfaceFinder) update() error { @@ -17,12 +17,16 @@ func (f *myInterfaceFinder) update() error { if err != nil { return err } - f.ifs = ifs + f.interfaces = ifs return nil } +func (f *myInterfaceFinder) updateInterfaces(interfaces []net.Interface) { + f.interfaces = interfaces +} + func (f *myInterfaceFinder) InterfaceIndexByName(name string) (interfaceIndex int, err error) { - for _, netInterface := range f.ifs { + for _, netInterface := range f.interfaces { if netInterface.Name == name { return netInterface.Index, nil } @@ -36,7 +40,7 @@ func (f *myInterfaceFinder) InterfaceIndexByName(name string) (interfaceIndex in } func (f *myInterfaceFinder) InterfaceNameByIndex(index int) (interfaceName string, err error) { - for _, netInterface := range f.ifs { + for _, netInterface := range f.interfaces { if netInterface.Index == index { return netInterface.Name, nil } diff --git a/route/router.go b/route/router.go index 40ea6cbf..b9b86085 100644 --- a/route/router.go +++ b/route/router.go @@ -269,29 +269,33 @@ func NewRouter( router.transportMap = transportMap router.transportDomainStrategy = transportDomainStrategy + usePlatformDefaultInterfaceMonitor := platformInterface != nil && platformInterface.UsePlatformDefaultInterfaceMonitor() needInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool { return inbound.HTTPOptions.SetSystemProxy || inbound.MixedOptions.SetSystemProxy || inbound.TunOptions.AutoRoute }) if needInterfaceMonitor { - networkMonitor, err := tun.NewNetworkUpdateMonitor(router) - if err == nil { - router.networkMonitor = networkMonitor - networkMonitor.RegisterCallback(router.interfaceFinder.update) + if !usePlatformDefaultInterfaceMonitor { + networkMonitor, err := tun.NewNetworkUpdateMonitor(router) + if err == nil { + router.networkMonitor = networkMonitor + networkMonitor.RegisterCallback(router.interfaceFinder.update) + } + interfaceMonitor, err := tun.NewDefaultInterfaceMonitor(router.networkMonitor, tun.DefaultInterfaceMonitorOptions{ + OverrideAndroidVPN: options.OverrideAndroidVPN, + }) + if err != nil { + return nil, E.New("auto_detect_interface unsupported on current platform") + } + interfaceMonitor.RegisterCallback(router.notifyNetworkUpdate) + router.interfaceMonitor = interfaceMonitor + } else { + interfaceMonitor := platformInterface.CreateDefaultInterfaceMonitor(router) + interfaceMonitor.RegisterCallback(router.notifyNetworkUpdate) + router.interfaceMonitor = interfaceMonitor } } - if router.networkMonitor != nil && needInterfaceMonitor { - interfaceMonitor, err := tun.NewDefaultInterfaceMonitor(router.networkMonitor, tun.DefaultInterfaceMonitorOptions{ - OverrideAndroidVPN: options.OverrideAndroidVPN, - }) - if err != nil { - return nil, E.New("auto_detect_interface unsupported on current platform") - } - interfaceMonitor.RegisterCallback(router.notifyNetworkUpdate) - router.interfaceMonitor = interfaceMonitor - } - needFindProcess := hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess needPackageManager := C.IsAndroid && platformInterface == nil && (needFindProcess || common.Any(inbounds, func(inbound option.Inbound) bool { return len(inbound.TunOptions.IncludePackage) > 0 || len(inbound.TunOptions.ExcludePackage) > 0 @@ -824,6 +828,25 @@ func (r *Router) InterfaceFinder() control.InterfaceFinder { return &r.interfaceFinder } +func (r *Router) UpdateInterfaces() error { + if r.platformInterface == nil || !r.platformInterface.UsePlatformInterfaceGetter() { + return r.interfaceFinder.update() + } else { + interfaces, err := r.platformInterface.Interfaces() + if err != nil { + return err + } + r.interfaceFinder.updateInterfaces(common.Map(interfaces, func(it platform.NetworkInterface) net.Interface { + return net.Interface{ + Name: it.Name, + Index: it.Index, + MTU: it.MTU, + } + })) + return nil + } +} + func (r *Router) AutoDetectInterface() bool { return r.autoDetectInterface } @@ -1137,7 +1160,7 @@ func (r *Router) NewError(ctx context.Context, err error) { } func (r *Router) notifyNetworkUpdate(int) error { - if C.IsAndroid { + if C.IsAndroid && r.platformInterface == nil { var vpnStatus string if r.interfaceMonitor.AndroidVPNEnabled() { vpnStatus = "enabled" @@ -1149,6 +1172,10 @@ func (r *Router) notifyNetworkUpdate(int) error { r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified())) } + if conntrack.Enabled { + conntrack.Close() + } + for _, outbound := range r.outbounds { listener, isListener := outbound.(adapter.InterfaceUpdateListener) if isListener { @@ -1158,9 +1185,5 @@ func (r *Router) notifyNetworkUpdate(int) error { } } } - - if conntrack.Enabled { - conntrack.Close() - } return nil } diff --git a/transport/dhcp/server.go b/transport/dhcp/server.go index 427017a9..56bd96ed 100644 --- a/transport/dhcp/server.go +++ b/transport/dhcp/server.go @@ -119,7 +119,7 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, func (t *Transport) fetchInterface() (*net.Interface, error) { interfaceName := t.interfaceName if t.autoInterface { - if t.router.NetworkMonitor() == nil { + if t.router.InterfaceMonitor() == nil { return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface") } interfaceName = t.router.InterfaceMonitor().DefaultInterfaceName(netip.Addr{})