package route import ( "context" "io" "net" "net/netip" "sync" "sync/atomic" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" ) var _ adapter.ConnectionManager = (*ConnectionManager)(nil) type ConnectionManager struct { logger logger.ContextLogger access sync.Mutex connections list.List[io.Closer] } func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { return &ConnectionManager{ logger: logger, } } func (m *ConnectionManager) Start(stage adapter.StartStage) error { return nil } func (m *ConnectionManager) Close() error { m.access.Lock() defer m.access.Unlock() for element := m.connections.Front(); element != nil; element = element.Next() { common.Close(element.Value) } m.connections.Init() return nil } func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = adapter.WithContext(ctx, &metadata) var ( remoteConn net.Conn err error ) if len(metadata.DestinationAddresses) > 0 { remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } if err != nil { err = E.Cause(err, "open outbound connection") N.CloseOnHandshakeFailure(conn, onClose, err) m.logger.ErrorContext(ctx, err) return } err = N.ReportConnHandshakeSuccess(conn, remoteConn) if err != nil { err = E.Cause(err, "report handshake success") remoteConn.Close() N.CloseOnHandshakeFailure(conn, onClose, err) m.logger.ErrorContext(ctx, err) return } m.access.Lock() element := m.connections.PushBack(conn) m.access.Unlock() onClose = N.AppendClose(onClose, func(it error) { m.access.Lock() defer m.access.Unlock() m.connections.Remove(element) }) var done atomic.Bool go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose) } func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = adapter.WithContext(ctx, &metadata) var ( remotePacketConn net.PacketConn remoteConn net.Conn destinationAddress netip.Addr err error ) if metadata.UDPConnect { if len(metadata.DestinationAddresses) > 0 { if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) } } else { remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) } if err != nil { N.CloseOnHandshakeFailure(conn, onClose, err) m.logger.ErrorContext(ctx, "open outbound packet connection: ", err) return } remotePacketConn = bufio.NewUnbindPacketConn(remoteConn) connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr()) if connRemoteAddr != metadata.Destination.Addr { destinationAddress = connRemoteAddr } } else { if len(metadata.DestinationAddresses) > 0 { remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination) } if err != nil { N.CloseOnHandshakeFailure(conn, onClose, err) m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err) return } } err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn) if err != nil { conn.Close() remotePacketConn.Close() m.logger.ErrorContext(ctx, "report handshake success: ", err) return } if destinationAddress.IsValid() { var originDestination M.Socksaddr if metadata.RouteOriginalDestination.IsValid() { originDestination = metadata.RouteOriginalDestination } else { originDestination = metadata.Destination } if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) { if metadata.UDPDisableDomainUnmapping { remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) } else { remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) } } if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { natConn.UpdateDestination(destinationAddress) } } var udpTimeout time.Duration if metadata.UDPTimeout > 0 { udpTimeout = metadata.UDPTimeout } else { protocol := metadata.Protocol if protocol == "" { protocol = C.PortProtocols[metadata.Destination.Port] } if protocol != "" { udpTimeout = C.ProtocolTimeouts[protocol] } } if udpTimeout > 0 { ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout) } destination := bufio.NewPacketConn(remotePacketConn) m.access.Lock() element := m.connections.PushBack(conn) m.access.Unlock() onClose = N.AppendClose(onClose, func(it error) { m.access.Lock() defer m.access.Unlock() m.connections.Remove(element) }) var done atomic.Bool go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose) go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) } func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { originSource := source originDestination := destination var readCounters, writeCounters []N.CountFunc for { source, readCounters = N.UnwrapCountReader(source, readCounters) destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) if cachedSrc, isCached := source.(N.CachedReader); isCached { cachedBuffer := cachedSrc.ReadCached() if cachedBuffer != nil { dataLen := cachedBuffer.Len() _, err := destination.Write(cachedBuffer.Bytes()) cachedBuffer.Release() if err != nil { if done.Swap(true) { onClose(err) } common.Close(originSource, originDestination) if !direction { m.logger.ErrorContext(ctx, "connection upload payload: ", err) } else { m.logger.ErrorContext(ctx, "connection download payload: ", err) } return } for _, counter := range readCounters { counter(int64(dataLen)) } for _, counter := range writeCounters { counter(int64(dataLen)) } } continue } break } _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) if err != nil { common.Close(originDestination) } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { err = duplexDst.CloseWrite() if err != nil { common.Close(originSource, originDestination) } } else { common.Close(originDestination) } if done.Swap(true) { onClose(err) common.Close(originSource, originDestination) } if !direction { if err == nil { m.logger.DebugContext(ctx, "connection upload finished") } else if !E.IsClosedOrCanceled(err) { m.logger.ErrorContext(ctx, "connection upload closed: ", err) } else { m.logger.TraceContext(ctx, "connection upload closed") } } else { if err == nil { m.logger.DebugContext(ctx, "connection download finished") } else if !E.IsClosedOrCanceled(err) { m.logger.ErrorContext(ctx, "connection download closed: ", err) } else { m.logger.TraceContext(ctx, "connection download closed") } } } func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { _, err := bufio.CopyPacket(destination, source) if !direction { if E.IsClosedOrCanceled(err) { m.logger.TraceContext(ctx, "packet upload closed") } else { m.logger.DebugContext(ctx, "packet upload closed: ", err) } } else { if E.IsClosedOrCanceled(err) { m.logger.TraceContext(ctx, "packet download closed") } else { m.logger.DebugContext(ctx, "packet download closed: ", err) } } if !done.Swap(true) { onClose(err) } common.Close(source, destination) }