From b45cb0763ec64f0a1b01367c1f69204ecac3ac16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 20 Nov 2024 11:32:02 +0800 Subject: [PATCH] refactor: connection manager --- adapter/connections.go | 14 ++ box.go | 18 +- protocol/direct/inbound.go | 2 +- route/conn.go | 332 +++++++++++++++++++++++++++++++++++++ route/conn_monitor.go | 128 ++++++++++++++ route/conn_monitor_test.go | 43 +++++ route/dns.go | 6 +- route/geo_resources.go | 8 +- route/network.go | 8 + route/route.go | 42 ++--- route/router.go | 20 ++- 11 files changed, 569 insertions(+), 52 deletions(-) create mode 100644 adapter/connections.go create mode 100644 route/conn.go create mode 100644 route/conn_monitor.go create mode 100644 route/conn_monitor_test.go diff --git a/adapter/connections.go b/adapter/connections.go new file mode 100644 index 00000000..0682d05a --- /dev/null +++ b/adapter/connections.go @@ -0,0 +1,14 @@ +package adapter + +import ( + "context" + "net" + + N "github.com/sagernet/sing/common/network" +) + +type ConnectionManager interface { + Lifecycle + NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc) + NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc) +} diff --git a/box.go b/box.go index 3b69617f..44a29992 100644 --- a/box.go +++ b/box.go @@ -36,9 +36,10 @@ type Box struct { logFactory log.Factory logger log.ContextLogger network *route.NetworkManager - router *route.Router inbound *inbound.Manager outbound *outbound.Manager + connection *route.ConnectionManager + router *route.Router services []adapter.LifecycleService done chan struct{} } @@ -128,6 +129,8 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize network manager") } service.MustRegister[adapter.NetworkManager](ctx, networkManager) + connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection")) + service.MustRegister[adapter.ConnectionManager](ctx, connectionManager) router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS)) if err != nil { return nil, E.Cause(err, "initialize router") @@ -238,9 +241,10 @@ func New(options Options) (*Box, error) { } return &Box{ network: networkManager, - router: router, inbound: inboundManager, outbound: outboundManager, + connection: connectionManager, + router: router, createdAt: createdAt, logFactory: logFactory, logger: logFactory.Logger(), @@ -299,11 +303,11 @@ func (s *Box) preStart() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound) + err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound) if err != nil { return err } - err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.router) + err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router) if err != nil { return err } @@ -323,7 +327,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound) + err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound) if err != nil { return err } @@ -331,7 +335,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound) + err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound) if err != nil { return err } @@ -350,7 +354,7 @@ func (s *Box) Close() error { close(s.done) } err := common.Close( - s.inbound, s.outbound, s.router, s.network, + s.inbound, s.outbound, s.router, s.connection, s.network, ) for _, lifecycleService := range s.services { err = E.Append(err, lifecycleService.Close(), func(err error) error { diff --git a/protocol/direct/inbound.go b/protocol/direct/inbound.go index 8415e21f..6db60d78 100644 --- a/protocol/direct/inbound.go +++ b/protocol/direct/inbound.go @@ -83,7 +83,7 @@ func (i *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { destination = i.overrideDestination case 2: destination = i.overrideDestination - destination.Port = source.Port + destination.Port = i.listener.UDPAddr().Port case 3: destination = source destination.Port = i.overrideDestination.Port diff --git a/route/conn.go b/route/conn.go new file mode 100644 index 00000000..594379cc --- /dev/null +++ b/route/conn.go @@ -0,0 +1,332 @@ +package route + +import ( + "context" + "io" + "net" + "net/netip" + "sync/atomic" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ adapter.ConnectionManager = (*ConnectionManager)(nil) + +type ConnectionManager struct { + logger logger.ContextLogger + monitor *ConnectionMonitor +} + +func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { + return &ConnectionManager{ + logger: logger, + monitor: NewConnectionMonitor(), + } +} + +func (m *ConnectionManager) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateInitialize { + return nil + } + return m.monitor.Start() +} + +func (m *ConnectionManager) Close() error { + return m.monitor.Close() +} + +func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + ctx = adapter.WithContext(ctx, &metadata) + var ( + remoteConn net.Conn + err error + ) + if len(metadata.DestinationAddresses) > 0 { + remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + if err != nil { + N.CloseOnHandshakeFailure(conn, onClose, err) + m.logger.ErrorContext(ctx, "open outbound connection: ", err) + return + } + err = N.ReportConnHandshakeSuccess(conn, remoteConn) + if err != nil { + remoteConn.Close() + N.CloseOnHandshakeFailure(conn, onClose, err) + m.logger.ErrorContext(ctx, "report handshake success: ", err) + return + } + var done atomic.Bool + if ctx.Done() != nil { + onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) + } + go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) + go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose) +} + +func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { + originSource := source + originDestination := destination + var readCounters, writeCounters []N.CountFunc + for { + source, readCounters = N.UnwrapCountReader(source, readCounters) + destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) + if cachedSrc, isCached := source.(N.CachedReader); isCached { + cachedBuffer := cachedSrc.ReadCached() + if cachedBuffer != nil { + dataLen := cachedBuffer.Len() + _, err := destination.Write(cachedBuffer.Bytes()) + cachedBuffer.Release() + if err != nil { + m.logger.ErrorContext(ctx, "connection upload payload: ", err) + if done.Swap(true) { + if onClose != nil { + onClose(err) + } + } + common.Close(originSource, originDestination) + return + } + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } + } + continue + } + break + } + _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) + if err != nil { + common.Close(originSource, originDestination) + } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { + err = duplexDst.CloseWrite() + if err != nil { + common.Close(originSource, originDestination) + } + } else { + common.Close(originDestination) + } + if done.Swap(true) { + if onClose != nil { + onClose(err) + } + common.Close(originSource, originDestination) + } + if !direction { + if err == nil { + m.logger.DebugContext(ctx, "connection upload finished") + } else if !E.IsClosedOrCanceled(err) { + m.logger.ErrorContext(ctx, "connection upload closed: ", err) + } else { + m.logger.TraceContext(ctx, "connection upload closed") + } + } else { + if err == nil { + m.logger.DebugContext(ctx, "connection download finished") + } else if !E.IsClosedOrCanceled(err) { + m.logger.ErrorContext(ctx, "connection download closed: ", err) + } else { + m.logger.TraceContext(ctx, "connection download closed") + } + } +} + +func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + ctx = adapter.WithContext(ctx, &metadata) + var ( + remotePacketConn net.PacketConn + remoteConn net.Conn + destinationAddress netip.Addr + err error + ) + if metadata.UDPConnect { + if len(metadata.DestinationAddresses) > 0 { + if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { + remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) + } + } else { + remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) + } + if err != nil { + N.CloseOnHandshakeFailure(conn, onClose, err) + m.logger.ErrorContext(ctx, "open outbound packet connection: ", err) + return + } + remotePacketConn = bufio.NewUnbindPacketConn(remoteConn) + connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr()) + if connRemoteAddr != metadata.Destination.Addr { + destinationAddress = connRemoteAddr + } + } else { + if len(metadata.DestinationAddresses) > 0 { + remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination) + } + if err != nil { + N.CloseOnHandshakeFailure(conn, onClose, err) + m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err) + return + } + } + err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn) + if err != nil { + conn.Close() + remotePacketConn.Close() + m.logger.ErrorContext(ctx, "report handshake success: ", err) + return + } + if destinationAddress.IsValid() { + var originDestination M.Socksaddr + if metadata.RouteOriginalDestination.IsValid() { + originDestination = metadata.RouteOriginalDestination + } else { + originDestination = metadata.Destination + } + if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) { + if metadata.UDPDisableDomainUnmapping { + remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) + } else { + remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) + } + } + if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { + natConn.UpdateDestination(destinationAddress) + } + } + destination := bufio.NewPacketConn(remotePacketConn) + var done atomic.Bool + if ctx.Done() != nil { + onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) + } + go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose) + go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) +} + +func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { + _, err := bufio.CopyPacket(destination, source) + /*var readCounters, writeCounters []N.CountFunc + var cachedPackets []*N.PacketBuffer + originSource := source + for { + source, readCounters = N.UnwrapCountPacketReader(source, readCounters) + destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters) + if cachedReader, isCached := source.(N.CachedPacketReader); isCached { + packet := cachedReader.ReadCachedPacket() + if packet != nil { + cachedPackets = append(cachedPackets, packet) + continue + } + } + break + } + var handled bool + if natConn, isNatConn := source.(udpnat.Conn); isNatConn { + natConn.SetHandler(&udpHijacker{ + ctx: ctx, + logger: m.logger, + source: natConn, + destination: destination, + direction: direction, + readCounters: readCounters, + writeCounters: writeCounters, + done: done, + onClose: onClose, + }) + handled = true + } + if cachedPackets != nil { + _, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters) + if err != nil { + common.Close(source, destination) + m.logger.ErrorContext(ctx, "packet upload payload: ", err) + return + } + } + if handled { + return + } + _, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/ + if !direction { + if E.IsClosedOrCanceled(err) { + m.logger.TraceContext(ctx, "packet upload closed") + } else { + m.logger.DebugContext(ctx, "packet upload closed: ", err) + } + } else { + if E.IsClosedOrCanceled(err) { + m.logger.TraceContext(ctx, "packet download closed") + } else { + m.logger.DebugContext(ctx, "packet download closed: ", err) + } + } + if !done.Swap(true) { + if onClose != nil { + onClose(err) + } + } + common.Close(source, destination) +} + +/*type udpHijacker struct { + ctx context.Context + logger logger.ContextLogger + source io.Closer + destination N.PacketWriter + direction bool + readCounters []N.CountFunc + writeCounters []N.CountFunc + done *atomic.Bool + onClose N.CloseHandlerFunc +} + +func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { + dataLen := buffer.Len() + for _, counter := range u.readCounters { + counter(int64(dataLen)) + } + err := u.destination.WritePacket(buffer, source) + if err != nil { + common.Close(u.source, u.destination) + u.logger.DebugContext(u.ctx, "packet upload closed: ", err) + return + } + for _, counter := range u.writeCounters { + counter(int64(dataLen)) + } +} + +func (u *udpHijacker) Close() error { + var err error + if !u.done.Swap(true) { + err = common.Close(u.source, u.destination) + if u.onClose != nil { + u.onClose(net.ErrClosed) + } + } + if u.direction { + u.logger.TraceContext(u.ctx, "packet download closed") + } else { + u.logger.TraceContext(u.ctx, "packet upload closed") + } + return err +} + +func (u *udpHijacker) Upstream() any { + return u.destination +} +*/ diff --git a/route/conn_monitor.go b/route/conn_monitor.go new file mode 100644 index 00000000..9e271b82 --- /dev/null +++ b/route/conn_monitor.go @@ -0,0 +1,128 @@ +package route + +import ( + "context" + "io" + "reflect" + "sync" + "time" + + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/x/list" +) + +type ConnectionMonitor struct { + access sync.RWMutex + reloadChan chan struct{} + connections list.List[*monitorEntry] +} + +type monitorEntry struct { + ctx context.Context + closer io.Closer +} + +func NewConnectionMonitor() *ConnectionMonitor { + return &ConnectionMonitor{ + reloadChan: make(chan struct{}, 1), + } +} + +func (m *ConnectionMonitor) Add(ctx context.Context, closer io.Closer) N.CloseHandlerFunc { + m.access.Lock() + defer m.access.Unlock() + element := m.connections.PushBack(&monitorEntry{ + ctx: ctx, + closer: closer, + }) + select { + case <-m.reloadChan: + return nil + default: + select { + case m.reloadChan <- struct{}{}: + default: + } + } + return func(it error) { + m.access.Lock() + defer m.access.Unlock() + m.connections.Remove(element) + select { + case <-m.reloadChan: + default: + select { + case m.reloadChan <- struct{}{}: + default: + } + } + } +} + +func (m *ConnectionMonitor) Start() error { + go m.monitor() + return nil +} + +func (m *ConnectionMonitor) Close() error { + m.access.Lock() + defer m.access.Unlock() + close(m.reloadChan) + for element := m.connections.Front(); element != nil; element = element.Next() { + element.Value.closer.Close() + } + return nil +} + +func (m *ConnectionMonitor) monitor() { + var ( + selectCases []reflect.SelectCase + elements []*list.Element[*monitorEntry] + ) + rootCase := reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(m.reloadChan), + } + for { + m.access.RLock() + if m.connections.Len() == 0 { + m.access.RUnlock() + if _, loaded := <-m.reloadChan; !loaded { + return + } else { + continue + } + } + if len(elements) < m.connections.Len() { + elements = make([]*list.Element[*monitorEntry], 0, m.connections.Len()) + } + if len(selectCases) < m.connections.Len()+1 { + selectCases = make([]reflect.SelectCase, 0, m.connections.Len()+1) + } + elements = elements[:0] + selectCases = selectCases[:1] + selectCases[0] = rootCase + for element := m.connections.Front(); element != nil; element = element.Next() { + elements = append(elements, element) + selectCases = append(selectCases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(element.Value.ctx.Done()), + }) + } + m.access.RUnlock() + selected, _, loaded := reflect.Select(selectCases) + if selected == 0 { + if !loaded { + return + } else { + time.Sleep(time.Second) + continue + } + } + element := elements[selected-1] + m.access.Lock() + m.connections.Remove(element) + m.access.Unlock() + element.Value.closer.Close() // maybe go close + } +} diff --git a/route/conn_monitor_test.go b/route/conn_monitor_test.go new file mode 100644 index 00000000..a712bddc --- /dev/null +++ b/route/conn_monitor_test.go @@ -0,0 +1,43 @@ +package route_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/sagernet/sing-box/route" + + "github.com/stretchr/testify/require" +) + +func TestMonitor(t *testing.T) { + t.Parallel() + var closer myCloser + closer.Add(1) + monitor := route.NewConnectionMonitor() + require.NoError(t, monitor.Start()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + monitor.Add(ctx, &closer) + done := make(chan struct{}) + go func() { + closer.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(time.Second + 100*time.Millisecond): + t.Fatal("timeout") + } + cancel() + require.NoError(t, monitor.Close()) +} + +type myCloser struct { + sync.WaitGroup +} + +func (c *myCloser) Close() error { + c.Done() + return nil +} diff --git a/route/dns.go b/route/dns.go index b69be5a2..2c6efefe 100644 --- a/route/dns.go +++ b/route/dns.go @@ -32,15 +32,15 @@ func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata ad } func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) { - if uConn, isUDPNAT2 := conn.(*udpnat.Conn); isUDPNAT2 { + if natConn, isNatConn := conn.(udpnat.Conn); isNatConn { metadata.Destination = M.Socksaddr{} for _, packet := range packetBuffers { buffer := packet.Buffer destination := packet.Destination N.PutPacketBuffer(packet) - go ExchangeDNSPacket(ctx, r, uConn, buffer, metadata, destination) + go ExchangeDNSPacket(ctx, r, natConn, buffer, metadata, destination) } - uConn.SetHandler(&dnsHijacker{ + natConn.SetHandler(&dnsHijacker{ router: r, conn: conn, ctx: ctx, diff --git a/route/geo_resources.go b/route/geo_resources.go index c5a45ffd..8a8a3ef5 100644 --- a/route/geo_resources.go +++ b/route/geo_resources.go @@ -145,13 +145,13 @@ func (r *Router) downloadGeoIPDatabase(savePath string) error { r.logger.Info("downloading geoip database") var detour adapter.Outbound if r.geoIPOptions.DownloadDetour != "" { - outbound, loaded := r.outboundManager.Outbound(r.geoIPOptions.DownloadDetour) + outbound, loaded := r.outbound.Outbound(r.geoIPOptions.DownloadDetour) if !loaded { return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) } detour = outbound } else { - detour = r.outboundManager.Default() + detour = r.outbound.Default() } if parentDir := filepath.Dir(savePath); parentDir != "" { @@ -200,13 +200,13 @@ func (r *Router) downloadGeositeDatabase(savePath string) error { r.logger.Info("downloading geosite database") var detour adapter.Outbound if r.geositeOptions.DownloadDetour != "" { - outbound, loaded := r.outboundManager.Outbound(r.geositeOptions.DownloadDetour) + outbound, loaded := r.outbound.Outbound(r.geositeOptions.DownloadDetour) if !loaded { return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour) } detour = outbound } else { - detour = r.outboundManager.Default() + detour = r.outbound.Default() } if parentDir := filepath.Dir(savePath); parentDir != "" { diff --git a/route/network.go b/route/network.go index e7c4df9c..510686e5 100644 --- a/route/network.go +++ b/route/network.go @@ -48,6 +48,7 @@ type NetworkManager struct { powerListener winpowrprof.EventListener pauseManager pause.Manager platformInterface platform.Interface + inboundManager adapter.InboundManager outboundManager adapter.OutboundManager wifiState adapter.WIFIState started bool @@ -354,6 +355,13 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState { func (r *NetworkManager) ResetNetwork() { conntrack.Close() + for _, inbound := range r.inboundManager.Inbounds() { + listener, isListener := inbound.(adapter.InterfaceUpdateListener) + if isListener { + listener.InterfaceUpdated() + } + } + for _, outbound := range r.outboundManager.Outbounds() { listener, isListener := outbound.(adapter.InterfaceUpdateListener) if isListener { diff --git a/route/route.go b/route/route.go index 7d21d315..a56016f9 100644 --- a/route/route.go +++ b/route/route.go @@ -11,7 +11,6 @@ import ( "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/common/conntrack" "github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/common/sniff" @@ -58,7 +57,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if metadata.LastInbound == metadata.InboundDetour { return E.New("routing loop on detour: ", metadata.InboundDetour) } - detour, loaded := r.inboundManager.Get(metadata.InboundDetour) + detour, loaded := r.inbound.Get(metadata.InboundDetour) if !loaded { return E.New("inbound detour not found: ", metadata.InboundDetour) } @@ -96,7 +95,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad switch action := selectedRule.Action().(type) { case *rule.RuleActionRoute: var loaded bool - selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound) + selectedOutbound, loaded = r.outbound.Outbound(action.Outbound) if !loaded { buf.ReleaseMulti(buffers) return E.New("outbound not found: ", action.Outbound) @@ -118,7 +117,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad } } if selectedRule == nil { - defaultOutbound := r.outboundManager.Default() + defaultOutbound := r.outbound.Default() if !common.Contains(defaultOutbound.Network(), N.NetworkTCP) { buf.ReleaseMulti(buffers) return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag()) @@ -148,19 +147,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad } return nil } - // TODO - err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata) - if err != nil { - conn.Close() - if onClose != nil { - onClose(err) - } - return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")) - } else { - if onClose != nil { - onClose(nil) - } - } + r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose) return nil } @@ -199,7 +186,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m if metadata.LastInbound == metadata.InboundDetour { return E.New("routing loop on detour: ", metadata.InboundDetour) } - detour, loaded := r.inboundManager.Get(metadata.InboundDetour) + detour, loaded := r.inbound.Get(metadata.InboundDetour) if !loaded { return E.New("inbound detour not found: ", metadata.InboundDetour) } @@ -233,7 +220,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m switch action := selectedRule.Action().(type) { case *rule.RuleActionRoute: var loaded bool - selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound) + selectedOutbound, loaded = r.outbound.Outbound(action.Outbound) if !loaded { N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("outbound not found: ", action.Outbound) @@ -252,7 +239,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m } } if selectedRule == nil || selectReturn { - defaultOutbound := r.outboundManager.Default() + defaultOutbound := r.outbound.Default() if !common.Contains(defaultOutbound.Network(), N.NetworkUDP) { N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag()) @@ -278,12 +265,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m } return nil } - // TODO - err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata) - N.CloseOnHandshakeFailure(conn, onClose, err) - if err != nil { - return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")) - } + r.connection.NewPacketConnection(ctx, selectedOutbound, conn, metadata, onClose) return nil } @@ -450,8 +432,12 @@ match: } metadata.NetworkStrategy = routeOptions.NetworkStrategy metadata.FallbackDelay = routeOptions.FallbackDelay - metadata.UDPDisableDomainUnmapping = routeOptions.UDPDisableDomainUnmapping - metadata.UDPConnect = routeOptions.UDPConnect + if routeOptions.UDPDisableDomainUnmapping { + metadata.UDPDisableDomainUnmapping = true + } + if routeOptions.UDPConnect { + metadata.UDPConnect = true + } } switch action := currentRule.Action().(type) { case *rule.RuleActionSniff: diff --git a/route/router.go b/route/router.go index f23604b3..792391e2 100644 --- a/route/router.go +++ b/route/router.go @@ -38,9 +38,10 @@ type Router struct { ctx context.Context logger log.ContextLogger dnsLogger log.ContextLogger - inboundManager adapter.InboundManager - outboundManager adapter.OutboundManager - networkManager adapter.NetworkManager + inbound adapter.InboundManager + outbound adapter.OutboundManager + connection adapter.ConnectionManager + network adapter.NetworkManager rules []adapter.Rule needGeoIPDatabase bool needGeositeDatabase bool @@ -74,9 +75,10 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route ctx: ctx, logger: logFactory.NewLogger("router"), dnsLogger: logFactory.NewLogger("dns"), - inboundManager: service.FromContext[adapter.InboundManager](ctx), - outboundManager: service.FromContext[adapter.OutboundManager](ctx), - networkManager: service.FromContext[adapter.NetworkManager](ctx), + inbound: service.FromContext[adapter.InboundManager](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), + connection: service.FromContext[adapter.ConnectionManager](ctx), + network: service.FromContext[adapter.NetworkManager](ctx), rules: make([]adapter.Rule, 0, len(options.Rules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), ruleSetMap: make(map[string]adapter.RuleSet), @@ -260,7 +262,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route Context: ctx, Name: "local", Address: "local", - Dialer: common.Must1(dialer.NewDefault(router.networkManager, option.DialerOptions{})), + Dialer: common.Must1(dialer.NewDefault(router.network, option.DialerOptions{})), }))) } defaultTransport = transports[0] @@ -405,7 +407,7 @@ func (r *Router) Start(stage adapter.StartStage) error { monitor.Start("initialize process searcher") searcher, err := process.NewSearcher(process.Config{ Logger: r.logger, - PackageManager: r.networkManager.PackageManager(), + PackageManager: r.network.PackageManager(), }) monitor.Finish() if err != nil { @@ -507,7 +509,7 @@ func (r *Router) SetTracker(tracker adapter.ConnectionTracker) { } func (r *Router) ResetNetwork() { - r.networkManager.ResetNetwork() + r.network.ResetNetwork() for _, transport := range r.transports { transport.Reset() }