Add high-level boundary management interface

I rewrote the router's boundary management part to implement dynamic management
from a high-level box interface. This also includes a number of changes I made
in the process of rewriting some messy parts, such as the Outbound tree
bottom-top starter.
This commit is contained in:
beryll1um 2024-10-14 02:15:29 +02:00 committed by Vitalii
parent 0b42f26af8
commit 99840ba0f4
13 changed files with 503 additions and 352 deletions

View file

@ -18,14 +18,28 @@ import (
) )
type Router interface { type Router interface {
Service AddOutbound(outbound Outbound) error
AddInbound(inbound Inbound) error
RemoveOutbound(tag string) error
RemoveInbound(tag string) error
PreStarter PreStarter
StartOutbounds() error
Service
StartInbounds() error
PostStarter PostStarter
Cleanup() error Cleanup() error
DefaultOutbound(network string) (Outbound, error)
Outbounds() []Outbound Outbounds() []Outbound
Outbound(tag string) (Outbound, bool) Outbound(tag string) (Outbound, bool)
DefaultOutbound(network string) (Outbound, error) Inbound(tag string) (Inbound, bool)
FakeIPStore() FakeIPStore FakeIPStore() FakeIPStore

225
box.go
View file

@ -29,16 +29,16 @@ import (
var _ adapter.Service = (*Box)(nil) var _ adapter.Service = (*Box)(nil)
type Box struct { type Box struct {
createdAt time.Time createdAt time.Time
router adapter.Router router adapter.Router
inbounds []adapter.Inbound logFactory log.Factory
outbounds []adapter.Outbound logger log.ContextLogger
logFactory log.Factory preServices1 map[string]adapter.Service
logger log.ContextLogger preServices2 map[string]adapter.Service
preServices1 map[string]adapter.Service postServices map[string]adapter.Service
preServices2 map[string]adapter.Service platformInterface platform.Interface
postServices map[string]adapter.Service ctx context.Context
done chan struct{} done chan struct{}
} }
type Options struct { type Options struct {
@ -97,57 +97,6 @@ func New(options Options) (*Box, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "parse route options") return nil, E.Cause(err, "parse route options")
} }
inbounds := make([]adapter.Inbound, 0, len(options.Inbounds))
outbounds := make([]adapter.Outbound, 0, len(options.Outbounds))
for i, inboundOptions := range options.Inbounds {
var in adapter.Inbound
var tag string
if inboundOptions.Tag != "" {
tag = inboundOptions.Tag
} else {
tag = F.ToString(i)
}
in, err = inbound.New(
ctx,
router,
logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")),
tag,
inboundOptions,
options.PlatformInterface,
)
if err != nil {
return nil, E.Cause(err, "parse inbound[", i, "]")
}
inbounds = append(inbounds, in)
}
for i, outboundOptions := range options.Outbounds {
var out adapter.Outbound
var tag string
if outboundOptions.Tag != "" {
tag = outboundOptions.Tag
} else {
tag = F.ToString(i)
}
out, err = outbound.New(
ctx,
router,
logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")),
tag,
outboundOptions)
if err != nil {
return nil, E.Cause(err, "parse outbound[", i, "]")
}
outbounds = append(outbounds, out)
}
err = router.Initialize(inbounds, outbounds, func() adapter.Outbound {
out, oErr := outbound.New(ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.Outbound{Type: "direct", Tag: "default"})
common.Must(oErr)
outbounds = append(outbounds, out)
return out
})
if err != nil {
return nil, err
}
if options.PlatformInterface != nil { if options.PlatformInterface != nil {
err = options.PlatformInterface.Initialize(ctx, router) err = options.PlatformInterface.Initialize(ctx, router)
if err != nil { if err != nil {
@ -183,18 +132,35 @@ func New(options Options) (*Box, error) {
router.SetV2RayServer(v2rayServer) router.SetV2RayServer(v2rayServer)
preServices2["v2ray api"] = v2rayServer preServices2["v2ray api"] = v2rayServer
} }
return &Box{ box := &Box{
router: router, router: router,
inbounds: inbounds, createdAt: createdAt,
outbounds: outbounds, logFactory: logFactory,
createdAt: createdAt, logger: logFactory.Logger(),
logFactory: logFactory, preServices1: preServices1,
logger: logFactory.Logger(), preServices2: preServices2,
preServices1: preServices1, postServices: postServices,
preServices2: preServices2, platformInterface: options.PlatformInterface,
postServices: postServices, ctx: ctx,
done: make(chan struct{}), done: make(chan struct{}),
}, nil }
for i, outOpts := range options.Outbounds {
if outOpts.Tag == "" {
outOpts.Tag = F.ToString(i)
}
if err := box.AddOutbound(outOpts); err != nil {
return nil, E.Cause(err, "create outbound")
}
}
for i, inOpts := range options.Inbounds {
if inOpts.Tag == "" {
inOpts.Tag = F.ToString(i)
}
if err := box.AddInbound(inOpts); err != nil {
return nil, E.Cause(err, "create inbound")
}
}
return box, nil
} }
func (s *Box) PreStart() error { func (s *Box) PreStart() error {
@ -263,12 +229,10 @@ func (s *Box) preStart() error {
} }
} }
} }
err = s.router.PreStart() if err := s.router.PreStart(); err != nil {
if err != nil {
return E.Cause(err, "pre-start router") return E.Cause(err, "pre-start router")
} }
err = s.startOutbounds() if err := s.router.StartOutbounds(); err != nil {
if err != nil {
return err return err
} }
return s.router.Start() return s.router.Start()
@ -291,20 +255,10 @@ func (s *Box) start() error {
return E.Cause(err, "start ", serviceName) return E.Cause(err, "start ", serviceName)
} }
} }
for i, in := range s.inbounds { if err := s.router.StartInbounds(); err != nil {
var tag string return E.Cause(err, "start inbounds")
if in.Tag() == "" {
tag = F.ToString(i)
} else {
tag = in.Tag()
}
err = in.Start()
if err != nil {
return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]")
}
} }
err = s.postStart() if err = s.postStart(); err != nil {
if err != nil {
return err return err
} }
return s.router.Cleanup() return s.router.Cleanup()
@ -317,26 +271,8 @@ func (s *Box) postStart() error {
return E.Cause(err, "start ", serviceName) return E.Cause(err, "start ", serviceName)
} }
} }
// TODO: reorganize ALL start order if err := s.router.PostStart(); err != nil {
for _, out := range s.outbounds { return E.Cause(err, "post-start")
if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound {
err := lateOutbound.PostStart()
if err != nil {
return E.Cause(err, "post-start outbound/", out.Tag())
}
}
}
err := s.router.PostStart()
if err != nil {
return err
}
for _, in := range s.inbounds {
if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound {
err = lateInbound.PostStart()
if err != nil {
return E.Cause(err, "post-start inbound/", in.Tag())
}
}
} }
return nil return nil
} }
@ -357,20 +293,6 @@ func (s *Box) Close() error {
}) })
monitor.Finish() monitor.Finish()
} }
for i, in := range s.inbounds {
monitor.Start("close inbound/", in.Type(), "[", i, "]")
errors = E.Append(errors, in.Close(), func(err error) error {
return E.Cause(err, "close inbound/", in.Type(), "[", i, "]")
})
monitor.Finish()
}
for i, out := range s.outbounds {
monitor.Start("close outbound/", out.Type(), "[", i, "]")
errors = E.Append(errors, common.Close(out), func(err error) error {
return E.Cause(err, "close outbound/", out.Type(), "[", i, "]")
})
monitor.Finish()
}
monitor.Start("close router") monitor.Start("close router")
if err := common.Close(s.router); err != nil { if err := common.Close(s.router); err != nil {
errors = E.Append(errors, err, func(err error) error { errors = E.Append(errors, err, func(err error) error {
@ -403,3 +325,58 @@ func (s *Box) Close() error {
func (s *Box) Router() adapter.Router { func (s *Box) Router() adapter.Router {
return s.router return s.router
} }
func (s *Box) AddOutbound(option option.Outbound) error {
if option.Tag == "" {
return E.New("empty tag")
}
out, err := outbound.New(
s.ctx,
s.router,
s.logFactory.NewLogger(F.ToString("outbound/", option.Type, "[", option.Tag, "]")),
option.Tag,
option,
)
if err != nil {
return E.Cause(err, "parse addited outbound")
}
if err := s.router.AddOutbound(out); err != nil {
return E.Cause(err, "outbound/", option.Type, "[", option.Tag, "]")
}
return nil
}
func (s *Box) AddInbound(option option.Inbound) error {
if option.Tag == "" {
return E.New("empty tag")
}
in, err := inbound.New(
s.ctx,
s.router,
s.logFactory.NewLogger(F.ToString("inbound/", option.Type, "[", option.Tag, "]")),
option.Tag,
option,
s.platformInterface,
)
if err != nil {
return E.Cause(err, "parse addited inbound")
}
if err := s.router.AddInbound(in); err != nil {
return E.Cause(err, "inbound/", option.Type, "[", option.Tag, "]")
}
return nil
}
func (s *Box) RemoveOutbound(tag string) error {
if err := s.router.RemoveOutbound(tag); err != nil {
return E.Cause(err, "outbound[", tag, "]")
}
return nil
}
func (s *Box) RemoveInbound(tag string) error {
if err := s.router.RemoveInbound(tag); err != nil {
return E.Cause(err, "inbound[", tag, "]")
}
return nil
}

View file

@ -1,85 +0,0 @@
package box
import (
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
func (s *Box) startOutbounds() error {
monitor := taskmonitor.New(s.logger, C.StartTimeout)
outboundTags := make(map[adapter.Outbound]string)
outbounds := make(map[string]adapter.Outbound)
for i, outboundToStart := range s.outbounds {
var outboundTag string
if outboundToStart.Tag() == "" {
outboundTag = F.ToString(i)
} else {
outboundTag = outboundToStart.Tag()
}
if _, exists := outbounds[outboundTag]; exists {
return E.New("outbound tag ", outboundTag, " duplicated")
}
outboundTags[outboundToStart] = outboundTag
outbounds[outboundTag] = outboundToStart
}
started := make(map[string]bool)
for {
canContinue := false
startOne:
for _, outboundToStart := range s.outbounds {
outboundTag := outboundTags[outboundToStart]
if started[outboundTag] {
continue
}
dependencies := outboundToStart.Dependencies()
for _, dependency := range dependencies {
if !started[dependency] {
continue startOne
}
}
started[outboundTag] = true
canContinue = true
if starter, isStarter := outboundToStart.(interface {
Start() error
}); isStarter {
monitor.Start("initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
}
}
if len(started) == len(s.outbounds) {
break
}
if canContinue {
continue
}
currentOutbound := common.Find(s.outbounds, func(it adapter.Outbound) bool {
return !started[outboundTags[it]]
})
var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
return !started[it]
})
if common.Contains(oTree, problemOutboundTag) {
return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
}
problemOutbound := outbounds[problemOutboundTag]
if problemOutbound == nil {
return E.New("dependency[", problemOutboundTag, "] not found for outbound[", outboundTags[oCurrent], "]")
}
return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
}
return lintOutbound([]string{outboundTags[currentOutbound]}, currentOutbound)
}
return nil
}

View file

@ -12,7 +12,7 @@ import (
) )
type Searcher interface { type Searcher interface {
FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*Info, error)
} }
var ErrNotFound = E.New("process not found") var ErrNotFound = E.New("process not found")
@ -29,8 +29,8 @@ type Info struct {
UserId int32 UserId int32
} }
func FindProcessInfo(searcher Searcher, ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) { func FindProcessInfo(searcher Searcher, ctx context.Context, network string, source netip.AddrPort) (*Info, error) {
info, err := searcher.FindProcessInfo(ctx, network, source, destination) info, err := searcher.FindProcessInfo(ctx, network, source)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -19,8 +19,8 @@ func NewSearcher(config Config) (Searcher, error) {
return &linuxSearcher{config.Logger}, nil return &linuxSearcher{config.Logger}, nil
} }
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) { func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*Info, error) {
inode, uid, err := resolveSocketByNetlink(network, source, destination) inode, uid, err := resolveSocketByNetlink(network, source)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -36,7 +36,7 @@ const (
pathProc = "/proc" pathProc = "/proc"
) )
func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) { func resolveSocketByNetlink(network string, source netip.AddrPort) (inode, uid uint32, err error) {
var family uint8 var family uint8
var protocol uint8 var protocol uint8

View file

@ -94,7 +94,7 @@ func (s *platformInterfaceStub) ReadWIFIState() adapter.WIFIState {
return adapter.WIFIState{} return adapter.WIFIState{}
} }
func (s *platformInterfaceStub) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*process.Info, error) { func (s *platformInterfaceStub) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*process.Info, error) {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }

View file

@ -30,7 +30,7 @@ func init() {
} }
} }
func ResolveSocketByProcSearch(network string, source, _ netip.AddrPort) int32 { func ResolveSocketByProcSearch(network string, source netip.AddrPort) int32 {
if netIndexOfLocal < 0 || netIndexOfUid < 0 { if netIndexOfLocal < 0 || netIndexOfUid < 0 {
return -1 return -1
} }

View file

@ -10,7 +10,7 @@ type PlatformInterface interface {
OpenTun(options TunOptions) (int32, error) OpenTun(options TunOptions) (int32, error)
WriteLog(message string) WriteLog(message string)
UseProcFS() bool UseProcFS() bool
FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (int32, error) FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32) (int32, error)
PackageNameByUid(uid int32) (string, error) PackageNameByUid(uid int32) (string, error)
UIDByPackageName(packageName string) (int32, error) UIDByPackageName(packageName string) (int32, error)
UsePlatformDefaultInterfaceMonitor() bool UsePlatformDefaultInterfaceMonitor() bool

View file

@ -203,10 +203,10 @@ func (w *platformInterfaceWrapper) ReadWIFIState() adapter.WIFIState {
return (adapter.WIFIState)(*wifiState) return (adapter.WIFIState)(*wifiState)
} }
func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*process.Info, error) { func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*process.Info, error) {
var uid int32 var uid int32
if w.useProcFS { if w.useProcFS {
uid = procfs.ResolveSocketByProcSearch(network, source, destination) uid = procfs.ResolveSocketByProcSearch(network, source)
if uid == -1 { if uid == -1 {
return nil, E.New("procfs: not found") return nil, E.New("procfs: not found")
} }
@ -221,7 +221,7 @@ func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network
return nil, E.New("unknown network: ", network) return nil, E.New("unknown network: ", network)
} }
var err error var err error
uid, err = w.iif.FindConnectionOwner(ipProtocol, source.Addr().String(), int32(source.Port()), destination.Addr().String(), int32(destination.Port())) uid, err = w.iif.FindConnectionOwner(ipProtocol, source.Addr().String(), int32(source.Port()))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -115,6 +115,12 @@ func (a *myInboundAdapter) Start() error {
func (a *myInboundAdapter) Close() error { func (a *myInboundAdapter) Close() error {
a.inShutdown.Store(true) a.inShutdown.Store(true)
if a.tcpListener != nil {
a.logger.Info("tcp server closed at ", a.tcpListener.Addr())
}
if a.udpConn != nil {
a.logger.Info("udp server closed at ", a.udpConn.LocalAddr())
}
var err error var err error
if a.systemProxy != nil && a.systemProxy.IsEnabled() { if a.systemProxy != nil && a.systemProxy.IsEnabled() {
err = a.systemProxy.Disable() err = a.systemProxy.Disable()

View file

@ -10,6 +10,7 @@ import (
"os/user" "os/user"
"runtime" "runtime"
"strings" "strings"
"sync"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
@ -50,11 +51,15 @@ import (
var _ adapter.Router = (*Router)(nil) var _ adapter.Router = (*Router)(nil)
type Router struct { type Router struct {
ctx context.Context ctx context.Context
logger log.ContextLogger logger log.ContextLogger
dnsLogger log.ContextLogger dnsLogger log.ContextLogger
// Currently this is responsible for protecting inbound and outbound dynamic
// control. I'm not sure if it can be separated because I haven't delved
// into the logic yet to make sure they don't interfere with each other.
// To research, may improve performance on some high-load setups.
boundary sync.RWMutex
inboundByTag map[string]adapter.Inbound inboundByTag map[string]adapter.Inbound
outbounds []adapter.Outbound
outboundByTag map[string]adapter.Outbound outboundByTag map[string]adapter.Outbound
rules []adapter.Rule rules []adapter.Rule
defaultDetour string defaultDetour string
@ -113,6 +118,7 @@ func NewRouter(
ctx: ctx, ctx: ctx,
logger: logFactory.NewLogger("router"), logger: logFactory.NewLogger("router"),
dnsLogger: logFactory.NewLogger("dns"), dnsLogger: logFactory.NewLogger("dns"),
inboundByTag: make(map[string]adapter.Inbound),
outboundByTag: make(map[string]adapter.Outbound), outboundByTag: make(map[string]adapter.Outbound),
rules: make([]adapter.Rule, 0, len(options.Rules)), rules: make([]adapter.Rule, 0, len(options.Rules)),
dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
@ -373,76 +379,6 @@ func NewRouter(
return router, nil return router, nil
} }
func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error {
inboundByTag := make(map[string]adapter.Inbound)
for _, inbound := range inbounds {
inboundByTag[inbound.Tag()] = inbound
}
outboundByTag := make(map[string]adapter.Outbound)
for _, detour := range outbounds {
outboundByTag[detour.Tag()] = detour
}
var defaultOutboundForConnection adapter.Outbound
var defaultOutboundForPacketConnection adapter.Outbound
if r.defaultDetour != "" {
detour, loaded := outboundByTag[r.defaultDetour]
if !loaded {
return E.New("default detour not found: ", r.defaultDetour)
}
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
}
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
}
}
if defaultOutboundForConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
break
}
}
}
if defaultOutboundForPacketConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
break
}
}
}
if defaultOutboundForConnection == nil || defaultOutboundForPacketConnection == nil {
detour := defaultOutbound()
if defaultOutboundForConnection == nil {
defaultOutboundForConnection = detour
}
if defaultOutboundForPacketConnection == nil {
defaultOutboundForPacketConnection = detour
}
outbounds = append(outbounds, detour)
outboundByTag[detour.Tag()] = detour
}
r.inboundByTag = inboundByTag
r.outbounds = outbounds
r.defaultOutboundForConnection = defaultOutboundForConnection
r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection
r.outboundByTag = outboundByTag
for i, rule := range r.rules {
if _, loaded := outboundByTag[rule.Outbound()]; !loaded {
return E.New("outbound not found for rule[", i, "]: ", rule.Outbound())
}
}
return nil
}
func (r *Router) Outbounds() []adapter.Outbound {
if !r.started {
return nil
}
return r.outbounds
}
func (r *Router) PreStart() error { func (r *Router) PreStart() error {
monitor := taskmonitor.New(r.logger, C.StartTimeout) monitor := taskmonitor.New(r.logger, C.StartTimeout)
if r.interfaceMonitor != nil { if r.interfaceMonitor != nil {
@ -581,9 +517,191 @@ func (r *Router) Start() error {
return nil return nil
} }
func (r *Router) Cleanup() error {
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
return nil
}
func (r *Router) AddOutbound(out adapter.Outbound) error {
r.boundary.Lock()
defer r.boundary.Unlock()
if _, ok := r.outboundByTag[out.Tag()]; ok {
return E.New("duplication of tag")
}
if r.defaultDetour == "" || r.defaultDetour == out.Tag() {
if r.defaultOutboundForConnection == nil {
if common.Contains(out.Network(), N.NetworkTCP) {
r.defaultOutboundForConnection = out
}
}
if r.defaultOutboundForPacketConnection == nil {
if common.Contains(out.Network(), N.NetworkUDP) {
r.defaultOutboundForPacketConnection = out
}
}
}
if r.started {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
monitor.Start("initialize outbound/", out.Type(), "[", out.Tag(), "]")
defer monitor.Finish()
if startable, isStartable := out.(interface{ Start() error }); isStartable {
if err := startable.Start(); err != nil {
return E.Cause(err, "start")
}
}
if err := postStartOutbound(out); err != nil {
return E.Cause(err, "post start")
}
}
r.outboundByTag[out.Tag()] = out
return nil
}
func (r *Router) AddInbound(in adapter.Inbound) error {
r.boundary.Lock()
defer r.boundary.Unlock()
if _, ok := r.inboundByTag[in.Tag()]; ok {
return E.New("duplication of tag")
}
if r.started {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
monitor.Start("initialize inbound/", in.Type(), "[", in.Tag(), "]")
defer monitor.Finish()
if err := in.Start(); err != nil {
return E.Cause(err, "start")
}
if err := postStartInbound(in); err != nil {
return E.Cause(err, "post-start")
}
}
r.inboundByTag[in.Tag()] = in
return nil
}
func (r *Router) RemoveOutbound(tag string) error {
r.boundary.Lock()
defer r.boundary.Unlock()
out, ok := r.outboundByTag[tag]
if !ok {
return E.New("unknown tag")
}
delete(r.outboundByTag, tag)
if out == r.defaultOutboundForConnection {
r.defaultOutboundForConnection = nil
}
if out == r.defaultOutboundForPacketConnection {
r.defaultOutboundForPacketConnection = nil
}
if r.defaultDetour == "" {
for _, out := range r.outboundByTag {
if r.defaultOutboundForConnection == nil {
if common.Contains(out.Network(), N.NetworkTCP) {
r.defaultOutboundForConnection = out
}
if common.Contains(out.Network(), N.NetworkUDP) {
r.defaultOutboundForPacketConnection = out
}
if r.defaultOutboundForConnection != nil && r.defaultOutboundForPacketConnection != nil {
break
}
}
}
}
if r.started {
if err := common.Close(out); err != nil {
return E.Cause(err, "close")
}
}
return nil
}
func (r *Router) RemoveInbound(tag string) error {
r.boundary.Lock()
defer r.boundary.Unlock()
in, ok := r.inboundByTag[tag]
if !ok {
return E.New("unknown tag")
}
delete(r.inboundByTag, tag)
if r.started {
if err := in.Close(); err != nil {
return E.Cause(err, "close")
}
}
return nil
}
func (r *Router) StartOutbounds() error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
startedTags := make(map[string]struct{})
for tag, out := range r.outboundByTag {
if err := (&OutboundStarter{
outboundByTag: r.outboundByTag,
startedTags: startedTags,
monitor: monitor,
}).Start(tag, make(map[string]struct{})); err != nil {
return E.Cause(err, "start outbound/", out.Type(), "[", tag, "]")
}
}
return nil
}
func (r *Router) StartInbounds() error {
for tag, in := range r.inboundByTag {
if err := in.Start(); err != nil {
return E.Cause(err, "start inbound/", in.Type(), "[", tag, "]")
}
}
return nil
}
func (r *Router) closeBounds(monitor *taskmonitor.Monitor) error {
r.boundary.Lock()
defer r.boundary.Unlock()
var err error
for tag, in := range r.inboundByTag {
monitor.Start("close inbound/", in.Type(), "[", tag, "]")
err = E.Append(err, in.Close(), func(err error) error {
return E.Cause(err, "close inbound/", in.Type(), "[", tag, "]")
})
monitor.Finish()
}
for tag, out := range r.outboundByTag {
monitor.Start("close outbound/", out.Type(), "[", tag, "]")
err = E.Append(err, common.Close(out), func(err error) error {
return E.Cause(err, "close outbound/", out.Type(), "[", tag, "]")
})
monitor.Finish()
}
return err
}
func (r *Router) Close() error { func (r *Router) Close() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout) monitor := taskmonitor.New(r.logger, C.StopTimeout)
var err error err := r.closeBounds(monitor)
for i, rule := range r.rules { for i, rule := range r.rules {
monitor.Start("close rule[", i, "]") monitor.Start("close rule[", i, "]")
err = E.Append(err, rule.Close(), func(err error) error { err = E.Append(err, rule.Close(), func(err error) error {
@ -654,10 +772,35 @@ func (r *Router) Close() error {
}) })
monitor.Finish() monitor.Finish()
} }
r.started = false
return err return err
} }
func postStartOutbound(out adapter.Outbound) error {
if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound {
if err := lateOutbound.PostStart(); err != nil {
return E.Cause(err, "outbound/", out.Type(), "[", out.Tag(), "]")
}
}
return nil
}
func postStartInbound(in adapter.Inbound) error {
if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound {
if err := lateInbound.PostStart(); err != nil {
return E.Cause(err, "inbound/", in.Type(), "[", in.Tag(), "]")
}
}
return nil
}
func (r *Router) PostStart() error { func (r *Router) PostStart() error {
// TODO: reorganize ALL start order
for _, out := range r.outboundByTag {
if err := postStartOutbound(out); err != nil {
return err
}
}
monitor := taskmonitor.New(r.logger, C.StopTimeout) monitor := taskmonitor.New(r.logger, C.StopTimeout)
if len(r.ruleSets) > 0 { if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set") monitor.Start("initialize rule-set")
@ -749,35 +892,58 @@ func (r *Router) PostStart() error {
return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]") return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]")
} }
} }
for _, in := range r.inboundByTag {
if err := postStartInbound(in); err != nil {
return err
}
}
r.started = true r.started = true
return nil return nil
} }
func (r *Router) Cleanup() error {
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
return nil
}
func (r *Router) Outbound(tag string) (adapter.Outbound, bool) {
outbound, loaded := r.outboundByTag[tag]
return outbound, loaded
}
func (r *Router) DefaultOutbound(network string) (adapter.Outbound, error) { func (r *Router) DefaultOutbound(network string) (adapter.Outbound, error) {
if network == N.NetworkTCP { r.boundary.RLock()
defer r.boundary.RUnlock()
switch network {
case N.NetworkTCP:
if r.defaultOutboundForConnection == nil { if r.defaultOutboundForConnection == nil {
return nil, E.New("missing default outbound for TCP connections") return nil, E.New("missing default outbound for TCP connections")
} }
return r.defaultOutboundForConnection, nil return r.defaultOutboundForConnection, nil
} else { case N.NetworkUDP:
if r.defaultOutboundForPacketConnection == nil { if r.defaultOutboundForPacketConnection == nil {
return nil, E.New("missing default outbound for UDP connections") return nil, E.New("missing default outbound for UDP connections")
} }
return r.defaultOutboundForPacketConnection, nil return r.defaultOutboundForPacketConnection, nil
} }
return nil, E.New("wrong network type provided")
}
func (r *Router) Outbounds() []adapter.Outbound {
if !r.started {
return nil
}
r.boundary.RLock()
defer r.boundary.RUnlock()
res := make([]adapter.Outbound, 0, len(r.outboundByTag))
for _, out := range r.outboundByTag {
res = append(res, out)
}
return res
}
func (r *Router) Outbound(tag string) (adapter.Outbound, bool) {
r.boundary.RLock()
defer r.boundary.RUnlock()
outbound, loaded := r.outboundByTag[tag]
return outbound, loaded
}
func (r *Router) Inbound(tag string) (adapter.Inbound, bool) {
r.boundary.RLock()
defer r.boundary.RUnlock()
inbound, loaded := r.inboundByTag[tag]
return inbound, loaded
} }
func (r *Router) FakeIPStore() adapter.FakeIPStore { func (r *Router) FakeIPStore() adapter.FakeIPStore {
@ -802,8 +968,8 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
if metadata.LastInbound == metadata.InboundDetour { if metadata.LastInbound == metadata.InboundDetour {
return E.New("routing loop on detour: ", metadata.InboundDetour) return E.New("routing loop on detour: ", metadata.InboundDetour)
} }
detour := r.inboundByTag[metadata.InboundDetour] detour, ok := r.Inbound(metadata.InboundDetour)
if detour == nil { if !ok {
return E.New("inbound detour not found: ", metadata.InboundDetour) return E.New("inbound detour not found: ", metadata.InboundDetour)
} }
injectable, isInjectable := detour.(adapter.InjectableInbound) injectable, isInjectable := detour.(adapter.InjectableInbound)
@ -908,15 +1074,27 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
} else if metadata.Destination.IsIPv6() { } else if metadata.Destination.IsIPv6() {
metadata.IPVersion = 6 metadata.IPVersion = 6
} }
ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForConnection)
if err != nil { rule, detour := r.ruleByMetadata(ctx, &metadata)
return err if rule == nil {
var err error
detour, err = r.DefaultOutbound(N.NetworkTCP)
if err != nil {
return E.New("missing supported outbound, closing packet connection")
}
}
if tag, loaded := outbound.TagFromContext(ctx); loaded {
if tag == detour.Tag() {
return E.New("connection loopback in outbound/", detour.Type(), "[", detour.Tag(), "]")
}
} }
if !common.Contains(detour.Network(), N.NetworkTCP) { if !common.Contains(detour.Network(), N.NetworkTCP) {
return E.New("missing supported outbound, closing connection") return E.New("missing support of network type by outbound, closing packet connection")
} }
ctx = outbound.ContextWithTag(ctx, detour.Tag())
if r.clashServer != nil { if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, matchedRule) trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, rule)
defer tracker.Leave() defer tracker.Leave()
conn = trackerConn conn = trackerConn
} }
@ -936,8 +1114,8 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
if metadata.LastInbound == metadata.InboundDetour { if metadata.LastInbound == metadata.InboundDetour {
return E.New("routing loop on detour: ", metadata.InboundDetour) return E.New("routing loop on detour: ", metadata.InboundDetour)
} }
detour := r.inboundByTag[metadata.InboundDetour] detour, ok := r.Inbound(metadata.InboundDetour)
if detour == nil { if !ok {
return E.New("inbound detour not found: ", metadata.InboundDetour) return E.New("inbound detour not found: ", metadata.InboundDetour)
} }
injectable, isInjectable := detour.(adapter.InjectableInbound) injectable, isInjectable := detour.(adapter.InjectableInbound)
@ -1082,15 +1260,27 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
} else if metadata.Destination.IsIPv6() { } else if metadata.Destination.IsIPv6() {
metadata.IPVersion = 6 metadata.IPVersion = 6
} }
ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection)
if err != nil { rule, detour := r.ruleByMetadata(ctx, &metadata)
return err if rule == nil {
var err error
detour, err = r.DefaultOutbound(N.NetworkUDP)
if err != nil {
return E.New("missing supported outbound, closing packet connection")
}
}
if tag, loaded := outbound.TagFromContext(ctx); loaded {
if tag == detour.Tag() {
return E.New("connection loopback in outbound/", detour.Type(), "[", detour.Tag(), "]")
}
} }
if !common.Contains(detour.Network(), N.NetworkUDP) { if !common.Contains(detour.Network(), N.NetworkUDP) {
return E.New("missing supported outbound, closing packet connection") return E.New("missing support of network type by outbound, closing packet connection")
} }
ctx = outbound.ContextWithTag(ctx, detour.Tag())
if r.clashServer != nil { if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, matchedRule) trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, rule)
defer tracker.Leave() defer tracker.Leave()
conn = trackerConn conn = trackerConn
} }
@ -1105,26 +1295,9 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
return detour.NewPacketConnection(ctx, conn, metadata) return detour.NewPacketConnection(ctx, conn, metadata)
} }
func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (context.Context, adapter.Rule, adapter.Outbound, error) { func (r *Router) processInfoByMetadata(ctx context.Context, metadata *adapter.InboundContext) *process.Info {
matchRule, matchOutbound := r.match0(ctx, metadata, defaultOutbound)
if contextOutbound, loaded := outbound.TagFromContext(ctx); loaded {
if contextOutbound == matchOutbound.Tag() {
return nil, nil, nil, E.New("connection loopback in outbound/", matchOutbound.Type(), "[", matchOutbound.Tag(), "]")
}
}
ctx = outbound.ContextWithTag(ctx, matchOutbound.Tag())
return ctx, matchRule, matchOutbound, nil
}
func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
if r.processSearcher != nil { if r.processSearcher != nil {
var originDestination netip.AddrPort processInfo, err := process.FindProcessInfo(r.processSearcher, ctx, metadata.Network, metadata.Source.AddrPort())
if metadata.OriginDestination.IsValid() {
originDestination = metadata.OriginDestination.AddrPort()
} else if metadata.Destination.IsIP() {
originDestination = metadata.Destination.AddrPort()
}
processInfo, err := process.FindProcessInfo(r.processSearcher, ctx, metadata.Network, metadata.Source.AddrPort(), originDestination)
if err != nil { if err != nil {
r.logger.InfoContext(ctx, "failed to search process: ", err) r.logger.InfoContext(ctx, "failed to search process: ", err)
} else { } else {
@ -1145,21 +1318,26 @@ func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, d
r.logger.InfoContext(ctx, "found user id: ", processInfo.UserId) r.logger.InfoContext(ctx, "found user id: ", processInfo.UserId)
} }
} }
metadata.ProcessInfo = processInfo return processInfo
} }
} }
return nil
}
func (r *Router) ruleByMetadata(ctx context.Context, metadata *adapter.InboundContext) (adapter.Rule, adapter.Outbound) {
metadata.ProcessInfo = r.processInfoByMetadata(ctx, metadata)
for i, rule := range r.rules { for i, rule := range r.rules {
metadata.ResetRuleCache() metadata.ResetRuleCache()
if rule.Match(metadata) { if rule.Match(metadata) {
detour := rule.Outbound() detour := rule.Outbound()
r.logger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) r.logger.DebugContext(ctx, "rule[", i, "] ", rule.String(), " => ", detour)
if outbound, loaded := r.Outbound(detour); loaded { if outbound, loaded := r.Outbound(detour); loaded {
return rule, outbound return rule, outbound
} }
r.logger.ErrorContext(ctx, "outbound not found: ", detour) r.logger.ErrorContext(ctx, "not found outbound[", detour, "]")
} }
} }
return nil, defaultOutbound return nil, nil
} }
func (r *Router) InterfaceFinder() control.InterfaceFinder { func (r *Router) InterfaceFinder() control.InterfaceFinder {
@ -1306,8 +1484,8 @@ func (r *Router) notifyNetworkUpdate(event int) {
func (r *Router) ResetNetwork() error { func (r *Router) ResetNetwork() error {
conntrack.Close() conntrack.Close()
for _, outbound := range r.outbounds { for _, out := range r.Outbounds() {
listener, isListener := outbound.(adapter.InterfaceUpdateListener) listener, isListener := out.(adapter.InterfaceUpdateListener)
if isListener { if isListener {
listener.InterfaceUpdated() listener.InterfaceUpdated()
} }
@ -1316,6 +1494,7 @@ func (r *Router) ResetNetwork() error {
for _, transport := range r.transports { for _, transport := range r.transports {
transport.Reset() transport.Reset()
} }
return nil return nil
} }

View file

@ -0,0 +1,60 @@
package route
import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
E "github.com/sagernet/sing/common/exceptions"
)
type OutboundStarter struct {
outboundByTag map[string]adapter.Outbound
startedTags map[string]struct{}
monitor *taskmonitor.Monitor
}
func (s *OutboundStarter) Start(tag string, pathIncludesTags map[string]struct{}) error {
adapter := s.outboundByTag[tag]
if adapter == nil {
return E.New("dependency[", tag, "] is not found")
}
// The outbound may have been started by another subtree in the previous,
// we don't need to start it again.
if _, ok := s.startedTags[tag]; ok {
return nil
}
// If we detected the repetition of the tags in scope of tree evaluation,
// the circular dependency is found, as it grows from bottom to top.
if _, ok := pathIncludesTags[tag]; ok {
return E.New("circular dependency related with outbound/", adapter.Type(), "[", tag, "]")
}
// This required to be done only if that outbound isn't already started,
// because some dependencies may come to the same root,
// but they aren't circular.
pathIncludesTags[tag] = struct{}{}
// Next, we are recursively starting all dependencies of the current
// outbound and repeating the cycle.
for _, dependencyTag := range adapter.Dependencies() {
if err := s.Start(dependencyTag, pathIncludesTags); err != nil {
return err
}
}
// Anyway, it will be finished soon, nothing will happen if I'll include
// Startable interface typecasting too.
s.monitor.Start("initialize outbound/", adapter.Type(), "[", tag, "]")
defer s.monitor.Finish()
// After the evaluation of entire tree let's begin to start all
// the outbounds!
if startable, isStartable := adapter.(interface{ Start() error }); isStartable {
if err := startable.Start(); err != nil {
return E.Cause(err, "initialize outbound/", adapter.Type(), "[", tag, "]")
}
}
return nil
}