refactor: Extract services form router

This commit is contained in:
世界 2024-11-10 16:46:59 +08:00
parent 547eb2da63
commit ac115e4f42
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
27 changed files with 315 additions and 465 deletions

View file

@ -4,28 +4,28 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"net"
"time" "time"
"github.com/sagernet/sing-box/common/urltest" "github.com/sagernet/sing-box/common/urltest"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/varbin" "github.com/sagernet/sing/common/varbin"
) )
type ClashServer interface { type ClashServer interface {
Service LifecycleService
LegacyPreStarter ConnectionTracker
Mode() string Mode() string
ModeList() []string ModeList() []string
HistoryStorage() *urltest.HistoryStorage HistoryStorage() *urltest.HistoryStorage
RoutedConnection(ctx context.Context, conn net.Conn, metadata InboundContext, matchedRule Rule) (net.Conn, Tracker) }
RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext, matchedRule Rule) (N.PacketConn, Tracker)
type V2RayServer interface {
LifecycleService
StatsService() ConnectionTracker
} }
type CacheFile interface { type CacheFile interface {
Service LifecycleService
LegacyPreStarter
StoreFakeIP() bool StoreFakeIP() bool
FakeIPStorage FakeIPStorage
@ -94,10 +94,6 @@ func (s *SavedRuleSet) UnmarshalBinary(data []byte) error {
return nil return nil
} }
type Tracker interface {
Leave()
}
type OutboundGroup interface { type OutboundGroup interface {
Outbound Outbound
Now() string Now() string
@ -115,13 +111,3 @@ func OutboundTag(detour Outbound) string {
} }
return detour.Tag() return detour.Tag()
} }
type V2RayServer interface {
Service
StatsService() V2RayStatsService
}
type V2RayStatsService interface {
RoutedConnection(inbound string, outbound string, user string, conn net.Conn) net.Conn
RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn
}

View file

@ -32,7 +32,7 @@ type InboundRegistry interface {
} }
type InboundManager interface { type InboundManager interface {
NewService Lifecycle
Inbounds() []Inbound Inbounds() []Inbound
Get(tag string) (Inbound, bool) Get(tag string) (Inbound, bool)
Remove(tag string) error Remove(tag string) error

View file

@ -44,7 +44,7 @@ func (m *Manager) Start(stage adapter.StartStage) error {
for _, inbound := range m.inbounds { for _, inbound := range m.inbounds {
err := adapter.LegacyStart(inbound, stage) err := adapter.LegacyStart(inbound, stage)
if err != nil { if err != nil {
return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]") return E.Cause(err, stage, " inbound/", inbound.Type(), "[", inbound.Tag(), "]")
} }
} }
return nil return nil
@ -118,7 +118,7 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.
for _, stage := range adapter.ListStartStages { for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(inbound, stage) err = adapter.LegacyStart(inbound, stage)
if err != nil { if err != nil {
return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]") return E.Cause(err, stage, " inbound/", inbound.Type(), "[", inbound.Tag(), "]")
} }
} }
} }

View file

@ -1,5 +1,7 @@
package adapter package adapter
import E "github.com/sagernet/sing/common/exceptions"
type StartStage uint8 type StartStage uint8
const ( const (
@ -16,7 +18,7 @@ var ListStartStages = []StartStage{
StartStateStarted, StartStateStarted,
} }
func (s StartStage) Action() string { func (s StartStage) String() string {
switch s { switch s {
case StartStateInitialize: case StartStateInitialize:
return "initialize" return "initialize"
@ -25,17 +27,38 @@ func (s StartStage) Action() string {
case StartStatePostStart: case StartStatePostStart:
return "post-start" return "post-start"
case StartStateStarted: case StartStateStarted:
return "start-after-started" return "finish-start"
default: default:
panic("unknown stage") panic("unknown stage")
} }
} }
type NewService interface { type Lifecycle interface {
NewStarter Start(stage StartStage) error
Close() error Close() error
} }
type NewStarter interface { type LifecycleService interface {
Start(stage StartStage) error Name() string
Lifecycle
}
func Start(stage StartStage, services ...Lifecycle) error {
for _, service := range services {
err := service.Start(stage)
if err != nil {
return err
}
}
return nil
}
func StartNamed(stage StartStage, services []LifecycleService) error {
for _, service := range services {
err := service.Start(stage)
if err != nil {
return E.Cause(err, stage.String(), " ", service.Name())
}
}
return nil
} }

View file

@ -1,13 +1,5 @@
package adapter package adapter
type LegacyPreStarter interface {
PreStart() error
}
type LegacyPostStarter interface {
PostStart() error
}
func LegacyStart(starter any, stage StartStage) error { func LegacyStart(starter any, stage StartStage) error {
switch stage { switch stage {
case StartStateInitialize: case StartStateInitialize:
@ -22,7 +14,7 @@ func LegacyStart(starter any, stage StartStage) error {
}); isStarter { }); isStarter {
return starter.Start() return starter.Start()
} }
case StartStatePostStart: case StartStateStarted:
if postStarter, isPostStarter := starter.(interface { if postStarter, isPostStarter := starter.(interface {
PostStart() error PostStart() error
}); isPostStarter { }); isPostStarter {
@ -31,3 +23,27 @@ func LegacyStart(starter any, stage StartStage) error {
} }
return nil return nil
} }
type lifecycleServiceWrapper struct {
Service
name string
}
func NewLifecycleService(service Service, name string) LifecycleService {
return &lifecycleServiceWrapper{
Service: service,
name: name,
}
}
func (l *lifecycleServiceWrapper) Name() string {
return l.name
}
func (l *lifecycleServiceWrapper) Start(stage StartStage) error {
return LegacyStart(l.Service, stage)
}
func (l *lifecycleServiceWrapper) Close() error {
return l.Service.Close()
}

View file

@ -6,7 +6,7 @@ import (
) )
type NetworkManager interface { type NetworkManager interface {
NewService Lifecycle
InterfaceFinder() control.InterfaceFinder InterfaceFinder() control.InterfaceFinder
UpdateInterfaces() error UpdateInterfaces() error
DefaultInterface() string DefaultInterface() string

View file

@ -24,7 +24,7 @@ type OutboundRegistry interface {
} }
type OutboundManager interface { type OutboundManager interface {
NewService Lifecycle
Outbounds() []Outbound Outbounds() []Outbound
Outbound(tag string) (Outbound, bool) Outbound(tag string) (Outbound, bool)
Default() Outbound Default() Outbound

View file

@ -61,7 +61,7 @@ func (m *Manager) Start(stage adapter.StartStage) error {
for _, outbound := range outbounds { for _, outbound := range outbounds {
err := adapter.LegacyStart(outbound, stage) err := adapter.LegacyStart(outbound, stage)
if err != nil { if err != nil {
return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]") return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
} }
} }
} }
@ -234,7 +234,7 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.
for _, stage := range adapter.ListStartStages { for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(outbound, stage) err = adapter.LegacyStart(outbound, stage)
if err != nil { if err != nil {
return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]") return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
} }
} }
} }

View file

@ -19,7 +19,7 @@ import (
) )
type Router interface { type Router interface {
NewService Lifecycle
FakeIPStore() FakeIPStore FakeIPStore() FakeIPStore
@ -38,15 +38,16 @@ type Router interface {
ClearDNSCache() ClearDNSCache()
Rules() []Rule Rules() []Rule
ClashServer() ClashServer SetTracker(tracker ConnectionTracker)
SetClashServer(server ClashServer)
V2RayServer() V2RayServer
SetV2RayServer(server V2RayServer)
ResetNetwork() ResetNetwork()
} }
type ConnectionTracker interface {
RoutedConnection(ctx context.Context, conn net.Conn, metadata InboundContext, matchedRule Rule, matchOutbound Outbound) net.Conn
RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext, matchedRule Rule, matchOutbound Outbound) N.PacketConn
}
// Deprecated: Use ConnectionRouterEx instead. // Deprecated: Use ConnectionRouterEx instead.
type ConnectionRouter interface { type ConnectionRouter interface {
RouteConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error RouteConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error

216
box.go
View file

@ -11,6 +11,7 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/taskmonitor" "github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental" "github.com/sagernet/sing-box/experimental"
@ -23,6 +24,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause" "github.com/sagernet/sing/service/pause"
) )
@ -31,15 +33,13 @@ var _ adapter.Service = (*Box)(nil)
type Box struct { type Box struct {
createdAt time.Time createdAt time.Time
router adapter.Router
inbound *inbound.Manager
outbound *outbound.Manager
network *route.NetworkManager
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
preServices1 map[string]adapter.Service network *route.NetworkManager
preServices2 map[string]adapter.Service router *route.Router
postServices map[string]adapter.Service inbound *inbound.Manager
outbound *outbound.Manager
services []adapter.LifecycleService
done chan struct{} done chan struct{}
} }
@ -49,7 +49,11 @@ type Options struct {
PlatformLogWriter log.PlatformWriter PlatformLogWriter log.PlatformWriter
} }
func Context(ctx context.Context, inboundRegistry adapter.InboundRegistry, outboundRegistry adapter.OutboundRegistry) context.Context { func Context(
ctx context.Context,
inboundRegistry adapter.InboundRegistry,
outboundRegistry adapter.OutboundRegistry,
) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil { service.FromContext[adapter.InboundRegistry](ctx) == nil {
ctx = service.ContextWith[option.InboundOptionsRegistry](ctx, inboundRegistry) ctx = service.ContextWith[option.InboundOptionsRegistry](ctx, inboundRegistry)
@ -70,14 +74,17 @@ func New(options Options) (*Box, error) {
ctx = context.Background() ctx = context.Background()
} }
ctx = service.ContextWithDefaultRegistry(ctx) ctx = service.ContextWithDefaultRegistry(ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
if inboundRegistry == nil { if inboundRegistry == nil {
return nil, E.New("missing inbound registry in context") return nil, E.New("missing inbound registry in context")
} }
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
if outboundRegistry == nil { if outboundRegistry == nil {
return nil, E.New("missing outbound registry in context") return nil, E.New("missing outbound registry in context")
} }
ctx = pause.WithDefaultManager(ctx) ctx = pause.WithDefaultManager(ctx)
experimentalOptions := common.PtrValueOrDefault(options.Experimental) experimentalOptions := common.PtrValueOrDefault(options.Experimental)
applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug)) applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug))
@ -109,17 +116,19 @@ func New(options Options) (*Box, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "create log factory") return nil, E.Cause(err, "create log factory")
} }
routeOptions := common.PtrValueOrDefault(options.Route) routeOptions := common.PtrValueOrDefault(options.Route)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final) outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final)
ctx = service.ContextWith[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager)
ctx = service.ContextWith[adapter.OutboundManager](ctx, outboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions) networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize network manager") return nil, E.Cause(err, "initialize network manager")
} }
ctx = service.ContextWith[adapter.NetworkManager](ctx, networkManager) service.MustRegister[adapter.NetworkManager](ctx, networkManager)
router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS), common.PtrValueOrDefault(options.NTP)) router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize router") return nil, E.Cause(err, "initialize router")
} }
@ -182,46 +191,60 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize platform interface") return nil, E.Cause(err, "initialize platform interface")
} }
} }
preServices1 := make(map[string]adapter.Service) var services []adapter.LifecycleService
preServices2 := make(map[string]adapter.Service)
postServices := make(map[string]adapter.Service)
if needCacheFile { if needCacheFile {
cacheFile := service.FromContext[adapter.CacheFile](ctx) cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile))
if cacheFile == nil {
cacheFile = cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile))
service.MustRegister[adapter.CacheFile](ctx, cacheFile) service.MustRegister[adapter.CacheFile](ctx, cacheFile)
} services = append(services, cacheFile)
preServices1["cache file"] = cacheFile
} }
if needClashAPI { if needClashAPI {
clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI) clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI)
clashAPIOptions.ModeList = experimental.CalculateClashModeList(options.Options) clashAPIOptions.ModeList = experimental.CalculateClashModeList(options.Options)
clashServer, err := experimental.NewClashServer(ctx, logFactory.(log.ObservableFactory), clashAPIOptions) clashServer, err := experimental.NewClashServer(ctx, logFactory.(log.ObservableFactory), clashAPIOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "create clash api server") return nil, E.Cause(err, "create clash-server")
} }
router.SetClashServer(clashServer) router.SetTracker(clashServer)
preServices2["clash api"] = clashServer service.MustRegister[adapter.ClashServer](ctx, clashServer)
services = append(services, clashServer)
} }
if needV2RayAPI { if needV2RayAPI {
v2rayServer, err := experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(experimentalOptions.V2RayAPI)) v2rayServer, err := experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(experimentalOptions.V2RayAPI))
if err != nil { if err != nil {
return nil, E.Cause(err, "create v2ray api server") return nil, E.Cause(err, "create v2ray-server")
} }
router.SetV2RayServer(v2rayServer) if v2rayServer.StatsService() != nil {
preServices2["v2ray api"] = v2rayServer router.SetTracker(v2rayServer.StatsService())
services = append(services, v2rayServer)
service.MustRegister[adapter.V2RayServer](ctx, v2rayServer)
}
}
ntpOptions := common.PtrValueOrDefault(options.NTP)
if ntpOptions.Enabled {
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
if err != nil {
return nil, E.Cause(err, "create NTP service")
}
timeService := ntp.NewService(ntp.Options{
Context: ctx,
Dialer: ntpDialer,
Logger: logFactory.NewLogger("ntp"),
Server: ntpOptions.ServerOptions.Build(),
Interval: time.Duration(ntpOptions.Interval),
WriteToSystem: ntpOptions.WriteToSystem,
})
service.MustRegister[ntp.TimeService](ctx, timeService)
services = append(services, adapter.NewLifecycleService(timeService, "ntp service"))
} }
return &Box{ return &Box{
network: networkManager,
router: router, router: router,
inbound: inboundManager, inbound: inboundManager,
outbound: outboundManager, outbound: outboundManager,
network: networkManager,
createdAt: createdAt, createdAt: createdAt,
logFactory: logFactory, logFactory: logFactory,
logger: logFactory.Logger(), logger: logFactory.Logger(),
preServices1: preServices1, services: services,
preServices2: preServices2,
postServices: postServices,
done: make(chan struct{}), done: make(chan struct{}),
}, nil }, nil
} }
@ -272,43 +295,19 @@ func (s *Box) preStart() error {
if err != nil { if err != nil {
return E.Cause(err, "start logger") return E.Cause(err, "start logger")
} }
for serviceName, service := range s.preServices1 { err = adapter.StartNamed(adapter.StartStateInitialize, s.services) // cache-file clash-api v2ray-api
if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService {
monitor.Start("pre-start ", serviceName)
err := preService.PreStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "pre-start ", serviceName)
}
}
}
for serviceName, service := range s.preServices2 {
if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService {
monitor.Start("pre-start ", serviceName)
err := preService.PreStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "pre-start ", serviceName)
}
}
}
err = s.network.Start(adapter.StartStateInitialize)
if err != nil {
return E.Cause(err, "initialize network manager")
}
err = s.router.Start(adapter.StartStateInitialize)
if err != nil {
return E.Cause(err, "initialize router")
}
err = s.outbound.Start(adapter.StartStateStart)
if err != nil { if err != nil {
return err return err
} }
err = s.network.Start(adapter.StartStateStart) err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound)
if err != nil { if err != nil {
return err return err
} }
return s.router.Start(adapter.StartStateStart) err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.router)
if err != nil {
return err
}
return nil
} }
func (s *Box) start() error { func (s *Box) start() error {
@ -316,57 +315,27 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
for serviceName, service := range s.preServices1 { err = adapter.StartNamed(adapter.StartStateStart, s.services)
err = service.Start()
if err != nil { if err != nil {
return E.Cause(err, "start ", serviceName) return err
}
}
for serviceName, service := range s.preServices2 {
err = service.Start()
if err != nil {
return E.Cause(err, "start ", serviceName)
}
} }
err = s.inbound.Start(adapter.StartStateStart) err = s.inbound.Start(adapter.StartStateStart)
if err != nil { if err != nil {
return err return err
} }
for serviceName, service := range s.postServices { err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound)
err := service.Start()
if err != nil {
return E.Cause(err, "start ", serviceName)
}
}
err = s.outbound.Start(adapter.StartStatePostStart)
if err != nil { if err != nil {
return err return err
} }
err = s.network.Start(adapter.StartStatePostStart) err = adapter.StartNamed(adapter.StartStatePostStart, s.services)
if err != nil { if err != nil {
return err return err
} }
err = s.router.Start(adapter.StartStatePostStart) err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound)
if err != nil { if err != nil {
return err return err
} }
err = s.inbound.Start(adapter.StartStatePostStart) err = adapter.StartNamed(adapter.StartStateStarted, s.services)
if err != nil {
return err
}
err = s.network.Start(adapter.StartStateStarted)
if err != nil {
return err
}
err = s.router.Start(adapter.StartStateStarted)
if err != nil {
return err
}
err = s.outbound.Start(adapter.StartStateStarted)
if err != nil {
return err
}
err = s.inbound.Start(adapter.StartStateStarted)
if err != nil { if err != nil {
return err return err
} }
@ -380,47 +349,18 @@ func (s *Box) Close() error {
default: default:
close(s.done) close(s.done)
} }
monitor := taskmonitor.New(s.logger, C.StopTimeout) err := common.Close(
var errors error s.inbound, s.outbound, s.router, s.network,
for serviceName, service := range s.postServices { )
monitor.Start("close ", serviceName) for _, lifecycleService := range s.services {
errors = E.Append(errors, service.Close(), func(err error) error { err = E.Append(err, lifecycleService.Close(), func(err error) error {
return E.Cause(err, "close ", serviceName) return E.Cause(err, "close ", lifecycleService.Name())
}) })
monitor.Finish()
} }
errors = E.Errors(errors, s.inbound.Close()) err = E.Append(err, s.logFactory.Close(), func(err error) error {
errors = E.Errors(errors, s.outbound.Close())
errors = E.Errors(errors, s.network.Close())
errors = E.Errors(errors, s.router.Close())
for serviceName, service := range s.preServices1 {
monitor.Start("close ", serviceName)
errors = E.Append(errors, service.Close(), func(err error) error {
return E.Cause(err, "close ", serviceName)
})
monitor.Finish()
}
for serviceName, service := range s.preServices2 {
monitor.Start("close ", serviceName)
errors = E.Append(errors, service.Close(), func(err error) error {
return E.Cause(err, "close ", serviceName)
})
monitor.Finish()
}
if err := common.Close(s.logFactory); err != nil {
errors = E.Append(errors, err, func(err error) error {
return E.Cause(err, "close logger") return E.Cause(err, "close logger")
}) })
} return err
return errors
}
func (s *Box) Inbound() adapter.InboundManager {
return s.inbound
}
func (s *Box) Outbound() adapter.OutboundManager {
return s.outbound
} }
func (s *Box) Network() adapter.NetworkManager { func (s *Box) Network() adapter.NetworkManager {
@ -430,3 +370,11 @@ func (s *Box) Network() adapter.NetworkManager {
func (s *Box) Router() adapter.Router { func (s *Box) Router() adapter.Router {
return s.router return s.router
} }
func (s *Box) Inbound() adapter.InboundManager {
return s.inbound
}
func (s *Box) Outbound() adapter.OutboundManager {
return s.outbound
}

View file

@ -93,7 +93,18 @@ func New(ctx context.Context, options option.CacheFileOptions) *CacheFile {
} }
} }
func (c *CacheFile) start() error { func (c *CacheFile) Name() string {
return "cache-file"
}
func (c *CacheFile) Dependencies() []string {
return nil
}
func (c *CacheFile) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateInitialize {
return nil
}
const fileMode = 0o666 const fileMode = 0o666
options := bbolt.Options{Timeout: time.Second} options := bbolt.Options{Timeout: time.Second}
var ( var (
@ -151,14 +162,6 @@ func (c *CacheFile) start() error {
return nil return nil
} }
func (c *CacheFile) PreStart() error {
return c.start()
}
func (c *CacheFile) Start() error {
return nil
}
func (c *CacheFile) Close() error { func (c *CacheFile) Close() error {
if c.DB == nil { if c.DB == nil {
return nil return nil

View file

@ -136,7 +136,13 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op
return s, nil return s, nil
} }
func (s *Server) PreStart() error { func (s *Server) Name() string {
return "clash server"
}
func (s *Server) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
cacheFile := service.FromContext[adapter.CacheFile](s.ctx) cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
if cacheFile != nil { if cacheFile != nil {
mode := cacheFile.LoadMode() mode := cacheFile.LoadMode()
@ -146,10 +152,7 @@ func (s *Server) PreStart() error {
s.mode = mode s.mode = mode
} }
} }
return nil case adapter.StartStateStarted:
}
func (s *Server) Start() error {
if s.externalController { if s.externalController {
s.checkAndDownloadExternalUI() s.checkAndDownloadExternalUI()
var ( var (
@ -175,6 +178,8 @@ func (s *Server) Start() error {
} }
}() }()
} }
}
return nil return nil
} }
@ -236,14 +241,12 @@ func (s *Server) TrafficManager() *trafficontrol.Manager {
return s.trafficManager return s.trafficManager
} }
func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) { func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn {
tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule) return trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule, matchOutbound)
return tracker, tracker
} }
func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) { func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn {
tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule) return trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule, matchOutbound)
return tracker, tracker
} }
func authentication(serverSecret string) func(next http.Handler) http.Handler { func authentication(serverSecret string) func(next http.Handler) http.Handler {

View file

@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
@ -88,7 +87,6 @@ func (t TrackerMetadata) MarshalJSON() ([]byte, error) {
} }
type Tracker interface { type Tracker interface {
adapter.Tracker
Metadata() TrackerMetadata Metadata() TrackerMetadata
Close() error Close() error
} }
@ -108,10 +106,6 @@ func (tt *TCPConn) Close() error {
return tt.ExtendedConn.Close() return tt.ExtendedConn.Close()
} }
func (tt *TCPConn) Leave() {
tt.manager.Leave(tt)
}
func (tt *TCPConn) Upstream() any { func (tt *TCPConn) Upstream() any {
return tt.ExtendedConn return tt.ExtendedConn
} }
@ -124,7 +118,7 @@ func (tt *TCPConn) WriterReplaceable() bool {
return true return true
} }
func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *TCPConn { func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, matchRule adapter.Rule, matchOutbound adapter.Outbound) *TCPConn {
id, _ := uuid.NewV4() id, _ := uuid.NewV4()
var ( var (
chain []string chain []string
@ -132,12 +126,8 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont
outbound string outbound string
outboundType string outboundType string
) )
var action adapter.RuleAction if matchOutbound != nil {
if rule != nil { next = matchOutbound.Tag()
action = rule.Action()
}
if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction {
next = routeAction.Outbound
} else { } else {
next = outboundManager.Default().Tag() next = outboundManager.Default().Tag()
} }
@ -172,7 +162,7 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont
Upload: upload, Upload: upload,
Download: download, Download: download,
Chain: common.Reverse(chain), Chain: common.Reverse(chain),
Rule: rule, Rule: matchRule,
Outbound: outbound, Outbound: outbound,
OutboundType: outboundType, OutboundType: outboundType,
}, },
@ -197,10 +187,6 @@ func (ut *UDPConn) Close() error {
return ut.PacketConn.Close() return ut.PacketConn.Close()
} }
func (ut *UDPConn) Leave() {
ut.manager.Leave(ut)
}
func (ut *UDPConn) Upstream() any { func (ut *UDPConn) Upstream() any {
return ut.PacketConn return ut.PacketConn
} }
@ -213,7 +199,7 @@ func (ut *UDPConn) WriterReplaceable() bool {
return true return true
} }
func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *UDPConn { func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, matchRule adapter.Rule, matchOutbound adapter.Outbound) *UDPConn {
id, _ := uuid.NewV4() id, _ := uuid.NewV4()
var ( var (
chain []string chain []string
@ -221,12 +207,8 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound
outbound string outbound string
outboundType string outboundType string
) )
var action adapter.RuleAction if matchOutbound != nil {
if rule != nil { next = matchOutbound.Tag()
action = rule.Action()
}
if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction {
next = routeAction.Outbound
} else { } else {
next = outboundManager.Default().Tag() next = outboundManager.Default().Tag()
} }
@ -261,7 +243,7 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound
Upload: upload, Upload: upload,
Download: download, Download: download,
Chain: common.Reverse(chain), Chain: common.Reverse(chain),
Rule: rule, Rule: matchRule,
Outbound: outbound, Outbound: outbound,
OutboundType: outboundType, OutboundType: outboundType,
}, },

View file

@ -38,11 +38,7 @@ func (s *CommandServer) handleSetClashMode(conn net.Conn) error {
if service == nil { if service == nil {
return writeError(conn, E.New("service not ready")) return writeError(conn, E.New("service not ready"))
} }
clashServer := service.instance.Router().ClashServer() service.clashServer.(*clashapi.Server).SetMode(newMode)
if clashServer == nil {
return writeError(conn, E.New("Clash API disabled"))
}
clashServer.(*clashapi.Server).SetMode(newMode)
return writeError(conn, nil) return writeError(conn, nil)
} }
@ -69,18 +65,14 @@ func (s *CommandServer) handleModeConn(conn net.Conn) error {
return ctx.Err() return ctx.Err()
} }
} }
clashServer := s.service.instance.Router().ClashServer() err := writeClashModeList(conn, s.service.clashServer)
if clashServer == nil {
return binary.Write(conn, binary.BigEndian, uint16(0))
}
err := writeClashModeList(conn, clashServer)
if err != nil { if err != nil {
return err return err
} }
for { for {
select { select {
case <-s.modeUpdate: case <-s.modeUpdate:
err = varbin.Write(conn, binary.BigEndian, clashServer.Mode()) err = varbin.Write(conn, binary.BigEndian, s.service.clashServer.Mode())
if err != nil { if err != nil {
return err return err
} }

View file

@ -45,11 +45,7 @@ func (s *CommandServer) handleCloseConnection(conn net.Conn) error {
if service == nil { if service == nil {
return writeError(conn, E.New("service not ready")) return writeError(conn, E.New("service not ready"))
} }
clashServer := service.instance.Router().ClashServer() targetConn := service.clashServer.(*clashapi.Server).TrafficManager().Connection(uuid.FromStringOrNil(connId))
if clashServer == nil {
return writeError(conn, E.New("Clash API disabled"))
}
targetConn := clashServer.(*clashapi.Server).TrafficManager().Connection(uuid.FromStringOrNil(connId))
if targetConn == nil { if targetConn == nil {
return writeError(conn, E.New("connection already closed")) return writeError(conn, E.New("connection already closed"))
} }

View file

@ -49,11 +49,7 @@ func (s *CommandServer) handleConnectionsConn(conn net.Conn) error {
for { for {
service := s.service service := s.service
if service != nil { if service != nil {
clashServer := service.instance.Router().ClashServer() trafficManager = service.clashServer.(*clashapi.Server).TrafficManager()
if clashServer == nil {
return E.New("Clash API disabled")
}
trafficManager = clashServer.(*clashapi.Server).TrafficManager()
break break
} }
select { select {

View file

@ -60,7 +60,7 @@ func NewCommandServer(handler CommandServerHandler, maxLines int32) *CommandServ
func (s *CommandServer) SetService(newService *BoxService) { func (s *CommandServer) SetService(newService *BoxService) {
if newService != nil { if newService != nil {
service.PtrFromContext[urltest.HistoryStorage](newService.ctx).SetHook(s.urlTestUpdate) service.PtrFromContext[urltest.HistoryStorage](newService.ctx).SetHook(s.urlTestUpdate)
newService.instance.Router().ClashServer().(*clashapi.Server).SetModeUpdateHook(s.modeUpdate) newService.clashServer.(*clashapi.Server).SetModeUpdateHook(s.modeUpdate)
} }
s.service = newService s.service = newService
s.notifyURLTestUpdate() s.notifyURLTestUpdate()

View file

@ -31,14 +31,12 @@ func (s *CommandServer) readStatus() StatusMessage {
message.ConnectionsOut = int32(conntrack.Count()) message.ConnectionsOut = int32(conntrack.Count())
if s.service != nil { if s.service != nil {
if clashServer := s.service.instance.Router().ClashServer(); clashServer != nil {
message.TrafficAvailable = true message.TrafficAvailable = true
trafficManager := clashServer.(*clashapi.Server).TrafficManager() trafficManager := s.service.clashServer.(*clashapi.Server).TrafficManager()
message.Uplink, message.Downlink = trafficManager.Now() message.Uplink, message.Downlink = trafficManager.Now()
message.UplinkTotal, message.DownlinkTotal = trafficManager.Total() message.UplinkTotal, message.DownlinkTotal = trafficManager.Total()
message.ConnectionsIn = int32(trafficManager.ConnectionsLen()) message.ConnectionsIn = int32(trafficManager.ConnectionsLen())
} }
}
return message return message
} }

View file

@ -34,17 +34,18 @@ import (
type BoxService struct { type BoxService struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
instance *box.Box
pauseManager pause.Manager
urlTestHistoryStorage *urltest.HistoryStorage urlTestHistoryStorage *urltest.HistoryStorage
instance *box.Box
clashServer adapter.ClashServer
pauseManager pause.Manager
servicePauseFields servicePauseFields
} }
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) { func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
ctx = service.ContextWith[deprecated.Manager](ctx, new(deprecatedManager))
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID) ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
options, err := parseConfig(ctx, configContent) options, err := parseConfig(ctx, configContent)
if err != nil { if err != nil {
return nil, err return nil, err
@ -54,7 +55,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box
urlTestHistoryStorage := urltest.NewHistoryStorage() urlTestHistoryStorage := urltest.NewHistoryStorage()
ctx = service.ContextWithPtr(ctx, urlTestHistoryStorage) ctx = service.ContextWithPtr(ctx, urlTestHistoryStorage)
platformWrapper := &platformInterfaceWrapper{iif: platformInterface, useProcFS: platformInterface.UseProcFS()} platformWrapper := &platformInterfaceWrapper{iif: platformInterface, useProcFS: platformInterface.UseProcFS()}
ctx = service.ContextWith[platform.Interface](ctx, platformWrapper) service.MustRegister[platform.Interface](ctx, platformWrapper)
instance, err := box.New(box.Options{ instance, err := box.New(box.Options{
Context: ctx, Context: ctx,
Options: options, Options: options,
@ -71,6 +72,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box
instance: instance, instance: instance,
urlTestHistoryStorage: urlTestHistoryStorage, urlTestHistoryStorage: urlTestHistoryStorage,
pauseManager: service.FromContext[pause.Manager](ctx), pauseManager: service.FromContext[pause.Manager](ctx),
clashServer: service.FromContext[adapter.ClashServer](ctx),
}, nil }, nil
} }

View file

@ -44,7 +44,14 @@ func NewServer(logger log.Logger, options option.V2RayAPIOptions) (adapter.V2Ray
return server, nil return server, nil
} }
func (s *Server) Start() error { func (s *Server) Name() string {
return "v2ray server"
}
func (s *Server) Start(stage adapter.StartStage) error {
if stage != adapter.StartStatePostStart {
return nil
}
listener, err := net.Listen("tcp", s.listen) listener, err := net.Listen("tcp", s.listen)
if err != nil { if err != nil {
return err return err
@ -70,6 +77,6 @@ func (s *Server) Close() error {
) )
} }
func (s *Server) StatsService() adapter.V2RayStatsService { func (s *Server) StatsService() adapter.ConnectionTracker {
return s.statsService return s.statsService
} }

View file

@ -22,7 +22,7 @@ func init() {
} }
var ( var (
_ adapter.V2RayStatsService = (*StatsService)(nil) _ adapter.ConnectionTracker = (*StatsService)(nil)
_ StatsServiceServer = (*StatsService)(nil) _ StatsServiceServer = (*StatsService)(nil)
) )
@ -60,7 +60,10 @@ func NewStatsService(options option.V2RayStatsServiceOptions) *StatsService {
} }
} }
func (s *StatsService) RoutedConnection(inbound string, outbound string, user string, conn net.Conn) net.Conn { func (s *StatsService) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn {
inbound := metadata.Inbound
user := metadata.User
outbound := matchOutbound.Tag()
var readCounter []*atomic.Int64 var readCounter []*atomic.Int64
var writeCounter []*atomic.Int64 var writeCounter []*atomic.Int64
countInbound := inbound != "" && s.inbounds[inbound] countInbound := inbound != "" && s.inbounds[inbound]
@ -86,7 +89,10 @@ func (s *StatsService) RoutedConnection(inbound string, outbound string, user st
return bufio.NewInt64CounterConn(conn, readCounter, writeCounter) return bufio.NewInt64CounterConn(conn, readCounter, writeCounter)
} }
func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn { func (s *StatsService) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn {
inbound := metadata.Inbound
user := metadata.User
outbound := matchOutbound.Tag()
var readCounter []*atomic.Int64 var readCounter []*atomic.Int64
var writeCounter []*atomic.Int64 var writeCounter []*atomic.Int64
countInbound := inbound != "" && s.inbounds[inbound] countInbound := inbound != "" && s.inbounds[inbound]

View file

@ -76,18 +76,7 @@ func (s *URLTest) Start() error {
} }
outbounds = append(outbounds, detour) outbounds = append(outbounds, detour)
} }
group, err := NewURLTestGroup( group, err := NewURLTestGroup(s.ctx, s.outboundManager, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections)
s.ctx,
s.router,
s.outboundManager,
s.logger,
outbounds,
s.link,
s.interval,
s.tolerance,
s.idleTimeout,
s.interruptExternalConnections,
)
if err != nil { if err != nil {
return err return err
} }
@ -215,18 +204,7 @@ type URLTestGroup struct {
lastActive atomic.TypedValue[time.Time] lastActive atomic.TypedValue[time.Time]
} }
func NewURLTestGroup( func NewURLTestGroup(ctx context.Context, outboundManager adapter.OutboundManager, logger log.Logger, outbounds []adapter.Outbound, link string, interval time.Duration, tolerance uint16, idleTimeout time.Duration, interruptExternalConnections bool) (*URLTestGroup, error) {
ctx context.Context,
router adapter.Router,
outboundManager adapter.OutboundManager,
logger log.Logger,
outbounds []adapter.Outbound,
link string,
interval time.Duration,
tolerance uint16,
idleTimeout time.Duration,
interruptExternalConnections bool,
) (*URLTestGroup, error) {
if interval == 0 { if interval == 0 {
interval = C.DefaultURLTestInterval interval = C.DefaultURLTestInterval
} }
@ -241,14 +219,13 @@ func NewURLTestGroup(
} }
var history *urltest.HistoryStorage var history *urltest.HistoryStorage
if history = service.PtrFromContext[urltest.HistoryStorage](ctx); history != nil { if history = service.PtrFromContext[urltest.HistoryStorage](ctx); history != nil {
} else if clashServer := router.ClashServer(); clashServer != nil { } else if clashServer := service.FromContext[adapter.ClashServer](ctx); clashServer != nil {
history = clashServer.HistoryStorage() history = clashServer.HistoryStorage()
} else { } else {
history = urltest.NewHistoryStorage() history = urltest.NewHistoryStorage()
} }
return &URLTestGroup{ return &URLTestGroup{
ctx: ctx, ctx: ctx,
router: router,
outboundManager: outboundManager, outboundManager: outboundManager,
logger: logger, logger: logger,
outbounds: outbounds, outbounds: outbounds,

View file

@ -91,16 +91,12 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if err != nil { if err != nil {
return err return err
} }
var ( var selectedOutbound adapter.Outbound
// selectedOutbound adapter.Outbound
selectedDialer N.Dialer
selectedTag string
selectedDescription string
)
if selectedRule != nil { if selectedRule != nil {
switch action := selectedRule.Action().(type) { switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute: case *rule.RuleActionRoute:
selectedOutbound, loaded := r.outboundManager.Outbound(action.Outbound) var loaded bool
selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound)
if !loaded { if !loaded {
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
return E.New("outbound not found: ", action.Outbound) return E.New("outbound not found: ", action.Outbound)
@ -109,12 +105,6 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag()) return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag())
} }
selectedDialer = selectedOutbound
selectedTag = selectedOutbound.Tag()
selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
case *rule.RuleActionDirect:
selectedDialer = action.Dialer
selectedDescription = action.String()
case *rule.RuleActionReject: case *rule.RuleActionReject:
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@ -133,25 +123,16 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag()) return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag())
} }
selectedDialer = defaultOutbound selectedOutbound = defaultOutbound
selectedTag = defaultOutbound.Tag()
selectedDescription = F.ToString("outbound/", defaultOutbound.Type(), "[", defaultOutbound.Tag(), "]")
} }
for _, buffer := range buffers { for _, buffer := range buffers {
conn = bufio.NewCachedConn(conn, buffer) conn = bufio.NewCachedConn(conn, buffer)
} }
if r.clashServer != nil { if r.tracker != nil {
trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, selectedRule) conn = r.tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound)
defer tracker.Leave()
conn = trackerConn
} }
if r.v2rayServer != nil { legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler)
if statsService := r.v2rayServer.StatsService(); statsService != nil {
conn = statsService.RoutedConnection(metadata.Inbound, selectedTag, metadata.User, conn)
}
}
legacyOutbound, isLegacy := selectedDialer.(adapter.ConnectionHandler)
if isLegacy { if isLegacy {
err = legacyOutbound.NewConnection(ctx, conn, metadata) err = legacyOutbound.NewConnection(ctx, conn, metadata)
if err != nil { if err != nil {
@ -159,7 +140,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if onClose != nil { if onClose != nil {
onClose(err) onClose(err)
} }
return E.Cause(err, selectedDescription) return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
} else { } else {
if onClose != nil { if onClose != nil {
onClose(nil) onClose(nil)
@ -168,13 +149,13 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
return nil return nil
} }
// TODO // TODO
err = outbound.NewConnection(ctx, selectedDialer, conn, metadata) err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata)
if err != nil { if err != nil {
conn.Close() conn.Close()
if onClose != nil { if onClose != nil {
onClose(err) onClose(err)
} }
return E.Cause(err, selectedDescription) return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
} else { } else {
if onClose != nil { if onClose != nil {
onClose(nil) onClose(nil)
@ -246,16 +227,13 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
if err != nil { if err != nil {
return err return err
} }
var ( var selectedOutbound adapter.Outbound
selectedDialer N.Dialer
selectedTag string
selectedDescription string
)
var selectReturn bool var selectReturn bool
if selectedRule != nil { if selectedRule != nil {
switch action := selectedRule.Action().(type) { switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute: case *rule.RuleActionRoute:
selectedOutbound, loaded := r.outboundManager.Outbound(action.Outbound) var loaded bool
selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound)
if !loaded { if !loaded {
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("outbound not found: ", action.Outbound) return E.New("outbound not found: ", action.Outbound)
@ -264,12 +242,6 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag()) return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
} }
selectedDialer = selectedOutbound
selectedTag = selectedOutbound.Tag()
selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
case *rule.RuleActionDirect:
selectedDialer = action.Dialer
selectedDescription = action.String()
case *rule.RuleActionReject: case *rule.RuleActionReject:
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@ -285,41 +257,32 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag()) return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag())
} }
selectedDialer = defaultOutbound selectedOutbound = defaultOutbound
selectedTag = defaultOutbound.Tag()
selectedDescription = F.ToString("outbound/", defaultOutbound.Type(), "[", defaultOutbound.Tag(), "]")
} }
for _, buffer := range packetBuffers { for _, buffer := range packetBuffers {
conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination) conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination)
N.PutPacketBuffer(buffer) N.PutPacketBuffer(buffer)
} }
if r.clashServer != nil { if r.tracker != nil {
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule) conn = r.tracker.RoutedPacketConnection(ctx, conn, metadata, selectedRule, selectedOutbound)
defer tracker.Leave()
conn = trackerConn
}
if r.v2rayServer != nil {
if statsService := r.v2rayServer.StatsService(); statsService != nil {
conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedTag, metadata.User, conn)
}
} }
if metadata.FakeIP { if metadata.FakeIP {
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
} }
legacyOutbound, isLegacy := selectedDialer.(adapter.PacketConnectionHandler) legacyOutbound, isLegacy := selectedOutbound.(adapter.PacketConnectionHandler)
if isLegacy { if isLegacy {
err = legacyOutbound.NewPacketConnection(ctx, conn, metadata) err = legacyOutbound.NewPacketConnection(ctx, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
return E.Cause(err, selectedDescription) return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
} }
return nil return nil
} }
// TODO // TODO
err = outbound.NewPacketConnection(ctx, selectedDialer, conn, metadata) err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
return E.Cause(err, selectedDescription) return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
} }
return nil return nil
} }

View file

@ -27,7 +27,6 @@ import (
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause" "github.com/sagernet/sing/service/pause"
@ -63,16 +62,14 @@ type Router struct {
dnsReverseMapping *DNSReverseMapping dnsReverseMapping *DNSReverseMapping
fakeIPStore adapter.FakeIPStore fakeIPStore adapter.FakeIPStore
processSearcher process.Searcher processSearcher process.Searcher
timeService *ntp.Service
pauseManager pause.Manager pauseManager pause.Manager
clashServer adapter.ClashServer tracker adapter.ConnectionTracker
v2rayServer adapter.V2RayServer
platformInterface platform.Interface platformInterface platform.Interface
needWIFIState bool needWIFIState bool
started bool started bool
} }
func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions, ntpOptions option.NTPOptions) (*Router, error) { func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) (*Router, error) {
router := &Router{ router := &Router{
ctx: ctx, ctx: ctx,
logger: logFactory.NewLogger("router"), logger: logFactory.NewLogger("router"),
@ -94,7 +91,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
platformInterface: service.FromContext[platform.Interface](ctx), platformInterface: service.FromContext[platform.Interface](ctx),
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
} }
ctx = service.ContextWith[adapter.Router](ctx, router) service.MustRegister[adapter.Router](ctx, router)
router.dnsClient = dns.NewClient(dns.ClientOptions{ router.dnsClient = dns.NewClient(dns.ClientOptions{
DisableCache: dnsOptions.DNSClientOptions.DisableCache, DisableCache: dnsOptions.DNSClientOptions.DisableCache,
DisableExpire: dnsOptions.DNSClientOptions.DisableExpire, DisableExpire: dnsOptions.DNSClientOptions.DisableExpire,
@ -290,23 +287,6 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
} }
router.fakeIPStore = fakeip.NewStore(ctx, router.logger, inet4Range, inet6Range) router.fakeIPStore = fakeip.NewStore(ctx, router.logger, inet4Range, inet6Range)
} }
if ntpOptions.Enabled {
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
if err != nil {
return nil, E.Cause(err, "create NTP service")
}
timeService := ntp.NewService(ntp.Options{
Context: ctx,
Dialer: ntpDialer,
Logger: logFactory.NewLogger("ntp"),
Server: ntpOptions.ServerOptions.Build(),
Interval: time.Duration(ntpOptions.Interval),
WriteToSystem: ntpOptions.WriteToSystem,
})
service.MustRegister[ntp.TimeService](ctx, timeService)
router.timeService = timeService
}
return router, nil return router, nil
} }
@ -380,14 +360,6 @@ func (r *Router) Start(stage adapter.StartStage) error {
return E.Cause(err, "initialize DNS server[", i, "]") return E.Cause(err, "initialize DNS server[", i, "]")
} }
} }
if r.timeService != nil {
monitor.Start("initialize time service")
err := r.timeService.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize time service")
}
}
case adapter.StartStatePostStart: case adapter.StartStatePostStart:
var cacheContext *adapter.HTTPStartContext var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 { if len(r.ruleSets) > 0 {
@ -502,13 +474,6 @@ func (r *Router) Close() error {
}) })
monitor.Finish() monitor.Finish()
} }
if r.timeService != nil {
monitor.Start("close time service")
err = E.Append(err, r.timeService.Close(), func(err error) error {
return E.Cause(err, "close time service")
})
monitor.Finish()
}
if r.fakeIPStore != nil { if r.fakeIPStore != nil {
monitor.Start("close fakeip store") monitor.Start("close fakeip store")
err = E.Append(err, r.fakeIPStore.Close(), func(err error) error { err = E.Append(err, r.fakeIPStore.Close(), func(err error) error {
@ -536,29 +501,8 @@ func (r *Router) Rules() []adapter.Rule {
return r.rules return r.rules
} }
func (r *Router) ClashServer() adapter.ClashServer { func (r *Router) SetTracker(tracker adapter.ConnectionTracker) {
return r.clashServer r.tracker = tracker
}
func (r *Router) SetClashServer(server adapter.ClashServer) {
r.clashServer = server
}
func (r *Router) V2RayServer() adapter.V2RayServer {
return r.v2rayServer
}
func (r *Router) SetV2RayServer(server adapter.V2RayServer) {
r.v2rayServer = server
}
func (r *Router) NewError(ctx context.Context, err error) {
common.Close(err)
if E.IsClosedOrCanceled(err) {
r.logger.DebugContext(ctx, "connection closed: ", err)
return
}
r.logger.ErrorContext(ctx, err)
} }
func (r *Router) ResetNetwork() { func (r *Router) ResetNetwork() {

View file

@ -219,7 +219,7 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if options.ClashMode != "" { if options.ClashMode != "" {
item := NewClashModeItem(router, options.ClashMode) item := NewClashModeItem(ctx, options.ClashMode)
rule.items = append(rule.items, item) rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }

View file

@ -216,7 +216,7 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if options.ClashMode != "" { if options.ClashMode != "" {
item := NewClashModeItem(router, options.ClashMode) item := NewClashModeItem(ctx, options.ClashMode)
rule.items = append(rule.items, item) rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }

View file

@ -1,31 +1,38 @@
package rule package rule
import ( import (
"context"
"strings" "strings"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/service"
) )
var _ RuleItem = (*ClashModeItem)(nil) var _ RuleItem = (*ClashModeItem)(nil)
type ClashModeItem struct { type ClashModeItem struct {
router adapter.Router ctx context.Context
clashServer adapter.ClashServer
mode string mode string
} }
func NewClashModeItem(router adapter.Router, mode string) *ClashModeItem { func NewClashModeItem(ctx context.Context, mode string) *ClashModeItem {
return &ClashModeItem{ return &ClashModeItem{
router: router, ctx: ctx,
mode: mode, mode: mode,
} }
} }
func (r *ClashModeItem) Start() error {
r.clashServer = service.FromContext[adapter.ClashServer](r.ctx)
return nil
}
func (r *ClashModeItem) Match(metadata *adapter.InboundContext) bool { func (r *ClashModeItem) Match(metadata *adapter.InboundContext) bool {
clashServer := r.router.ClashServer() if r.clashServer == nil {
if clashServer == nil {
return false return false
} }
return strings.EqualFold(clashServer.Mode(), r.mode) return strings.EqualFold(r.clashServer.Mode(), r.mode)
} }
func (r *ClashModeItem) String() string { func (r *ClashModeItem) String() string {