Override destination if replaced in hosts

This commit is contained in:
世界 2021-09-28 14:41:31 +08:00
parent 50e576081e
commit 27224868ab
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 44 additions and 6 deletions

View file

@ -4,6 +4,7 @@ package dispatcher
import ( import (
"context" "context"
"github.com/xtls/xray-core/features/dns"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -15,7 +16,6 @@ import (
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
@ -92,13 +92,14 @@ type DefaultDispatcher struct {
router routing.Router router routing.Router
policy policy.Manager policy policy.Manager
stats stats.Manager stats stats.Manager
hosts dns.HostsLookup
} }
func init() { func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
d := new(DefaultDispatcher) d := new(DefaultDispatcher)
if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
return d.Init(config.(*Config), om, router, pm, sm) return d.Init(config.(*Config), om, router, pm, sm, dc)
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -107,11 +108,14 @@ func init() {
} }
// Init initializes DefaultDispatcher. // Init initializes DefaultDispatcher.
func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
d.ohm = om d.ohm = om
d.router = router d.router = router
d.policy = pm d.policy = pm
d.stats = sm d.stats = sm
if hosts, ok := dc.(dns.HostsLookup); ok {
d.hosts = hosts
}
return nil return nil
} }
@ -294,7 +298,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
result, err := sniffer(ctx, nil, true) result, err := sniffer(ctx, nil, true)
if err == nil { if err == nil {
content.Protocol = result.Protocol() content.Protocol = result.Protocol()
if shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { if shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain() domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain) destination.Address = net.ParseAddress(domain)
@ -316,7 +320,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
if err == nil { if err == nil {
content.Protocol = result.Protocol() content.Protocol = result.Protocol()
} }
if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { if err == nil && shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain() domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain) destination.Address = net.ParseAddress(domain)
@ -379,6 +383,20 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni
} }
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
ob := session.OutboundFromContext(ctx)
if d.hosts != nil && destination.Address.Family().IsDomain() {
proxied := d.hosts.LookupHosts(ob.Target.String())
if proxied != nil {
ro := ob.RouteTarget == destination
destination.Address = *proxied
if ro {
ob.RouteTarget = destination
} else {
ob.Target = destination
}
}
}
var handler outbound.Handler var handler outbound.Handler
if d.router != nil { if d.router != nil {

View file

@ -223,6 +223,22 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) {
return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...)) return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
} }
// LookupHosts implements dns.HostsLookup.
func (s *DNS) LookupHosts(domain string) *net.Address {
domain = strings.TrimSuffix(domain, ".")
if domain == "" {
return nil
}
// Normalize the FQDN form query
addrs := s.hosts.Lookup(domain, *s.ipOption)
if len(addrs) > 0 {
newError("domain replaced: ", domain, " -> ", addrs[0].String()).AtInfo().WriteToLog()
return &addrs[0]
}
return nil
}
// GetIPOption implements ClientWithIPOption. // GetIPOption implements ClientWithIPOption.
func (s *DNS) GetIPOption() *dns.IPOption { func (s *DNS) GetIPOption() *dns.IPOption {
return s.ipOption return s.ipOption

View file

@ -24,6 +24,10 @@ type Client interface {
LookupIP(domain string, option IPOption) ([]net.IP, error) LookupIP(domain string, option IPOption) ([]net.IP, error)
} }
type HostsLookup interface {
LookupHosts(domain string) *net.Address
}
// ClientType returns the type of Client interface. Can be used for implementing common.HasType. // ClientType returns the type of Client interface. Can be used for implementing common.HasType.
// //
// xray:api:beta // xray:api:beta