mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-01-28 19:56:49 +00:00
Add NDIS inbound
This commit is contained in:
parent
e483c909b4
commit
79d3649a8b
7
box.go
7
box.go
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
|
@ -84,7 +85,6 @@ func New(options Options) (*Box, error) {
|
|||
ctx = context.Background()
|
||||
}
|
||||
ctx = service.ContextWithDefaultRegistry(ctx)
|
||||
|
||||
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
|
||||
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||
|
@ -101,7 +101,10 @@ func New(options Options) (*Box, error) {
|
|||
|
||||
ctx = pause.WithDefaultManager(ctx)
|
||||
experimentalOptions := common.PtrValueOrDefault(options.Experimental)
|
||||
applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug))
|
||||
debugOptions := common.PtrValueOrDefault(experimentalOptions.Debug)
|
||||
applyDebugOptions(debugOptions)
|
||||
ctx = conntrack.ContextWithDefaultTracker(ctx, debugOptions.OOMKiller, uint64(debugOptions.MemoryLimit))
|
||||
|
||||
var needCacheFile bool
|
||||
var needClashAPI bool
|
||||
var needV2RayAPI bool
|
||||
|
|
|
@ -1,54 +0,0 @@
|
|||
package conntrack
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
element *list.Element[io.Closer]
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn) (net.Conn, error) {
|
||||
connAccess.Lock()
|
||||
element := openConnection.PushBack(conn)
|
||||
connAccess.Unlock()
|
||||
if KillerEnabled {
|
||||
err := KillerCheck()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
element: element,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
if c.element.Value != nil {
|
||||
connAccess.Lock()
|
||||
if c.element.Value != nil {
|
||||
openConnection.Remove(c.element)
|
||||
c.element.Value = nil
|
||||
}
|
||||
connAccess.Unlock()
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *Conn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
14
common/conntrack/context.go
Normal file
14
common/conntrack/context.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func ContextWithDefaultTracker(ctx context.Context, killerEnabled bool, memoryLimit uint64) context.Context {
|
||||
if service.FromContext[Tracker](ctx) != nil {
|
||||
return ctx
|
||||
}
|
||||
return service.ContextWith[Tracker](ctx, NewDefaultTracker(killerEnabled, memoryLimit))
|
||||
}
|
245
common/conntrack/default.go
Normal file
245
common/conntrack/default.go
Normal file
|
@ -0,0 +1,245 @@
|
|||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
runtimeDebug "runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/memory"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
var _ Tracker = (*DefaultTracker)(nil)
|
||||
|
||||
type DefaultTracker struct {
|
||||
connAccess sync.RWMutex
|
||||
connList list.List[net.Conn]
|
||||
connAddress map[netip.AddrPort]netip.AddrPort
|
||||
|
||||
packetConnAccess sync.RWMutex
|
||||
packetConnList list.List[AbstractPacketConn]
|
||||
packetConnAddress map[netip.AddrPort]bool
|
||||
|
||||
pendingAccess sync.RWMutex
|
||||
pendingList list.List[netip.AddrPort]
|
||||
|
||||
killerEnabled bool
|
||||
memoryLimit uint64
|
||||
killerLastCheck time.Time
|
||||
}
|
||||
|
||||
func NewDefaultTracker(killerEnabled bool, memoryLimit uint64) *DefaultTracker {
|
||||
return &DefaultTracker{
|
||||
connAddress: make(map[netip.AddrPort]netip.AddrPort),
|
||||
packetConnAddress: make(map[netip.AddrPort]bool),
|
||||
killerEnabled: killerEnabled,
|
||||
memoryLimit: memoryLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) NewConn(conn net.Conn) (net.Conn, error) {
|
||||
err := t.KillerCheck()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
t.connAccess.Lock()
|
||||
element := t.connList.PushBack(conn)
|
||||
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
|
||||
t.connAccess.Unlock()
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
closeFunc: common.OnceFunc(func() {
|
||||
t.removeConn(element)
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error) {
|
||||
err := t.KillerCheck()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
t.connAccess.Lock()
|
||||
element := t.connList.PushBack(conn)
|
||||
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
|
||||
t.connAccess.Unlock()
|
||||
return N.OnceClose(func(it error) {
|
||||
t.removeConn(element)
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
|
||||
err := t.KillerCheck()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
t.packetConnAccess.Lock()
|
||||
element := t.packetConnList.PushBack(conn)
|
||||
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
|
||||
t.packetConnAccess.Unlock()
|
||||
return &PacketConn{
|
||||
PacketConn: conn,
|
||||
closeFunc: common.OnceFunc(func() {
|
||||
t.removePacketConn(element)
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error) {
|
||||
err := t.KillerCheck()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
t.packetConnAccess.Lock()
|
||||
element := t.packetConnList.PushBack(conn)
|
||||
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
|
||||
t.packetConnAccess.Unlock()
|
||||
return N.OnceClose(func(it error) {
|
||||
t.removePacketConn(element)
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) CheckConn(source netip.AddrPort, destination netip.AddrPort) bool {
|
||||
t.connAccess.RLock()
|
||||
defer t.connAccess.RUnlock()
|
||||
return t.connAddress[source] == destination
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) CheckPacketConn(source netip.AddrPort) bool {
|
||||
t.packetConnAccess.RLock()
|
||||
defer t.packetConnAccess.RUnlock()
|
||||
return t.packetConnAddress[source]
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) AddPendingDestination(destination netip.AddrPort) func() {
|
||||
t.pendingAccess.Lock()
|
||||
defer t.pendingAccess.Unlock()
|
||||
element := t.pendingList.PushBack(destination)
|
||||
return func() {
|
||||
t.pendingAccess.Lock()
|
||||
defer t.pendingAccess.Unlock()
|
||||
t.pendingList.Remove(element)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) CheckDestination(destination netip.AddrPort) bool {
|
||||
t.pendingAccess.RLock()
|
||||
defer t.pendingAccess.RUnlock()
|
||||
for element := t.pendingList.Front(); element != nil; element = element.Next() {
|
||||
if element.Value == destination {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) KillerCheck() error {
|
||||
if !t.killerEnabled {
|
||||
return nil
|
||||
}
|
||||
nowTime := time.Now()
|
||||
if nowTime.Sub(t.killerLastCheck) < 3*time.Second {
|
||||
return nil
|
||||
}
|
||||
t.killerLastCheck = nowTime
|
||||
if memory.Total() > t.memoryLimit {
|
||||
t.Close()
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
runtimeDebug.FreeOSMemory()
|
||||
}()
|
||||
return E.New("out of memory")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) Count() int {
|
||||
t.connAccess.RLock()
|
||||
defer t.connAccess.RUnlock()
|
||||
t.packetConnAccess.RLock()
|
||||
defer t.packetConnAccess.RUnlock()
|
||||
return t.connList.Len() + t.packetConnList.Len()
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) Close() {
|
||||
t.connAccess.Lock()
|
||||
for element := t.connList.Front(); element != nil; element = element.Next() {
|
||||
element.Value.Close()
|
||||
}
|
||||
t.connList.Init()
|
||||
t.connAccess.Unlock()
|
||||
t.packetConnAccess.Lock()
|
||||
for element := t.packetConnList.Front(); element != nil; element = element.Next() {
|
||||
element.Value.Close()
|
||||
}
|
||||
t.packetConnList.Init()
|
||||
t.packetConnAccess.Unlock()
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) removeConn(element *list.Element[net.Conn]) {
|
||||
t.connAccess.Lock()
|
||||
defer t.connAccess.Unlock()
|
||||
delete(t.connAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
|
||||
t.connList.Remove(element)
|
||||
}
|
||||
|
||||
func (t *DefaultTracker) removePacketConn(element *list.Element[AbstractPacketConn]) {
|
||||
t.packetConnAccess.Lock()
|
||||
defer t.packetConnAccess.Unlock()
|
||||
delete(t.packetConnAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
|
||||
t.packetConnList.Remove(element)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
closeFunc func()
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closeFunc()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *Conn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
closeFunc func()
|
||||
}
|
||||
|
||||
func (c *PacketConn) Close() error {
|
||||
c.closeFunc()
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
func (c *PacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *PacketConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
package conntrack
|
||||
|
||||
import (
|
||||
runtimeDebug "runtime/debug"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/memory"
|
||||
)
|
||||
|
||||
var (
|
||||
KillerEnabled bool
|
||||
MemoryLimit uint64
|
||||
killerLastCheck time.Time
|
||||
)
|
||||
|
||||
func KillerCheck() error {
|
||||
if !KillerEnabled {
|
||||
return nil
|
||||
}
|
||||
nowTime := time.Now()
|
||||
if nowTime.Sub(killerLastCheck) < 3*time.Second {
|
||||
return nil
|
||||
}
|
||||
killerLastCheck = nowTime
|
||||
if memory.Total() > MemoryLimit {
|
||||
Close()
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
runtimeDebug.FreeOSMemory()
|
||||
}()
|
||||
return E.New("out of memory")
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
package conntrack
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
element *list.Element[io.Closer]
|
||||
}
|
||||
|
||||
func NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
|
||||
connAccess.Lock()
|
||||
element := openConnection.PushBack(conn)
|
||||
connAccess.Unlock()
|
||||
if KillerEnabled {
|
||||
err := KillerCheck()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &PacketConn{
|
||||
PacketConn: conn,
|
||||
element: element,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *PacketConn) Close() error {
|
||||
if c.element.Value != nil {
|
||||
connAccess.Lock()
|
||||
if c.element.Value != nil {
|
||||
openConnection.Remove(c.element)
|
||||
c.element.Value = nil
|
||||
}
|
||||
connAccess.Unlock()
|
||||
}
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
func (c *PacketConn) Upstream() any {
|
||||
return bufio.NewPacketConn(c.PacketConn)
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *PacketConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
package conntrack
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
var (
|
||||
connAccess sync.RWMutex
|
||||
openConnection list.List[io.Closer]
|
||||
)
|
||||
|
||||
func Count() int {
|
||||
if !Enabled {
|
||||
return 0
|
||||
}
|
||||
return openConnection.Len()
|
||||
}
|
||||
|
||||
func List() []io.Closer {
|
||||
if !Enabled {
|
||||
return nil
|
||||
}
|
||||
connAccess.RLock()
|
||||
defer connAccess.RUnlock()
|
||||
connList := make([]io.Closer, 0, openConnection.Len())
|
||||
for element := openConnection.Front(); element != nil; element = element.Next() {
|
||||
connList = append(connList, element.Value)
|
||||
}
|
||||
return connList
|
||||
}
|
||||
|
||||
func Close() {
|
||||
if !Enabled {
|
||||
return
|
||||
}
|
||||
connAccess.Lock()
|
||||
defer connAccess.Unlock()
|
||||
for element := openConnection.Front(); element != nil; element = element.Next() {
|
||||
common.Close(element.Value)
|
||||
element.Value = nil
|
||||
}
|
||||
openConnection.Init()
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
//go:build !with_conntrack
|
||||
|
||||
package conntrack
|
||||
|
||||
const Enabled = false
|
|
@ -1,5 +0,0 @@
|
|||
//go:build with_conntrack
|
||||
|
||||
package conntrack
|
||||
|
||||
const Enabled = true
|
32
common/conntrack/tracker.go
Normal file
32
common/conntrack/tracker.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
// TODO: add to N
|
||||
type AbstractPacketConn interface {
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
SetDeadline(t time.Time) error
|
||||
SetReadDeadline(t time.Time) error
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type Tracker interface {
|
||||
NewConn(conn net.Conn) (net.Conn, error)
|
||||
NewPacketConn(conn net.PacketConn) (net.PacketConn, error)
|
||||
NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error)
|
||||
NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error)
|
||||
CheckConn(source netip.AddrPort, destination netip.AddrPort) bool
|
||||
CheckPacketConn(source netip.AddrPort) bool
|
||||
AddPendingDestination(destination netip.AddrPort) func()
|
||||
CheckDestination(destination netip.AddrPort) bool
|
||||
KillerCheck() error
|
||||
Count() int
|
||||
Close()
|
||||
}
|
|
@ -28,6 +28,7 @@ var (
|
|||
)
|
||||
|
||||
type DefaultDialer struct {
|
||||
tracker conntrack.Tracker
|
||||
dialer4 tcpDialer
|
||||
dialer6 tcpDialer
|
||||
udpDialer4 net.Dialer
|
||||
|
@ -46,6 +47,7 @@ type DefaultDialer struct {
|
|||
}
|
||||
|
||||
func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) {
|
||||
tracker := service.FromContext[conntrack.Tracker](ctx)
|
||||
networkManager := service.FromContext[adapter.NetworkManager](ctx)
|
||||
platformInterface := service.FromContext[platform.Interface](ctx)
|
||||
|
||||
|
@ -197,6 +199,7 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
|
|||
return nil, err
|
||||
}
|
||||
return &DefaultDialer{
|
||||
tracker: tracker,
|
||||
dialer4: tcpDialer4,
|
||||
dialer6: tcpDialer6,
|
||||
udpDialer4: udpDialer4,
|
||||
|
@ -219,18 +222,26 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address
|
|||
return nil, E.New("invalid address")
|
||||
}
|
||||
if d.networkStrategy == nil {
|
||||
if address.IsFqdn() {
|
||||
return nil, E.New("unexpected domain destination")
|
||||
}
|
||||
// Since pending check is only used by ndis, it is not performed for non-windows connections which are only supported on platform clients
|
||||
if d.tracker != nil {
|
||||
done := d.tracker.AddPendingDestination(address.AddrPort())
|
||||
defer done()
|
||||
}
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkUDP:
|
||||
if !address.IsIPv6() {
|
||||
return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
|
||||
return d.trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
|
||||
} else {
|
||||
return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
|
||||
return d.trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
|
||||
}
|
||||
}
|
||||
if !address.IsIPv6() {
|
||||
return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
|
||||
return d.trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
|
||||
} else {
|
||||
return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
|
||||
return d.trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
|
||||
}
|
||||
} else {
|
||||
return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
|
||||
|
@ -282,17 +293,17 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
|
|||
if !fastFallback && !isPrimary {
|
||||
d.networkLastFallback.Store(time.Now())
|
||||
}
|
||||
return trackConn(conn, nil)
|
||||
return d.trackConn(conn, nil)
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if d.networkStrategy == nil {
|
||||
if destination.IsIPv6() {
|
||||
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
|
||||
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
|
||||
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
|
||||
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
|
||||
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
|
||||
} else {
|
||||
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
|
||||
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
|
||||
}
|
||||
} else {
|
||||
return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
|
||||
|
@ -329,23 +340,23 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
return trackPacketConn(packetConn, nil)
|
||||
return d.trackPacketConn(packetConn, nil)
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
|
||||
return d.udpListener.ListenPacket(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func trackConn(conn net.Conn, err error) (net.Conn, error) {
|
||||
if !conntrack.Enabled || err != nil {
|
||||
func (d *DefaultDialer) trackConn(conn net.Conn, err error) (net.Conn, error) {
|
||||
if d.tracker == nil || err != nil {
|
||||
return conn, err
|
||||
}
|
||||
return conntrack.NewConn(conn)
|
||||
return d.tracker.NewConn(conn)
|
||||
}
|
||||
|
||||
func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
|
||||
if !conntrack.Enabled || err != nil {
|
||||
func (d *DefaultDialer) trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
|
||||
if err != nil {
|
||||
return conn, err
|
||||
}
|
||||
return conntrack.NewPacketConn(conn)
|
||||
return d.tracker.NewPacketConn(conn)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ const (
|
|||
TypeVLESS = "vless"
|
||||
TypeTUIC = "tuic"
|
||||
TypeHysteria2 = "hysteria2"
|
||||
TypeNDIS = "ndis"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -80,6 +81,8 @@ func ProxyDisplayName(proxyType string) string {
|
|||
return "Selector"
|
||||
case TypeURLTest:
|
||||
return "URLTest"
|
||||
case TypeNDIS:
|
||||
return "NDIS"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
|
|
5
debug.go
5
debug.go
|
@ -3,7 +3,6 @@ package box
|
|||
import (
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
|
@ -26,9 +25,5 @@ func applyDebugOptions(options option.DebugOptions) {
|
|||
}
|
||||
if options.MemoryLimit != 0 {
|
||||
debug.SetMemoryLimit(int64(float64(options.MemoryLimit) / 1.5))
|
||||
conntrack.MemoryLimit = uint64(options.MemoryLimit)
|
||||
}
|
||||
if options.OOMKiller != nil {
|
||||
conntrack.KillerEnabled = *options.OOMKiller
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,8 +5,6 @@ import (
|
|||
"net"
|
||||
runtimeDebug "runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
)
|
||||
|
||||
func (c *CommandClient) CloseConnections() error {
|
||||
|
@ -19,7 +17,7 @@ func (c *CommandClient) CloseConnections() error {
|
|||
}
|
||||
|
||||
func (s *CommandServer) handleCloseConnections(conn net.Conn) error {
|
||||
conntrack.Close()
|
||||
tracker.Close()
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
runtimeDebug.FreeOSMemory()
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-box/experimental/clashapi"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/memory"
|
||||
|
@ -28,7 +27,7 @@ func (s *CommandServer) readStatus() StatusMessage {
|
|||
var message StatusMessage
|
||||
message.Memory = int64(memory.Inuse())
|
||||
message.Goroutines = int32(runtime.NumGoroutine())
|
||||
message.ConnectionsOut = int32(conntrack.Count())
|
||||
message.ConnectionsOut = int32(tracker.Count())
|
||||
|
||||
if s.service != nil {
|
||||
message.TrafficAvailable = true
|
||||
|
|
|
@ -7,17 +7,21 @@ import (
|
|||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
)
|
||||
|
||||
var tracker *conntrack.DefaultTracker
|
||||
|
||||
func SetMemoryLimit(enabled bool) {
|
||||
if tracker != nil {
|
||||
tracker.Close()
|
||||
}
|
||||
const memoryLimit = 45 * 1024 * 1024
|
||||
const memoryLimitGo = memoryLimit / 1.5
|
||||
if enabled {
|
||||
runtimeDebug.SetGCPercent(10)
|
||||
runtimeDebug.SetMemoryLimit(memoryLimitGo)
|
||||
conntrack.KillerEnabled = true
|
||||
conntrack.MemoryLimit = memoryLimit
|
||||
tracker = conntrack.NewDefaultTracker(true, memoryLimit)
|
||||
} else {
|
||||
runtimeDebug.SetGCPercent(100)
|
||||
runtimeDebug.SetMemoryLimit(math.MaxInt64)
|
||||
conntrack.KillerEnabled = false
|
||||
tracker = conntrack.NewDefaultTracker(false, 0)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
"github.com/sagernet/sing-box"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-box/common/process"
|
||||
"github.com/sagernet/sing-box/common/urltest"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
|
@ -60,6 +61,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box
|
|||
useProcFS: platformInterface.UseProcFS(),
|
||||
}
|
||||
service.MustRegister[platform.Interface](ctx, platformWrapper)
|
||||
service.MustRegister[conntrack.Tracker](ctx, tracker)
|
||||
instance, err := box.New(box.Options{
|
||||
Context: ctx,
|
||||
Options: options,
|
||||
|
|
1
go.mod
1
go.mod
|
@ -41,6 +41,7 @@ require (
|
|||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08
|
||||
go.uber.org/zap v1.27.0
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||
golang.org/x/crypto v0.31.0
|
||||
|
|
3
go.sum
3
go.sum
|
@ -158,6 +158,8 @@ github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923 h1:tHNk7XK9GkmKUR6Gh8gV
|
|||
github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08 h1:is+7xN6CAKtgxt3mDSl9OQNvjfi6LggugSP07QhDtws=
|
||||
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08/go.mod h1:lFE7JYt3LC2UYJ31mRDwl/K35pbtxDnkSDlXrYzgyqg=
|
||||
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
|
||||
github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
|
||||
github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
|
||||
|
@ -191,6 +193,7 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
|
|
12
include/ndis.go
Normal file
12
include/ndis.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
//go:build windows && with_gvisor
|
||||
|
||||
package include
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/protocol/ndis"
|
||||
)
|
||||
|
||||
func registerNDISInbound(registry *inbound.Registry) {
|
||||
ndis.RegisterInbound(registry)
|
||||
}
|
20
include/ndis_nongvisor_stub.go
Normal file
20
include/ndis_nongvisor_stub.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
//go:build windows && !with_gvisor
|
||||
|
||||
package include
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-tun"
|
||||
)
|
||||
|
||||
func registerNDISInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
|
||||
return nil, tun.ErrGVisorNotIncluded
|
||||
})
|
||||
}
|
20
include/ndis_nonwindows_stub.go
Normal file
20
include/ndis_nonwindows_stub.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
//go:build !windows
|
||||
|
||||
package include
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func registerNDISInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
|
||||
return nil, E.New("NDIS is only supported in windows")
|
||||
})
|
||||
}
|
|
@ -51,6 +51,7 @@ func InboundRegistry() *inbound.Registry {
|
|||
|
||||
registerQUICInbounds(registry)
|
||||
registerStubForRemovedInbounds(registry)
|
||||
registerNDISInbound(registry)
|
||||
|
||||
return registry
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ type DebugOptions struct {
|
|||
PanicOnFault *bool `json:"panic_on_fault,omitempty"`
|
||||
TraceBack string `json:"trace_back,omitempty"`
|
||||
MemoryLimit MemoryBytes `json:"memory_limit,omitempty"`
|
||||
OOMKiller *bool `json:"oom_killer,omitempty"`
|
||||
OOMKiller bool `json:"oom_killer,omitempty"`
|
||||
}
|
||||
|
||||
type MemoryBytes uint64
|
||||
|
|
17
option/ndis.go
Normal file
17
option/ndis.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package option
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
)
|
||||
|
||||
type NDISInboundOptions struct {
|
||||
Network NetworkList `json:"network,omitempty"`
|
||||
RouteAddress badoption.Listable[netip.Prefix] `json:"route_address,omitempty"`
|
||||
RouteAddressSet badoption.Listable[string] `json:"route_address_set,omitempty"`
|
||||
RouteExcludeAddress badoption.Listable[netip.Prefix] `json:"route_exclude_address,omitempty"`
|
||||
RouteExcludeAddressSet badoption.Listable[string] `json:"route_exclude_address_set,omitempty"`
|
||||
InterfaceName string `json:"interface_name,omitempty"`
|
||||
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
|
||||
}
|
110
protocol/ndis/endpoint.go
Normal file
110
protocol/ndis/endpoint.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
//go:build windows
|
||||
|
||||
package ndis
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
|
||||
"github.com/wiresock/ndisapi-go"
|
||||
"github.com/wiresock/ndisapi-go/driver"
|
||||
)
|
||||
|
||||
var _ stack.LinkEndpoint = (*ndisEndpoint)(nil)
|
||||
|
||||
type ndisEndpoint struct {
|
||||
filter *driver.QueuedPacketFilter
|
||||
mtu uint32
|
||||
address tcpip.LinkAddress
|
||||
dispatcher stack.NetworkDispatcher
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) MTU() uint32 {
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) SetMTU(mtu uint32) {
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) MaxHeaderLength() uint16 {
|
||||
return header.EthernetMinimumSize
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) LinkAddress() tcpip.LinkAddress {
|
||||
return e.address
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
e.dispatcher = dispatcher
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) IsAttached() bool {
|
||||
return e.dispatcher != nil
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) Wait() {
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
return header.ARPHardwareEther
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) AddHeader(pkt *stack.PacketBuffer) {
|
||||
eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
|
||||
fields := header.EthernetFields{
|
||||
SrcAddr: pkt.EgressRoute.LocalLinkAddress,
|
||||
DstAddr: pkt.EgressRoute.RemoteLinkAddress,
|
||||
Type: pkt.NetworkProtocolNumber,
|
||||
}
|
||||
eth.Encode(&fields)
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
|
||||
_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) Close() {
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) SetOnCloseAction(f func()) {
|
||||
}
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
return new(ndisapi.IntermediateBuffer)
|
||||
},
|
||||
}
|
||||
|
||||
func (e *ndisEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
|
||||
for _, packetBuffer := range list.AsSlice() {
|
||||
ndisBuf := bufferPool.Get().(*ndisapi.IntermediateBuffer)
|
||||
viewList, offset := packetBuffer.AsViewList()
|
||||
var view *buffer.View
|
||||
for view = viewList.Front(); view != nil && offset >= view.Size(); view = view.Next() {
|
||||
offset -= view.Size()
|
||||
}
|
||||
index := copy(ndisBuf.Buffer[:], view.AsSlice()[offset:])
|
||||
for view = view.Next(); view != nil; view = view.Next() {
|
||||
index += copy(ndisBuf.Buffer[index:], view.AsSlice())
|
||||
}
|
||||
ndisBuf.Length = uint32(index)
|
||||
err := e.filter.InsertPacketToMstcp(ndisBuf)
|
||||
bufferPool.Put(ndisBuf)
|
||||
if err != nil {
|
||||
return 0, &tcpip.ErrAborted{}
|
||||
}
|
||||
}
|
||||
return list.Len(), nil
|
||||
}
|
203
protocol/ndis/inbound.go
Normal file
203
protocol/ndis/inbound.go
Normal file
|
@ -0,0 +1,203 @@
|
|||
//go:build windows
|
||||
|
||||
package ndis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"github.com/wiresock/ndisapi-go"
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
func RegisterInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, NewInbound)
|
||||
}
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
ctx context.Context
|
||||
router adapter.Router
|
||||
logger log.ContextLogger
|
||||
api *ndisapi.NdisApi
|
||||
tracker conntrack.Tracker
|
||||
routeAddress []netip.Prefix
|
||||
routeExcludeAddress []netip.Prefix
|
||||
routeRuleSet []adapter.RuleSet
|
||||
routeRuleSetCallback []*list.Element[adapter.RuleSetUpdateCallback]
|
||||
routeExcludeRuleSet []adapter.RuleSet
|
||||
routeExcludeRuleSetCallback []*list.Element[adapter.RuleSetUpdateCallback]
|
||||
stack *Stack
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
|
||||
api, err := ndisapi.NewNdisApi()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create NDIS API")
|
||||
}
|
||||
//if !api.IsDriverLoaded() {
|
||||
// return nil, E.New("missing NDIS driver")
|
||||
//}
|
||||
networkManager := service.FromContext[adapter.NetworkManager](ctx)
|
||||
trackerOut := service.FromContext[conntrack.Tracker](ctx)
|
||||
var udpTimeout time.Duration
|
||||
if options.UDPTimeout != 0 {
|
||||
udpTimeout = time.Duration(options.UDPTimeout)
|
||||
} else {
|
||||
udpTimeout = C.UDPTimeout
|
||||
}
|
||||
var (
|
||||
routeRuleSet []adapter.RuleSet
|
||||
routeExcludeRuleSet []adapter.RuleSet
|
||||
)
|
||||
for _, routeAddressSet := range options.RouteAddressSet {
|
||||
ruleSet, loaded := router.RuleSet(routeAddressSet)
|
||||
if !loaded {
|
||||
return nil, E.New("parse route_address_set: rule-set not found: ", routeAddressSet)
|
||||
}
|
||||
ruleSet.IncRef()
|
||||
routeRuleSet = append(routeRuleSet, ruleSet)
|
||||
}
|
||||
for _, routeExcludeAddressSet := range options.RouteExcludeAddressSet {
|
||||
ruleSet, loaded := router.RuleSet(routeExcludeAddressSet)
|
||||
if !loaded {
|
||||
return nil, E.New("parse route_exclude_address_set: rule-set not found: ", routeExcludeAddressSet)
|
||||
}
|
||||
ruleSet.IncRef()
|
||||
routeExcludeRuleSet = append(routeExcludeRuleSet, ruleSet)
|
||||
}
|
||||
trackerIn := conntrack.NewDefaultTracker(false, 0)
|
||||
return &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeNDIS, tag),
|
||||
ctx: ctx,
|
||||
router: router,
|
||||
logger: logger,
|
||||
api: api,
|
||||
tracker: trackerIn,
|
||||
routeRuleSet: routeRuleSet,
|
||||
routeExcludeRuleSet: routeExcludeRuleSet,
|
||||
stack: &Stack{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
network: networkManager,
|
||||
trackerIn: trackerIn,
|
||||
trackerOut: trackerOut,
|
||||
api: api,
|
||||
udpTimeout: udpTimeout,
|
||||
routeAddress: options.RouteAddress,
|
||||
routeExcludeAddress: options.RouteExcludeAddress,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *Inbound) Start(stage adapter.StartStage) error {
|
||||
switch stage {
|
||||
case adapter.StartStateStart:
|
||||
monitor := taskmonitor.New(t.logger, C.StartTimeout)
|
||||
var (
|
||||
routeAddressSet []*netipx.IPSet
|
||||
routeExcludeAddressSet []*netipx.IPSet
|
||||
)
|
||||
for _, routeRuleSet := range t.routeRuleSet {
|
||||
ipSets := routeRuleSet.ExtractIPSet()
|
||||
if len(ipSets) == 0 {
|
||||
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeRuleSet.Name())
|
||||
}
|
||||
t.routeRuleSetCallback = append(t.routeRuleSetCallback, routeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
||||
routeRuleSet.DecRef()
|
||||
routeAddressSet = append(routeAddressSet, ipSets...)
|
||||
}
|
||||
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
||||
ipSets := routeExcludeRuleSet.ExtractIPSet()
|
||||
if len(ipSets) == 0 {
|
||||
t.logger.Warn("route_exclude_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
|
||||
}
|
||||
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
||||
routeExcludeRuleSet.DecRef()
|
||||
routeExcludeAddressSet = append(routeExcludeAddressSet, ipSets...)
|
||||
}
|
||||
t.stack.routeAddressSet = routeAddressSet
|
||||
t.stack.routeExcludeAddressSet = routeExcludeAddressSet
|
||||
monitor.Start("starting NDIS stack")
|
||||
t.stack.handler = t
|
||||
err := t.stack.Start()
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
return E.Cause(err, "starting NDIS stack")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Inbound) Close() error {
|
||||
if t.api != nil {
|
||||
t.stack.Close()
|
||||
t.api.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Inbound) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
|
||||
return t.router.PreMatch(adapter.InboundContext{
|
||||
Inbound: t.Tag(),
|
||||
InboundType: C.TypeNDIS,
|
||||
Network: network,
|
||||
Source: source,
|
||||
Destination: destination,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||
ctx = log.ContextWithNewID(ctx)
|
||||
var metadata adapter.InboundContext
|
||||
metadata.Inbound = t.Tag()
|
||||
metadata.InboundType = C.TypeNDIS
|
||||
metadata.Source = source
|
||||
metadata.Destination = destination
|
||||
t.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
|
||||
t.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||
done, err := t.tracker.NewConnEx(conn)
|
||||
if err != nil {
|
||||
t.logger.ErrorContext(ctx, E.Cause(err, "track inbound connection"))
|
||||
return
|
||||
}
|
||||
t.router.RouteConnectionEx(ctx, conn, metadata, N.AppendClose(onClose, done))
|
||||
}
|
||||
|
||||
func (t *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||
ctx = log.ContextWithNewID(ctx)
|
||||
var metadata adapter.InboundContext
|
||||
metadata.Inbound = t.Tag()
|
||||
metadata.InboundType = C.TypeNDIS
|
||||
metadata.Source = source
|
||||
metadata.Destination = destination
|
||||
t.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
|
||||
t.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
||||
done, err := t.tracker.NewPacketConnEx(conn)
|
||||
if err != nil {
|
||||
t.logger.ErrorContext(ctx, E.Cause(err, "track inbound connection"))
|
||||
return
|
||||
}
|
||||
t.router.RoutePacketConnectionEx(ctx, conn, metadata, N.AppendClose(onClose, done))
|
||||
}
|
||||
|
||||
func (t *Inbound) updateRouteAddressSet(it adapter.RuleSet) {
|
||||
t.stack.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
|
||||
t.stack.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
|
||||
}
|
267
protocol/ndis/stack.go
Normal file
267
protocol/ndis/stack.go
Normal file
|
@ -0,0 +1,267 @@
|
|||
//go:build windows
|
||||
|
||||
package ndis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"github.com/wiresock/ndisapi-go"
|
||||
"github.com/wiresock/ndisapi-go/driver"
|
||||
"go4.org/netipx"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
type Stack struct {
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
network adapter.NetworkManager
|
||||
trackerIn conntrack.Tracker
|
||||
trackerOut conntrack.Tracker
|
||||
api *ndisapi.NdisApi
|
||||
handler tun.Handler
|
||||
udpTimeout time.Duration
|
||||
filter *driver.QueuedPacketFilter
|
||||
stack *stack.Stack
|
||||
endpoint *ndisEndpoint
|
||||
routeAddress []netip.Prefix
|
||||
routeExcludeAddress []netip.Prefix
|
||||
routeAddressSet []*netipx.IPSet
|
||||
routeExcludeAddressSet []*netipx.IPSet
|
||||
currentInterface *control.Interface
|
||||
}
|
||||
|
||||
func (s *Stack) Start() error {
|
||||
err := s.start(s.network.InterfaceMonitor().DefaultInterface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.network.InterfaceMonitor().RegisterCallback(s.updateDefaultInterface)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stack) updateDefaultInterface(defaultInterface *control.Interface, flags int) {
|
||||
if s.currentInterface.Equals(*defaultInterface) {
|
||||
return
|
||||
}
|
||||
err := s.start(defaultInterface)
|
||||
if err != nil {
|
||||
s.logger.Error(E.Cause(err, "reconfigure NDIS at: ", defaultInterface.Name))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stack) start(defaultInterface *control.Interface) error {
|
||||
_ = s.Close()
|
||||
adapters, err := s.api.GetTcpipBoundAdaptersInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if defaultInterface != nil {
|
||||
for index := 0; index < int(adapters.AdapterCount); index++ {
|
||||
name := s.api.ConvertWindows2000AdapterName(string(adapters.AdapterNameList[index][:]))
|
||||
if name != defaultInterface.Name {
|
||||
continue
|
||||
}
|
||||
s.filter, err = driver.NewQueuedPacketFilter(s.api, adapters, nil, s.processOut)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
address := tcpip.LinkAddress(adapters.CurrentAddress[index][:])
|
||||
mtu := uint32(adapters.MTU[index])
|
||||
endpoint := &ndisEndpoint{
|
||||
filter: s.filter,
|
||||
mtu: mtu,
|
||||
address: address,
|
||||
}
|
||||
s.stack, err = tun.NewGVisorStack(endpoint)
|
||||
if err != nil {
|
||||
s.filter = nil
|
||||
return err
|
||||
}
|
||||
s.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(s.ctx, s.stack, s.handler).HandlePacket)
|
||||
s.stack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(s.ctx, s.stack, s.handler, s.udpTimeout).HandlePacket)
|
||||
err = s.filter.StartFilter(index)
|
||||
if err != nil {
|
||||
s.filter = nil
|
||||
s.stack.Close()
|
||||
s.stack = nil
|
||||
return err
|
||||
}
|
||||
s.endpoint = endpoint
|
||||
s.logger.Info("started at ", defaultInterface.Name)
|
||||
break
|
||||
}
|
||||
}
|
||||
s.currentInterface = defaultInterface
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stack) Close() error {
|
||||
if s.filter != nil {
|
||||
s.filter.StopFilter()
|
||||
s.filter.Close()
|
||||
s.filter = nil
|
||||
}
|
||||
if s.stack != nil {
|
||||
s.stack.Close()
|
||||
for _, endpoint := range s.stack.CleanupEndpoints() {
|
||||
endpoint.Abort()
|
||||
}
|
||||
s.stack = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stack) processOut(handle ndisapi.Handle, packet *ndisapi.IntermediateBuffer) ndisapi.FilterAction {
|
||||
if packet.Length < header.EthernetMinimumSize {
|
||||
return ndisapi.FilterActionPass
|
||||
}
|
||||
if s.endpoint.dispatcher == nil || s.filterPacket(packet.Buffer[:packet.Length]) {
|
||||
return ndisapi.FilterActionPass
|
||||
}
|
||||
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet.Buffer[:packet.Length]),
|
||||
})
|
||||
_, ok := packetBuffer.LinkHeader().Consume(header.EthernetMinimumSize)
|
||||
if !ok {
|
||||
packetBuffer.DecRef()
|
||||
return ndisapi.FilterActionPass
|
||||
}
|
||||
ethHdr := header.Ethernet(packetBuffer.LinkHeader().Slice())
|
||||
destinationAddress := ethHdr.DestinationAddress()
|
||||
if destinationAddress == header.EthernetBroadcastAddress {
|
||||
packetBuffer.PktType = tcpip.PacketBroadcast
|
||||
} else if header.IsMulticastEthernetAddress(destinationAddress) {
|
||||
packetBuffer.PktType = tcpip.PacketMulticast
|
||||
} else if destinationAddress == s.endpoint.address {
|
||||
packetBuffer.PktType = tcpip.PacketHost
|
||||
} else {
|
||||
packetBuffer.PktType = tcpip.PacketOtherHost
|
||||
}
|
||||
s.endpoint.dispatcher.DeliverNetworkPacket(ethHdr.Type(), packetBuffer)
|
||||
packetBuffer.DecRef()
|
||||
return ndisapi.FilterActionDrop
|
||||
}
|
||||
|
||||
func (s *Stack) filterPacket(packet []byte) bool {
|
||||
var ipHdr header.Network
|
||||
switch header.IPVersion(packet[header.EthernetMinimumSize:]) {
|
||||
case ipv4.Version:
|
||||
ipHdr = header.IPv4(packet[header.EthernetMinimumSize:])
|
||||
case ipv6.Version:
|
||||
ipHdr = header.IPv6(packet[header.EthernetMinimumSize:])
|
||||
default:
|
||||
return true
|
||||
}
|
||||
sourceAddr := tun.AddrFromAddress(ipHdr.SourceAddress())
|
||||
destinationAddr := tun.AddrFromAddress(ipHdr.DestinationAddress())
|
||||
if !destinationAddr.IsGlobalUnicast() {
|
||||
return true
|
||||
}
|
||||
var (
|
||||
transportProtocol tcpip.TransportProtocolNumber
|
||||
transportHdr header.Transport
|
||||
)
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case tcp.ProtocolNumber:
|
||||
transportProtocol = header.TCPProtocolNumber
|
||||
transportHdr = header.TCP(ipHdr.Payload())
|
||||
case udp.ProtocolNumber:
|
||||
transportProtocol = header.UDPProtocolNumber
|
||||
transportHdr = header.UDP(ipHdr.Payload())
|
||||
default:
|
||||
return false
|
||||
}
|
||||
source := netip.AddrPortFrom(sourceAddr, transportHdr.SourcePort())
|
||||
destination := netip.AddrPortFrom(destinationAddr, transportHdr.DestinationPort())
|
||||
if transportProtocol == header.TCPProtocolNumber {
|
||||
if s.trackerIn.CheckConn(source, destination) {
|
||||
if debug.Enabled {
|
||||
s.logger.Trace("fall exists TCP ", source, " ", destination)
|
||||
}
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if s.trackerIn.CheckPacketConn(source) {
|
||||
if debug.Enabled {
|
||||
s.logger.Trace("fall exists UDP ", source, " ", destination)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(s.routeAddress) > 0 {
|
||||
var match bool
|
||||
for _, route := range s.routeAddress {
|
||||
if route.Contains(destinationAddr) {
|
||||
match = true
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(s.routeAddressSet) > 0 {
|
||||
var match bool
|
||||
for _, ipSet := range s.routeAddressSet {
|
||||
if ipSet.Contains(destinationAddr) {
|
||||
match = true
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(s.routeExcludeAddress) > 0 {
|
||||
for _, address := range s.routeExcludeAddress {
|
||||
if address.Contains(destinationAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(s.routeExcludeAddressSet) > 0 {
|
||||
for _, ipSet := range s.routeAddressSet {
|
||||
if ipSet.Contains(destinationAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.trackerOut.CheckDestination(destination) {
|
||||
if debug.Enabled {
|
||||
s.logger.Trace("passing pending ", source, " ", destination)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if transportProtocol == header.TCPProtocolNumber {
|
||||
if s.trackerOut.CheckConn(source, destination) {
|
||||
if debug.Enabled {
|
||||
s.logger.Trace("passing TCP ", source, " ", destination)
|
||||
}
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if s.trackerOut.CheckPacketConn(source) {
|
||||
if debug.Enabled {
|
||||
s.logger.Trace("passing UDP ", source, " ", destination)
|
||||
}
|
||||
}
|
||||
}
|
||||
if debug.Enabled {
|
||||
s.logger.Trace("fall ", source, " ", destination)
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -306,7 +306,6 @@ func (t *Inbound) Start(stage adapter.StartStage) error {
|
|||
t.tunOptions.Name = tun.CalculateInterfaceName("")
|
||||
}
|
||||
if t.platformInterface == nil || runtime.GOOS != "android" {
|
||||
t.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
|
||||
for _, routeRuleSet := range t.routeRuleSet {
|
||||
ipSets := routeRuleSet.ExtractIPSet()
|
||||
if len(ipSets) == 0 {
|
||||
|
@ -316,11 +315,10 @@ func (t *Inbound) Start(stage adapter.StartStage) error {
|
|||
routeRuleSet.DecRef()
|
||||
t.routeAddressSet = append(t.routeAddressSet, ipSets...)
|
||||
}
|
||||
t.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
|
||||
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
||||
ipSets := routeExcludeRuleSet.ExtractIPSet()
|
||||
if len(ipSets) == 0 {
|
||||
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
|
||||
t.logger.Warn("route_exclude_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
|
||||
}
|
||||
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
||||
routeExcludeRuleSet.DecRef()
|
||||
|
|
|
@ -35,6 +35,7 @@ var _ adapter.NetworkManager = (*NetworkManager)(nil)
|
|||
|
||||
type NetworkManager struct {
|
||||
logger logger.ContextLogger
|
||||
tracker conntrack.Tracker
|
||||
interfaceFinder *control.DefaultInterfaceFinder
|
||||
networkInterfaces atomic.TypedValue[[]adapter.NetworkInterface]
|
||||
|
||||
|
@ -57,6 +58,7 @@ type NetworkManager struct {
|
|||
func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) {
|
||||
nm := &NetworkManager{
|
||||
logger: logger,
|
||||
tracker: service.FromContext[conntrack.Tracker](ctx),
|
||||
interfaceFinder: control.NewDefaultInterfaceFinder(),
|
||||
autoDetectInterface: routeOptions.AutoDetectInterface,
|
||||
defaultOptions: adapter.NetworkOptions{
|
||||
|
@ -355,7 +357,7 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
|
|||
}
|
||||
|
||||
func (r *NetworkManager) ResetNetwork() {
|
||||
conntrack.Close()
|
||||
r.tracker.Close()
|
||||
|
||||
for _, endpoint := range r.endpoint.Endpoints() {
|
||||
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-box/common/process"
|
||||
"github.com/sagernet/sing-box/common/sniff"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
|
@ -72,7 +71,10 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
|
|||
injectable.NewConnectionEx(ctx, conn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
conntrack.KillerCheck()
|
||||
err := r.connTracker.KillerCheck()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metadata.Network = N.NetworkTCP
|
||||
switch metadata.Destination.Fqdn {
|
||||
case mux.Destination.Fqdn:
|
||||
|
@ -190,7 +192,10 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
|||
injectable.NewPacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
conntrack.KillerCheck()
|
||||
err := r.connTracker.KillerCheck()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: move to UoT
|
||||
metadata.Network = N.NetworkUDP
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/geoip"
|
||||
"github.com/sagernet/sing-box/common/geosite"
|
||||
|
@ -38,6 +39,7 @@ type Router struct {
|
|||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
dnsLogger log.ContextLogger
|
||||
connTracker conntrack.Tracker
|
||||
inbound adapter.InboundManager
|
||||
outbound adapter.OutboundManager
|
||||
connection adapter.ConnectionManager
|
||||
|
@ -75,6 +77,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
|
|||
ctx: ctx,
|
||||
logger: logFactory.NewLogger("router"),
|
||||
dnsLogger: logFactory.NewLogger("dns"),
|
||||
connTracker: service.FromContext[conntrack.Tracker](ctx),
|
||||
inbound: service.FromContext[adapter.InboundManager](ctx),
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
connection: service.FromContext[adapter.ConnectionManager](ctx),
|
||||
|
|
Loading…
Reference in a new issue