diff --git a/outbound/default.go b/outbound/default.go index 46967f98..d7319976 100644 --- a/outbound/default.go +++ b/outbound/default.go @@ -68,6 +68,10 @@ func NewEarlyConnection(ctx context.Context, this N.Dialer, conn net.Conn, metad } func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { + switch metadata.Protocol { + case C.ProtocolQUIC, C.ProtocolDNS: + return connectPacketConnection(ctx, this, conn, metadata) + } ctx = adapter.WithContext(ctx, &metadata) var outConn net.PacketConn var err error @@ -92,6 +96,31 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) } +func connectPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) + var outConn net.Conn + var err error + if len(metadata.DestinationAddresses) > 0 { + outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) + } else { + outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + if err != nil { + return N.HandshakeFailure(conn, err) + } + if metadata.Protocol != "" { + switch metadata.Protocol { + case C.ProtocolQUIC: + ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout) + case C.ProtocolDNS: + ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout) + case C.ProtocolSTUN: + ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout) + } + } + return bufio.CopyPacketConn(ctx, conn, bufio.NewUnbindPacketConn(outConn)) +} + func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error { if cachedReader, isCached := serverConn.(N.CachedReader); isCached { payload := cachedReader.ReadCached()