package rule import ( "bytes" "context" "io" "net" "net/http" "runtime" "strings" "sync" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/srs" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/atomic" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" "go4.org/netipx" ) var _ adapter.RuleSet = (*RemoteRuleSet)(nil) type RemoteRuleSet struct { ctx context.Context cancel context.CancelFunc router adapter.Router outboundManager adapter.OutboundManager logger logger.ContextLogger options option.RuleSet metadata adapter.RuleSetMetadata updateInterval time.Duration dialer N.Dialer rules []adapter.HeadlessRule lastUpdated time.Time lastEtag string updateTicker *time.Ticker cacheFile adapter.CacheFile pauseManager pause.Manager callbackAccess sync.Mutex callbacks list.List[adapter.RuleSetUpdateCallback] refs atomic.Int32 } func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet { ctx, cancel := context.WithCancel(ctx) var updateInterval time.Duration if options.RemoteOptions.UpdateInterval > 0 { updateInterval = time.Duration(options.RemoteOptions.UpdateInterval) } else { updateInterval = 24 * time.Hour } return &RemoteRuleSet{ ctx: ctx, cancel: cancel, router: router, outboundManager: service.FromContext[adapter.OutboundManager](ctx), logger: logger, options: options, updateInterval: updateInterval, pauseManager: service.FromContext[pause.Manager](ctx), } } func (s *RemoteRuleSet) Name() string { return s.options.Tag } func (s *RemoteRuleSet) String() string { return strings.Join(F.MapToString(s.rules), " ") } func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx) var dialer N.Dialer if s.options.RemoteOptions.DownloadDetour != "" { outbound, loaded := s.outboundManager.Outbound(s.options.RemoteOptions.DownloadDetour) if !loaded { return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour) } dialer = outbound } else { dialer = s.outboundManager.Default() } s.dialer = dialer if s.cacheFile != nil { if savedSet := s.cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil { err := s.loadBytes(savedSet.Content) if err != nil { return E.Cause(err, "restore cached rule-set") } s.lastUpdated = savedSet.LastUpdated s.lastEtag = savedSet.LastEtag } } if s.lastUpdated.IsZero() { err := s.fetchOnce(ctx, startContext) if err != nil { return E.Cause(err, "initial rule-set: ", s.options.Tag) } } s.updateTicker = time.NewTicker(s.updateInterval) return nil } func (s *RemoteRuleSet) PostStart() error { go s.loopUpdate() return nil } func (s *RemoteRuleSet) Metadata() adapter.RuleSetMetadata { return s.metadata } func (s *RemoteRuleSet) ExtractIPSet() []*netipx.IPSet { return common.FlatMap(s.rules, extractIPSetFromRule) } func (s *RemoteRuleSet) IncRef() { s.refs.Add(1) } func (s *RemoteRuleSet) DecRef() { if s.refs.Add(-1) < 0 { panic("rule-set: negative refs") } } func (s *RemoteRuleSet) Cleanup() { if s.refs.Load() == 0 { s.rules = nil } } func (s *RemoteRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] { s.callbackAccess.Lock() defer s.callbackAccess.Unlock() return s.callbacks.PushBack(callback) } func (s *RemoteRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) { s.callbackAccess.Lock() defer s.callbackAccess.Unlock() s.callbacks.Remove(element) } func (s *RemoteRuleSet) loadBytes(content []byte) error { var ( plainRuleSet option.PlainRuleSet err error ) switch s.options.Format { case C.RuleSetFormatSource: var compat option.PlainRuleSetCompat compat, err = json.UnmarshalExtended[option.PlainRuleSetCompat](content) if err != nil { return err } plainRuleSet, err = compat.Upgrade() if err != nil { return err } case C.RuleSetFormatBinary: plainRuleSet, err = srs.Read(bytes.NewReader(content), false) if err != nil { return err } default: return E.New("unknown rule-set format: ", s.options.Format) } rules := make([]adapter.HeadlessRule, len(plainRuleSet.Rules)) for i, ruleOptions := range plainRuleSet.Rules { rules[i], err = NewHeadlessRule(s.router, ruleOptions) if err != nil { return E.Cause(err, "parse rule_set.rules.[", i, "]") } } s.metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule) s.metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule) s.metadata.ContainsIPCIDRRule = hasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule) s.rules = rules s.callbackAccess.Lock() callbacks := s.callbacks.Array() s.callbackAccess.Unlock() for _, callback := range callbacks { callback(s) } return nil } func (s *RemoteRuleSet) loopUpdate() { if time.Since(s.lastUpdated) > s.updateInterval { err := s.fetchOnce(s.ctx, nil) if err != nil { s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err) } else if s.refs.Load() == 0 { s.rules = nil } } for { runtime.GC() select { case <-s.ctx.Done(): return case <-s.updateTicker.C: s.pauseManager.WaitActive() err := s.fetchOnce(s.ctx, nil) if err != nil { s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err) } else if s.refs.Load() == 0 { s.rules = nil } } } } func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext *adapter.HTTPStartContext) error { s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL) var httpClient *http.Client if startContext != nil { httpClient = startContext.HTTPClient(s.options.RemoteOptions.DownloadDetour, s.dialer) } else { httpClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: true, TLSHandshakeTimeout: C.TCPTimeout, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) }, }, } } request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil) if err != nil { return err } if s.lastEtag != "" { request.Header.Set("If-None-Match", s.lastEtag) } response, err := httpClient.Do(request.WithContext(ctx)) if err != nil { return err } switch response.StatusCode { case http.StatusOK: case http.StatusNotModified: s.lastUpdated = time.Now() if s.cacheFile != nil { savedRuleSet := s.cacheFile.LoadRuleSet(s.options.Tag) if savedRuleSet != nil { savedRuleSet.LastUpdated = s.lastUpdated err = s.cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet) if err != nil { s.logger.Error("save rule-set updated time: ", err) return nil } } } s.logger.Info("update rule-set ", s.options.Tag, ": not modified") return nil default: return E.New("unexpected status: ", response.Status) } content, err := io.ReadAll(response.Body) if err != nil { response.Body.Close() return err } err = s.loadBytes(content) if err != nil { response.Body.Close() return err } response.Body.Close() eTagHeader := response.Header.Get("Etag") if eTagHeader != "" { s.lastEtag = eTagHeader } s.lastUpdated = time.Now() if s.cacheFile != nil { err = s.cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{ LastUpdated: s.lastUpdated, Content: content, LastEtag: s.lastEtag, }) if err != nil { s.logger.Error("save rule-set cache: ", err) } } s.logger.Info("updated rule-set ", s.options.Tag) return nil } func (s *RemoteRuleSet) Close() error { s.rules = nil s.updateTicker.Stop() s.cancel() return nil } func (s *RemoteRuleSet) Match(metadata *adapter.InboundContext) bool { for _, rule := range s.rules { if rule.Match(metadata) { return true } } return false }