mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-25 18:11:28 +00:00
Add disableCache/disableExpire option for dns client
This commit is contained in:
parent
8a761d7e3b
commit
ecac383477
229
dns/client.go
229
dns/client.go
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
|
||||||
"golang.org/x/net/dns/dnsmessage"
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
)
|
)
|
||||||
|
@ -27,12 +28,21 @@ var (
|
||||||
var _ adapter.DNSClient = (*Client)(nil)
|
var _ adapter.DNSClient = (*Client)(nil)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message]
|
cache *cache.LruCache[dnsmessage.Question, *dnsmessage.Message]
|
||||||
|
disableCache bool
|
||||||
|
disableExpire bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient() *Client {
|
func NewClient(options option.DNSClientOptions) *Client {
|
||||||
return &Client{
|
if options.DisableCache {
|
||||||
cache: cache.New[dnsmessage.Question, dnsmessage.Message](),
|
return &Client{
|
||||||
|
disableCache: true,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return &Client{
|
||||||
|
cache: cache.New[dnsmessage.Question, *dnsmessage.Message](),
|
||||||
|
disableExpire: options.DisableExpire,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,10 +51,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
||||||
return nil, E.New("empty query")
|
return nil, E.New("empty query")
|
||||||
}
|
}
|
||||||
question := message.Questions[0]
|
question := message.Questions[0]
|
||||||
cachedAnswer, cached := c.cache.Load(question)
|
if !c.disableCache {
|
||||||
if cached {
|
cachedAnswer, cached := c.cache.Load(question)
|
||||||
cachedAnswer.ID = message.ID
|
if cached {
|
||||||
return &cachedAnswer, nil
|
cachedAnswer.ID = message.ID
|
||||||
|
return cachedAnswer, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !transport.Raw() {
|
if !transport.Raw() {
|
||||||
if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA {
|
if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA {
|
||||||
|
@ -56,7 +68,9 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.cache.StoreWithExpire(question, *response, calculateExpire(message))
|
if !c.disableCache {
|
||||||
|
c.storeCache(question, response)
|
||||||
|
}
|
||||||
return message, err
|
return message, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,37 +107,39 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
|
||||||
}
|
}
|
||||||
return sortAddresses(response4, response6, strategy), nil
|
return sortAddresses(response4, response6, strategy), nil
|
||||||
}
|
}
|
||||||
if strategy == C.DomainStrategyUseIPv4 {
|
if !c.disableCache {
|
||||||
response, err := c.questionCache(dnsmessage.Question{
|
if strategy == C.DomainStrategyUseIPv4 {
|
||||||
Name: dnsName,
|
response, err := c.questionCache(dnsmessage.Question{
|
||||||
Type: dnsmessage.TypeA,
|
Name: dnsName,
|
||||||
Class: dnsmessage.ClassINET,
|
Type: dnsmessage.TypeA,
|
||||||
})
|
Class: dnsmessage.ClassINET,
|
||||||
if err != ErrNotCached {
|
})
|
||||||
return response, err
|
if err != ErrNotCached {
|
||||||
}
|
return response, err
|
||||||
} else if strategy == C.DomainStrategyUseIPv6 {
|
}
|
||||||
response, err := c.questionCache(dnsmessage.Question{
|
} else if strategy == C.DomainStrategyUseIPv6 {
|
||||||
Name: dnsName,
|
response, err := c.questionCache(dnsmessage.Question{
|
||||||
Type: dnsmessage.TypeAAAA,
|
Name: dnsName,
|
||||||
Class: dnsmessage.ClassINET,
|
Type: dnsmessage.TypeAAAA,
|
||||||
})
|
Class: dnsmessage.ClassINET,
|
||||||
if err != ErrNotCached {
|
})
|
||||||
return response, err
|
if err != ErrNotCached {
|
||||||
}
|
return response, err
|
||||||
} else {
|
}
|
||||||
response4, _ := c.questionCache(dnsmessage.Question{
|
} else {
|
||||||
Name: dnsName,
|
response4, _ := c.questionCache(dnsmessage.Question{
|
||||||
Type: dnsmessage.TypeA,
|
Name: dnsName,
|
||||||
Class: dnsmessage.ClassINET,
|
Type: dnsmessage.TypeA,
|
||||||
})
|
Class: dnsmessage.ClassINET,
|
||||||
response6, _ := c.questionCache(dnsmessage.Question{
|
})
|
||||||
Name: dnsName,
|
response6, _ := c.questionCache(dnsmessage.Question{
|
||||||
Type: dnsmessage.TypeAAAA,
|
Name: dnsName,
|
||||||
Class: dnsmessage.ClassINET,
|
Type: dnsmessage.TypeAAAA,
|
||||||
})
|
Class: dnsmessage.ClassINET,
|
||||||
if len(response4) > 0 || len(response6) > 0 {
|
})
|
||||||
return sortAddresses(response4, response6, strategy), nil
|
if len(response4) > 0 || len(response6) > 0 {
|
||||||
|
return sortAddresses(response4, response6, strategy), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var rCode dnsmessage.RCode
|
var rCode dnsmessage.RCode
|
||||||
|
@ -135,70 +151,74 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
|
||||||
} else {
|
} else {
|
||||||
rCode = dnsmessage.RCode(rCodeError)
|
rCode = dnsmessage.RCode(rCodeError)
|
||||||
}
|
}
|
||||||
|
if c.disableCache {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
header := dnsmessage.Header{
|
header := dnsmessage.Header{
|
||||||
Response: true,
|
Response: true,
|
||||||
Authoritative: true,
|
Authoritative: true,
|
||||||
RCode: rCode,
|
RCode: rCode,
|
||||||
}
|
}
|
||||||
expire := time.Now().Add(time.Second * time.Duration(DefaultTTL))
|
if !c.disableCache {
|
||||||
if strategy != C.DomainStrategyUseIPv6 {
|
if strategy != C.DomainStrategyUseIPv6 {
|
||||||
question4 := dnsmessage.Question{
|
question4 := dnsmessage.Question{
|
||||||
Name: dnsName,
|
Name: dnsName,
|
||||||
Type: dnsmessage.TypeA,
|
Type: dnsmessage.TypeA,
|
||||||
Class: dnsmessage.ClassINET,
|
Class: dnsmessage.ClassINET,
|
||||||
}
|
|
||||||
response4 := common.Filter(response, func(addr netip.Addr) bool {
|
|
||||||
return addr.Is4() || addr.Is4In6()
|
|
||||||
})
|
|
||||||
message4 := dnsmessage.Message{
|
|
||||||
Header: header,
|
|
||||||
Questions: []dnsmessage.Question{question4},
|
|
||||||
}
|
|
||||||
if len(response4) > 0 {
|
|
||||||
for _, address := range response4 {
|
|
||||||
message4.Answers = append(message4.Answers, dnsmessage.Resource{
|
|
||||||
Header: dnsmessage.ResourceHeader{
|
|
||||||
Name: question4.Name,
|
|
||||||
Class: question4.Class,
|
|
||||||
TTL: DefaultTTL,
|
|
||||||
},
|
|
||||||
Body: &dnsmessage.AResource{
|
|
||||||
A: address.As4(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
response4 := common.Filter(response, func(addr netip.Addr) bool {
|
||||||
c.cache.StoreWithExpire(question4, message4, expire)
|
return addr.Is4() || addr.Is4In6()
|
||||||
}
|
})
|
||||||
if strategy != C.DomainStrategyUseIPv4 {
|
message4 := &dnsmessage.Message{
|
||||||
question6 := dnsmessage.Question{
|
Header: header,
|
||||||
Name: dnsName,
|
Questions: []dnsmessage.Question{question4},
|
||||||
Type: dnsmessage.TypeAAAA,
|
|
||||||
Class: dnsmessage.ClassINET,
|
|
||||||
}
|
|
||||||
response6 := common.Filter(response, func(addr netip.Addr) bool {
|
|
||||||
return addr.Is6() && !addr.Is4In6()
|
|
||||||
})
|
|
||||||
message6 := dnsmessage.Message{
|
|
||||||
Header: header,
|
|
||||||
Questions: []dnsmessage.Question{question6},
|
|
||||||
}
|
|
||||||
if len(response6) > 0 {
|
|
||||||
for _, address := range response6 {
|
|
||||||
message6.Answers = append(message6.Answers, dnsmessage.Resource{
|
|
||||||
Header: dnsmessage.ResourceHeader{
|
|
||||||
Name: question6.Name,
|
|
||||||
Class: question6.Class,
|
|
||||||
TTL: DefaultTTL,
|
|
||||||
},
|
|
||||||
Body: &dnsmessage.AAAAResource{
|
|
||||||
AAAA: address.As16(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
if len(response4) > 0 {
|
||||||
|
for _, address := range response4 {
|
||||||
|
message4.Answers = append(message4.Answers, dnsmessage.Resource{
|
||||||
|
Header: dnsmessage.ResourceHeader{
|
||||||
|
Name: question4.Name,
|
||||||
|
Class: question4.Class,
|
||||||
|
TTL: DefaultTTL,
|
||||||
|
},
|
||||||
|
Body: &dnsmessage.AResource{
|
||||||
|
A: address.As4(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.storeCache(question4, message4)
|
||||||
|
}
|
||||||
|
if strategy != C.DomainStrategyUseIPv4 {
|
||||||
|
question6 := dnsmessage.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Type: dnsmessage.TypeAAAA,
|
||||||
|
Class: dnsmessage.ClassINET,
|
||||||
|
}
|
||||||
|
response6 := common.Filter(response, func(addr netip.Addr) bool {
|
||||||
|
return addr.Is6() && !addr.Is4In6()
|
||||||
|
})
|
||||||
|
message6 := &dnsmessage.Message{
|
||||||
|
Header: header,
|
||||||
|
Questions: []dnsmessage.Question{question6},
|
||||||
|
}
|
||||||
|
if len(response6) > 0 {
|
||||||
|
for _, address := range response6 {
|
||||||
|
message6.Answers = append(message6.Answers, dnsmessage.Resource{
|
||||||
|
Header: dnsmessage.ResourceHeader{
|
||||||
|
Name: question6.Name,
|
||||||
|
Class: question6.Class,
|
||||||
|
TTL: DefaultTTL,
|
||||||
|
},
|
||||||
|
Body: &dnsmessage.AAAAResource{
|
||||||
|
AAAA: address.As16(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.storeCache(question6, message6)
|
||||||
}
|
}
|
||||||
c.cache.StoreWithExpire(question6, message6, expire)
|
|
||||||
}
|
}
|
||||||
return response, err
|
return response, err
|
||||||
}
|
}
|
||||||
|
@ -211,14 +231,19 @@ func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.Do
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateExpire(message *dnsmessage.Message) time.Time {
|
func (c *Client) storeCache(question dnsmessage.Question, message *dnsmessage.Message) {
|
||||||
|
if c.disableExpire {
|
||||||
|
c.cache.Store(question, message)
|
||||||
|
return
|
||||||
|
}
|
||||||
timeToLive := DefaultTTL
|
timeToLive := DefaultTTL
|
||||||
for _, answer := range message.Answers {
|
for _, answer := range message.Answers {
|
||||||
if int(answer.Header.TTL) < timeToLive {
|
if int(answer.Header.TTL) < timeToLive {
|
||||||
timeToLive = int(answer.Header.TTL)
|
timeToLive = int(answer.Header.TTL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return time.Now().Add(time.Second * time.Duration(timeToLive))
|
expire := time.Now().Add(time.Second * time.Duration(timeToLive))
|
||||||
|
c.cache.StoreWithExpire(question, message, expire)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) {
|
func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) {
|
||||||
|
@ -275,9 +300,11 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
|
||||||
Type: qType,
|
Type: qType,
|
||||||
Class: dnsmessage.ClassINET,
|
Class: dnsmessage.ClassINET,
|
||||||
}
|
}
|
||||||
cachedAddresses, err := c.questionCache(question)
|
if !c.disableCache {
|
||||||
if err != ErrNotCached {
|
cachedAddresses, err := c.questionCache(question)
|
||||||
return cachedAddresses, err
|
if err != ErrNotCached {
|
||||||
|
return cachedAddresses, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
message := dnsmessage.Message{
|
message := dnsmessage.Message{
|
||||||
Header: dnsmessage.Header{
|
Header: dnsmessage.Header{
|
||||||
|
@ -298,7 +325,7 @@ func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, erro
|
||||||
if !cached {
|
if !cached {
|
||||||
return nil, ErrNotCached
|
return nil, ErrNotCached
|
||||||
}
|
}
|
||||||
return messageToAddresses(&response)
|
return messageToAddresses(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) {
|
func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) {
|
||||||
|
|
|
@ -91,7 +91,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
|
||||||
cancel()
|
cancel()
|
||||||
conn.err = err
|
conn.err = err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.logger.Warn("connection closed: ", err)
|
t.logger.Debug("connection closed: ", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
|
||||||
cancel()
|
cancel()
|
||||||
conn.err = err
|
conn.err = err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.logger.Warn("connection closed: ", err)
|
t.logger.Debug("connection closed: ", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
|
||||||
cancel()
|
cancel()
|
||||||
conn.err = err
|
conn.err = err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.logger.Warn("connection closed: ", err)
|
t.logger.Debug("connection closed: ", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,17 @@ package option
|
||||||
|
|
||||||
type DNSOptions struct {
|
type DNSOptions struct {
|
||||||
Servers []DNSServerOptions `json:"servers,omitempty"`
|
Servers []DNSServerOptions `json:"servers,omitempty"`
|
||||||
|
DNSClientOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
type DNSClientOptions struct {
|
||||||
|
DisableCache bool `json:"disable_cache,omitempty"`
|
||||||
|
DisableExpire bool `json:"disable_expire,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DNSServerOptions struct {
|
type DNSServerOptions struct {
|
||||||
Tag string `json:"tag,omitempty"`
|
Tag string `json:"tag,omitempty"`
|
||||||
Address string `json:"address"`
|
Address string `json:"address"`
|
||||||
Detour string `json:"detour,omitempty"`
|
|
||||||
AddressResolver string `json:"address_resolver,omitempty"`
|
AddressResolver string `json:"address_resolver,omitempty"`
|
||||||
|
DialerOptions
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue