From 0ac8e6e8d80d4f58a2e866e455b601c2fcae8e2b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Tue, 22 Oct 2024 22:01:28 +0800
Subject: [PATCH] Implement resolve(server)

---
 adapter/inbound.go | 14 ++++---
 route/route.go     |  2 +-
 route/route_dns.go | 93 ++++++++++++++++++++++++++++------------------
 3 files changed, 65 insertions(+), 44 deletions(-)

diff --git a/adapter/inbound.go b/adapter/inbound.go
index f4d5802f..300f57e3 100644
--- a/adapter/inbound.go
+++ b/adapter/inbound.go
@@ -50,12 +50,14 @@ type InboundContext struct {
 	// Deprecated
 	InboundOptions            option.InboundOptions
 	UDPDisableDomainUnmapping bool
-	DestinationAddresses      []netip.Addr
-	SourceGeoIPCode           string
-	GeoIPCode                 string
-	ProcessInfo               *process.Info
-	QueryType                 uint16
-	FakeIP                    bool
+	DNSServer                 string
+
+	DestinationAddresses []netip.Addr
+	SourceGeoIPCode      string
+	GeoIPCode            string
+	ProcessInfo          *process.Info
+	QueryType            uint16
+	FakeIP               bool
 
 	// rule cache
 
diff --git a/route/route.go b/route/route.go
index cecd0f2a..56493bd1 100644
--- a/route/route.go
+++ b/route/route.go
@@ -584,7 +584,7 @@ func (r *Router) actionSniff(
 
 func (r *Router) actionResolve(ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionResolve) error {
 	if metadata.Destination.IsFqdn() {
-		// TODO: check if WithContext is necessary
+		metadata.DNSServer = action.Server
 		addresses, err := r.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, action.Strategy)
 		if err != nil {
 			return err
diff --git a/route/route_dns.go b/route/route_dns.go
index 43eb61e6..60aff6a9 100644
--- a/route/route_dns.go
+++ b/route/route_dns.go
@@ -185,41 +185,7 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
 		cached        bool
 		err           error
 	)
-	responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy)
-	if cached {
-		if len(responseAddrs) == 0 {
-			return nil, dns.RCodeNameError
-		}
-		return responseAddrs, nil
-	}
-	r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
-	ctx, metadata := adapter.ExtendContext(ctx)
-	metadata.Destination = M.Socksaddr{}
-	metadata.Domain = domain
-	var (
-		transport dns.Transport
-		options   dns.QueryOptions
-		rule      adapter.DNSRule
-		ruleIndex int
-	)
-	ruleIndex = -1
-	for {
-		dnsCtx := adapter.OverrideContext(ctx)
-		var addressLimit bool
-		transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true)
-		if strategy != dns.DomainStrategyAsIS {
-			options.Strategy = strategy
-		}
-		if rule != nil && rule.WithAddressLimit() {
-			addressLimit = true
-			responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {
-				metadata.DestinationAddresses = responseAddrs
-				return rule.MatchAddressLimit(metadata)
-			})
-		} else {
-			addressLimit = false
-			responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options)
-		}
+	printResult := func() {
 		if err != nil {
 			if errors.Is(err, dns.ErrResponseRejectedCached) {
 				r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
@@ -232,10 +198,63 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
 			r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
 			err = dns.RCodeNameError
 		}
-		if !addressLimit || err == nil {
-			break
+	}
+	responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy)
+	if cached {
+		if len(responseAddrs) == 0 {
+			return nil, dns.RCodeNameError
+		}
+		return responseAddrs, nil
+	}
+	r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
+	ctx, metadata := adapter.ExtendContext(ctx)
+	metadata.Destination = M.Socksaddr{}
+	metadata.Domain = domain
+	if metadata.DNSServer != "" {
+		transport, loaded := r.transportMap[metadata.DNSServer]
+		if !loaded {
+			return nil, E.New("transport not found: ", metadata.DNSServer)
+		}
+		if strategy == dns.DomainStrategyAsIS {
+			if transportDomainStrategy, loaded := r.transportDomainStrategy[transport]; loaded {
+				strategy = transportDomainStrategy
+			} else {
+				strategy = r.defaultDomainStrategy
+			}
+		}
+		responseAddrs, err = r.dnsClient.Lookup(ctx, transport, domain, dns.QueryOptions{Strategy: strategy})
+	} else {
+		var (
+			transport dns.Transport
+			options   dns.QueryOptions
+			rule      adapter.DNSRule
+			ruleIndex int
+		)
+		ruleIndex = -1
+		for {
+			dnsCtx := adapter.OverrideContext(ctx)
+			var addressLimit bool
+			transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true)
+			if strategy != dns.DomainStrategyAsIS {
+				options.Strategy = strategy
+			}
+			if rule != nil && rule.WithAddressLimit() {
+				addressLimit = true
+				responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {
+					metadata.DestinationAddresses = responseAddrs
+					return rule.MatchAddressLimit(metadata)
+				})
+			} else {
+				addressLimit = false
+				responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options)
+			}
+			if !addressLimit || err == nil {
+				break
+			}
+			printResult()
 		}
 	}
+	printResult()
 	if len(responseAddrs) > 0 {
 		r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
 	}