Add NDIS inbound

This commit is contained in:
世界 2025-01-03 18:34:07 +08:00
parent e483c909b4
commit 79d3649a8b
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
32 changed files with 1339 additions and 572 deletions

7
box.go
View file

@ -12,6 +12,7 @@ import (
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/taskmonitor"
"github.com/sagernet/sing-box/common/tls"
@ -84,7 +85,6 @@ func New(options Options) (*Box, error) {
ctx = context.Background()
}
ctx = service.ContextWithDefaultRegistry(ctx)
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
@ -101,7 +101,10 @@ func New(options Options) (*Box, error) {
ctx = pause.WithDefaultManager(ctx)
experimentalOptions := common.PtrValueOrDefault(options.Experimental)
applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug))
debugOptions := common.PtrValueOrDefault(experimentalOptions.Debug)
applyDebugOptions(debugOptions)
ctx = conntrack.ContextWithDefaultTracker(ctx, debugOptions.OOMKiller, uint64(debugOptions.MemoryLimit))
var needCacheFile bool
var needClashAPI bool
var needV2RayAPI bool

View file

@ -1,54 +0,0 @@
package conntrack
import (
"io"
"net"
"github.com/sagernet/sing/common/x/list"
)
type Conn struct {
net.Conn
element *list.Element[io.Closer]
}
func NewConn(conn net.Conn) (net.Conn, error) {
connAccess.Lock()
element := openConnection.PushBack(conn)
connAccess.Unlock()
if KillerEnabled {
err := KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
}
return &Conn{
Conn: conn,
element: element,
}, nil
}
func (c *Conn) Close() error {
if c.element.Value != nil {
connAccess.Lock()
if c.element.Value != nil {
openConnection.Remove(c.element)
c.element.Value = nil
}
connAccess.Unlock()
}
return c.Conn.Close()
}
func (c *Conn) Upstream() any {
return c.Conn
}
func (c *Conn) ReaderReplaceable() bool {
return true
}
func (c *Conn) WriterReplaceable() bool {
return true
}

View file

@ -0,0 +1,14 @@
package conntrack
import (
"context"
"github.com/sagernet/sing/service"
)
func ContextWithDefaultTracker(ctx context.Context, killerEnabled bool, memoryLimit uint64) context.Context {
if service.FromContext[Tracker](ctx) != nil {
return ctx
}
return service.ContextWith[Tracker](ctx, NewDefaultTracker(killerEnabled, memoryLimit))
}

245
common/conntrack/default.go Normal file
View file

@ -0,0 +1,245 @@
package conntrack
import (
"net"
"net/netip"
runtimeDebug "runtime/debug"
"sync"
"time"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/memory"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)
var _ Tracker = (*DefaultTracker)(nil)
type DefaultTracker struct {
connAccess sync.RWMutex
connList list.List[net.Conn]
connAddress map[netip.AddrPort]netip.AddrPort
packetConnAccess sync.RWMutex
packetConnList list.List[AbstractPacketConn]
packetConnAddress map[netip.AddrPort]bool
pendingAccess sync.RWMutex
pendingList list.List[netip.AddrPort]
killerEnabled bool
memoryLimit uint64
killerLastCheck time.Time
}
func NewDefaultTracker(killerEnabled bool, memoryLimit uint64) *DefaultTracker {
return &DefaultTracker{
connAddress: make(map[netip.AddrPort]netip.AddrPort),
packetConnAddress: make(map[netip.AddrPort]bool),
killerEnabled: killerEnabled,
memoryLimit: memoryLimit,
}
}
func (t *DefaultTracker) NewConn(conn net.Conn) (net.Conn, error) {
err := t.KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
t.connAccess.Lock()
element := t.connList.PushBack(conn)
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
t.connAccess.Unlock()
return &Conn{
Conn: conn,
closeFunc: common.OnceFunc(func() {
t.removeConn(element)
}),
}, nil
}
func (t *DefaultTracker) NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error) {
err := t.KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
t.connAccess.Lock()
element := t.connList.PushBack(conn)
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
t.connAccess.Unlock()
return N.OnceClose(func(it error) {
t.removeConn(element)
}), nil
}
func (t *DefaultTracker) NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
err := t.KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
t.packetConnAccess.Lock()
element := t.packetConnList.PushBack(conn)
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
t.packetConnAccess.Unlock()
return &PacketConn{
PacketConn: conn,
closeFunc: common.OnceFunc(func() {
t.removePacketConn(element)
}),
}, nil
}
func (t *DefaultTracker) NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error) {
err := t.KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
t.packetConnAccess.Lock()
element := t.packetConnList.PushBack(conn)
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
t.packetConnAccess.Unlock()
return N.OnceClose(func(it error) {
t.removePacketConn(element)
}), nil
}
func (t *DefaultTracker) CheckConn(source netip.AddrPort, destination netip.AddrPort) bool {
t.connAccess.RLock()
defer t.connAccess.RUnlock()
return t.connAddress[source] == destination
}
func (t *DefaultTracker) CheckPacketConn(source netip.AddrPort) bool {
t.packetConnAccess.RLock()
defer t.packetConnAccess.RUnlock()
return t.packetConnAddress[source]
}
func (t *DefaultTracker) AddPendingDestination(destination netip.AddrPort) func() {
t.pendingAccess.Lock()
defer t.pendingAccess.Unlock()
element := t.pendingList.PushBack(destination)
return func() {
t.pendingAccess.Lock()
defer t.pendingAccess.Unlock()
t.pendingList.Remove(element)
}
}
func (t *DefaultTracker) CheckDestination(destination netip.AddrPort) bool {
t.pendingAccess.RLock()
defer t.pendingAccess.RUnlock()
for element := t.pendingList.Front(); element != nil; element = element.Next() {
if element.Value == destination {
return true
}
}
return false
}
func (t *DefaultTracker) KillerCheck() error {
if !t.killerEnabled {
return nil
}
nowTime := time.Now()
if nowTime.Sub(t.killerLastCheck) < 3*time.Second {
return nil
}
t.killerLastCheck = nowTime
if memory.Total() > t.memoryLimit {
t.Close()
go func() {
time.Sleep(time.Second)
runtimeDebug.FreeOSMemory()
}()
return E.New("out of memory")
}
return nil
}
func (t *DefaultTracker) Count() int {
t.connAccess.RLock()
defer t.connAccess.RUnlock()
t.packetConnAccess.RLock()
defer t.packetConnAccess.RUnlock()
return t.connList.Len() + t.packetConnList.Len()
}
func (t *DefaultTracker) Close() {
t.connAccess.Lock()
for element := t.connList.Front(); element != nil; element = element.Next() {
element.Value.Close()
}
t.connList.Init()
t.connAccess.Unlock()
t.packetConnAccess.Lock()
for element := t.packetConnList.Front(); element != nil; element = element.Next() {
element.Value.Close()
}
t.packetConnList.Init()
t.packetConnAccess.Unlock()
}
func (t *DefaultTracker) removeConn(element *list.Element[net.Conn]) {
t.connAccess.Lock()
defer t.connAccess.Unlock()
delete(t.connAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
t.connList.Remove(element)
}
func (t *DefaultTracker) removePacketConn(element *list.Element[AbstractPacketConn]) {
t.packetConnAccess.Lock()
defer t.packetConnAccess.Unlock()
delete(t.packetConnAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
t.packetConnList.Remove(element)
}
type Conn struct {
net.Conn
closeFunc func()
}
func (c *Conn) Close() error {
c.closeFunc()
return c.Conn.Close()
}
func (c *Conn) Upstream() any {
return c.Conn
}
func (c *Conn) ReaderReplaceable() bool {
return true
}
func (c *Conn) WriterReplaceable() bool {
return true
}
type PacketConn struct {
net.PacketConn
closeFunc func()
}
func (c *PacketConn) Close() error {
c.closeFunc()
return c.PacketConn.Close()
}
func (c *PacketConn) Upstream() any {
return c.PacketConn
}
func (c *PacketConn) ReaderReplaceable() bool {
return true
}
func (c *PacketConn) WriterReplaceable() bool {
return true
}

View file

@ -1,35 +0,0 @@
package conntrack
import (
runtimeDebug "runtime/debug"
"time"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/memory"
)
var (
KillerEnabled bool
MemoryLimit uint64
killerLastCheck time.Time
)
func KillerCheck() error {
if !KillerEnabled {
return nil
}
nowTime := time.Now()
if nowTime.Sub(killerLastCheck) < 3*time.Second {
return nil
}
killerLastCheck = nowTime
if memory.Total() > MemoryLimit {
Close()
go func() {
time.Sleep(time.Second)
runtimeDebug.FreeOSMemory()
}()
return E.New("out of memory")
}
return nil
}

View file

@ -1,55 +0,0 @@
package conntrack
import (
"io"
"net"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/x/list"
)
type PacketConn struct {
net.PacketConn
element *list.Element[io.Closer]
}
func NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
connAccess.Lock()
element := openConnection.PushBack(conn)
connAccess.Unlock()
if KillerEnabled {
err := KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
}
return &PacketConn{
PacketConn: conn,
element: element,
}, nil
}
func (c *PacketConn) Close() error {
if c.element.Value != nil {
connAccess.Lock()
if c.element.Value != nil {
openConnection.Remove(c.element)
c.element.Value = nil
}
connAccess.Unlock()
}
return c.PacketConn.Close()
}
func (c *PacketConn) Upstream() any {
return bufio.NewPacketConn(c.PacketConn)
}
func (c *PacketConn) ReaderReplaceable() bool {
return true
}
func (c *PacketConn) WriterReplaceable() bool {
return true
}

View file

@ -1,47 +0,0 @@
package conntrack
import (
"io"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/x/list"
)
var (
connAccess sync.RWMutex
openConnection list.List[io.Closer]
)
func Count() int {
if !Enabled {
return 0
}
return openConnection.Len()
}
func List() []io.Closer {
if !Enabled {
return nil
}
connAccess.RLock()
defer connAccess.RUnlock()
connList := make([]io.Closer, 0, openConnection.Len())
for element := openConnection.Front(); element != nil; element = element.Next() {
connList = append(connList, element.Value)
}
return connList
}
func Close() {
if !Enabled {
return
}
connAccess.Lock()
defer connAccess.Unlock()
for element := openConnection.Front(); element != nil; element = element.Next() {
common.Close(element.Value)
element.Value = nil
}
openConnection.Init()
}

View file

@ -1,5 +0,0 @@
//go:build !with_conntrack
package conntrack
const Enabled = false

View file

@ -1,5 +0,0 @@
//go:build with_conntrack
package conntrack
const Enabled = true

View file

@ -0,0 +1,32 @@
package conntrack
import (
"net"
"net/netip"
"time"
N "github.com/sagernet/sing/common/network"
)
// TODO: add to N
type AbstractPacketConn interface {
Close() error
LocalAddr() net.Addr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}
type Tracker interface {
NewConn(conn net.Conn) (net.Conn, error)
NewPacketConn(conn net.PacketConn) (net.PacketConn, error)
NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error)
NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error)
CheckConn(source netip.AddrPort, destination netip.AddrPort) bool
CheckPacketConn(source netip.AddrPort) bool
AddPendingDestination(destination netip.AddrPort) func()
CheckDestination(destination netip.AddrPort) bool
KillerCheck() error
Count() int
Close()
}

View file

@ -28,6 +28,7 @@ var (
)
type DefaultDialer struct {
tracker conntrack.Tracker
dialer4 tcpDialer
dialer6 tcpDialer
udpDialer4 net.Dialer
@ -46,6 +47,7 @@ type DefaultDialer struct {
}
func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) {
tracker := service.FromContext[conntrack.Tracker](ctx)
networkManager := service.FromContext[adapter.NetworkManager](ctx)
platformInterface := service.FromContext[platform.Interface](ctx)
@ -197,6 +199,7 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
return nil, err
}
return &DefaultDialer{
tracker: tracker,
dialer4: tcpDialer4,
dialer6: tcpDialer6,
udpDialer4: udpDialer4,
@ -219,18 +222,26 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address
return nil, E.New("invalid address")
}
if d.networkStrategy == nil {
if address.IsFqdn() {
return nil, E.New("unexpected domain destination")
}
// Since pending check is only used by ndis, it is not performed for non-windows connections which are only supported on platform clients
if d.tracker != nil {
done := d.tracker.AddPendingDestination(address.AddrPort())
defer done()
}
switch N.NetworkName(network) {
case N.NetworkUDP:
if !address.IsIPv6() {
return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
return d.trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
} else {
return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
return d.trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
}
}
if !address.IsIPv6() {
return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
return d.trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
} else {
return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
return d.trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
}
} else {
return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
@ -282,17 +293,17 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
if !fastFallback && !isPrimary {
d.networkLastFallback.Store(time.Now())
}
return trackConn(conn, nil)
return d.trackConn(conn, nil)
}
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if d.networkStrategy == nil {
if destination.IsIPv6() {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
} else {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
}
} else {
return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
@ -329,23 +340,23 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
return nil, err
}
}
return trackPacketConn(packetConn, nil)
return d.trackPacketConn(packetConn, nil)
}
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
return d.udpListener.ListenPacket(context.Background(), network, address)
}
func trackConn(conn net.Conn, err error) (net.Conn, error) {
if !conntrack.Enabled || err != nil {
func (d *DefaultDialer) trackConn(conn net.Conn, err error) (net.Conn, error) {
if d.tracker == nil || err != nil {
return conn, err
}
return conntrack.NewConn(conn)
return d.tracker.NewConn(conn)
}
func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
if !conntrack.Enabled || err != nil {
func (d *DefaultDialer) trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
if err != nil {
return conn, err
}
return conntrack.NewPacketConn(conn)
return d.tracker.NewPacketConn(conn)
}

View file

@ -23,6 +23,7 @@ const (
TypeVLESS = "vless"
TypeTUIC = "tuic"
TypeHysteria2 = "hysteria2"
TypeNDIS = "ndis"
)
const (
@ -80,6 +81,8 @@ func ProxyDisplayName(proxyType string) string {
return "Selector"
case TypeURLTest:
return "URLTest"
case TypeNDIS:
return "NDIS"
default:
return "Unknown"
}

View file

@ -3,7 +3,6 @@ package box
import (
"runtime/debug"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/option"
)
@ -26,9 +25,5 @@ func applyDebugOptions(options option.DebugOptions) {
}
if options.MemoryLimit != 0 {
debug.SetMemoryLimit(int64(float64(options.MemoryLimit) / 1.5))
conntrack.MemoryLimit = uint64(options.MemoryLimit)
}
if options.OOMKiller != nil {
conntrack.KillerEnabled = *options.OOMKiller
}
}

View file

@ -5,8 +5,6 @@ import (
"net"
runtimeDebug "runtime/debug"
"time"
"github.com/sagernet/sing-box/common/conntrack"
)
func (c *CommandClient) CloseConnections() error {
@ -19,7 +17,7 @@ func (c *CommandClient) CloseConnections() error {
}
func (s *CommandServer) handleCloseConnections(conn net.Conn) error {
conntrack.Close()
tracker.Close()
go func() {
time.Sleep(time.Second)
runtimeDebug.FreeOSMemory()

View file

@ -6,7 +6,6 @@ import (
"runtime"
"time"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/experimental/clashapi"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/memory"
@ -28,7 +27,7 @@ func (s *CommandServer) readStatus() StatusMessage {
var message StatusMessage
message.Memory = int64(memory.Inuse())
message.Goroutines = int32(runtime.NumGoroutine())
message.ConnectionsOut = int32(conntrack.Count())
message.ConnectionsOut = int32(tracker.Count())
if s.service != nil {
message.TrafficAvailable = true

View file

@ -7,17 +7,21 @@ import (
"github.com/sagernet/sing-box/common/conntrack"
)
var tracker *conntrack.DefaultTracker
func SetMemoryLimit(enabled bool) {
if tracker != nil {
tracker.Close()
}
const memoryLimit = 45 * 1024 * 1024
const memoryLimitGo = memoryLimit / 1.5
if enabled {
runtimeDebug.SetGCPercent(10)
runtimeDebug.SetMemoryLimit(memoryLimitGo)
conntrack.KillerEnabled = true
conntrack.MemoryLimit = memoryLimit
tracker = conntrack.NewDefaultTracker(true, memoryLimit)
} else {
runtimeDebug.SetGCPercent(100)
runtimeDebug.SetMemoryLimit(math.MaxInt64)
conntrack.KillerEnabled = false
tracker = conntrack.NewDefaultTracker(false, 0)
}
}

View file

@ -12,6 +12,7 @@ import (
"github.com/sagernet/sing-box"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/process"
"github.com/sagernet/sing-box/common/urltest"
C "github.com/sagernet/sing-box/constant"
@ -60,6 +61,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box
useProcFS: platformInterface.UseProcFS(),
}
service.MustRegister[platform.Interface](ctx, platformWrapper)
service.MustRegister[conntrack.Tracker](ctx, tracker)
instance, err := box.New(box.Options{
Context: ctx,
Options: options,

1
go.mod
View file

@ -41,6 +41,7 @@ require (
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.9.0
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08
go.uber.org/zap v1.27.0
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/crypto v0.31.0

3
go.sum
View file

@ -158,6 +158,8 @@ github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923 h1:tHNk7XK9GkmKUR6Gh8gV
github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08 h1:is+7xN6CAKtgxt3mDSl9OQNvjfi6LggugSP07QhDtws=
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08/go.mod h1:lFE7JYt3LC2UYJ31mRDwl/K35pbtxDnkSDlXrYzgyqg=
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
@ -191,6 +193,7 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

12
include/ndis.go Normal file
View file

@ -0,0 +1,12 @@
//go:build windows && with_gvisor
package include
import (
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/protocol/ndis"
)
func registerNDISInbound(registry *inbound.Registry) {
ndis.RegisterInbound(registry)
}

View file

@ -0,0 +1,20 @@
//go:build windows && !with_gvisor
package include
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
)
func registerNDISInbound(registry *inbound.Registry) {
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
return nil, tun.ErrGVisorNotIncluded
})
}

View file

@ -0,0 +1,20 @@
//go:build !windows
package include
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func registerNDISInbound(registry *inbound.Registry) {
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
return nil, E.New("NDIS is only supported in windows")
})
}

View file

@ -51,6 +51,7 @@ func InboundRegistry() *inbound.Registry {
registerQUICInbounds(registry)
registerStubForRemovedInbounds(registry)
registerNDISInbound(registry)
return registry
}

View file

@ -13,7 +13,7 @@ type DebugOptions struct {
PanicOnFault *bool `json:"panic_on_fault,omitempty"`
TraceBack string `json:"trace_back,omitempty"`
MemoryLimit MemoryBytes `json:"memory_limit,omitempty"`
OOMKiller *bool `json:"oom_killer,omitempty"`
OOMKiller bool `json:"oom_killer,omitempty"`
}
type MemoryBytes uint64

17
option/ndis.go Normal file
View file

@ -0,0 +1,17 @@
package option
import (
"net/netip"
"github.com/sagernet/sing/common/json/badoption"
)
type NDISInboundOptions struct {
Network NetworkList `json:"network,omitempty"`
RouteAddress badoption.Listable[netip.Prefix] `json:"route_address,omitempty"`
RouteAddressSet badoption.Listable[string] `json:"route_address_set,omitempty"`
RouteExcludeAddress badoption.Listable[netip.Prefix] `json:"route_exclude_address,omitempty"`
RouteExcludeAddressSet badoption.Listable[string] `json:"route_exclude_address_set,omitempty"`
InterfaceName string `json:"interface_name,omitempty"`
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
}

110
protocol/ndis/endpoint.go Normal file
View file

@ -0,0 +1,110 @@
//go:build windows
package ndis
import (
"sync"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/wiresock/ndisapi-go"
"github.com/wiresock/ndisapi-go/driver"
)
var _ stack.LinkEndpoint = (*ndisEndpoint)(nil)
type ndisEndpoint struct {
filter *driver.QueuedPacketFilter
mtu uint32
address tcpip.LinkAddress
dispatcher stack.NetworkDispatcher
}
func (e *ndisEndpoint) MTU() uint32 {
return e.mtu
}
func (e *ndisEndpoint) SetMTU(mtu uint32) {
}
func (e *ndisEndpoint) MaxHeaderLength() uint16 {
return header.EthernetMinimumSize
}
func (e *ndisEndpoint) LinkAddress() tcpip.LinkAddress {
return e.address
}
func (e *ndisEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
func (e *ndisEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return 0
}
func (e *ndisEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *ndisEndpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *ndisEndpoint) Wait() {
}
func (e *ndisEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareEther
}
func (e *ndisEndpoint) AddHeader(pkt *stack.PacketBuffer) {
eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
fields := header.EthernetFields{
SrcAddr: pkt.EgressRoute.LocalLinkAddress,
DstAddr: pkt.EgressRoute.RemoteLinkAddress,
Type: pkt.NetworkProtocolNumber,
}
eth.Encode(&fields)
}
func (e *ndisEndpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
return ok
}
func (e *ndisEndpoint) Close() {
}
func (e *ndisEndpoint) SetOnCloseAction(f func()) {
}
var bufferPool = sync.Pool{
New: func() any {
return new(ndisapi.IntermediateBuffer)
},
}
func (e *ndisEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
ndisBuf := bufferPool.Get().(*ndisapi.IntermediateBuffer)
viewList, offset := packetBuffer.AsViewList()
var view *buffer.View
for view = viewList.Front(); view != nil && offset >= view.Size(); view = view.Next() {
offset -= view.Size()
}
index := copy(ndisBuf.Buffer[:], view.AsSlice()[offset:])
for view = view.Next(); view != nil; view = view.Next() {
index += copy(ndisBuf.Buffer[index:], view.AsSlice())
}
ndisBuf.Length = uint32(index)
err := e.filter.InsertPacketToMstcp(ndisBuf)
bufferPool.Put(ndisBuf)
if err != nil {
return 0, &tcpip.ErrAborted{}
}
}
return list.Len(), nil
}

203
protocol/ndis/inbound.go Normal file
View file

@ -0,0 +1,203 @@
//go:build windows
package ndis
import (
"context"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/wiresock/ndisapi-go"
"go4.org/netipx"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, NewInbound)
}
type Inbound struct {
inbound.Adapter
ctx context.Context
router adapter.Router
logger log.ContextLogger
api *ndisapi.NdisApi
tracker conntrack.Tracker
routeAddress []netip.Prefix
routeExcludeAddress []netip.Prefix
routeRuleSet []adapter.RuleSet
routeRuleSetCallback []*list.Element[adapter.RuleSetUpdateCallback]
routeExcludeRuleSet []adapter.RuleSet
routeExcludeRuleSetCallback []*list.Element[adapter.RuleSetUpdateCallback]
stack *Stack
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
api, err := ndisapi.NewNdisApi()
if err != nil {
return nil, E.Cause(err, "create NDIS API")
}
//if !api.IsDriverLoaded() {
// return nil, E.New("missing NDIS driver")
//}
networkManager := service.FromContext[adapter.NetworkManager](ctx)
trackerOut := service.FromContext[conntrack.Tracker](ctx)
var udpTimeout time.Duration
if options.UDPTimeout != 0 {
udpTimeout = time.Duration(options.UDPTimeout)
} else {
udpTimeout = C.UDPTimeout
}
var (
routeRuleSet []adapter.RuleSet
routeExcludeRuleSet []adapter.RuleSet
)
for _, routeAddressSet := range options.RouteAddressSet {
ruleSet, loaded := router.RuleSet(routeAddressSet)
if !loaded {
return nil, E.New("parse route_address_set: rule-set not found: ", routeAddressSet)
}
ruleSet.IncRef()
routeRuleSet = append(routeRuleSet, ruleSet)
}
for _, routeExcludeAddressSet := range options.RouteExcludeAddressSet {
ruleSet, loaded := router.RuleSet(routeExcludeAddressSet)
if !loaded {
return nil, E.New("parse route_exclude_address_set: rule-set not found: ", routeExcludeAddressSet)
}
ruleSet.IncRef()
routeExcludeRuleSet = append(routeExcludeRuleSet, ruleSet)
}
trackerIn := conntrack.NewDefaultTracker(false, 0)
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeNDIS, tag),
ctx: ctx,
router: router,
logger: logger,
api: api,
tracker: trackerIn,
routeRuleSet: routeRuleSet,
routeExcludeRuleSet: routeExcludeRuleSet,
stack: &Stack{
ctx: ctx,
logger: logger,
network: networkManager,
trackerIn: trackerIn,
trackerOut: trackerOut,
api: api,
udpTimeout: udpTimeout,
routeAddress: options.RouteAddress,
routeExcludeAddress: options.RouteExcludeAddress,
},
}, nil
}
func (t *Inbound) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
monitor := taskmonitor.New(t.logger, C.StartTimeout)
var (
routeAddressSet []*netipx.IPSet
routeExcludeAddressSet []*netipx.IPSet
)
for _, routeRuleSet := range t.routeRuleSet {
ipSets := routeRuleSet.ExtractIPSet()
if len(ipSets) == 0 {
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeRuleSet.Name())
}
t.routeRuleSetCallback = append(t.routeRuleSetCallback, routeRuleSet.RegisterCallback(t.updateRouteAddressSet))
routeRuleSet.DecRef()
routeAddressSet = append(routeAddressSet, ipSets...)
}
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
ipSets := routeExcludeRuleSet.ExtractIPSet()
if len(ipSets) == 0 {
t.logger.Warn("route_exclude_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
}
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
routeExcludeRuleSet.DecRef()
routeExcludeAddressSet = append(routeExcludeAddressSet, ipSets...)
}
t.stack.routeAddressSet = routeAddressSet
t.stack.routeExcludeAddressSet = routeExcludeAddressSet
monitor.Start("starting NDIS stack")
t.stack.handler = t
err := t.stack.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "starting NDIS stack")
}
}
return nil
}
func (t *Inbound) Close() error {
if t.api != nil {
t.stack.Close()
t.api.Close()
}
return nil
}
func (t *Inbound) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
return t.router.PreMatch(adapter.InboundContext{
Inbound: t.Tag(),
InboundType: C.TypeNDIS,
Network: network,
Source: source,
Destination: destination,
})
}
func (t *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = t.Tag()
metadata.InboundType = C.TypeNDIS
metadata.Source = source
metadata.Destination = destination
t.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
t.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
done, err := t.tracker.NewConnEx(conn)
if err != nil {
t.logger.ErrorContext(ctx, E.Cause(err, "track inbound connection"))
return
}
t.router.RouteConnectionEx(ctx, conn, metadata, N.AppendClose(onClose, done))
}
func (t *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = t.Tag()
metadata.InboundType = C.TypeNDIS
metadata.Source = source
metadata.Destination = destination
t.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
t.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
done, err := t.tracker.NewPacketConnEx(conn)
if err != nil {
t.logger.ErrorContext(ctx, E.Cause(err, "track inbound connection"))
return
}
t.router.RoutePacketConnectionEx(ctx, conn, metadata, N.AppendClose(onClose, done))
}
func (t *Inbound) updateRouteAddressSet(it adapter.RuleSet) {
t.stack.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
t.stack.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
}

267
protocol/ndis/stack.go Normal file
View file

@ -0,0 +1,267 @@
//go:build windows
package ndis
import (
"context"
"net/netip"
"time"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/wiresock/ndisapi-go"
"github.com/wiresock/ndisapi-go/driver"
"go4.org/netipx"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type Stack struct {
ctx context.Context
logger logger.ContextLogger
network adapter.NetworkManager
trackerIn conntrack.Tracker
trackerOut conntrack.Tracker
api *ndisapi.NdisApi
handler tun.Handler
udpTimeout time.Duration
filter *driver.QueuedPacketFilter
stack *stack.Stack
endpoint *ndisEndpoint
routeAddress []netip.Prefix
routeExcludeAddress []netip.Prefix
routeAddressSet []*netipx.IPSet
routeExcludeAddressSet []*netipx.IPSet
currentInterface *control.Interface
}
func (s *Stack) Start() error {
err := s.start(s.network.InterfaceMonitor().DefaultInterface())
if err != nil {
return err
}
s.network.InterfaceMonitor().RegisterCallback(s.updateDefaultInterface)
return nil
}
func (s *Stack) updateDefaultInterface(defaultInterface *control.Interface, flags int) {
if s.currentInterface.Equals(*defaultInterface) {
return
}
err := s.start(defaultInterface)
if err != nil {
s.logger.Error(E.Cause(err, "reconfigure NDIS at: ", defaultInterface.Name))
}
}
func (s *Stack) start(defaultInterface *control.Interface) error {
_ = s.Close()
adapters, err := s.api.GetTcpipBoundAdaptersInfo()
if err != nil {
return err
}
if defaultInterface != nil {
for index := 0; index < int(adapters.AdapterCount); index++ {
name := s.api.ConvertWindows2000AdapterName(string(adapters.AdapterNameList[index][:]))
if name != defaultInterface.Name {
continue
}
s.filter, err = driver.NewQueuedPacketFilter(s.api, adapters, nil, s.processOut)
if err != nil {
return err
}
address := tcpip.LinkAddress(adapters.CurrentAddress[index][:])
mtu := uint32(adapters.MTU[index])
endpoint := &ndisEndpoint{
filter: s.filter,
mtu: mtu,
address: address,
}
s.stack, err = tun.NewGVisorStack(endpoint)
if err != nil {
s.filter = nil
return err
}
s.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(s.ctx, s.stack, s.handler).HandlePacket)
s.stack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(s.ctx, s.stack, s.handler, s.udpTimeout).HandlePacket)
err = s.filter.StartFilter(index)
if err != nil {
s.filter = nil
s.stack.Close()
s.stack = nil
return err
}
s.endpoint = endpoint
s.logger.Info("started at ", defaultInterface.Name)
break
}
}
s.currentInterface = defaultInterface
return nil
}
func (s *Stack) Close() error {
if s.filter != nil {
s.filter.StopFilter()
s.filter.Close()
s.filter = nil
}
if s.stack != nil {
s.stack.Close()
for _, endpoint := range s.stack.CleanupEndpoints() {
endpoint.Abort()
}
s.stack = nil
}
return nil
}
func (s *Stack) processOut(handle ndisapi.Handle, packet *ndisapi.IntermediateBuffer) ndisapi.FilterAction {
if packet.Length < header.EthernetMinimumSize {
return ndisapi.FilterActionPass
}
if s.endpoint.dispatcher == nil || s.filterPacket(packet.Buffer[:packet.Length]) {
return ndisapi.FilterActionPass
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet.Buffer[:packet.Length]),
})
_, ok := packetBuffer.LinkHeader().Consume(header.EthernetMinimumSize)
if !ok {
packetBuffer.DecRef()
return ndisapi.FilterActionPass
}
ethHdr := header.Ethernet(packetBuffer.LinkHeader().Slice())
destinationAddress := ethHdr.DestinationAddress()
if destinationAddress == header.EthernetBroadcastAddress {
packetBuffer.PktType = tcpip.PacketBroadcast
} else if header.IsMulticastEthernetAddress(destinationAddress) {
packetBuffer.PktType = tcpip.PacketMulticast
} else if destinationAddress == s.endpoint.address {
packetBuffer.PktType = tcpip.PacketHost
} else {
packetBuffer.PktType = tcpip.PacketOtherHost
}
s.endpoint.dispatcher.DeliverNetworkPacket(ethHdr.Type(), packetBuffer)
packetBuffer.DecRef()
return ndisapi.FilterActionDrop
}
func (s *Stack) filterPacket(packet []byte) bool {
var ipHdr header.Network
switch header.IPVersion(packet[header.EthernetMinimumSize:]) {
case ipv4.Version:
ipHdr = header.IPv4(packet[header.EthernetMinimumSize:])
case ipv6.Version:
ipHdr = header.IPv6(packet[header.EthernetMinimumSize:])
default:
return true
}
sourceAddr := tun.AddrFromAddress(ipHdr.SourceAddress())
destinationAddr := tun.AddrFromAddress(ipHdr.DestinationAddress())
if !destinationAddr.IsGlobalUnicast() {
return true
}
var (
transportProtocol tcpip.TransportProtocolNumber
transportHdr header.Transport
)
switch ipHdr.TransportProtocol() {
case tcp.ProtocolNumber:
transportProtocol = header.TCPProtocolNumber
transportHdr = header.TCP(ipHdr.Payload())
case udp.ProtocolNumber:
transportProtocol = header.UDPProtocolNumber
transportHdr = header.UDP(ipHdr.Payload())
default:
return false
}
source := netip.AddrPortFrom(sourceAddr, transportHdr.SourcePort())
destination := netip.AddrPortFrom(destinationAddr, transportHdr.DestinationPort())
if transportProtocol == header.TCPProtocolNumber {
if s.trackerIn.CheckConn(source, destination) {
if debug.Enabled {
s.logger.Trace("fall exists TCP ", source, " ", destination)
}
return false
}
} else {
if s.trackerIn.CheckPacketConn(source) {
if debug.Enabled {
s.logger.Trace("fall exists UDP ", source, " ", destination)
}
}
}
if len(s.routeAddress) > 0 {
var match bool
for _, route := range s.routeAddress {
if route.Contains(destinationAddr) {
match = true
}
}
if !match {
return true
}
}
if len(s.routeAddressSet) > 0 {
var match bool
for _, ipSet := range s.routeAddressSet {
if ipSet.Contains(destinationAddr) {
match = true
}
}
if !match {
return true
}
}
if len(s.routeExcludeAddress) > 0 {
for _, address := range s.routeExcludeAddress {
if address.Contains(destinationAddr) {
return true
}
}
}
if len(s.routeExcludeAddressSet) > 0 {
for _, ipSet := range s.routeAddressSet {
if ipSet.Contains(destinationAddr) {
return true
}
}
}
if s.trackerOut.CheckDestination(destination) {
if debug.Enabled {
s.logger.Trace("passing pending ", source, " ", destination)
}
return true
}
if transportProtocol == header.TCPProtocolNumber {
if s.trackerOut.CheckConn(source, destination) {
if debug.Enabled {
s.logger.Trace("passing TCP ", source, " ", destination)
}
return true
}
} else {
if s.trackerOut.CheckPacketConn(source) {
if debug.Enabled {
s.logger.Trace("passing UDP ", source, " ", destination)
}
}
}
if debug.Enabled {
s.logger.Trace("fall ", source, " ", destination)
}
return false
}

View file

@ -306,7 +306,6 @@ func (t *Inbound) Start(stage adapter.StartStage) error {
t.tunOptions.Name = tun.CalculateInterfaceName("")
}
if t.platformInterface == nil || runtime.GOOS != "android" {
t.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
for _, routeRuleSet := range t.routeRuleSet {
ipSets := routeRuleSet.ExtractIPSet()
if len(ipSets) == 0 {
@ -316,11 +315,10 @@ func (t *Inbound) Start(stage adapter.StartStage) error {
routeRuleSet.DecRef()
t.routeAddressSet = append(t.routeAddressSet, ipSets...)
}
t.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
ipSets := routeExcludeRuleSet.ExtractIPSet()
if len(ipSets) == 0 {
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
t.logger.Warn("route_exclude_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
}
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
routeExcludeRuleSet.DecRef()

View file

@ -35,6 +35,7 @@ var _ adapter.NetworkManager = (*NetworkManager)(nil)
type NetworkManager struct {
logger logger.ContextLogger
tracker conntrack.Tracker
interfaceFinder *control.DefaultInterfaceFinder
networkInterfaces atomic.TypedValue[[]adapter.NetworkInterface]
@ -57,6 +58,7 @@ type NetworkManager struct {
func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) {
nm := &NetworkManager{
logger: logger,
tracker: service.FromContext[conntrack.Tracker](ctx),
interfaceFinder: control.NewDefaultInterfaceFinder(),
autoDetectInterface: routeOptions.AutoDetectInterface,
defaultOptions: adapter.NetworkOptions{
@ -355,7 +357,7 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
}
func (r *NetworkManager) ResetNetwork() {
conntrack.Close()
r.tracker.Close()
for _, endpoint := range r.endpoint.Endpoints() {
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)

View file

@ -11,7 +11,6 @@ import (
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/process"
"github.com/sagernet/sing-box/common/sniff"
C "github.com/sagernet/sing-box/constant"
@ -72,7 +71,10 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
injectable.NewConnectionEx(ctx, conn, metadata, onClose)
return nil
}
conntrack.KillerCheck()
err := r.connTracker.KillerCheck()
if err != nil {
return err
}
metadata.Network = N.NetworkTCP
switch metadata.Destination.Fqdn {
case mux.Destination.Fqdn:
@ -190,7 +192,10 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
injectable.NewPacketConnectionEx(ctx, conn, metadata, onClose)
return nil
}
conntrack.KillerCheck()
err := r.connTracker.KillerCheck()
if err != nil {
return err
}
// TODO: move to UoT
metadata.Network = N.NetworkUDP

View file

@ -10,6 +10,7 @@ import (
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/geoip"
"github.com/sagernet/sing-box/common/geosite"
@ -38,6 +39,7 @@ type Router struct {
ctx context.Context
logger log.ContextLogger
dnsLogger log.ContextLogger
connTracker conntrack.Tracker
inbound adapter.InboundManager
outbound adapter.OutboundManager
connection adapter.ConnectionManager
@ -75,6 +77,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
ctx: ctx,
logger: logFactory.NewLogger("router"),
dnsLogger: logFactory.NewLogger("dns"),
connTracker: service.FromContext[conntrack.Tracker](ctx),
inbound: service.FromContext[adapter.InboundManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
connection: service.FromContext[adapter.ConnectionManager](ctx),