mirror of
synced 2025-03-20 21:39:35 +00:00
* Add session context outbounds as slice slice is needed for dialer proxy where two outbounds work on top of each other There are two sets of target addr for example It also enable Xtls to correctly do splice copy by checking both outbounds are ready to do direct copy * Fill outbound tag info * Splice now checks capalibility from all outbounds * Fix unit tests
448 lines
13 KiB
448 lines
13 KiB
package dispatcher
//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
import (
routing_session "github.com/xtls/xray-core/features/routing/session"
var errSniffingTimeout = newError("timeout on sniffing")
type cachedReader struct {
reader *pipe.Reader
cache buf.MultiBuffer
func (r *cachedReader) Cache(b *buf.Buffer) {
mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
if !mb.IsEmpty() {
r.cache, _ = buf.MergeMulti(r.cache, mb)
rawBytes := b.Extend(buf.Size)
n := r.cache.Copy(rawBytes)
b.Resize(0, int32(n))
func (r *cachedReader) readInternal() buf.MultiBuffer {
defer r.Unlock()
if r.cache != nil && !r.cache.IsEmpty() {
mb := r.cache
r.cache = nil
return mb
return nil
func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
mb := r.readInternal()
if mb != nil {
return mb, nil
return r.reader.ReadMultiBuffer()
func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
mb := r.readInternal()
if mb != nil {
return mb, nil
return r.reader.ReadMultiBufferTimeout(timeout)
func (r *cachedReader) Interrupt() {
if r.cache != nil {
r.cache = buf.ReleaseMulti(r.cache)
// DefaultDispatcher is a default implementation of Dispatcher.
type DefaultDispatcher struct {
ohm outbound.Manager
router routing.Router
policy policy.Manager
stats stats.Manager
dns dns.Client
fdns dns.FakeDNSEngine
func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
d := new(DefaultDispatcher)
if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
d.fdns = fdns
return d.Init(config.(*Config), om, router, pm, sm, dc)
}); err != nil {
return nil, err
return d, nil
// Init initializes DefaultDispatcher.
func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error {
d.ohm = om
d.router = router
d.policy = pm
d.stats = sm
d.dns = dns
return nil
// Type implements common.HasType.
func (*DefaultDispatcher) Type() interface{} {
return routing.DispatcherType()
// Start implements common.Runnable.
func (*DefaultDispatcher) Start() error {
return nil
// Close implements common.Closable.
func (*DefaultDispatcher) Close() error { return nil }
func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
opt := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opt...)
downlinkReader, downlinkWriter := pipe.New(opt...)
inboundLink := &transport.Link{
Reader: downlinkReader,
Writer: uplinkWriter,
outboundLink := &transport.Link{
Reader: uplinkReader,
Writer: downlinkWriter,
sessionInbound := session.InboundFromContext(ctx)
var user *protocol.MemoryUser
if sessionInbound != nil {
user = sessionInbound.User
if user != nil && len(user.Email) > 0 {
p := d.policy.ForLevel(user.Level)
if p.Stats.UserUplink {
name := "user>>>" + user.Email + ">>>traffic>>>uplink"
if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
inboundLink.Writer = &SizeStatWriter{
Counter: c,
Writer: inboundLink.Writer,
if p.Stats.UserDownlink {
name := "user>>>" + user.Email + ">>>traffic>>>downlink"
if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
outboundLink.Writer = &SizeStatWriter{
Counter: c,
Writer: outboundLink.Writer,
return inboundLink, outboundLink
func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
domain := result.Domain()
if domain == "" {
return false
for _, d := range request.ExcludeForDomain {
if strings.ToLower(domain) == d {
return false
protocolString := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok {
protocolString = resComp.ProtocolForDomainResult()
for _, p := range request.OverrideDestinationForProtocol {
if strings.HasPrefix(protocolString, p) || strings.HasPrefix(p, protocolString) {
return true
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
fkr0.IsIPInIPPool(destination.Address) {
newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx))
return true
if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
if resultSubset.IsProtoSubsetOf(p) {
return true
return false
// Dispatch implements routing.Dispatcher.
func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
if !destination.IsValid() {
panic("Dispatcher: Invalid destination.")
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
if content == nil {
content = new(session.Content)
ctx = session.ContextWithContent(ctx, content)
sniffingRequest := content.SniffingRequest
inbound, outbound := d.getLink(ctx)
if !sniffingRequest.Enabled {
go d.routedDispatch(ctx, outbound, destination)
} else {
go func() {
cReader := &cachedReader{
reader: outbound.Reader.(*pipe.Reader),
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
if err == nil {
content.Protocol = result.Protocol()
if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain)
protocol := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok {
protocol = resComp.ProtocolForDomainResult()
isFakeIP := false
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) {
isFakeIP = true
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
ob.RouteTarget = destination
} else {
ob.Target = destination
d.routedDispatch(ctx, outbound, destination)
return inbound, nil
// DispatchLink implements routing.Dispatcher.
func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
if !destination.IsValid() {
return newError("Dispatcher: Invalid destination.")
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
if content == nil {
content = new(session.Content)
ctx = session.ContextWithContent(ctx, content)
sniffingRequest := content.SniffingRequest
if !sniffingRequest.Enabled {
d.routedDispatch(ctx, outbound, destination)
} else {
cReader := &cachedReader{
reader: outbound.Reader.(*pipe.Reader),
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
if err == nil {
content.Protocol = result.Protocol()
if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain)
protocol := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok {
protocol = resComp.ProtocolForDomainResult()
isFakeIP := false
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) {
isFakeIP = true
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
ob.RouteTarget = destination
} else {
ob.Target = destination
d.routedDispatch(ctx, outbound, destination)
return nil
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
payload := buf.New()
defer payload.Release()
sniffer := NewSniffer(ctx)
metaresult, metadataErr := sniffer.SniffMetadata(ctx)
if metadataOnly {
return metaresult, metadataErr
contentResult, contentErr := func() (SniffResult, error) {
totalAttempt := 0
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
if totalAttempt > 2 {
return nil, errSniffingTimeout
if !payload.IsEmpty() {
result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
if err != common.ErrNoClue {
return result, err
if payload.IsFull() {
return nil, errUnknownContent
if contentErr != nil && metadataErr == nil {
return metaresult, nil
if contentErr == nil && metadataErr == nil {
return CompositeResult(metaresult, contentResult), nil
return contentResult, contentErr
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := hosts.LookupHosts(ob.Target.String())
if proxied != nil {
ro := ob.RouteTarget == destination
destination.Address = *proxied
if ro {
ob.RouteTarget = destination
} else {
ob.Target = destination
var handler outbound.Handler
routingLink := routing_session.AsRoutingContext(ctx)
inTag := routingLink.GetInboundTag()
isPickRoute := 0
if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
ctx = session.SetForcedOutboundTagToContext(ctx, "")
if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
isPickRoute = 1
newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
handler = h
} else {
newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx))
} else if d.router != nil {
if route, err := d.router.PickRoute(routingLink); err == nil {
outTag := route.GetOutboundTag()
if h := d.ohm.GetHandler(outTag); h != nil {
isPickRoute = 2
newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
handler = h
} else {
newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
} else {
newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
if handler == nil {
handler = d.ohm.GetDefaultHandler()
if handler == nil {
newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
ob.Tag = handler.Tag()
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
if tag := handler.Tag(); tag != "" {
if inTag == "" {
accessMessage.Detour = tag
} else if isPickRoute == 1 {
accessMessage.Detour = inTag + " ==> " + tag
} else if isPickRoute == 2 {
accessMessage.Detour = inTag + " -> " + tag
} else {
accessMessage.Detour = inTag + " >> " + tag
handler.Dispatch(ctx, link)