mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-25 10:01:30 +00:00
Add disableCache/disableExpire option for dns client
This commit is contained in:
parent
8a761d7e3b
commit
ecac383477
229
dns/client.go
229
dns/client.go
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
@ -27,12 +28,21 @@ var (
|
|||
var _ adapter.DNSClient = (*Client)(nil)
|
||||
|
||||
type Client struct {
|
||||
cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message]
|
||||
cache *cache.LruCache[dnsmessage.Question, *dnsmessage.Message]
|
||||
disableCache bool
|
||||
disableExpire bool
|
||||
}
|
||||
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
cache: cache.New[dnsmessage.Question, dnsmessage.Message](),
|
||||
func NewClient(options option.DNSClientOptions) *Client {
|
||||
if options.DisableCache {
|
||||
return &Client{
|
||||
disableCache: true,
|
||||
}
|
||||
} else {
|
||||
return &Client{
|
||||
cache: cache.New[dnsmessage.Question, *dnsmessage.Message](),
|
||||
disableExpire: options.DisableExpire,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,10 +51,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||
return nil, E.New("empty query")
|
||||
}
|
||||
question := message.Questions[0]
|
||||
cachedAnswer, cached := c.cache.Load(question)
|
||||
if cached {
|
||||
cachedAnswer.ID = message.ID
|
||||
return &cachedAnswer, nil
|
||||
if !c.disableCache {
|
||||
cachedAnswer, cached := c.cache.Load(question)
|
||||
if cached {
|
||||
cachedAnswer.ID = message.ID
|
||||
return cachedAnswer, nil
|
||||
}
|
||||
}
|
||||
if !transport.Raw() {
|
||||
if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA {
|
||||
|
@ -56,7 +68,9 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.cache.StoreWithExpire(question, *response, calculateExpire(message))
|
||||
if !c.disableCache {
|
||||
c.storeCache(question, response)
|
||||
}
|
||||
return message, err
|
||||
}
|
||||
|
||||
|
@ -93,37 +107,39 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
|
|||
}
|
||||
return sortAddresses(response4, response6, strategy), nil
|
||||
}
|
||||
if strategy == C.DomainStrategyUseIPv4 {
|
||||
response, err := c.questionCache(dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
if err != ErrNotCached {
|
||||
return response, err
|
||||
}
|
||||
} else if strategy == C.DomainStrategyUseIPv6 {
|
||||
response, err := c.questionCache(dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
if err != ErrNotCached {
|
||||
return response, err
|
||||
}
|
||||
} else {
|
||||
response4, _ := c.questionCache(dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
response6, _ := c.questionCache(dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
if len(response4) > 0 || len(response6) > 0 {
|
||||
return sortAddresses(response4, response6, strategy), nil
|
||||
if !c.disableCache {
|
||||
if strategy == C.DomainStrategyUseIPv4 {
|
||||
response, err := c.questionCache(dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
if err != ErrNotCached {
|
||||
return response, err
|
||||
}
|
||||
} else if strategy == C.DomainStrategyUseIPv6 {
|
||||
response, err := c.questionCache(dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
if err != ErrNotCached {
|
||||
return response, err
|
||||
}
|
||||
} else {
|
||||
response4, _ := c.questionCache(dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
response6, _ := c.questionCache(dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
if len(response4) > 0 || len(response6) > 0 {
|
||||
return sortAddresses(response4, response6, strategy), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
var rCode dnsmessage.RCode
|
||||
|
@ -135,70 +151,74 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
|
|||
} else {
|
||||
rCode = dnsmessage.RCode(rCodeError)
|
||||
}
|
||||
if c.disableCache {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
header := dnsmessage.Header{
|
||||
Response: true,
|
||||
Authoritative: true,
|
||||
RCode: rCode,
|
||||
}
|
||||
expire := time.Now().Add(time.Second * time.Duration(DefaultTTL))
|
||||
if strategy != C.DomainStrategyUseIPv6 {
|
||||
question4 := dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
response4 := common.Filter(response, func(addr netip.Addr) bool {
|
||||
return addr.Is4() || addr.Is4In6()
|
||||
})
|
||||
message4 := dnsmessage.Message{
|
||||
Header: header,
|
||||
Questions: []dnsmessage.Question{question4},
|
||||
}
|
||||
if len(response4) > 0 {
|
||||
for _, address := range response4 {
|
||||
message4.Answers = append(message4.Answers, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: question4.Name,
|
||||
Class: question4.Class,
|
||||
TTL: DefaultTTL,
|
||||
},
|
||||
Body: &dnsmessage.AResource{
|
||||
A: address.As4(),
|
||||
},
|
||||
})
|
||||
if !c.disableCache {
|
||||
if strategy != C.DomainStrategyUseIPv6 {
|
||||
question4 := dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
}
|
||||
c.cache.StoreWithExpire(question4, message4, expire)
|
||||
}
|
||||
if strategy != C.DomainStrategyUseIPv4 {
|
||||
question6 := dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
response6 := common.Filter(response, func(addr netip.Addr) bool {
|
||||
return addr.Is6() && !addr.Is4In6()
|
||||
})
|
||||
message6 := dnsmessage.Message{
|
||||
Header: header,
|
||||
Questions: []dnsmessage.Question{question6},
|
||||
}
|
||||
if len(response6) > 0 {
|
||||
for _, address := range response6 {
|
||||
message6.Answers = append(message6.Answers, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: question6.Name,
|
||||
Class: question6.Class,
|
||||
TTL: DefaultTTL,
|
||||
},
|
||||
Body: &dnsmessage.AAAAResource{
|
||||
AAAA: address.As16(),
|
||||
},
|
||||
})
|
||||
response4 := common.Filter(response, func(addr netip.Addr) bool {
|
||||
return addr.Is4() || addr.Is4In6()
|
||||
})
|
||||
message4 := &dnsmessage.Message{
|
||||
Header: header,
|
||||
Questions: []dnsmessage.Question{question4},
|
||||
}
|
||||
if len(response4) > 0 {
|
||||
for _, address := range response4 {
|
||||
message4.Answers = append(message4.Answers, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: question4.Name,
|
||||
Class: question4.Class,
|
||||
TTL: DefaultTTL,
|
||||
},
|
||||
Body: &dnsmessage.AResource{
|
||||
A: address.As4(),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
c.storeCache(question4, message4)
|
||||
}
|
||||
if strategy != C.DomainStrategyUseIPv4 {
|
||||
question6 := dnsmessage.Question{
|
||||
Name: dnsName,
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
response6 := common.Filter(response, func(addr netip.Addr) bool {
|
||||
return addr.Is6() && !addr.Is4In6()
|
||||
})
|
||||
message6 := &dnsmessage.Message{
|
||||
Header: header,
|
||||
Questions: []dnsmessage.Question{question6},
|
||||
}
|
||||
if len(response6) > 0 {
|
||||
for _, address := range response6 {
|
||||
message6.Answers = append(message6.Answers, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: question6.Name,
|
||||
Class: question6.Class,
|
||||
TTL: DefaultTTL,
|
||||
},
|
||||
Body: &dnsmessage.AAAAResource{
|
||||
AAAA: address.As16(),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
c.storeCache(question6, message6)
|
||||
}
|
||||
c.cache.StoreWithExpire(question6, message6, expire)
|
||||
}
|
||||
return response, err
|
||||
}
|
||||
|
@ -211,14 +231,19 @@ func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.Do
|
|||
}
|
||||
}
|
||||
|
||||
func calculateExpire(message *dnsmessage.Message) time.Time {
|
||||
func (c *Client) storeCache(question dnsmessage.Question, message *dnsmessage.Message) {
|
||||
if c.disableExpire {
|
||||
c.cache.Store(question, message)
|
||||
return
|
||||
}
|
||||
timeToLive := DefaultTTL
|
||||
for _, answer := range message.Answers {
|
||||
if int(answer.Header.TTL) < timeToLive {
|
||||
timeToLive = int(answer.Header.TTL)
|
||||
}
|
||||
}
|
||||
return time.Now().Add(time.Second * time.Duration(timeToLive))
|
||||
expire := time.Now().Add(time.Second * time.Duration(timeToLive))
|
||||
c.cache.StoreWithExpire(question, message, expire)
|
||||
}
|
||||
|
||||
func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) {
|
||||
|
@ -275,9 +300,11 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
|
|||
Type: qType,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
cachedAddresses, err := c.questionCache(question)
|
||||
if err != ErrNotCached {
|
||||
return cachedAddresses, err
|
||||
if !c.disableCache {
|
||||
cachedAddresses, err := c.questionCache(question)
|
||||
if err != ErrNotCached {
|
||||
return cachedAddresses, err
|
||||
}
|
||||
}
|
||||
message := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{
|
||||
|
@ -298,7 +325,7 @@ func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, erro
|
|||
if !cached {
|
||||
return nil, ErrNotCached
|
||||
}
|
||||
return messageToAddresses(&response)
|
||||
return messageToAddresses(response)
|
||||
}
|
||||
|
||||
func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) {
|
||||
|
|
|
@ -91,7 +91,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
|
|||
cancel()
|
||||
conn.err = err
|
||||
if err != nil {
|
||||
t.logger.Warn("connection closed: ", err)
|
||||
t.logger.Debug("connection closed: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
|
|||
cancel()
|
||||
conn.err = err
|
||||
if err != nil {
|
||||
t.logger.Warn("connection closed: ", err)
|
||||
t.logger.Debug("connection closed: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
|
|||
cancel()
|
||||
conn.err = err
|
||||
if err != nil {
|
||||
t.logger.Warn("connection closed: ", err)
|
||||
t.logger.Debug("connection closed: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,11 +2,17 @@ package option
|
|||
|
||||
type DNSOptions struct {
|
||||
Servers []DNSServerOptions `json:"servers,omitempty"`
|
||||
DNSClientOptions
|
||||
}
|
||||
|
||||
type DNSClientOptions struct {
|
||||
DisableCache bool `json:"disable_cache,omitempty"`
|
||||
DisableExpire bool `json:"disable_expire,omitempty"`
|
||||
}
|
||||
|
||||
type DNSServerOptions struct {
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Address string `json:"address"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
AddressResolver string `json:"address_resolver,omitempty"`
|
||||
DialerOptions
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue