diff --git a/dns/client.go b/dns/client.go index d755134b..2f7c21d5 100644 --- a/dns/client.go +++ b/dns/client.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" "golang.org/x/net/dns/dnsmessage" ) @@ -27,12 +28,21 @@ var ( var _ adapter.DNSClient = (*Client)(nil) type Client struct { - cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message] + cache *cache.LruCache[dnsmessage.Question, *dnsmessage.Message] + disableCache bool + disableExpire bool } -func NewClient() *Client { - return &Client{ - cache: cache.New[dnsmessage.Question, dnsmessage.Message](), +func NewClient(options option.DNSClientOptions) *Client { + if options.DisableCache { + return &Client{ + disableCache: true, + } + } else { + return &Client{ + cache: cache.New[dnsmessage.Question, *dnsmessage.Message](), + disableExpire: options.DisableExpire, + } } } @@ -41,10 +51,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m return nil, E.New("empty query") } question := message.Questions[0] - cachedAnswer, cached := c.cache.Load(question) - if cached { - cachedAnswer.ID = message.ID - return &cachedAnswer, nil + if !c.disableCache { + cachedAnswer, cached := c.cache.Load(question) + if cached { + cachedAnswer.ID = message.ID + return cachedAnswer, nil + } } if !transport.Raw() { if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA { @@ -56,7 +68,9 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m if err != nil { return nil, err } - c.cache.StoreWithExpire(question, *response, calculateExpire(message)) + if !c.disableCache { + c.storeCache(question, response) + } return message, err } @@ -93,37 +107,39 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom } return sortAddresses(response4, response6, strategy), nil } - if strategy == C.DomainStrategyUseIPv4 { - response, err := c.questionCache(dnsmessage.Question{ - Name: dnsName, - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }) - if err != ErrNotCached { - return response, err - } - } else if strategy == C.DomainStrategyUseIPv6 { - response, err := c.questionCache(dnsmessage.Question{ - Name: dnsName, - Type: dnsmessage.TypeAAAA, - Class: dnsmessage.ClassINET, - }) - if err != ErrNotCached { - return response, err - } - } else { - response4, _ := c.questionCache(dnsmessage.Question{ - Name: dnsName, - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }) - response6, _ := c.questionCache(dnsmessage.Question{ - Name: dnsName, - Type: dnsmessage.TypeAAAA, - Class: dnsmessage.ClassINET, - }) - if len(response4) > 0 || len(response6) > 0 { - return sortAddresses(response4, response6, strategy), nil + if !c.disableCache { + if strategy == C.DomainStrategyUseIPv4 { + response, err := c.questionCache(dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + if err != ErrNotCached { + return response, err + } + } else if strategy == C.DomainStrategyUseIPv6 { + response, err := c.questionCache(dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }) + if err != ErrNotCached { + return response, err + } + } else { + response4, _ := c.questionCache(dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + response6, _ := c.questionCache(dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }) + if len(response4) > 0 || len(response6) > 0 { + return sortAddresses(response4, response6, strategy), nil + } } } var rCode dnsmessage.RCode @@ -135,70 +151,74 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom } else { rCode = dnsmessage.RCode(rCodeError) } + if c.disableCache { + return nil, err + } } header := dnsmessage.Header{ Response: true, Authoritative: true, RCode: rCode, } - expire := time.Now().Add(time.Second * time.Duration(DefaultTTL)) - if strategy != C.DomainStrategyUseIPv6 { - question4 := dnsmessage.Question{ - Name: dnsName, - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - } - response4 := common.Filter(response, func(addr netip.Addr) bool { - return addr.Is4() || addr.Is4In6() - }) - message4 := dnsmessage.Message{ - Header: header, - Questions: []dnsmessage.Question{question4}, - } - if len(response4) > 0 { - for _, address := range response4 { - message4.Answers = append(message4.Answers, dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: question4.Name, - Class: question4.Class, - TTL: DefaultTTL, - }, - Body: &dnsmessage.AResource{ - A: address.As4(), - }, - }) + if !c.disableCache { + if strategy != C.DomainStrategyUseIPv6 { + question4 := dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, } - } - c.cache.StoreWithExpire(question4, message4, expire) - } - if strategy != C.DomainStrategyUseIPv4 { - question6 := dnsmessage.Question{ - Name: dnsName, - Type: dnsmessage.TypeAAAA, - Class: dnsmessage.ClassINET, - } - response6 := common.Filter(response, func(addr netip.Addr) bool { - return addr.Is6() && !addr.Is4In6() - }) - message6 := dnsmessage.Message{ - Header: header, - Questions: []dnsmessage.Question{question6}, - } - if len(response6) > 0 { - for _, address := range response6 { - message6.Answers = append(message6.Answers, dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: question6.Name, - Class: question6.Class, - TTL: DefaultTTL, - }, - Body: &dnsmessage.AAAAResource{ - AAAA: address.As16(), - }, - }) + response4 := common.Filter(response, func(addr netip.Addr) bool { + return addr.Is4() || addr.Is4In6() + }) + message4 := &dnsmessage.Message{ + Header: header, + Questions: []dnsmessage.Question{question4}, } + if len(response4) > 0 { + for _, address := range response4 { + message4.Answers = append(message4.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: question4.Name, + Class: question4.Class, + TTL: DefaultTTL, + }, + Body: &dnsmessage.AResource{ + A: address.As4(), + }, + }) + } + } + c.storeCache(question4, message4) + } + if strategy != C.DomainStrategyUseIPv4 { + question6 := dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + } + response6 := common.Filter(response, func(addr netip.Addr) bool { + return addr.Is6() && !addr.Is4In6() + }) + message6 := &dnsmessage.Message{ + Header: header, + Questions: []dnsmessage.Question{question6}, + } + if len(response6) > 0 { + for _, address := range response6 { + message6.Answers = append(message6.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: question6.Name, + Class: question6.Class, + TTL: DefaultTTL, + }, + Body: &dnsmessage.AAAAResource{ + AAAA: address.As16(), + }, + }) + } + } + c.storeCache(question6, message6) } - c.cache.StoreWithExpire(question6, message6, expire) } return response, err } @@ -211,14 +231,19 @@ func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.Do } } -func calculateExpire(message *dnsmessage.Message) time.Time { +func (c *Client) storeCache(question dnsmessage.Question, message *dnsmessage.Message) { + if c.disableExpire { + c.cache.Store(question, message) + return + } timeToLive := DefaultTTL for _, answer := range message.Answers { if int(answer.Header.TTL) < timeToLive { timeToLive = int(answer.Header.TTL) } } - return time.Now().Add(time.Second * time.Duration(timeToLive)) + expire := time.Now().Add(time.Second * time.Duration(timeToLive)) + c.cache.StoreWithExpire(question, message, expire) } func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) { @@ -275,9 +300,11 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran Type: qType, Class: dnsmessage.ClassINET, } - cachedAddresses, err := c.questionCache(question) - if err != ErrNotCached { - return cachedAddresses, err + if !c.disableCache { + cachedAddresses, err := c.questionCache(question) + if err != ErrNotCached { + return cachedAddresses, err + } } message := dnsmessage.Message{ Header: dnsmessage.Header{ @@ -298,7 +325,7 @@ func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, erro if !cached { return nil, ErrNotCached } - return messageToAddresses(&response) + return messageToAddresses(response) } func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) { diff --git a/dns/transport_tcp.go b/dns/transport_tcp.go index e57d6ad7..945928b9 100644 --- a/dns/transport_tcp.go +++ b/dns/transport_tcp.go @@ -91,7 +91,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) { cancel() conn.err = err if err != nil { - t.logger.Warn("connection closed: ", err) + t.logger.Debug("connection closed: ", err) } } diff --git a/dns/transport_tls.go b/dns/transport_tls.go index 2a2fb17e..b28dbfd7 100644 --- a/dns/transport_tls.go +++ b/dns/transport_tls.go @@ -99,7 +99,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) { cancel() conn.err = err if err != nil { - t.logger.Warn("connection closed: ", err) + t.logger.Debug("connection closed: ", err) } } diff --git a/dns/transport_udp.go b/dns/transport_udp.go index 81f64142..a0358f80 100644 --- a/dns/transport_udp.go +++ b/dns/transport_udp.go @@ -87,7 +87,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) { cancel() conn.err = err if err != nil { - t.logger.Warn("connection closed: ", err) + t.logger.Debug("connection closed: ", err) } } diff --git a/option/dns.go b/option/dns.go index 17b463ab..f7f380e2 100644 --- a/option/dns.go +++ b/option/dns.go @@ -2,11 +2,17 @@ package option type DNSOptions struct { Servers []DNSServerOptions `json:"servers,omitempty"` + DNSClientOptions +} + +type DNSClientOptions struct { + DisableCache bool `json:"disable_cache,omitempty"` + DisableExpire bool `json:"disable_expire,omitempty"` } type DNSServerOptions struct { Tag string `json:"tag,omitempty"` Address string `json:"address"` - Detour string `json:"detour,omitempty"` AddressResolver string `json:"address_resolver,omitempty"` + DialerOptions }