refactor: Platform Interfaces

This commit is contained in:
世界 2024-11-11 16:23:45 +08:00
parent ca21816051
commit f92fb77bed
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
15 changed files with 267 additions and 251 deletions

View file

@ -9,6 +9,8 @@ type NetworkManager interface {
Lifecycle Lifecycle
InterfaceFinder() control.InterfaceFinder InterfaceFinder() control.InterfaceFinder
UpdateInterfaces() error UpdateInterfaces() error
DefaultNetworkInterface() *NetworkInterface
NetworkInterfaces() []NetworkInterface
DefaultInterface() string DefaultInterface() string
AutoDetectInterface() bool AutoDetectInterface() bool
AutoDetectInterfaceFunc() control.Func AutoDetectInterfaceFunc() control.Func
@ -21,3 +23,20 @@ type NetworkManager interface {
WIFIState() WIFIState WIFIState() WIFIState
ResetNetwork() ResetNetwork()
} }
type InterfaceUpdateListener interface {
InterfaceUpdated()
}
type WIFIState struct {
SSID string
BSSID string
}
type NetworkInterface struct {
control.Interface
Type string
DNSServers []string
Expensive bool
Constrained bool
}

View file

@ -119,12 +119,3 @@ func (c *HTTPStartContext) Close() {
client.CloseIdleConnections() client.CloseIdleConnections()
} }
} }
type InterfaceUpdateListener interface {
InterfaceUpdated()
}
type WIFIState struct {
SSID string
BSSID string
}

View file

@ -2,7 +2,6 @@ package settings
import ( import (
"context" "context"
"net/netip"
"strconv" "strconv"
"strings" "strings"
@ -77,14 +76,14 @@ func (p *DarwinSystemProxy) update(event int) {
} }
func (p *DarwinSystemProxy) update0() error { func (p *DarwinSystemProxy) update0() error {
newInterfaceName := p.monitor.DefaultInterfaceName(netip.IPv4Unspecified()) newInterface := p.monitor.DefaultInterface()
if p.interfaceName == newInterfaceName { if p.interfaceName == newInterface.Name {
return nil return nil
} }
if p.interfaceName != "" { if p.interfaceName != "" {
_ = p.Disable() _ = p.Disable()
} }
p.interfaceName = newInterfaceName p.interfaceName = newInterface.Name
interfaceDisplayName, err := getInterfaceDisplayName(p.interfaceName) interfaceDisplayName, err := getInterfaceDisplayName(p.interfaceName)
if err != nil { if err != nil {
return err return err

8
constant/network.go Normal file
View file

@ -0,0 +1,8 @@
package constant
const (
InterfaceTypeWIFI = "wifi"
InterfaceTypeCellular = "cellular"
InterfaceTypeEthernet = "ethernet"
InterfaceTypeOther = "other"
)

View file

@ -6,7 +6,7 @@ import (
const IsAndroid = goos.IsAndroid == 1 const IsAndroid = goos.IsAndroid == 1
const IsDarwin = goos.IsDarwin == 1 const IsDarwin = goos.IsDarwin == 1 || goos.IsIos == 1
const IsDragonfly = goos.IsDragonfly == 1 const IsDragonfly = goos.IsDragonfly == 1

View file

@ -74,11 +74,7 @@ func (s *platformInterfaceStub) CreateDefaultInterfaceMonitor(logger logger.Logg
return (*interfaceMonitorStub)(nil) return (*interfaceMonitorStub)(nil)
} }
func (s *platformInterfaceStub) UsePlatformInterfaceGetter() bool { func (s *platformInterfaceStub) Interfaces() ([]adapter.NetworkInterface, error) {
return true
}
func (s *platformInterfaceStub) Interfaces() ([]control.Interface, error) {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
@ -111,16 +107,8 @@ func (s *interfaceMonitorStub) Close() error {
return os.ErrInvalid return os.ErrInvalid
} }
func (s *interfaceMonitorStub) DefaultInterfaceName(destination netip.Addr) string { func (s *interfaceMonitorStub) DefaultInterface() *control.Interface {
return "" return nil
}
func (s *interfaceMonitorStub) DefaultInterfaceIndex(destination netip.Addr) int {
return -1
}
func (s *interfaceMonitorStub) DefaultInterface(destination netip.Addr) (string, int) {
return "", -1
} }
func (s *interfaceMonitorStub) OverrideAndroidVPN() bool { func (s *interfaceMonitorStub) OverrideAndroidVPN() bool {

View file

@ -1,4 +1,4 @@
//go:build !linux //go:build !unix
package libbox package libbox

View file

@ -1,3 +1,5 @@
//go:build unix
package libbox package libbox
import ( import (

View file

@ -1,15 +1,10 @@
package libbox package libbox
import ( import (
"net"
"net/netip"
"sync"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
) )
@ -20,19 +15,9 @@ var (
type platformDefaultInterfaceMonitor struct { type platformDefaultInterfaceMonitor struct {
*platformInterfaceWrapper *platformInterfaceWrapper
networkAddresses []networkAddress element *list.Element[tun.NetworkUpdateCallback]
defaultInterfaceName string callbacks list.List[tun.DefaultInterfaceUpdateCallback]
defaultInterfaceIndex int logger logger.Logger
element *list.Element[tun.NetworkUpdateCallback]
access sync.Mutex
callbacks list.List[tun.DefaultInterfaceUpdateCallback]
logger logger.Logger
}
type networkAddress struct {
interfaceName string
interfaceIndex int
addresses []netip.Prefix
} }
func (m *platformDefaultInterfaceMonitor) Start() error { func (m *platformDefaultInterfaceMonitor) Start() error {
@ -43,37 +28,10 @@ func (m *platformDefaultInterfaceMonitor) Close() error {
return m.iif.CloseDefaultInterfaceMonitor(m) return m.iif.CloseDefaultInterfaceMonitor(m)
} }
func (m *platformDefaultInterfaceMonitor) DefaultInterfaceName(destination netip.Addr) string { func (m *platformDefaultInterfaceMonitor) DefaultInterface() *control.Interface {
for _, address := range m.networkAddresses { m.defaultInterfaceAccess.Lock()
for _, prefix := range address.addresses { defer m.defaultInterfaceAccess.Unlock()
if prefix.Contains(destination) { return m.defaultInterface
return address.interfaceName
}
}
}
return m.defaultInterfaceName
}
func (m *platformDefaultInterfaceMonitor) DefaultInterfaceIndex(destination netip.Addr) int {
for _, address := range m.networkAddresses {
for _, prefix := range address.addresses {
if prefix.Contains(destination) {
return address.interfaceIndex
}
}
}
return m.defaultInterfaceIndex
}
func (m *platformDefaultInterfaceMonitor) DefaultInterface(destination netip.Addr) (string, int) {
for _, address := range m.networkAddresses {
for _, prefix := range address.addresses {
if prefix.Contains(destination) {
return address.interfaceName, address.interfaceIndex
}
}
}
return m.defaultInterfaceName, m.defaultInterfaceIndex
} }
func (m *platformDefaultInterfaceMonitor) OverrideAndroidVPN() bool { func (m *platformDefaultInterfaceMonitor) OverrideAndroidVPN() bool {
@ -85,104 +43,57 @@ func (m *platformDefaultInterfaceMonitor) AndroidVPNEnabled() bool {
} }
func (m *platformDefaultInterfaceMonitor) RegisterCallback(callback tun.DefaultInterfaceUpdateCallback) *list.Element[tun.DefaultInterfaceUpdateCallback] { func (m *platformDefaultInterfaceMonitor) RegisterCallback(callback tun.DefaultInterfaceUpdateCallback) *list.Element[tun.DefaultInterfaceUpdateCallback] {
m.access.Lock() m.defaultInterfaceAccess.Lock()
defer m.access.Unlock() defer m.defaultInterfaceAccess.Unlock()
return m.callbacks.PushBack(callback) return m.callbacks.PushBack(callback)
} }
func (m *platformDefaultInterfaceMonitor) UnregisterCallback(element *list.Element[tun.DefaultInterfaceUpdateCallback]) { func (m *platformDefaultInterfaceMonitor) UnregisterCallback(element *list.Element[tun.DefaultInterfaceUpdateCallback]) {
m.access.Lock() m.defaultInterfaceAccess.Lock()
defer m.access.Unlock() defer m.defaultInterfaceAccess.Unlock()
m.callbacks.Remove(element) m.callbacks.Remove(element)
} }
func (m *platformDefaultInterfaceMonitor) UpdateDefaultInterface(interfaceName string, interfaceIndex32 int32) { func (m *platformDefaultInterfaceMonitor) UpdateDefaultInterface(interfaceName string, interfaceIndex32 int32, isExpensive bool, isConstrained bool) {
if sFixAndroidStack { if sFixAndroidStack {
go m.updateDefaultInterface(interfaceName, interfaceIndex32) go m.updateDefaultInterface(interfaceName, interfaceIndex32, isExpensive, isConstrained)
} else { } else {
m.updateDefaultInterface(interfaceName, interfaceIndex32) m.updateDefaultInterface(interfaceName, interfaceIndex32, isExpensive, isConstrained)
} }
} }
func (m *platformDefaultInterfaceMonitor) updateDefaultInterface(interfaceName string, interfaceIndex32 int32) { func (m *platformDefaultInterfaceMonitor) updateDefaultInterface(interfaceName string, interfaceIndex32 int32, isExpensive bool, isConstrained bool) {
if interfaceName == "" || interfaceIndex32 == -1 { m.isExpensive = isExpensive
m.defaultInterfaceName = "" m.isConstrained = isConstrained
m.defaultInterfaceIndex = -1 err := m.networkManager.UpdateInterfaces()
m.access.Lock()
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
callback(tun.EventNoRoute)
}
return
}
var err error
if m.iif.UsePlatformInterfaceGetter() {
err = m.updateInterfacesPlatform()
} else {
err = m.updateInterfaces()
}
if err == nil {
err = m.networkManager.UpdateInterfaces()
}
if err != nil { if err != nil {
m.logger.Error(E.Cause(err, "update interfaces")) m.logger.Error(E.Cause(err, "update interfaces"))
} }
interfaceIndex := int(interfaceIndex32) m.defaultInterfaceAccess.Lock()
if m.defaultInterfaceName == interfaceName && m.defaultInterfaceIndex == interfaceIndex { if interfaceIndex32 == -1 {
m.defaultInterface = nil
callbacks := m.callbacks.Array()
m.defaultInterfaceAccess.Unlock()
for _, callback := range callbacks {
callback(tun.EventInterfaceUpdate)
}
return
}
oldInterface := m.defaultInterface
newInterface, err := m.networkManager.InterfaceFinder().ByIndex(int(interfaceIndex32))
if err != nil {
m.defaultInterfaceAccess.Unlock()
m.logger.Error(E.Cause(err, "find updated interface: ", interfaceName))
return
}
m.defaultInterface = newInterface
if oldInterface != nil && oldInterface.Name == m.defaultInterface.Name && oldInterface.Index == m.defaultInterface.Index {
m.defaultInterfaceAccess.Unlock()
return return
} }
m.defaultInterfaceName = interfaceName
m.defaultInterfaceIndex = interfaceIndex
m.access.Lock()
callbacks := m.callbacks.Array() callbacks := m.callbacks.Array()
m.access.Unlock() m.defaultInterfaceAccess.Unlock()
for _, callback := range callbacks { for _, callback := range callbacks {
callback(tun.EventInterfaceUpdate) callback(tun.EventInterfaceUpdate)
} }
} }
func (m *platformDefaultInterfaceMonitor) updateInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {
return err
}
var addresses []networkAddress
for _, iif := range interfaces {
var netAddresses []net.Addr
netAddresses, err = iif.Addrs()
if err != nil {
return err
}
var address networkAddress
address.interfaceName = iif.Name
address.interfaceIndex = iif.Index
address.addresses = common.Map(common.FilterIsInstance(netAddresses, func(it net.Addr) (*net.IPNet, bool) {
value, loaded := it.(*net.IPNet)
return value, loaded
}), func(it *net.IPNet) netip.Prefix {
bits, _ := it.Mask.Size()
return netip.PrefixFrom(M.AddrFromIP(it.IP), bits)
})
addresses = append(addresses, address)
}
m.networkAddresses = addresses
return nil
}
func (m *platformDefaultInterfaceMonitor) updateInterfacesPlatform() error {
interfaces, err := m.Interfaces()
if err != nil {
return err
}
var addresses []networkAddress
for _, iif := range interfaces {
var address networkAddress
address.interfaceName = iif.Name
address.interfaceIndex = iif.Index
// address.addresses = common.Map(iif.Addresses, netip.MustParsePrefix)
addresses = append(addresses, address)
}
m.networkAddresses = addresses
return nil
}

View file

@ -1,6 +1,7 @@
package libbox package libbox
import ( import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
) )
@ -13,10 +14,8 @@ type PlatformInterface interface {
FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (int32, error) FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (int32, error)
PackageNameByUid(uid int32) (string, error) PackageNameByUid(uid int32) (string, error)
UIDByPackageName(packageName string) (int32, error) UIDByPackageName(packageName string) (int32, error)
UsePlatformDefaultInterfaceMonitor() bool
StartDefaultInterfaceMonitor(listener InterfaceUpdateListener) error StartDefaultInterfaceMonitor(listener InterfaceUpdateListener) error
CloseDefaultInterfaceMonitor(listener InterfaceUpdateListener) error CloseDefaultInterfaceMonitor(listener InterfaceUpdateListener) error
UsePlatformInterfaceGetter() bool
GetInterfaces() (NetworkInterfaceIterator, error) GetInterfaces() (NetworkInterfaceIterator, error)
UnderNetworkExtension() bool UnderNetworkExtension() bool
IncludeAllNetworks() bool IncludeAllNetworks() bool
@ -31,15 +30,26 @@ type TunInterface interface {
} }
type InterfaceUpdateListener interface { type InterfaceUpdateListener interface {
UpdateDefaultInterface(interfaceName string, interfaceIndex int32) UpdateDefaultInterface(interfaceName string, interfaceIndex int32, isExpensive bool, isConstrained bool)
} }
const (
InterfaceTypeWIFI = C.InterfaceTypeWIFI
InterfaceTypeCellular = C.InterfaceTypeCellular
InterfaceTypeEthernet = C.InterfaceTypeEthernet
InterfaceTypeOther = C.InterfaceTypeOther
)
type NetworkInterface struct { type NetworkInterface struct {
Index int32 Index int32
MTU int32 MTU int32
Name string Name string
Addresses StringIterator Addresses StringIterator
Flags int32 Flags int32
Type string
DNSServer StringIterator
Metered bool
} }
type WIFIState struct { type WIFIState struct {

View file

@ -5,7 +5,6 @@ import (
"github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/common/process"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
) )
@ -14,10 +13,8 @@ type Interface interface {
UsePlatformAutoDetectInterfaceControl() bool UsePlatformAutoDetectInterfaceControl() bool
AutoDetectInterfaceControl(fd int) error AutoDetectInterfaceControl(fd int) error
OpenTun(options *tun.Options, platformOptions option.TunPlatformOptions) (tun.Tun, error) OpenTun(options *tun.Options, platformOptions option.TunPlatformOptions) (tun.Tun, error)
UsePlatformDefaultInterfaceMonitor() bool
CreateDefaultInterfaceMonitor(logger logger.Logger) tun.DefaultInterfaceMonitor CreateDefaultInterfaceMonitor(logger logger.Logger) tun.DefaultInterfaceMonitor
UsePlatformInterfaceGetter() bool Interfaces() ([]adapter.NetworkInterface, error)
Interfaces() ([]control.Interface, error)
UnderNetworkExtension() bool UnderNetworkExtension() bool
IncludeAllNetworks() bool IncludeAllNetworks() bool
ClearDNSCache() ClearDNSCache()

View file

@ -6,6 +6,7 @@ import (
"os" "os"
"runtime" "runtime"
runtimeDebug "runtime/debug" runtimeDebug "runtime/debug"
"sync"
"syscall" "syscall"
"time" "time"
@ -54,7 +55,10 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
urlTestHistoryStorage := urltest.NewHistoryStorage() urlTestHistoryStorage := urltest.NewHistoryStorage()
ctx = service.ContextWithPtr(ctx, urlTestHistoryStorage) ctx = service.ContextWithPtr(ctx, urlTestHistoryStorage)
platformWrapper := &platformInterfaceWrapper{iif: platformInterface, useProcFS: platformInterface.UseProcFS()} platformWrapper := &platformInterfaceWrapper{
iif: platformInterface,
useProcFS: platformInterface.UseProcFS(),
}
service.MustRegister[platform.Interface](ctx, platformWrapper) service.MustRegister[platform.Interface](ctx, platformWrapper)
instance, err := box.New(box.Options{ instance, err := box.New(box.Options{
Context: ctx, Context: ctx,
@ -119,9 +123,14 @@ var (
) )
type platformInterfaceWrapper struct { type platformInterfaceWrapper struct {
iif PlatformInterface iif PlatformInterface
useProcFS bool useProcFS bool
networkManager adapter.NetworkManager networkManager adapter.NetworkManager
myTunName string
defaultInterfaceAccess sync.Mutex
defaultInterface *control.Interface
isExpensive bool
isConstrained bool
} }
func (w *platformInterfaceWrapper) Initialize(networkManager adapter.NetworkManager) error { func (w *platformInterfaceWrapper) Initialize(networkManager adapter.NetworkManager) error {
@ -161,38 +170,42 @@ func (w *platformInterfaceWrapper) OpenTun(options *tun.Options, platformOptions
return nil, E.Cause(err, "dup tun file descriptor") return nil, E.Cause(err, "dup tun file descriptor")
} }
options.FileDescriptor = dupFd options.FileDescriptor = dupFd
w.myTunName = options.Name
return tun.New(*options) return tun.New(*options)
} }
func (w *platformInterfaceWrapper) UsePlatformDefaultInterfaceMonitor() bool {
return w.iif.UsePlatformDefaultInterfaceMonitor()
}
func (w *platformInterfaceWrapper) CreateDefaultInterfaceMonitor(logger logger.Logger) tun.DefaultInterfaceMonitor { func (w *platformInterfaceWrapper) CreateDefaultInterfaceMonitor(logger logger.Logger) tun.DefaultInterfaceMonitor {
return &platformDefaultInterfaceMonitor{ return &platformDefaultInterfaceMonitor{
platformInterfaceWrapper: w, platformInterfaceWrapper: w,
defaultInterfaceIndex: -1,
logger: logger, logger: logger,
} }
} }
func (w *platformInterfaceWrapper) UsePlatformInterfaceGetter() bool { func (w *platformInterfaceWrapper) Interfaces() ([]adapter.NetworkInterface, error) {
return w.iif.UsePlatformInterfaceGetter()
}
func (w *platformInterfaceWrapper) Interfaces() ([]control.Interface, error) {
interfaceIterator, err := w.iif.GetInterfaces() interfaceIterator, err := w.iif.GetInterfaces()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var interfaces []control.Interface var interfaces []adapter.NetworkInterface
for _, netInterface := range iteratorToArray[*NetworkInterface](interfaceIterator) { for _, netInterface := range iteratorToArray[*NetworkInterface](interfaceIterator) {
interfaces = append(interfaces, control.Interface{ if netInterface.Name == w.myTunName {
Index: int(netInterface.Index), continue
MTU: int(netInterface.MTU), }
Name: netInterface.Name, w.defaultInterfaceAccess.Lock()
Addresses: common.Map(iteratorToArray[string](netInterface.Addresses), netip.MustParsePrefix), isDefault := w.defaultInterface != nil && int(netInterface.Index) == w.defaultInterface.Index
Flags: linkFlags(uint32(netInterface.Flags)), w.defaultInterfaceAccess.Unlock()
interfaces = append(interfaces, adapter.NetworkInterface{
Interface: control.Interface{
Index: int(netInterface.Index),
MTU: int(netInterface.MTU),
Name: netInterface.Name,
Addresses: common.Map(iteratorToArray[string](netInterface.Addresses), netip.MustParsePrefix),
Flags: linkFlags(uint32(netInterface.Flags)),
},
Type: netInterface.Type,
DNSServers: iteratorToArray[string](netInterface.DNSServer),
Expensive: netInterface.Metered || isDefault && w.isExpensive,
Constrained: isDefault && w.isConstrained,
}) })
} }
return interfaces, nil return interfaces, nil

View file

@ -33,7 +33,7 @@ func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn {
} }
if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn { if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
if !source.Addr().IsLoopback() { if !source.Addr().IsLoopback() {
_, err := l.networkManager.InterfaceFinder().InterfaceByAddr(source.Addr()) _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr())
if err != nil { if err != nil {
return conn return conn
} }
@ -59,7 +59,7 @@ func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Soc
return conn return conn
} }
if !source.Addr().IsLoopback() { if !source.Addr().IsLoopback() {
_, err := l.networkManager.InterfaceFinder().InterfaceByAddr(source.Addr()) _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr())
if err != nil { if err != nil {
return conn return conn
} }
@ -82,7 +82,7 @@ func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.Ad
return false return false
} }
if !source.Addr().IsLoopback() { if !source.Addr().IsLoopback() {
_, err := l.networkManager.InterfaceFinder().InterfaceByAddr(source.Addr()) _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr())
if err != nil { if err != nil {
return false return false
} }

View file

@ -3,9 +3,10 @@ package route
import ( import (
"context" "context"
"errors" "errors"
"net/netip" "net"
"os" "os"
"runtime" "runtime"
"strings"
"syscall" "syscall"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
@ -15,33 +16,41 @@ import (
"github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/winpowrprof" "github.com/sagernet/sing/common/winpowrprof"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause" "github.com/sagernet/sing/service/pause"
"golang.org/x/exp/slices"
) )
var _ adapter.NetworkManager = (*NetworkManager)(nil) var _ adapter.NetworkManager = (*NetworkManager)(nil)
type NetworkManager struct { type NetworkManager struct {
logger logger.ContextLogger logger logger.ContextLogger
interfaceFinder *control.DefaultInterfaceFinder interfaceFinder *control.DefaultInterfaceFinder
networkInterfaces atomic.TypedValue[[]adapter.NetworkInterface]
autoDetectInterface bool autoDetectInterface bool
defaultInterface string defaultInterface string
defaultMark uint32 defaultMark uint32
autoRedirectOutputMark uint32 autoRedirectOutputMark uint32
networkMonitor tun.NetworkUpdateMonitor
interfaceMonitor tun.DefaultInterfaceMonitor networkMonitor tun.NetworkUpdateMonitor
packageManager tun.PackageManager interfaceMonitor tun.DefaultInterfaceMonitor
powerListener winpowrprof.EventListener packageManager tun.PackageManager
pauseManager pause.Manager powerListener winpowrprof.EventListener
platformInterface platform.Interface pauseManager pause.Manager
outboundManager adapter.OutboundManager platformInterface platform.Interface
wifiState adapter.WIFIState outboundManager adapter.OutboundManager
started bool wifiState adapter.WIFIState
started bool
} }
func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) { func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) {
@ -55,7 +64,7 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp
platformInterface: service.FromContext[platform.Interface](ctx), platformInterface: service.FromContext[platform.Interface](ctx),
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outboundManager: service.FromContext[adapter.OutboundManager](ctx),
} }
usePlatformDefaultInterfaceMonitor := nm.platformInterface != nil && nm.platformInterface.UsePlatformDefaultInterfaceMonitor() usePlatformDefaultInterfaceMonitor := nm.platformInterface != nil
enforceInterfaceMonitor := routeOptions.AutoDetectInterface enforceInterfaceMonitor := routeOptions.AutoDetectInterface
if !usePlatformDefaultInterfaceMonitor { if !usePlatformDefaultInterfaceMonitor {
networkMonitor, err := tun.NewNetworkUpdateMonitor(logger) networkMonitor, err := tun.NewNetworkUpdateMonitor(logger)
@ -87,17 +96,17 @@ func (r *NetworkManager) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(r.logger, C.StartTimeout) monitor := taskmonitor.New(r.logger, C.StartTimeout)
switch stage { switch stage {
case adapter.StartStateInitialize: case adapter.StartStateInitialize:
if r.interfaceMonitor != nil { if r.networkMonitor != nil {
monitor.Start("initialize interface monitor") monitor.Start("initialize network monitor")
err := r.interfaceMonitor.Start() err := r.networkMonitor.Start()
monitor.Finish() monitor.Finish()
if err != nil { if err != nil {
return err return err
} }
} }
if r.networkMonitor != nil { if r.interfaceMonitor != nil {
monitor.Start("initialize network monitor") monitor.Start("initialize interface monitor")
err := r.networkMonitor.Start() err := r.interfaceMonitor.Start()
monitor.Finish() monitor.Finish()
if err != nil { if err != nil {
return err return err
@ -148,20 +157,6 @@ func (r *NetworkManager) Start(stage adapter.StartStage) error {
func (r *NetworkManager) Close() error { func (r *NetworkManager) Close() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout) monitor := taskmonitor.New(r.logger, C.StopTimeout)
var err error var err error
if r.interfaceMonitor != nil {
monitor.Start("close interface monitor")
err = E.Append(err, r.interfaceMonitor.Close(), func(err error) error {
return E.Cause(err, "close interface monitor")
})
monitor.Finish()
}
if r.networkMonitor != nil {
monitor.Start("close network monitor")
err = E.Append(err, r.networkMonitor.Close(), func(err error) error {
return E.Cause(err, "close network monitor")
})
monitor.Finish()
}
if r.packageManager != nil { if r.packageManager != nil {
monitor.Start("close package manager") monitor.Start("close package manager")
err = E.Append(err, r.packageManager.Close(), func(err error) error { err = E.Append(err, r.packageManager.Close(), func(err error) error {
@ -176,6 +171,20 @@ func (r *NetworkManager) Close() error {
}) })
monitor.Finish() monitor.Finish()
} }
if r.interfaceMonitor != nil {
monitor.Start("close interface monitor")
err = E.Append(err, r.interfaceMonitor.Close(), func(err error) error {
return E.Cause(err, "close interface monitor")
})
monitor.Finish()
}
if r.networkMonitor != nil {
monitor.Start("close network monitor")
err = E.Append(err, r.networkMonitor.Close(), func(err error) error {
return E.Cause(err, "close network monitor")
})
monitor.Finish()
}
return nil return nil
} }
@ -184,18 +193,75 @@ func (r *NetworkManager) InterfaceFinder() control.InterfaceFinder {
} }
func (r *NetworkManager) UpdateInterfaces() error { func (r *NetworkManager) UpdateInterfaces() error {
if r.platformInterface == nil || !r.platformInterface.UsePlatformInterfaceGetter() { if r.platformInterface == nil {
return r.interfaceFinder.Update() return r.interfaceFinder.Update()
} else { } else {
interfaces, err := r.platformInterface.Interfaces() interfaces, err := r.platformInterface.Interfaces()
if err != nil { if err != nil {
return err return err
} }
r.interfaceFinder.UpdateInterfaces(interfaces) if C.IsDarwin {
err = r.interfaceFinder.Update()
if err != nil {
return err
}
// NEInterface only provides name,index and type
interfaces = common.Map(interfaces, func(it adapter.NetworkInterface) adapter.NetworkInterface {
iif, _ := r.interfaceFinder.ByIndex(it.Index)
if iif != nil {
it.Interface = *iif
}
return it
})
} else {
r.interfaceFinder.UpdateInterfaces(common.Map(interfaces, func(it adapter.NetworkInterface) control.Interface { return it.Interface }))
}
oldInterfaces := r.networkInterfaces.Load()
newInterfaces := common.Filter(interfaces, func(it adapter.NetworkInterface) bool {
return it.Flags&net.FlagUp != 0
})
r.networkInterfaces.Store(newInterfaces)
if len(newInterfaces) > 0 && !slices.EqualFunc(oldInterfaces, newInterfaces, func(oldInterface adapter.NetworkInterface, newInterface adapter.NetworkInterface) bool {
return oldInterface.Interface.Index == newInterface.Interface.Index &&
oldInterface.Interface.Name == newInterface.Interface.Name &&
oldInterface.Interface.Flags == newInterface.Interface.Flags &&
oldInterface.Type == newInterface.Type &&
oldInterface.Expensive == newInterface.Expensive &&
oldInterface.Constrained == newInterface.Constrained
}) {
r.logger.Info("updated available networks: ", strings.Join(common.Map(newInterfaces, func(it adapter.NetworkInterface) string {
var options []string
options = append(options, F.ToString(it.Type))
if it.Expensive {
options = append(options, "expensive")
}
if it.Constrained {
options = append(options, "constrained")
}
return F.ToString(it.Name, " (", strings.Join(options, ", "), ")")
}), ", "))
}
return nil return nil
} }
} }
func (r *NetworkManager) DefaultNetworkInterface() *adapter.NetworkInterface {
iif := r.interfaceMonitor.DefaultInterface()
if iif == nil {
return nil
}
for _, it := range r.networkInterfaces.Load() {
if it.Interface.Index == iif.Index {
return &it
}
}
return &adapter.NetworkInterface{Interface: *iif}
}
func (r *NetworkManager) NetworkInterfaces() []adapter.NetworkInterface {
return r.networkInterfaces.Load()
}
func (r *NetworkManager) DefaultInterface() string { func (r *NetworkManager) DefaultInterface() string {
return r.defaultInterface return r.defaultInterface
} }
@ -217,18 +283,17 @@ func (r *NetworkManager) AutoDetectInterfaceFunc() control.Func {
} }
return control.BindToInterfaceFunc(r.interfaceFinder, func(network string, address string) (interfaceName string, interfaceIndex int, err error) { return control.BindToInterfaceFunc(r.interfaceFinder, func(network string, address string) (interfaceName string, interfaceIndex int, err error) {
remoteAddr := M.ParseSocksaddr(address).Addr remoteAddr := M.ParseSocksaddr(address).Addr
if C.IsLinux { if remoteAddr.IsValid() {
interfaceName, interfaceIndex = r.interfaceMonitor.DefaultInterface(remoteAddr) iif, err := r.interfaceFinder.ByAddr(remoteAddr)
if interfaceIndex == -1 { if err == nil {
err = tun.ErrNoRoute return iif.Name, iif.Index, nil
}
} else {
interfaceIndex = r.interfaceMonitor.DefaultInterfaceIndex(remoteAddr)
if interfaceIndex == -1 {
err = tun.ErrNoRoute
} }
} }
return defaultInterface := r.interfaceMonitor.DefaultInterface()
if defaultInterface == nil {
return "", -1, tun.ErrNoRoute
}
return defaultInterface.Name, defaultInterface.Index, nil
}) })
} }
} }
@ -282,6 +347,12 @@ func (r *NetworkManager) notifyNetworkUpdate(event int) {
r.logger.Error("missing default interface") r.logger.Error("missing default interface")
} else { } else {
r.pauseManager.NetworkWake() r.pauseManager.NetworkWake()
defaultInterface := r.DefaultNetworkInterface()
if defaultInterface == nil {
panic("invalid interface context")
}
var options []string
options = append(options, F.ToString("index ", defaultInterface.Index))
if C.IsAndroid && r.platformInterface == nil { if C.IsAndroid && r.platformInterface == nil {
var vpnStatus string var vpnStatus string
if r.interfaceMonitor.AndroidVPNEnabled() { if r.interfaceMonitor.AndroidVPNEnabled() {
@ -289,17 +360,24 @@ func (r *NetworkManager) notifyNetworkUpdate(event int) {
} else { } else {
vpnStatus = "disabled" vpnStatus = "disabled"
} }
r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()), ", vpn ", vpnStatus) options = append(options, "vpn "+vpnStatus)
} else { } else {
r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified())) if defaultInterface.Type != "" {
options = append(options, F.ToString("type ", defaultInterface.Type))
}
if defaultInterface.Expensive {
options = append(options, "expensive")
}
if defaultInterface.Constrained {
options = append(options, "constrained")
}
} }
r.logger.Info("updated default interface ", defaultInterface.Name, ", ", strings.Join(options, ", "))
if r.platformInterface != nil { if r.platformInterface != nil {
state := r.platformInterface.ReadWIFIState() state := r.platformInterface.ReadWIFIState()
if state != r.wifiState { if state != r.wifiState {
r.wifiState = state r.wifiState = state
if state.SSID == "" && state.BSSID == "" { if state.SSID != "" {
r.logger.Info("updated WIFI state: disconnected")
} else {
r.logger.Info("updated WIFI state: SSID=", state.SSID, ", BSSID=", state.BSSID) r.logger.Info("updated WIFI state: SSID=", state.SSID, ", BSSID=", state.BSSID)
} }
} }
@ -309,7 +387,6 @@ func (r *NetworkManager) notifyNetworkUpdate(event int) {
if !r.started { if !r.started {
return return
} }
r.ResetNetwork() r.ResetNetwork()
} }

View file

@ -119,18 +119,19 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg,
return nil, err return nil, err
} }
func (t *Transport) fetchInterface() (*net.Interface, error) { func (t *Transport) fetchInterface() (*control.Interface, error) {
interfaceName := t.interfaceName
if t.autoInterface { if t.autoInterface {
if t.networkManager.InterfaceMonitor() == nil { if t.networkManager.InterfaceMonitor() == nil {
return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface") return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface")
} }
interfaceName = t.networkManager.InterfaceMonitor().DefaultInterfaceName(netip.Addr{}) defaultInterface := t.networkManager.InterfaceMonitor().DefaultInterface()
if defaultInterface == nil {
return nil, E.New("missing default interface")
}
return defaultInterface, nil
} else {
return t.networkManager.InterfaceFinder().ByName(t.interfaceName)
} }
if interfaceName == "" {
return nil, E.New("missing default interface")
}
return net.InterfaceByName(interfaceName)
} }
func (t *Transport) fetchServers() error { func (t *Transport) fetchServers() error {
@ -172,7 +173,7 @@ func (t *Transport) interfaceUpdated(int) {
} }
} }
func (t *Transport) fetchServers0(ctx context.Context, iface *net.Interface) error { func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface) error {
var listener net.ListenConfig var listener net.ListenConfig
listener.Control = control.Append(listener.Control, control.BindToInterface(t.networkManager.InterfaceFinder(), iface.Name, iface.Index)) listener.Control = control.Append(listener.Control, control.BindToInterface(t.networkManager.InterfaceFinder(), iface.Name, iface.Index))
listener.Control = control.Append(listener.Control, control.ReuseAddr()) listener.Control = control.Append(listener.Control, control.ReuseAddr())
@ -206,7 +207,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *net.Interface) err
return group.Run(ctx) return group.Run(ctx)
} }
func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.PacketConn, transactionID dhcpv4.TransactionID) error { func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn net.PacketConn, transactionID dhcpv4.TransactionID) error {
buffer := buf.NewSize(dhcpv4.MaxMessageSize) buffer := buf.NewSize(dhcpv4.MaxMessageSize)
defer buffer.Release() defer buffer.Release()
@ -246,7 +247,7 @@ func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.Pa
} }
} }
func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Addr) error { func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []netip.Addr) error {
if len(serverAddrs) > 0 { if len(serverAddrs) > 0 {
t.options.Logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string { t.options.Logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string {
return it.String() return it.String()