Fix loopback detector

This commit is contained in:
世界 2024-04-17 21:35:26 +08:00
parent 76f20482f7
commit 7b0f5061dc
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 52 additions and 27 deletions

View file

@ -148,7 +148,7 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
if err != nil {
return nil, err
}
conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn))
conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn), destination)
if originDestination != destination {
conn = bufio.NewNATPacketConn(bufio.NewPacketConn(conn), destination, originDestination)
}
@ -156,14 +156,14 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
}
func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
if h.loopBack.CheckConn(metadata.Source.AddrPort()) {
if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
return E.New("reject loopback connection to ", metadata.Destination)
}
return NewConnection(ctx, h, conn, metadata)
}
func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
if h.loopBack.CheckPacketConn(metadata.Source.AddrPort()) {
if h.loopBack.CheckPacketConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
return E.New("reject loopback packet connection to ", metadata.Destination)
}
return NewPacketConnection(ctx, h, conn, metadata)

View file

@ -10,58 +10,83 @@ import (
)
type loopBackDetector struct {
// router adapter.Router
connAccess sync.RWMutex
packetConnAccess sync.RWMutex
connMap map[netip.AddrPort]bool
packetConnMap map[netip.AddrPort]bool
connMap map[netip.AddrPort]netip.AddrPort
packetConnMap map[uint16]uint16
}
func newLoopBackDetector() *loopBackDetector {
func newLoopBackDetector( /*router adapter.Router*/ ) *loopBackDetector {
return &loopBackDetector{
connMap: make(map[netip.AddrPort]bool),
packetConnMap: make(map[netip.AddrPort]bool),
// router: router,
connMap: make(map[netip.AddrPort]netip.AddrPort),
packetConnMap: make(map[uint16]uint16),
}
}
func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn {
connAddr := M.AddrPortFromNet(conn.LocalAddr())
if !connAddr.IsValid() {
source := M.AddrPortFromNet(conn.LocalAddr())
if !source.IsValid() {
return conn
}
if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
/*if !source.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
if err != nil {
return conn
}
}*/
if !N.IsPublicAddr(source.Addr()) {
return conn
}
l.packetConnAccess.Lock()
l.packetConnMap[connAddr] = true
l.packetConnMap[source.Port()] = M.AddrPortFromNet(conn.RemoteAddr()).Port()
l.packetConnAccess.Unlock()
return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connAddr: connAddr}
return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connPort: source.Port()}
} else {
l.connAccess.Lock()
l.connMap[connAddr] = true
l.connMap[source] = M.AddrPortFromNet(conn.RemoteAddr())
l.connAccess.Unlock()
return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: connAddr}
return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: source}
}
}
func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
connAddr := M.AddrPortFromNet(conn.LocalAddr())
if !connAddr.IsValid() {
func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Socksaddr) N.NetPacketConn {
source := M.AddrPortFromNet(conn.LocalAddr())
if !source.IsValid() {
return conn
}
l.packetConnAccess.Lock()
l.packetConnMap[connAddr] = true
l.packetConnMap[source.Port()] = destination.AddrPort().Port()
l.packetConnAccess.Unlock()
return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connAddr: connAddr}
return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connPort: source.Port()}
}
func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool {
func (l *loopBackDetector) CheckConn(source netip.AddrPort, local netip.AddrPort) bool {
l.connAccess.RLock()
defer l.connAccess.RUnlock()
return l.connMap[connAddr]
destination, loaded := l.connMap[source]
return loaded && destination != local
}
func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool {
func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.AddrPort) bool {
if !source.IsValid() {
return false
}
/*if !source.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
if err != nil {
return false
}
}*/
if N.IsPublicAddr(source.Addr()) {
return false
}
l.packetConnAccess.RLock()
defer l.packetConnAccess.RUnlock()
return l.packetConnMap[connAddr]
destinationPort, loaded := l.packetConnMap[source.Port()]
return loaded && destinationPort != local.Port()
}
type loopBackDetectWrapper struct {
@ -95,14 +120,14 @@ func (w *loopBackDetectWrapper) Upstream() any {
type loopBackDetectPacketWrapper struct {
N.NetPacketConn
detector *loopBackDetector
connAddr netip.AddrPort
connPort uint16
closeOnce sync.Once
}
func (w *loopBackDetectPacketWrapper) Close() error {
w.closeOnce.Do(func() {
w.detector.packetConnAccess.Lock()
delete(w.detector.packetConnMap, w.connAddr)
delete(w.detector.packetConnMap, w.connPort)
w.detector.packetConnAccess.Unlock()
})
return w.NetPacketConn.Close()
@ -128,14 +153,14 @@ type abstractUDPConn interface {
type loopBackDetectUDPWrapper struct {
abstractUDPConn
detector *loopBackDetector
connAddr netip.AddrPort
connPort uint16
closeOnce sync.Once
}
func (w *loopBackDetectUDPWrapper) Close() error {
w.closeOnce.Do(func() {
w.detector.packetConnAccess.Lock()
delete(w.detector.packetConnMap, w.connAddr)
delete(w.detector.packetConnMap, w.connPort)
w.detector.packetConnAccess.Unlock()
})
return w.abstractUDPConn.Close()