Fix buffer.UDP destination override (#2356)

This commit is contained in:
dyhkwong 2023-08-29 15:12:36 +08:00 committed by GitHub
parent e013dce1df
commit b8bd243df5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 94 deletions

View file

@ -4,7 +4,6 @@ package dispatcher
import ( import (
"context" "context"
"fmt"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -135,77 +134,10 @@ func (*DefaultDispatcher) Start() error {
// Close implements common.Closable. // Close implements common.Closable.
func (*DefaultDispatcher) Close() error { return nil } func (*DefaultDispatcher) Close() error { return nil }
func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) { func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
downOpt := pipe.OptionsFromContext(ctx) opt := pipe.OptionsFromContext(ctx)
upOpt := downOpt uplinkReader, uplinkWriter := pipe.New(opt...)
downlinkReader, downlinkWriter := pipe.New(opt...)
if network == net.Network_UDP {
var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
// When target replies, server will restore the domain and send back to client.
// Note: this map is not global but per connection context
upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
for i, buffer := range mb {
if buffer.UDP == nil {
continue
}
addr := buffer.UDP.Address
if addr.Family().IsIP() {
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
domain := fkr0.GetDomainFromFakeDNS(addr)
if len(domain) > 0 {
buffer.UDP.Address = net.DomainAddress(domain)
newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
} else {
newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
}
}
} else {
if ip2domain == nil {
ip2domain = new(sync.Map)
newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
}
domain := addr.Domain()
ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
if err == nil {
for _, ip := range ips {
ip2domain.Store(ip.String(), domain)
}
newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
} else {
newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
}
return mb
}))
downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
for i, buffer := range mb {
if buffer.UDP == nil {
continue
}
addr := buffer.UDP.Address
if addr.Family().IsIP() {
if ip2domain == nil {
continue
}
if domain, found := ip2domain.Load(addr.IP().String()); found {
buffer.UDP.Address = net.DomainAddress(domain.(string))
newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
}
} else {
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
buffer.UDP.Address = fakeIp[0]
newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
}
}
}
return mb
}))
}
uplinkReader, uplinkWriter := pipe.New(upOpt...)
downlinkReader, downlinkWriter := pipe.New(downOpt...)
inboundLink := &transport.Link{ inboundLink := &transport.Link{
Reader: downlinkReader, Reader: downlinkReader,
@ -263,7 +195,7 @@ func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResu
protocolString = resComp.ProtocolForDomainResult() protocolString = resComp.ProtocolForDomainResult()
} }
for _, p := range request.OverrideDestinationForProtocol { for _, p := range request.OverrideDestinationForProtocol {
if strings.HasPrefix(protocolString, p) { if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) {
return true return true
} }
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
@ -287,7 +219,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
panic("Dispatcher: Invalid destination.") panic("Dispatcher: Invalid destination.")
} }
ob := &session.Outbound{ ob := &session.Outbound{
Target: destination, OriginalTarget: destination,
Target: destination,
} }
ctx = session.ContextWithOutbound(ctx, ob) ctx = session.ContextWithOutbound(ctx, ob)
content := session.ContentFromContext(ctx) content := session.ContentFromContext(ctx)
@ -295,9 +228,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
content = new(session.Content) content = new(session.Content)
ctx = session.ContextWithContent(ctx, content) ctx = session.ContextWithContent(ctx, content)
} }
sniffingRequest := content.SniffingRequest sniffingRequest := content.SniffingRequest
inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest) inbound, outbound := d.getLink(ctx)
if !sniffingRequest.Enabled { if !sniffingRequest.Enabled {
go d.routedDispatch(ctx, outbound, destination) go d.routedDispatch(ctx, outbound, destination)
} else { } else {
@ -314,7 +246,15 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
domain := result.Domain() domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain) destination.Address = net.ParseAddress(domain)
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { protocol := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok {
protocol = resComp.ProtocolForDomainResult()
}
isFakeIP := false
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
isFakeIP = true
}
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
ob.RouteTarget = destination ob.RouteTarget = destination
} else { } else {
ob.Target = destination ob.Target = destination
@ -332,7 +272,8 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
return newError("Dispatcher: Invalid destination.") return newError("Dispatcher: Invalid destination.")
} }
ob := &session.Outbound{ ob := &session.Outbound{
Target: destination, OriginalTarget: destination,
Target: destination,
} }
ctx = session.ContextWithOutbound(ctx, ob) ctx = session.ContextWithOutbound(ctx, ob)
content := session.ContentFromContext(ctx) content := session.ContentFromContext(ctx)
@ -356,7 +297,15 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
domain := result.Domain() domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain) destination.Address = net.ParseAddress(domain)
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { protocol := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok {
protocol = resComp.ProtocolForDomainResult()
}
isFakeIP := false
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
isFakeIP = true
}
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
ob.RouteTarget = destination ob.RouteTarget = destination
} else { } else {
ob.Target = destination ob.Target = destination

View file

@ -8,6 +8,7 @@ import (
"github.com/xtls/xray-core/app/proxyman" "github.com/xtls/xray-core/app/proxyman"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/mux" "github.com/xtls/xray-core/common/mux"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/net/cnc"
@ -166,6 +167,11 @@ func (h *Handler) Tag() string {
// Dispatch implements proxy.Outbound.Dispatch. // Dispatch implements proxy.Outbound.Dispatch.
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
outbound := session.OutboundFromContext(ctx)
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
}
if h.mux != nil { if h.mux != nil {
test := func(err error) { test := func(err error) {
if err != nil { if err != nil {
@ -175,7 +181,6 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
common.Interrupt(link.Writer) common.Interrupt(link.Writer)
} }
} }
outbound := session.OutboundFromContext(ctx)
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 { if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
switch h.udp443 { switch h.udp443 {
case "reject": case "reject":

38
common/buf/override.go Normal file
View file

@ -0,0 +1,38 @@
package buf
import (
"github.com/xtls/xray-core/common/net"
)
type EndpointOverrideReader struct {
Reader
Dest net.Address
OriginalDest net.Address
}
func (r *EndpointOverrideReader) ReadMultiBuffer() (MultiBuffer, error) {
mb, err := r.Reader.ReadMultiBuffer()
if err == nil {
for _, b := range mb {
if b.UDP != nil && b.UDP.Address == r.OriginalDest {
b.UDP.Address = r.Dest
}
}
}
return mb, err
}
type EndpointOverrideWriter struct {
Writer
Dest net.Address
OriginalDest net.Address
}
func (w *EndpointOverrideWriter) WriteMultiBuffer(mb MultiBuffer) error {
for _, b := range mb {
if b.UDP != nil && b.UDP.Address == w.Dest {
b.UDP.Address = w.OriginalDest
}
}
return w.Writer.WriteMultiBuffer(mb)
}

View file

@ -55,8 +55,9 @@ type Inbound struct {
// Outbound is the metadata of an outbound connection. // Outbound is the metadata of an outbound connection.
type Outbound struct { type Outbound struct {
// Target address of the outbound connection. // Target address of the outbound connection.
Target net.Destination OriginalTarget net.Destination
RouteTarget net.Destination Target net.Destination
RouteTarget net.Destination
// Gateway address // Gateway address
Gateway net.Address Gateway net.Address
} }

View file

@ -24,7 +24,6 @@ const (
type pipeOption struct { type pipeOption struct {
limit int32 // maximum buffer size in bytes limit int32 // maximum buffer size in bytes
discardOverflow bool discardOverflow bool
onTransmission func(buffer buf.MultiBuffer) buf.MultiBuffer
} }
func (o *pipeOption) isFull(curSize int32) bool { func (o *pipeOption) isFull(curSize int32) bool {
@ -141,10 +140,6 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
return nil return nil
} }
if p.option.onTransmission != nil {
mb = p.option.onTransmission(mb)
}
for { for {
err := p.writeMultiBufferInternal(mb) err := p.writeMultiBufferInternal(mb)
if err == nil { if err == nil {

View file

@ -3,7 +3,6 @@ package pipe
import ( import (
"context" "context"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
@ -26,12 +25,6 @@ func WithSizeLimit(limit int32) Option {
} }
} }
func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option {
return func(option *pipeOption) {
option.onTransmission = hook
}
}
// DiscardOverflow returns an Option for Pipe to discard writes if full. // DiscardOverflow returns an Option for Pipe to discard writes if full.
func DiscardOverflow() Option { func DiscardOverflow() Option {
return func(opt *pipeOption) { return func(opt *pipeOption) {