Allow nested logical rules

This commit is contained in:
世界 2023-11-28 23:47:32 +08:00
parent ad93b45021
commit 340e74eed4
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 79 additions and 47 deletions

View file

@ -5,6 +5,7 @@ import (
"os" "os"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -27,24 +28,37 @@ func NewClashServer(ctx context.Context, router adapter.Router, logFactory log.O
func CalculateClashModeList(options option.Options) []string { func CalculateClashModeList(options option.Options) []string {
var clashMode []string var clashMode []string
for _, dnsRule := range common.PtrValueOrDefault(options.DNS).Rules { clashMode = append(clashMode, extraClashModeFromRule(common.PtrValueOrDefault(options.Route).Rules)...)
if dnsRule.DefaultOptions.ClashMode != "" && !common.Contains(clashMode, dnsRule.DefaultOptions.ClashMode) { clashMode = append(clashMode, extraClashModeFromDNSRule(common.PtrValueOrDefault(options.DNS).Rules)...)
clashMode = append(clashMode, dnsRule.DefaultOptions.ClashMode) clashMode = common.FilterNotDefault(common.Uniq(clashMode))
} return clashMode
for _, defaultRule := range dnsRule.LogicalOptions.Rules { }
if defaultRule.ClashMode != "" && !common.Contains(clashMode, defaultRule.ClashMode) {
clashMode = append(clashMode, defaultRule.ClashMode) func extraClashModeFromRule(rules []option.Rule) []string {
} var clashMode []string
} for _, rule := range rules {
} switch rule.Type {
for _, rule := range common.PtrValueOrDefault(options.Route).Rules { case C.RuleTypeDefault:
if rule.DefaultOptions.ClashMode != "" && !common.Contains(clashMode, rule.DefaultOptions.ClashMode) { if rule.DefaultOptions.ClashMode != "" {
clashMode = append(clashMode, rule.DefaultOptions.ClashMode) clashMode = append(clashMode, rule.DefaultOptions.ClashMode)
}
for _, defaultRule := range rule.LogicalOptions.Rules {
if defaultRule.ClashMode != "" && !common.Contains(clashMode, defaultRule.ClashMode) {
clashMode = append(clashMode, defaultRule.ClashMode)
} }
case C.RuleTypeLogical:
clashMode = append(clashMode, extraClashModeFromRule(rule.LogicalOptions.Rules)...)
}
}
return clashMode
}
func extraClashModeFromDNSRule(rules []option.DNSRule) []string {
var clashMode []string
for _, rule := range rules {
switch rule.Type {
case C.RuleTypeDefault:
if rule.DefaultOptions.ClashMode != "" {
clashMode = append(clashMode, rule.DefaultOptions.ClashMode)
}
case C.RuleTypeLogical:
clashMode = append(clashMode, extraClashModeFromDNSRule(rule.LogicalOptions.Rules)...)
} }
} }
return clashMode return clashMode

View file

@ -53,6 +53,17 @@ func (r *Rule) UnmarshalJSON(bytes []byte) error {
return nil return nil
} }
func (r Rule) IsValid() bool {
switch r.Type {
case C.RuleTypeDefault:
return r.DefaultOptions.IsValid()
case C.RuleTypeLogical:
return r.LogicalOptions.IsValid()
default:
panic("unknown rule type: " + r.Type)
}
}
type DefaultRule struct { type DefaultRule struct {
Inbound Listable[string] `json:"inbound,omitempty"` Inbound Listable[string] `json:"inbound,omitempty"`
IPVersion int `json:"ip_version,omitempty"` IPVersion int `json:"ip_version,omitempty"`
@ -92,12 +103,12 @@ func (r DefaultRule) IsValid() bool {
} }
type LogicalRule struct { type LogicalRule struct {
Mode string `json:"mode"` Mode string `json:"mode"`
Rules []DefaultRule `json:"rules,omitempty"` Rules []Rule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"` Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"` Outbound string `json:"outbound,omitempty"`
} }
func (r LogicalRule) IsValid() bool { func (r LogicalRule) IsValid() bool {
return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid) return len(r.Rules) > 0 && common.All(r.Rules, Rule.IsValid)
} }

View file

@ -53,6 +53,17 @@ func (r *DNSRule) UnmarshalJSON(bytes []byte) error {
return nil return nil
} }
func (r DNSRule) IsValid() bool {
switch r.Type {
case C.RuleTypeDefault:
return r.DefaultOptions.IsValid()
case C.RuleTypeLogical:
return r.LogicalOptions.IsValid()
default:
panic("unknown DNS rule type: " + r.Type)
}
}
type DefaultDNSRule struct { type DefaultDNSRule struct {
Inbound Listable[string] `json:"inbound,omitempty"` Inbound Listable[string] `json:"inbound,omitempty"`
IPVersion int `json:"ip_version,omitempty"` IPVersion int `json:"ip_version,omitempty"`
@ -96,14 +107,14 @@ func (r DefaultDNSRule) IsValid() bool {
} }
type LogicalDNSRule struct { type LogicalDNSRule struct {
Mode string `json:"mode"` Mode string `json:"mode"`
Rules []DefaultDNSRule `json:"rules,omitempty"` Rules []DNSRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"` Invert bool `json:"invert,omitempty"`
Server string `json:"server,omitempty"` Server string `json:"server,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"` DisableCache bool `json:"disable_cache,omitempty"`
RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"`
} }
func (r LogicalDNSRule) IsValid() bool { func (r LogicalDNSRule) IsValid() bool {
return len(r.Rules) > 0 && common.All(r.Rules, DefaultDNSRule.IsValid) return len(r.Rules) > 0 && common.All(r.Rules, DNSRule.IsValid)
} }

View file

@ -128,14 +128,14 @@ func NewRouter(
Logger: router.dnsLogger, Logger: router.dnsLogger,
}) })
for i, ruleOptions := range options.Rules { for i, ruleOptions := range options.Rules {
routeRule, err := NewRule(router, router.logger, ruleOptions) routeRule, err := NewRule(router, router.logger, ruleOptions, true)
if err != nil { if err != nil {
return nil, E.Cause(err, "parse rule[", i, "]") return nil, E.Cause(err, "parse rule[", i, "]")
} }
router.rules = append(router.rules, routeRule) router.rules = append(router.rules, routeRule)
} }
for i, dnsRuleOptions := range dnsOptions.Rules { for i, dnsRuleOptions := range dnsOptions.Rules {
dnsRule, err := NewDNSRule(router, router.logger, dnsRuleOptions) dnsRule, err := NewDNSRule(router, router.logger, dnsRuleOptions, true)
if err != nil { if err != nil {
return nil, E.Cause(err, "parse dns rule[", i, "]") return nil, E.Cause(err, "parse dns rule[", i, "]")
} }

View file

@ -252,10 +252,8 @@ func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool
return true return true
} }
case C.RuleTypeLogical: case C.RuleTypeLogical:
for _, subRule := range rule.LogicalOptions.Rules { if hasRule(rule.LogicalOptions.Rules, cond) {
if cond(subRule) { return true
return true
}
} }
} }
} }
@ -270,10 +268,8 @@ func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bo
return true return true
} }
case C.RuleTypeLogical: case C.RuleTypeLogical:
for _, subRule := range rule.LogicalOptions.Rules { if hasDNSRule(rule.LogicalOptions.Rules, cond) {
if cond(subRule) { return true
return true
}
} }
} }
} }

View file

@ -8,13 +8,13 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.Rule, error) { func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule, checkOutbound bool) (adapter.Rule, error) {
switch options.Type { switch options.Type {
case "", C.RuleTypeDefault: case "", C.RuleTypeDefault:
if !options.DefaultOptions.IsValid() { if !options.DefaultOptions.IsValid() {
return nil, E.New("missing conditions") return nil, E.New("missing conditions")
} }
if options.DefaultOptions.Outbound == "" { if options.DefaultOptions.Outbound == "" && checkOutbound {
return nil, E.New("missing outbound field") return nil, E.New("missing outbound field")
} }
return NewDefaultRule(router, logger, options.DefaultOptions) return NewDefaultRule(router, logger, options.DefaultOptions)
@ -22,7 +22,7 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul
if !options.LogicalOptions.IsValid() { if !options.LogicalOptions.IsValid() {
return nil, E.New("missing conditions") return nil, E.New("missing conditions")
} }
if options.LogicalOptions.Outbound == "" { if options.LogicalOptions.Outbound == "" && checkOutbound {
return nil, E.New("missing outbound field") return nil, E.New("missing outbound field")
} }
return NewLogicalRule(router, logger, options.LogicalOptions) return NewLogicalRule(router, logger, options.LogicalOptions)
@ -220,7 +220,7 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt
return nil, E.New("unknown logical mode: ", options.Mode) return nil, E.New("unknown logical mode: ", options.Mode)
} }
for i, subRule := range options.Rules { for i, subRule := range options.Rules {
rule, err := NewDefaultRule(router, logger, subRule) rule, err := NewRule(router, logger, subRule, false)
if err != nil { if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]") return nil, E.Cause(err, "sub rule[", i, "]")
} }

View file

@ -8,13 +8,13 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) { func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule, checkServer bool) (adapter.DNSRule, error) {
switch options.Type { switch options.Type {
case "", C.RuleTypeDefault: case "", C.RuleTypeDefault:
if !options.DefaultOptions.IsValid() { if !options.DefaultOptions.IsValid() {
return nil, E.New("missing conditions") return nil, E.New("missing conditions")
} }
if options.DefaultOptions.Server == "" { if options.DefaultOptions.Server == "" && checkServer {
return nil, E.New("missing server field") return nil, E.New("missing server field")
} }
return NewDefaultDNSRule(router, logger, options.DefaultOptions) return NewDefaultDNSRule(router, logger, options.DefaultOptions)
@ -22,7 +22,7 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.
if !options.LogicalOptions.IsValid() { if !options.LogicalOptions.IsValid() {
return nil, E.New("missing conditions") return nil, E.New("missing conditions")
} }
if options.LogicalOptions.Server == "" { if options.LogicalOptions.Server == "" && checkServer {
return nil, E.New("missing server field") return nil, E.New("missing server field")
} }
return NewLogicalDNSRule(router, logger, options.LogicalOptions) return NewLogicalDNSRule(router, logger, options.LogicalOptions)
@ -228,7 +228,7 @@ func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options
return nil, E.New("unknown logical mode: ", options.Mode) return nil, E.New("unknown logical mode: ", options.Mode)
} }
for i, subRule := range options.Rules { for i, subRule := range options.Rules {
rule, err := NewDefaultDNSRule(router, logger, subRule) rule, err := NewDNSRule(router, logger, subRule, false)
if err != nil { if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]") return nil, E.Cause(err, "sub rule[", i, "]")
} }