diff --git a/outbound/direct.go b/outbound/direct.go index 2d3e6f84..49ac760e 100644 --- a/outbound/direct.go +++ b/outbound/direct.go @@ -148,7 +148,7 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net if err != nil { return nil, err } - conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn)) + conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn), destination) if originDestination != destination { conn = bufio.NewNATPacketConn(bufio.NewPacketConn(conn), destination, originDestination) } @@ -156,14 +156,14 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net } func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - if h.loopBack.CheckConn(metadata.Source.AddrPort()) { + if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) { return E.New("reject loopback connection to ", metadata.Destination) } return NewConnection(ctx, h, conn, metadata) } func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - if h.loopBack.CheckPacketConn(metadata.Source.AddrPort()) { + if h.loopBack.CheckPacketConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) { return E.New("reject loopback packet connection to ", metadata.Destination) } return NewPacketConnection(ctx, h, conn, metadata) diff --git a/outbound/direct_loopback_detect.go b/outbound/direct_loopback_detect.go index 41cd26bc..62cff876 100644 --- a/outbound/direct_loopback_detect.go +++ b/outbound/direct_loopback_detect.go @@ -10,58 +10,83 @@ import ( ) type loopBackDetector struct { + // router adapter.Router connAccess sync.RWMutex packetConnAccess sync.RWMutex - connMap map[netip.AddrPort]bool - packetConnMap map[netip.AddrPort]bool + connMap map[netip.AddrPort]netip.AddrPort + packetConnMap map[uint16]uint16 } -func newLoopBackDetector() *loopBackDetector { +func newLoopBackDetector( /*router adapter.Router*/ ) *loopBackDetector { return &loopBackDetector{ - connMap: make(map[netip.AddrPort]bool), - packetConnMap: make(map[netip.AddrPort]bool), + // router: router, + connMap: make(map[netip.AddrPort]netip.AddrPort), + packetConnMap: make(map[uint16]uint16), } } func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn { - connAddr := M.AddrPortFromNet(conn.LocalAddr()) - if !connAddr.IsValid() { + source := M.AddrPortFromNet(conn.LocalAddr()) + if !source.IsValid() { return conn } if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn { + /*if !source.Addr().IsLoopback() { + _, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr()) + if err != nil { + return conn + } + }*/ + if !N.IsPublicAddr(source.Addr()) { + return conn + } l.packetConnAccess.Lock() - l.packetConnMap[connAddr] = true + l.packetConnMap[source.Port()] = M.AddrPortFromNet(conn.RemoteAddr()).Port() l.packetConnAccess.Unlock() - return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connAddr: connAddr} + return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connPort: source.Port()} } else { l.connAccess.Lock() - l.connMap[connAddr] = true + l.connMap[source] = M.AddrPortFromNet(conn.RemoteAddr()) l.connAccess.Unlock() - return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: connAddr} + return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: source} } } -func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn) N.NetPacketConn { - connAddr := M.AddrPortFromNet(conn.LocalAddr()) - if !connAddr.IsValid() { +func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Socksaddr) N.NetPacketConn { + source := M.AddrPortFromNet(conn.LocalAddr()) + if !source.IsValid() { return conn } l.packetConnAccess.Lock() - l.packetConnMap[connAddr] = true + l.packetConnMap[source.Port()] = destination.AddrPort().Port() l.packetConnAccess.Unlock() - return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connAddr: connAddr} + return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connPort: source.Port()} } -func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool { +func (l *loopBackDetector) CheckConn(source netip.AddrPort, local netip.AddrPort) bool { l.connAccess.RLock() defer l.connAccess.RUnlock() - return l.connMap[connAddr] + destination, loaded := l.connMap[source] + return loaded && destination != local } -func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool { +func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.AddrPort) bool { + if !source.IsValid() { + return false + } + /*if !source.Addr().IsLoopback() { + _, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr()) + if err != nil { + return false + } + }*/ + if N.IsPublicAddr(source.Addr()) { + return false + } l.packetConnAccess.RLock() defer l.packetConnAccess.RUnlock() - return l.packetConnMap[connAddr] + destinationPort, loaded := l.packetConnMap[source.Port()] + return loaded && destinationPort != local.Port() } type loopBackDetectWrapper struct { @@ -95,14 +120,14 @@ func (w *loopBackDetectWrapper) Upstream() any { type loopBackDetectPacketWrapper struct { N.NetPacketConn detector *loopBackDetector - connAddr netip.AddrPort + connPort uint16 closeOnce sync.Once } func (w *loopBackDetectPacketWrapper) Close() error { w.closeOnce.Do(func() { w.detector.packetConnAccess.Lock() - delete(w.detector.packetConnMap, w.connAddr) + delete(w.detector.packetConnMap, w.connPort) w.detector.packetConnAccess.Unlock() }) return w.NetPacketConn.Close() @@ -128,14 +153,14 @@ type abstractUDPConn interface { type loopBackDetectUDPWrapper struct { abstractUDPConn detector *loopBackDetector - connAddr netip.AddrPort + connPort uint16 closeOnce sync.Once } func (w *loopBackDetectUDPWrapper) Close() error { w.closeOnce.Do(func() { w.detector.packetConnAccess.Lock() - delete(w.detector.packetConnMap, w.connAddr) + delete(w.detector.packetConnMap, w.connPort) w.detector.packetConnAccess.Unlock() }) return w.abstractUDPConn.Close()