Add multi network dialing

This commit is contained in:
世界 2024-11-12 19:37:10 +08:00
parent b973eca28c
commit f274186752
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
24 changed files with 988 additions and 335 deletions

View file

@ -3,8 +3,10 @@ package adapter
import ( import (
"context" "context"
"net/netip" "net/netip"
"time"
"github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/common/process"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -66,6 +68,8 @@ type InboundContext struct {
InboundOptions option.InboundOptions InboundOptions option.InboundOptions
UDPDisableDomainUnmapping bool UDPDisableDomainUnmapping bool
UDPConnect bool UDPConnect bool
NetworkStrategy C.NetworkStrategy
FallbackDelay time.Duration
DNSServer string DNSServer string

View file

@ -1,6 +1,9 @@
package adapter package adapter
import ( import (
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
) )
@ -11,10 +14,10 @@ type NetworkManager interface {
UpdateInterfaces() error UpdateInterfaces() error
DefaultNetworkInterface() *NetworkInterface DefaultNetworkInterface() *NetworkInterface
NetworkInterfaces() []NetworkInterface NetworkInterfaces() []NetworkInterface
DefaultInterface() string
AutoDetectInterface() bool AutoDetectInterface() bool
AutoDetectInterfaceFunc() control.Func AutoDetectInterfaceFunc() control.Func
DefaultMark() uint32 ProtectFunc() control.Func
DefaultOptions() NetworkOptions
RegisterAutoRedirectOutputMark(mark uint32) error RegisterAutoRedirectOutputMark(mark uint32) error
AutoRedirectOutputMark() uint32 AutoRedirectOutputMark() uint32
NetworkMonitor() tun.NetworkUpdateMonitor NetworkMonitor() tun.NetworkUpdateMonitor
@ -24,6 +27,13 @@ type NetworkManager interface {
ResetNetwork() ResetNetwork()
} }
type NetworkOptions struct {
DefaultNetworkStrategy C.NetworkStrategy
DefaultFallbackDelay time.Duration
DefaultInterface string
DefaultMark uint32
}
type InterfaceUpdateListener interface { type InterfaceUpdateListener interface {
InterfaceUpdated() InterfaceUpdated()
} }

View file

@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
@ -25,35 +25,11 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a
var outConn net.Conn var outConn net.Conn
var err error var err error
if len(metadata.DestinationAddresses) > 0 { if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer {
} else { outConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.FallbackDelay)
outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } else {
} outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
err = N.ReportConnHandshakeSuccess(conn, outConn)
if err != nil {
outConn.Close()
return err
}
return CopyEarlyConn(ctx, conn, outConn)
}
func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn
var err error
if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
if err != nil {
return N.ReportHandshakeFailure(conn, err)
} }
outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, destinationAddresses)
} else { } else {
outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
} }
@ -79,7 +55,11 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
) )
if metadata.UDPConnect { if metadata.UDPConnect {
if len(metadata.DestinationAddresses) > 0 { if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer {
outConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.FallbackDelay)
} else {
outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
}
} else { } else {
outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
} }
@ -93,7 +73,11 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
} }
} else { } else {
if len(metadata.DestinationAddresses) > 0 { if len(metadata.DestinationAddresses) > 0 {
outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer {
outPacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, parallelDialer, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.FallbackDelay)
} else {
outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
}
} else { } else {
outPacketConn, err = this.ListenPacket(ctx, metadata.Destination) outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
} }
@ -129,76 +113,6 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn)) return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
} }
func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata)
var (
outPacketConn net.PacketConn
outConn net.Conn
destinationAddress netip.Addr
err error
)
if metadata.UDPConnect {
if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, destinationAddresses)
} else {
outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
}
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr())
if connRemoteAddr != metadata.Destination.Addr {
destinationAddress = connRemoteAddr
}
} else {
if len(metadata.DestinationAddresses) > 0 {
outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
} else {
outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
}
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
}
err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn)
if err != nil {
outPacketConn.Close()
return err
}
if destinationAddress.IsValid() {
if metadata.Destination.IsFqdn() {
outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
}
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
natConn.UpdateDestination(destinationAddress)
}
}
switch metadata.Protocol {
case C.ProtocolSTUN:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout)
case C.ProtocolQUIC:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout)
case C.ProtocolDNS:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
}
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
}
func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error { func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {
if cachedReader, isCached := conn.(N.CachedReader); isCached { if cachedReader, isCached := conn.(N.CachedReader); isCached {
payload := cachedReader.ReadCached() payload := cachedReader.ReadCached()

View file

@ -10,66 +10,93 @@ import (
"github.com/sagernet/sing-box/common/conntrack" "github.com/sagernet/sing-box/common/conntrack"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
var _ WireGuardListener = (*DefaultDialer)(nil) var (
_ ParallelInterfaceDialer = (*DefaultDialer)(nil)
_ WireGuardListener = (*DefaultDialer)(nil)
)
type DefaultDialer struct { type DefaultDialer struct {
dialer4 tcpDialer dialer4 tcpDialer
dialer6 tcpDialer dialer6 tcpDialer
udpDialer4 net.Dialer udpDialer4 net.Dialer
udpDialer6 net.Dialer udpDialer6 net.Dialer
udpListener net.ListenConfig udpListener net.ListenConfig
udpAddr4 string udpAddr4 string
udpAddr6 string udpAddr6 string
isWireGuardListener bool isWireGuardListener bool
networkManager adapter.NetworkManager
networkStrategy C.NetworkStrategy
networkFallbackDelay time.Duration
networkLastFallback atomic.TypedValue[time.Time]
} }
func NewDefault(networkManager adapter.NetworkManager, options option.DialerOptions) (*DefaultDialer, error) { func NewDefault(networkManager adapter.NetworkManager, options option.DialerOptions) (*DefaultDialer, error) {
var dialer net.Dialer var (
var listener net.ListenConfig dialer net.Dialer
listener net.ListenConfig
interfaceFinder control.InterfaceFinder
networkStrategy C.NetworkStrategy
networkFallbackDelay time.Duration
)
if networkManager != nil {
interfaceFinder = networkManager.InterfaceFinder()
} else {
interfaceFinder = control.NewDefaultInterfaceFinder()
}
if options.BindInterface != "" { if options.BindInterface != "" {
var interfaceFinder control.InterfaceFinder
if networkManager != nil {
interfaceFinder = networkManager.InterfaceFinder()
} else {
interfaceFinder = control.NewDefaultInterfaceFinder()
}
bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1) bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
dialer.Control = control.Append(dialer.Control, bindFunc) dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc)
} else if networkManager != nil && networkManager.AutoDetectInterface() {
bindFunc := networkManager.AutoDetectInterfaceFunc()
dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc)
} else if networkManager != nil && networkManager.DefaultInterface() != "" {
bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), networkManager.DefaultInterface(), -1)
dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc)
}
var autoRedirectOutputMark uint32
if networkManager != nil {
autoRedirectOutputMark = networkManager.AutoRedirectOutputMark()
}
if autoRedirectOutputMark > 0 {
dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark))
listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark))
} }
if options.RoutingMark > 0 { if options.RoutingMark > 0 {
dialer.Control = control.Append(dialer.Control, control.RoutingMark(options.RoutingMark)) dialer.Control = control.Append(dialer.Control, control.RoutingMark(options.RoutingMark))
listener.Control = control.Append(listener.Control, control.RoutingMark(options.RoutingMark)) listener.Control = control.Append(listener.Control, control.RoutingMark(options.RoutingMark))
}
if networkManager != nil {
autoRedirectOutputMark := networkManager.AutoRedirectOutputMark()
if autoRedirectOutputMark > 0 { if autoRedirectOutputMark > 0 {
return nil, E.New("`auto_redirect` with `route_[_exclude]_address_set is conflict with `routing_mark`") if options.RoutingMark > 0 {
return nil, E.New("`routing_mark` is conflict with `tun.auto_redirect` with `tun.route_[_exclude]_address_set")
}
dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark))
listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark))
} }
} else if networkManager != nil && networkManager.DefaultMark() > 0 { }
dialer.Control = control.Append(dialer.Control, control.RoutingMark(networkManager.DefaultMark())) if C.NetworkStrategy(options.NetworkStrategy) != C.NetworkStrategyDefault {
listener.Control = control.Append(listener.Control, control.RoutingMark(networkManager.DefaultMark())) if options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil {
if autoRedirectOutputMark > 0 { return nil, E.New("`network_strategy` is conflict with `bind_interface`, `inet4_bind_address` and `inet6_bind_address`")
return nil, E.New("`auto_redirect` with `route_[_exclude]_address_set is conflict with `default_mark`") }
networkStrategy = C.NetworkStrategy(options.NetworkStrategy)
networkFallbackDelay = time.Duration(options.NetworkFallbackDelay)
if networkManager == nil || !networkManager.AutoDetectInterface() {
return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
}
}
if networkManager != nil && options.BindInterface == "" && options.Inet4BindAddress == nil && options.Inet6BindAddress == nil {
defaultOptions := networkManager.DefaultOptions()
if defaultOptions.DefaultInterface != "" {
bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), defaultOptions.DefaultInterface, -1)
dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc)
} else if networkManager.AutoDetectInterface() {
if defaultOptions.DefaultNetworkStrategy != C.NetworkStrategyDefault && C.NetworkStrategy(options.NetworkStrategy) == C.NetworkStrategyDefault {
networkStrategy = defaultOptions.DefaultNetworkStrategy
networkFallbackDelay = defaultOptions.DefaultFallbackDelay
bindFunc := networkManager.ProtectFunc()
dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc)
} else {
bindFunc := networkManager.AutoDetectInterfaceFunc()
dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc)
}
} }
} }
if options.ReuseAddr { if options.ReuseAddr {
@ -130,6 +157,9 @@ func NewDefault(networkManager adapter.NetworkManager, options option.DialerOpti
listener.Control = control.Append(listener.Control, controlFn) listener.Control = control.Append(listener.Control, controlFn)
} }
} }
if networkStrategy != C.NetworkStrategyDefault && options.TCPFastOpen {
return nil, E.New("`tcp_fast_open` is conflict with `network_strategy` or `route.default_network_strategy`")
}
tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen) tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
if err != nil { if err != nil {
return nil, err return nil, err
@ -139,14 +169,17 @@ func NewDefault(networkManager adapter.NetworkManager, options option.DialerOpti
return nil, err return nil, err
} }
return &DefaultDialer{ return &DefaultDialer{
tcpDialer4, dialer4: tcpDialer4,
tcpDialer6, dialer6: tcpDialer6,
udpDialer4, udpDialer4: udpDialer4,
udpDialer6, udpDialer6: udpDialer6,
listener, udpListener: listener,
udpAddr4, udpAddr4: udpAddr4,
udpAddr6, udpAddr6: udpAddr6,
options.IsWireGuardListener, isWireGuardListener: options.IsWireGuardListener,
networkManager: networkManager,
networkStrategy: networkStrategy,
networkFallbackDelay: networkFallbackDelay,
}, nil }, nil
} }
@ -154,33 +187,88 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address
if !address.IsValid() { if !address.IsValid() {
return nil, E.New("invalid address") return nil, E.New("invalid address")
} }
switch N.NetworkName(network) { if d.networkStrategy == C.NetworkStrategyDefault {
case N.NetworkUDP: switch N.NetworkName(network) {
if !address.IsIPv6() { case N.NetworkUDP:
return trackConn(d.udpDialer4.DialContext(ctx, network, address.String())) if !address.IsIPv6() {
} else { return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
return trackConn(d.udpDialer6.DialContext(ctx, network, address.String())) } else {
return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
}
}
if !address.IsIPv6() {
return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
} else {
return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
} }
}
if !address.IsIPv6() {
return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
} else { } else {
return trackConn(DialSlowContext(&d.dialer6, ctx, network, address)) return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkFallbackDelay)
} }
} }
func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network string, address M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) {
if strategy == C.NetworkStrategyDefault {
return d.DialContext(ctx, network, address)
}
if !d.networkManager.AutoDetectInterface() {
return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
}
var dialer net.Dialer
if N.NetworkName(network) == N.NetworkTCP {
dialer = dialerFromTCPDialer(d.dialer4)
} else {
dialer = d.udpDialer4
}
fastFallback := time.Now().Sub(d.networkLastFallback.Load()) < C.TCPTimeout
var (
conn net.Conn
isPrimary bool
err error
)
if !fastFallback {
conn, isPrimary, err = d.dialParallelInterface(ctx, dialer, network, address.String(), strategy, fallbackDelay)
} else {
conn, isPrimary, err = d.dialParallelInterfaceFastFallback(ctx, dialer, network, address.String(), strategy, fallbackDelay, d.networkLastFallback.Store)
}
if err != nil {
return nil, err
}
if !fastFallback && !isPrimary {
d.networkLastFallback.Store(time.Now())
}
return trackConn(conn, nil)
}
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if destination.IsIPv6() { if d.networkStrategy == C.NetworkStrategyDefault {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6)) if destination.IsIPv6() {
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() { return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4)) } else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
} else {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
}
} else { } else {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4)) return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkFallbackDelay)
} }
} }
func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, error) {
if strategy == C.NetworkStrategyDefault {
return d.ListenPacket(ctx, destination)
}
if !d.networkManager.AutoDetectInterface() {
return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
}
network := N.NetworkUDP
if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
network += "4"
}
return trackPacketConn(d.listenSerialInterfacePacket(ctx, d.udpListener, network, "", strategy, fallbackDelay))
}
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address)) return trackPacketConn(d.listenSerialInterfacePacket(context.Background(), d.udpListener, network, address, d.networkStrategy, d.networkFallbackDelay))
} }
func trackConn(conn net.Conn, err error) (net.Conn, error) { func trackConn(conn net.Conn, err error) (net.Conn, error) {

View file

@ -13,3 +13,7 @@ type tcpDialer = tfo.Dialer
func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) { func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) {
return tfo.Dialer{Dialer: dialer, DisableTFO: !tfoEnabled}, nil return tfo.Dialer{Dialer: dialer, DisableTFO: !tfoEnabled}, nil
} }
func dialerFromTCPDialer(dialer tcpDialer) net.Dialer {
return dialer.Dialer
}

View file

@ -16,3 +16,7 @@ func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) {
} }
return dialer, nil return dialer, nil
} }
func dialerFromTCPDialer(dialer tcpDialer) net.Dialer {
return dialer
}

View file

@ -0,0 +1,241 @@
package dialer
import (
"context"
"net"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
N "github.com/sagernet/sing/common/network"
)
func (d *DefaultDialer) dialParallelInterface(ctx context.Context, dialer net.Dialer, network string, addr string, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, bool, error) {
primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy)
if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
return nil, false, E.New("no available network interface")
}
if fallbackDelay == 0 {
fallbackDelay = N.DefaultFallbackDelay
}
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
net.Conn
error
primary bool
}
results := make(chan dialResult) // unbuffered
startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) {
perNetDialer := dialer
perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index))
conn, err := perNetDialer.DialContext(ctx, network, addr)
if err != nil {
select {
case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Name, ")"), primary: primary}:
case <-returned:
}
} else {
select {
case results <- dialResult{Conn: conn}:
case <-returned:
conn.Close()
}
}
}
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
for _, iif := range primaryInterfaces {
go startRacer(primaryCtx, true, iif)
}
var (
fallbackTimer *time.Timer
fallbackChan <-chan time.Time
)
if len(fallbackInterfaces) > 0 {
fallbackTimer = time.NewTimer(fallbackDelay)
defer fallbackTimer.Stop()
fallbackChan = fallbackTimer.C
}
var errors []error
for {
select {
case <-fallbackChan:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
for _, iif := range fallbackInterfaces {
go startRacer(fallbackCtx, false, iif)
}
case res := <-results:
if res.error == nil {
return res.Conn, res.primary, nil
}
errors = append(errors, res.error)
if len(errors) == len(primaryInterfaces)+len(fallbackInterfaces) {
return nil, false, E.Errors(errors...)
}
if res.primary && fallbackTimer != nil && fallbackTimer.Stop() {
fallbackTimer.Reset(0)
}
}
}
}
func (d *DefaultDialer) dialParallelInterfaceFastFallback(ctx context.Context, dialer net.Dialer, network string, addr string, strategy C.NetworkStrategy, fallbackDelay time.Duration, resetFastFallback func(time.Time)) (net.Conn, bool, error) {
primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy)
if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
return nil, false, E.New("no available network interface")
}
if fallbackDelay == 0 {
fallbackDelay = N.DefaultFallbackDelay
}
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
net.Conn
error
primary bool
}
startAt := time.Now()
results := make(chan dialResult) // unbuffered
startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) {
perNetDialer := dialer
perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index))
conn, err := perNetDialer.DialContext(ctx, network, addr)
if err != nil {
select {
case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Name, ")"), primary: primary}:
case <-returned:
}
} else {
select {
case results <- dialResult{Conn: conn}:
case <-returned:
if primary && time.Since(startAt) <= fallbackDelay {
resetFastFallback(time.Time{})
}
conn.Close()
}
}
}
for _, iif := range primaryInterfaces {
go startRacer(ctx, true, iif)
}
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
for _, iif := range fallbackInterfaces {
go startRacer(fallbackCtx, false, iif)
}
var errors []error
for {
select {
case res := <-results:
if res.error == nil {
return res.Conn, res.primary, nil
}
errors = append(errors, res.error)
if len(errors) == len(primaryInterfaces)+len(fallbackInterfaces) {
return nil, false, E.Errors(errors...)
}
}
}
}
func (d *DefaultDialer) listenSerialInterfacePacket(ctx context.Context, listener net.ListenConfig, network string, addr string, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, error) {
primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy)
if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
return nil, E.New("no available network interface")
}
if fallbackDelay == 0 {
fallbackDelay = N.DefaultFallbackDelay
}
var errors []error
for _, primaryInterface := range primaryInterfaces {
perNetListener := listener
perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, primaryInterface.Name, primaryInterface.Index))
conn, err := perNetListener.ListenPacket(ctx, network, addr)
if err == nil {
return conn, nil
}
errors = append(errors, E.Cause(err, "listen ", primaryInterface.Name, " (", primaryInterface.Name, ")"))
}
for _, fallbackInterface := range fallbackInterfaces {
perNetListener := listener
perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, fallbackInterface.Name, fallbackInterface.Index))
conn, err := perNetListener.ListenPacket(ctx, network, addr)
if err == nil {
return conn, nil
}
errors = append(errors, E.Cause(err, "listen ", fallbackInterface.Name, " (", fallbackInterface.Name, ")"))
}
return nil, E.Errors(errors...)
}
func selectInterfaces(networkManager adapter.NetworkManager, strategy C.NetworkStrategy) (primaryInterfaces []adapter.NetworkInterface, fallbackInterfaces []adapter.NetworkInterface) {
interfaces := networkManager.NetworkInterfaces()
switch strategy {
case C.NetworkStrategyFallback:
defaultIf := networkManager.InterfaceMonitor().DefaultInterface()
if defaultIf != nil {
for _, iif := range interfaces {
if iif.Index == defaultIf.Index {
primaryInterfaces = append(primaryInterfaces, iif)
} else {
fallbackInterfaces = append(fallbackInterfaces, iif)
}
}
} else {
primaryInterfaces = interfaces
}
case C.NetworkStrategyHybrid:
primaryInterfaces = interfaces
case C.NetworkStrategyWIFI:
for _, iif := range interfaces {
if iif.Type == C.InterfaceTypeWIFI {
primaryInterfaces = append(primaryInterfaces, iif)
} else {
fallbackInterfaces = append(fallbackInterfaces, iif)
}
}
case C.NetworkStrategyCellular:
for _, iif := range interfaces {
if iif.Type == C.InterfaceTypeCellular {
primaryInterfaces = append(primaryInterfaces, iif)
} else {
fallbackInterfaces = append(fallbackInterfaces, iif)
}
}
case C.NetworkStrategyEthernet:
for _, iif := range interfaces {
if iif.Type == C.InterfaceTypeEthernet {
primaryInterfaces = append(primaryInterfaces, iif)
} else {
fallbackInterfaces = append(fallbackInterfaces, iif)
}
}
case C.NetworkStrategyWIFIOnly:
for _, iif := range interfaces {
if iif.Type == C.InterfaceTypeWIFI {
primaryInterfaces = append(primaryInterfaces, iif)
}
}
case C.NetworkStrategyCellularOnly:
for _, iif := range interfaces {
if iif.Type == C.InterfaceTypeCellular {
primaryInterfaces = append(primaryInterfaces, iif)
}
}
case C.NetworkStrategyEthernetOnly:
for _, iif := range interfaces {
if iif.Type == C.InterfaceTypeEthernet {
primaryInterfaces = append(primaryInterfaces, iif)
}
}
default:
panic(F.ToString("unknown network strategy: ", strategy))
}
return primaryInterfaces, fallbackInterfaces
}

View file

@ -0,0 +1,122 @@
package dialer
import (
"context"
"net"
"net/netip"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func DialSerialNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) {
if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, fallbackDelay)
}
var errors []error
for _, address := range destinationAddresses {
conn, err := dialer.DialParallelInterface(ctx, network, M.SocksaddrFrom(address, destination.Port), strategy, fallbackDelay)
if err == nil {
return conn, nil
}
errors = append(errors, err)
}
return nil, E.Errors(errors...)
}
func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) {
if fallbackDelay == 0 {
fallbackDelay = N.DefaultFallbackDelay
}
returned := make(chan struct{})
defer close(returned)
addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
return address.Is4() || address.Is4In6()
})
addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
return address.Is6() && !address.Is4In6()
})
if len(addresses4) == 0 || len(addresses6) == 0 {
return DialSerialNetwork(ctx, dialer, network, destination, destinationAddresses, strategy, fallbackDelay)
}
var primaries, fallbacks []netip.Addr
if preferIPv6 {
primaries = addresses6
fallbacks = addresses4
} else {
primaries = addresses4
fallbacks = addresses6
}
type dialResult struct {
net.Conn
error
primary bool
done bool
}
results := make(chan dialResult) // unbuffered
startRacer := func(ctx context.Context, primary bool) {
ras := primaries
if !primary {
ras = fallbacks
}
c, err := DialSerialNetwork(ctx, dialer, network, destination, ras, strategy, fallbackDelay)
select {
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
if c != nil {
c.Close()
}
}
}
var primary, fallback dialResult
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
go startRacer(primaryCtx, true)
fallbackTimer := time.NewTimer(fallbackDelay)
defer fallbackTimer.Stop()
for {
select {
case <-fallbackTimer.C:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
go startRacer(fallbackCtx, false)
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if res.primary {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
return nil, primary.error
}
if res.primary && fallbackTimer.Stop() {
fallbackTimer.Reset(0)
}
}
}
}
func ListenSerialNetworkPacket(ctx context.Context, dialer ParallelInterfaceDialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {
if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, fallbackDelay)
}
var errors []error
for _, address := range destinationAddresses {
conn, err := dialer.ListenSerialInterfacePacket(ctx, M.SocksaddrFrom(address, destination.Port), strategy, fallbackDelay)
if err == nil {
return conn, address, nil
}
errors = append(errors, err)
}
return nil, netip.Addr{}, E.Errors(errors...)
}

View file

@ -2,12 +2,16 @@ package dialer
import ( import (
"context" "context"
"net"
"net/netip"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
) )
@ -49,3 +53,35 @@ func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
} }
return dialer, nil return dialer, nil
} }
func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInterfaceDialer, error) {
if options.Detour != "" {
return nil, E.New("`detour` is not supported in direct context")
}
networkManager := service.FromContext[adapter.NetworkManager](ctx)
if options.IsWireGuardListener {
return NewDefault(networkManager, options)
}
dialer, err := NewDefault(networkManager, options)
if err != nil {
return nil, err
}
return NewResolveParallelInterfaceDialer(
service.FromContext[adapter.Router](ctx),
dialer,
true,
dns.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay),
), nil
}
type ParallelInterfaceDialer interface {
N.Dialer
DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error)
ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, error)
}
type ParallelNetworkDialer interface {
DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error)
ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error)
}

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
@ -14,7 +15,12 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
type ResolveDialer struct { var (
_ N.Dialer = (*resolveDialer)(nil)
_ ParallelInterfaceDialer = (*resolveParallelNetworkDialer)(nil)
)
type resolveDialer struct {
dialer N.Dialer dialer N.Dialer
parallel bool parallel bool
router adapter.Router router adapter.Router
@ -22,8 +28,8 @@ type ResolveDialer struct {
fallbackDelay time.Duration fallbackDelay time.Duration
} }
func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) *ResolveDialer { func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) N.Dialer {
return &ResolveDialer{ return &resolveDialer{
dialer, dialer,
parallel, parallel,
router, router,
@ -32,7 +38,25 @@ func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, str
} }
} }
func (d *ResolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { type resolveParallelNetworkDialer struct {
resolveDialer
dialer ParallelInterfaceDialer
}
func NewResolveParallelInterfaceDialer(router adapter.Router, dialer ParallelInterfaceDialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer {
return &resolveParallelNetworkDialer{
resolveDialer{
dialer,
parallel,
router,
strategy,
fallbackDelay,
},
dialer,
}
}
func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
@ -57,7 +81,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
} }
} }
func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
@ -82,6 +106,59 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
} }
func (d *ResolveDialer) Upstream() any { func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) {
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
if err != nil {
return nil, err
}
if fallbackDelay == 0 {
fallbackDelay = d.fallbackDelay
}
if d.parallel {
return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, strategy, fallbackDelay)
} else {
return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, fallbackDelay)
}
}
func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, error) {
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
if err != nil {
return nil, err
}
conn, destinationAddress, err := ListenSerialNetworkPacket(ctx, d.dialer, destination, addresses, strategy, fallbackDelay)
if err != nil {
return nil, err
}
return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
}
func (d *resolveDialer) Upstream() any {
return d.dialer return d.dialer
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/shell"
@ -33,7 +34,7 @@ func NewSystemProxy(ctx context.Context, serverAddr M.Socksaddr, supportSOCKS bo
serverAddr: serverAddr, serverAddr: serverAddr,
supportSOCKS: supportSOCKS, supportSOCKS: supportSOCKS,
} }
proxy.element = interfaceMonitor.RegisterCallback(proxy.update) proxy.element = interfaceMonitor.RegisterCallback(proxy.routeUpdate)
return proxy, nil return proxy, nil
} }
@ -65,11 +66,8 @@ func (p *DarwinSystemProxy) Disable() error {
return err return err
} }
func (p *DarwinSystemProxy) update(event int) { func (p *DarwinSystemProxy) routeUpdate(defaultInterface *control.Interface, flags int) {
if event&tun.EventInterfaceUpdate == 0 { if !p.isEnabled || defaultInterface == nil {
return
}
if !p.isEnabled {
return return
} }
_ = p.update0() _ = p.update0()

View file

@ -1,8 +1,50 @@
package constant package constant
import (
"github.com/sagernet/sing/common"
F "github.com/sagernet/sing/common/format"
)
const ( const (
InterfaceTypeWIFI = "wifi" InterfaceTypeWIFI = "wifi"
InterfaceTypeCellular = "cellular" InterfaceTypeCellular = "cellular"
InterfaceTypeEthernet = "ethernet" InterfaceTypeEthernet = "ethernet"
InterfaceTypeOther = "other" InterfaceTypeOther = "other"
) )
type NetworkStrategy int
const (
NetworkStrategyDefault NetworkStrategy = iota
NetworkStrategyFallback
NetworkStrategyHybrid
NetworkStrategyWIFI
NetworkStrategyCellular
NetworkStrategyEthernet
NetworkStrategyWIFIOnly
NetworkStrategyCellularOnly
NetworkStrategyEthernetOnly
)
var (
NetworkStrategyToString = map[NetworkStrategy]string{
NetworkStrategyDefault: "default",
NetworkStrategyFallback: "fallback",
NetworkStrategyHybrid: "hybrid",
NetworkStrategyWIFI: "wifi",
NetworkStrategyCellular: "cellular",
NetworkStrategyEthernet: "ethernet",
NetworkStrategyWIFIOnly: "wifi_only",
NetworkStrategyCellularOnly: "cellular_only",
NetworkStrategyEthernetOnly: "ethernet_only",
}
StringToNetworkStrategy = common.ReverseMap(NetworkStrategyToString)
)
func (s NetworkStrategy) String() string {
name, loaded := NetworkStrategyToString[s]
if !loaded {
return F.ToString(int(s))
}
return name
}

View file

@ -67,7 +67,7 @@ func (m *platformDefaultInterfaceMonitor) UpdateDefaultInterface(interfaceName s
callbacks := m.callbacks.Array() callbacks := m.callbacks.Array()
m.defaultInterfaceAccess.Unlock() m.defaultInterfaceAccess.Unlock()
for _, callback := range callbacks { for _, callback := range callbacks {
callback(tun.EventInterfaceUpdate) callback(nil, 0)
} }
return return
} }
@ -86,6 +86,6 @@ func (m *platformDefaultInterfaceMonitor) UpdateDefaultInterface(interfaceName s
callbacks := m.callbacks.Array() callbacks := m.callbacks.Array()
m.defaultInterfaceAccess.Unlock() m.defaultInterfaceAccess.Unlock()
for _, callback := range callbacks { for _, callback := range callbacks {
callback(tun.EventInterfaceUpdate) callback(newInterface, 0)
} }
} }

View file

@ -65,21 +65,23 @@ type DialerOptionsWrapper interface {
} }
type DialerOptions struct { type DialerOptions struct {
Detour string `json:"detour,omitempty"` Detour string `json:"detour,omitempty"`
BindInterface string `json:"bind_interface,omitempty"` BindInterface string `json:"bind_interface,omitempty"`
Inet4BindAddress *badoption.Addr `json:"inet4_bind_address,omitempty"` Inet4BindAddress *badoption.Addr `json:"inet4_bind_address,omitempty"`
Inet6BindAddress *badoption.Addr `json:"inet6_bind_address,omitempty"` Inet6BindAddress *badoption.Addr `json:"inet6_bind_address,omitempty"`
ProtectPath string `json:"protect_path,omitempty"` ProtectPath string `json:"protect_path,omitempty"`
RoutingMark uint32 `json:"routing_mark,omitempty"` RoutingMark uint32 `json:"routing_mark,omitempty"`
ReuseAddr bool `json:"reuse_addr,omitempty"` ReuseAddr bool `json:"reuse_addr,omitempty"`
ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"`
TCPFastOpen bool `json:"tcp_fast_open,omitempty"` TCPFastOpen bool `json:"tcp_fast_open,omitempty"`
TCPMultiPath bool `json:"tcp_multi_path,omitempty"` TCPMultiPath bool `json:"tcp_multi_path,omitempty"`
UDPFragment *bool `json:"udp_fragment,omitempty"` UDPFragment *bool `json:"udp_fragment,omitempty"`
UDPFragmentDefault bool `json:"-"` UDPFragmentDefault bool `json:"-"`
DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"`
FallbackDelay badoption.Duration `json:"fallback_delay,omitempty"` NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"`
IsWireGuardListener bool `json:"-"` FallbackDelay badoption.Duration `json:"fallback_delay,omitempty"`
NetworkFallbackDelay badoption.Duration `json:"network_fallback_delay,omitempty"`
IsWireGuardListener bool `json:"-"`
} }
func (o *DialerOptions) TakeDialerOptions() DialerOptions { func (o *DialerOptions) TakeDialerOptions() DialerOptions {

View file

@ -1,16 +1,20 @@
package option package option
import "github.com/sagernet/sing/common/json/badoption"
type RouteOptions struct { type RouteOptions struct {
GeoIP *GeoIPOptions `json:"geoip,omitempty"` GeoIP *GeoIPOptions `json:"geoip,omitempty"`
Geosite *GeositeOptions `json:"geosite,omitempty"` Geosite *GeositeOptions `json:"geosite,omitempty"`
Rules []Rule `json:"rules,omitempty"` Rules []Rule `json:"rules,omitempty"`
RuleSet []RuleSet `json:"rule_set,omitempty"` RuleSet []RuleSet `json:"rule_set,omitempty"`
Final string `json:"final,omitempty"` Final string `json:"final,omitempty"`
FindProcess bool `json:"find_process,omitempty"` FindProcess bool `json:"find_process,omitempty"`
AutoDetectInterface bool `json:"auto_detect_interface,omitempty"` AutoDetectInterface bool `json:"auto_detect_interface,omitempty"`
OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"` OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"`
DefaultInterface string `json:"default_interface,omitempty"` DefaultInterface string `json:"default_interface,omitempty"`
DefaultMark uint32 `json:"default_mark,omitempty"` DefaultMark uint32 `json:"default_mark,omitempty"`
DefaultNetworkStrategy NetworkStrategy `json:"default_network_strategy,omitempty"`
DefaultFallbackDelay badoption.Duration `json:"default_fallback_delay,omitempty"`
} }
type GeoIPOptions struct { type GeoIPOptions struct {

View file

@ -137,14 +137,18 @@ func (r *DNSRuleAction) UnmarshalJSONContext(ctx context.Context, data []byte) e
} }
type RouteActionOptions struct { type RouteActionOptions struct {
Outbound string `json:"outbound,omitempty"` Outbound string `json:"outbound,omitempty"`
UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"`
UDPConnect bool `json:"udp_connect,omitempty"` FallbackDelay uint32 `json:"fallback_delay,omitempty"`
UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"`
UDPConnect bool `json:"udp_connect,omitempty"`
} }
type _RouteOptionsActionOptions struct { type _RouteOptionsActionOptions struct {
UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"`
UDPConnect bool `json:"udp_connect,omitempty"` FallbackDelay uint32 `json:"fallback_delay,omitempty"`
UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"`
UDPConnect bool `json:"udp_connect,omitempty"`
} }
type RouteOptionsActionOptions _RouteOptionsActionOptions type RouteOptionsActionOptions _RouteOptionsActionOptions

View file

@ -3,6 +3,7 @@ package option
import ( import (
"strings" "strings"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
@ -150,3 +151,23 @@ func DNSQueryTypeToString(queryType uint16) string {
} }
return F.ToString(queryType) return F.ToString(queryType)
} }
type NetworkStrategy C.NetworkStrategy
func (n NetworkStrategy) MarshalJSON() ([]byte, error) {
return json.Marshal(C.NetworkStrategy(n).String())
}
func (n *NetworkStrategy) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
strategy, loaded := C.StringToNetworkStrategy[value]
if !loaded {
return E.New("unknown network strategy: ", value)
}
*n = NetworkStrategy(strategy)
return nil
}

View file

@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
dns "github.com/sagernet/sing-dns" dns "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
@ -24,31 +25,38 @@ func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.DirectOutboundOptions](registry, C.TypeDirect, NewOutbound) outbound.Register[option.DirectOutboundOptions](registry, C.TypeDirect, NewOutbound)
} }
var _ N.ParallelDialer = (*Outbound)(nil) var (
_ N.ParallelDialer = (*Outbound)(nil)
_ dialer.ParallelNetworkDialer = (*Outbound)(nil)
)
type Outbound struct { type Outbound struct {
outbound.Adapter outbound.Adapter
logger logger.ContextLogger logger logger.ContextLogger
dialer N.Dialer dialer dialer.ParallelInterfaceDialer
domainStrategy dns.DomainStrategy domainStrategy dns.DomainStrategy
fallbackDelay time.Duration fallbackDelay time.Duration
overrideOption int networkStrategy C.NetworkStrategy
overrideDestination M.Socksaddr networkFallbackDelay time.Duration
overrideOption int
overrideDestination M.Socksaddr
// loopBack *loopBackDetector // loopBack *loopBackDetector
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) {
options.UDPFragmentDefault = true options.UDPFragmentDefault = true
outboundDialer, err := dialer.New(ctx, options.DialerOptions) outboundDialer, err := dialer.NewDirect(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions),
logger: logger, logger: logger,
domainStrategy: dns.DomainStrategy(options.DomainStrategy), domainStrategy: dns.DomainStrategy(options.DomainStrategy),
fallbackDelay: time.Duration(options.FallbackDelay), fallbackDelay: time.Duration(options.FallbackDelay),
dialer: outboundDialer, networkStrategy: C.NetworkStrategy(options.NetworkStrategy),
networkFallbackDelay: time.Duration(options.NetworkFallbackDelay),
dialer: outboundDialer,
// loopBack: newLoopBackDetector(router), // loopBack: newLoopBackDetector(router),
} }
if options.ProxyProtocol != 0 { if options.ProxyProtocol != 0 {
@ -96,33 +104,6 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination
return h.dialer.DialContext(ctx, network, destination) return h.dialer.DialContext(ctx, network, destination)
} }
func (h *Outbound) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) {
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
switch h.overrideOption {
case 1, 2:
// override address
return h.DialContext(ctx, network, destination)
case 3:
destination.Port = h.overrideDestination.Port
}
network = N.NetworkName(network)
switch network {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
var domainStrategy dns.DomainStrategy
if h.domainStrategy != dns.DomainStrategyAsIS {
domainStrategy = h.domainStrategy
} else {
domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
}
return N.DialParallel(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, h.fallbackDelay)
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
ctx, metadata := adapter.ExtendContext(ctx) ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag() metadata.Outbound = h.Tag()
@ -154,6 +135,110 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
return conn, nil return conn, nil
} }
func (h *Outbound) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) {
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
switch h.overrideOption {
case 1, 2:
// override address
return h.DialContext(ctx, network, destination)
case 3:
destination.Port = h.overrideDestination.Port
}
network = N.NetworkName(network)
switch network {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
var domainStrategy dns.DomainStrategy
if h.domainStrategy != dns.DomainStrategyAsIS {
domainStrategy = h.domainStrategy
} else {
domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
}
switch domainStrategy {
case dns.DomainStrategyUseIPv4:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv4 address available for ", destination)
}
case dns.DomainStrategyUseIPv6:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv6 address available for ", destination)
}
}
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, h.networkStrategy, h.fallbackDelay)
}
func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) {
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
switch h.overrideOption {
case 1, 2:
// override address
return h.DialContext(ctx, network, destination)
case 3:
destination.Port = h.overrideDestination.Port
}
network = N.NetworkName(network)
switch network {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
var domainStrategy dns.DomainStrategy
if h.domainStrategy != dns.DomainStrategyAsIS {
domainStrategy = h.domainStrategy
} else {
domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
}
switch domainStrategy {
case dns.DomainStrategyUseIPv4:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv4 address available for ", destination)
}
case dns.DomainStrategyUseIPv6:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv6 address available for ", destination)
}
}
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, networkStrategy, fallbackDelay)
}
func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
switch h.overrideOption {
case 1:
destination = h.overrideDestination
case 2:
newDestination := h.overrideDestination
newDestination.Port = destination.Port
destination = newDestination
case 3:
destination.Port = h.overrideDestination.Port
}
if h.overrideOption == 0 {
h.logger.InfoContext(ctx, "outbound packet connection")
} else {
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
conn, newDestination, err := dialer.ListenSerialNetworkPacket(ctx, h.dialer, destination, destinationAddresses, networkStrategy, fallbackDelay)
if err != nil {
return nil, netip.Addr{}, err
}
return conn, newDestination, nil
}
/*func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { /*func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) { if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
return E.New("reject loopback connection to ", metadata.Destination) return E.New("reject loopback connection to ", metadata.Destination)

View file

@ -10,7 +10,6 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
@ -115,23 +114,3 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
h.logger.InfoContext(ctx, "outbound packet connection to ", destination) h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return h.client.ListenPacket(ctx, destination) return h.client.ListenPacket(ctx, destination)
} }
// TODO
// Deprecated
func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
if h.resolve {
return outbound.NewDirectConnection(ctx, h.router, h, conn, metadata, dns.DomainStrategyUseIPv4)
} else {
return outbound.NewConnection(ctx, h, conn, metadata)
}
}
// TODO
// Deprecated
func (h *Outbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
if h.resolve {
return outbound.NewDirectPacketConnection(ctx, h.router, h, conn, metadata, dns.DomainStrategyUseIPv4)
} else {
return outbound.NewPacketConnection(ctx, h, conn, metadata)
}
}

View file

@ -16,7 +16,6 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard" "github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -231,15 +230,3 @@ func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
} }
return w.tunDevice.ListenPacket(ctx, destination) return w.tunDevice.ListenPacket(ctx, destination)
} }
// TODO
// Deprecated
func (w *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return outbound.NewDirectConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS)
}
// TODO
// Deprecated
func (w *Outbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return outbound.NewDirectPacketConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS)
}

View file

@ -8,6 +8,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/conntrack" "github.com/sagernet/sing-box/common/conntrack"
@ -38,8 +39,7 @@ type NetworkManager struct {
networkInterfaces atomic.TypedValue[[]adapter.NetworkInterface] networkInterfaces atomic.TypedValue[[]adapter.NetworkInterface]
autoDetectInterface bool autoDetectInterface bool
defaultInterface string defaultOptions adapter.NetworkOptions
defaultMark uint32
autoRedirectOutputMark uint32 autoRedirectOutputMark uint32
networkMonitor tun.NetworkUpdateMonitor networkMonitor tun.NetworkUpdateMonitor
@ -58,11 +58,23 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp
logger: logger, logger: logger,
interfaceFinder: control.NewDefaultInterfaceFinder(), interfaceFinder: control.NewDefaultInterfaceFinder(),
autoDetectInterface: routeOptions.AutoDetectInterface, autoDetectInterface: routeOptions.AutoDetectInterface,
defaultInterface: routeOptions.DefaultInterface, defaultOptions: adapter.NetworkOptions{
defaultMark: routeOptions.DefaultMark, DefaultInterface: routeOptions.DefaultInterface,
pauseManager: service.FromContext[pause.Manager](ctx), DefaultMark: routeOptions.DefaultMark,
platformInterface: service.FromContext[platform.Interface](ctx), DefaultNetworkStrategy: C.NetworkStrategy(routeOptions.DefaultNetworkStrategy),
outboundManager: service.FromContext[adapter.OutboundManager](ctx), DefaultFallbackDelay: time.Duration(routeOptions.DefaultFallbackDelay),
},
pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[platform.Interface](ctx),
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
}
if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault {
if routeOptions.DefaultInterface != "" {
return nil, E.New("`default_network_strategy` is conflict with `default_interface`")
}
if !routeOptions.AutoDetectInterface {
return nil, E.New("`auto_detect_interface` is required by `default_network_strategy`")
}
} }
usePlatformDefaultInterfaceMonitor := nm.platformInterface != nil usePlatformDefaultInterfaceMonitor := nm.platformInterface != nil
enforceInterfaceMonitor := routeOptions.AutoDetectInterface enforceInterfaceMonitor := routeOptions.AutoDetectInterface
@ -84,12 +96,12 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp
if err != nil { if err != nil {
return nil, E.New("auto_detect_interface unsupported on current platform") return nil, E.New("auto_detect_interface unsupported on current platform")
} }
interfaceMonitor.RegisterCallback(nm.notifyNetworkUpdate) interfaceMonitor.RegisterCallback(nm.notifyInterfaceUpdate)
nm.interfaceMonitor = interfaceMonitor nm.interfaceMonitor = interfaceMonitor
} }
} else { } else {
interfaceMonitor := nm.platformInterface.CreateDefaultInterfaceMonitor(logger) interfaceMonitor := nm.platformInterface.CreateDefaultInterfaceMonitor(logger)
interfaceMonitor.RegisterCallback(nm.notifyNetworkUpdate) interfaceMonitor.RegisterCallback(nm.notifyInterfaceUpdate)
nm.interfaceMonitor = interfaceMonitor nm.interfaceMonitor = interfaceMonitor
} }
return nm, nil return nm, nil
@ -265,10 +277,6 @@ func (r *NetworkManager) NetworkInterfaces() []adapter.NetworkInterface {
return r.networkInterfaces.Load() return r.networkInterfaces.Load()
} }
func (r *NetworkManager) DefaultInterface() string {
return r.defaultInterface
}
func (r *NetworkManager) AutoDetectInterface() bool { func (r *NetworkManager) AutoDetectInterface() bool {
return r.autoDetectInterface return r.autoDetectInterface
} }
@ -301,8 +309,19 @@ func (r *NetworkManager) AutoDetectInterfaceFunc() control.Func {
} }
} }
func (r *NetworkManager) DefaultMark() uint32 { func (r *NetworkManager) ProtectFunc() control.Func {
return r.defaultMark if r.platformInterface != nil && r.platformInterface.UsePlatformAutoDetectInterfaceControl() {
return func(network, address string, conn syscall.RawConn) error {
return control.Raw(conn, func(fd uintptr) error {
return r.platformInterface.AutoDetectInterfaceControl(int(fd))
})
}
}
return nil
}
func (r *NetworkManager) DefaultOptions() adapter.NetworkOptions {
return r.defaultOptions
} }
func (r *NetworkManager) RegisterAutoRedirectOutputMark(mark uint32) error { func (r *NetworkManager) RegisterAutoRedirectOutputMark(mark uint32) error {
@ -344,45 +363,47 @@ func (r *NetworkManager) ResetNetwork() {
} }
} }
func (r *NetworkManager) notifyNetworkUpdate(event int) { func (r *NetworkManager) notifyInterfaceUpdate(defaultInterface *control.Interface, flags int) {
if event == tun.EventNoRoute { if defaultInterface == nil {
r.pauseManager.NetworkPause() r.pauseManager.NetworkPause()
r.logger.Error("missing default interface") r.logger.Error("missing default interface")
} else { return
r.pauseManager.NetworkWake() }
defaultInterface := r.DefaultNetworkInterface()
if defaultInterface == nil { r.pauseManager.NetworkWake()
panic("invalid interface context") var options []string
} options = append(options, F.ToString("index ", defaultInterface.Index))
var options []string if C.IsAndroid && r.platformInterface == nil {
options = append(options, F.ToString("index ", defaultInterface.Index)) var vpnStatus string
if C.IsAndroid && r.platformInterface == nil { if r.interfaceMonitor.AndroidVPNEnabled() {
var vpnStatus string vpnStatus = "enabled"
if r.interfaceMonitor.AndroidVPNEnabled() {
vpnStatus = "enabled"
} else {
vpnStatus = "disabled"
}
options = append(options, "vpn "+vpnStatus)
} else { } else {
if defaultInterface.Type != "" { vpnStatus = "disabled"
options = append(options, F.ToString("type ", defaultInterface.Type))
}
if defaultInterface.Expensive {
options = append(options, "expensive")
}
if defaultInterface.Constrained {
options = append(options, "constrained")
}
} }
r.logger.Info("updated default interface ", defaultInterface.Name, ", ", strings.Join(options, ", ")) options = append(options, "vpn "+vpnStatus)
if r.platformInterface != nil { } else if r.platformInterface != nil {
state := r.platformInterface.ReadWIFIState() networkInterface := common.Find(r.networkInterfaces.Load(), func(it adapter.NetworkInterface) bool {
if state != r.wifiState { return it.Interface.Index == defaultInterface.Index
r.wifiState = state })
if state.SSID != "" { if networkInterface.Type == "" {
r.logger.Info("updated WIFI state: SSID=", state.SSID, ", BSSID=", state.BSSID) // race
} return
}
options = append(options, F.ToString("type ", networkInterface.Type))
if networkInterface.Expensive {
options = append(options, "expensive")
}
if networkInterface.Constrained {
options = append(options, "constrained")
}
}
r.logger.Info("updated default interface ", defaultInterface.Name, ", ", strings.Join(options, ", "))
if r.platformInterface != nil {
state := r.platformInterface.ReadWIFIState()
if state != r.wifiState {
r.wifiState = state
if state.SSID != "" {
r.logger.Info("updated WIFI state: SSID=", state.SSID, ", BSSID=", state.BSSID)
} }
} }
} }

View file

@ -424,9 +424,13 @@ match:
} }
switch action := currentRule.Action().(type) { switch action := currentRule.Action().(type) {
case *rule.RuleActionRoute: case *rule.RuleActionRoute:
metadata.NetworkStrategy = action.NetworkStrategy
metadata.FallbackDelay = action.FallbackDelay
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
metadata.UDPConnect = action.UDPConnect metadata.UDPConnect = action.UDPConnect
case *rule.RuleActionRouteOptions: case *rule.RuleActionRouteOptions:
metadata.NetworkStrategy = action.NetworkStrategy
metadata.FallbackDelay = action.FallbackDelay
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
metadata.UDPConnect = action.UDPConnect metadata.UDPConnect = action.UDPConnect
case *rule.RuleActionSniff: case *rule.RuleActionSniff:

View file

@ -30,12 +30,16 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
return &RuleActionRoute{ return &RuleActionRoute{
Outbound: action.RouteOptions.Outbound, Outbound: action.RouteOptions.Outbound,
RuleActionRouteOptions: RuleActionRouteOptions{ RuleActionRouteOptions: RuleActionRouteOptions{
NetworkStrategy: C.NetworkStrategy(action.RouteOptions.NetworkStrategy),
FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay),
UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping,
UDPConnect: action.RouteOptions.UDPConnect, UDPConnect: action.RouteOptions.UDPConnect,
}, },
}, nil }, nil
case C.RuleActionTypeRouteOptions: case C.RuleActionTypeRouteOptions:
return &RuleActionRouteOptions{ return &RuleActionRouteOptions{
NetworkStrategy: C.NetworkStrategy(action.RouteOptionsOptions.NetworkStrategy),
FallbackDelay: time.Duration(action.RouteOptionsOptions.FallbackDelay),
UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping, UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping,
UDPConnect: action.RouteOptionsOptions.UDPConnect, UDPConnect: action.RouteOptionsOptions.UDPConnect,
}, nil }, nil
@ -135,6 +139,8 @@ func (r *RuleActionRoute) String() string {
} }
type RuleActionRouteOptions struct { type RuleActionRouteOptions struct {
NetworkStrategy C.NetworkStrategy
FallbackDelay time.Duration
UDPDisableDomainUnmapping bool UDPDisableDomainUnmapping bool
UDPConnect bool UDPConnect bool
} }

View file

@ -166,7 +166,7 @@ func (t *Transport) updateServers() error {
} }
} }
func (t *Transport) interfaceUpdated(int) { func (t *Transport) interfaceUpdated(defaultInterface *control.Interface, flags int) {
err := t.updateServers() err := t.updateServers()
if err != nil { if err != nil {
t.options.Logger.Error("update servers: ", err) t.options.Logger.Error("update servers: ", err)