From d30abe097bfd7bb82d2774d048b22bc8953d71d0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Thu, 7 Nov 2024 13:44:00 +0800
Subject: [PATCH] Implement udp connect

---
 adapter/outbound/default.go | 106 +++++++++++++++++++++++++-----------
 1 file changed, 75 insertions(+), 31 deletions(-)

diff --git a/adapter/outbound/default.go b/adapter/outbound/default.go
index 78b9bfd8..68d675ec 100644
--- a/adapter/outbound/default.go
+++ b/adapter/outbound/default.go
@@ -68,28 +68,47 @@ func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dial
 
 func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
 	ctx = adapter.WithContext(ctx, &metadata)
-	var outConn net.PacketConn
-	var destinationAddress netip.Addr
-	var err error
-	if len(metadata.DestinationAddresses) > 0 {
-		outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
+	var (
+		outPacketConn      net.PacketConn
+		outConn            net.Conn
+		destinationAddress netip.Addr
+		err                error
+	)
+	if metadata.UDPConnect {
+		if len(metadata.DestinationAddresses) > 0 {
+			outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
+		} else {
+			outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
+		}
+		if err != nil {
+			return N.ReportHandshakeFailure(conn, err)
+		}
+		outPacketConn = bufio.NewUnbindPacketConn(outConn)
+		connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr())
+		if connRemoteAddr != metadata.Destination.Addr {
+			destinationAddress = connRemoteAddr
+		}
 	} else {
-		outConn, err = this.ListenPacket(ctx, metadata.Destination)
+		if len(metadata.DestinationAddresses) > 0 {
+			outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
+		} else {
+			outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
+		}
+		if err != nil {
+			return N.ReportHandshakeFailure(conn, err)
+		}
 	}
+	err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn)
 	if err != nil {
-		return N.ReportHandshakeFailure(conn, err)
-	}
-	err = N.ReportPacketConnHandshakeSuccess(conn, outConn)
-	if err != nil {
-		outConn.Close()
+		outPacketConn.Close()
 		return err
 	}
 	if destinationAddress.IsValid() {
 		if metadata.Destination.IsFqdn() {
 			if metadata.UDPDisableDomainUnmapping {
-				outConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
+				outPacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
 			} else {
-				outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
+				outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
 			}
 		}
 		if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
@@ -104,37 +123,62 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
 	case C.ProtocolDNS:
 		ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
 	}
-	return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
+	return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
 }
 
 func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error {
 	ctx = adapter.WithContext(ctx, &metadata)
-	var outConn net.PacketConn
-	var destinationAddress netip.Addr
-	var err error
-	if len(metadata.DestinationAddresses) > 0 {
-		outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
-	} else if metadata.Destination.IsFqdn() {
-		var destinationAddresses []netip.Addr
-		destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
+	var (
+		outPacketConn      net.PacketConn
+		outConn            net.Conn
+		destinationAddress netip.Addr
+		err                error
+	)
+	if metadata.UDPConnect {
+		if len(metadata.DestinationAddresses) > 0 {
+			outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
+		} else if metadata.Destination.IsFqdn() {
+			var destinationAddresses []netip.Addr
+			destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
+			if err != nil {
+				return N.ReportHandshakeFailure(conn, err)
+			}
+			outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, destinationAddresses)
+		} else {
+			outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
+		}
 		if err != nil {
 			return N.ReportHandshakeFailure(conn, err)
 		}
-		outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
+		connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr())
+		if connRemoteAddr != metadata.Destination.Addr {
+			destinationAddress = connRemoteAddr
+		}
 	} else {
-		outConn, err = this.ListenPacket(ctx, metadata.Destination)
+		if len(metadata.DestinationAddresses) > 0 {
+			outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
+		} else if metadata.Destination.IsFqdn() {
+			var destinationAddresses []netip.Addr
+			destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
+			if err != nil {
+				return N.ReportHandshakeFailure(conn, err)
+			}
+			outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
+		} else {
+			outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
+		}
+		if err != nil {
+			return N.ReportHandshakeFailure(conn, err)
+		}
 	}
+	err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn)
 	if err != nil {
-		return N.ReportHandshakeFailure(conn, err)
-	}
-	err = N.ReportPacketConnHandshakeSuccess(conn, outConn)
-	if err != nil {
-		outConn.Close()
+		outPacketConn.Close()
 		return err
 	}
 	if destinationAddress.IsValid() {
 		if metadata.Destination.IsFqdn() {
-			outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
+			outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
 		}
 		if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
 			natConn.UpdateDestination(destinationAddress)
@@ -148,7 +192,7 @@ func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this
 	case C.ProtocolDNS:
 		ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
 	}
-	return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
+	return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
 }
 
 func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {