mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-10 02:53:11 +00:00
b413066012
* Fix UDP destination override * Fix code style * Fix fakedns object init Do type convertion at runtime in case if user don't use fakedns in config. Since dispatcher now depend on fakedns object, move the injection order of fakedns to top (As a temporary solution) * Amend logic for handing fakedns client A map is used by server side when client turn on fakedns Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs. When target replies, server will restore the domain and send back to client. Co-authored-by: hmol233 <82594500+hmol233@users.noreply.github.com>
526 lines
16 KiB
Go
526 lines
16 KiB
Go
package dispatcher
|
|
|
|
//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/xtls/xray-core/common"
|
|
"github.com/xtls/xray-core/common/buf"
|
|
"github.com/xtls/xray-core/common/log"
|
|
"github.com/xtls/xray-core/common/net"
|
|
"github.com/xtls/xray-core/common/protocol"
|
|
"github.com/xtls/xray-core/common/session"
|
|
"github.com/xtls/xray-core/core"
|
|
"github.com/xtls/xray-core/features/dns"
|
|
"github.com/xtls/xray-core/features/outbound"
|
|
"github.com/xtls/xray-core/features/policy"
|
|
"github.com/xtls/xray-core/features/routing"
|
|
routing_session "github.com/xtls/xray-core/features/routing/session"
|
|
"github.com/xtls/xray-core/features/stats"
|
|
"github.com/xtls/xray-core/transport"
|
|
"github.com/xtls/xray-core/transport/pipe"
|
|
)
|
|
|
|
var errSniffingTimeout = newError("timeout on sniffing")
|
|
|
|
type cachedReader struct {
|
|
sync.Mutex
|
|
reader *pipe.Reader
|
|
cache buf.MultiBuffer
|
|
}
|
|
|
|
func (r *cachedReader) Cache(b *buf.Buffer) {
|
|
mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
|
|
r.Lock()
|
|
if !mb.IsEmpty() {
|
|
r.cache, _ = buf.MergeMulti(r.cache, mb)
|
|
}
|
|
b.Clear()
|
|
rawBytes := b.Extend(buf.Size)
|
|
n := r.cache.Copy(rawBytes)
|
|
b.Resize(0, int32(n))
|
|
r.Unlock()
|
|
}
|
|
|
|
func (r *cachedReader) readInternal() buf.MultiBuffer {
|
|
r.Lock()
|
|
defer r.Unlock()
|
|
|
|
if r.cache != nil && !r.cache.IsEmpty() {
|
|
mb := r.cache
|
|
r.cache = nil
|
|
return mb
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
|
mb := r.readInternal()
|
|
if mb != nil {
|
|
return mb, nil
|
|
}
|
|
|
|
return r.reader.ReadMultiBuffer()
|
|
}
|
|
|
|
func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
|
|
mb := r.readInternal()
|
|
if mb != nil {
|
|
return mb, nil
|
|
}
|
|
|
|
return r.reader.ReadMultiBufferTimeout(timeout)
|
|
}
|
|
|
|
func (r *cachedReader) Interrupt() {
|
|
r.Lock()
|
|
if r.cache != nil {
|
|
r.cache = buf.ReleaseMulti(r.cache)
|
|
}
|
|
r.Unlock()
|
|
r.reader.Interrupt()
|
|
}
|
|
|
|
// DefaultDispatcher is a default implementation of Dispatcher.
|
|
type DefaultDispatcher struct {
|
|
ohm outbound.Manager
|
|
router routing.Router
|
|
policy policy.Manager
|
|
stats stats.Manager
|
|
dns dns.Client
|
|
fdns dns.FakeDNSEngine
|
|
}
|
|
|
|
func init() {
|
|
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
|
d := new(DefaultDispatcher)
|
|
if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
|
|
core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
|
|
d.fdns = fdns
|
|
})
|
|
return d.Init(config.(*Config), om, router, pm, sm, dc)
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
return d, nil
|
|
}))
|
|
}
|
|
|
|
// Init initializes DefaultDispatcher.
|
|
func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error {
|
|
d.ohm = om
|
|
d.router = router
|
|
d.policy = pm
|
|
d.stats = sm
|
|
d.dns = dns
|
|
return nil
|
|
}
|
|
|
|
// Type implements common.HasType.
|
|
func (*DefaultDispatcher) Type() interface{} {
|
|
return routing.DispatcherType()
|
|
}
|
|
|
|
// Start implements common.Runnable.
|
|
func (*DefaultDispatcher) Start() error {
|
|
return nil
|
|
}
|
|
|
|
// Close implements common.Closable.
|
|
func (*DefaultDispatcher) Close() error { return nil }
|
|
|
|
func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
|
|
downOpt := pipe.OptionsFromContext(ctx)
|
|
upOpt := downOpt
|
|
|
|
if network == net.Network_UDP {
|
|
var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
|
|
// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
|
|
// When target replies, server will restore the domain and send back to client.
|
|
// Note: this map is not global but per connection context
|
|
upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
|
|
for i, buffer := range mb {
|
|
if buffer.UDP == nil {
|
|
continue
|
|
}
|
|
addr := buffer.UDP.Address
|
|
if addr.Family().IsIP() {
|
|
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
|
|
domain := fkr0.GetDomainFromFakeDNS(addr)
|
|
if len(domain) > 0 {
|
|
buffer.UDP.Address = net.DomainAddress(domain)
|
|
newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
|
|
} else {
|
|
newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
|
|
}
|
|
}
|
|
} else {
|
|
if (ip2domain == nil) {
|
|
ip2domain = new(sync.Map)
|
|
newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
|
|
}
|
|
domain := addr.Domain()
|
|
ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
|
|
if err == nil {
|
|
for _, ip := range ips {
|
|
ip2domain.Store(ip.String(), domain)
|
|
}
|
|
newError("[fakedns client] candidate ip: " + fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
|
|
} else {
|
|
newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
|
|
}
|
|
}
|
|
}
|
|
return mb
|
|
}))
|
|
downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
|
|
for i, buffer := range mb {
|
|
if buffer.UDP == nil {
|
|
continue
|
|
}
|
|
addr := buffer.UDP.Address
|
|
if addr.Family().IsIP() {
|
|
if ip2domain == nil {
|
|
continue
|
|
}
|
|
if domain, found := ip2domain.Load(addr.IP().String()); found {
|
|
buffer.UDP.Address = net.DomainAddress(domain.(string))
|
|
newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
|
|
}
|
|
} else {
|
|
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
|
|
fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
|
|
buffer.UDP.Address = fakeIp[0]
|
|
newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
|
|
}
|
|
}
|
|
}
|
|
return mb
|
|
}))
|
|
}
|
|
uplinkReader, uplinkWriter := pipe.New(upOpt...)
|
|
downlinkReader, downlinkWriter := pipe.New(downOpt...)
|
|
|
|
inboundLink := &transport.Link{
|
|
Reader: downlinkReader,
|
|
Writer: uplinkWriter,
|
|
}
|
|
|
|
outboundLink := &transport.Link{
|
|
Reader: uplinkReader,
|
|
Writer: downlinkWriter,
|
|
}
|
|
|
|
sessionInbound := session.InboundFromContext(ctx)
|
|
var user *protocol.MemoryUser
|
|
if sessionInbound != nil {
|
|
user = sessionInbound.User
|
|
}
|
|
|
|
if user != nil && len(user.Email) > 0 {
|
|
p := d.policy.ForLevel(user.Level)
|
|
if p.Stats.UserUplink {
|
|
name := "user>>>" + user.Email + ">>>traffic>>>uplink"
|
|
if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
|
|
inboundLink.Writer = &SizeStatWriter{
|
|
Counter: c,
|
|
Writer: inboundLink.Writer,
|
|
}
|
|
}
|
|
}
|
|
if p.Stats.UserDownlink {
|
|
name := "user>>>" + user.Email + ">>>traffic>>>downlink"
|
|
if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
|
|
outboundLink.Writer = &SizeStatWriter{
|
|
Counter: c,
|
|
Writer: outboundLink.Writer,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return inboundLink, outboundLink
|
|
}
|
|
|
|
func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
|
|
domain := result.Domain()
|
|
for _, d := range request.ExcludeForDomain {
|
|
if strings.ToLower(domain) == d {
|
|
return false
|
|
}
|
|
}
|
|
protocolString := result.Protocol()
|
|
if resComp, ok := result.(SnifferResultComposite); ok {
|
|
protocolString = resComp.ProtocolForDomainResult()
|
|
}
|
|
for _, p := range request.OverrideDestinationForProtocol {
|
|
if strings.HasPrefix(protocolString, p) {
|
|
return true
|
|
}
|
|
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
|
|
destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) {
|
|
newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx))
|
|
return true
|
|
}
|
|
if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
|
|
if resultSubset.IsProtoSubsetOf(p) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// Dispatch implements routing.Dispatcher.
|
|
func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
|
|
if !destination.IsValid() {
|
|
panic("Dispatcher: Invalid destination.")
|
|
}
|
|
ob := &session.Outbound{
|
|
Target: destination,
|
|
}
|
|
ctx = session.ContextWithOutbound(ctx, ob)
|
|
content := session.ContentFromContext(ctx)
|
|
if content == nil {
|
|
content = new(session.Content)
|
|
ctx = session.ContextWithContent(ctx, content)
|
|
}
|
|
|
|
sniffingRequest := content.SniffingRequest
|
|
inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
|
|
switch {
|
|
case !sniffingRequest.Enabled:
|
|
go d.routedDispatch(ctx, outbound, destination)
|
|
case destination.Network != net.Network_TCP:
|
|
// Only metadata sniff will be used for non tcp connection
|
|
result, err := sniffer(ctx, nil, true)
|
|
if err == nil {
|
|
content.Protocol = result.Protocol()
|
|
if d.shouldOverride(ctx, result, sniffingRequest, destination) {
|
|
domain := result.Domain()
|
|
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
|
destination.Address = net.ParseAddress(domain)
|
|
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
|
|
ob.RouteTarget = destination
|
|
} else {
|
|
ob.Target = destination
|
|
}
|
|
}
|
|
}
|
|
go d.routedDispatch(ctx, outbound, destination)
|
|
default:
|
|
go func() {
|
|
cReader := &cachedReader{
|
|
reader: outbound.Reader.(*pipe.Reader),
|
|
}
|
|
outbound.Reader = cReader
|
|
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly)
|
|
if err == nil {
|
|
content.Protocol = result.Protocol()
|
|
}
|
|
if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
|
|
domain := result.Domain()
|
|
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
|
destination.Address = net.ParseAddress(domain)
|
|
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
|
|
ob.RouteTarget = destination
|
|
} else {
|
|
ob.Target = destination
|
|
}
|
|
}
|
|
d.routedDispatch(ctx, outbound, destination)
|
|
}()
|
|
}
|
|
return inbound, nil
|
|
}
|
|
|
|
// DispatchLink implements routing.Dispatcher.
|
|
func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
|
|
if !destination.IsValid() {
|
|
return newError("Dispatcher: Invalid destination.")
|
|
}
|
|
ob := &session.Outbound{
|
|
Target: destination,
|
|
}
|
|
ctx = session.ContextWithOutbound(ctx, ob)
|
|
content := session.ContentFromContext(ctx)
|
|
if content == nil {
|
|
content = new(session.Content)
|
|
ctx = session.ContextWithContent(ctx, content)
|
|
}
|
|
sniffingRequest := content.SniffingRequest
|
|
switch {
|
|
case !sniffingRequest.Enabled:
|
|
go d.routedDispatch(ctx, outbound, destination)
|
|
case destination.Network != net.Network_TCP:
|
|
// Only metadata sniff will be used for non tcp connection
|
|
result, err := sniffer(ctx, nil, true)
|
|
if err == nil {
|
|
content.Protocol = result.Protocol()
|
|
if d.shouldOverride(ctx, result, sniffingRequest, destination) {
|
|
domain := result.Domain()
|
|
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
|
destination.Address = net.ParseAddress(domain)
|
|
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
|
|
ob.RouteTarget = destination
|
|
} else {
|
|
ob.Target = destination
|
|
}
|
|
}
|
|
}
|
|
go d.routedDispatch(ctx, outbound, destination)
|
|
default:
|
|
go func() {
|
|
cReader := &cachedReader{
|
|
reader: outbound.Reader.(*pipe.Reader),
|
|
}
|
|
outbound.Reader = cReader
|
|
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly)
|
|
if err == nil {
|
|
content.Protocol = result.Protocol()
|
|
}
|
|
if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
|
|
domain := result.Domain()
|
|
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
|
destination.Address = net.ParseAddress(domain)
|
|
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
|
|
ob.RouteTarget = destination
|
|
} else {
|
|
ob.Target = destination
|
|
}
|
|
}
|
|
d.routedDispatch(ctx, outbound, destination)
|
|
}()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) {
|
|
payload := buf.New()
|
|
defer payload.Release()
|
|
|
|
sniffer := NewSniffer(ctx)
|
|
|
|
metaresult, metadataErr := sniffer.SniffMetadata(ctx)
|
|
|
|
if metadataOnly {
|
|
return metaresult, metadataErr
|
|
}
|
|
|
|
contentResult, contentErr := func() (SniffResult, error) {
|
|
totalAttempt := 0
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
default:
|
|
totalAttempt++
|
|
if totalAttempt > 2 {
|
|
return nil, errSniffingTimeout
|
|
}
|
|
|
|
cReader.Cache(payload)
|
|
if !payload.IsEmpty() {
|
|
result, err := sniffer.Sniff(ctx, payload.Bytes())
|
|
if err != common.ErrNoClue {
|
|
return result, err
|
|
}
|
|
}
|
|
if payload.IsFull() {
|
|
return nil, errUnknownContent
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
if contentErr != nil && metadataErr == nil {
|
|
return metaresult, nil
|
|
}
|
|
if contentErr == nil && metadataErr == nil {
|
|
return CompositeResult(metaresult, contentResult), nil
|
|
}
|
|
return contentResult, contentErr
|
|
}
|
|
|
|
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
|
|
ob := session.OutboundFromContext(ctx)
|
|
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
|
|
proxied := hosts.LookupHosts(ob.Target.String())
|
|
if proxied != nil {
|
|
ro := ob.RouteTarget == destination
|
|
destination.Address = *proxied
|
|
if ro {
|
|
ob.RouteTarget = destination
|
|
} else {
|
|
ob.Target = destination
|
|
}
|
|
}
|
|
}
|
|
|
|
var handler outbound.Handler
|
|
|
|
routingLink := routing_session.AsRoutingContext(ctx)
|
|
inTag := routingLink.GetInboundTag()
|
|
isPickRoute := 0
|
|
if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
|
|
ctx = session.SetForcedOutboundTagToContext(ctx, "")
|
|
if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
|
|
isPickRoute = 1
|
|
newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
|
|
handler = h
|
|
} else {
|
|
newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx))
|
|
common.Close(link.Writer)
|
|
common.Interrupt(link.Reader)
|
|
return
|
|
}
|
|
} else if d.router != nil {
|
|
if route, err := d.router.PickRoute(routingLink); err == nil {
|
|
outTag := route.GetOutboundTag()
|
|
if h := d.ohm.GetHandler(outTag); h != nil {
|
|
isPickRoute = 2
|
|
newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
|
|
handler = h
|
|
} else {
|
|
newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
|
|
}
|
|
} else {
|
|
newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
|
|
}
|
|
}
|
|
|
|
if handler == nil {
|
|
handler = d.ohm.GetDefaultHandler()
|
|
}
|
|
|
|
if handler == nil {
|
|
newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
|
|
common.Close(link.Writer)
|
|
common.Interrupt(link.Reader)
|
|
return
|
|
}
|
|
|
|
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
|
|
if tag := handler.Tag(); tag != "" {
|
|
if inTag == "" {
|
|
accessMessage.Detour = tag
|
|
} else if isPickRoute == 1 {
|
|
accessMessage.Detour = inTag + " ==> " + tag
|
|
} else if isPickRoute == 2 {
|
|
accessMessage.Detour = inTag + " -> " + tag
|
|
} else {
|
|
accessMessage.Detour = inTag + " >> " + tag
|
|
}
|
|
}
|
|
log.Record(accessMessage)
|
|
}
|
|
|
|
handler.Dispatch(ctx, link)
|
|
}
|