Add inline rule-set & Add reload for local rule-set

This commit is contained in:
世界 2024-06-26 00:43:51 +08:00
parent f98faaf1ea
commit 8e163e0a7d
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
14 changed files with 309 additions and 278 deletions

View file

@ -56,7 +56,10 @@ func compileRuleSet(sourcePath string) error {
if err != nil { if err != nil {
return err return err
} }
ruleSet := plainRuleSet.Upgrade() ruleSet, err := plainRuleSet.Upgrade()
if err != nil {
return err
}
var outputPath string var outputPath string
if flagRuleSetCompileOutput == flagRuleSetCompileDefaultOutput { if flagRuleSetCompileOutput == flagRuleSetCompileDefaultOutput {
if strings.HasSuffix(sourcePath, ".json") { if strings.HasSuffix(sourcePath, ".json") {

View file

@ -63,7 +63,10 @@ func ruleSetMatch(sourcePath string, domain string) error {
if err != nil { if err != nil {
return err return err
} }
plainRuleSet = compat.Upgrade() plainRuleSet, err = compat.Upgrade()
if err != nil {
return err
}
case C.RuleSetFormatBinary: case C.RuleSetFormatBinary:
plainRuleSet, err = srs.Read(bytes.NewReader(content), false) plainRuleSet, err = srs.Read(bytes.NewReader(content), false)
if err != nil { if err != nil {

View file

@ -61,7 +61,10 @@ func upgradeRuleSet(sourcePath string) error {
log.Info("already up-to-date") log.Info("already up-to-date")
return nil return nil
} }
plainRuleSet := plainRuleSetCompat.Upgrade() plainRuleSet, err := plainRuleSetCompat.Upgrade()
if err != nil {
return err
}
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
encoder := json.NewEncoder(buffer) encoder := json.NewEncoder(buffer)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")

View file

@ -11,12 +11,11 @@ import (
"strings" "strings"
cftls "github.com/sagernet/cloudflare-tls" cftls "github.com/sagernet/cloudflare-tls"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
"github.com/fsnotify/fsnotify"
) )
type echServerConfig struct { type echServerConfig struct {
@ -26,9 +25,8 @@ type echServerConfig struct {
key []byte key []byte
certificatePath string certificatePath string
keyPath string keyPath string
watcher *fsnotify.Watcher
echKeyPath string echKeyPath string
echWatcher *fsnotify.Watcher watcher *fswatch.Watcher
} }
func (c *echServerConfig) ServerName() string { func (c *echServerConfig) ServerName() string {
@ -66,146 +64,84 @@ func (c *echServerConfig) Clone() Config {
} }
func (c *echServerConfig) Start() error { func (c *echServerConfig) Start() error {
if c.certificatePath != "" && c.keyPath != "" { err := c.startWatcher()
err := c.startWatcher() if err != nil {
if err != nil { c.logger.Warn("create credentials watcher: ", err)
c.logger.Warn("create fsnotify watcher: ", err)
}
}
if c.echKeyPath != "" {
err := c.startECHWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
} }
return nil return nil
} }
func (c *echServerConfig) startWatcher() error { func (c *echServerConfig) startWatcher() error {
watcher, err := fsnotify.NewWatcher() var watchPath []string
if err != nil {
return err
}
if c.certificatePath != "" { if c.certificatePath != "" {
err = watcher.Add(c.certificatePath) watchPath = append(watchPath, c.certificatePath)
if err != nil {
return err
}
} }
if c.keyPath != "" { if c.keyPath != "" {
err = watcher.Add(c.keyPath) watchPath = append(watchPath, c.keyPath)
if err != nil { }
return err if c.echKeyPath != "" {
} watchPath = append(watchPath, c.echKeyPath)
}
if len(watchPath) == 0 {
return nil
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPath,
Callback: func(path string) {
err := c.credentialsUpdated(path)
if err != nil {
c.logger.Error(E.Cause(err, "reload credentials from ", path))
}
},
})
if err != nil {
return err
} }
c.watcher = watcher c.watcher = watcher
go c.loopUpdate()
return nil return nil
} }
func (c *echServerConfig) loopUpdate() { func (c *echServerConfig) credentialsUpdated(path string) error {
for { if path == c.certificatePath || path == c.keyPath {
select { if path == c.certificatePath {
case event, ok := <-c.watcher.Events: certificate, err := os.ReadFile(c.certificatePath)
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadKeyPair()
if err != nil { if err != nil {
c.logger.Error(E.Cause(err, "reload TLS key pair")) return err
} }
case err, ok := <-c.watcher.Errors: c.certificate = certificate
if !ok { } else {
return key, err := os.ReadFile(c.keyPath)
}
c.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
func (c *echServerConfig) reloadKeyPair() error {
if c.certificatePath != "" {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
}
c.certificate = certificate
}
if c.keyPath != "" {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
}
c.key = key
}
keyPair, err := cftls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload key pair")
}
c.config.Certificates = []cftls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
return nil
}
func (c *echServerConfig) startECHWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
err = watcher.Add(c.echKeyPath)
if err != nil {
return err
}
c.echWatcher = watcher
go c.loopECHUpdate()
return nil
}
func (c *echServerConfig) loopECHUpdate() {
for {
select {
case event, ok := <-c.echWatcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadECHKey()
if err != nil { if err != nil {
c.logger.Error(E.Cause(err, "reload ECH key")) return err
} }
case err, ok := <-c.echWatcher.Errors: c.key = key
if !ok {
return
}
c.logger.Error(E.Cause(err, "fsnotify error"))
} }
keyPair, err := cftls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "parse key pair")
}
c.config.Certificates = []cftls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
} else {
echKeyContent, err := os.ReadFile(c.echKeyPath)
if err != nil {
return err
}
block, rest := pem.Decode(echKeyContent)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return E.New("invalid ECH keys pem")
}
echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
if err != nil {
return E.Cause(err, "parse ECH keys")
}
echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
if err != nil {
return E.Cause(err, "create ECH key set")
}
c.config.ServerECHProvider = echKeySet
c.logger.Info("reloaded ECH keys")
} }
}
func (c *echServerConfig) reloadECHKey() error {
echKeyContent, err := os.ReadFile(c.echKeyPath)
if err != nil {
return err
}
block, rest := pem.Decode(echKeyContent)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return E.New("invalid ECH keys pem")
}
echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
if err != nil {
return E.Cause(err, "parse ECH keys")
}
echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
if err != nil {
return E.Cause(err, "create ECH key set")
}
c.config.ServerECHProvider = echKeySet
c.logger.Info("reloaded ECH keys")
return nil return nil
} }
@ -213,12 +149,7 @@ func (c *echServerConfig) Close() error {
var err error var err error
if c.watcher != nil { if c.watcher != nil {
err = E.Append(err, c.watcher.Close(), func(err error) error { err = E.Append(err, c.watcher.Close(), func(err error) error {
return E.Cause(err, "close certificate watcher") return E.Cause(err, "close credentials watcher")
})
}
if c.echWatcher != nil {
err = E.Append(err, c.echWatcher.Close(), func(err error) error {
return E.Cause(err, "close ECH key watcher")
}) })
} }
return err return err

View file

@ -7,14 +7,13 @@ import (
"os" "os"
"strings" "strings"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
"github.com/fsnotify/fsnotify"
) )
var errInsecureUnused = E.New("tls: insecure unused") var errInsecureUnused = E.New("tls: insecure unused")
@ -27,7 +26,7 @@ type STDServerConfig struct {
key []byte key []byte
certificatePath string certificatePath string
keyPath string keyPath string
watcher *fsnotify.Watcher watcher *fswatch.Watcher
} }
func (c *STDServerConfig) ServerName() string { func (c *STDServerConfig) ServerName() string {
@ -88,59 +87,37 @@ func (c *STDServerConfig) Start() error {
} }
func (c *STDServerConfig) startWatcher() error { func (c *STDServerConfig) startWatcher() error {
watcher, err := fsnotify.NewWatcher() var watchPath []string
if c.certificatePath != "" {
watchPath = append(watchPath, c.certificatePath)
}
if c.keyPath != "" {
watchPath = append(watchPath, c.keyPath)
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPath,
Callback: func(path string) {
err := c.certificateUpdated(path)
if err != nil {
c.logger.Error(err)
}
},
})
if err != nil { if err != nil {
return err return err
} }
if c.certificatePath != "" {
err = watcher.Add(c.certificatePath)
if err != nil {
return err
}
}
if c.keyPath != "" {
err = watcher.Add(c.keyPath)
if err != nil {
return err
}
}
c.watcher = watcher c.watcher = watcher
go c.loopUpdate()
return nil return nil
} }
func (c *STDServerConfig) loopUpdate() { func (c *STDServerConfig) certificateUpdated(path string) error {
for { if path == c.certificatePath {
select {
case event, ok := <-c.watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadKeyPair()
if err != nil {
c.logger.Error(E.Cause(err, "reload TLS key pair"))
}
case err, ok := <-c.watcher.Errors:
if !ok {
return
}
c.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
func (c *STDServerConfig) reloadKeyPair() error {
if c.certificatePath != "" {
certificate, err := os.ReadFile(c.certificatePath) certificate, err := os.ReadFile(c.certificatePath)
if err != nil { if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath) return E.Cause(err, "reload certificate from ", c.certificatePath)
} }
c.certificate = certificate c.certificate = certificate
} } else if path == c.keyPath {
if c.keyPath != "" {
key, err := os.ReadFile(c.keyPath) key, err := os.ReadFile(c.keyPath)
if err != nil { if err != nil {
return E.Cause(err, "reload key from ", c.keyPath) return E.Cause(err, "reload key from ", c.keyPath)

View file

@ -11,6 +11,7 @@ const (
) )
const ( const (
RuleSetTypeInline = "inline"
RuleSetTypeLocal = "local" RuleSetTypeLocal = "local"
RuleSetTypeRemote = "remote" RuleSetTypeRemote = "remote"
RuleSetFormatSource = "source" RuleSetFormatSource = "source"

View file

@ -1,48 +1,56 @@
---
icon: material/new-box
---
!!! quote "Changes in sing-box 1.10.0"
:material-plus: `type: inline`
# rule-set # rule-set
!!! question "Since sing-box 1.8.0" !!! question "Since sing-box 1.8.0"
### Structure ### Structure
```json === "Inline"
{
"type": "",
"tag": "",
"format": "",
... // Typed Fields !!! question "Since sing-box 1.10.0"
}
```
#### Local Structure ```json
{
"type": "inline", // optional
"tag": "",
"rules": []
}
```
```json === "Local File"
{
"type": "local",
... ```json
{
"type": "local",
"tag": "",
"format": "source", // or binary
"path": ""
}
```
"path": "" === "Remote File"
}
```
#### Remote Structure !!! info ""
!!! info "" Remote rule-set will be cached if `experimental.cache_file.enabled`.
Remote rule-set will be cached if `experimental.cache_file.enabled`. ```json
{
```json "type": "remote",
{ "tag": "",
"type": "remote", "format": "source", // or binary
"url": "",
..., "download_detour": "", // optional
"update_interval": "" // optional
"url": "", }
"download_detour": "", ```
"update_interval": ""
}
```
### Fields ### Fields
@ -58,11 +66,23 @@ Type of rule-set, `local` or `remote`.
Tag of rule-set. Tag of rule-set.
### Inline Fields
!!! question "Since sing-box 1.10.0"
#### rules
==Required==
List of [Headless Rule](./headless-rule.md/).
### Local or Remote Fields
#### format #### format
==Required== ==Required==
Format of rule-set, `source` or `binary`. Format of rule-set file, `source` or `binary`.
### Local Fields ### Local Fields
@ -70,6 +90,10 @@ Format of rule-set, `source` or `binary`.
==Required== ==Required==
!!! note ""
Will be automatically reloaded if file modified since sing-box 1.10.0.
File path of rule-set. File path of rule-set.
### Remote Fields ### Remote Fields

View file

@ -178,6 +178,10 @@ The server certificate line array, in PEM format.
#### certificate_path #### certificate_path
!!! note ""
Will be automatically reloaded if file modified.
The path to the server certificate, in PEM format. The path to the server certificate, in PEM format.
#### key #### key
@ -190,6 +194,10 @@ The server private key line array, in PEM format.
==Server only== ==Server only==
!!! note ""
Will be automatically reloaded if file modified.
The path to the server private key, in PEM format. The path to the server private key, in PEM format.
## Custom TLS support ## Custom TLS support
@ -266,6 +274,10 @@ ECH key line array, in PEM format.
==Server only== ==Server only==
!!! note ""
Will be automatically reloaded if file modified.
The path to ECH key, in PEM format. The path to ECH key, in PEM format.
#### config #### config
@ -398,7 +410,3 @@ A hexadecimal string with zero to eight digits.
The maximum time difference between the server and the client. The maximum time difference between the server and the client.
Check disabled if empty. Check disabled if empty.
### Reload
For server configuration, certificate, key and ECH key will be automatically reloaded if modified.

View file

@ -176,12 +176,20 @@ TLS 版本值:
#### certificate_path #### certificate_path
!!! note ""
文件更改时将自动重新加载。
服务器 PEM 证书路径。 服务器 PEM 证书路径。
#### key #### key
==仅服务器== ==仅服务器==
!!! note ""
文件更改时将自动重新加载。
服务器 PEM 私钥行数组。 服务器 PEM 私钥行数组。
#### key_path #### key_path
@ -258,6 +266,10 @@ ECH PEM 密钥行数组
==仅服务器== ==仅服务器==
!!! note ""
文件更改时将自动重新加载。
ECH PEM 密钥路径 ECH PEM 密钥路径
#### config #### config
@ -384,7 +396,3 @@ ACME DNS01 验证字段。如果配置,将禁用其他验证方法。
服务器与和客户端之间允许的最大时间差。 服务器与和客户端之间允许的最大时间差。
默认禁用检查。 默认禁用检查。
### 重载
对于服务器配置,如果修改,证书和密钥将自动重新加载。

4
go.mod
View file

@ -7,7 +7,6 @@ require (
github.com/caddyserver/certmagic v0.20.0 github.com/caddyserver/certmagic v0.20.0
github.com/cloudflare/circl v1.3.7 github.com/cloudflare/circl v1.3.7
github.com/cretz/bine v0.2.0 github.com/cretz/bine v0.2.0
github.com/fsnotify/fsnotify v1.7.0
github.com/go-chi/chi/v5 v5.0.12 github.com/go-chi/chi/v5 v5.0.12
github.com/go-chi/cors v1.2.1 github.com/go-chi/cors v1.2.1
github.com/go-chi/render v1.0.3 github.com/go-chi/render v1.0.3
@ -22,6 +21,7 @@ require (
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a
github.com/sagernet/cloudflare-tls v0.0.0-20231208171750-a4483c1b7cd1 github.com/sagernet/cloudflare-tls v0.0.0-20231208171750-a4483c1b7cd1
github.com/sagernet/fswatch v0.1.1
github.com/sagernet/gomobile v0.1.3 github.com/sagernet/gomobile v0.1.3
github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f
github.com/sagernet/quic-go v0.45.1-beta.2 github.com/sagernet/quic-go v0.45.1-beta.2
@ -59,6 +59,7 @@ require (
github.com/ajg/form v1.5.1 // indirect github.com/ajg/form v1.5.1 // indirect
github.com/andybalholm/brotli v1.0.6 // indirect github.com/andybalholm/brotli v1.0.6 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gaukas/godicttls v0.0.4 // indirect github.com/gaukas/godicttls v0.0.4 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
@ -81,7 +82,6 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/sagernet/fswatch v0.1.1 // indirect
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a // indirect github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a // indirect
github.com/sagernet/nftables v0.3.0-beta.4 // indirect github.com/sagernet/nftables v0.3.0-beta.4 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect

View file

@ -17,6 +17,7 @@ type _RuleSet struct {
Type string `json:"type"` Type string `json:"type"`
Tag string `json:"tag"` Tag string `json:"tag"`
Format string `json:"format"` Format string `json:"format"`
InlineOptions PlainRuleSet `json:"-"`
LocalOptions LocalRuleSet `json:"-"` LocalOptions LocalRuleSet `json:"-"`
RemoteOptions RemoteRuleSet `json:"-"` RemoteOptions RemoteRuleSet `json:"-"`
} }
@ -26,6 +27,9 @@ type RuleSet _RuleSet
func (r RuleSet) MarshalJSON() ([]byte, error) { func (r RuleSet) MarshalJSON() ([]byte, error) {
var v any var v any
switch r.Type { switch r.Type {
case "", C.RuleSetTypeInline:
r.Type = ""
v = r.InlineOptions
case C.RuleSetTypeLocal: case C.RuleSetTypeLocal:
v = r.LocalOptions v = r.LocalOptions
case C.RuleSetTypeRemote: case C.RuleSetTypeRemote:
@ -44,21 +48,26 @@ func (r *RuleSet) UnmarshalJSON(bytes []byte) error {
if r.Tag == "" { if r.Tag == "" {
return E.New("missing tag") return E.New("missing tag")
} }
switch r.Format { if r.Type != C.RuleSetTypeInline {
case "": switch r.Format {
return E.New("missing format") case "":
case C.RuleSetFormatSource, C.RuleSetFormatBinary: return E.New("missing format")
default: case C.RuleSetFormatSource, C.RuleSetFormatBinary:
return E.New("unknown rule-set format: " + r.Format) default:
return E.New("unknown rule-set format: " + r.Format)
}
} else {
r.Format = ""
} }
var v any var v any
switch r.Type { switch r.Type {
case "", C.RuleSetTypeInline:
r.Type = C.RuleSetTypeInline
v = &r.InlineOptions
case C.RuleSetTypeLocal: case C.RuleSetTypeLocal:
v = &r.LocalOptions v = &r.LocalOptions
case C.RuleSetTypeRemote: case C.RuleSetTypeRemote:
v = &r.RemoteOptions v = &r.RemoteOptions
case "":
return E.New("missing type")
default: default:
return E.New("unknown rule-set type: " + r.Type) return E.New("unknown rule-set type: " + r.Type)
} }
@ -214,15 +223,13 @@ func (r *PlainRuleSetCompat) UnmarshalJSON(bytes []byte) error {
return nil return nil
} }
func (r PlainRuleSetCompat) Upgrade() PlainRuleSet { func (r PlainRuleSetCompat) Upgrade() (PlainRuleSet, error) {
var result PlainRuleSet
switch r.Version { switch r.Version {
case C.RuleSetVersion1, C.RuleSetVersion2: case C.RuleSetVersion1, C.RuleSetVersion2:
result = r.Options
default: default:
panic("unknown rule-set version: " + F.ToString(r.Version)) return PlainRuleSet{}, E.New("unknown rule-set version: " + F.ToString(r.Version))
} }
return result return r.Options, nil
} }
type PlainRuleSet struct { type PlainRuleSet struct {

View file

@ -20,8 +20,8 @@ import (
func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) (adapter.RuleSet, error) { func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) (adapter.RuleSet, error) {
switch options.Type { switch options.Type {
case C.RuleSetTypeLocal: case C.RuleSetTypeInline, C.RuleSetTypeLocal, "":
return NewLocalRuleSet(router, options) return NewLocalRuleSet(router, logger, options)
case C.RuleSetTypeRemote: case C.RuleSetTypeRemote:
return NewRemoteRuleSet(ctx, router, logger, options), nil return NewRemoteRuleSet(ctx, router, logger, options), nil
default: default:

View file

@ -3,8 +3,10 @@ package route
import ( import (
"context" "context"
"os" "os"
"path/filepath"
"strings" "strings"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/srs" "github.com/sagernet/sing-box/common/srs"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@ -14,6 +16,7 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"go4.org/netipx" "go4.org/netipx"
@ -22,50 +25,55 @@ import (
var _ adapter.RuleSet = (*LocalRuleSet)(nil) var _ adapter.RuleSet = (*LocalRuleSet)(nil)
type LocalRuleSet struct { type LocalRuleSet struct {
tag string router adapter.Router
rules []adapter.HeadlessRule logger logger.Logger
metadata adapter.RuleSetMetadata tag string
refs atomic.Int32 rules []adapter.HeadlessRule
metadata adapter.RuleSetMetadata
fileFormat string
watcher *fswatch.Watcher
refs atomic.Int32
} }
func NewLocalRuleSet(router adapter.Router, options option.RuleSet) (*LocalRuleSet, error) { func NewLocalRuleSet(router adapter.Router, logger logger.Logger, options option.RuleSet) (*LocalRuleSet, error) {
var plainRuleSet option.PlainRuleSet ruleSet := &LocalRuleSet{
switch options.Format { router: router,
case C.RuleSetFormatSource, "": logger: logger,
content, err := os.ReadFile(options.LocalOptions.Path) tag: options.Tag,
if err != nil { fileFormat: options.Format,
return nil, err
}
compat, err := json.UnmarshalExtended[option.PlainRuleSetCompat](content)
if err != nil {
return nil, err
}
plainRuleSet = compat.Upgrade()
case C.RuleSetFormatBinary:
setFile, err := os.Open(options.LocalOptions.Path)
if err != nil {
return nil, err
}
plainRuleSet, err = srs.Read(setFile, false)
if err != nil {
return nil, err
}
default:
return nil, E.New("unknown rule-set format: ", options.Format)
} }
rules := make([]adapter.HeadlessRule, len(plainRuleSet.Rules)) if options.Type == C.RuleSetTypeInline {
var err error if len(options.InlineOptions.Rules) == 0 {
for i, ruleOptions := range plainRuleSet.Rules { return nil, E.New("empty inline rule-set")
rules[i], err = NewHeadlessRule(router, ruleOptions) }
err := ruleSet.reloadRules(options.InlineOptions.Rules)
if err != nil { if err != nil {
return nil, E.Cause(err, "parse rule_set.rules.[", i, "]") return nil, err
}
} else {
err := ruleSet.reloadFile(options.LocalOptions.Path)
if err != nil {
return nil, err
} }
} }
var metadata adapter.RuleSetMetadata if options.Type == C.RuleSetTypeLocal {
metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule) var watcher *fswatch.Watcher
metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule) filePath, _ := filepath.Abs(options.LocalOptions.Path)
metadata.ContainsIPCIDRRule = hasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule) watcher, err := fswatch.NewWatcher(fswatch.Options{
return &LocalRuleSet{tag: options.Tag, rules: rules, metadata: metadata}, nil Path: []string{filePath},
Callback: func(path string) {
uErr := ruleSet.reloadFile(path)
if uErr != nil {
logger.Error(E.Cause(uErr, "reload rule-set ", options.Tag))
}
},
})
if err != nil {
return nil, err
}
ruleSet.watcher = watcher
}
return ruleSet, nil
} }
func (s *LocalRuleSet) Name() string { func (s *LocalRuleSet) Name() string {
@ -77,6 +85,61 @@ func (s *LocalRuleSet) String() string {
} }
func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error {
if s.watcher != nil {
err := s.watcher.Start()
if err != nil {
s.logger.Error(E.Cause(err, "watch rule-set file"))
}
}
return nil
}
func (s *LocalRuleSet) reloadFile(path string) error {
var plainRuleSet option.PlainRuleSet
switch s.fileFormat {
case C.RuleSetFormatSource, "":
content, err := os.ReadFile(path)
if err != nil {
return err
}
compat, err := json.UnmarshalExtended[option.PlainRuleSetCompat](content)
if err != nil {
return err
}
plainRuleSet, err = compat.Upgrade()
if err != nil {
return err
}
case C.RuleSetFormatBinary:
setFile, err := os.Open(path)
if err != nil {
return err
}
plainRuleSet, err = srs.Read(setFile, false)
if err != nil {
return err
}
default:
return E.New("unknown rule-set format: ", s.fileFormat)
}
return s.reloadRules(plainRuleSet.Rules)
}
func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
rules := make([]adapter.HeadlessRule, len(headlessRules))
var err error
for i, ruleOptions := range headlessRules {
rules[i], err = NewHeadlessRule(s.router, ruleOptions)
if err != nil {
return E.Cause(err, "parse rule_set.rules.[", i, "]")
}
}
var metadata adapter.RuleSetMetadata
metadata.ContainsProcessRule = hasHeadlessRule(headlessRules, isProcessHeadlessRule)
metadata.ContainsWIFIRule = hasHeadlessRule(headlessRules, isWIFIHeadlessRule)
metadata.ContainsIPCIDRRule = hasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
s.rules = rules
s.metadata = metadata
return nil return nil
} }
@ -117,7 +180,7 @@ func (s *LocalRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetU
func (s *LocalRuleSet) Close() error { func (s *LocalRuleSet) Close() error {
s.rules = nil s.rules = nil
return nil return common.Close(common.PtrOrNil(s.watcher))
} }
func (s *LocalRuleSet) Match(metadata *adapter.InboundContext) bool { func (s *LocalRuleSet) Match(metadata *adapter.InboundContext) bool {

View file

@ -168,7 +168,10 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error {
if err != nil { if err != nil {
return err return err
} }
plainRuleSet = compat.Upgrade() plainRuleSet, err = compat.Upgrade()
if err != nil {
return err
}
case C.RuleSetFormatBinary: case C.RuleSetFormatBinary:
plainRuleSet, err = srs.Read(bytes.NewReader(content), false) plainRuleSet, err = srs.Read(bytes.NewReader(content), false)
if err != nil { if err != nil {