Fix loopback detector

This commit is contained in:
世界 2024-04-17 21:35:26 +08:00
parent 76f20482f7
commit 7b0f5061dc
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
2 changed files with 52 additions and 27 deletions

View File

@ -148,7 +148,7 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn)) conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn), destination)
if originDestination != destination { if originDestination != destination {
conn = bufio.NewNATPacketConn(bufio.NewPacketConn(conn), destination, originDestination) conn = bufio.NewNATPacketConn(bufio.NewPacketConn(conn), destination, originDestination)
} }
@ -156,14 +156,14 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
} }
func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
if h.loopBack.CheckConn(metadata.Source.AddrPort()) { if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
return E.New("reject loopback connection to ", metadata.Destination) return E.New("reject loopback connection to ", metadata.Destination)
} }
return NewConnection(ctx, h, conn, metadata) return NewConnection(ctx, h, conn, metadata)
} }
func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
if h.loopBack.CheckPacketConn(metadata.Source.AddrPort()) { if h.loopBack.CheckPacketConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
return E.New("reject loopback packet connection to ", metadata.Destination) return E.New("reject loopback packet connection to ", metadata.Destination)
} }
return NewPacketConnection(ctx, h, conn, metadata) return NewPacketConnection(ctx, h, conn, metadata)

View File

@ -10,58 +10,83 @@ import (
) )
type loopBackDetector struct { type loopBackDetector struct {
// router adapter.Router
connAccess sync.RWMutex connAccess sync.RWMutex
packetConnAccess sync.RWMutex packetConnAccess sync.RWMutex
connMap map[netip.AddrPort]bool connMap map[netip.AddrPort]netip.AddrPort
packetConnMap map[netip.AddrPort]bool packetConnMap map[uint16]uint16
} }
func newLoopBackDetector() *loopBackDetector { func newLoopBackDetector( /*router adapter.Router*/ ) *loopBackDetector {
return &loopBackDetector{ return &loopBackDetector{
connMap: make(map[netip.AddrPort]bool), // router: router,
packetConnMap: make(map[netip.AddrPort]bool), connMap: make(map[netip.AddrPort]netip.AddrPort),
packetConnMap: make(map[uint16]uint16),
} }
} }
func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn { func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn {
connAddr := M.AddrPortFromNet(conn.LocalAddr()) source := M.AddrPortFromNet(conn.LocalAddr())
if !connAddr.IsValid() { if !source.IsValid() {
return conn return conn
} }
if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn { if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
/*if !source.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
if err != nil {
return conn
}
}*/
if !N.IsPublicAddr(source.Addr()) {
return conn
}
l.packetConnAccess.Lock() l.packetConnAccess.Lock()
l.packetConnMap[connAddr] = true l.packetConnMap[source.Port()] = M.AddrPortFromNet(conn.RemoteAddr()).Port()
l.packetConnAccess.Unlock() l.packetConnAccess.Unlock()
return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connAddr: connAddr} return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connPort: source.Port()}
} else { } else {
l.connAccess.Lock() l.connAccess.Lock()
l.connMap[connAddr] = true l.connMap[source] = M.AddrPortFromNet(conn.RemoteAddr())
l.connAccess.Unlock() l.connAccess.Unlock()
return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: connAddr} return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: source}
} }
} }
func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn) N.NetPacketConn { func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Socksaddr) N.NetPacketConn {
connAddr := M.AddrPortFromNet(conn.LocalAddr()) source := M.AddrPortFromNet(conn.LocalAddr())
if !connAddr.IsValid() { if !source.IsValid() {
return conn return conn
} }
l.packetConnAccess.Lock() l.packetConnAccess.Lock()
l.packetConnMap[connAddr] = true l.packetConnMap[source.Port()] = destination.AddrPort().Port()
l.packetConnAccess.Unlock() l.packetConnAccess.Unlock()
return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connAddr: connAddr} return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connPort: source.Port()}
} }
func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool { func (l *loopBackDetector) CheckConn(source netip.AddrPort, local netip.AddrPort) bool {
l.connAccess.RLock() l.connAccess.RLock()
defer l.connAccess.RUnlock() defer l.connAccess.RUnlock()
return l.connMap[connAddr] destination, loaded := l.connMap[source]
return loaded && destination != local
} }
func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool { func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.AddrPort) bool {
if !source.IsValid() {
return false
}
/*if !source.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
if err != nil {
return false
}
}*/
if N.IsPublicAddr(source.Addr()) {
return false
}
l.packetConnAccess.RLock() l.packetConnAccess.RLock()
defer l.packetConnAccess.RUnlock() defer l.packetConnAccess.RUnlock()
return l.packetConnMap[connAddr] destinationPort, loaded := l.packetConnMap[source.Port()]
return loaded && destinationPort != local.Port()
} }
type loopBackDetectWrapper struct { type loopBackDetectWrapper struct {
@ -95,14 +120,14 @@ func (w *loopBackDetectWrapper) Upstream() any {
type loopBackDetectPacketWrapper struct { type loopBackDetectPacketWrapper struct {
N.NetPacketConn N.NetPacketConn
detector *loopBackDetector detector *loopBackDetector
connAddr netip.AddrPort connPort uint16
closeOnce sync.Once closeOnce sync.Once
} }
func (w *loopBackDetectPacketWrapper) Close() error { func (w *loopBackDetectPacketWrapper) Close() error {
w.closeOnce.Do(func() { w.closeOnce.Do(func() {
w.detector.packetConnAccess.Lock() w.detector.packetConnAccess.Lock()
delete(w.detector.packetConnMap, w.connAddr) delete(w.detector.packetConnMap, w.connPort)
w.detector.packetConnAccess.Unlock() w.detector.packetConnAccess.Unlock()
}) })
return w.NetPacketConn.Close() return w.NetPacketConn.Close()
@ -128,14 +153,14 @@ type abstractUDPConn interface {
type loopBackDetectUDPWrapper struct { type loopBackDetectUDPWrapper struct {
abstractUDPConn abstractUDPConn
detector *loopBackDetector detector *loopBackDetector
connAddr netip.AddrPort connPort uint16
closeOnce sync.Once closeOnce sync.Once
} }
func (w *loopBackDetectUDPWrapper) Close() error { func (w *loopBackDetectUDPWrapper) Close() error {
w.closeOnce.Do(func() { w.closeOnce.Do(func() {
w.detector.packetConnAccess.Lock() w.detector.packetConnAccess.Lock()
delete(w.detector.packetConnMap, w.connAddr) delete(w.detector.packetConnMap, w.connPort)
w.detector.packetConnAccess.Unlock() w.detector.packetConnAccess.Unlock()
}) })
return w.abstractUDPConn.Close() return w.abstractUDPConn.Close()