From b456aff4acc0b559165eba1e55fd42b13fb98006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 22 Oct 2024 22:01:28 +0800 Subject: [PATCH] Implement resolve(server) --- adapter/inbound.go | 14 ++++--- route/route.go | 2 +- route/route_dns.go | 93 ++++++++++++++++++++++++++++------------------ 3 files changed, 65 insertions(+), 44 deletions(-) diff --git a/adapter/inbound.go b/adapter/inbound.go index f4d5802f..300f57e3 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -50,12 +50,14 @@ type InboundContext struct { // Deprecated InboundOptions option.InboundOptions UDPDisableDomainUnmapping bool - DestinationAddresses []netip.Addr - SourceGeoIPCode string - GeoIPCode string - ProcessInfo *process.Info - QueryType uint16 - FakeIP bool + DNSServer string + + DestinationAddresses []netip.Addr + SourceGeoIPCode string + GeoIPCode string + ProcessInfo *process.Info + QueryType uint16 + FakeIP bool // rule cache diff --git a/route/route.go b/route/route.go index cecd0f2a..56493bd1 100644 --- a/route/route.go +++ b/route/route.go @@ -584,7 +584,7 @@ func (r *Router) actionSniff( func (r *Router) actionResolve(ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionResolve) error { if metadata.Destination.IsFqdn() { - // TODO: check if WithContext is necessary + metadata.DNSServer = action.Server addresses, err := r.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, action.Strategy) if err != nil { return err diff --git a/route/route_dns.go b/route/route_dns.go index 43eb61e6..60aff6a9 100644 --- a/route/route_dns.go +++ b/route/route_dns.go @@ -185,41 +185,7 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS cached bool err error ) - responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy) - if cached { - if len(responseAddrs) == 0 { - return nil, dns.RCodeNameError - } - return responseAddrs, nil - } - r.dnsLogger.DebugContext(ctx, "lookup domain ", domain) - ctx, metadata := adapter.ExtendContext(ctx) - metadata.Destination = M.Socksaddr{} - metadata.Domain = domain - var ( - transport dns.Transport - options dns.QueryOptions - rule adapter.DNSRule - ruleIndex int - ) - ruleIndex = -1 - for { - dnsCtx := adapter.OverrideContext(ctx) - var addressLimit bool - transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true) - if strategy != dns.DomainStrategyAsIS { - options.Strategy = strategy - } - if rule != nil && rule.WithAddressLimit() { - addressLimit = true - responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool { - metadata.DestinationAddresses = responseAddrs - return rule.MatchAddressLimit(metadata) - }) - } else { - addressLimit = false - responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options) - } + printResult := func() { if err != nil { if errors.Is(err, dns.ErrResponseRejectedCached) { r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)") @@ -232,10 +198,63 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result") err = dns.RCodeNameError } - if !addressLimit || err == nil { - break + } + responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy) + if cached { + if len(responseAddrs) == 0 { + return nil, dns.RCodeNameError + } + return responseAddrs, nil + } + r.dnsLogger.DebugContext(ctx, "lookup domain ", domain) + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Destination = M.Socksaddr{} + metadata.Domain = domain + if metadata.DNSServer != "" { + transport, loaded := r.transportMap[metadata.DNSServer] + if !loaded { + return nil, E.New("transport not found: ", metadata.DNSServer) + } + if strategy == dns.DomainStrategyAsIS { + if transportDomainStrategy, loaded := r.transportDomainStrategy[transport]; loaded { + strategy = transportDomainStrategy + } else { + strategy = r.defaultDomainStrategy + } + } + responseAddrs, err = r.dnsClient.Lookup(ctx, transport, domain, dns.QueryOptions{Strategy: strategy}) + } else { + var ( + transport dns.Transport + options dns.QueryOptions + rule adapter.DNSRule + ruleIndex int + ) + ruleIndex = -1 + for { + dnsCtx := adapter.OverrideContext(ctx) + var addressLimit bool + transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true) + if strategy != dns.DomainStrategyAsIS { + options.Strategy = strategy + } + if rule != nil && rule.WithAddressLimit() { + addressLimit = true + responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool { + metadata.DestinationAddresses = responseAddrs + return rule.MatchAddressLimit(metadata) + }) + } else { + addressLimit = false + responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options) + } + if !addressLimit || err == nil { + break + } + printResult() } } + printResult() if len(responseAddrs) > 0 { r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " ")) }