diff --git a/outbound/direct.go b/outbound/direct.go index 259205e4..3925d2c8 100644 --- a/outbound/direct.go +++ b/outbound/direct.go @@ -30,6 +30,7 @@ type Direct struct { fallbackDelay time.Duration overrideOption int overrideDestination M.Socksaddr + loopBack *loopBackDetector } func NewDirect(router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (*Direct, error) { @@ -50,6 +51,7 @@ func NewDirect(router adapter.Router, logger log.ContextLogger, tag string, opti domainStrategy: dns.DomainStrategy(options.DomainStrategy), fallbackDelay: time.Duration(options.FallbackDelay), dialer: outboundDialer, + loopBack: newLoopBackDetector(), } if options.ProxyProtocol != 0 { return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0") @@ -88,7 +90,11 @@ func (h *Direct) DialContext(ctx context.Context, network string, destination M. case N.NetworkUDP: h.logger.InfoContext(ctx, "outbound packet connection to ", destination) } - return h.dialer.DialContext(ctx, network, destination) + conn, err := h.dialer.DialContext(ctx, network, destination) + if err != nil { + return nil, err + } + return h.loopBack.NewConn(conn), nil } func (h *Direct) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) { @@ -142,6 +148,7 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net if err != nil { return nil, err } + conn = h.loopBack.NewPacketConn(conn) if originDestination != destination { conn = bufio.NewNATPacketConn(bufio.NewPacketConn(conn), destination, originDestination) } @@ -149,9 +156,15 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net } func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + if h.loopBack.CheckConn(metadata.Source.AddrPort()) { + return E.New("reject loopback connection to ", metadata.Destination) + } return NewConnection(ctx, h, conn, metadata) } func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + if h.loopBack.CheckPacketConn(metadata.Source.AddrPort()) { + return E.New("reject loopback packet connection to ", metadata.Destination) + } return NewPacketConnection(ctx, h, conn, metadata) } diff --git a/outbound/direct_loopback_detect.go b/outbound/direct_loopback_detect.go new file mode 100644 index 00000000..22e15cb7 --- /dev/null +++ b/outbound/direct_loopback_detect.go @@ -0,0 +1,153 @@ +package outbound + +import ( + "net" + "net/netip" + "sync" + + M "github.com/sagernet/sing/common/metadata" +) + +type loopBackDetector struct { + connAccess sync.RWMutex + packetConnAccess sync.RWMutex + connMap map[netip.AddrPort]bool + packetConnMap map[netip.AddrPort]bool +} + +func newLoopBackDetector() *loopBackDetector { + return &loopBackDetector{ + connMap: make(map[netip.AddrPort]bool), + packetConnMap: make(map[netip.AddrPort]bool), + } +} + +func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn { + connAddr := M.AddrPortFromNet(conn.LocalAddr()) + if !connAddr.IsValid() { + return conn + } + if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn { + l.packetConnAccess.Lock() + l.packetConnMap[connAddr] = true + l.packetConnAccess.Unlock() + return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connAddr: connAddr} + } else { + l.connAccess.Lock() + l.connMap[connAddr] = true + l.connAccess.Unlock() + return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: connAddr} + } +} + +func (l *loopBackDetector) NewPacketConn(conn net.PacketConn) net.PacketConn { + connAddr := M.AddrPortFromNet(conn.LocalAddr()) + if !connAddr.IsValid() { + return conn + } + l.packetConnAccess.Lock() + l.packetConnMap[connAddr] = true + l.packetConnAccess.Unlock() + return &loopBackDetectPacketWrapper{PacketConn: conn, detector: l, connAddr: connAddr} +} + +func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool { + l.connAccess.RLock() + defer l.connAccess.RUnlock() + return l.connMap[connAddr] +} + +func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool { + l.packetConnAccess.RLock() + defer l.packetConnAccess.RUnlock() + return l.packetConnMap[connAddr] +} + +type loopBackDetectWrapper struct { + net.Conn + detector *loopBackDetector + connAddr netip.AddrPort + closeOnce sync.Once +} + +func (w *loopBackDetectWrapper) Close() error { + w.closeOnce.Do(func() { + w.detector.connAccess.Lock() + delete(w.detector.connMap, w.connAddr) + w.detector.connAccess.Unlock() + }) + return w.Conn.Close() +} + +func (w *loopBackDetectWrapper) ReaderReplaceable() bool { + return true +} + +func (w *loopBackDetectWrapper) WriterReplaceable() bool { + return true +} + +func (w *loopBackDetectWrapper) Upstream() any { + return w.Conn +} + +type loopBackDetectPacketWrapper struct { + net.PacketConn + detector *loopBackDetector + connAddr netip.AddrPort + closeOnce sync.Once +} + +func (w *loopBackDetectPacketWrapper) Close() error { + w.closeOnce.Do(func() { + w.detector.packetConnAccess.Lock() + delete(w.detector.packetConnMap, w.connAddr) + w.detector.packetConnAccess.Unlock() + }) + return w.PacketConn.Close() +} + +func (w *loopBackDetectPacketWrapper) ReaderReplaceable() bool { + return true +} + +func (w *loopBackDetectPacketWrapper) WriterReplaceable() bool { + return true +} + +func (w *loopBackDetectPacketWrapper) Upstream() any { + return w.PacketConn +} + +type abstractUDPConn interface { + net.Conn + net.PacketConn +} + +type loopBackDetectUDPWrapper struct { + abstractUDPConn + detector *loopBackDetector + connAddr netip.AddrPort + closeOnce sync.Once +} + +func (w *loopBackDetectUDPWrapper) Close() error { + w.closeOnce.Do(func() { + w.detector.packetConnAccess.Lock() + delete(w.detector.packetConnMap, w.connAddr) + w.detector.packetConnAccess.Unlock() + }) + return w.abstractUDPConn.Close() +} + +func (w *loopBackDetectUDPWrapper) ReaderReplaceable() bool { + return true +} + +func (w *loopBackDetectUDPWrapper) WriterReplaceable() bool { + return true +} + +func (w *loopBackDetectUDPWrapper) Upstream() any { + return w.abstractUDPConn +}