Add disableCache/disableExpire option for dns client

This commit is contained in:
世界 2022-07-06 23:39:17 +08:00
parent 8a761d7e3b
commit ecac383477
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 138 additions and 105 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
) )
@ -27,12 +28,21 @@ var (
var _ adapter.DNSClient = (*Client)(nil) var _ adapter.DNSClient = (*Client)(nil)
type Client struct { type Client struct {
cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message] cache *cache.LruCache[dnsmessage.Question, *dnsmessage.Message]
disableCache bool
disableExpire bool
} }
func NewClient() *Client { func NewClient(options option.DNSClientOptions) *Client {
return &Client{ if options.DisableCache {
cache: cache.New[dnsmessage.Question, dnsmessage.Message](), return &Client{
disableCache: true,
}
} else {
return &Client{
cache: cache.New[dnsmessage.Question, *dnsmessage.Message](),
disableExpire: options.DisableExpire,
}
} }
} }
@ -41,10 +51,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
return nil, E.New("empty query") return nil, E.New("empty query")
} }
question := message.Questions[0] question := message.Questions[0]
cachedAnswer, cached := c.cache.Load(question) if !c.disableCache {
if cached { cachedAnswer, cached := c.cache.Load(question)
cachedAnswer.ID = message.ID if cached {
return &cachedAnswer, nil cachedAnswer.ID = message.ID
return cachedAnswer, nil
}
} }
if !transport.Raw() { if !transport.Raw() {
if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA { if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA {
@ -56,7 +68,9 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.cache.StoreWithExpire(question, *response, calculateExpire(message)) if !c.disableCache {
c.storeCache(question, response)
}
return message, err return message, err
} }
@ -93,37 +107,39 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
} }
return sortAddresses(response4, response6, strategy), nil return sortAddresses(response4, response6, strategy), nil
} }
if strategy == C.DomainStrategyUseIPv4 { if !c.disableCache {
response, err := c.questionCache(dnsmessage.Question{ if strategy == C.DomainStrategyUseIPv4 {
Name: dnsName, response, err := c.questionCache(dnsmessage.Question{
Type: dnsmessage.TypeA, Name: dnsName,
Class: dnsmessage.ClassINET, Type: dnsmessage.TypeA,
}) Class: dnsmessage.ClassINET,
if err != ErrNotCached { })
return response, err if err != ErrNotCached {
} return response, err
} else if strategy == C.DomainStrategyUseIPv6 { }
response, err := c.questionCache(dnsmessage.Question{ } else if strategy == C.DomainStrategyUseIPv6 {
Name: dnsName, response, err := c.questionCache(dnsmessage.Question{
Type: dnsmessage.TypeAAAA, Name: dnsName,
Class: dnsmessage.ClassINET, Type: dnsmessage.TypeAAAA,
}) Class: dnsmessage.ClassINET,
if err != ErrNotCached { })
return response, err if err != ErrNotCached {
} return response, err
} else { }
response4, _ := c.questionCache(dnsmessage.Question{ } else {
Name: dnsName, response4, _ := c.questionCache(dnsmessage.Question{
Type: dnsmessage.TypeA, Name: dnsName,
Class: dnsmessage.ClassINET, Type: dnsmessage.TypeA,
}) Class: dnsmessage.ClassINET,
response6, _ := c.questionCache(dnsmessage.Question{ })
Name: dnsName, response6, _ := c.questionCache(dnsmessage.Question{
Type: dnsmessage.TypeAAAA, Name: dnsName,
Class: dnsmessage.ClassINET, Type: dnsmessage.TypeAAAA,
}) Class: dnsmessage.ClassINET,
if len(response4) > 0 || len(response6) > 0 { })
return sortAddresses(response4, response6, strategy), nil if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), nil
}
} }
} }
var rCode dnsmessage.RCode var rCode dnsmessage.RCode
@ -135,70 +151,74 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
} else { } else {
rCode = dnsmessage.RCode(rCodeError) rCode = dnsmessage.RCode(rCodeError)
} }
if c.disableCache {
return nil, err
}
} }
header := dnsmessage.Header{ header := dnsmessage.Header{
Response: true, Response: true,
Authoritative: true, Authoritative: true,
RCode: rCode, RCode: rCode,
} }
expire := time.Now().Add(time.Second * time.Duration(DefaultTTL)) if !c.disableCache {
if strategy != C.DomainStrategyUseIPv6 { if strategy != C.DomainStrategyUseIPv6 {
question4 := dnsmessage.Question{ question4 := dnsmessage.Question{
Name: dnsName, Name: dnsName,
Type: dnsmessage.TypeA, Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
}
response4 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is4() || addr.Is4In6()
})
message4 := dnsmessage.Message{
Header: header,
Questions: []dnsmessage.Question{question4},
}
if len(response4) > 0 {
for _, address := range response4 {
message4.Answers = append(message4.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: question4.Name,
Class: question4.Class,
TTL: DefaultTTL,
},
Body: &dnsmessage.AResource{
A: address.As4(),
},
})
} }
} response4 := common.Filter(response, func(addr netip.Addr) bool {
c.cache.StoreWithExpire(question4, message4, expire) return addr.Is4() || addr.Is4In6()
} })
if strategy != C.DomainStrategyUseIPv4 { message4 := &dnsmessage.Message{
question6 := dnsmessage.Question{ Header: header,
Name: dnsName, Questions: []dnsmessage.Question{question4},
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
}
response6 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is6() && !addr.Is4In6()
})
message6 := dnsmessage.Message{
Header: header,
Questions: []dnsmessage.Question{question6},
}
if len(response6) > 0 {
for _, address := range response6 {
message6.Answers = append(message6.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: question6.Name,
Class: question6.Class,
TTL: DefaultTTL,
},
Body: &dnsmessage.AAAAResource{
AAAA: address.As16(),
},
})
} }
if len(response4) > 0 {
for _, address := range response4 {
message4.Answers = append(message4.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: question4.Name,
Class: question4.Class,
TTL: DefaultTTL,
},
Body: &dnsmessage.AResource{
A: address.As4(),
},
})
}
}
c.storeCache(question4, message4)
}
if strategy != C.DomainStrategyUseIPv4 {
question6 := dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
}
response6 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is6() && !addr.Is4In6()
})
message6 := &dnsmessage.Message{
Header: header,
Questions: []dnsmessage.Question{question6},
}
if len(response6) > 0 {
for _, address := range response6 {
message6.Answers = append(message6.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: question6.Name,
Class: question6.Class,
TTL: DefaultTTL,
},
Body: &dnsmessage.AAAAResource{
AAAA: address.As16(),
},
})
}
}
c.storeCache(question6, message6)
} }
c.cache.StoreWithExpire(question6, message6, expire)
} }
return response, err return response, err
} }
@ -211,14 +231,19 @@ func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.Do
} }
} }
func calculateExpire(message *dnsmessage.Message) time.Time { func (c *Client) storeCache(question dnsmessage.Question, message *dnsmessage.Message) {
if c.disableExpire {
c.cache.Store(question, message)
return
}
timeToLive := DefaultTTL timeToLive := DefaultTTL
for _, answer := range message.Answers { for _, answer := range message.Answers {
if int(answer.Header.TTL) < timeToLive { if int(answer.Header.TTL) < timeToLive {
timeToLive = int(answer.Header.TTL) timeToLive = int(answer.Header.TTL)
} }
} }
return time.Now().Add(time.Second * time.Duration(timeToLive)) expire := time.Now().Add(time.Second * time.Duration(timeToLive))
c.cache.StoreWithExpire(question, message, expire)
} }
func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) { func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) {
@ -275,9 +300,11 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
Type: qType, Type: qType,
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
} }
cachedAddresses, err := c.questionCache(question) if !c.disableCache {
if err != ErrNotCached { cachedAddresses, err := c.questionCache(question)
return cachedAddresses, err if err != ErrNotCached {
return cachedAddresses, err
}
} }
message := dnsmessage.Message{ message := dnsmessage.Message{
Header: dnsmessage.Header{ Header: dnsmessage.Header{
@ -298,7 +325,7 @@ func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, erro
if !cached { if !cached {
return nil, ErrNotCached return nil, ErrNotCached
} }
return messageToAddresses(&response) return messageToAddresses(response)
} }
func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) { func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) {

View file

@ -91,7 +91,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
cancel() cancel()
conn.err = err conn.err = err
if err != nil { if err != nil {
t.logger.Warn("connection closed: ", err) t.logger.Debug("connection closed: ", err)
} }
} }

View file

@ -99,7 +99,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
cancel() cancel()
conn.err = err conn.err = err
if err != nil { if err != nil {
t.logger.Warn("connection closed: ", err) t.logger.Debug("connection closed: ", err)
} }
} }

View file

@ -87,7 +87,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
cancel() cancel()
conn.err = err conn.err = err
if err != nil { if err != nil {
t.logger.Warn("connection closed: ", err) t.logger.Debug("connection closed: ", err)
} }
} }

View file

@ -2,11 +2,17 @@ package option
type DNSOptions struct { type DNSOptions struct {
Servers []DNSServerOptions `json:"servers,omitempty"` Servers []DNSServerOptions `json:"servers,omitempty"`
DNSClientOptions
}
type DNSClientOptions struct {
DisableCache bool `json:"disable_cache,omitempty"`
DisableExpire bool `json:"disable_expire,omitempty"`
} }
type DNSServerOptions struct { type DNSServerOptions struct {
Tag string `json:"tag,omitempty"` Tag string `json:"tag,omitempty"`
Address string `json:"address"` Address string `json:"address"`
Detour string `json:"detour,omitempty"`
AddressResolver string `json:"address_resolver,omitempty"` AddressResolver string `json:"address_resolver,omitempty"`
DialerOptions
} }