refactor: connection manager

This commit is contained in:
世界 2024-11-20 11:32:02 +08:00
parent 48d102a0ab
commit b45cb0763e
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 569 additions and 52 deletions

14
adapter/connections.go Normal file
View file

@ -0,0 +1,14 @@
package adapter
import (
"context"
"net"
N "github.com/sagernet/sing/common/network"
)
type ConnectionManager interface {
Lifecycle
NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc)
NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc)
}

18
box.go
View file

@ -36,9 +36,10 @@ type Box struct {
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
network *route.NetworkManager network *route.NetworkManager
router *route.Router
inbound *inbound.Manager inbound *inbound.Manager
outbound *outbound.Manager outbound *outbound.Manager
connection *route.ConnectionManager
router *route.Router
services []adapter.LifecycleService services []adapter.LifecycleService
done chan struct{} done chan struct{}
} }
@ -128,6 +129,8 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize network manager") return nil, E.Cause(err, "initialize network manager")
} }
service.MustRegister[adapter.NetworkManager](ctx, networkManager) service.MustRegister[adapter.NetworkManager](ctx, networkManager)
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS)) router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize router") return nil, E.Cause(err, "initialize router")
@ -238,9 +241,10 @@ func New(options Options) (*Box, error) {
} }
return &Box{ return &Box{
network: networkManager, network: networkManager,
router: router,
inbound: inboundManager, inbound: inboundManager,
outbound: outboundManager, outbound: outboundManager,
connection: connectionManager,
router: router,
createdAt: createdAt, createdAt: createdAt,
logFactory: logFactory, logFactory: logFactory,
logger: logFactory.Logger(), logger: logFactory.Logger(),
@ -299,11 +303,11 @@ func (s *Box) preStart() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound) err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound)
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.router) err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router)
if err != nil { if err != nil {
return err return err
} }
@ -323,7 +327,7 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound) err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound)
if err != nil { if err != nil {
return err return err
} }
@ -331,7 +335,7 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound) err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound)
if err != nil { if err != nil {
return err return err
} }
@ -350,7 +354,7 @@ func (s *Box) Close() error {
close(s.done) close(s.done)
} }
err := common.Close( err := common.Close(
s.inbound, s.outbound, s.router, s.network, s.inbound, s.outbound, s.router, s.connection, s.network,
) )
for _, lifecycleService := range s.services { for _, lifecycleService := range s.services {
err = E.Append(err, lifecycleService.Close(), func(err error) error { err = E.Append(err, lifecycleService.Close(), func(err error) error {

View file

@ -83,7 +83,7 @@ func (i *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
destination = i.overrideDestination destination = i.overrideDestination
case 2: case 2:
destination = i.overrideDestination destination = i.overrideDestination
destination.Port = source.Port destination.Port = i.listener.UDPAddr().Port
case 3: case 3:
destination = source destination = source
destination.Port = i.overrideDestination.Port destination.Port = i.overrideDestination.Port

332
route/conn.go Normal file
View file

@ -0,0 +1,332 @@
package route
import (
"context"
"io"
"net"
"net/netip"
"sync/atomic"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
type ConnectionManager struct {
logger logger.ContextLogger
monitor *ConnectionMonitor
}
func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
return &ConnectionManager{
logger: logger,
monitor: NewConnectionMonitor(),
}
}
func (m *ConnectionManager) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateInitialize {
return nil
}
return m.monitor.Start()
}
func (m *ConnectionManager) Close() error {
return m.monitor.Close()
}
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = adapter.WithContext(ctx, &metadata)
var (
remoteConn net.Conn
err error
)
if len(metadata.DestinationAddresses) > 0 {
remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "open outbound connection: ", err)
return
}
err = N.ReportConnHandshakeSuccess(conn, remoteConn)
if err != nil {
remoteConn.Close()
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "report handshake success: ", err)
return
}
var done atomic.Bool
if ctx.Done() != nil {
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
}
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
}
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
originSource := source
originDestination := destination
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountReader(source, readCounters)
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
dataLen := cachedBuffer.Len()
_, err := destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
if done.Swap(true) {
if onClose != nil {
onClose(err)
}
}
common.Close(originSource, originDestination)
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
continue
}
break
}
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
if err != nil {
common.Close(originSource, originDestination)
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
err = duplexDst.CloseWrite()
if err != nil {
common.Close(originSource, originDestination)
}
} else {
common.Close(originDestination)
}
if done.Swap(true) {
if onClose != nil {
onClose(err)
}
common.Close(originSource, originDestination)
}
if !direction {
if err == nil {
m.logger.DebugContext(ctx, "connection upload finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection upload closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection upload closed")
}
} else {
if err == nil {
m.logger.DebugContext(ctx, "connection download finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection download closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection download closed")
}
}
}
func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = adapter.WithContext(ctx, &metadata)
var (
remotePacketConn net.PacketConn
remoteConn net.Conn
destinationAddress netip.Addr
err error
)
if metadata.UDPConnect {
if len(metadata.DestinationAddresses) > 0 {
if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer {
remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
}
} else {
remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
}
if err != nil {
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
return
}
remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
if connRemoteAddr != metadata.Destination.Addr {
destinationAddress = connRemoteAddr
}
} else {
if len(metadata.DestinationAddresses) > 0 {
remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
}
if err != nil {
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
return
}
}
err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
if err != nil {
conn.Close()
remotePacketConn.Close()
m.logger.ErrorContext(ctx, "report handshake success: ", err)
return
}
if destinationAddress.IsValid() {
var originDestination M.Socksaddr
if metadata.RouteOriginalDestination.IsValid() {
originDestination = metadata.RouteOriginalDestination
} else {
originDestination = metadata.Destination
}
if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
if metadata.UDPDisableDomainUnmapping {
remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
} else {
remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
}
}
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
natConn.UpdateDestination(destinationAddress)
}
}
destination := bufio.NewPacketConn(remotePacketConn)
var done atomic.Bool
if ctx.Done() != nil {
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
}
go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
}
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
_, err := bufio.CopyPacket(destination, source)
/*var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
originSource := source
for {
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters)
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
packet := cachedReader.ReadCachedPacket()
if packet != nil {
cachedPackets = append(cachedPackets, packet)
continue
}
}
break
}
var handled bool
if natConn, isNatConn := source.(udpnat.Conn); isNatConn {
natConn.SetHandler(&udpHijacker{
ctx: ctx,
logger: m.logger,
source: natConn,
destination: destination,
direction: direction,
readCounters: readCounters,
writeCounters: writeCounters,
done: done,
onClose: onClose,
})
handled = true
}
if cachedPackets != nil {
_, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters)
if err != nil {
common.Close(source, destination)
m.logger.ErrorContext(ctx, "packet upload payload: ", err)
return
}
}
if handled {
return
}
_, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/
if !direction {
if E.IsClosedOrCanceled(err) {
m.logger.TraceContext(ctx, "packet upload closed")
} else {
m.logger.DebugContext(ctx, "packet upload closed: ", err)
}
} else {
if E.IsClosedOrCanceled(err) {
m.logger.TraceContext(ctx, "packet download closed")
} else {
m.logger.DebugContext(ctx, "packet download closed: ", err)
}
}
if !done.Swap(true) {
if onClose != nil {
onClose(err)
}
}
common.Close(source, destination)
}
/*type udpHijacker struct {
ctx context.Context
logger logger.ContextLogger
source io.Closer
destination N.PacketWriter
direction bool
readCounters []N.CountFunc
writeCounters []N.CountFunc
done *atomic.Bool
onClose N.CloseHandlerFunc
}
func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
dataLen := buffer.Len()
for _, counter := range u.readCounters {
counter(int64(dataLen))
}
err := u.destination.WritePacket(buffer, source)
if err != nil {
common.Close(u.source, u.destination)
u.logger.DebugContext(u.ctx, "packet upload closed: ", err)
return
}
for _, counter := range u.writeCounters {
counter(int64(dataLen))
}
}
func (u *udpHijacker) Close() error {
var err error
if !u.done.Swap(true) {
err = common.Close(u.source, u.destination)
if u.onClose != nil {
u.onClose(net.ErrClosed)
}
}
if u.direction {
u.logger.TraceContext(u.ctx, "packet download closed")
} else {
u.logger.TraceContext(u.ctx, "packet upload closed")
}
return err
}
func (u *udpHijacker) Upstream() any {
return u.destination
}
*/

128
route/conn_monitor.go Normal file
View file

@ -0,0 +1,128 @@
package route
import (
"context"
"io"
"reflect"
"sync"
"time"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)
type ConnectionMonitor struct {
access sync.RWMutex
reloadChan chan struct{}
connections list.List[*monitorEntry]
}
type monitorEntry struct {
ctx context.Context
closer io.Closer
}
func NewConnectionMonitor() *ConnectionMonitor {
return &ConnectionMonitor{
reloadChan: make(chan struct{}, 1),
}
}
func (m *ConnectionMonitor) Add(ctx context.Context, closer io.Closer) N.CloseHandlerFunc {
m.access.Lock()
defer m.access.Unlock()
element := m.connections.PushBack(&monitorEntry{
ctx: ctx,
closer: closer,
})
select {
case <-m.reloadChan:
return nil
default:
select {
case m.reloadChan <- struct{}{}:
default:
}
}
return func(it error) {
m.access.Lock()
defer m.access.Unlock()
m.connections.Remove(element)
select {
case <-m.reloadChan:
default:
select {
case m.reloadChan <- struct{}{}:
default:
}
}
}
}
func (m *ConnectionMonitor) Start() error {
go m.monitor()
return nil
}
func (m *ConnectionMonitor) Close() error {
m.access.Lock()
defer m.access.Unlock()
close(m.reloadChan)
for element := m.connections.Front(); element != nil; element = element.Next() {
element.Value.closer.Close()
}
return nil
}
func (m *ConnectionMonitor) monitor() {
var (
selectCases []reflect.SelectCase
elements []*list.Element[*monitorEntry]
)
rootCase := reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(m.reloadChan),
}
for {
m.access.RLock()
if m.connections.Len() == 0 {
m.access.RUnlock()
if _, loaded := <-m.reloadChan; !loaded {
return
} else {
continue
}
}
if len(elements) < m.connections.Len() {
elements = make([]*list.Element[*monitorEntry], 0, m.connections.Len())
}
if len(selectCases) < m.connections.Len()+1 {
selectCases = make([]reflect.SelectCase, 0, m.connections.Len()+1)
}
elements = elements[:0]
selectCases = selectCases[:1]
selectCases[0] = rootCase
for element := m.connections.Front(); element != nil; element = element.Next() {
elements = append(elements, element)
selectCases = append(selectCases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(element.Value.ctx.Done()),
})
}
m.access.RUnlock()
selected, _, loaded := reflect.Select(selectCases)
if selected == 0 {
if !loaded {
return
} else {
time.Sleep(time.Second)
continue
}
}
element := elements[selected-1]
m.access.Lock()
m.connections.Remove(element)
m.access.Unlock()
element.Value.closer.Close() // maybe go close
}
}

View file

@ -0,0 +1,43 @@
package route_test
import (
"context"
"sync"
"testing"
"time"
"github.com/sagernet/sing-box/route"
"github.com/stretchr/testify/require"
)
func TestMonitor(t *testing.T) {
t.Parallel()
var closer myCloser
closer.Add(1)
monitor := route.NewConnectionMonitor()
require.NoError(t, monitor.Start())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
monitor.Add(ctx, &closer)
done := make(chan struct{})
go func() {
closer.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(time.Second + 100*time.Millisecond):
t.Fatal("timeout")
}
cancel()
require.NoError(t, monitor.Close())
}
type myCloser struct {
sync.WaitGroup
}
func (c *myCloser) Close() error {
c.Done()
return nil
}

View file

@ -32,15 +32,15 @@ func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata ad
} }
func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) { func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) {
if uConn, isUDPNAT2 := conn.(*udpnat.Conn); isUDPNAT2 { if natConn, isNatConn := conn.(udpnat.Conn); isNatConn {
metadata.Destination = M.Socksaddr{} metadata.Destination = M.Socksaddr{}
for _, packet := range packetBuffers { for _, packet := range packetBuffers {
buffer := packet.Buffer buffer := packet.Buffer
destination := packet.Destination destination := packet.Destination
N.PutPacketBuffer(packet) N.PutPacketBuffer(packet)
go ExchangeDNSPacket(ctx, r, uConn, buffer, metadata, destination) go ExchangeDNSPacket(ctx, r, natConn, buffer, metadata, destination)
} }
uConn.SetHandler(&dnsHijacker{ natConn.SetHandler(&dnsHijacker{
router: r, router: r,
conn: conn, conn: conn,
ctx: ctx, ctx: ctx,

View file

@ -145,13 +145,13 @@ func (r *Router) downloadGeoIPDatabase(savePath string) error {
r.logger.Info("downloading geoip database") r.logger.Info("downloading geoip database")
var detour adapter.Outbound var detour adapter.Outbound
if r.geoIPOptions.DownloadDetour != "" { if r.geoIPOptions.DownloadDetour != "" {
outbound, loaded := r.outboundManager.Outbound(r.geoIPOptions.DownloadDetour) outbound, loaded := r.outbound.Outbound(r.geoIPOptions.DownloadDetour)
if !loaded { if !loaded {
return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour)
} }
detour = outbound detour = outbound
} else { } else {
detour = r.outboundManager.Default() detour = r.outbound.Default()
} }
if parentDir := filepath.Dir(savePath); parentDir != "" { if parentDir := filepath.Dir(savePath); parentDir != "" {
@ -200,13 +200,13 @@ func (r *Router) downloadGeositeDatabase(savePath string) error {
r.logger.Info("downloading geosite database") r.logger.Info("downloading geosite database")
var detour adapter.Outbound var detour adapter.Outbound
if r.geositeOptions.DownloadDetour != "" { if r.geositeOptions.DownloadDetour != "" {
outbound, loaded := r.outboundManager.Outbound(r.geositeOptions.DownloadDetour) outbound, loaded := r.outbound.Outbound(r.geositeOptions.DownloadDetour)
if !loaded { if !loaded {
return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour) return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour)
} }
detour = outbound detour = outbound
} else { } else {
detour = r.outboundManager.Default() detour = r.outbound.Default()
} }
if parentDir := filepath.Dir(savePath); parentDir != "" { if parentDir := filepath.Dir(savePath); parentDir != "" {

View file

@ -48,6 +48,7 @@ type NetworkManager struct {
powerListener winpowrprof.EventListener powerListener winpowrprof.EventListener
pauseManager pause.Manager pauseManager pause.Manager
platformInterface platform.Interface platformInterface platform.Interface
inboundManager adapter.InboundManager
outboundManager adapter.OutboundManager outboundManager adapter.OutboundManager
wifiState adapter.WIFIState wifiState adapter.WIFIState
started bool started bool
@ -354,6 +355,13 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
func (r *NetworkManager) ResetNetwork() { func (r *NetworkManager) ResetNetwork() {
conntrack.Close() conntrack.Close()
for _, inbound := range r.inboundManager.Inbounds() {
listener, isListener := inbound.(adapter.InterfaceUpdateListener)
if isListener {
listener.InterfaceUpdated()
}
}
for _, outbound := range r.outboundManager.Outbounds() { for _, outbound := range r.outboundManager.Outbounds() {
listener, isListener := outbound.(adapter.InterfaceUpdateListener) listener, isListener := outbound.(adapter.InterfaceUpdateListener)
if isListener { if isListener {

View file

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/conntrack" "github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/common/process"
"github.com/sagernet/sing-box/common/sniff" "github.com/sagernet/sing-box/common/sniff"
@ -58,7 +57,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if metadata.LastInbound == metadata.InboundDetour { if metadata.LastInbound == metadata.InboundDetour {
return E.New("routing loop on detour: ", metadata.InboundDetour) return E.New("routing loop on detour: ", metadata.InboundDetour)
} }
detour, loaded := r.inboundManager.Get(metadata.InboundDetour) detour, loaded := r.inbound.Get(metadata.InboundDetour)
if !loaded { if !loaded {
return E.New("inbound detour not found: ", metadata.InboundDetour) return E.New("inbound detour not found: ", metadata.InboundDetour)
} }
@ -96,7 +95,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
switch action := selectedRule.Action().(type) { switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute: case *rule.RuleActionRoute:
var loaded bool var loaded bool
selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound) selectedOutbound, loaded = r.outbound.Outbound(action.Outbound)
if !loaded { if !loaded {
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
return E.New("outbound not found: ", action.Outbound) return E.New("outbound not found: ", action.Outbound)
@ -118,7 +117,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
} }
} }
if selectedRule == nil { if selectedRule == nil {
defaultOutbound := r.outboundManager.Default() defaultOutbound := r.outbound.Default()
if !common.Contains(defaultOutbound.Network(), N.NetworkTCP) { if !common.Contains(defaultOutbound.Network(), N.NetworkTCP) {
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag()) return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag())
@ -148,19 +147,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
} }
return nil return nil
} }
// TODO r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose)
err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata)
if err != nil {
conn.Close()
if onClose != nil {
onClose(err)
}
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
} else {
if onClose != nil {
onClose(nil)
}
}
return nil return nil
} }
@ -199,7 +186,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
if metadata.LastInbound == metadata.InboundDetour { if metadata.LastInbound == metadata.InboundDetour {
return E.New("routing loop on detour: ", metadata.InboundDetour) return E.New("routing loop on detour: ", metadata.InboundDetour)
} }
detour, loaded := r.inboundManager.Get(metadata.InboundDetour) detour, loaded := r.inbound.Get(metadata.InboundDetour)
if !loaded { if !loaded {
return E.New("inbound detour not found: ", metadata.InboundDetour) return E.New("inbound detour not found: ", metadata.InboundDetour)
} }
@ -233,7 +220,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
switch action := selectedRule.Action().(type) { switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute: case *rule.RuleActionRoute:
var loaded bool var loaded bool
selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound) selectedOutbound, loaded = r.outbound.Outbound(action.Outbound)
if !loaded { if !loaded {
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("outbound not found: ", action.Outbound) return E.New("outbound not found: ", action.Outbound)
@ -252,7 +239,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
} }
} }
if selectedRule == nil || selectReturn { if selectedRule == nil || selectReturn {
defaultOutbound := r.outboundManager.Default() defaultOutbound := r.outbound.Default()
if !common.Contains(defaultOutbound.Network(), N.NetworkUDP) { if !common.Contains(defaultOutbound.Network(), N.NetworkUDP) {
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag()) return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag())
@ -278,12 +265,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
} }
return nil return nil
} }
// TODO r.connection.NewPacketConnection(ctx, selectedOutbound, conn, metadata, onClose)
err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil {
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
}
return nil return nil
} }
@ -450,8 +432,12 @@ match:
} }
metadata.NetworkStrategy = routeOptions.NetworkStrategy metadata.NetworkStrategy = routeOptions.NetworkStrategy
metadata.FallbackDelay = routeOptions.FallbackDelay metadata.FallbackDelay = routeOptions.FallbackDelay
metadata.UDPDisableDomainUnmapping = routeOptions.UDPDisableDomainUnmapping if routeOptions.UDPDisableDomainUnmapping {
metadata.UDPConnect = routeOptions.UDPConnect metadata.UDPDisableDomainUnmapping = true
}
if routeOptions.UDPConnect {
metadata.UDPConnect = true
}
} }
switch action := currentRule.Action().(type) { switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff: case *rule.RuleActionSniff:

View file

@ -38,9 +38,10 @@ type Router struct {
ctx context.Context ctx context.Context
logger log.ContextLogger logger log.ContextLogger
dnsLogger log.ContextLogger dnsLogger log.ContextLogger
inboundManager adapter.InboundManager inbound adapter.InboundManager
outboundManager adapter.OutboundManager outbound adapter.OutboundManager
networkManager adapter.NetworkManager connection adapter.ConnectionManager
network adapter.NetworkManager
rules []adapter.Rule rules []adapter.Rule
needGeoIPDatabase bool needGeoIPDatabase bool
needGeositeDatabase bool needGeositeDatabase bool
@ -74,9 +75,10 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
ctx: ctx, ctx: ctx,
logger: logFactory.NewLogger("router"), logger: logFactory.NewLogger("router"),
dnsLogger: logFactory.NewLogger("dns"), dnsLogger: logFactory.NewLogger("dns"),
inboundManager: service.FromContext[adapter.InboundManager](ctx), inbound: service.FromContext[adapter.InboundManager](ctx),
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx),
networkManager: service.FromContext[adapter.NetworkManager](ctx), connection: service.FromContext[adapter.ConnectionManager](ctx),
network: service.FromContext[adapter.NetworkManager](ctx),
rules: make([]adapter.Rule, 0, len(options.Rules)), rules: make([]adapter.Rule, 0, len(options.Rules)),
dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
ruleSetMap: make(map[string]adapter.RuleSet), ruleSetMap: make(map[string]adapter.RuleSet),
@ -260,7 +262,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
Context: ctx, Context: ctx,
Name: "local", Name: "local",
Address: "local", Address: "local",
Dialer: common.Must1(dialer.NewDefault(router.networkManager, option.DialerOptions{})), Dialer: common.Must1(dialer.NewDefault(router.network, option.DialerOptions{})),
}))) })))
} }
defaultTransport = transports[0] defaultTransport = transports[0]
@ -405,7 +407,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
monitor.Start("initialize process searcher") monitor.Start("initialize process searcher")
searcher, err := process.NewSearcher(process.Config{ searcher, err := process.NewSearcher(process.Config{
Logger: r.logger, Logger: r.logger,
PackageManager: r.networkManager.PackageManager(), PackageManager: r.network.PackageManager(),
}) })
monitor.Finish() monitor.Finish()
if err != nil { if err != nil {
@ -507,7 +509,7 @@ func (r *Router) SetTracker(tracker adapter.ConnectionTracker) {
} }
func (r *Router) ResetNetwork() { func (r *Router) ResetNetwork() {
r.networkManager.ResetNetwork() r.network.ResetNetwork()
for _, transport := range r.transports { for _, transport := range r.transports {
transport.Reset() transport.Reset()
} }