diff --git a/common/dialer/default.go b/common/dialer/default.go index 53fc6dc7..654d13a3 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -53,8 +53,9 @@ var warnTFOOnUnsupportedPlatform = warning.New( ) type DefaultDialer struct { - tfo.Dialer - net.ListenConfig + dialer tfo.Dialer + udpDialer net.Dialer + udpListener net.ListenConfig bindUDPAddr string } @@ -111,24 +112,28 @@ func NewDefault(router adapter.Router, options option.DialerOptions) *DefaultDia warnTFOOnUnsupportedPlatform.Check() } var bindUDPAddr string + udpDialer := dialer bindAddress := netip.Addr(options.BindAddress) if bindAddress.IsValid() { dialer.LocalAddr = &net.TCPAddr{ IP: bindAddress.AsSlice(), } + udpDialer.LocalAddr = &net.UDPAddr{ + IP: bindAddress.AsSlice(), + } bindUDPAddr = M.SocksaddrFrom(bindAddress, 0).String() } - return &DefaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, listener, bindUDPAddr} + return &DefaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, udpDialer, listener, bindUDPAddr} } func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) { - return d.Dialer.DialContext(ctx, network, address.Unwrap().String()) + switch N.NetworkName(network) { + case N.NetworkUDP: + return d.udpDialer.DialContext(ctx, network, address.String()) + } + return d.dialer.DialContext(ctx, network, address.Unwrap().String()) } func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return d.ListenConfig.ListenPacket(ctx, N.NetworkUDP, d.bindUDPAddr) -} - -func (d *DefaultDialer) Upstream() any { - return &d.Dialer + return d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.bindUDPAddr) }