Rename HTTP start context

This commit is contained in:
世界 2024-10-25 22:24:19 +08:00
parent 6ed9a06394
commit 327bb35ddd
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 58 additions and 66 deletions

View file

@ -2,13 +2,17 @@ package adapter
import ( import (
"context" "context"
"net"
"net/http" "net/http"
"net/netip" "net/netip"
"sync"
"github.com/sagernet/sing-box/common/geoip" "github.com/sagernet/sing-box/common/geoip"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
@ -98,7 +102,7 @@ type DNSRule interface {
type RuleSet interface { type RuleSet interface {
Name() string Name() string
StartContext(ctx context.Context, startContext RuleSetStartContext) error StartContext(ctx context.Context, startContext *HTTPStartContext) error
PostStart() error PostStart() error
Metadata() RuleSetMetadata Metadata() RuleSetMetadata
ExtractIPSet() []*netipx.IPSet ExtractIPSet() []*netipx.IPSet
@ -118,10 +122,42 @@ type RuleSetMetadata struct {
ContainsWIFIRule bool ContainsWIFIRule bool
ContainsIPCIDRRule bool ContainsIPCIDRRule bool
} }
type HTTPStartContext struct {
access sync.Mutex
httpClientCache map[string]*http.Client
}
type RuleSetStartContext interface { func NewHTTPStartContext() *HTTPStartContext {
HTTPClient(detour string, dialer N.Dialer) *http.Client return &HTTPStartContext{
Close() httpClientCache: make(map[string]*http.Client),
}
}
func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client {
c.access.Lock()
defer c.access.Unlock()
if httpClient, loaded := c.httpClientCache[detour]; loaded {
return httpClient
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
c.httpClientCache[detour] = httpClient
return httpClient
}
func (c *HTTPStartContext) Close() {
c.access.Lock()
defer c.access.Unlock()
for _, client := range c.httpClientCache {
client.CloseIdleConnections()
}
} }
type InterfaceUpdateListener interface { type InterfaceUpdateListener interface {

View file

@ -659,14 +659,15 @@ func (r *Router) Close() error {
func (r *Router) PostStart() error { func (r *Router) PostStart() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout) monitor := taskmonitor.New(r.logger, C.StopTimeout)
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 { if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set") monitor.Start("initialize rule-set")
ruleSetStartContext := NewRuleSetStartContext() cacheContext = adapter.NewHTTPStartContext()
var ruleSetStartGroup task.Group var ruleSetStartGroup task.Group
for i, ruleSet := range r.ruleSets { for i, ruleSet := range r.ruleSets {
ruleSetInPlace := ruleSet ruleSetInPlace := ruleSet
ruleSetStartGroup.Append0(func(ctx context.Context) error { ruleSetStartGroup.Append0(func(ctx context.Context) error {
err := ruleSetInPlace.StartContext(ctx, ruleSetStartContext) err := ruleSetInPlace.StartContext(ctx, cacheContext)
if err != nil { if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]") return E.Cause(err, "initialize rule-set[", i, "]")
} }
@ -680,7 +681,9 @@ func (r *Router) PostStart() error {
if err != nil { if err != nil {
return err return err
} }
ruleSetStartContext.Close() }
if cacheContext != nil {
cacheContext.Close()
} }
needFindProcess := r.needFindProcess needFindProcess := r.needFindProcess
needWIFIState := r.needWIFIState needWIFIState := r.needWIFIState

View file

@ -2,9 +2,6 @@ package route
import ( import (
"context" "context"
"net"
"net/http"
"sync"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@ -12,8 +9,6 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"go4.org/netipx" "go4.org/netipx"
) )
@ -46,43 +41,3 @@ func extractIPSetFromRule(rawRule adapter.HeadlessRule) []*netipx.IPSet {
panic("unexpected rule type") panic("unexpected rule type")
} }
} }
var _ adapter.RuleSetStartContext = (*RuleSetStartContext)(nil)
type RuleSetStartContext struct {
access sync.Mutex
httpClientCache map[string]*http.Client
}
func NewRuleSetStartContext() *RuleSetStartContext {
return &RuleSetStartContext{
httpClientCache: make(map[string]*http.Client),
}
}
func (c *RuleSetStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client {
c.access.Lock()
defer c.access.Unlock()
if httpClient, loaded := c.httpClientCache[detour]; loaded {
return httpClient
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
c.httpClientCache[detour] = httpClient
return httpClient
}
func (c *RuleSetStartContext) Close() {
c.access.Lock()
defer c.access.Unlock()
for _, client := range c.httpClientCache {
client.CloseIdleConnections()
}
}

View file

@ -58,7 +58,6 @@ func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.L
} }
} }
if options.Type == C.RuleSetTypeLocal { if options.Type == C.RuleSetTypeLocal {
var watcher *fswatch.Watcher
filePath, _ := filepath.Abs(options.LocalOptions.Path) filePath, _ := filepath.Abs(options.LocalOptions.Path)
watcher, err := fswatch.NewWatcher(fswatch.Options{ watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: []string{filePath}, Path: []string{filePath},
@ -85,7 +84,7 @@ func (s *LocalRuleSet) String() string {
return strings.Join(F.MapToString(s.rules), " ") return strings.Join(F.MapToString(s.rules), " ")
} }
func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { func (s *LocalRuleSet) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
if s.watcher != nil { if s.watcher != nil {
err := s.watcher.Start() err := s.watcher.Start()
if err != nil { if err != nil {

View file

@ -45,6 +45,7 @@ type RemoteRuleSet struct {
lastUpdated time.Time lastUpdated time.Time
lastEtag string lastEtag string
updateTicker *time.Ticker updateTicker *time.Ticker
cacheFile adapter.CacheFile
pauseManager pause.Manager pauseManager pause.Manager
callbackAccess sync.Mutex callbackAccess sync.Mutex
callbacks list.List[adapter.RuleSetUpdateCallback] callbacks list.List[adapter.RuleSetUpdateCallback]
@ -78,7 +79,8 @@ func (s *RemoteRuleSet) String() string {
return strings.Join(F.MapToString(s.rules), " ") return strings.Join(F.MapToString(s.rules), " ")
} }
func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
var dialer N.Dialer var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" { if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour) outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour)
@ -94,9 +96,8 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.R
dialer = outbound dialer = outbound
} }
s.dialer = dialer s.dialer = dialer
cacheFile := service.FromContext[adapter.CacheFile](s.ctx) if s.cacheFile != nil {
if cacheFile != nil { if savedSet := s.cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil {
if savedSet := cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil {
err := s.loadBytes(savedSet.Content) err := s.loadBytes(savedSet.Content)
if err != nil { if err != nil {
return E.Cause(err, "restore cached rule-set") return E.Cause(err, "restore cached rule-set")
@ -226,7 +227,7 @@ func (s *RemoteRuleSet) loopUpdate() {
} }
} }
func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.RuleSetStartContext) error { func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext *adapter.HTTPStartContext) error {
s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL) s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL)
var httpClient *http.Client var httpClient *http.Client
if startContext != nil { if startContext != nil {
@ -257,12 +258,11 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule
case http.StatusOK: case http.StatusOK:
case http.StatusNotModified: case http.StatusNotModified:
s.lastUpdated = time.Now() s.lastUpdated = time.Now()
cacheFile := service.FromContext[adapter.CacheFile](s.ctx) if s.cacheFile != nil {
if cacheFile != nil { savedRuleSet := s.cacheFile.LoadRuleSet(s.options.Tag)
savedRuleSet := cacheFile.LoadRuleSet(s.options.Tag)
if savedRuleSet != nil { if savedRuleSet != nil {
savedRuleSet.LastUpdated = s.lastUpdated savedRuleSet.LastUpdated = s.lastUpdated
err = cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet) err = s.cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet)
if err != nil { if err != nil {
s.logger.Error("save rule-set updated time: ", err) s.logger.Error("save rule-set updated time: ", err)
return nil return nil
@ -290,9 +290,8 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule
s.lastEtag = eTagHeader s.lastEtag = eTagHeader
} }
s.lastUpdated = time.Now() s.lastUpdated = time.Now()
cacheFile := service.FromContext[adapter.CacheFile](s.ctx) if s.cacheFile != nil {
if cacheFile != nil { err = s.cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{
err = cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{
LastUpdated: s.lastUpdated, LastUpdated: s.lastUpdated,
Content: content, Content: content,
LastEtag: s.lastEtag, LastEtag: s.lastEtag,