diff --git a/dns/client.go b/dns/client.go index 0f593ec7..4eb529e8 100644 --- a/dns/client.go +++ b/dns/client.go @@ -31,24 +31,32 @@ type Client struct { cache *cache.LruCache[dnsmessage.Question, *dnsmessage.Message] disableCache bool disableExpire bool + strategy C.DomainStrategy } func NewClient(options option.DNSClientOptions) *Client { - if options.DisableCache { - return &Client{ - disableCache: true, - } - } else { - return &Client{ - cache: cache.New[dnsmessage.Question, *dnsmessage.Message](), - disableExpire: options.DisableExpire, - } + client := &Client{ + disableCache: options.DisableCache, + disableExpire: options.DisableExpire, + strategy: C.DomainStrategy(options.Strategy), } + if !options.DisableCache { + client.cache = cache.New[dnsmessage.Question, *dnsmessage.Message]() + } + return client } func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message) (*dnsmessage.Message, error) { - if len(message.Questions) == 0 { - return nil, E.New("empty query") + if len(message.Questions) != 1 { + responseMessage := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: message.ID, + RCode: dnsmessage.RCodeFormatError, + Response: true, + RecursionDesired: true, + }, + } + return &responseMessage, nil } question := message.Questions[0] if !c.disableCache { @@ -64,6 +72,18 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m } return nil, ErrNoRawSupport } + if question.Type == dnsmessage.TypeA && c.strategy == C.DomainStrategyUseIPv6 || question.Type == dnsmessage.TypeAAAA && c.strategy == C.DomainStrategyUseIPv4 { + responseMessage := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: message.ID, + RCode: dnsmessage.RCodeNameError, + Response: true, + RecursionDesired: true, + }, + Questions: []dnsmessage.Question{question}, + } + return &responseMessage, nil + } messageId := message.ID response, err := transport.Exchange(ctx, message) if err != nil { diff --git a/option/dns.go b/option/dns.go index ba21bb2a..8909437f 100644 --- a/option/dns.go +++ b/option/dns.go @@ -9,10 +9,9 @@ import ( ) type DNSOptions struct { - Servers []DNSServerOptions `json:"servers,omitempty"` - Rules []DNSRule `json:"rules,omitempty"` - Final string `json:"final,omitempty"` - Strategy DomainStrategy `json:"strategy,omitempty"` + Servers []DNSServerOptions `json:"servers,omitempty"` + Rules []DNSRule `json:"rules,omitempty"` + Final string `json:"final,omitempty"` DNSClientOptions } @@ -20,13 +19,13 @@ func (o DNSOptions) Equals(other DNSOptions) bool { return common.ComparableSliceEquals(o.Servers, other.Servers) && common.SliceEquals(o.Rules, other.Rules) && o.Final == other.Final && - o.Strategy == other.Strategy && o.DNSClientOptions == other.DNSClientOptions } type DNSClientOptions struct { - DisableCache bool `json:"disable_cache,omitempty"` - DisableExpire bool `json:"disable_expire,omitempty"` + Strategy DomainStrategy `json:"strategy,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` + DisableExpire bool `json:"disable_expire,omitempty"` } type DNSServerOptions struct {