From 340e74eed4644b22349319c3449327742a170585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 28 Nov 2023 23:47:32 +0800 Subject: [PATCH] Allow nested logical rules --- experimental/clashapi.go | 48 ++++++++++++++++++++++------------- option/rule.go | 21 +++++++++++---- option/rule_dns.go | 25 +++++++++++++----- route/router.go | 4 +-- route/router_geo_resources.go | 12 +++------ route/rule_default.go | 8 +++--- route/rule_dns.go | 8 +++--- 7 files changed, 79 insertions(+), 47 deletions(-) diff --git a/experimental/clashapi.go b/experimental/clashapi.go index 894d40a7..805fbd5b 100644 --- a/experimental/clashapi.go +++ b/experimental/clashapi.go @@ -5,6 +5,7 @@ import ( "os" "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" @@ -27,24 +28,37 @@ func NewClashServer(ctx context.Context, router adapter.Router, logFactory log.O func CalculateClashModeList(options option.Options) []string { var clashMode []string - for _, dnsRule := range common.PtrValueOrDefault(options.DNS).Rules { - if dnsRule.DefaultOptions.ClashMode != "" && !common.Contains(clashMode, dnsRule.DefaultOptions.ClashMode) { - clashMode = append(clashMode, dnsRule.DefaultOptions.ClashMode) - } - for _, defaultRule := range dnsRule.LogicalOptions.Rules { - if defaultRule.ClashMode != "" && !common.Contains(clashMode, defaultRule.ClashMode) { - clashMode = append(clashMode, defaultRule.ClashMode) - } - } - } - for _, rule := range common.PtrValueOrDefault(options.Route).Rules { - if rule.DefaultOptions.ClashMode != "" && !common.Contains(clashMode, rule.DefaultOptions.ClashMode) { - clashMode = append(clashMode, rule.DefaultOptions.ClashMode) - } - for _, defaultRule := range rule.LogicalOptions.Rules { - if defaultRule.ClashMode != "" && !common.Contains(clashMode, defaultRule.ClashMode) { - clashMode = append(clashMode, defaultRule.ClashMode) + clashMode = append(clashMode, extraClashModeFromRule(common.PtrValueOrDefault(options.Route).Rules)...) + clashMode = append(clashMode, extraClashModeFromDNSRule(common.PtrValueOrDefault(options.DNS).Rules)...) + clashMode = common.FilterNotDefault(common.Uniq(clashMode)) + return clashMode +} + +func extraClashModeFromRule(rules []option.Rule) []string { + var clashMode []string + for _, rule := range rules { + switch rule.Type { + case C.RuleTypeDefault: + if rule.DefaultOptions.ClashMode != "" { + clashMode = append(clashMode, rule.DefaultOptions.ClashMode) } + case C.RuleTypeLogical: + clashMode = append(clashMode, extraClashModeFromRule(rule.LogicalOptions.Rules)...) + } + } + return clashMode +} + +func extraClashModeFromDNSRule(rules []option.DNSRule) []string { + var clashMode []string + for _, rule := range rules { + switch rule.Type { + case C.RuleTypeDefault: + if rule.DefaultOptions.ClashMode != "" { + clashMode = append(clashMode, rule.DefaultOptions.ClashMode) + } + case C.RuleTypeLogical: + clashMode = append(clashMode, extraClashModeFromDNSRule(rule.LogicalOptions.Rules)...) } } return clashMode diff --git a/option/rule.go b/option/rule.go index 8caba96e..4f404202 100644 --- a/option/rule.go +++ b/option/rule.go @@ -53,6 +53,17 @@ func (r *Rule) UnmarshalJSON(bytes []byte) error { return nil } +func (r Rule) IsValid() bool { + switch r.Type { + case C.RuleTypeDefault: + return r.DefaultOptions.IsValid() + case C.RuleTypeLogical: + return r.LogicalOptions.IsValid() + default: + panic("unknown rule type: " + r.Type) + } +} + type DefaultRule struct { Inbound Listable[string] `json:"inbound,omitempty"` IPVersion int `json:"ip_version,omitempty"` @@ -92,12 +103,12 @@ func (r DefaultRule) IsValid() bool { } type LogicalRule struct { - Mode string `json:"mode"` - Rules []DefaultRule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Outbound string `json:"outbound,omitempty"` + Mode string `json:"mode"` + Rules []Rule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Outbound string `json:"outbound,omitempty"` } func (r LogicalRule) IsValid() bool { - return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid) + return len(r.Rules) > 0 && common.All(r.Rules, Rule.IsValid) } diff --git a/option/rule_dns.go b/option/rule_dns.go index ba572b9a..fca34322 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -53,6 +53,17 @@ func (r *DNSRule) UnmarshalJSON(bytes []byte) error { return nil } +func (r DNSRule) IsValid() bool { + switch r.Type { + case C.RuleTypeDefault: + return r.DefaultOptions.IsValid() + case C.RuleTypeLogical: + return r.LogicalOptions.IsValid() + default: + panic("unknown DNS rule type: " + r.Type) + } +} + type DefaultDNSRule struct { Inbound Listable[string] `json:"inbound,omitempty"` IPVersion int `json:"ip_version,omitempty"` @@ -96,14 +107,14 @@ func (r DefaultDNSRule) IsValid() bool { } type LogicalDNSRule struct { - Mode string `json:"mode"` - Rules []DefaultDNSRule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Server string `json:"server,omitempty"` - DisableCache bool `json:"disable_cache,omitempty"` - RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` + Mode string `json:"mode"` + Rules []DNSRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` + RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` } func (r LogicalDNSRule) IsValid() bool { - return len(r.Rules) > 0 && common.All(r.Rules, DefaultDNSRule.IsValid) + return len(r.Rules) > 0 && common.All(r.Rules, DNSRule.IsValid) } diff --git a/route/router.go b/route/router.go index 2e5514ac..21455831 100644 --- a/route/router.go +++ b/route/router.go @@ -128,14 +128,14 @@ func NewRouter( Logger: router.dnsLogger, }) for i, ruleOptions := range options.Rules { - routeRule, err := NewRule(router, router.logger, ruleOptions) + routeRule, err := NewRule(router, router.logger, ruleOptions, true) if err != nil { return nil, E.Cause(err, "parse rule[", i, "]") } router.rules = append(router.rules, routeRule) } for i, dnsRuleOptions := range dnsOptions.Rules { - dnsRule, err := NewDNSRule(router, router.logger, dnsRuleOptions) + dnsRule, err := NewDNSRule(router, router.logger, dnsRuleOptions, true) if err != nil { return nil, E.Cause(err, "parse dns rule[", i, "]") } diff --git a/route/router_geo_resources.go b/route/router_geo_resources.go index 8715cf92..638d00df 100644 --- a/route/router_geo_resources.go +++ b/route/router_geo_resources.go @@ -252,10 +252,8 @@ func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool return true } case C.RuleTypeLogical: - for _, subRule := range rule.LogicalOptions.Rules { - if cond(subRule) { - return true - } + if hasRule(rule.LogicalOptions.Rules, cond) { + return true } } } @@ -270,10 +268,8 @@ func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bo return true } case C.RuleTypeLogical: - for _, subRule := range rule.LogicalOptions.Rules { - if cond(subRule) { - return true - } + if hasDNSRule(rule.LogicalOptions.Rules, cond) { + return true } } } diff --git a/route/rule_default.go b/route/rule_default.go index 2d62f97a..8c8473ab 100644 --- a/route/rule_default.go +++ b/route/rule_default.go @@ -8,13 +8,13 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.Rule, error) { +func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule, checkOutbound bool) (adapter.Rule, error) { switch options.Type { case "", C.RuleTypeDefault: if !options.DefaultOptions.IsValid() { return nil, E.New("missing conditions") } - if options.DefaultOptions.Outbound == "" { + if options.DefaultOptions.Outbound == "" && checkOutbound { return nil, E.New("missing outbound field") } return NewDefaultRule(router, logger, options.DefaultOptions) @@ -22,7 +22,7 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul if !options.LogicalOptions.IsValid() { return nil, E.New("missing conditions") } - if options.LogicalOptions.Outbound == "" { + if options.LogicalOptions.Outbound == "" && checkOutbound { return nil, E.New("missing outbound field") } return NewLogicalRule(router, logger, options.LogicalOptions) @@ -220,7 +220,7 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt return nil, E.New("unknown logical mode: ", options.Mode) } for i, subRule := range options.Rules { - rule, err := NewDefaultRule(router, logger, subRule) + rule, err := NewRule(router, logger, subRule, false) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") } diff --git a/route/rule_dns.go b/route/rule_dns.go index 5132f024..b4449325 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -8,13 +8,13 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) { +func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule, checkServer bool) (adapter.DNSRule, error) { switch options.Type { case "", C.RuleTypeDefault: if !options.DefaultOptions.IsValid() { return nil, E.New("missing conditions") } - if options.DefaultOptions.Server == "" { + if options.DefaultOptions.Server == "" && checkServer { return nil, E.New("missing server field") } return NewDefaultDNSRule(router, logger, options.DefaultOptions) @@ -22,7 +22,7 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option. if !options.LogicalOptions.IsValid() { return nil, E.New("missing conditions") } - if options.LogicalOptions.Server == "" { + if options.LogicalOptions.Server == "" && checkServer { return nil, E.New("missing server field") } return NewLogicalDNSRule(router, logger, options.LogicalOptions) @@ -228,7 +228,7 @@ func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options return nil, E.New("unknown logical mode: ", options.Mode) } for i, subRule := range options.Rules { - rule, err := NewDefaultDNSRule(router, logger, subRule) + rule, err := NewDNSRule(router, logger, subRule, false) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") }