diff --git a/adapter/router.go b/adapter/router.go index 619c1110..5dc4de53 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -2,13 +2,17 @@ package adapter import ( "context" + "net" "net/http" "net/netip" + "sync" "github.com/sagernet/sing-box/common/geoip" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-dns" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/control" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/service" @@ -98,7 +102,7 @@ type DNSRule interface { type RuleSet interface { Name() string - StartContext(ctx context.Context, startContext RuleSetStartContext) error + StartContext(ctx context.Context, startContext *HTTPStartContext) error PostStart() error Metadata() RuleSetMetadata ExtractIPSet() []*netipx.IPSet @@ -118,10 +122,42 @@ type RuleSetMetadata struct { ContainsWIFIRule bool ContainsIPCIDRRule bool } +type HTTPStartContext struct { + access sync.Mutex + httpClientCache map[string]*http.Client +} -type RuleSetStartContext interface { - HTTPClient(detour string, dialer N.Dialer) *http.Client - Close() +func NewHTTPStartContext() *HTTPStartContext { + return &HTTPStartContext{ + httpClientCache: make(map[string]*http.Client), + } +} + +func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client { + c.access.Lock() + defer c.access.Unlock() + if httpClient, loaded := c.httpClientCache[detour]; loaded { + return httpClient + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: C.TCPTimeout, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + c.httpClientCache[detour] = httpClient + return httpClient +} + +func (c *HTTPStartContext) Close() { + c.access.Lock() + defer c.access.Unlock() + for _, client := range c.httpClientCache { + client.CloseIdleConnections() + } } type InterfaceUpdateListener interface { diff --git a/route/router.go b/route/router.go index c8fe94be..aac5f603 100644 --- a/route/router.go +++ b/route/router.go @@ -659,14 +659,15 @@ func (r *Router) Close() error { func (r *Router) PostStart() error { monitor := taskmonitor.New(r.logger, C.StopTimeout) + var cacheContext *adapter.HTTPStartContext if len(r.ruleSets) > 0 { monitor.Start("initialize rule-set") - ruleSetStartContext := NewRuleSetStartContext() + cacheContext = adapter.NewHTTPStartContext() var ruleSetStartGroup task.Group for i, ruleSet := range r.ruleSets { ruleSetInPlace := ruleSet ruleSetStartGroup.Append0(func(ctx context.Context) error { - err := ruleSetInPlace.StartContext(ctx, ruleSetStartContext) + err := ruleSetInPlace.StartContext(ctx, cacheContext) if err != nil { return E.Cause(err, "initialize rule-set[", i, "]") } @@ -680,7 +681,9 @@ func (r *Router) PostStart() error { if err != nil { return err } - ruleSetStartContext.Close() + } + if cacheContext != nil { + cacheContext.Close() } needFindProcess := r.needFindProcess needWIFIState := r.needWIFIState diff --git a/route/rule_set.go b/route/rule_set.go index 39d51e6f..a2c6d0c1 100644 --- a/route/rule_set.go +++ b/route/rule_set.go @@ -2,9 +2,6 @@ package route import ( "context" - "net" - "net/http" - "sync" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" @@ -12,8 +9,6 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "go4.org/netipx" ) @@ -46,43 +41,3 @@ func extractIPSetFromRule(rawRule adapter.HeadlessRule) []*netipx.IPSet { panic("unexpected rule type") } } - -var _ adapter.RuleSetStartContext = (*RuleSetStartContext)(nil) - -type RuleSetStartContext struct { - access sync.Mutex - httpClientCache map[string]*http.Client -} - -func NewRuleSetStartContext() *RuleSetStartContext { - return &RuleSetStartContext{ - httpClientCache: make(map[string]*http.Client), - } -} - -func (c *RuleSetStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client { - c.access.Lock() - defer c.access.Unlock() - if httpClient, loaded := c.httpClientCache[detour]; loaded { - return httpClient - } - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: C.TCPTimeout, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - c.httpClientCache[detour] = httpClient - return httpClient -} - -func (c *RuleSetStartContext) Close() { - c.access.Lock() - defer c.access.Unlock() - for _, client := range c.httpClientCache { - client.CloseIdleConnections() - } -} diff --git a/route/rule_set_local.go b/route/rule_set_local.go index 893842d5..4485fbad 100644 --- a/route/rule_set_local.go +++ b/route/rule_set_local.go @@ -58,7 +58,6 @@ func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.L } } if options.Type == C.RuleSetTypeLocal { - var watcher *fswatch.Watcher filePath, _ := filepath.Abs(options.LocalOptions.Path) watcher, err := fswatch.NewWatcher(fswatch.Options{ Path: []string{filePath}, @@ -85,7 +84,7 @@ func (s *LocalRuleSet) String() string { return strings.Join(F.MapToString(s.rules), " ") } -func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { +func (s *LocalRuleSet) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { if s.watcher != nil { err := s.watcher.Start() if err != nil { diff --git a/route/rule_set_remote.go b/route/rule_set_remote.go index 03662ee4..11541609 100644 --- a/route/rule_set_remote.go +++ b/route/rule_set_remote.go @@ -45,6 +45,7 @@ type RemoteRuleSet struct { lastUpdated time.Time lastEtag string updateTicker *time.Ticker + cacheFile adapter.CacheFile pauseManager pause.Manager callbackAccess sync.Mutex callbacks list.List[adapter.RuleSetUpdateCallback] @@ -78,7 +79,8 @@ func (s *RemoteRuleSet) String() string { return strings.Join(F.MapToString(s.rules), " ") } -func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { +func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { + s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx) var dialer N.Dialer if s.options.RemoteOptions.DownloadDetour != "" { outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour) @@ -94,9 +96,8 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.R dialer = outbound } s.dialer = dialer - cacheFile := service.FromContext[adapter.CacheFile](s.ctx) - if cacheFile != nil { - if savedSet := cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil { + if s.cacheFile != nil { + if savedSet := s.cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil { err := s.loadBytes(savedSet.Content) if err != nil { return E.Cause(err, "restore cached rule-set") @@ -226,7 +227,7 @@ func (s *RemoteRuleSet) loopUpdate() { } } -func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.RuleSetStartContext) error { +func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext *adapter.HTTPStartContext) error { s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL) var httpClient *http.Client if startContext != nil { @@ -257,12 +258,11 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule case http.StatusOK: case http.StatusNotModified: s.lastUpdated = time.Now() - cacheFile := service.FromContext[adapter.CacheFile](s.ctx) - if cacheFile != nil { - savedRuleSet := cacheFile.LoadRuleSet(s.options.Tag) + if s.cacheFile != nil { + savedRuleSet := s.cacheFile.LoadRuleSet(s.options.Tag) if savedRuleSet != nil { savedRuleSet.LastUpdated = s.lastUpdated - err = cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet) + err = s.cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet) if err != nil { s.logger.Error("save rule-set updated time: ", err) return nil @@ -290,9 +290,8 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule s.lastEtag = eTagHeader } s.lastUpdated = time.Now() - cacheFile := service.FromContext[adapter.CacheFile](s.ctx) - if cacheFile != nil { - err = cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{ + if s.cacheFile != nil { + err = s.cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{ LastUpdated: s.lastUpdated, Content: content, LastEtag: s.lastEtag,