From 8e163e0a7d78f3604a8d053f0267f6374866aa1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 26 Jun 2024 00:43:51 +0800 Subject: [PATCH] Add inline rule-set & Add reload for local rule-set --- cmd/sing-box/cmd_rule_set_compile.go | 5 +- cmd/sing-box/cmd_rule_set_match.go | 5 +- cmd/sing-box/cmd_rule_set_upgrade.go | 5 +- common/tls/ech_server.go | 193 +++++++++------------------ common/tls/std_server.go | 65 +++------ constant/rule.go | 1 + docs/configuration/rule-set/index.md | 92 ++++++++----- docs/configuration/shared/tls.md | 18 ++- docs/configuration/shared/tls.zh.md | 16 ++- go.mod | 4 +- option/rule_set.go | 33 +++-- route/rule_set.go | 4 +- route/rule_set_local.go | 141 +++++++++++++------ route/rule_set_remote.go | 5 +- 14 files changed, 309 insertions(+), 278 deletions(-) diff --git a/cmd/sing-box/cmd_rule_set_compile.go b/cmd/sing-box/cmd_rule_set_compile.go index 7e3753c9..4fae4d99 100644 --- a/cmd/sing-box/cmd_rule_set_compile.go +++ b/cmd/sing-box/cmd_rule_set_compile.go @@ -56,7 +56,10 @@ func compileRuleSet(sourcePath string) error { if err != nil { return err } - ruleSet := plainRuleSet.Upgrade() + ruleSet, err := plainRuleSet.Upgrade() + if err != nil { + return err + } var outputPath string if flagRuleSetCompileOutput == flagRuleSetCompileDefaultOutput { if strings.HasSuffix(sourcePath, ".json") { diff --git a/cmd/sing-box/cmd_rule_set_match.go b/cmd/sing-box/cmd_rule_set_match.go index 8bf2ec7e..937458f2 100644 --- a/cmd/sing-box/cmd_rule_set_match.go +++ b/cmd/sing-box/cmd_rule_set_match.go @@ -63,7 +63,10 @@ func ruleSetMatch(sourcePath string, domain string) error { if err != nil { return err } - plainRuleSet = compat.Upgrade() + plainRuleSet, err = compat.Upgrade() + if err != nil { + return err + } case C.RuleSetFormatBinary: plainRuleSet, err = srs.Read(bytes.NewReader(content), false) if err != nil { diff --git a/cmd/sing-box/cmd_rule_set_upgrade.go b/cmd/sing-box/cmd_rule_set_upgrade.go index 0ec039fd..e885d849 100644 --- a/cmd/sing-box/cmd_rule_set_upgrade.go +++ b/cmd/sing-box/cmd_rule_set_upgrade.go @@ -61,7 +61,10 @@ func upgradeRuleSet(sourcePath string) error { log.Info("already up-to-date") return nil } - plainRuleSet := plainRuleSetCompat.Upgrade() + plainRuleSet, err := plainRuleSetCompat.Upgrade() + if err != nil { + return err + } buffer := new(bytes.Buffer) encoder := json.NewEncoder(buffer) encoder.SetIndent("", " ") diff --git a/common/tls/ech_server.go b/common/tls/ech_server.go index 43ddd820..ac3e6279 100644 --- a/common/tls/ech_server.go +++ b/common/tls/ech_server.go @@ -11,12 +11,11 @@ import ( "strings" cftls "github.com/sagernet/cloudflare-tls" + "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/ntp" - - "github.com/fsnotify/fsnotify" ) type echServerConfig struct { @@ -26,9 +25,8 @@ type echServerConfig struct { key []byte certificatePath string keyPath string - watcher *fsnotify.Watcher echKeyPath string - echWatcher *fsnotify.Watcher + watcher *fswatch.Watcher } func (c *echServerConfig) ServerName() string { @@ -66,146 +64,84 @@ func (c *echServerConfig) Clone() Config { } func (c *echServerConfig) Start() error { - if c.certificatePath != "" && c.keyPath != "" { - err := c.startWatcher() - if err != nil { - c.logger.Warn("create fsnotify watcher: ", err) - } - } - if c.echKeyPath != "" { - err := c.startECHWatcher() - if err != nil { - c.logger.Warn("create fsnotify watcher: ", err) - } + err := c.startWatcher() + if err != nil { + c.logger.Warn("create credentials watcher: ", err) } return nil } func (c *echServerConfig) startWatcher() error { - watcher, err := fsnotify.NewWatcher() - if err != nil { - return err - } + var watchPath []string if c.certificatePath != "" { - err = watcher.Add(c.certificatePath) - if err != nil { - return err - } + watchPath = append(watchPath, c.certificatePath) } if c.keyPath != "" { - err = watcher.Add(c.keyPath) - if err != nil { - return err - } + watchPath = append(watchPath, c.keyPath) + } + if c.echKeyPath != "" { + watchPath = append(watchPath, c.echKeyPath) + } + if len(watchPath) == 0 { + return nil + } + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: watchPath, + Callback: func(path string) { + err := c.credentialsUpdated(path) + if err != nil { + c.logger.Error(E.Cause(err, "reload credentials from ", path)) + } + }, + }) + if err != nil { + return err } c.watcher = watcher - go c.loopUpdate() return nil } -func (c *echServerConfig) loopUpdate() { - for { - select { - case event, ok := <-c.watcher.Events: - if !ok { - return - } - if event.Op&fsnotify.Write != fsnotify.Write { - continue - } - err := c.reloadKeyPair() +func (c *echServerConfig) credentialsUpdated(path string) error { + if path == c.certificatePath || path == c.keyPath { + if path == c.certificatePath { + certificate, err := os.ReadFile(c.certificatePath) if err != nil { - c.logger.Error(E.Cause(err, "reload TLS key pair")) + return err } - case err, ok := <-c.watcher.Errors: - if !ok { - return - } - c.logger.Error(E.Cause(err, "fsnotify error")) - } - } -} - -func (c *echServerConfig) reloadKeyPair() error { - if c.certificatePath != "" { - certificate, err := os.ReadFile(c.certificatePath) - if err != nil { - return E.Cause(err, "reload certificate from ", c.certificatePath) - } - c.certificate = certificate - } - if c.keyPath != "" { - key, err := os.ReadFile(c.keyPath) - if err != nil { - return E.Cause(err, "reload key from ", c.keyPath) - } - c.key = key - } - keyPair, err := cftls.X509KeyPair(c.certificate, c.key) - if err != nil { - return E.Cause(err, "reload key pair") - } - c.config.Certificates = []cftls.Certificate{keyPair} - c.logger.Info("reloaded TLS certificate") - return nil -} - -func (c *echServerConfig) startECHWatcher() error { - watcher, err := fsnotify.NewWatcher() - if err != nil { - return err - } - err = watcher.Add(c.echKeyPath) - if err != nil { - return err - } - c.echWatcher = watcher - go c.loopECHUpdate() - return nil -} - -func (c *echServerConfig) loopECHUpdate() { - for { - select { - case event, ok := <-c.echWatcher.Events: - if !ok { - return - } - if event.Op&fsnotify.Write != fsnotify.Write { - continue - } - err := c.reloadECHKey() + c.certificate = certificate + } else { + key, err := os.ReadFile(c.keyPath) if err != nil { - c.logger.Error(E.Cause(err, "reload ECH key")) + return err } - case err, ok := <-c.echWatcher.Errors: - if !ok { - return - } - c.logger.Error(E.Cause(err, "fsnotify error")) + c.key = key } + keyPair, err := cftls.X509KeyPair(c.certificate, c.key) + if err != nil { + return E.Cause(err, "parse key pair") + } + c.config.Certificates = []cftls.Certificate{keyPair} + c.logger.Info("reloaded TLS certificate") + } else { + echKeyContent, err := os.ReadFile(c.echKeyPath) + if err != nil { + return err + } + block, rest := pem.Decode(echKeyContent) + if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 { + return E.New("invalid ECH keys pem") + } + echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes) + if err != nil { + return E.Cause(err, "parse ECH keys") + } + echKeySet, err := cftls.EXP_NewECHKeySet(echKeys) + if err != nil { + return E.Cause(err, "create ECH key set") + } + c.config.ServerECHProvider = echKeySet + c.logger.Info("reloaded ECH keys") } -} - -func (c *echServerConfig) reloadECHKey() error { - echKeyContent, err := os.ReadFile(c.echKeyPath) - if err != nil { - return err - } - block, rest := pem.Decode(echKeyContent) - if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 { - return E.New("invalid ECH keys pem") - } - echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes) - if err != nil { - return E.Cause(err, "parse ECH keys") - } - echKeySet, err := cftls.EXP_NewECHKeySet(echKeys) - if err != nil { - return E.Cause(err, "create ECH key set") - } - c.config.ServerECHProvider = echKeySet - c.logger.Info("reloaded ECH keys") return nil } @@ -213,12 +149,7 @@ func (c *echServerConfig) Close() error { var err error if c.watcher != nil { err = E.Append(err, c.watcher.Close(), func(err error) error { - return E.Cause(err, "close certificate watcher") - }) - } - if c.echWatcher != nil { - err = E.Append(err, c.echWatcher.Close(), func(err error) error { - return E.Cause(err, "close ECH key watcher") + return E.Cause(err, "close credentials watcher") }) } return err diff --git a/common/tls/std_server.go b/common/tls/std_server.go index 7184bdb3..7001bd3a 100644 --- a/common/tls/std_server.go +++ b/common/tls/std_server.go @@ -7,14 +7,13 @@ import ( "os" "strings" + "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/ntp" - - "github.com/fsnotify/fsnotify" ) var errInsecureUnused = E.New("tls: insecure unused") @@ -27,7 +26,7 @@ type STDServerConfig struct { key []byte certificatePath string keyPath string - watcher *fsnotify.Watcher + watcher *fswatch.Watcher } func (c *STDServerConfig) ServerName() string { @@ -88,59 +87,37 @@ func (c *STDServerConfig) Start() error { } func (c *STDServerConfig) startWatcher() error { - watcher, err := fsnotify.NewWatcher() + var watchPath []string + if c.certificatePath != "" { + watchPath = append(watchPath, c.certificatePath) + } + if c.keyPath != "" { + watchPath = append(watchPath, c.keyPath) + } + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: watchPath, + Callback: func(path string) { + err := c.certificateUpdated(path) + if err != nil { + c.logger.Error(err) + } + }, + }) if err != nil { return err } - if c.certificatePath != "" { - err = watcher.Add(c.certificatePath) - if err != nil { - return err - } - } - if c.keyPath != "" { - err = watcher.Add(c.keyPath) - if err != nil { - return err - } - } c.watcher = watcher - go c.loopUpdate() return nil } -func (c *STDServerConfig) loopUpdate() { - for { - select { - case event, ok := <-c.watcher.Events: - if !ok { - return - } - if event.Op&fsnotify.Write != fsnotify.Write { - continue - } - err := c.reloadKeyPair() - if err != nil { - c.logger.Error(E.Cause(err, "reload TLS key pair")) - } - case err, ok := <-c.watcher.Errors: - if !ok { - return - } - c.logger.Error(E.Cause(err, "fsnotify error")) - } - } -} - -func (c *STDServerConfig) reloadKeyPair() error { - if c.certificatePath != "" { +func (c *STDServerConfig) certificateUpdated(path string) error { + if path == c.certificatePath { certificate, err := os.ReadFile(c.certificatePath) if err != nil { return E.Cause(err, "reload certificate from ", c.certificatePath) } c.certificate = certificate - } - if c.keyPath != "" { + } else if path == c.keyPath { key, err := os.ReadFile(c.keyPath) if err != nil { return E.Cause(err, "reload key from ", c.keyPath) diff --git a/constant/rule.go b/constant/rule.go index fb0f39e2..718e79a5 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -11,6 +11,7 @@ const ( ) const ( + RuleSetTypeInline = "inline" RuleSetTypeLocal = "local" RuleSetTypeRemote = "remote" RuleSetFormatSource = "source" diff --git a/docs/configuration/rule-set/index.md b/docs/configuration/rule-set/index.md index b92d80f3..dfe71d9e 100644 --- a/docs/configuration/rule-set/index.md +++ b/docs/configuration/rule-set/index.md @@ -1,48 +1,56 @@ +--- +icon: material/new-box +--- + +!!! quote "Changes in sing-box 1.10.0" + + :material-plus: `type: inline` + # rule-set !!! question "Since sing-box 1.8.0" ### Structure -```json -{ - "type": "", - "tag": "", - "format": "", - - ... // Typed Fields -} -``` +=== "Inline" -#### Local Structure + !!! question "Since sing-box 1.10.0" -```json -{ - "type": "local", - - ... - - "path": "" -} -``` + ```json + { + "type": "inline", // optional + "tag": "", + "rules": [] + } + ``` -#### Remote Structure +=== "Local File" -!!! info "" + ```json + { + "type": "local", + "tag": "", + "format": "source", // or binary + "path": "" + } + ``` - Remote rule-set will be cached if `experimental.cache_file.enabled`. +=== "Remote File" -```json -{ - "type": "remote", - - ..., - - "url": "", - "download_detour": "", - "update_interval": "" -} -``` + !!! info "" + + Remote rule-set will be cached if `experimental.cache_file.enabled`. + + ```json + { + "type": "remote", + "tag": "", + "format": "source", // or binary + "url": "", + "download_detour": "", // optional + "update_interval": "" // optional + } + ``` ### Fields @@ -58,11 +66,23 @@ Type of rule-set, `local` or `remote`. Tag of rule-set. +### Inline Fields + +!!! question "Since sing-box 1.10.0" + +#### rules + +==Required== + +List of [Headless Rule](./headless-rule.md/). + +### Local or Remote Fields + #### format ==Required== -Format of rule-set, `source` or `binary`. +Format of rule-set file, `source` or `binary`. ### Local Fields @@ -70,6 +90,10 @@ Format of rule-set, `source` or `binary`. ==Required== +!!! note "" + + Will be automatically reloaded if file modified since sing-box 1.10.0. + File path of rule-set. ### Remote Fields diff --git a/docs/configuration/shared/tls.md b/docs/configuration/shared/tls.md index b1441a8a..799aa0b0 100644 --- a/docs/configuration/shared/tls.md +++ b/docs/configuration/shared/tls.md @@ -178,6 +178,10 @@ The server certificate line array, in PEM format. #### certificate_path +!!! note "" + + Will be automatically reloaded if file modified. + The path to the server certificate, in PEM format. #### key @@ -190,6 +194,10 @@ The server private key line array, in PEM format. ==Server only== +!!! note "" + + Will be automatically reloaded if file modified. + The path to the server private key, in PEM format. ## Custom TLS support @@ -266,6 +274,10 @@ ECH key line array, in PEM format. ==Server only== +!!! note "" + + Will be automatically reloaded if file modified. + The path to ECH key, in PEM format. #### config @@ -397,8 +409,4 @@ A hexadecimal string with zero to eight digits. The maximum time difference between the server and the client. -Check disabled if empty. - -### Reload - -For server configuration, certificate, key and ECH key will be automatically reloaded if modified. \ No newline at end of file +Check disabled if empty. \ No newline at end of file diff --git a/docs/configuration/shared/tls.zh.md b/docs/configuration/shared/tls.zh.md index 360c4536..68de9845 100644 --- a/docs/configuration/shared/tls.zh.md +++ b/docs/configuration/shared/tls.zh.md @@ -176,12 +176,20 @@ TLS 版本值: #### certificate_path +!!! note "" + + 文件更改时将自动重新加载。 + 服务器 PEM 证书路径。 #### key ==仅服务器== +!!! note "" + + 文件更改时将自动重新加载。 + 服务器 PEM 私钥行数组。 #### key_path @@ -258,6 +266,10 @@ ECH PEM 密钥行数组 ==仅服务器== +!!! note "" + + 文件更改时将自动重新加载。 + ECH PEM 密钥路径 #### config @@ -384,7 +396,3 @@ ACME DNS01 验证字段。如果配置,将禁用其他验证方法。 服务器与和客户端之间允许的最大时间差。 默认禁用检查。 - -### 重载 - -对于服务器配置,如果修改,证书和密钥将自动重新加载。 \ No newline at end of file diff --git a/go.mod b/go.mod index 3fe7b7ec..ca5264c0 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/caddyserver/certmagic v0.20.0 github.com/cloudflare/circl v1.3.7 github.com/cretz/bine v0.2.0 - github.com/fsnotify/fsnotify v1.7.0 github.com/go-chi/chi/v5 v5.0.12 github.com/go-chi/cors v1.2.1 github.com/go-chi/render v1.0.3 @@ -22,6 +21,7 @@ require ( github.com/oschwald/maxminddb-golang v1.12.0 github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/cloudflare-tls v0.0.0-20231208171750-a4483c1b7cd1 + github.com/sagernet/fswatch v0.1.1 github.com/sagernet/gomobile v0.1.3 github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f github.com/sagernet/quic-go v0.45.1-beta.2 @@ -59,6 +59,7 @@ require ( github.com/ajg/form v1.5.1 // indirect github.com/andybalholm/brotli v1.0.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect @@ -81,7 +82,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect - github.com/sagernet/fswatch v0.1.1 // indirect github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a // indirect github.com/sagernet/nftables v0.3.0-beta.4 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/option/rule_set.go b/option/rule_set.go index ec32d0a1..002cadd4 100644 --- a/option/rule_set.go +++ b/option/rule_set.go @@ -17,6 +17,7 @@ type _RuleSet struct { Type string `json:"type"` Tag string `json:"tag"` Format string `json:"format"` + InlineOptions PlainRuleSet `json:"-"` LocalOptions LocalRuleSet `json:"-"` RemoteOptions RemoteRuleSet `json:"-"` } @@ -26,6 +27,9 @@ type RuleSet _RuleSet func (r RuleSet) MarshalJSON() ([]byte, error) { var v any switch r.Type { + case "", C.RuleSetTypeInline: + r.Type = "" + v = r.InlineOptions case C.RuleSetTypeLocal: v = r.LocalOptions case C.RuleSetTypeRemote: @@ -44,21 +48,26 @@ func (r *RuleSet) UnmarshalJSON(bytes []byte) error { if r.Tag == "" { return E.New("missing tag") } - switch r.Format { - case "": - return E.New("missing format") - case C.RuleSetFormatSource, C.RuleSetFormatBinary: - default: - return E.New("unknown rule-set format: " + r.Format) + if r.Type != C.RuleSetTypeInline { + switch r.Format { + case "": + return E.New("missing format") + case C.RuleSetFormatSource, C.RuleSetFormatBinary: + default: + return E.New("unknown rule-set format: " + r.Format) + } + } else { + r.Format = "" } var v any switch r.Type { + case "", C.RuleSetTypeInline: + r.Type = C.RuleSetTypeInline + v = &r.InlineOptions case C.RuleSetTypeLocal: v = &r.LocalOptions case C.RuleSetTypeRemote: v = &r.RemoteOptions - case "": - return E.New("missing type") default: return E.New("unknown rule-set type: " + r.Type) } @@ -214,15 +223,13 @@ func (r *PlainRuleSetCompat) UnmarshalJSON(bytes []byte) error { return nil } -func (r PlainRuleSetCompat) Upgrade() PlainRuleSet { - var result PlainRuleSet +func (r PlainRuleSetCompat) Upgrade() (PlainRuleSet, error) { switch r.Version { case C.RuleSetVersion1, C.RuleSetVersion2: - result = r.Options default: - panic("unknown rule-set version: " + F.ToString(r.Version)) + return PlainRuleSet{}, E.New("unknown rule-set version: " + F.ToString(r.Version)) } - return result + return r.Options, nil } type PlainRuleSet struct { diff --git a/route/rule_set.go b/route/rule_set.go index 92952c51..fd960b5e 100644 --- a/route/rule_set.go +++ b/route/rule_set.go @@ -20,8 +20,8 @@ import ( func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) (adapter.RuleSet, error) { switch options.Type { - case C.RuleSetTypeLocal: - return NewLocalRuleSet(router, options) + case C.RuleSetTypeInline, C.RuleSetTypeLocal, "": + return NewLocalRuleSet(router, logger, options) case C.RuleSetTypeRemote: return NewRemoteRuleSet(ctx, router, logger, options), nil default: diff --git a/route/rule_set_local.go b/route/rule_set_local.go index aa8c3ff6..cf38f168 100644 --- a/route/rule_set_local.go +++ b/route/rule_set_local.go @@ -3,8 +3,10 @@ package route import ( "context" "os" + "path/filepath" "strings" + "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/srs" C "github.com/sagernet/sing-box/constant" @@ -14,6 +16,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/x/list" "go4.org/netipx" @@ -22,50 +25,55 @@ import ( var _ adapter.RuleSet = (*LocalRuleSet)(nil) type LocalRuleSet struct { - tag string - rules []adapter.HeadlessRule - metadata adapter.RuleSetMetadata - refs atomic.Int32 + router adapter.Router + logger logger.Logger + tag string + rules []adapter.HeadlessRule + metadata adapter.RuleSetMetadata + fileFormat string + watcher *fswatch.Watcher + refs atomic.Int32 } -func NewLocalRuleSet(router adapter.Router, options option.RuleSet) (*LocalRuleSet, error) { - var plainRuleSet option.PlainRuleSet - switch options.Format { - case C.RuleSetFormatSource, "": - content, err := os.ReadFile(options.LocalOptions.Path) - if err != nil { - return nil, err - } - compat, err := json.UnmarshalExtended[option.PlainRuleSetCompat](content) - if err != nil { - return nil, err - } - plainRuleSet = compat.Upgrade() - case C.RuleSetFormatBinary: - setFile, err := os.Open(options.LocalOptions.Path) - if err != nil { - return nil, err - } - plainRuleSet, err = srs.Read(setFile, false) - if err != nil { - return nil, err - } - default: - return nil, E.New("unknown rule-set format: ", options.Format) +func NewLocalRuleSet(router adapter.Router, logger logger.Logger, options option.RuleSet) (*LocalRuleSet, error) { + ruleSet := &LocalRuleSet{ + router: router, + logger: logger, + tag: options.Tag, + fileFormat: options.Format, } - rules := make([]adapter.HeadlessRule, len(plainRuleSet.Rules)) - var err error - for i, ruleOptions := range plainRuleSet.Rules { - rules[i], err = NewHeadlessRule(router, ruleOptions) + if options.Type == C.RuleSetTypeInline { + if len(options.InlineOptions.Rules) == 0 { + return nil, E.New("empty inline rule-set") + } + err := ruleSet.reloadRules(options.InlineOptions.Rules) if err != nil { - return nil, E.Cause(err, "parse rule_set.rules.[", i, "]") + return nil, err + } + } else { + err := ruleSet.reloadFile(options.LocalOptions.Path) + if err != nil { + return nil, err } } - var metadata adapter.RuleSetMetadata - metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule) - metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule) - metadata.ContainsIPCIDRRule = hasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule) - return &LocalRuleSet{tag: options.Tag, rules: rules, metadata: metadata}, nil + if options.Type == C.RuleSetTypeLocal { + var watcher *fswatch.Watcher + filePath, _ := filepath.Abs(options.LocalOptions.Path) + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: []string{filePath}, + Callback: func(path string) { + uErr := ruleSet.reloadFile(path) + if uErr != nil { + logger.Error(E.Cause(uErr, "reload rule-set ", options.Tag)) + } + }, + }) + if err != nil { + return nil, err + } + ruleSet.watcher = watcher + } + return ruleSet, nil } func (s *LocalRuleSet) Name() string { @@ -77,6 +85,61 @@ func (s *LocalRuleSet) String() string { } func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { + if s.watcher != nil { + err := s.watcher.Start() + if err != nil { + s.logger.Error(E.Cause(err, "watch rule-set file")) + } + } + return nil +} + +func (s *LocalRuleSet) reloadFile(path string) error { + var plainRuleSet option.PlainRuleSet + switch s.fileFormat { + case C.RuleSetFormatSource, "": + content, err := os.ReadFile(path) + if err != nil { + return err + } + compat, err := json.UnmarshalExtended[option.PlainRuleSetCompat](content) + if err != nil { + return err + } + plainRuleSet, err = compat.Upgrade() + if err != nil { + return err + } + case C.RuleSetFormatBinary: + setFile, err := os.Open(path) + if err != nil { + return err + } + plainRuleSet, err = srs.Read(setFile, false) + if err != nil { + return err + } + default: + return E.New("unknown rule-set format: ", s.fileFormat) + } + return s.reloadRules(plainRuleSet.Rules) +} + +func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error { + rules := make([]adapter.HeadlessRule, len(headlessRules)) + var err error + for i, ruleOptions := range headlessRules { + rules[i], err = NewHeadlessRule(s.router, ruleOptions) + if err != nil { + return E.Cause(err, "parse rule_set.rules.[", i, "]") + } + } + var metadata adapter.RuleSetMetadata + metadata.ContainsProcessRule = hasHeadlessRule(headlessRules, isProcessHeadlessRule) + metadata.ContainsWIFIRule = hasHeadlessRule(headlessRules, isWIFIHeadlessRule) + metadata.ContainsIPCIDRRule = hasHeadlessRule(headlessRules, isIPCIDRHeadlessRule) + s.rules = rules + s.metadata = metadata return nil } @@ -117,7 +180,7 @@ func (s *LocalRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetU func (s *LocalRuleSet) Close() error { s.rules = nil - return nil + return common.Close(common.PtrOrNil(s.watcher)) } func (s *LocalRuleSet) Match(metadata *adapter.InboundContext) bool { diff --git a/route/rule_set_remote.go b/route/rule_set_remote.go index 1473a494..03662ee4 100644 --- a/route/rule_set_remote.go +++ b/route/rule_set_remote.go @@ -168,7 +168,10 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error { if err != nil { return err } - plainRuleSet = compat.Upgrade() + plainRuleSet, err = compat.Upgrade() + if err != nil { + return err + } case C.RuleSetFormatBinary: plainRuleSet, err = srs.Read(bytes.NewReader(content), false) if err != nil {