Fakedns fix xUDP destination override (#1011)

* Fix UDP destination override

* Fix code style

* Fix fakedns object init

Do type convertion at runtime in case if user don't use fakedns in config.
Since dispatcher now depend on fakedns object, move the injection order of
fakedns to top (As a temporary solution)

* Amend logic for handing fakedns client

A map is used by server side when client turn on fakedns
Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
When target replies, server will restore the domain and send back to client.

Co-authored-by: hmol233 <82594500+hmol233@users.noreply.github.com>
This commit is contained in:
yuhan6665 2022-04-23 19:24:46 -04:00 committed by GitHub
parent c9df755426
commit b413066012
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 24 deletions

View file

@ -4,6 +4,7 @@ package dispatcher
import ( import (
"context" "context"
"fmt"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -92,13 +93,17 @@ type DefaultDispatcher struct {
router routing.Router router routing.Router
policy policy.Manager policy policy.Manager
stats stats.Manager stats stats.Manager
hosts dns.HostsLookup dns dns.Client
fdns dns.FakeDNSEngine
} }
func init() { func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
d := new(DefaultDispatcher) d := new(DefaultDispatcher)
if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
d.fdns = fdns
})
return d.Init(config.(*Config), om, router, pm, sm, dc) return d.Init(config.(*Config), om, router, pm, sm, dc)
}); err != nil { }); err != nil {
return nil, err return nil, err
@ -108,14 +113,12 @@ func init() {
} }
// Init initializes DefaultDispatcher. // Init initializes DefaultDispatcher.
func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error {
d.ohm = om d.ohm = om
d.router = router d.router = router
d.policy = pm d.policy = pm
d.stats = sm d.stats = sm
if hosts, ok := dc.(dns.HostsLookup); ok { d.dns = dns
d.hosts = hosts
}
return nil return nil
} }
@ -132,10 +135,77 @@ func (*DefaultDispatcher) Start() error {
// Close implements common.Closable. // Close implements common.Closable.
func (*DefaultDispatcher) Close() error { return nil } func (*DefaultDispatcher) Close() error { return nil }
func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
opt := pipe.OptionsFromContext(ctx) downOpt := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opt...) upOpt := downOpt
downlinkReader, downlinkWriter := pipe.New(opt...)
if network == net.Network_UDP {
var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
// When target replies, server will restore the domain and send back to client.
// Note: this map is not global but per connection context
upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
for i, buffer := range mb {
if buffer.UDP == nil {
continue
}
addr := buffer.UDP.Address
if addr.Family().IsIP() {
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
domain := fkr0.GetDomainFromFakeDNS(addr)
if len(domain) > 0 {
buffer.UDP.Address = net.DomainAddress(domain)
newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
} else {
newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
}
}
} else {
if (ip2domain == nil) {
ip2domain = new(sync.Map)
newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
}
domain := addr.Domain()
ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
if err == nil {
for _, ip := range ips {
ip2domain.Store(ip.String(), domain)
}
newError("[fakedns client] candidate ip: " + fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
} else {
newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
}
return mb
}))
downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
for i, buffer := range mb {
if buffer.UDP == nil {
continue
}
addr := buffer.UDP.Address
if addr.Family().IsIP() {
if ip2domain == nil {
continue
}
if domain, found := ip2domain.Load(addr.IP().String()); found {
buffer.UDP.Address = net.DomainAddress(domain.(string))
newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
}
} else {
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
buffer.UDP.Address = fakeIp[0]
newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
}
}
}
return mb
}))
}
uplinkReader, uplinkWriter := pipe.New(upOpt...)
downlinkReader, downlinkWriter := pipe.New(downOpt...)
inboundLink := &transport.Link{ inboundLink := &transport.Link{
Reader: downlinkReader, Reader: downlinkReader,
@ -178,17 +248,13 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran
return inboundLink, outboundLink return inboundLink, outboundLink
} }
func shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
domain := result.Domain() domain := result.Domain()
for _, d := range request.ExcludeForDomain { for _, d := range request.ExcludeForDomain {
if strings.ToLower(domain) == d { if strings.ToLower(domain) == d {
return false return false
} }
} }
var fakeDNSEngine dns.FakeDNSEngine
core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
fakeDNSEngine = fdns
})
protocolString := result.Protocol() protocolString := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok { if resComp, ok := result.(SnifferResultComposite); ok {
protocolString = resComp.ProtocolForDomainResult() protocolString = resComp.ProtocolForDomainResult()
@ -197,7 +263,7 @@ func shouldOverride(ctx context.Context, result SniffResult, request session.Sni
if strings.HasPrefix(protocolString, p) { if strings.HasPrefix(protocolString, p) {
return true return true
} }
if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) { destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) {
newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx)) newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx))
return true return true
@ -221,14 +287,14 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
Target: destination, Target: destination,
} }
ctx = session.ContextWithOutbound(ctx, ob) ctx = session.ContextWithOutbound(ctx, ob)
inbound, outbound := d.getLink(ctx)
content := session.ContentFromContext(ctx) content := session.ContentFromContext(ctx)
if content == nil { if content == nil {
content = new(session.Content) content = new(session.Content)
ctx = session.ContextWithContent(ctx, content) ctx = session.ContextWithContent(ctx, content)
} }
sniffingRequest := content.SniffingRequest sniffingRequest := content.SniffingRequest
inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
switch { switch {
case !sniffingRequest.Enabled: case !sniffingRequest.Enabled:
go d.routedDispatch(ctx, outbound, destination) go d.routedDispatch(ctx, outbound, destination)
@ -237,7 +303,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
result, err := sniffer(ctx, nil, true) result, err := sniffer(ctx, nil, true)
if err == nil { if err == nil {
content.Protocol = result.Protocol() content.Protocol = result.Protocol()
if shouldOverride(ctx, result, sniffingRequest, destination) { if d.shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain() domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain) destination.Address = net.ParseAddress(domain)
@ -259,7 +325,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
if err == nil { if err == nil {
content.Protocol = result.Protocol() content.Protocol = result.Protocol()
} }
if err == nil && shouldOverride(ctx, result, sniffingRequest, destination) { if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain() domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain) destination.Address = net.ParseAddress(domain)
@ -298,7 +364,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
result, err := sniffer(ctx, nil, true) result, err := sniffer(ctx, nil, true)
if err == nil { if err == nil {
content.Protocol = result.Protocol() content.Protocol = result.Protocol()
if shouldOverride(ctx, result, sniffingRequest, destination) { if d.shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain() domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain) destination.Address = net.ParseAddress(domain)
@ -320,7 +386,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
if err == nil { if err == nil {
content.Protocol = result.Protocol() content.Protocol = result.Protocol()
} }
if err == nil && shouldOverride(ctx, result, sniffingRequest, destination) { if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain() domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain) destination.Address = net.ParseAddress(domain)
@ -384,8 +450,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
ob := session.OutboundFromContext(ctx) ob := session.OutboundFromContext(ctx)
if d.hosts != nil && destination.Address.Family().IsDomain() { if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := d.hosts.LookupHosts(ob.Target.String()) proxied := hosts.LookupHosts(ob.Target.String())
if proxied != nil { if proxied != nil {
ro := ob.RouteTarget == destination ro := ob.RouteTarget == destination
destination.Address = *proxied destination.Address = *proxied

View file

@ -632,7 +632,7 @@ func (c *Config) Build() (*core.Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
config.App = append(config.App, serial.ToTypedMessage(r)) config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...)
} }
if c.Observatory != nil { if c.Observatory != nil {

View file

@ -24,6 +24,7 @@ const (
type pipeOption struct { type pipeOption struct {
limit int32 // maximum buffer size in bytes limit int32 // maximum buffer size in bytes
discardOverflow bool discardOverflow bool
onTransmission func(buffer buf.MultiBuffer) buf.MultiBuffer
} }
func (o *pipeOption) isFull(curSize int32) bool { func (o *pipeOption) isFull(curSize int32) bool {
@ -137,6 +138,10 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
return nil return nil
} }
if p.option.onTransmission != nil {
mb = p.option.onTransmission(mb)
}
for { for {
err := p.writeMultiBufferInternal(mb) err := p.writeMultiBufferInternal(mb)
if err == nil { if err == nil {

View file

@ -3,6 +3,7 @@ package pipe
import ( import (
"context" "context"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
@ -25,6 +26,12 @@ func WithSizeLimit(limit int32) Option {
} }
} }
func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option {
return func(option *pipeOption) {
option.onTransmission = hook
}
}
// DiscardOverflow returns an Option for Pipe to discard writes if full. // DiscardOverflow returns an Option for Pipe to discard writes if full.
func DiscardOverflow() Option { func DiscardOverflow() Option {
return func(opt *pipeOption) { return func(opt *pipeOption) {