Add disable_cache option to dns rule

This commit is contained in:
世界 2022-07-24 14:05:06 +08:00
parent 8666631732
commit af19ba6119
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 127 additions and 82 deletions

View file

@ -50,3 +50,8 @@ type Rule interface {
Outbound() string
String() string
}
type DNSRule interface {
Rule
DisableCache() bool
}

View file

@ -55,6 +55,7 @@ func (r DNSRule) MarshalJSON() ([]byte, error) {
var v any
switch r.Type {
case C.RuleTypeDefault:
r.Type = ""
v = r.DefaultOptions
case C.RuleTypeLogical:
v = r.LogicalOptions
@ -109,6 +110,7 @@ type DefaultDNSRule struct {
Outbound Listable[string] `json:"outbound,omitempty"`
Invert bool `json:"invert,omitempty"`
Server string `json:"server,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"`
}
func (r DefaultDNSRule) IsValid() bool {
@ -135,13 +137,17 @@ func (r DefaultDNSRule) Equals(other DefaultDNSRule) bool {
common.ComparableSliceEquals(r.UserID, other.UserID) &&
common.ComparableSliceEquals(r.PackageName, other.PackageName) &&
common.ComparableSliceEquals(r.Outbound, other.Outbound) &&
r.Server == other.Server
r.Invert == other.Invert &&
r.Server == other.Server &&
r.DisableCache == other.DisableCache
}
type LogicalDNSRule struct {
Mode string `json:"mode"`
Rules []DefaultDNSRule `json:"rules,omitempty"`
Server string `json:"server,omitempty"`
Mode string `json:"mode"`
Rules []DefaultDNSRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Server string `json:"server,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"`
}
func (r LogicalDNSRule) IsValid() bool {
@ -151,5 +157,7 @@ func (r LogicalDNSRule) IsValid() bool {
func (r LogicalDNSRule) Equals(other LogicalDNSRule) bool {
return r.Mode == other.Mode &&
common.SliceEquals(r.Rules, other.Rules) &&
r.Server == other.Server
r.Invert == other.Invert &&
r.Server == other.Server &&
r.DisableCache == other.DisableCache
}

View file

@ -145,6 +145,7 @@ func (r DefaultRule) Equals(other DefaultRule) bool {
type LogicalRule struct {
Mode string `json:"mode"`
Rules []DefaultRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
}
@ -155,5 +156,6 @@ func (r LogicalRule) IsValid() bool {
func (r LogicalRule) Equals(other LogicalRule) bool {
return r.Mode == other.Mode &&
common.SliceEquals(r.Rules, other.Rules) &&
r.Invert == other.Invert &&
r.Outbound == other.Outbound
}

View file

@ -59,7 +59,7 @@ type Router struct {
geositeCache map[string]adapter.Rule
dnsClient *dns.Client
defaultDomainStrategy dns.DomainStrategy
dnsRules []adapter.Rule
dnsRules []adapter.DNSRule
defaultTransport dns.Transport
transports []dns.Transport
transportMap map[string]dns.Transport
@ -80,7 +80,7 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont
dnsLogger: dnsLogger,
outboundByTag: make(map[string]adapter.Outbound),
rules: make([]adapter.Rule, 0, len(options.Rules)),
dnsRules: make([]adapter.Rule, 0, len(dnsOptions.Rules)),
dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule),
needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule),
geoIPOptions: common.PtrValueOrDefault(options.GeoIP),
@ -536,15 +536,18 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
}
func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
return r.dnsClient.Exchange(ctx, r.matchDNS(ctx), message)
ctx, transport := r.matchDNS(ctx)
return r.dnsClient.Exchange(ctx, transport, message)
}
func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, strategy)
ctx, transport := r.matchDNS(ctx)
return r.dnsClient.Lookup(ctx, transport, domain, strategy)
}
func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, r.defaultDomainStrategy)
ctx, transport := r.matchDNS(ctx)
return r.dnsClient.Lookup(ctx, transport, domain, r.defaultDomainStrategy)
}
func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
@ -586,23 +589,26 @@ func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, de
return nil, defaultOutbound
}
func (r *Router) matchDNS(ctx context.Context) dns.Transport {
func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
r.dnsLogger.WarnContext(ctx, "no context: ", reflect.TypeOf(ctx))
return r.defaultTransport
return ctx, r.defaultTransport
}
for i, rule := range r.dnsRules {
if rule.Match(metadata) {
if rule.DisableCache() {
ctx = dns.ContextWithDisableCache(ctx, true)
}
detour := rule.Outbound()
r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
if transport, loaded := r.transportMap[detour]; loaded {
return transport
return ctx, transport
}
r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour)
}
}
return r.defaultTransport
return ctx, r.defaultTransport
}
func (r *Router) InterfaceBindManager() control.BindManager {

View file

@ -49,10 +49,6 @@ type DefaultRule struct {
outbound string
}
func (r *DefaultRule) Type() string {
return C.RuleTypeDefault
}
type RuleItem interface {
Match(metadata *adapter.InboundContext) bool
String() string
@ -180,6 +176,10 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
return rule, nil
}
func (r *DefaultRule) Type() string {
return C.RuleTypeDefault
}
func (r *DefaultRule) Start() error {
for _, item := range r.allItems {
err := common.Start(item)
@ -261,9 +261,34 @@ var _ adapter.Rule = (*LogicalRule)(nil)
type LogicalRule struct {
mode string
rules []*DefaultRule
invert bool
outbound string
}
func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
r := &LogicalRule{
rules: make([]*DefaultRule, len(options.Rules)),
invert: options.Invert,
outbound: options.Outbound,
}
switch options.Mode {
case C.LogicalTypeAnd:
r.mode = C.LogicalTypeAnd
case C.LogicalTypeOr:
r.mode = C.LogicalTypeOr
default:
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultRule(router, logger, subRule)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
r.rules[i] = rule
}
return r, nil
}
func (r *LogicalRule) Type() string {
return C.RuleTypeLogical
}
@ -298,38 +323,15 @@ func (r *LogicalRule) Close() error {
return nil
}
func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
r := &LogicalRule{
rules: make([]*DefaultRule, len(options.Rules)),
outbound: options.Outbound,
}
switch options.Mode {
case C.LogicalTypeAnd:
r.mode = C.LogicalTypeAnd
case C.LogicalTypeOr:
r.mode = C.LogicalTypeOr
default:
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultRule(router, logger, subRule)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
r.rules[i] = rule
}
return r, nil
}
func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool {
if r.mode == C.LogicalTypeAnd {
return common.All(r.rules, func(it *DefaultRule) bool {
return it.Match(metadata)
})
}) != r.invert
} else {
return common.Any(r.rules, func(it *DefaultRule) bool {
return it.Match(metadata)
})
}) != r.invert
}
}
@ -345,5 +347,9 @@ func (r *LogicalRule) String() string {
case C.LogicalTypeOr:
op = "||"
}
return "logical(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
if !r.invert {
return strings.Join(F.MapToString(r.rules), " "+op+" ")
} else {
return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
}
}

View file

@ -12,7 +12,7 @@ import (
F "github.com/sagernet/sing/common/format"
)
func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.Rule, error) {
func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) {
if common.IsEmptyByEquals(options) {
return nil, E.New("empty rule config")
}
@ -38,7 +38,7 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.
}
}
var _ adapter.Rule = (*DefaultDNSRule)(nil)
var _ adapter.DNSRule = (*DefaultDNSRule)(nil)
type DefaultDNSRule struct {
items []RuleItem
@ -46,16 +46,14 @@ type DefaultDNSRule struct {
allItems []RuleItem
invert bool
outbound string
}
func (r *DefaultDNSRule) Type() string {
return C.RuleTypeDefault
disableCache bool
}
func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
rule := &DefaultDNSRule{
invert: true,
outbound: options.Server,
invert: options.Invert,
outbound: options.Server,
disableCache: options.DisableCache,
}
if len(options.Inbound) > 0 {
item := NewInboundRule(options.Inbound)
@ -156,6 +154,10 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
return rule, nil
}
func (r *DefaultDNSRule) Type() string {
return C.RuleTypeDefault
}
func (r *DefaultDNSRule) Start() error {
for _, item := range r.allItems {
err := common.Start(item)
@ -213,16 +215,47 @@ func (r *DefaultDNSRule) Outbound() string {
return r.outbound
}
func (r *DefaultDNSRule) DisableCache() bool {
return r.disableCache
}
func (r *DefaultDNSRule) String() string {
return strings.Join(F.MapToString(r.allItems), " ")
}
var _ adapter.Rule = (*LogicalRule)(nil)
var _ adapter.DNSRule = (*LogicalDNSRule)(nil)
type LogicalDNSRule struct {
mode string
rules []*DefaultDNSRule
outbound string
mode string
rules []*DefaultDNSRule
invert bool
outbound string
disableCache bool
}
func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
r := &LogicalDNSRule{
rules: make([]*DefaultDNSRule, len(options.Rules)),
invert: options.Invert,
outbound: options.Server,
disableCache: options.DisableCache,
}
switch options.Mode {
case C.LogicalTypeAnd:
r.mode = C.LogicalTypeAnd
case C.LogicalTypeOr:
r.mode = C.LogicalTypeOr
default:
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultDNSRule(router, logger, subRule)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
r.rules[i] = rule
}
return r, nil
}
func (r *LogicalDNSRule) Type() string {
@ -259,38 +292,15 @@ func (r *LogicalDNSRule) Close() error {
return nil
}
func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
r := &LogicalDNSRule{
rules: make([]*DefaultDNSRule, len(options.Rules)),
outbound: options.Server,
}
switch options.Mode {
case C.LogicalTypeAnd:
r.mode = C.LogicalTypeAnd
case C.LogicalTypeOr:
r.mode = C.LogicalTypeOr
default:
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultDNSRule(router, logger, subRule)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
r.rules[i] = rule
}
return r, nil
}
func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
if r.mode == C.LogicalTypeAnd {
return common.All(r.rules, func(it *DefaultDNSRule) bool {
return it.Match(metadata)
})
}) != r.invert
} else {
return common.Any(r.rules, func(it *DefaultDNSRule) bool {
return it.Match(metadata)
})
}) != r.invert
}
}
@ -298,6 +308,10 @@ func (r *LogicalDNSRule) Outbound() string {
return r.outbound
}
func (r *LogicalDNSRule) DisableCache() bool {
return r.disableCache
}
func (r *LogicalDNSRule) String() string {
var op string
switch r.mode {
@ -306,5 +320,9 @@ func (r *LogicalDNSRule) String() string {
case C.LogicalTypeOr:
op = "||"
}
return "logical(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
if !r.invert {
return strings.Join(F.MapToString(r.rules), " "+op+" ")
} else {
return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
}
}