Fix connect domain for IP outbounds

This commit is contained in:
世界 2023-09-06 19:13:39 +08:00
parent 7082cf277e
commit 1402bdab41
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 88 additions and 8 deletions

View file

@ -70,6 +70,28 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a
return CopyEarlyConn(ctx, conn, outConn) return CopyEarlyConn(ctx, conn, outConn)
} }
func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn
var err error
if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.LookupDefault(ctx, metadata.Destination.Fqdn)
if err != nil {
return N.HandshakeFailure(conn, err)
}
outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, destinationAddresses)
} else {
outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
return N.HandshakeFailure(conn, err)
}
return CopyEarlyConn(ctx, conn, outConn)
}
func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
var outConn net.PacketConn var outConn net.PacketConn
@ -99,6 +121,42 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
} }
func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
ctx = adapter.WithContext(ctx, &metadata)
var outConn net.PacketConn
var destinationAddress netip.Addr
var err error
if len(metadata.DestinationAddresses) > 0 {
outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.LookupDefault(ctx, metadata.Destination.Fqdn)
if err != nil {
return N.HandshakeFailure(conn, err)
}
outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
} else {
outConn, err = this.ListenPacket(ctx, metadata.Destination)
}
if err != nil {
return N.HandshakeFailure(conn, err)
}
if destinationAddress.IsValid() {
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
natConn.UpdateDestination(destinationAddress)
}
}
switch metadata.Protocol {
case C.ProtocolSTUN:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout)
case C.ProtocolQUIC:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout)
case C.ProtocolDNS:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
}
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
}
func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error { func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {
if cachedReader, isCached := conn.(N.CachedReader); isCached { if cachedReader, isCached := conn.(N.CachedReader); isCached {
payload := cachedReader.ReadCached() payload := cachedReader.ReadCached()

View file

@ -80,11 +80,11 @@ func (h *Socks) DialContext(ctx context.Context, network string, destination M.S
return nil, E.Extend(N.ErrUnknownNetwork, network) return nil, E.Extend(N.ErrUnknownNetwork, network)
} }
if h.resolve && destination.IsFqdn() { if h.resolve && destination.IsFqdn() {
addrs, err := h.router.LookupDefault(ctx, destination.Fqdn) destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return N.DialSerial(ctx, h.client, network, destination, addrs) return N.DialSerial(ctx, h.client, network, destination, destinationAddresses)
} }
return h.client.DialContext(ctx, network, destination) return h.client.DialContext(ctx, network, destination)
} }
@ -97,14 +97,25 @@ func (h *Socks) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.
h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination) h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination)
return h.uotClient.ListenPacket(ctx, destination) return h.uotClient.ListenPacket(ctx, destination)
} }
if h.resolve && destination.IsFqdn() {
destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
packetConn, _, err := N.ListenSerial(ctx, h.client, destination, destinationAddresses)
if err != nil {
return nil, err
}
return packetConn, nil
}
h.logger.InfoContext(ctx, "outbound packet connection to ", destination) h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return h.client.ListenPacket(ctx, destination) return h.client.ListenPacket(ctx, destination)
} }
func (h *Socks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Socks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return NewConnection(ctx, h, conn, metadata) return NewDirectConnection(ctx, h.router, h, conn, metadata)
} }
func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return NewPacketConnection(ctx, h, conn, metadata) return NewDirectPacketConnection(ctx, h.router, h, conn, metadata)
} }

View file

@ -202,26 +202,37 @@ func (w *WireGuard) DialContext(ctx context.Context, network string, destination
w.logger.InfoContext(ctx, "outbound packet connection to ", destination) w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
} }
if destination.IsFqdn() { if destination.IsFqdn() {
addrs, err := w.router.LookupDefault(ctx, destination.Fqdn) destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return N.DialSerial(ctx, w.tunDevice, network, destination, addrs) return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses)
} }
return w.tunDevice.DialContext(ctx, network, destination) return w.tunDevice.DialContext(ctx, network, destination)
} }
func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination) w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses)
if err != nil {
return nil, err
}
return packetConn, err
}
return w.tunDevice.ListenPacket(ctx, destination) return w.tunDevice.ListenPacket(ctx, destination)
} }
func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return NewConnection(ctx, w, conn, metadata) return NewDirectConnection(ctx, w.router, w, conn, metadata)
} }
func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return NewPacketConnection(ctx, w, conn, metadata) return NewDirectPacketConnection(ctx, w.router, w, conn, metadata)
} }
func (w *WireGuard) Start() error { func (w *WireGuard) Start() error {