From cb2e15f8a7cf804f70fe21263306afea6dc83d92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 Oct 2023 12:00:00 +0800 Subject: [PATCH] Fix UDP domain NAT --- adapter/inbound.go | 8 ++++ common/dialer/resolve.go | 4 +- outbound/default.go | 11 ++++- outbound/direct.go | 2 +- route/router.go | 2 +- test/box_test.go | 27 ++++++++++++ test/domain_inbound_test.go | 83 +++++++++++++++++++++++++++++++++++++ test/go.mod | 6 +-- test/go.sum | 4 +- test/hysteria2_test.go | 2 +- test/wireguard_test.go | 2 +- 11 files changed, 138 insertions(+), 13 deletions(-) create mode 100644 test/domain_inbound_test.go diff --git a/adapter/inbound.go b/adapter/inbound.go index 6a566dc2..2d24083c 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -75,3 +75,11 @@ func AppendContext(ctx context.Context) (context.Context, *InboundContext) { metadata = new(InboundContext) return WithContext(ctx, metadata), metadata } + +func ExtendContext(ctx context.Context) (context.Context, *InboundContext) { + var newMetadata InboundContext + if metadata := ContextFrom(ctx); metadata != nil { + newMetadata = *metadata + } + return WithContext(ctx, &newMetadata), &newMetadata +} diff --git a/common/dialer/resolve.go b/common/dialer/resolve.go index 9e20c81d..f2ee50db 100644 --- a/common/dialer/resolve.go +++ b/common/dialer/resolve.go @@ -36,7 +36,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina if !destination.IsFqdn() { return d.dialer.DialContext(ctx, network, destination) } - ctx, metadata := adapter.AppendContext(ctx) + ctx, metadata := adapter.ExtendContext(ctx) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) metadata.Destination = destination metadata.Domain = "" @@ -61,7 +61,7 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd if !destination.IsFqdn() { return d.dialer.ListenPacket(ctx, destination) } - ctx, metadata := adapter.AppendContext(ctx) + ctx, metadata := adapter.ExtendContext(ctx) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) metadata.Destination = destination metadata.Domain = "" diff --git a/outbound/default.go b/outbound/default.go index 0382825f..79ed7b33 100644 --- a/outbound/default.go +++ b/outbound/default.go @@ -17,6 +17,7 @@ import ( "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -119,7 +120,10 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, return err } if destinationAddress.IsValid() { - if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { + if metadata.Destination.IsFqdn() { + outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) + } + if natConn, loaded := common.Cast[*bufio.NATPacketConn](conn); loaded { natConn.UpdateDestination(destinationAddress) } } @@ -159,7 +163,10 @@ func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this return err } if destinationAddress.IsValid() { - if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { + if metadata.Destination.IsFqdn() { + outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) + } + if natConn, loaded := common.Cast[*bufio.NATPacketConn](conn); loaded { natConn.UpdateDestination(destinationAddress) } } diff --git a/outbound/direct.go b/outbound/direct.go index ed126830..d5a835c5 100644 --- a/outbound/direct.go +++ b/outbound/direct.go @@ -164,7 +164,7 @@ func (h *Direct) DialParallel(ctx context.Context, network string, destination M } func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - ctx, metadata := adapter.AppendContext(ctx) + ctx, metadata := adapter.ExtendContext(ctx) metadata.Outbound = h.tag metadata.Destination = destination switch h.overrideOption { diff --git a/route/router.go b/route/router.go index e02e9b39..e7658c18 100644 --- a/route/router.go +++ b/route/router.go @@ -835,7 +835,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m } } if metadata.FakeIP { - conn = fakeip.NewNATPacketConn(conn, metadata.OriginDestination, metadata.Destination) + conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) } return detour.NewPacketConnection(ctx, conn, metadata) } diff --git a/test/box_test.go b/test/box_test.go index 4092979f..5009cc63 100644 --- a/test/box_test.go +++ b/test/box_test.go @@ -2,10 +2,15 @@ package main import ( "context" + "crypto/tls" + "io" "net" + "net/http" "testing" "time" + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/http3" "github.com/sagernet/sing-box" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" @@ -74,6 +79,28 @@ func testSuit(t *testing.T, clientPort uint16, testPort uint16) { // require.NoError(t, testPacketConnTimeout(t, dialUDP)) } +func testQUIC(t *testing.T, clientPort uint16) { + dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "") + client := &http.Client{ + Transport: &http3.RoundTripper{ + Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + destination := M.ParseSocksaddr(addr) + udpConn, err := dialer.DialContext(ctx, N.NetworkUDP, destination) + if err != nil { + return nil, err + } + return quic.DialEarly(ctx, udpConn.(net.PacketConn), destination, tlsCfg, cfg) + }, + }, + } + response, err := client.Get("https://cloudflare.com/cdn-cgi/trace") + require.NoError(t, err) + require.Equal(t, http.StatusOK, response.StatusCode) + content, err := io.ReadAll(response.Body) + require.NoError(t, err) + println(string(content)) +} + func testSuitLargeUDP(t *testing.T, clientPort uint16, testPort uint16) { dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "") dialTCP := func() (net.Conn, error) { diff --git a/test/domain_inbound_test.go b/test/domain_inbound_test.go new file mode 100644 index 00000000..bf43aa98 --- /dev/null +++ b/test/domain_inbound_test.go @@ -0,0 +1,83 @@ +package main + +import ( + "net/netip" + "testing" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + dns "github.com/sagernet/sing-dns" + + "github.com/gofrs/uuid/v5" +) + +func TestTUICDomainUDP(t *testing.T) { + _, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + startInstance(t, option.Options{ + Inbounds: []option.Inbound{ + { + Type: C.TypeMixed, + Tag: "mixed-in", + MixedOptions: option.HTTPMixedInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: clientPort, + }, + }, + }, + { + Type: C.TypeTUIC, + TUICOptions: option.TUICInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: serverPort, + InboundOptions: option.InboundOptions{ + DomainStrategy: option.DomainStrategy(dns.DomainStrategyUseIPv6), + }, + }, + Users: []option.TUICUser{{ + UUID: uuid.Nil.String(), + }}, + TLS: &option.InboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + KeyPath: keyPem, + }, + }, + }, + }, + Outbounds: []option.Outbound{ + { + Type: C.TypeDirect, + }, + { + Type: C.TypeTUIC, + Tag: "tuic-out", + TUICOptions: option.TUICOutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + UUID: uuid.Nil.String(), + TLS: &option.OutboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + }, + }, + }, + }, + Route: &option.RouteOptions{ + Rules: []option.Rule{ + { + DefaultOptions: option.DefaultRule{ + Inbound: []string{"mixed-in"}, + Outbound: "tuic-out", + }, + }, + }, + }, + }) + testQUIC(t, clientPort) +} diff --git a/test/go.mod b/test/go.mod index b4336c69..85cda496 100644 --- a/test/go.mod +++ b/test/go.mod @@ -10,7 +10,9 @@ require ( github.com/docker/docker v24.0.6+incompatible github.com/docker/go-connections v0.4.0 github.com/gofrs/uuid/v5 v5.0.0 + github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee github.com/sagernet/sing v0.2.15 + github.com/sagernet/sing-dns v0.1.10 github.com/sagernet/sing-quic v0.1.2 github.com/sagernet/sing-shadowsocks v0.2.5 github.com/sagernet/sing-shadowsocks2 v0.1.4 @@ -73,12 +75,10 @@ require ( github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 // indirect github.com/sagernet/gvisor v0.0.0-20230627031050-1ab0276e0dd2 // indirect github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 // indirect - github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee // indirect github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 // indirect - github.com/sagernet/sing-dns v0.1.10 // indirect github.com/sagernet/sing-mux v0.1.3 // indirect github.com/sagernet/sing-shadowtls v0.1.4 // indirect - github.com/sagernet/sing-tun v0.1.15 // indirect + github.com/sagernet/sing-tun v0.1.16 // indirect github.com/sagernet/sing-vmess v0.1.8 // indirect github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 // indirect github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 // indirect diff --git a/test/go.sum b/test/go.sum index 9d2a8ca1..9e371949 100644 --- a/test/go.sum +++ b/test/go.sum @@ -147,8 +147,8 @@ github.com/sagernet/sing-shadowsocks2 v0.1.4 h1:vht2M8t3m5DTgXR2j24KbYOygG5aOp+M github.com/sagernet/sing-shadowsocks2 v0.1.4/go.mod h1:Mgdee99NxxNd5Zld3ixIs18yVs4x2dI2VTDDE1N14Wc= github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnVpEx6Tw3k= github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4= -github.com/sagernet/sing-tun v0.1.15 h1:XfHQD/dhCCQeespPojB4gRhADI1A/4mSLLJCnh5qUnQ= -github.com/sagernet/sing-tun v0.1.15/go.mod h1:zgRoBAtOM24QXx0IKYFEnuTtXPq1Z4rDYRWkP8kJm+g= +github.com/sagernet/sing-tun v0.1.16 h1:RHXYIVg6uacvdfbYMiPEz9VX5uu6mNrvP7u9yAH3oNc= +github.com/sagernet/sing-tun v0.1.16/go.mod h1:S3q8GCjeyRniK+KLmo4XqKY0bS3x2UdKkKbqxT/Agl8= github.com/sagernet/sing-vmess v0.1.8 h1:XVWad1RpTy9b5tPxdm5MCU8cGfrTGdR8qCq6HV2aCNc= github.com/sagernet/sing-vmess v0.1.8/go.mod h1:vhx32UNzTDUkNwOyIjcZQohre1CaytquC5mPplId8uA= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= diff --git a/test/hysteria2_test.go b/test/hysteria2_test.go index 695e598b..04735eeb 100644 --- a/test/hysteria2_test.go +++ b/test/hysteria2_test.go @@ -97,7 +97,7 @@ func testHysteria2Self(t *testing.T, salamanderPassword string) { }, }, }) - testSuit(t, clientPort, testPort) + testSuitLargeUDP(t, clientPort, testPort) } func TestHysteria2Inbound(t *testing.T) { diff --git a/test/wireguard_test.go b/test/wireguard_test.go index 1889a3a6..50e87ee0 100644 --- a/test/wireguard_test.go +++ b/test/wireguard_test.go @@ -40,7 +40,7 @@ func _TestWireGuard(t *testing.T) { Server: "127.0.0.1", ServerPort: serverPort, }, - LocalAddress: []option.ListenPrefix{option.ListenPrefix(netip.MustParsePrefix("10.0.0.2/32"))}, + LocalAddress: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")}, PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=", PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=", },