From e83331c2d9755762c963c42b9f0c0c7e27b1f627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 10 Nov 2024 16:46:59 +0800 Subject: [PATCH] refactor: Extract services form router --- adapter/experimental.go | 30 +-- adapter/inbound.go | 2 +- adapter/inbound/manager.go | 4 +- adapter/lifecycle.go | 35 ++- adapter/lifecycle_legacy.go | 34 ++- adapter/network.go | 2 +- adapter/outbound.go | 2 +- adapter/outbound/manager.go | 4 +- adapter/router.go | 13 +- box.go | 246 +++++++----------- experimental/cachefile/cache.go | 21 +- experimental/clashapi/server.go | 81 +++--- .../clashapi/trafficontrol/tracker.go | 34 +-- experimental/libbox/command_clash_mode.go | 14 +- .../libbox/command_close_connection.go | 6 +- experimental/libbox/command_connections.go | 6 +- experimental/libbox/command_server.go | 2 +- experimental/libbox/command_status.go | 10 +- experimental/libbox/service.go | 10 +- experimental/v2rayapi/server.go | 11 +- experimental/v2rayapi/stats.go | 12 +- protocol/group/urltest.go | 29 +-- route/route.go | 77 ++---- route/router.go | 66 +---- route/rule/rule_default.go | 2 +- route/rule/rule_dns.go | 2 +- route/rule/rule_item_clash_mode.go | 23 +- 27 files changed, 314 insertions(+), 464 deletions(-) diff --git a/adapter/experimental.go b/adapter/experimental.go index bee24c4f..f22ff9b2 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -4,28 +4,28 @@ import ( "bytes" "context" "encoding/binary" - "net" "time" "github.com/sagernet/sing-box/common/urltest" "github.com/sagernet/sing-dns" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/varbin" ) type ClashServer interface { - Service - LegacyPreStarter + LifecycleService + ConnectionTracker Mode() string ModeList() []string HistoryStorage() *urltest.HistoryStorage - RoutedConnection(ctx context.Context, conn net.Conn, metadata InboundContext, matchedRule Rule) (net.Conn, Tracker) - RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext, matchedRule Rule) (N.PacketConn, Tracker) +} + +type V2RayServer interface { + LifecycleService + StatsService() ConnectionTracker } type CacheFile interface { - Service - LegacyPreStarter + LifecycleService StoreFakeIP() bool FakeIPStorage @@ -94,10 +94,6 @@ func (s *SavedRuleSet) UnmarshalBinary(data []byte) error { return nil } -type Tracker interface { - Leave() -} - type OutboundGroup interface { Outbound Now() string @@ -115,13 +111,3 @@ func OutboundTag(detour Outbound) string { } return detour.Tag() } - -type V2RayServer interface { - Service - StatsService() V2RayStatsService -} - -type V2RayStatsService interface { - RoutedConnection(inbound string, outbound string, user string, conn net.Conn) net.Conn - RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn -} diff --git a/adapter/inbound.go b/adapter/inbound.go index d80e59f7..7932237d 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -32,7 +32,7 @@ type InboundRegistry interface { } type InboundManager interface { - NewService + Lifecycle Inbounds() []Inbound Get(tag string) (Inbound, bool) Remove(tag string) error diff --git a/adapter/inbound/manager.go b/adapter/inbound/manager.go index 69a3ad46..d2be0f36 100644 --- a/adapter/inbound/manager.go +++ b/adapter/inbound/manager.go @@ -44,7 +44,7 @@ func (m *Manager) Start(stage adapter.StartStage) error { for _, inbound := range m.inbounds { err := adapter.LegacyStart(inbound, stage) if err != nil { - return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]") + return E.Cause(err, stage, " inbound/", inbound.Type(), "[", inbound.Tag(), "]") } } return nil @@ -118,7 +118,7 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log. for _, stage := range adapter.ListStartStages { err = adapter.LegacyStart(inbound, stage) if err != nil { - return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]") + return E.Cause(err, stage, " inbound/", inbound.Type(), "[", inbound.Tag(), "]") } } } diff --git a/adapter/lifecycle.go b/adapter/lifecycle.go index 85de425d..aff9fadb 100644 --- a/adapter/lifecycle.go +++ b/adapter/lifecycle.go @@ -1,5 +1,7 @@ package adapter +import E "github.com/sagernet/sing/common/exceptions" + type StartStage uint8 const ( @@ -16,7 +18,7 @@ var ListStartStages = []StartStage{ StartStateStarted, } -func (s StartStage) Action() string { +func (s StartStage) String() string { switch s { case StartStateInitialize: return "initialize" @@ -25,17 +27,38 @@ func (s StartStage) Action() string { case StartStatePostStart: return "post-start" case StartStateStarted: - return "start-after-started" + return "finish-start" default: panic("unknown stage") } } -type NewService interface { - NewStarter +type Lifecycle interface { + Start(stage StartStage) error Close() error } -type NewStarter interface { - Start(stage StartStage) error +type LifecycleService interface { + Name() string + Lifecycle +} + +func Start(stage StartStage, services ...Lifecycle) error { + for _, service := range services { + err := service.Start(stage) + if err != nil { + return err + } + } + return nil +} + +func StartNamed(stage StartStage, services []LifecycleService) error { + for _, service := range services { + err := service.Start(stage) + if err != nil { + return E.Cause(err, stage.String(), " ", service.Name()) + } + } + return nil } diff --git a/adapter/lifecycle_legacy.go b/adapter/lifecycle_legacy.go index 5968131b..0c8c75da 100644 --- a/adapter/lifecycle_legacy.go +++ b/adapter/lifecycle_legacy.go @@ -1,13 +1,5 @@ package adapter -type LegacyPreStarter interface { - PreStart() error -} - -type LegacyPostStarter interface { - PostStart() error -} - func LegacyStart(starter any, stage StartStage) error { switch stage { case StartStateInitialize: @@ -22,7 +14,7 @@ func LegacyStart(starter any, stage StartStage) error { }); isStarter { return starter.Start() } - case StartStatePostStart: + case StartStateStarted: if postStarter, isPostStarter := starter.(interface { PostStart() error }); isPostStarter { @@ -31,3 +23,27 @@ func LegacyStart(starter any, stage StartStage) error { } return nil } + +type lifecycleServiceWrapper struct { + Service + name string +} + +func NewLifecycleService(service Service, name string) LifecycleService { + return &lifecycleServiceWrapper{ + Service: service, + name: name, + } +} + +func (l *lifecycleServiceWrapper) Name() string { + return l.name +} + +func (l *lifecycleServiceWrapper) Start(stage StartStage) error { + return LegacyStart(l.Service, stage) +} + +func (l *lifecycleServiceWrapper) Close() error { + return l.Service.Close() +} diff --git a/adapter/network.go b/adapter/network.go index 0ce27411..533bfced 100644 --- a/adapter/network.go +++ b/adapter/network.go @@ -6,7 +6,7 @@ import ( ) type NetworkManager interface { - NewService + Lifecycle InterfaceFinder() control.InterfaceFinder UpdateInterfaces() error DefaultInterface() string diff --git a/adapter/outbound.go b/adapter/outbound.go index b170398a..2c2b1091 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -24,7 +24,7 @@ type OutboundRegistry interface { } type OutboundManager interface { - NewService + Lifecycle Outbounds() []Outbound Outbound(tag string) (Outbound, bool) Default() Outbound diff --git a/adapter/outbound/manager.go b/adapter/outbound/manager.go index 10a89a1c..84a105c5 100644 --- a/adapter/outbound/manager.go +++ b/adapter/outbound/manager.go @@ -61,7 +61,7 @@ func (m *Manager) Start(stage adapter.StartStage) error { for _, outbound := range outbounds { err := adapter.LegacyStart(outbound, stage) if err != nil { - return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]") + return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]") } } } @@ -234,7 +234,7 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log. for _, stage := range adapter.ListStartStages { err = adapter.LegacyStart(outbound, stage) if err != nil { - return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]") + return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]") } } } diff --git a/adapter/router.go b/adapter/router.go index 40a461a7..6dd39357 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -19,7 +19,7 @@ import ( ) type Router interface { - NewService + Lifecycle FakeIPStore() FakeIPStore @@ -38,15 +38,16 @@ type Router interface { ClearDNSCache() Rules() []Rule - ClashServer() ClashServer - SetClashServer(server ClashServer) - - V2RayServer() V2RayServer - SetV2RayServer(server V2RayServer) + SetTracker(tracker ConnectionTracker) ResetNetwork() } +type ConnectionTracker interface { + RoutedConnection(ctx context.Context, conn net.Conn, metadata InboundContext, matchedRule Rule, matchOutbound Outbound) net.Conn + RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext, matchedRule Rule, matchOutbound Outbound) N.PacketConn +} + // Deprecated: Use ConnectionRouterEx instead. type ConnectionRouter interface { RouteConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error diff --git a/box.go b/box.go index 8eac0dfa..3b69617f 100644 --- a/box.go +++ b/box.go @@ -11,6 +11,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/taskmonitor" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/experimental" @@ -23,6 +24,7 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" ) @@ -30,17 +32,15 @@ import ( var _ adapter.Service = (*Box)(nil) type Box struct { - createdAt time.Time - router adapter.Router - inbound *inbound.Manager - outbound *outbound.Manager - network *route.NetworkManager - logFactory log.Factory - logger log.ContextLogger - preServices1 map[string]adapter.Service - preServices2 map[string]adapter.Service - postServices map[string]adapter.Service - done chan struct{} + createdAt time.Time + logFactory log.Factory + logger log.ContextLogger + network *route.NetworkManager + router *route.Router + inbound *inbound.Manager + outbound *outbound.Manager + services []adapter.LifecycleService + done chan struct{} } type Options struct { @@ -49,7 +49,11 @@ type Options struct { PlatformLogWriter log.PlatformWriter } -func Context(ctx context.Context, inboundRegistry adapter.InboundRegistry, outboundRegistry adapter.OutboundRegistry) context.Context { +func Context( + ctx context.Context, + inboundRegistry adapter.InboundRegistry, + outboundRegistry adapter.OutboundRegistry, +) context.Context { if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || service.FromContext[adapter.InboundRegistry](ctx) == nil { ctx = service.ContextWith[option.InboundOptionsRegistry](ctx, inboundRegistry) @@ -70,14 +74,17 @@ func New(options Options) (*Box, error) { ctx = context.Background() } ctx = service.ContextWithDefaultRegistry(ctx) + inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) if inboundRegistry == nil { return nil, E.New("missing inbound registry in context") } + outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) if outboundRegistry == nil { return nil, E.New("missing outbound registry in context") } + ctx = pause.WithDefaultManager(ctx) experimentalOptions := common.PtrValueOrDefault(options.Experimental) applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug)) @@ -109,17 +116,19 @@ func New(options Options) (*Box, error) { if err != nil { return nil, E.Cause(err, "create log factory") } + routeOptions := common.PtrValueOrDefault(options.Route) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry) outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final) - ctx = service.ContextWith[adapter.InboundManager](ctx, inboundManager) - ctx = service.ContextWith[adapter.OutboundManager](ctx, outboundManager) + service.MustRegister[adapter.InboundManager](ctx, inboundManager) + service.MustRegister[adapter.OutboundManager](ctx, outboundManager) + networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions) if err != nil { return nil, E.Cause(err, "initialize network manager") } - ctx = service.ContextWith[adapter.NetworkManager](ctx, networkManager) - router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS), common.PtrValueOrDefault(options.NTP)) + service.MustRegister[adapter.NetworkManager](ctx, networkManager) + router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS)) if err != nil { return nil, E.Cause(err, "initialize router") } @@ -182,47 +191,61 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize platform interface") } } - preServices1 := make(map[string]adapter.Service) - preServices2 := make(map[string]adapter.Service) - postServices := make(map[string]adapter.Service) + var services []adapter.LifecycleService if needCacheFile { - cacheFile := service.FromContext[adapter.CacheFile](ctx) - if cacheFile == nil { - cacheFile = cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile)) - service.MustRegister[adapter.CacheFile](ctx, cacheFile) - } - preServices1["cache file"] = cacheFile + cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile)) + service.MustRegister[adapter.CacheFile](ctx, cacheFile) + services = append(services, cacheFile) } if needClashAPI { clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI) clashAPIOptions.ModeList = experimental.CalculateClashModeList(options.Options) clashServer, err := experimental.NewClashServer(ctx, logFactory.(log.ObservableFactory), clashAPIOptions) if err != nil { - return nil, E.Cause(err, "create clash api server") + return nil, E.Cause(err, "create clash-server") } - router.SetClashServer(clashServer) - preServices2["clash api"] = clashServer + router.SetTracker(clashServer) + service.MustRegister[adapter.ClashServer](ctx, clashServer) + services = append(services, clashServer) } if needV2RayAPI { v2rayServer, err := experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(experimentalOptions.V2RayAPI)) if err != nil { - return nil, E.Cause(err, "create v2ray api server") + return nil, E.Cause(err, "create v2ray-server") } - router.SetV2RayServer(v2rayServer) - preServices2["v2ray api"] = v2rayServer + if v2rayServer.StatsService() != nil { + router.SetTracker(v2rayServer.StatsService()) + services = append(services, v2rayServer) + service.MustRegister[adapter.V2RayServer](ctx, v2rayServer) + } + } + ntpOptions := common.PtrValueOrDefault(options.NTP) + if ntpOptions.Enabled { + ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions) + if err != nil { + return nil, E.Cause(err, "create NTP service") + } + timeService := ntp.NewService(ntp.Options{ + Context: ctx, + Dialer: ntpDialer, + Logger: logFactory.NewLogger("ntp"), + Server: ntpOptions.ServerOptions.Build(), + Interval: time.Duration(ntpOptions.Interval), + WriteToSystem: ntpOptions.WriteToSystem, + }) + service.MustRegister[ntp.TimeService](ctx, timeService) + services = append(services, adapter.NewLifecycleService(timeService, "ntp service")) } return &Box{ - router: router, - inbound: inboundManager, - outbound: outboundManager, - network: networkManager, - createdAt: createdAt, - logFactory: logFactory, - logger: logFactory.Logger(), - preServices1: preServices1, - preServices2: preServices2, - postServices: postServices, - done: make(chan struct{}), + network: networkManager, + router: router, + inbound: inboundManager, + outbound: outboundManager, + createdAt: createdAt, + logFactory: logFactory, + logger: logFactory.Logger(), + services: services, + done: make(chan struct{}), }, nil } @@ -272,43 +295,19 @@ func (s *Box) preStart() error { if err != nil { return E.Cause(err, "start logger") } - for serviceName, service := range s.preServices1 { - if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService { - monitor.Start("pre-start ", serviceName) - err := preService.PreStart() - monitor.Finish() - if err != nil { - return E.Cause(err, "pre-start ", serviceName) - } - } - } - for serviceName, service := range s.preServices2 { - if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService { - monitor.Start("pre-start ", serviceName) - err := preService.PreStart() - monitor.Finish() - if err != nil { - return E.Cause(err, "pre-start ", serviceName) - } - } - } - err = s.network.Start(adapter.StartStateInitialize) - if err != nil { - return E.Cause(err, "initialize network manager") - } - err = s.router.Start(adapter.StartStateInitialize) - if err != nil { - return E.Cause(err, "initialize router") - } - err = s.outbound.Start(adapter.StartStateStart) + err = adapter.StartNamed(adapter.StartStateInitialize, s.services) // cache-file clash-api v2ray-api if err != nil { return err } - err = s.network.Start(adapter.StartStateStart) + err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound) if err != nil { return err } - return s.router.Start(adapter.StartStateStart) + err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.router) + if err != nil { + return err + } + return nil } func (s *Box) start() error { @@ -316,57 +315,27 @@ func (s *Box) start() error { if err != nil { return err } - for serviceName, service := range s.preServices1 { - err = service.Start() - if err != nil { - return E.Cause(err, "start ", serviceName) - } - } - for serviceName, service := range s.preServices2 { - err = service.Start() - if err != nil { - return E.Cause(err, "start ", serviceName) - } + err = adapter.StartNamed(adapter.StartStateStart, s.services) + if err != nil { + return err } err = s.inbound.Start(adapter.StartStateStart) if err != nil { return err } - for serviceName, service := range s.postServices { - err := service.Start() - if err != nil { - return E.Cause(err, "start ", serviceName) - } - } - err = s.outbound.Start(adapter.StartStatePostStart) + err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound) if err != nil { return err } - err = s.network.Start(adapter.StartStatePostStart) + err = adapter.StartNamed(adapter.StartStatePostStart, s.services) if err != nil { return err } - err = s.router.Start(adapter.StartStatePostStart) + err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound) if err != nil { return err } - err = s.inbound.Start(adapter.StartStatePostStart) - if err != nil { - return err - } - err = s.network.Start(adapter.StartStateStarted) - if err != nil { - return err - } - err = s.router.Start(adapter.StartStateStarted) - if err != nil { - return err - } - err = s.outbound.Start(adapter.StartStateStarted) - if err != nil { - return err - } - err = s.inbound.Start(adapter.StartStateStarted) + err = adapter.StartNamed(adapter.StartStateStarted, s.services) if err != nil { return err } @@ -380,47 +349,18 @@ func (s *Box) Close() error { default: close(s.done) } - monitor := taskmonitor.New(s.logger, C.StopTimeout) - var errors error - for serviceName, service := range s.postServices { - monitor.Start("close ", serviceName) - errors = E.Append(errors, service.Close(), func(err error) error { - return E.Cause(err, "close ", serviceName) - }) - monitor.Finish() - } - errors = E.Errors(errors, s.inbound.Close()) - errors = E.Errors(errors, s.outbound.Close()) - errors = E.Errors(errors, s.network.Close()) - errors = E.Errors(errors, s.router.Close()) - for serviceName, service := range s.preServices1 { - monitor.Start("close ", serviceName) - errors = E.Append(errors, service.Close(), func(err error) error { - return E.Cause(err, "close ", serviceName) - }) - monitor.Finish() - } - for serviceName, service := range s.preServices2 { - monitor.Start("close ", serviceName) - errors = E.Append(errors, service.Close(), func(err error) error { - return E.Cause(err, "close ", serviceName) - }) - monitor.Finish() - } - if err := common.Close(s.logFactory); err != nil { - errors = E.Append(errors, err, func(err error) error { - return E.Cause(err, "close logger") + err := common.Close( + s.inbound, s.outbound, s.router, s.network, + ) + for _, lifecycleService := range s.services { + err = E.Append(err, lifecycleService.Close(), func(err error) error { + return E.Cause(err, "close ", lifecycleService.Name()) }) } - return errors -} - -func (s *Box) Inbound() adapter.InboundManager { - return s.inbound -} - -func (s *Box) Outbound() adapter.OutboundManager { - return s.outbound + err = E.Append(err, s.logFactory.Close(), func(err error) error { + return E.Cause(err, "close logger") + }) + return err } func (s *Box) Network() adapter.NetworkManager { @@ -430,3 +370,11 @@ func (s *Box) Network() adapter.NetworkManager { func (s *Box) Router() adapter.Router { return s.router } + +func (s *Box) Inbound() adapter.InboundManager { + return s.inbound +} + +func (s *Box) Outbound() adapter.OutboundManager { + return s.outbound +} diff --git a/experimental/cachefile/cache.go b/experimental/cachefile/cache.go index 1027588f..498b9474 100644 --- a/experimental/cachefile/cache.go +++ b/experimental/cachefile/cache.go @@ -93,7 +93,18 @@ func New(ctx context.Context, options option.CacheFileOptions) *CacheFile { } } -func (c *CacheFile) start() error { +func (c *CacheFile) Name() string { + return "cache-file" +} + +func (c *CacheFile) Dependencies() []string { + return nil +} + +func (c *CacheFile) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateInitialize { + return nil + } const fileMode = 0o666 options := bbolt.Options{Timeout: time.Second} var ( @@ -151,14 +162,6 @@ func (c *CacheFile) start() error { return nil } -func (c *CacheFile) PreStart() error { - return c.start() -} - -func (c *CacheFile) Start() error { - return nil -} - func (c *CacheFile) Close() error { if c.DB == nil { return nil diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index 9592ab98..a9422b43 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -136,45 +136,50 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op return s, nil } -func (s *Server) PreStart() error { - cacheFile := service.FromContext[adapter.CacheFile](s.ctx) - if cacheFile != nil { - mode := cacheFile.LoadMode() - if common.Any(s.modeList, func(it string) bool { - return strings.EqualFold(it, mode) - }) { - s.mode = mode - } - } - return nil +func (s *Server) Name() string { + return "clash server" } -func (s *Server) Start() error { - if s.externalController { - s.checkAndDownloadExternalUI() - var ( - listener net.Listener - err error - ) - for i := 0; i < 3; i++ { - listener, err = net.Listen("tcp", s.httpServer.Addr) - if runtime.GOOS == "android" && errors.Is(err, syscall.EADDRINUSE) { - time.Sleep(100 * time.Millisecond) - continue +func (s *Server) Start(stage adapter.StartStage) error { + switch stage { + case adapter.StartStateStart: + cacheFile := service.FromContext[adapter.CacheFile](s.ctx) + if cacheFile != nil { + mode := cacheFile.LoadMode() + if common.Any(s.modeList, func(it string) bool { + return strings.EqualFold(it, mode) + }) { + s.mode = mode } - break } - if err != nil { - return E.Cause(err, "external controller listen error") - } - s.logger.Info("restful api listening at ", listener.Addr()) - go func() { - err = s.httpServer.Serve(listener) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - s.logger.Error("external controller serve error: ", err) + case adapter.StartStateStarted: + if s.externalController { + s.checkAndDownloadExternalUI() + var ( + listener net.Listener + err error + ) + for i := 0; i < 3; i++ { + listener, err = net.Listen("tcp", s.httpServer.Addr) + if runtime.GOOS == "android" && errors.Is(err, syscall.EADDRINUSE) { + time.Sleep(100 * time.Millisecond) + continue + } + break } - }() + if err != nil { + return E.Cause(err, "external controller listen error") + } + s.logger.Info("restful api listening at ", listener.Addr()) + go func() { + err = s.httpServer.Serve(listener) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("external controller serve error: ", err) + } + }() + } } + return nil } @@ -236,14 +241,12 @@ func (s *Server) TrafficManager() *trafficontrol.Manager { return s.trafficManager } -func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) { - tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule) - return tracker, tracker +func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { + return trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule, matchOutbound) } -func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) { - tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule) - return tracker, tracker +func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn { + return trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule, matchOutbound) } func authentication(serverSecret string) func(next http.Handler) http.Handler { diff --git a/experimental/clashapi/trafficontrol/tracker.go b/experimental/clashapi/trafficontrol/tracker.go index df5437fa..e324be20 100644 --- a/experimental/clashapi/trafficontrol/tracker.go +++ b/experimental/clashapi/trafficontrol/tracker.go @@ -5,7 +5,6 @@ import ( "time" "github.com/sagernet/sing-box/adapter" - R "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/bufio" @@ -88,7 +87,6 @@ func (t TrackerMetadata) MarshalJSON() ([]byte, error) { } type Tracker interface { - adapter.Tracker Metadata() TrackerMetadata Close() error } @@ -108,10 +106,6 @@ func (tt *TCPConn) Close() error { return tt.ExtendedConn.Close() } -func (tt *TCPConn) Leave() { - tt.manager.Leave(tt) -} - func (tt *TCPConn) Upstream() any { return tt.ExtendedConn } @@ -124,7 +118,7 @@ func (tt *TCPConn) WriterReplaceable() bool { return true } -func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *TCPConn { +func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, matchRule adapter.Rule, matchOutbound adapter.Outbound) *TCPConn { id, _ := uuid.NewV4() var ( chain []string @@ -132,12 +126,8 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont outbound string outboundType string ) - var action adapter.RuleAction - if rule != nil { - action = rule.Action() - } - if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction { - next = routeAction.Outbound + if matchOutbound != nil { + next = matchOutbound.Tag() } else { next = outboundManager.Default().Tag() } @@ -172,7 +162,7 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont Upload: upload, Download: download, Chain: common.Reverse(chain), - Rule: rule, + Rule: matchRule, Outbound: outbound, OutboundType: outboundType, }, @@ -197,10 +187,6 @@ func (ut *UDPConn) Close() error { return ut.PacketConn.Close() } -func (ut *UDPConn) Leave() { - ut.manager.Leave(ut) -} - func (ut *UDPConn) Upstream() any { return ut.PacketConn } @@ -213,7 +199,7 @@ func (ut *UDPConn) WriterReplaceable() bool { return true } -func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *UDPConn { +func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, matchRule adapter.Rule, matchOutbound adapter.Outbound) *UDPConn { id, _ := uuid.NewV4() var ( chain []string @@ -221,12 +207,8 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound outbound string outboundType string ) - var action adapter.RuleAction - if rule != nil { - action = rule.Action() - } - if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction { - next = routeAction.Outbound + if matchOutbound != nil { + next = matchOutbound.Tag() } else { next = outboundManager.Default().Tag() } @@ -261,7 +243,7 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound Upload: upload, Download: download, Chain: common.Reverse(chain), - Rule: rule, + Rule: matchRule, Outbound: outbound, OutboundType: outboundType, }, diff --git a/experimental/libbox/command_clash_mode.go b/experimental/libbox/command_clash_mode.go index 1b6eb470..af69047f 100644 --- a/experimental/libbox/command_clash_mode.go +++ b/experimental/libbox/command_clash_mode.go @@ -38,11 +38,7 @@ func (s *CommandServer) handleSetClashMode(conn net.Conn) error { if service == nil { return writeError(conn, E.New("service not ready")) } - clashServer := service.instance.Router().ClashServer() - if clashServer == nil { - return writeError(conn, E.New("Clash API disabled")) - } - clashServer.(*clashapi.Server).SetMode(newMode) + service.clashServer.(*clashapi.Server).SetMode(newMode) return writeError(conn, nil) } @@ -69,18 +65,14 @@ func (s *CommandServer) handleModeConn(conn net.Conn) error { return ctx.Err() } } - clashServer := s.service.instance.Router().ClashServer() - if clashServer == nil { - return binary.Write(conn, binary.BigEndian, uint16(0)) - } - err := writeClashModeList(conn, clashServer) + err := writeClashModeList(conn, s.service.clashServer) if err != nil { return err } for { select { case <-s.modeUpdate: - err = varbin.Write(conn, binary.BigEndian, clashServer.Mode()) + err = varbin.Write(conn, binary.BigEndian, s.service.clashServer.Mode()) if err != nil { return err } diff --git a/experimental/libbox/command_close_connection.go b/experimental/libbox/command_close_connection.go index a2b05e56..46f7023f 100644 --- a/experimental/libbox/command_close_connection.go +++ b/experimental/libbox/command_close_connection.go @@ -45,11 +45,7 @@ func (s *CommandServer) handleCloseConnection(conn net.Conn) error { if service == nil { return writeError(conn, E.New("service not ready")) } - clashServer := service.instance.Router().ClashServer() - if clashServer == nil { - return writeError(conn, E.New("Clash API disabled")) - } - targetConn := clashServer.(*clashapi.Server).TrafficManager().Connection(uuid.FromStringOrNil(connId)) + targetConn := service.clashServer.(*clashapi.Server).TrafficManager().Connection(uuid.FromStringOrNil(connId)) if targetConn == nil { return writeError(conn, E.New("connection already closed")) } diff --git a/experimental/libbox/command_connections.go b/experimental/libbox/command_connections.go index b51c7352..39d9303c 100644 --- a/experimental/libbox/command_connections.go +++ b/experimental/libbox/command_connections.go @@ -49,11 +49,7 @@ func (s *CommandServer) handleConnectionsConn(conn net.Conn) error { for { service := s.service if service != nil { - clashServer := service.instance.Router().ClashServer() - if clashServer == nil { - return E.New("Clash API disabled") - } - trafficManager = clashServer.(*clashapi.Server).TrafficManager() + trafficManager = service.clashServer.(*clashapi.Server).TrafficManager() break } select { diff --git a/experimental/libbox/command_server.go b/experimental/libbox/command_server.go index 26b4aa79..798a52bd 100644 --- a/experimental/libbox/command_server.go +++ b/experimental/libbox/command_server.go @@ -60,7 +60,7 @@ func NewCommandServer(handler CommandServerHandler, maxLines int32) *CommandServ func (s *CommandServer) SetService(newService *BoxService) { if newService != nil { service.PtrFromContext[urltest.HistoryStorage](newService.ctx).SetHook(s.urlTestUpdate) - newService.instance.Router().ClashServer().(*clashapi.Server).SetModeUpdateHook(s.modeUpdate) + newService.clashServer.(*clashapi.Server).SetModeUpdateHook(s.modeUpdate) } s.service = newService s.notifyURLTestUpdate() diff --git a/experimental/libbox/command_status.go b/experimental/libbox/command_status.go index a6280d0f..f8709ef0 100644 --- a/experimental/libbox/command_status.go +++ b/experimental/libbox/command_status.go @@ -31,12 +31,10 @@ func (s *CommandServer) readStatus() StatusMessage { message.ConnectionsOut = int32(conntrack.Count()) if s.service != nil { - if clashServer := s.service.instance.Router().ClashServer(); clashServer != nil { - message.TrafficAvailable = true - trafficManager := clashServer.(*clashapi.Server).TrafficManager() - message.UplinkTotal, message.DownlinkTotal = trafficManager.Total() - message.ConnectionsIn = int32(trafficManager.ConnectionsLen()) - } + message.TrafficAvailable = true + trafficManager := s.service.clashServer.(*clashapi.Server).TrafficManager() + message.UplinkTotal, message.DownlinkTotal = trafficManager.Total() + message.ConnectionsIn = int32(trafficManager.ConnectionsLen()) } return message diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 42a5129c..b44abc54 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -34,17 +34,18 @@ import ( type BoxService struct { ctx context.Context cancel context.CancelFunc - instance *box.Box - pauseManager pause.Manager urlTestHistoryStorage *urltest.HistoryStorage + instance *box.Box + clashServer adapter.ClashServer + pauseManager pause.Manager servicePauseFields } func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) { ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) - ctx = service.ContextWith[deprecated.Manager](ctx, new(deprecatedManager)) ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID) + service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager)) options, err := parseConfig(ctx, configContent) if err != nil { return nil, err @@ -54,7 +55,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box urlTestHistoryStorage := urltest.NewHistoryStorage() ctx = service.ContextWithPtr(ctx, urlTestHistoryStorage) platformWrapper := &platformInterfaceWrapper{iif: platformInterface, useProcFS: platformInterface.UseProcFS()} - ctx = service.ContextWith[platform.Interface](ctx, platformWrapper) + service.MustRegister[platform.Interface](ctx, platformWrapper) instance, err := box.New(box.Options{ Context: ctx, Options: options, @@ -71,6 +72,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box instance: instance, urlTestHistoryStorage: urlTestHistoryStorage, pauseManager: service.FromContext[pause.Manager](ctx), + clashServer: service.FromContext[adapter.ClashServer](ctx), }, nil } diff --git a/experimental/v2rayapi/server.go b/experimental/v2rayapi/server.go index 8b4b4385..8ebae1c4 100644 --- a/experimental/v2rayapi/server.go +++ b/experimental/v2rayapi/server.go @@ -44,7 +44,14 @@ func NewServer(logger log.Logger, options option.V2RayAPIOptions) (adapter.V2Ray return server, nil } -func (s *Server) Start() error { +func (s *Server) Name() string { + return "v2ray server" +} + +func (s *Server) Start(stage adapter.StartStage) error { + if stage != adapter.StartStatePostStart { + return nil + } listener, err := net.Listen("tcp", s.listen) if err != nil { return err @@ -70,6 +77,6 @@ func (s *Server) Close() error { ) } -func (s *Server) StatsService() adapter.V2RayStatsService { +func (s *Server) StatsService() adapter.ConnectionTracker { return s.statsService } diff --git a/experimental/v2rayapi/stats.go b/experimental/v2rayapi/stats.go index 38b9a301..6c44518f 100644 --- a/experimental/v2rayapi/stats.go +++ b/experimental/v2rayapi/stats.go @@ -22,7 +22,7 @@ func init() { } var ( - _ adapter.V2RayStatsService = (*StatsService)(nil) + _ adapter.ConnectionTracker = (*StatsService)(nil) _ StatsServiceServer = (*StatsService)(nil) ) @@ -60,7 +60,10 @@ func NewStatsService(options option.V2RayStatsServiceOptions) *StatsService { } } -func (s *StatsService) RoutedConnection(inbound string, outbound string, user string, conn net.Conn) net.Conn { +func (s *StatsService) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { + inbound := metadata.Inbound + user := metadata.User + outbound := matchOutbound.Tag() var readCounter []*atomic.Int64 var writeCounter []*atomic.Int64 countInbound := inbound != "" && s.inbounds[inbound] @@ -86,7 +89,10 @@ func (s *StatsService) RoutedConnection(inbound string, outbound string, user st return bufio.NewInt64CounterConn(conn, readCounter, writeCounter) } -func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn { +func (s *StatsService) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn { + inbound := metadata.Inbound + user := metadata.User + outbound := matchOutbound.Tag() var readCounter []*atomic.Int64 var writeCounter []*atomic.Int64 countInbound := inbound != "" && s.inbounds[inbound] diff --git a/protocol/group/urltest.go b/protocol/group/urltest.go index 4d76a31c..f1a84b50 100644 --- a/protocol/group/urltest.go +++ b/protocol/group/urltest.go @@ -76,18 +76,7 @@ func (s *URLTest) Start() error { } outbounds = append(outbounds, detour) } - group, err := NewURLTestGroup( - s.ctx, - s.router, - s.outboundManager, - s.logger, - outbounds, - s.link, - s.interval, - s.tolerance, - s.idleTimeout, - s.interruptExternalConnections, - ) + group, err := NewURLTestGroup(s.ctx, s.outboundManager, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections) if err != nil { return err } @@ -215,18 +204,7 @@ type URLTestGroup struct { lastActive atomic.TypedValue[time.Time] } -func NewURLTestGroup( - ctx context.Context, - router adapter.Router, - outboundManager adapter.OutboundManager, - logger log.Logger, - outbounds []adapter.Outbound, - link string, - interval time.Duration, - tolerance uint16, - idleTimeout time.Duration, - interruptExternalConnections bool, -) (*URLTestGroup, error) { +func NewURLTestGroup(ctx context.Context, outboundManager adapter.OutboundManager, logger log.Logger, outbounds []adapter.Outbound, link string, interval time.Duration, tolerance uint16, idleTimeout time.Duration, interruptExternalConnections bool) (*URLTestGroup, error) { if interval == 0 { interval = C.DefaultURLTestInterval } @@ -241,14 +219,13 @@ func NewURLTestGroup( } var history *urltest.HistoryStorage if history = service.PtrFromContext[urltest.HistoryStorage](ctx); history != nil { - } else if clashServer := router.ClashServer(); clashServer != nil { + } else if clashServer := service.FromContext[adapter.ClashServer](ctx); clashServer != nil { history = clashServer.HistoryStorage() } else { history = urltest.NewHistoryStorage() } return &URLTestGroup{ ctx: ctx, - router: router, outboundManager: outboundManager, logger: logger, outbounds: outbounds, diff --git a/route/route.go b/route/route.go index d55b4ec1..1c4da4b7 100644 --- a/route/route.go +++ b/route/route.go @@ -91,16 +91,12 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if err != nil { return err } - var ( - // selectedOutbound adapter.Outbound - selectedDialer N.Dialer - selectedTag string - selectedDescription string - ) + var selectedOutbound adapter.Outbound if selectedRule != nil { switch action := selectedRule.Action().(type) { case *rule.RuleActionRoute: - selectedOutbound, loaded := r.outboundManager.Outbound(action.Outbound) + var loaded bool + selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound) if !loaded { buf.ReleaseMulti(buffers) return E.New("outbound not found: ", action.Outbound) @@ -109,12 +105,6 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad buf.ReleaseMulti(buffers) return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag()) } - selectedDialer = selectedOutbound - selectedTag = selectedOutbound.Tag() - selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") - case *rule.RuleActionDirect: - selectedDialer = action.Dialer - selectedDescription = action.String() case *rule.RuleActionReject: buf.ReleaseMulti(buffers) N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) @@ -133,25 +123,16 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad buf.ReleaseMulti(buffers) return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag()) } - selectedDialer = defaultOutbound - selectedTag = defaultOutbound.Tag() - selectedDescription = F.ToString("outbound/", defaultOutbound.Type(), "[", defaultOutbound.Tag(), "]") + selectedOutbound = defaultOutbound } for _, buffer := range buffers { conn = bufio.NewCachedConn(conn, buffer) } - if r.clashServer != nil { - trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, selectedRule) - defer tracker.Leave() - conn = trackerConn + if r.tracker != nil { + conn = r.tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound) } - if r.v2rayServer != nil { - if statsService := r.v2rayServer.StatsService(); statsService != nil { - conn = statsService.RoutedConnection(metadata.Inbound, selectedTag, metadata.User, conn) - } - } - legacyOutbound, isLegacy := selectedDialer.(adapter.ConnectionHandler) + legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler) if isLegacy { err = legacyOutbound.NewConnection(ctx, conn, metadata) if err != nil { @@ -159,7 +140,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if onClose != nil { onClose(err) } - return E.Cause(err, selectedDescription) + return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")) } else { if onClose != nil { onClose(nil) @@ -168,13 +149,13 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad return nil } // TODO - err = outbound.NewConnection(ctx, selectedDialer, conn, metadata) + err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata) if err != nil { conn.Close() if onClose != nil { onClose(err) } - return E.Cause(err, selectedDescription) + return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")) } else { if onClose != nil { onClose(nil) @@ -246,16 +227,13 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m if err != nil { return err } - var ( - selectedDialer N.Dialer - selectedTag string - selectedDescription string - ) + var selectedOutbound adapter.Outbound var selectReturn bool if selectedRule != nil { switch action := selectedRule.Action().(type) { case *rule.RuleActionRoute: - selectedOutbound, loaded := r.outboundManager.Outbound(action.Outbound) + var loaded bool + selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound) if !loaded { N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("outbound not found: ", action.Outbound) @@ -264,12 +242,6 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag()) } - selectedDialer = selectedOutbound - selectedTag = selectedOutbound.Tag() - selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") - case *rule.RuleActionDirect: - selectedDialer = action.Dialer - selectedDescription = action.String() case *rule.RuleActionReject: N.ReleaseMultiPacketBuffer(packetBuffers) N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) @@ -285,41 +257,32 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag()) } - selectedDialer = defaultOutbound - selectedTag = defaultOutbound.Tag() - selectedDescription = F.ToString("outbound/", defaultOutbound.Type(), "[", defaultOutbound.Tag(), "]") + selectedOutbound = defaultOutbound } for _, buffer := range packetBuffers { conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination) N.PutPacketBuffer(buffer) } - if r.clashServer != nil { - trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule) - defer tracker.Leave() - conn = trackerConn - } - if r.v2rayServer != nil { - if statsService := r.v2rayServer.StatsService(); statsService != nil { - conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedTag, metadata.User, conn) - } + if r.tracker != nil { + conn = r.tracker.RoutedPacketConnection(ctx, conn, metadata, selectedRule, selectedOutbound) } if metadata.FakeIP { conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) } - legacyOutbound, isLegacy := selectedDialer.(adapter.PacketConnectionHandler) + legacyOutbound, isLegacy := selectedOutbound.(adapter.PacketConnectionHandler) if isLegacy { err = legacyOutbound.NewPacketConnection(ctx, conn, metadata) N.CloseOnHandshakeFailure(conn, onClose, err) if err != nil { - return E.Cause(err, selectedDescription) + return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")) } return nil } // TODO - err = outbound.NewPacketConnection(ctx, selectedDialer, conn, metadata) + err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata) N.CloseOnHandshakeFailure(conn, onClose, err) if err != nil { - return E.Cause(err, selectedDescription) + return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")) } return nil } diff --git a/route/router.go b/route/router.go index 62b37447..1f760a86 100644 --- a/route/router.go +++ b/route/router.go @@ -27,7 +27,6 @@ import ( F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" @@ -63,16 +62,14 @@ type Router struct { dnsReverseMapping *DNSReverseMapping fakeIPStore adapter.FakeIPStore processSearcher process.Searcher - timeService *ntp.Service pauseManager pause.Manager - clashServer adapter.ClashServer - v2rayServer adapter.V2RayServer + tracker adapter.ConnectionTracker platformInterface platform.Interface needWIFIState bool started bool } -func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions, ntpOptions option.NTPOptions) (*Router, error) { +func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) (*Router, error) { router := &Router{ ctx: ctx, logger: logFactory.NewLogger("router"), @@ -94,7 +91,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route platformInterface: service.FromContext[platform.Interface](ctx), needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), } - ctx = service.ContextWith[adapter.Router](ctx, router) + service.MustRegister[adapter.Router](ctx, router) router.dnsClient = dns.NewClient(dns.ClientOptions{ DisableCache: dnsOptions.DNSClientOptions.DisableCache, DisableExpire: dnsOptions.DNSClientOptions.DisableExpire, @@ -290,23 +287,6 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route } router.fakeIPStore = fakeip.NewStore(ctx, router.logger, inet4Range, inet6Range) } - - if ntpOptions.Enabled { - ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions) - if err != nil { - return nil, E.Cause(err, "create NTP service") - } - timeService := ntp.NewService(ntp.Options{ - Context: ctx, - Dialer: ntpDialer, - Logger: logFactory.NewLogger("ntp"), - Server: ntpOptions.ServerOptions.Build(), - Interval: time.Duration(ntpOptions.Interval), - WriteToSystem: ntpOptions.WriteToSystem, - }) - service.MustRegister[ntp.TimeService](ctx, timeService) - router.timeService = timeService - } return router, nil } @@ -380,14 +360,6 @@ func (r *Router) Start(stage adapter.StartStage) error { return E.Cause(err, "initialize DNS server[", i, "]") } } - if r.timeService != nil { - monitor.Start("initialize time service") - err := r.timeService.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize time service") - } - } case adapter.StartStatePostStart: var cacheContext *adapter.HTTPStartContext if len(r.ruleSets) > 0 { @@ -502,13 +474,6 @@ func (r *Router) Close() error { }) monitor.Finish() } - if r.timeService != nil { - monitor.Start("close time service") - err = E.Append(err, r.timeService.Close(), func(err error) error { - return E.Cause(err, "close time service") - }) - monitor.Finish() - } if r.fakeIPStore != nil { monitor.Start("close fakeip store") err = E.Append(err, r.fakeIPStore.Close(), func(err error) error { @@ -536,29 +501,8 @@ func (r *Router) Rules() []adapter.Rule { return r.rules } -func (r *Router) ClashServer() adapter.ClashServer { - return r.clashServer -} - -func (r *Router) SetClashServer(server adapter.ClashServer) { - r.clashServer = server -} - -func (r *Router) V2RayServer() adapter.V2RayServer { - return r.v2rayServer -} - -func (r *Router) SetV2RayServer(server adapter.V2RayServer) { - r.v2rayServer = server -} - -func (r *Router) NewError(ctx context.Context, err error) { - common.Close(err) - if E.IsClosedOrCanceled(err) { - r.logger.DebugContext(ctx, "connection closed: ", err) - return - } - r.logger.ErrorContext(ctx, err) +func (r *Router) SetTracker(tracker adapter.ConnectionTracker) { + r.tracker = tracker } func (r *Router) ResetNetwork() { diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index 33a8e16c..12d9e96a 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -219,7 +219,7 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio rule.allItems = append(rule.allItems, item) } if options.ClashMode != "" { - item := NewClashModeItem(router, options.ClashMode) + item := NewClashModeItem(ctx, options.ClashMode) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index df5f3f33..1ec652b2 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -216,7 +216,7 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op rule.allItems = append(rule.allItems, item) } if options.ClashMode != "" { - item := NewClashModeItem(router, options.ClashMode) + item := NewClashModeItem(ctx, options.ClashMode) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } diff --git a/route/rule/rule_item_clash_mode.go b/route/rule/rule_item_clash_mode.go index aa5126cb..fe2347a0 100644 --- a/route/rule/rule_item_clash_mode.go +++ b/route/rule/rule_item_clash_mode.go @@ -1,31 +1,38 @@ package rule import ( + "context" "strings" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/service" ) var _ RuleItem = (*ClashModeItem)(nil) type ClashModeItem struct { - router adapter.Router - mode string + ctx context.Context + clashServer adapter.ClashServer + mode string } -func NewClashModeItem(router adapter.Router, mode string) *ClashModeItem { +func NewClashModeItem(ctx context.Context, mode string) *ClashModeItem { return &ClashModeItem{ - router: router, - mode: mode, + ctx: ctx, + mode: mode, } } +func (r *ClashModeItem) Start() error { + r.clashServer = service.FromContext[adapter.ClashServer](r.ctx) + return nil +} + func (r *ClashModeItem) Match(metadata *adapter.InboundContext) bool { - clashServer := r.router.ClashServer() - if clashServer == nil { + if r.clashServer == nil { return false } - return strings.EqualFold(clashServer.Mode(), r.mode) + return strings.EqualFold(r.clashServer.Mode(), r.mode) } func (r *ClashModeItem) String() string {