URLTest improvements

This commit is contained in:
世界 2023-04-13 16:11:46 +08:00
parent 1fbe7c54bf
commit b491c350ae
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/batch" "github.com/sagernet/sing/common/batch"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -20,8 +21,9 @@ import (
) )
var ( var (
_ adapter.Outbound = (*URLTest)(nil) _ adapter.Outbound = (*URLTest)(nil)
_ adapter.OutboundGroup = (*URLTest)(nil) _ adapter.OutboundGroup = (*URLTest)(nil)
_ adapter.InterfaceUpdateListener = (*URLTest)(nil)
) )
type URLTest struct { type URLTest struct {
@ -71,7 +73,8 @@ func (s *URLTest) Start() error {
outbounds = append(outbounds, detour) outbounds = append(outbounds, detour)
} }
s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance) s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance)
return s.group.Start() go s.group.CheckOutbounds(false)
return nil
} }
func (s *URLTest) Close() error { func (s *URLTest) Close() error {
@ -93,6 +96,7 @@ func (s *URLTest) URLTest(ctx context.Context, link string) (map[string]uint16,
} }
func (s *URLTest) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (s *URLTest) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
s.group.Start()
outbound := s.group.Select(network) outbound := s.group.Select(network)
conn, err := outbound.DialContext(ctx, network, destination) conn, err := outbound.DialContext(ctx, network, destination)
if err == nil { if err == nil {
@ -104,6 +108,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M
} }
func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
s.group.Start()
outbound := s.group.Select(N.NetworkUDP) outbound := s.group.Select(N.NetworkUDP)
conn, err := outbound.ListenPacket(ctx, destination) conn, err := outbound.ListenPacket(ctx, destination)
if err == nil { if err == nil {
@ -122,6 +127,11 @@ func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, me
return NewPacketConnection(ctx, s, conn, metadata) return NewPacketConnection(ctx, s, conn, metadata)
} }
func (s *URLTest) InterfaceUpdated() error {
go s.group.CheckOutbounds(true)
return nil
}
type URLTestGroup struct { type URLTestGroup struct {
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
@ -131,7 +141,9 @@ type URLTestGroup struct {
interval time.Duration interval time.Duration
tolerance uint16 tolerance uint16
history *urltest.HistoryStorage history *urltest.HistoryStorage
checking atomic.Bool
access sync.Mutex
ticker *time.Ticker ticker *time.Ticker
close chan struct{} close chan struct{}
} }
@ -162,13 +174,23 @@ func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logg
} }
} }
func (g *URLTestGroup) Start() error { func (g *URLTestGroup) Start() {
if g.ticker != nil {
return
}
g.access.Lock()
defer g.access.Unlock()
if g.ticker != nil {
return
}
g.ticker = time.NewTicker(g.interval) g.ticker = time.NewTicker(g.interval)
go g.loopCheck() go g.loopCheck()
return nil
} }
func (g *URLTestGroup) Close() error { func (g *URLTestGroup) Close() error {
if g.ticker == nil {
return nil
}
g.ticker.Stop() g.ticker.Stop()
close(g.close) close(g.close)
return nil return nil
@ -228,25 +250,33 @@ func (g *URLTestGroup) Fallback(used adapter.Outbound) []adapter.Outbound {
} }
func (g *URLTestGroup) loopCheck() { func (g *URLTestGroup) loopCheck() {
go g.checkOutbounds() go g.CheckOutbounds(true)
for { for {
select { select {
case <-g.close: case <-g.close:
return return
case <-g.ticker.C: case <-g.ticker.C:
g.checkOutbounds() g.CheckOutbounds(false)
} }
} }
} }
func (g *URLTestGroup) checkOutbounds() { func (g *URLTestGroup) CheckOutbounds(force bool) {
_, _ = g.URLTest(g.ctx, g.link) _, _ = g.urlTest(g.ctx, g.link, force)
} }
func (g *URLTestGroup) URLTest(ctx context.Context, link string) (map[string]uint16, error) { func (g *URLTestGroup) URLTest(ctx context.Context, link string) (map[string]uint16, error) {
return g.urlTest(ctx, link, false)
}
func (g *URLTestGroup) urlTest(ctx context.Context, link string, force bool) (map[string]uint16, error) {
result := make(map[string]uint16)
if g.checking.Swap(true) {
return result, nil
}
defer g.checking.Store(false)
b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10)) b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10))
checked := make(map[string]bool) checked := make(map[string]bool)
result := make(map[string]uint16)
var resultAccess sync.Mutex var resultAccess sync.Mutex
for _, detour := range g.outbounds { for _, detour := range g.outbounds {
tag := detour.Tag() tag := detour.Tag()
@ -255,7 +285,7 @@ func (g *URLTestGroup) URLTest(ctx context.Context, link string) (map[string]uin
continue continue
} }
history := g.history.LoadURLTestHistory(realTag) history := g.history.LoadURLTestHistory(realTag)
if history != nil && time.Now().Sub(history.Time) < g.interval { if !force && history != nil && time.Now().Sub(history.Time) < g.interval {
continue continue
} }
checked[realTag] = true checked[realTag] = true