Implement udp connect

This commit is contained in:
世界 2024-11-07 13:44:00 +08:00
parent 59a607e303
commit 1133cf3ef5
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -20,6 +20,7 @@ import (
) )
func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn var outConn net.Conn
var err error var err error
@ -40,6 +41,7 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a
} }
func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error { func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn var outConn net.Conn
var err error var err error
@ -67,29 +69,49 @@ func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dial
} }
func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
var outConn net.PacketConn var (
var destinationAddress netip.Addr outPacketConn net.PacketConn
var err error outConn net.Conn
if len(metadata.DestinationAddresses) > 0 { destinationAddress netip.Addr
outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) err error
)
if metadata.UDPConnect {
if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
} else {
outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
}
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
outPacketConn = bufio.NewUnbindPacketConn(outConn)
connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr())
if connRemoteAddr != metadata.Destination.Addr {
destinationAddress = connRemoteAddr
}
} else { } else {
outConn, err = this.ListenPacket(ctx, metadata.Destination) if len(metadata.DestinationAddresses) > 0 {
outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
} else {
outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
}
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
} }
err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn)
if err != nil { if err != nil {
return N.ReportHandshakeFailure(conn, err) outPacketConn.Close()
}
err = N.ReportPacketConnHandshakeSuccess(conn, outConn)
if err != nil {
outConn.Close()
return err return err
} }
if destinationAddress.IsValid() { if destinationAddress.IsValid() {
if metadata.Destination.IsFqdn() { if metadata.Destination.IsFqdn() {
if metadata.UDPDisableDomainUnmapping { if metadata.UDPDisableDomainUnmapping {
outConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) outPacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
} else { } else {
outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
} }
} }
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
@ -104,37 +126,63 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
case C.ProtocolDNS: case C.ProtocolDNS:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout) ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
} }
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
} }
func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error { func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
var outConn net.PacketConn var (
var destinationAddress netip.Addr outPacketConn net.PacketConn
var err error outConn net.Conn
if len(metadata.DestinationAddresses) > 0 { destinationAddress netip.Addr
outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) err error
} else if metadata.Destination.IsFqdn() { )
var destinationAddresses []netip.Addr if metadata.UDPConnect {
destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, destinationAddresses)
} else {
outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
}
if err != nil { if err != nil {
return N.ReportHandshakeFailure(conn, err) return N.ReportHandshakeFailure(conn, err)
} }
outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses) connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr())
if connRemoteAddr != metadata.Destination.Addr {
destinationAddress = connRemoteAddr
}
} else { } else {
outConn, err = this.ListenPacket(ctx, metadata.Destination) if len(metadata.DestinationAddresses) > 0 {
outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
} else {
outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
}
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
} }
err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn)
if err != nil { if err != nil {
return N.ReportHandshakeFailure(conn, err) outPacketConn.Close()
}
err = N.ReportPacketConnHandshakeSuccess(conn, outConn)
if err != nil {
outConn.Close()
return err return err
} }
if destinationAddress.IsValid() { if destinationAddress.IsValid() {
if metadata.Destination.IsFqdn() { if metadata.Destination.IsFqdn() {
outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
} }
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
natConn.UpdateDestination(destinationAddress) natConn.UpdateDestination(destinationAddress)
@ -148,7 +196,7 @@ func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this
case C.ProtocolDNS: case C.ProtocolDNS:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout) ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
} }
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
} }
func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error { func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {