package route import ( "context" "io" "net" "net/netip" "sync/atomic" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) var _ adapter.ConnectionManager = (*ConnectionManager)(nil) type ConnectionManager struct { logger logger.ContextLogger monitor *ConnectionMonitor } func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { return &ConnectionManager{ logger: logger, monitor: NewConnectionMonitor(), } } func (m *ConnectionManager) Start() error { return m.monitor.Start() } func (m *ConnectionManager) Close() error { return m.monitor.Close() } func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = adapter.WithContext(ctx, &metadata) var ( remoteConn net.Conn err error ) if len(metadata.DestinationAddresses) > 0 { remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } if err != nil { N.CloseOnHandshakeFailure(conn, onClose, err) m.logger.ErrorContext(ctx, "open outbound connection: ", err) return } err = N.ReportConnHandshakeSuccess(conn, remoteConn) if err != nil { remoteConn.Close() N.CloseOnHandshakeFailure(conn, onClose, err) m.logger.ErrorContext(ctx, "report handshake success: ", err) return } var done atomic.Bool if ctx.Done() != nil { onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) } go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose) } func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { originSource := source var readCounters, writeCounters []N.CountFunc for { source, readCounters = N.UnwrapCountReader(source, readCounters) destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) if cachedSrc, isCached := source.(N.CachedReader); isCached { cachedBuffer := cachedSrc.ReadCached() if cachedBuffer != nil { dataLen := cachedBuffer.Len() _, err := destination.Write(cachedBuffer.Bytes()) cachedBuffer.Release() if err != nil { m.logger.ErrorContext(ctx, "connection upload payload: ", err) if done.Swap(true) { if onClose != nil { onClose(err) } } common.Close(source, destination) return } for _, counter := range readCounters { counter(int64(dataLen)) } for _, counter := range writeCounters { counter(int64(dataLen)) } } continue } break } _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) if err != nil { common.Close(destination, source) } else if _, dstDuplex := destination.(N.WriteCloser); dstDuplex { N.CloseWrite(destination) } else { common.Close(destination) } if done.Swap(true) { if onClose != nil { onClose(err) } common.Close(source, destination) } if !direction { if err == nil { m.logger.DebugContext(ctx, "connection upload finished") } else if !E.IsClosedOrCanceled(err) { m.logger.ErrorContext(ctx, "connection upload closed: ", err) } else { m.logger.TraceContext(ctx, "connection upload closed") } } else { if err == nil { m.logger.DebugContext(ctx, "connection download finished") } else if !E.IsClosedOrCanceled(err) { m.logger.ErrorContext(ctx, "connection download closed: ", err) } else { m.logger.TraceContext(ctx, "connection download closed") } } } func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = adapter.WithContext(ctx, &metadata) var ( remotePacketConn net.PacketConn remoteConn net.Conn destinationAddress netip.Addr err error ) if metadata.UDPConnect { if len(metadata.DestinationAddresses) > 0 { if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) } } else { remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) } if err != nil { N.CloseOnHandshakeFailure(conn, onClose, err) m.logger.ErrorContext(ctx, "open outbound packet connection: ", err) return } remotePacketConn = bufio.NewUnbindPacketConn(remoteConn) connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr()) if connRemoteAddr != metadata.Destination.Addr { destinationAddress = connRemoteAddr } } else { if len(metadata.DestinationAddresses) > 0 { remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination) } if err != nil { N.CloseOnHandshakeFailure(conn, onClose, err) m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err) return } } err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn) if err != nil { conn.Close() remotePacketConn.Close() m.logger.ErrorContext(ctx, "report handshake success: ", err) return } if destinationAddress.IsValid() { var originDestination M.Socksaddr if metadata.RouteOriginalDestination.IsValid() { originDestination = metadata.RouteOriginalDestination } else { originDestination = metadata.Destination } if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) { if metadata.UDPDisableDomainUnmapping { remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) } else { remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) } } if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { natConn.UpdateDestination(destinationAddress) } } var udpTimeout time.Duration if metadata.UDPTimeout > 0 { udpTimeout = metadata.UDPTimeout } else { protocol := metadata.Protocol if protocol == "" { protocol = C.PortProtocols[metadata.Destination.Port] } if protocol != "" { udpTimeout = C.ProtocolTimeouts[protocol] } } if udpTimeout > 0 { ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout) } destination := bufio.NewPacketConn(remotePacketConn) if ctx.Done() != nil { onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) } var done atomic.Bool go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose) go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) } func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { _, err := bufio.CopyPacket(destination, source) /*var readCounters, writeCounters []N.CountFunc var cachedPackets []*N.PacketBuffer originSource := source for { source, readCounters = N.UnwrapCountPacketReader(source, readCounters) destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters) if cachedReader, isCached := source.(N.CachedPacketReader); isCached { packet := cachedReader.ReadCachedPacket() if packet != nil { cachedPackets = append(cachedPackets, packet) continue } } break } var handled bool if natConn, isNatConn := source.(udpnat.Conn); isNatConn { natConn.SetHandler(&udpHijacker{ ctx: ctx, logger: m.logger, source: natConn, destination: destination, direction: direction, readCounters: readCounters, writeCounters: writeCounters, done: done, onClose: onClose, }) handled = true } if cachedPackets != nil { _, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters) if err != nil { common.Close(source, destination) m.logger.ErrorContext(ctx, "packet upload payload: ", err) return } } if handled { return } _, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/ if !direction { if E.IsClosedOrCanceled(err) { m.logger.TraceContext(ctx, "packet upload closed") } else { m.logger.DebugContext(ctx, "packet upload closed: ", err) } } else { if E.IsClosedOrCanceled(err) { m.logger.TraceContext(ctx, "packet download closed") } else { m.logger.DebugContext(ctx, "packet download closed: ", err) } } if !done.Swap(true) { if onClose != nil { onClose(err) } } common.Close(source, destination) } /*type udpHijacker struct { ctx context.Context logger logger.ContextLogger source io.Closer destination N.PacketWriter direction bool readCounters []N.CountFunc writeCounters []N.CountFunc done *atomic.Bool onClose N.CloseHandlerFunc } func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { dataLen := buffer.Len() for _, counter := range u.readCounters { counter(int64(dataLen)) } err := u.destination.WritePacket(buffer, source) if err != nil { common.Close(u.source, u.destination) u.logger.DebugContext(u.ctx, "packet upload closed: ", err) return } for _, counter := range u.writeCounters { counter(int64(dataLen)) } } func (u *udpHijacker) Close() error { var err error if !u.done.Swap(true) { err = common.Close(u.source, u.destination) if u.onClose != nil { u.onClose(net.ErrClosed) } } if u.direction { u.logger.TraceContext(u.ctx, "packet download closed") } else { u.logger.TraceContext(u.ctx, "packet upload closed") } return err } func (u *udpHijacker) Upstream() any { return u.destination } */