sing-box/route/rule/rule_set_remote.go
2024-11-03 19:47:09 +08:00

322 lines
8.3 KiB
Go

package rule
import (
"bytes"
"context"
"io"
"net"
"net/http"
"runtime"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/srs"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"go4.org/netipx"
)
var _ adapter.RuleSet = (*RemoteRuleSet)(nil)
type RemoteRuleSet struct {
ctx context.Context
cancel context.CancelFunc
router adapter.Router
logger logger.ContextLogger
options option.RuleSet
metadata adapter.RuleSetMetadata
updateInterval time.Duration
dialer N.Dialer
rules []adapter.HeadlessRule
lastUpdated time.Time
lastEtag string
updateTicker *time.Ticker
cacheFile adapter.CacheFile
pauseManager pause.Manager
callbackAccess sync.Mutex
callbacks list.List[adapter.RuleSetUpdateCallback]
refs atomic.Int32
}
func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet {
ctx, cancel := context.WithCancel(ctx)
var updateInterval time.Duration
if options.RemoteOptions.UpdateInterval > 0 {
updateInterval = time.Duration(options.RemoteOptions.UpdateInterval)
} else {
updateInterval = 24 * time.Hour
}
return &RemoteRuleSet{
ctx: ctx,
cancel: cancel,
router: router,
logger: logger,
options: options,
updateInterval: updateInterval,
pauseManager: service.FromContext[pause.Manager](ctx),
}
}
func (s *RemoteRuleSet) Name() string {
return s.options.Tag
}
func (s *RemoteRuleSet) String() string {
return strings.Join(F.MapToString(s.rules), " ")
}
func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour)
if !loaded {
return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour)
}
dialer = outbound
} else {
outbound, err := s.router.DefaultOutbound(N.NetworkTCP)
if err != nil {
return err
}
dialer = outbound
}
s.dialer = dialer
if s.cacheFile != nil {
if savedSet := s.cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil {
err := s.loadBytes(savedSet.Content)
if err != nil {
return E.Cause(err, "restore cached rule-set")
}
s.lastUpdated = savedSet.LastUpdated
s.lastEtag = savedSet.LastEtag
}
}
if s.lastUpdated.IsZero() {
err := s.fetchOnce(ctx, startContext)
if err != nil {
return E.Cause(err, "initial rule-set: ", s.options.Tag)
}
}
s.updateTicker = time.NewTicker(s.updateInterval)
return nil
}
func (s *RemoteRuleSet) PostStart() error {
go s.loopUpdate()
return nil
}
func (s *RemoteRuleSet) Metadata() adapter.RuleSetMetadata {
return s.metadata
}
func (s *RemoteRuleSet) ExtractIPSet() []*netipx.IPSet {
return common.FlatMap(s.rules, extractIPSetFromRule)
}
func (s *RemoteRuleSet) IncRef() {
s.refs.Add(1)
}
func (s *RemoteRuleSet) DecRef() {
if s.refs.Add(-1) < 0 {
panic("rule-set: negative refs")
}
}
func (s *RemoteRuleSet) Cleanup() {
if s.refs.Load() == 0 {
s.rules = nil
}
}
func (s *RemoteRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] {
s.callbackAccess.Lock()
defer s.callbackAccess.Unlock()
return s.callbacks.PushBack(callback)
}
func (s *RemoteRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) {
s.callbackAccess.Lock()
defer s.callbackAccess.Unlock()
s.callbacks.Remove(element)
}
func (s *RemoteRuleSet) loadBytes(content []byte) error {
var (
plainRuleSet option.PlainRuleSet
err error
)
switch s.options.Format {
case C.RuleSetFormatSource:
var compat option.PlainRuleSetCompat
compat, err = json.UnmarshalExtended[option.PlainRuleSetCompat](content)
if err != nil {
return err
}
plainRuleSet, err = compat.Upgrade()
if err != nil {
return err
}
case C.RuleSetFormatBinary:
plainRuleSet, err = srs.Read(bytes.NewReader(content), false)
if err != nil {
return err
}
default:
return E.New("unknown rule-set format: ", s.options.Format)
}
rules := make([]adapter.HeadlessRule, len(plainRuleSet.Rules))
for i, ruleOptions := range plainRuleSet.Rules {
rules[i], err = NewHeadlessRule(s.router, ruleOptions)
if err != nil {
return E.Cause(err, "parse rule_set.rules.[", i, "]")
}
}
s.metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
s.metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
s.metadata.ContainsIPCIDRRule = hasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
s.rules = rules
s.callbackAccess.Lock()
callbacks := s.callbacks.Array()
s.callbackAccess.Unlock()
for _, callback := range callbacks {
callback(s)
}
return nil
}
func (s *RemoteRuleSet) loopUpdate() {
if time.Since(s.lastUpdated) > s.updateInterval {
err := s.fetchOnce(s.ctx, nil)
if err != nil {
s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err)
} else if s.refs.Load() == 0 {
s.rules = nil
}
}
for {
runtime.GC()
select {
case <-s.ctx.Done():
return
case <-s.updateTicker.C:
s.pauseManager.WaitActive()
err := s.fetchOnce(s.ctx, nil)
if err != nil {
s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err)
} else if s.refs.Load() == 0 {
s.rules = nil
}
}
}
}
func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext *adapter.HTTPStartContext) error {
s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL)
var httpClient *http.Client
if startContext != nil {
httpClient = startContext.HTTPClient(s.options.RemoteOptions.DownloadDetour, s.dialer)
} else {
httpClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
}
request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil)
if err != nil {
return err
}
if s.lastEtag != "" {
request.Header.Set("If-None-Match", s.lastEtag)
}
response, err := httpClient.Do(request.WithContext(ctx))
if err != nil {
return err
}
switch response.StatusCode {
case http.StatusOK:
case http.StatusNotModified:
s.lastUpdated = time.Now()
if s.cacheFile != nil {
savedRuleSet := s.cacheFile.LoadRuleSet(s.options.Tag)
if savedRuleSet != nil {
savedRuleSet.LastUpdated = s.lastUpdated
err = s.cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet)
if err != nil {
s.logger.Error("save rule-set updated time: ", err)
return nil
}
}
}
s.logger.Info("update rule-set ", s.options.Tag, ": not modified")
return nil
default:
return E.New("unexpected status: ", response.Status)
}
content, err := io.ReadAll(response.Body)
if err != nil {
response.Body.Close()
return err
}
err = s.loadBytes(content)
if err != nil {
response.Body.Close()
return err
}
response.Body.Close()
eTagHeader := response.Header.Get("Etag")
if eTagHeader != "" {
s.lastEtag = eTagHeader
}
s.lastUpdated = time.Now()
if s.cacheFile != nil {
err = s.cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{
LastUpdated: s.lastUpdated,
Content: content,
LastEtag: s.lastEtag,
})
if err != nil {
s.logger.Error("save rule-set cache: ", err)
}
}
s.logger.Info("updated rule-set ", s.options.Tag)
return nil
}
func (s *RemoteRuleSet) Close() error {
s.rules = nil
s.updateTicker.Stop()
s.cancel()
return nil
}
func (s *RemoteRuleSet) Match(metadata *adapter.InboundContext) bool {
for _, rule := range s.rules {
if rule.Match(metadata) {
return true
}
}
return false
}