mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-01-30 04:36:54 +00:00
Add NDIS inbound
This commit is contained in:
parent
e483c909b4
commit
79d3649a8b
7
box.go
7
box.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||||
"github.com/sagernet/sing-box/adapter/inbound"
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
"github.com/sagernet/sing-box/adapter/outbound"
|
"github.com/sagernet/sing-box/adapter/outbound"
|
||||||
|
"github.com/sagernet/sing-box/common/conntrack"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
"github.com/sagernet/sing-box/common/taskmonitor"
|
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
@ -84,7 +85,6 @@ func New(options Options) (*Box, error) {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
}
|
}
|
||||||
ctx = service.ContextWithDefaultRegistry(ctx)
|
ctx = service.ContextWithDefaultRegistry(ctx)
|
||||||
|
|
||||||
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
|
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
|
||||||
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
||||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||||
|
@ -101,7 +101,10 @@ func New(options Options) (*Box, error) {
|
||||||
|
|
||||||
ctx = pause.WithDefaultManager(ctx)
|
ctx = pause.WithDefaultManager(ctx)
|
||||||
experimentalOptions := common.PtrValueOrDefault(options.Experimental)
|
experimentalOptions := common.PtrValueOrDefault(options.Experimental)
|
||||||
applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug))
|
debugOptions := common.PtrValueOrDefault(experimentalOptions.Debug)
|
||||||
|
applyDebugOptions(debugOptions)
|
||||||
|
ctx = conntrack.ContextWithDefaultTracker(ctx, debugOptions.OOMKiller, uint64(debugOptions.MemoryLimit))
|
||||||
|
|
||||||
var needCacheFile bool
|
var needCacheFile bool
|
||||||
var needClashAPI bool
|
var needClashAPI bool
|
||||||
var needV2RayAPI bool
|
var needV2RayAPI bool
|
||||||
|
|
|
@ -1,54 +0,0 @@
|
||||||
package conntrack
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/x/list"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Conn struct {
|
|
||||||
net.Conn
|
|
||||||
element *list.Element[io.Closer]
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConn(conn net.Conn) (net.Conn, error) {
|
|
||||||
connAccess.Lock()
|
|
||||||
element := openConnection.PushBack(conn)
|
|
||||||
connAccess.Unlock()
|
|
||||||
if KillerEnabled {
|
|
||||||
err := KillerCheck()
|
|
||||||
if err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &Conn{
|
|
||||||
Conn: conn,
|
|
||||||
element: element,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
|
||||||
if c.element.Value != nil {
|
|
||||||
connAccess.Lock()
|
|
||||||
if c.element.Value != nil {
|
|
||||||
openConnection.Remove(c.element)
|
|
||||||
c.element.Value = nil
|
|
||||||
}
|
|
||||||
connAccess.Unlock()
|
|
||||||
}
|
|
||||||
return c.Conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) Upstream() any {
|
|
||||||
return c.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) ReaderReplaceable() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) WriterReplaceable() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
14
common/conntrack/context.go
Normal file
14
common/conntrack/context.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ContextWithDefaultTracker(ctx context.Context, killerEnabled bool, memoryLimit uint64) context.Context {
|
||||||
|
if service.FromContext[Tracker](ctx) != nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return service.ContextWith[Tracker](ctx, NewDefaultTracker(killerEnabled, memoryLimit))
|
||||||
|
}
|
245
common/conntrack/default.go
Normal file
245
common/conntrack/default.go
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
runtimeDebug "runtime/debug"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/memory"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/x/list"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ Tracker = (*DefaultTracker)(nil)
|
||||||
|
|
||||||
|
type DefaultTracker struct {
|
||||||
|
connAccess sync.RWMutex
|
||||||
|
connList list.List[net.Conn]
|
||||||
|
connAddress map[netip.AddrPort]netip.AddrPort
|
||||||
|
|
||||||
|
packetConnAccess sync.RWMutex
|
||||||
|
packetConnList list.List[AbstractPacketConn]
|
||||||
|
packetConnAddress map[netip.AddrPort]bool
|
||||||
|
|
||||||
|
pendingAccess sync.RWMutex
|
||||||
|
pendingList list.List[netip.AddrPort]
|
||||||
|
|
||||||
|
killerEnabled bool
|
||||||
|
memoryLimit uint64
|
||||||
|
killerLastCheck time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDefaultTracker(killerEnabled bool, memoryLimit uint64) *DefaultTracker {
|
||||||
|
return &DefaultTracker{
|
||||||
|
connAddress: make(map[netip.AddrPort]netip.AddrPort),
|
||||||
|
packetConnAddress: make(map[netip.AddrPort]bool),
|
||||||
|
killerEnabled: killerEnabled,
|
||||||
|
memoryLimit: memoryLimit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) NewConn(conn net.Conn) (net.Conn, error) {
|
||||||
|
err := t.KillerCheck()
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.connAccess.Lock()
|
||||||
|
element := t.connList.PushBack(conn)
|
||||||
|
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
|
||||||
|
t.connAccess.Unlock()
|
||||||
|
return &Conn{
|
||||||
|
Conn: conn,
|
||||||
|
closeFunc: common.OnceFunc(func() {
|
||||||
|
t.removeConn(element)
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error) {
|
||||||
|
err := t.KillerCheck()
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.connAccess.Lock()
|
||||||
|
element := t.connList.PushBack(conn)
|
||||||
|
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
|
||||||
|
t.connAccess.Unlock()
|
||||||
|
return N.OnceClose(func(it error) {
|
||||||
|
t.removeConn(element)
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
|
||||||
|
err := t.KillerCheck()
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.packetConnAccess.Lock()
|
||||||
|
element := t.packetConnList.PushBack(conn)
|
||||||
|
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
|
||||||
|
t.packetConnAccess.Unlock()
|
||||||
|
return &PacketConn{
|
||||||
|
PacketConn: conn,
|
||||||
|
closeFunc: common.OnceFunc(func() {
|
||||||
|
t.removePacketConn(element)
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error) {
|
||||||
|
err := t.KillerCheck()
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.packetConnAccess.Lock()
|
||||||
|
element := t.packetConnList.PushBack(conn)
|
||||||
|
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
|
||||||
|
t.packetConnAccess.Unlock()
|
||||||
|
return N.OnceClose(func(it error) {
|
||||||
|
t.removePacketConn(element)
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) CheckConn(source netip.AddrPort, destination netip.AddrPort) bool {
|
||||||
|
t.connAccess.RLock()
|
||||||
|
defer t.connAccess.RUnlock()
|
||||||
|
return t.connAddress[source] == destination
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) CheckPacketConn(source netip.AddrPort) bool {
|
||||||
|
t.packetConnAccess.RLock()
|
||||||
|
defer t.packetConnAccess.RUnlock()
|
||||||
|
return t.packetConnAddress[source]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) AddPendingDestination(destination netip.AddrPort) func() {
|
||||||
|
t.pendingAccess.Lock()
|
||||||
|
defer t.pendingAccess.Unlock()
|
||||||
|
element := t.pendingList.PushBack(destination)
|
||||||
|
return func() {
|
||||||
|
t.pendingAccess.Lock()
|
||||||
|
defer t.pendingAccess.Unlock()
|
||||||
|
t.pendingList.Remove(element)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) CheckDestination(destination netip.AddrPort) bool {
|
||||||
|
t.pendingAccess.RLock()
|
||||||
|
defer t.pendingAccess.RUnlock()
|
||||||
|
for element := t.pendingList.Front(); element != nil; element = element.Next() {
|
||||||
|
if element.Value == destination {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) KillerCheck() error {
|
||||||
|
if !t.killerEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
nowTime := time.Now()
|
||||||
|
if nowTime.Sub(t.killerLastCheck) < 3*time.Second {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
t.killerLastCheck = nowTime
|
||||||
|
if memory.Total() > t.memoryLimit {
|
||||||
|
t.Close()
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
runtimeDebug.FreeOSMemory()
|
||||||
|
}()
|
||||||
|
return E.New("out of memory")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) Count() int {
|
||||||
|
t.connAccess.RLock()
|
||||||
|
defer t.connAccess.RUnlock()
|
||||||
|
t.packetConnAccess.RLock()
|
||||||
|
defer t.packetConnAccess.RUnlock()
|
||||||
|
return t.connList.Len() + t.packetConnList.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) Close() {
|
||||||
|
t.connAccess.Lock()
|
||||||
|
for element := t.connList.Front(); element != nil; element = element.Next() {
|
||||||
|
element.Value.Close()
|
||||||
|
}
|
||||||
|
t.connList.Init()
|
||||||
|
t.connAccess.Unlock()
|
||||||
|
t.packetConnAccess.Lock()
|
||||||
|
for element := t.packetConnList.Front(); element != nil; element = element.Next() {
|
||||||
|
element.Value.Close()
|
||||||
|
}
|
||||||
|
t.packetConnList.Init()
|
||||||
|
t.packetConnAccess.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) removeConn(element *list.Element[net.Conn]) {
|
||||||
|
t.connAccess.Lock()
|
||||||
|
defer t.connAccess.Unlock()
|
||||||
|
delete(t.connAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
|
||||||
|
t.connList.Remove(element)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DefaultTracker) removePacketConn(element *list.Element[AbstractPacketConn]) {
|
||||||
|
t.packetConnAccess.Lock()
|
||||||
|
defer t.packetConnAccess.Unlock()
|
||||||
|
delete(t.packetConnAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
|
||||||
|
t.packetConnList.Remove(element)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
closeFunc func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
c.closeFunc()
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Upstream() any {
|
||||||
|
return c.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ReaderReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) WriterReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketConn struct {
|
||||||
|
net.PacketConn
|
||||||
|
closeFunc func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) Close() error {
|
||||||
|
c.closeFunc()
|
||||||
|
return c.PacketConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) Upstream() any {
|
||||||
|
return c.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) ReaderReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) WriterReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
|
@ -1,35 +0,0 @@
|
||||||
package conntrack
|
|
||||||
|
|
||||||
import (
|
|
||||||
runtimeDebug "runtime/debug"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
"github.com/sagernet/sing/common/memory"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
KillerEnabled bool
|
|
||||||
MemoryLimit uint64
|
|
||||||
killerLastCheck time.Time
|
|
||||||
)
|
|
||||||
|
|
||||||
func KillerCheck() error {
|
|
||||||
if !KillerEnabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
nowTime := time.Now()
|
|
||||||
if nowTime.Sub(killerLastCheck) < 3*time.Second {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
killerLastCheck = nowTime
|
|
||||||
if memory.Total() > MemoryLimit {
|
|
||||||
Close()
|
|
||||||
go func() {
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
runtimeDebug.FreeOSMemory()
|
|
||||||
}()
|
|
||||||
return E.New("out of memory")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,55 +0,0 @@
|
||||||
package conntrack
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
|
||||||
"github.com/sagernet/sing/common/x/list"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PacketConn struct {
|
|
||||||
net.PacketConn
|
|
||||||
element *list.Element[io.Closer]
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
|
|
||||||
connAccess.Lock()
|
|
||||||
element := openConnection.PushBack(conn)
|
|
||||||
connAccess.Unlock()
|
|
||||||
if KillerEnabled {
|
|
||||||
err := KillerCheck()
|
|
||||||
if err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &PacketConn{
|
|
||||||
PacketConn: conn,
|
|
||||||
element: element,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PacketConn) Close() error {
|
|
||||||
if c.element.Value != nil {
|
|
||||||
connAccess.Lock()
|
|
||||||
if c.element.Value != nil {
|
|
||||||
openConnection.Remove(c.element)
|
|
||||||
c.element.Value = nil
|
|
||||||
}
|
|
||||||
connAccess.Unlock()
|
|
||||||
}
|
|
||||||
return c.PacketConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PacketConn) Upstream() any {
|
|
||||||
return bufio.NewPacketConn(c.PacketConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PacketConn) ReaderReplaceable() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PacketConn) WriterReplaceable() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
|
@ -1,47 +0,0 @@
|
||||||
package conntrack
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/x/list"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
connAccess sync.RWMutex
|
|
||||||
openConnection list.List[io.Closer]
|
|
||||||
)
|
|
||||||
|
|
||||||
func Count() int {
|
|
||||||
if !Enabled {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return openConnection.Len()
|
|
||||||
}
|
|
||||||
|
|
||||||
func List() []io.Closer {
|
|
||||||
if !Enabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
connAccess.RLock()
|
|
||||||
defer connAccess.RUnlock()
|
|
||||||
connList := make([]io.Closer, 0, openConnection.Len())
|
|
||||||
for element := openConnection.Front(); element != nil; element = element.Next() {
|
|
||||||
connList = append(connList, element.Value)
|
|
||||||
}
|
|
||||||
return connList
|
|
||||||
}
|
|
||||||
|
|
||||||
func Close() {
|
|
||||||
if !Enabled {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
connAccess.Lock()
|
|
||||||
defer connAccess.Unlock()
|
|
||||||
for element := openConnection.Front(); element != nil; element = element.Next() {
|
|
||||||
common.Close(element.Value)
|
|
||||||
element.Value = nil
|
|
||||||
}
|
|
||||||
openConnection.Init()
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
//go:build !with_conntrack
|
|
||||||
|
|
||||||
package conntrack
|
|
||||||
|
|
||||||
const Enabled = false
|
|
|
@ -1,5 +0,0 @@
|
||||||
//go:build with_conntrack
|
|
||||||
|
|
||||||
package conntrack
|
|
||||||
|
|
||||||
const Enabled = true
|
|
32
common/conntrack/tracker.go
Normal file
32
common/conntrack/tracker.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: add to N
|
||||||
|
type AbstractPacketConn interface {
|
||||||
|
Close() error
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
SetDeadline(t time.Time) error
|
||||||
|
SetReadDeadline(t time.Time) error
|
||||||
|
SetWriteDeadline(t time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tracker interface {
|
||||||
|
NewConn(conn net.Conn) (net.Conn, error)
|
||||||
|
NewPacketConn(conn net.PacketConn) (net.PacketConn, error)
|
||||||
|
NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error)
|
||||||
|
NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error)
|
||||||
|
CheckConn(source netip.AddrPort, destination netip.AddrPort) bool
|
||||||
|
CheckPacketConn(source netip.AddrPort) bool
|
||||||
|
AddPendingDestination(destination netip.AddrPort) func()
|
||||||
|
CheckDestination(destination netip.AddrPort) bool
|
||||||
|
KillerCheck() error
|
||||||
|
Count() int
|
||||||
|
Close()
|
||||||
|
}
|
|
@ -28,6 +28,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type DefaultDialer struct {
|
type DefaultDialer struct {
|
||||||
|
tracker conntrack.Tracker
|
||||||
dialer4 tcpDialer
|
dialer4 tcpDialer
|
||||||
dialer6 tcpDialer
|
dialer6 tcpDialer
|
||||||
udpDialer4 net.Dialer
|
udpDialer4 net.Dialer
|
||||||
|
@ -46,6 +47,7 @@ type DefaultDialer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) {
|
func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) {
|
||||||
|
tracker := service.FromContext[conntrack.Tracker](ctx)
|
||||||
networkManager := service.FromContext[adapter.NetworkManager](ctx)
|
networkManager := service.FromContext[adapter.NetworkManager](ctx)
|
||||||
platformInterface := service.FromContext[platform.Interface](ctx)
|
platformInterface := service.FromContext[platform.Interface](ctx)
|
||||||
|
|
||||||
|
@ -197,6 +199,7 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &DefaultDialer{
|
return &DefaultDialer{
|
||||||
|
tracker: tracker,
|
||||||
dialer4: tcpDialer4,
|
dialer4: tcpDialer4,
|
||||||
dialer6: tcpDialer6,
|
dialer6: tcpDialer6,
|
||||||
udpDialer4: udpDialer4,
|
udpDialer4: udpDialer4,
|
||||||
|
@ -219,18 +222,26 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address
|
||||||
return nil, E.New("invalid address")
|
return nil, E.New("invalid address")
|
||||||
}
|
}
|
||||||
if d.networkStrategy == nil {
|
if d.networkStrategy == nil {
|
||||||
|
if address.IsFqdn() {
|
||||||
|
return nil, E.New("unexpected domain destination")
|
||||||
|
}
|
||||||
|
// Since pending check is only used by ndis, it is not performed for non-windows connections which are only supported on platform clients
|
||||||
|
if d.tracker != nil {
|
||||||
|
done := d.tracker.AddPendingDestination(address.AddrPort())
|
||||||
|
defer done()
|
||||||
|
}
|
||||||
switch N.NetworkName(network) {
|
switch N.NetworkName(network) {
|
||||||
case N.NetworkUDP:
|
case N.NetworkUDP:
|
||||||
if !address.IsIPv6() {
|
if !address.IsIPv6() {
|
||||||
return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
|
return d.trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
|
||||||
} else {
|
} else {
|
||||||
return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
|
return d.trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !address.IsIPv6() {
|
if !address.IsIPv6() {
|
||||||
return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
|
return d.trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
|
||||||
} else {
|
} else {
|
||||||
return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
|
return d.trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
|
return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
|
||||||
|
@ -282,17 +293,17 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
|
||||||
if !fastFallback && !isPrimary {
|
if !fastFallback && !isPrimary {
|
||||||
d.networkLastFallback.Store(time.Now())
|
d.networkLastFallback.Store(time.Now())
|
||||||
}
|
}
|
||||||
return trackConn(conn, nil)
|
return d.trackConn(conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
if d.networkStrategy == nil {
|
if d.networkStrategy == nil {
|
||||||
if destination.IsIPv6() {
|
if destination.IsIPv6() {
|
||||||
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
|
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
|
||||||
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
|
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
|
||||||
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
|
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
|
||||||
} else {
|
} else {
|
||||||
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
|
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
|
return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
|
||||||
|
@ -329,23 +340,23 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return trackPacketConn(packetConn, nil)
|
return d.trackPacketConn(packetConn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
|
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
|
||||||
return d.udpListener.ListenPacket(context.Background(), network, address)
|
return d.udpListener.ListenPacket(context.Background(), network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func trackConn(conn net.Conn, err error) (net.Conn, error) {
|
func (d *DefaultDialer) trackConn(conn net.Conn, err error) (net.Conn, error) {
|
||||||
if !conntrack.Enabled || err != nil {
|
if d.tracker == nil || err != nil {
|
||||||
return conn, err
|
return conn, err
|
||||||
}
|
}
|
||||||
return conntrack.NewConn(conn)
|
return d.tracker.NewConn(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
|
func (d *DefaultDialer) trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
|
||||||
if !conntrack.Enabled || err != nil {
|
if err != nil {
|
||||||
return conn, err
|
return conn, err
|
||||||
}
|
}
|
||||||
return conntrack.NewPacketConn(conn)
|
return d.tracker.NewPacketConn(conn)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ const (
|
||||||
TypeVLESS = "vless"
|
TypeVLESS = "vless"
|
||||||
TypeTUIC = "tuic"
|
TypeTUIC = "tuic"
|
||||||
TypeHysteria2 = "hysteria2"
|
TypeHysteria2 = "hysteria2"
|
||||||
|
TypeNDIS = "ndis"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -80,6 +81,8 @@ func ProxyDisplayName(proxyType string) string {
|
||||||
return "Selector"
|
return "Selector"
|
||||||
case TypeURLTest:
|
case TypeURLTest:
|
||||||
return "URLTest"
|
return "URLTest"
|
||||||
|
case TypeNDIS:
|
||||||
|
return "NDIS"
|
||||||
default:
|
default:
|
||||||
return "Unknown"
|
return "Unknown"
|
||||||
}
|
}
|
||||||
|
|
5
debug.go
5
debug.go
|
@ -3,7 +3,6 @@ package box
|
||||||
import (
|
import (
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/common/conntrack"
|
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,9 +25,5 @@ func applyDebugOptions(options option.DebugOptions) {
|
||||||
}
|
}
|
||||||
if options.MemoryLimit != 0 {
|
if options.MemoryLimit != 0 {
|
||||||
debug.SetMemoryLimit(int64(float64(options.MemoryLimit) / 1.5))
|
debug.SetMemoryLimit(int64(float64(options.MemoryLimit) / 1.5))
|
||||||
conntrack.MemoryLimit = uint64(options.MemoryLimit)
|
|
||||||
}
|
|
||||||
if options.OOMKiller != nil {
|
|
||||||
conntrack.KillerEnabled = *options.OOMKiller
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,8 +5,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
runtimeDebug "runtime/debug"
|
runtimeDebug "runtime/debug"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/common/conntrack"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *CommandClient) CloseConnections() error {
|
func (c *CommandClient) CloseConnections() error {
|
||||||
|
@ -19,7 +17,7 @@ func (c *CommandClient) CloseConnections() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *CommandServer) handleCloseConnections(conn net.Conn) error {
|
func (s *CommandServer) handleCloseConnections(conn net.Conn) error {
|
||||||
conntrack.Close()
|
tracker.Close()
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
runtimeDebug.FreeOSMemory()
|
runtimeDebug.FreeOSMemory()
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/common/conntrack"
|
|
||||||
"github.com/sagernet/sing-box/experimental/clashapi"
|
"github.com/sagernet/sing-box/experimental/clashapi"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/memory"
|
"github.com/sagernet/sing/common/memory"
|
||||||
|
@ -28,7 +27,7 @@ func (s *CommandServer) readStatus() StatusMessage {
|
||||||
var message StatusMessage
|
var message StatusMessage
|
||||||
message.Memory = int64(memory.Inuse())
|
message.Memory = int64(memory.Inuse())
|
||||||
message.Goroutines = int32(runtime.NumGoroutine())
|
message.Goroutines = int32(runtime.NumGoroutine())
|
||||||
message.ConnectionsOut = int32(conntrack.Count())
|
message.ConnectionsOut = int32(tracker.Count())
|
||||||
|
|
||||||
if s.service != nil {
|
if s.service != nil {
|
||||||
message.TrafficAvailable = true
|
message.TrafficAvailable = true
|
||||||
|
|
|
@ -7,17 +7,21 @@ import (
|
||||||
"github.com/sagernet/sing-box/common/conntrack"
|
"github.com/sagernet/sing-box/common/conntrack"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var tracker *conntrack.DefaultTracker
|
||||||
|
|
||||||
func SetMemoryLimit(enabled bool) {
|
func SetMemoryLimit(enabled bool) {
|
||||||
|
if tracker != nil {
|
||||||
|
tracker.Close()
|
||||||
|
}
|
||||||
const memoryLimit = 45 * 1024 * 1024
|
const memoryLimit = 45 * 1024 * 1024
|
||||||
const memoryLimitGo = memoryLimit / 1.5
|
const memoryLimitGo = memoryLimit / 1.5
|
||||||
if enabled {
|
if enabled {
|
||||||
runtimeDebug.SetGCPercent(10)
|
runtimeDebug.SetGCPercent(10)
|
||||||
runtimeDebug.SetMemoryLimit(memoryLimitGo)
|
runtimeDebug.SetMemoryLimit(memoryLimitGo)
|
||||||
conntrack.KillerEnabled = true
|
tracker = conntrack.NewDefaultTracker(true, memoryLimit)
|
||||||
conntrack.MemoryLimit = memoryLimit
|
|
||||||
} else {
|
} else {
|
||||||
runtimeDebug.SetGCPercent(100)
|
runtimeDebug.SetGCPercent(100)
|
||||||
runtimeDebug.SetMemoryLimit(math.MaxInt64)
|
runtimeDebug.SetMemoryLimit(math.MaxInt64)
|
||||||
conntrack.KillerEnabled = false
|
tracker = conntrack.NewDefaultTracker(false, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
|
|
||||||
"github.com/sagernet/sing-box"
|
"github.com/sagernet/sing-box"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/conntrack"
|
||||||
"github.com/sagernet/sing-box/common/process"
|
"github.com/sagernet/sing-box/common/process"
|
||||||
"github.com/sagernet/sing-box/common/urltest"
|
"github.com/sagernet/sing-box/common/urltest"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
@ -60,6 +61,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box
|
||||||
useProcFS: platformInterface.UseProcFS(),
|
useProcFS: platformInterface.UseProcFS(),
|
||||||
}
|
}
|
||||||
service.MustRegister[platform.Interface](ctx, platformWrapper)
|
service.MustRegister[platform.Interface](ctx, platformWrapper)
|
||||||
|
service.MustRegister[conntrack.Tracker](ctx, tracker)
|
||||||
instance, err := box.New(box.Options{
|
instance, err := box.New(box.Options{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
Options: options,
|
Options: options,
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -41,6 +41,7 @@ require (
|
||||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
|
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
|
||||||
github.com/spf13/cobra v1.8.1
|
github.com/spf13/cobra v1.8.1
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
|
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||||
golang.org/x/crypto v0.31.0
|
golang.org/x/crypto v0.31.0
|
||||||
|
|
3
go.sum
3
go.sum
|
@ -158,6 +158,8 @@ github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923 h1:tHNk7XK9GkmKUR6Gh8gV
|
||||||
github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264=
|
github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264=
|
||||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
|
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08 h1:is+7xN6CAKtgxt3mDSl9OQNvjfi6LggugSP07QhDtws=
|
||||||
|
github.com/wiresock/ndisapi-go v0.0.0-20241230094942-3299a7566e08/go.mod h1:lFE7JYt3LC2UYJ31mRDwl/K35pbtxDnkSDlXrYzgyqg=
|
||||||
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
|
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
|
||||||
github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
|
github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
|
||||||
github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
|
github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
|
||||||
|
@ -191,6 +193,7 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
|
12
include/ndis.go
Normal file
12
include/ndis.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
//go:build windows && with_gvisor
|
||||||
|
|
||||||
|
package include
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
|
"github.com/sagernet/sing-box/protocol/ndis"
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerNDISInbound(registry *inbound.Registry) {
|
||||||
|
ndis.RegisterInbound(registry)
|
||||||
|
}
|
20
include/ndis_nongvisor_stub.go
Normal file
20
include/ndis_nongvisor_stub.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
//go:build windows && !with_gvisor
|
||||||
|
|
||||||
|
package include
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerNDISInbound(registry *inbound.Registry) {
|
||||||
|
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
|
||||||
|
return nil, tun.ErrGVisorNotIncluded
|
||||||
|
})
|
||||||
|
}
|
20
include/ndis_nonwindows_stub.go
Normal file
20
include/ndis_nonwindows_stub.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package include
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerNDISInbound(registry *inbound.Registry) {
|
||||||
|
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
|
||||||
|
return nil, E.New("NDIS is only supported in windows")
|
||||||
|
})
|
||||||
|
}
|
|
@ -51,6 +51,7 @@ func InboundRegistry() *inbound.Registry {
|
||||||
|
|
||||||
registerQUICInbounds(registry)
|
registerQUICInbounds(registry)
|
||||||
registerStubForRemovedInbounds(registry)
|
registerStubForRemovedInbounds(registry)
|
||||||
|
registerNDISInbound(registry)
|
||||||
|
|
||||||
return registry
|
return registry
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ type DebugOptions struct {
|
||||||
PanicOnFault *bool `json:"panic_on_fault,omitempty"`
|
PanicOnFault *bool `json:"panic_on_fault,omitempty"`
|
||||||
TraceBack string `json:"trace_back,omitempty"`
|
TraceBack string `json:"trace_back,omitempty"`
|
||||||
MemoryLimit MemoryBytes `json:"memory_limit,omitempty"`
|
MemoryLimit MemoryBytes `json:"memory_limit,omitempty"`
|
||||||
OOMKiller *bool `json:"oom_killer,omitempty"`
|
OOMKiller bool `json:"oom_killer,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MemoryBytes uint64
|
type MemoryBytes uint64
|
||||||
|
|
17
option/ndis.go
Normal file
17
option/ndis.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package option
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/json/badoption"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NDISInboundOptions struct {
|
||||||
|
Network NetworkList `json:"network,omitempty"`
|
||||||
|
RouteAddress badoption.Listable[netip.Prefix] `json:"route_address,omitempty"`
|
||||||
|
RouteAddressSet badoption.Listable[string] `json:"route_address_set,omitempty"`
|
||||||
|
RouteExcludeAddress badoption.Listable[netip.Prefix] `json:"route_exclude_address,omitempty"`
|
||||||
|
RouteExcludeAddressSet badoption.Listable[string] `json:"route_exclude_address_set,omitempty"`
|
||||||
|
InterfaceName string `json:"interface_name,omitempty"`
|
||||||
|
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
|
||||||
|
}
|
110
protocol/ndis/endpoint.go
Normal file
110
protocol/ndis/endpoint.go
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package ndis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/gvisor/pkg/buffer"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
"github.com/wiresock/ndisapi-go"
|
||||||
|
"github.com/wiresock/ndisapi-go/driver"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ stack.LinkEndpoint = (*ndisEndpoint)(nil)
|
||||||
|
|
||||||
|
type ndisEndpoint struct {
|
||||||
|
filter *driver.QueuedPacketFilter
|
||||||
|
mtu uint32
|
||||||
|
address tcpip.LinkAddress
|
||||||
|
dispatcher stack.NetworkDispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) MTU() uint32 {
|
||||||
|
return e.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) SetMTU(mtu uint32) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) MaxHeaderLength() uint16 {
|
||||||
|
return header.EthernetMinimumSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) LinkAddress() tcpip.LinkAddress {
|
||||||
|
return e.address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||||
|
e.dispatcher = dispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) IsAttached() bool {
|
||||||
|
return e.dispatcher != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) Wait() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) ARPHardwareType() header.ARPHardwareType {
|
||||||
|
return header.ARPHardwareEther
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) AddHeader(pkt *stack.PacketBuffer) {
|
||||||
|
eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
|
||||||
|
fields := header.EthernetFields{
|
||||||
|
SrcAddr: pkt.EgressRoute.LocalLinkAddress,
|
||||||
|
DstAddr: pkt.EgressRoute.RemoteLinkAddress,
|
||||||
|
Type: pkt.NetworkProtocolNumber,
|
||||||
|
}
|
||||||
|
eth.Encode(&fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
|
||||||
|
_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) Close() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) SetOnCloseAction(f func()) {
|
||||||
|
}
|
||||||
|
|
||||||
|
var bufferPool = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return new(ndisapi.IntermediateBuffer)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ndisEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
|
for _, packetBuffer := range list.AsSlice() {
|
||||||
|
ndisBuf := bufferPool.Get().(*ndisapi.IntermediateBuffer)
|
||||||
|
viewList, offset := packetBuffer.AsViewList()
|
||||||
|
var view *buffer.View
|
||||||
|
for view = viewList.Front(); view != nil && offset >= view.Size(); view = view.Next() {
|
||||||
|
offset -= view.Size()
|
||||||
|
}
|
||||||
|
index := copy(ndisBuf.Buffer[:], view.AsSlice()[offset:])
|
||||||
|
for view = view.Next(); view != nil; view = view.Next() {
|
||||||
|
index += copy(ndisBuf.Buffer[index:], view.AsSlice())
|
||||||
|
}
|
||||||
|
ndisBuf.Length = uint32(index)
|
||||||
|
err := e.filter.InsertPacketToMstcp(ndisBuf)
|
||||||
|
bufferPool.Put(ndisBuf)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &tcpip.ErrAborted{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return list.Len(), nil
|
||||||
|
}
|
203
protocol/ndis/inbound.go
Normal file
203
protocol/ndis/inbound.go
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package ndis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
|
"github.com/sagernet/sing-box/common/conntrack"
|
||||||
|
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/x/list"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
|
||||||
|
"github.com/wiresock/ndisapi-go"
|
||||||
|
"go4.org/netipx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RegisterInbound(registry *inbound.Registry) {
|
||||||
|
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, NewInbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Inbound struct {
|
||||||
|
inbound.Adapter
|
||||||
|
ctx context.Context
|
||||||
|
router adapter.Router
|
||||||
|
logger log.ContextLogger
|
||||||
|
api *ndisapi.NdisApi
|
||||||
|
tracker conntrack.Tracker
|
||||||
|
routeAddress []netip.Prefix
|
||||||
|
routeExcludeAddress []netip.Prefix
|
||||||
|
routeRuleSet []adapter.RuleSet
|
||||||
|
routeRuleSetCallback []*list.Element[adapter.RuleSetUpdateCallback]
|
||||||
|
routeExcludeRuleSet []adapter.RuleSet
|
||||||
|
routeExcludeRuleSetCallback []*list.Element[adapter.RuleSetUpdateCallback]
|
||||||
|
stack *Stack
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
|
||||||
|
api, err := ndisapi.NewNdisApi()
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "create NDIS API")
|
||||||
|
}
|
||||||
|
//if !api.IsDriverLoaded() {
|
||||||
|
// return nil, E.New("missing NDIS driver")
|
||||||
|
//}
|
||||||
|
networkManager := service.FromContext[adapter.NetworkManager](ctx)
|
||||||
|
trackerOut := service.FromContext[conntrack.Tracker](ctx)
|
||||||
|
var udpTimeout time.Duration
|
||||||
|
if options.UDPTimeout != 0 {
|
||||||
|
udpTimeout = time.Duration(options.UDPTimeout)
|
||||||
|
} else {
|
||||||
|
udpTimeout = C.UDPTimeout
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
routeRuleSet []adapter.RuleSet
|
||||||
|
routeExcludeRuleSet []adapter.RuleSet
|
||||||
|
)
|
||||||
|
for _, routeAddressSet := range options.RouteAddressSet {
|
||||||
|
ruleSet, loaded := router.RuleSet(routeAddressSet)
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("parse route_address_set: rule-set not found: ", routeAddressSet)
|
||||||
|
}
|
||||||
|
ruleSet.IncRef()
|
||||||
|
routeRuleSet = append(routeRuleSet, ruleSet)
|
||||||
|
}
|
||||||
|
for _, routeExcludeAddressSet := range options.RouteExcludeAddressSet {
|
||||||
|
ruleSet, loaded := router.RuleSet(routeExcludeAddressSet)
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("parse route_exclude_address_set: rule-set not found: ", routeExcludeAddressSet)
|
||||||
|
}
|
||||||
|
ruleSet.IncRef()
|
||||||
|
routeExcludeRuleSet = append(routeExcludeRuleSet, ruleSet)
|
||||||
|
}
|
||||||
|
trackerIn := conntrack.NewDefaultTracker(false, 0)
|
||||||
|
return &Inbound{
|
||||||
|
Adapter: inbound.NewAdapter(C.TypeNDIS, tag),
|
||||||
|
ctx: ctx,
|
||||||
|
router: router,
|
||||||
|
logger: logger,
|
||||||
|
api: api,
|
||||||
|
tracker: trackerIn,
|
||||||
|
routeRuleSet: routeRuleSet,
|
||||||
|
routeExcludeRuleSet: routeExcludeRuleSet,
|
||||||
|
stack: &Stack{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: logger,
|
||||||
|
network: networkManager,
|
||||||
|
trackerIn: trackerIn,
|
||||||
|
trackerOut: trackerOut,
|
||||||
|
api: api,
|
||||||
|
udpTimeout: udpTimeout,
|
||||||
|
routeAddress: options.RouteAddress,
|
||||||
|
routeExcludeAddress: options.RouteExcludeAddress,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
switch stage {
|
||||||
|
case adapter.StartStateStart:
|
||||||
|
monitor := taskmonitor.New(t.logger, C.StartTimeout)
|
||||||
|
var (
|
||||||
|
routeAddressSet []*netipx.IPSet
|
||||||
|
routeExcludeAddressSet []*netipx.IPSet
|
||||||
|
)
|
||||||
|
for _, routeRuleSet := range t.routeRuleSet {
|
||||||
|
ipSets := routeRuleSet.ExtractIPSet()
|
||||||
|
if len(ipSets) == 0 {
|
||||||
|
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeRuleSet.Name())
|
||||||
|
}
|
||||||
|
t.routeRuleSetCallback = append(t.routeRuleSetCallback, routeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
||||||
|
routeRuleSet.DecRef()
|
||||||
|
routeAddressSet = append(routeAddressSet, ipSets...)
|
||||||
|
}
|
||||||
|
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
||||||
|
ipSets := routeExcludeRuleSet.ExtractIPSet()
|
||||||
|
if len(ipSets) == 0 {
|
||||||
|
t.logger.Warn("route_exclude_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
|
||||||
|
}
|
||||||
|
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
||||||
|
routeExcludeRuleSet.DecRef()
|
||||||
|
routeExcludeAddressSet = append(routeExcludeAddressSet, ipSets...)
|
||||||
|
}
|
||||||
|
t.stack.routeAddressSet = routeAddressSet
|
||||||
|
t.stack.routeExcludeAddressSet = routeExcludeAddressSet
|
||||||
|
monitor.Start("starting NDIS stack")
|
||||||
|
t.stack.handler = t
|
||||||
|
err := t.stack.Start()
|
||||||
|
monitor.Finish()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "starting NDIS stack")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Inbound) Close() error {
|
||||||
|
if t.api != nil {
|
||||||
|
t.stack.Close()
|
||||||
|
t.api.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Inbound) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
|
||||||
|
return t.router.PreMatch(adapter.InboundContext{
|
||||||
|
Inbound: t.Tag(),
|
||||||
|
InboundType: C.TypeNDIS,
|
||||||
|
Network: network,
|
||||||
|
Source: source,
|
||||||
|
Destination: destination,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
ctx = log.ContextWithNewID(ctx)
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Inbound = t.Tag()
|
||||||
|
metadata.InboundType = C.TypeNDIS
|
||||||
|
metadata.Source = source
|
||||||
|
metadata.Destination = destination
|
||||||
|
t.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
|
||||||
|
t.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||||
|
done, err := t.tracker.NewConnEx(conn)
|
||||||
|
if err != nil {
|
||||||
|
t.logger.ErrorContext(ctx, E.Cause(err, "track inbound connection"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.router.RouteConnectionEx(ctx, conn, metadata, N.AppendClose(onClose, done))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
ctx = log.ContextWithNewID(ctx)
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Inbound = t.Tag()
|
||||||
|
metadata.InboundType = C.TypeNDIS
|
||||||
|
metadata.Source = source
|
||||||
|
metadata.Destination = destination
|
||||||
|
t.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
|
||||||
|
t.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
||||||
|
done, err := t.tracker.NewPacketConnEx(conn)
|
||||||
|
if err != nil {
|
||||||
|
t.logger.ErrorContext(ctx, E.Cause(err, "track inbound connection"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.router.RoutePacketConnectionEx(ctx, conn, metadata, N.AppendClose(onClose, done))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Inbound) updateRouteAddressSet(it adapter.RuleSet) {
|
||||||
|
t.stack.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
|
||||||
|
t.stack.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
|
||||||
|
}
|
267
protocol/ndis/stack.go
Normal file
267
protocol/ndis/stack.go
Normal file
|
@ -0,0 +1,267 @@
|
||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package ndis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/gvisor/pkg/buffer"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/conntrack"
|
||||||
|
"github.com/sagernet/sing-tun"
|
||||||
|
"github.com/sagernet/sing/common/control"
|
||||||
|
"github.com/sagernet/sing/common/debug"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
|
||||||
|
"github.com/wiresock/ndisapi-go"
|
||||||
|
"github.com/wiresock/ndisapi-go/driver"
|
||||||
|
"go4.org/netipx"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Stack struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger logger.ContextLogger
|
||||||
|
network adapter.NetworkManager
|
||||||
|
trackerIn conntrack.Tracker
|
||||||
|
trackerOut conntrack.Tracker
|
||||||
|
api *ndisapi.NdisApi
|
||||||
|
handler tun.Handler
|
||||||
|
udpTimeout time.Duration
|
||||||
|
filter *driver.QueuedPacketFilter
|
||||||
|
stack *stack.Stack
|
||||||
|
endpoint *ndisEndpoint
|
||||||
|
routeAddress []netip.Prefix
|
||||||
|
routeExcludeAddress []netip.Prefix
|
||||||
|
routeAddressSet []*netipx.IPSet
|
||||||
|
routeExcludeAddressSet []*netipx.IPSet
|
||||||
|
currentInterface *control.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stack) Start() error {
|
||||||
|
err := s.start(s.network.InterfaceMonitor().DefaultInterface())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.network.InterfaceMonitor().RegisterCallback(s.updateDefaultInterface)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stack) updateDefaultInterface(defaultInterface *control.Interface, flags int) {
|
||||||
|
if s.currentInterface.Equals(*defaultInterface) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := s.start(defaultInterface)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error(E.Cause(err, "reconfigure NDIS at: ", defaultInterface.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stack) start(defaultInterface *control.Interface) error {
|
||||||
|
_ = s.Close()
|
||||||
|
adapters, err := s.api.GetTcpipBoundAdaptersInfo()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if defaultInterface != nil {
|
||||||
|
for index := 0; index < int(adapters.AdapterCount); index++ {
|
||||||
|
name := s.api.ConvertWindows2000AdapterName(string(adapters.AdapterNameList[index][:]))
|
||||||
|
if name != defaultInterface.Name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.filter, err = driver.NewQueuedPacketFilter(s.api, adapters, nil, s.processOut)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
address := tcpip.LinkAddress(adapters.CurrentAddress[index][:])
|
||||||
|
mtu := uint32(adapters.MTU[index])
|
||||||
|
endpoint := &ndisEndpoint{
|
||||||
|
filter: s.filter,
|
||||||
|
mtu: mtu,
|
||||||
|
address: address,
|
||||||
|
}
|
||||||
|
s.stack, err = tun.NewGVisorStack(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
s.filter = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(s.ctx, s.stack, s.handler).HandlePacket)
|
||||||
|
s.stack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(s.ctx, s.stack, s.handler, s.udpTimeout).HandlePacket)
|
||||||
|
err = s.filter.StartFilter(index)
|
||||||
|
if err != nil {
|
||||||
|
s.filter = nil
|
||||||
|
s.stack.Close()
|
||||||
|
s.stack = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.endpoint = endpoint
|
||||||
|
s.logger.Info("started at ", defaultInterface.Name)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.currentInterface = defaultInterface
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stack) Close() error {
|
||||||
|
if s.filter != nil {
|
||||||
|
s.filter.StopFilter()
|
||||||
|
s.filter.Close()
|
||||||
|
s.filter = nil
|
||||||
|
}
|
||||||
|
if s.stack != nil {
|
||||||
|
s.stack.Close()
|
||||||
|
for _, endpoint := range s.stack.CleanupEndpoints() {
|
||||||
|
endpoint.Abort()
|
||||||
|
}
|
||||||
|
s.stack = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stack) processOut(handle ndisapi.Handle, packet *ndisapi.IntermediateBuffer) ndisapi.FilterAction {
|
||||||
|
if packet.Length < header.EthernetMinimumSize {
|
||||||
|
return ndisapi.FilterActionPass
|
||||||
|
}
|
||||||
|
if s.endpoint.dispatcher == nil || s.filterPacket(packet.Buffer[:packet.Length]) {
|
||||||
|
return ndisapi.FilterActionPass
|
||||||
|
}
|
||||||
|
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(packet.Buffer[:packet.Length]),
|
||||||
|
})
|
||||||
|
_, ok := packetBuffer.LinkHeader().Consume(header.EthernetMinimumSize)
|
||||||
|
if !ok {
|
||||||
|
packetBuffer.DecRef()
|
||||||
|
return ndisapi.FilterActionPass
|
||||||
|
}
|
||||||
|
ethHdr := header.Ethernet(packetBuffer.LinkHeader().Slice())
|
||||||
|
destinationAddress := ethHdr.DestinationAddress()
|
||||||
|
if destinationAddress == header.EthernetBroadcastAddress {
|
||||||
|
packetBuffer.PktType = tcpip.PacketBroadcast
|
||||||
|
} else if header.IsMulticastEthernetAddress(destinationAddress) {
|
||||||
|
packetBuffer.PktType = tcpip.PacketMulticast
|
||||||
|
} else if destinationAddress == s.endpoint.address {
|
||||||
|
packetBuffer.PktType = tcpip.PacketHost
|
||||||
|
} else {
|
||||||
|
packetBuffer.PktType = tcpip.PacketOtherHost
|
||||||
|
}
|
||||||
|
s.endpoint.dispatcher.DeliverNetworkPacket(ethHdr.Type(), packetBuffer)
|
||||||
|
packetBuffer.DecRef()
|
||||||
|
return ndisapi.FilterActionDrop
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stack) filterPacket(packet []byte) bool {
|
||||||
|
var ipHdr header.Network
|
||||||
|
switch header.IPVersion(packet[header.EthernetMinimumSize:]) {
|
||||||
|
case ipv4.Version:
|
||||||
|
ipHdr = header.IPv4(packet[header.EthernetMinimumSize:])
|
||||||
|
case ipv6.Version:
|
||||||
|
ipHdr = header.IPv6(packet[header.EthernetMinimumSize:])
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
sourceAddr := tun.AddrFromAddress(ipHdr.SourceAddress())
|
||||||
|
destinationAddr := tun.AddrFromAddress(ipHdr.DestinationAddress())
|
||||||
|
if !destinationAddr.IsGlobalUnicast() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
transportProtocol tcpip.TransportProtocolNumber
|
||||||
|
transportHdr header.Transport
|
||||||
|
)
|
||||||
|
switch ipHdr.TransportProtocol() {
|
||||||
|
case tcp.ProtocolNumber:
|
||||||
|
transportProtocol = header.TCPProtocolNumber
|
||||||
|
transportHdr = header.TCP(ipHdr.Payload())
|
||||||
|
case udp.ProtocolNumber:
|
||||||
|
transportProtocol = header.UDPProtocolNumber
|
||||||
|
transportHdr = header.UDP(ipHdr.Payload())
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
source := netip.AddrPortFrom(sourceAddr, transportHdr.SourcePort())
|
||||||
|
destination := netip.AddrPortFrom(destinationAddr, transportHdr.DestinationPort())
|
||||||
|
if transportProtocol == header.TCPProtocolNumber {
|
||||||
|
if s.trackerIn.CheckConn(source, destination) {
|
||||||
|
if debug.Enabled {
|
||||||
|
s.logger.Trace("fall exists TCP ", source, " ", destination)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if s.trackerIn.CheckPacketConn(source) {
|
||||||
|
if debug.Enabled {
|
||||||
|
s.logger.Trace("fall exists UDP ", source, " ", destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(s.routeAddress) > 0 {
|
||||||
|
var match bool
|
||||||
|
for _, route := range s.routeAddress {
|
||||||
|
if route.Contains(destinationAddr) {
|
||||||
|
match = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(s.routeAddressSet) > 0 {
|
||||||
|
var match bool
|
||||||
|
for _, ipSet := range s.routeAddressSet {
|
||||||
|
if ipSet.Contains(destinationAddr) {
|
||||||
|
match = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(s.routeExcludeAddress) > 0 {
|
||||||
|
for _, address := range s.routeExcludeAddress {
|
||||||
|
if address.Contains(destinationAddr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(s.routeExcludeAddressSet) > 0 {
|
||||||
|
for _, ipSet := range s.routeAddressSet {
|
||||||
|
if ipSet.Contains(destinationAddr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.trackerOut.CheckDestination(destination) {
|
||||||
|
if debug.Enabled {
|
||||||
|
s.logger.Trace("passing pending ", source, " ", destination)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if transportProtocol == header.TCPProtocolNumber {
|
||||||
|
if s.trackerOut.CheckConn(source, destination) {
|
||||||
|
if debug.Enabled {
|
||||||
|
s.logger.Trace("passing TCP ", source, " ", destination)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if s.trackerOut.CheckPacketConn(source) {
|
||||||
|
if debug.Enabled {
|
||||||
|
s.logger.Trace("passing UDP ", source, " ", destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if debug.Enabled {
|
||||||
|
s.logger.Trace("fall ", source, " ", destination)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
|
@ -306,7 +306,6 @@ func (t *Inbound) Start(stage adapter.StartStage) error {
|
||||||
t.tunOptions.Name = tun.CalculateInterfaceName("")
|
t.tunOptions.Name = tun.CalculateInterfaceName("")
|
||||||
}
|
}
|
||||||
if t.platformInterface == nil || runtime.GOOS != "android" {
|
if t.platformInterface == nil || runtime.GOOS != "android" {
|
||||||
t.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
|
|
||||||
for _, routeRuleSet := range t.routeRuleSet {
|
for _, routeRuleSet := range t.routeRuleSet {
|
||||||
ipSets := routeRuleSet.ExtractIPSet()
|
ipSets := routeRuleSet.ExtractIPSet()
|
||||||
if len(ipSets) == 0 {
|
if len(ipSets) == 0 {
|
||||||
|
@ -316,11 +315,10 @@ func (t *Inbound) Start(stage adapter.StartStage) error {
|
||||||
routeRuleSet.DecRef()
|
routeRuleSet.DecRef()
|
||||||
t.routeAddressSet = append(t.routeAddressSet, ipSets...)
|
t.routeAddressSet = append(t.routeAddressSet, ipSets...)
|
||||||
}
|
}
|
||||||
t.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
|
|
||||||
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
||||||
ipSets := routeExcludeRuleSet.ExtractIPSet()
|
ipSets := routeExcludeRuleSet.ExtractIPSet()
|
||||||
if len(ipSets) == 0 {
|
if len(ipSets) == 0 {
|
||||||
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
|
t.logger.Warn("route_exclude_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
|
||||||
}
|
}
|
||||||
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
||||||
routeExcludeRuleSet.DecRef()
|
routeExcludeRuleSet.DecRef()
|
||||||
|
|
|
@ -35,6 +35,7 @@ var _ adapter.NetworkManager = (*NetworkManager)(nil)
|
||||||
|
|
||||||
type NetworkManager struct {
|
type NetworkManager struct {
|
||||||
logger logger.ContextLogger
|
logger logger.ContextLogger
|
||||||
|
tracker conntrack.Tracker
|
||||||
interfaceFinder *control.DefaultInterfaceFinder
|
interfaceFinder *control.DefaultInterfaceFinder
|
||||||
networkInterfaces atomic.TypedValue[[]adapter.NetworkInterface]
|
networkInterfaces atomic.TypedValue[[]adapter.NetworkInterface]
|
||||||
|
|
||||||
|
@ -57,6 +58,7 @@ type NetworkManager struct {
|
||||||
func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) {
|
func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) {
|
||||||
nm := &NetworkManager{
|
nm := &NetworkManager{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
tracker: service.FromContext[conntrack.Tracker](ctx),
|
||||||
interfaceFinder: control.NewDefaultInterfaceFinder(),
|
interfaceFinder: control.NewDefaultInterfaceFinder(),
|
||||||
autoDetectInterface: routeOptions.AutoDetectInterface,
|
autoDetectInterface: routeOptions.AutoDetectInterface,
|
||||||
defaultOptions: adapter.NetworkOptions{
|
defaultOptions: adapter.NetworkOptions{
|
||||||
|
@ -355,7 +357,7 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *NetworkManager) ResetNetwork() {
|
func (r *NetworkManager) ResetNetwork() {
|
||||||
conntrack.Close()
|
r.tracker.Close()
|
||||||
|
|
||||||
for _, endpoint := range r.endpoint.Endpoints() {
|
for _, endpoint := range r.endpoint.Endpoints() {
|
||||||
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)
|
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/conntrack"
|
|
||||||
"github.com/sagernet/sing-box/common/process"
|
"github.com/sagernet/sing-box/common/process"
|
||||||
"github.com/sagernet/sing-box/common/sniff"
|
"github.com/sagernet/sing-box/common/sniff"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
@ -72,7 +71,10 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
|
||||||
injectable.NewConnectionEx(ctx, conn, metadata, onClose)
|
injectable.NewConnectionEx(ctx, conn, metadata, onClose)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
conntrack.KillerCheck()
|
err := r.connTracker.KillerCheck()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
metadata.Network = N.NetworkTCP
|
metadata.Network = N.NetworkTCP
|
||||||
switch metadata.Destination.Fqdn {
|
switch metadata.Destination.Fqdn {
|
||||||
case mux.Destination.Fqdn:
|
case mux.Destination.Fqdn:
|
||||||
|
@ -190,7 +192,10 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
||||||
injectable.NewPacketConnectionEx(ctx, conn, metadata, onClose)
|
injectable.NewPacketConnectionEx(ctx, conn, metadata, onClose)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
conntrack.KillerCheck()
|
err := r.connTracker.KillerCheck()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: move to UoT
|
// TODO: move to UoT
|
||||||
metadata.Network = N.NetworkUDP
|
metadata.Network = N.NetworkUDP
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/conntrack"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
"github.com/sagernet/sing-box/common/geoip"
|
"github.com/sagernet/sing-box/common/geoip"
|
||||||
"github.com/sagernet/sing-box/common/geosite"
|
"github.com/sagernet/sing-box/common/geosite"
|
||||||
|
@ -38,6 +39,7 @@ type Router struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
dnsLogger log.ContextLogger
|
dnsLogger log.ContextLogger
|
||||||
|
connTracker conntrack.Tracker
|
||||||
inbound adapter.InboundManager
|
inbound adapter.InboundManager
|
||||||
outbound adapter.OutboundManager
|
outbound adapter.OutboundManager
|
||||||
connection adapter.ConnectionManager
|
connection adapter.ConnectionManager
|
||||||
|
@ -75,6 +77,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
logger: logFactory.NewLogger("router"),
|
logger: logFactory.NewLogger("router"),
|
||||||
dnsLogger: logFactory.NewLogger("dns"),
|
dnsLogger: logFactory.NewLogger("dns"),
|
||||||
|
connTracker: service.FromContext[conntrack.Tracker](ctx),
|
||||||
inbound: service.FromContext[adapter.InboundManager](ctx),
|
inbound: service.FromContext[adapter.InboundManager](ctx),
|
||||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||||
connection: service.FromContext[adapter.ConnectionManager](ctx),
|
connection: service.FromContext[adapter.ConnectionManager](ctx),
|
||||||
|
|
Loading…
Reference in a new issue