Add disableCache/disableExpire option for dns client

This commit is contained in:
世界 2022-07-06 23:39:17 +08:00
parent 8a761d7e3b
commit ecac383477
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 138 additions and 105 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
) )
@ -27,12 +28,21 @@ var (
var _ adapter.DNSClient = (*Client)(nil) var _ adapter.DNSClient = (*Client)(nil)
type Client struct { type Client struct {
cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message] cache *cache.LruCache[dnsmessage.Question, *dnsmessage.Message]
disableCache bool
disableExpire bool
} }
func NewClient() *Client { func NewClient(options option.DNSClientOptions) *Client {
if options.DisableCache {
return &Client{ return &Client{
cache: cache.New[dnsmessage.Question, dnsmessage.Message](), disableCache: true,
}
} else {
return &Client{
cache: cache.New[dnsmessage.Question, *dnsmessage.Message](),
disableExpire: options.DisableExpire,
}
} }
} }
@ -41,10 +51,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
return nil, E.New("empty query") return nil, E.New("empty query")
} }
question := message.Questions[0] question := message.Questions[0]
if !c.disableCache {
cachedAnswer, cached := c.cache.Load(question) cachedAnswer, cached := c.cache.Load(question)
if cached { if cached {
cachedAnswer.ID = message.ID cachedAnswer.ID = message.ID
return &cachedAnswer, nil return cachedAnswer, nil
}
} }
if !transport.Raw() { if !transport.Raw() {
if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA { if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA {
@ -56,7 +68,9 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.cache.StoreWithExpire(question, *response, calculateExpire(message)) if !c.disableCache {
c.storeCache(question, response)
}
return message, err return message, err
} }
@ -93,6 +107,7 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
} }
return sortAddresses(response4, response6, strategy), nil return sortAddresses(response4, response6, strategy), nil
} }
if !c.disableCache {
if strategy == C.DomainStrategyUseIPv4 { if strategy == C.DomainStrategyUseIPv4 {
response, err := c.questionCache(dnsmessage.Question{ response, err := c.questionCache(dnsmessage.Question{
Name: dnsName, Name: dnsName,
@ -126,6 +141,7 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
return sortAddresses(response4, response6, strategy), nil return sortAddresses(response4, response6, strategy), nil
} }
} }
}
var rCode dnsmessage.RCode var rCode dnsmessage.RCode
response, err := transport.Lookup(ctx, domain, strategy) response, err := transport.Lookup(ctx, domain, strategy)
if err != nil { if err != nil {
@ -135,13 +151,16 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
} else { } else {
rCode = dnsmessage.RCode(rCodeError) rCode = dnsmessage.RCode(rCodeError)
} }
if c.disableCache {
return nil, err
}
} }
header := dnsmessage.Header{ header := dnsmessage.Header{
Response: true, Response: true,
Authoritative: true, Authoritative: true,
RCode: rCode, RCode: rCode,
} }
expire := time.Now().Add(time.Second * time.Duration(DefaultTTL)) if !c.disableCache {
if strategy != C.DomainStrategyUseIPv6 { if strategy != C.DomainStrategyUseIPv6 {
question4 := dnsmessage.Question{ question4 := dnsmessage.Question{
Name: dnsName, Name: dnsName,
@ -151,7 +170,7 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
response4 := common.Filter(response, func(addr netip.Addr) bool { response4 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is4() || addr.Is4In6() return addr.Is4() || addr.Is4In6()
}) })
message4 := dnsmessage.Message{ message4 := &dnsmessage.Message{
Header: header, Header: header,
Questions: []dnsmessage.Question{question4}, Questions: []dnsmessage.Question{question4},
} }
@ -169,7 +188,7 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
}) })
} }
} }
c.cache.StoreWithExpire(question4, message4, expire) c.storeCache(question4, message4)
} }
if strategy != C.DomainStrategyUseIPv4 { if strategy != C.DomainStrategyUseIPv4 {
question6 := dnsmessage.Question{ question6 := dnsmessage.Question{
@ -180,7 +199,7 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
response6 := common.Filter(response, func(addr netip.Addr) bool { response6 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is6() && !addr.Is4In6() return addr.Is6() && !addr.Is4In6()
}) })
message6 := dnsmessage.Message{ message6 := &dnsmessage.Message{
Header: header, Header: header,
Questions: []dnsmessage.Question{question6}, Questions: []dnsmessage.Question{question6},
} }
@ -198,7 +217,8 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
}) })
} }
} }
c.cache.StoreWithExpire(question6, message6, expire) c.storeCache(question6, message6)
}
} }
return response, err return response, err
} }
@ -211,14 +231,19 @@ func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.Do
} }
} }
func calculateExpire(message *dnsmessage.Message) time.Time { func (c *Client) storeCache(question dnsmessage.Question, message *dnsmessage.Message) {
if c.disableExpire {
c.cache.Store(question, message)
return
}
timeToLive := DefaultTTL timeToLive := DefaultTTL
for _, answer := range message.Answers { for _, answer := range message.Answers {
if int(answer.Header.TTL) < timeToLive { if int(answer.Header.TTL) < timeToLive {
timeToLive = int(answer.Header.TTL) timeToLive = int(answer.Header.TTL)
} }
} }
return time.Now().Add(time.Second * time.Duration(timeToLive)) expire := time.Now().Add(time.Second * time.Duration(timeToLive))
c.cache.StoreWithExpire(question, message, expire)
} }
func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) { func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) {
@ -275,10 +300,12 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
Type: qType, Type: qType,
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
} }
if !c.disableCache {
cachedAddresses, err := c.questionCache(question) cachedAddresses, err := c.questionCache(question)
if err != ErrNotCached { if err != ErrNotCached {
return cachedAddresses, err return cachedAddresses, err
} }
}
message := dnsmessage.Message{ message := dnsmessage.Message{
Header: dnsmessage.Header{ Header: dnsmessage.Header{
ID: 0, ID: 0,
@ -298,7 +325,7 @@ func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, erro
if !cached { if !cached {
return nil, ErrNotCached return nil, ErrNotCached
} }
return messageToAddresses(&response) return messageToAddresses(response)
} }
func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) { func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) {

View file

@ -91,7 +91,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
cancel() cancel()
conn.err = err conn.err = err
if err != nil { if err != nil {
t.logger.Warn("connection closed: ", err) t.logger.Debug("connection closed: ", err)
} }
} }

View file

@ -99,7 +99,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
cancel() cancel()
conn.err = err conn.err = err
if err != nil { if err != nil {
t.logger.Warn("connection closed: ", err) t.logger.Debug("connection closed: ", err)
} }
} }

View file

@ -87,7 +87,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
cancel() cancel()
conn.err = err conn.err = err
if err != nil { if err != nil {
t.logger.Warn("connection closed: ", err) t.logger.Debug("connection closed: ", err)
} }
} }

View file

@ -2,11 +2,17 @@ package option
type DNSOptions struct { type DNSOptions struct {
Servers []DNSServerOptions `json:"servers,omitempty"` Servers []DNSServerOptions `json:"servers,omitempty"`
DNSClientOptions
}
type DNSClientOptions struct {
DisableCache bool `json:"disable_cache,omitempty"`
DisableExpire bool `json:"disable_expire,omitempty"`
} }
type DNSServerOptions struct { type DNSServerOptions struct {
Tag string `json:"tag,omitempty"` Tag string `json:"tag,omitempty"`
Address string `json:"address"` Address string `json:"address"`
Detour string `json:"detour,omitempty"`
AddressResolver string `json:"address_resolver,omitempty"` AddressResolver string `json:"address_resolver,omitempty"`
DialerOptions
} }