Add high-level boundary management interface

I rewrote the router's boundary management part to implement dynamic management
from a high-level box interface. This also includes a number of changes I made
in the process of rewriting some messy parts, such as the Outbound tree
bottom-top starter.
This commit is contained in:
beryll1um 2024-10-14 02:15:29 +02:00 committed by Vitalii
parent 0b42f26af8
commit 99840ba0f4
13 changed files with 503 additions and 352 deletions

View file

@ -18,14 +18,28 @@ import (
)
type Router interface {
Service
AddOutbound(outbound Outbound) error
AddInbound(inbound Inbound) error
RemoveOutbound(tag string) error
RemoveInbound(tag string) error
PreStarter
StartOutbounds() error
Service
StartInbounds() error
PostStarter
Cleanup() error
DefaultOutbound(network string) (Outbound, error)
Outbounds() []Outbound
Outbound(tag string) (Outbound, bool)
DefaultOutbound(network string) (Outbound, error)
Inbound(tag string) (Inbound, bool)
FakeIPStore() FakeIPStore

225
box.go
View file

@ -29,16 +29,16 @@ import (
var _ adapter.Service = (*Box)(nil)
type Box struct {
createdAt time.Time
router adapter.Router
inbounds []adapter.Inbound
outbounds []adapter.Outbound
logFactory log.Factory
logger log.ContextLogger
preServices1 map[string]adapter.Service
preServices2 map[string]adapter.Service
postServices map[string]adapter.Service
done chan struct{}
createdAt time.Time
router adapter.Router
logFactory log.Factory
logger log.ContextLogger
preServices1 map[string]adapter.Service
preServices2 map[string]adapter.Service
postServices map[string]adapter.Service
platformInterface platform.Interface
ctx context.Context
done chan struct{}
}
type Options struct {
@ -97,57 +97,6 @@ func New(options Options) (*Box, error) {
if err != nil {
return nil, E.Cause(err, "parse route options")
}
inbounds := make([]adapter.Inbound, 0, len(options.Inbounds))
outbounds := make([]adapter.Outbound, 0, len(options.Outbounds))
for i, inboundOptions := range options.Inbounds {
var in adapter.Inbound
var tag string
if inboundOptions.Tag != "" {
tag = inboundOptions.Tag
} else {
tag = F.ToString(i)
}
in, err = inbound.New(
ctx,
router,
logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")),
tag,
inboundOptions,
options.PlatformInterface,
)
if err != nil {
return nil, E.Cause(err, "parse inbound[", i, "]")
}
inbounds = append(inbounds, in)
}
for i, outboundOptions := range options.Outbounds {
var out adapter.Outbound
var tag string
if outboundOptions.Tag != "" {
tag = outboundOptions.Tag
} else {
tag = F.ToString(i)
}
out, err = outbound.New(
ctx,
router,
logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")),
tag,
outboundOptions)
if err != nil {
return nil, E.Cause(err, "parse outbound[", i, "]")
}
outbounds = append(outbounds, out)
}
err = router.Initialize(inbounds, outbounds, func() adapter.Outbound {
out, oErr := outbound.New(ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.Outbound{Type: "direct", Tag: "default"})
common.Must(oErr)
outbounds = append(outbounds, out)
return out
})
if err != nil {
return nil, err
}
if options.PlatformInterface != nil {
err = options.PlatformInterface.Initialize(ctx, router)
if err != nil {
@ -183,18 +132,35 @@ func New(options Options) (*Box, error) {
router.SetV2RayServer(v2rayServer)
preServices2["v2ray api"] = v2rayServer
}
return &Box{
router: router,
inbounds: inbounds,
outbounds: outbounds,
createdAt: createdAt,
logFactory: logFactory,
logger: logFactory.Logger(),
preServices1: preServices1,
preServices2: preServices2,
postServices: postServices,
done: make(chan struct{}),
}, nil
box := &Box{
router: router,
createdAt: createdAt,
logFactory: logFactory,
logger: logFactory.Logger(),
preServices1: preServices1,
preServices2: preServices2,
postServices: postServices,
platformInterface: options.PlatformInterface,
ctx: ctx,
done: make(chan struct{}),
}
for i, outOpts := range options.Outbounds {
if outOpts.Tag == "" {
outOpts.Tag = F.ToString(i)
}
if err := box.AddOutbound(outOpts); err != nil {
return nil, E.Cause(err, "create outbound")
}
}
for i, inOpts := range options.Inbounds {
if inOpts.Tag == "" {
inOpts.Tag = F.ToString(i)
}
if err := box.AddInbound(inOpts); err != nil {
return nil, E.Cause(err, "create inbound")
}
}
return box, nil
}
func (s *Box) PreStart() error {
@ -263,12 +229,10 @@ func (s *Box) preStart() error {
}
}
}
err = s.router.PreStart()
if err != nil {
if err := s.router.PreStart(); err != nil {
return E.Cause(err, "pre-start router")
}
err = s.startOutbounds()
if err != nil {
if err := s.router.StartOutbounds(); err != nil {
return err
}
return s.router.Start()
@ -291,20 +255,10 @@ func (s *Box) start() error {
return E.Cause(err, "start ", serviceName)
}
}
for i, in := range s.inbounds {
var tag string
if in.Tag() == "" {
tag = F.ToString(i)
} else {
tag = in.Tag()
}
err = in.Start()
if err != nil {
return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]")
}
if err := s.router.StartInbounds(); err != nil {
return E.Cause(err, "start inbounds")
}
err = s.postStart()
if err != nil {
if err = s.postStart(); err != nil {
return err
}
return s.router.Cleanup()
@ -317,26 +271,8 @@ func (s *Box) postStart() error {
return E.Cause(err, "start ", serviceName)
}
}
// TODO: reorganize ALL start order
for _, out := range s.outbounds {
if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound {
err := lateOutbound.PostStart()
if err != nil {
return E.Cause(err, "post-start outbound/", out.Tag())
}
}
}
err := s.router.PostStart()
if err != nil {
return err
}
for _, in := range s.inbounds {
if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound {
err = lateInbound.PostStart()
if err != nil {
return E.Cause(err, "post-start inbound/", in.Tag())
}
}
if err := s.router.PostStart(); err != nil {
return E.Cause(err, "post-start")
}
return nil
}
@ -357,20 +293,6 @@ func (s *Box) Close() error {
})
monitor.Finish()
}
for i, in := range s.inbounds {
monitor.Start("close inbound/", in.Type(), "[", i, "]")
errors = E.Append(errors, in.Close(), func(err error) error {
return E.Cause(err, "close inbound/", in.Type(), "[", i, "]")
})
monitor.Finish()
}
for i, out := range s.outbounds {
monitor.Start("close outbound/", out.Type(), "[", i, "]")
errors = E.Append(errors, common.Close(out), func(err error) error {
return E.Cause(err, "close outbound/", out.Type(), "[", i, "]")
})
monitor.Finish()
}
monitor.Start("close router")
if err := common.Close(s.router); err != nil {
errors = E.Append(errors, err, func(err error) error {
@ -403,3 +325,58 @@ func (s *Box) Close() error {
func (s *Box) Router() adapter.Router {
return s.router
}
func (s *Box) AddOutbound(option option.Outbound) error {
if option.Tag == "" {
return E.New("empty tag")
}
out, err := outbound.New(
s.ctx,
s.router,
s.logFactory.NewLogger(F.ToString("outbound/", option.Type, "[", option.Tag, "]")),
option.Tag,
option,
)
if err != nil {
return E.Cause(err, "parse addited outbound")
}
if err := s.router.AddOutbound(out); err != nil {
return E.Cause(err, "outbound/", option.Type, "[", option.Tag, "]")
}
return nil
}
func (s *Box) AddInbound(option option.Inbound) error {
if option.Tag == "" {
return E.New("empty tag")
}
in, err := inbound.New(
s.ctx,
s.router,
s.logFactory.NewLogger(F.ToString("inbound/", option.Type, "[", option.Tag, "]")),
option.Tag,
option,
s.platformInterface,
)
if err != nil {
return E.Cause(err, "parse addited inbound")
}
if err := s.router.AddInbound(in); err != nil {
return E.Cause(err, "inbound/", option.Type, "[", option.Tag, "]")
}
return nil
}
func (s *Box) RemoveOutbound(tag string) error {
if err := s.router.RemoveOutbound(tag); err != nil {
return E.Cause(err, "outbound[", tag, "]")
}
return nil
}
func (s *Box) RemoveInbound(tag string) error {
if err := s.router.RemoveInbound(tag); err != nil {
return E.Cause(err, "inbound[", tag, "]")
}
return nil
}

View file

@ -1,85 +0,0 @@
package box
import (
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
func (s *Box) startOutbounds() error {
monitor := taskmonitor.New(s.logger, C.StartTimeout)
outboundTags := make(map[adapter.Outbound]string)
outbounds := make(map[string]adapter.Outbound)
for i, outboundToStart := range s.outbounds {
var outboundTag string
if outboundToStart.Tag() == "" {
outboundTag = F.ToString(i)
} else {
outboundTag = outboundToStart.Tag()
}
if _, exists := outbounds[outboundTag]; exists {
return E.New("outbound tag ", outboundTag, " duplicated")
}
outboundTags[outboundToStart] = outboundTag
outbounds[outboundTag] = outboundToStart
}
started := make(map[string]bool)
for {
canContinue := false
startOne:
for _, outboundToStart := range s.outbounds {
outboundTag := outboundTags[outboundToStart]
if started[outboundTag] {
continue
}
dependencies := outboundToStart.Dependencies()
for _, dependency := range dependencies {
if !started[dependency] {
continue startOne
}
}
started[outboundTag] = true
canContinue = true
if starter, isStarter := outboundToStart.(interface {
Start() error
}); isStarter {
monitor.Start("initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
}
}
if len(started) == len(s.outbounds) {
break
}
if canContinue {
continue
}
currentOutbound := common.Find(s.outbounds, func(it adapter.Outbound) bool {
return !started[outboundTags[it]]
})
var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
return !started[it]
})
if common.Contains(oTree, problemOutboundTag) {
return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
}
problemOutbound := outbounds[problemOutboundTag]
if problemOutbound == nil {
return E.New("dependency[", problemOutboundTag, "] not found for outbound[", outboundTags[oCurrent], "]")
}
return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
}
return lintOutbound([]string{outboundTags[currentOutbound]}, currentOutbound)
}
return nil
}

View file

@ -12,7 +12,7 @@ import (
)
type Searcher interface {
FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error)
FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*Info, error)
}
var ErrNotFound = E.New("process not found")
@ -29,8 +29,8 @@ type Info struct {
UserId int32
}
func FindProcessInfo(searcher Searcher, ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
info, err := searcher.FindProcessInfo(ctx, network, source, destination)
func FindProcessInfo(searcher Searcher, ctx context.Context, network string, source netip.AddrPort) (*Info, error) {
info, err := searcher.FindProcessInfo(ctx, network, source)
if err != nil {
return nil, err
}

View file

@ -19,8 +19,8 @@ func NewSearcher(config Config) (Searcher, error) {
return &linuxSearcher{config.Logger}, nil
}
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
inode, uid, err := resolveSocketByNetlink(network, source, destination)
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*Info, error) {
inode, uid, err := resolveSocketByNetlink(network, source)
if err != nil {
return nil, err
}

View file

@ -36,7 +36,7 @@ const (
pathProc = "/proc"
)
func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) {
func resolveSocketByNetlink(network string, source netip.AddrPort) (inode, uid uint32, err error) {
var family uint8
var protocol uint8

View file

@ -94,7 +94,7 @@ func (s *platformInterfaceStub) ReadWIFIState() adapter.WIFIState {
return adapter.WIFIState{}
}
func (s *platformInterfaceStub) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*process.Info, error) {
func (s *platformInterfaceStub) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*process.Info, error) {
return nil, os.ErrInvalid
}

View file

@ -30,7 +30,7 @@ func init() {
}
}
func ResolveSocketByProcSearch(network string, source, _ netip.AddrPort) int32 {
func ResolveSocketByProcSearch(network string, source netip.AddrPort) int32 {
if netIndexOfLocal < 0 || netIndexOfUid < 0 {
return -1
}

View file

@ -10,7 +10,7 @@ type PlatformInterface interface {
OpenTun(options TunOptions) (int32, error)
WriteLog(message string)
UseProcFS() bool
FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (int32, error)
FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32) (int32, error)
PackageNameByUid(uid int32) (string, error)
UIDByPackageName(packageName string) (int32, error)
UsePlatformDefaultInterfaceMonitor() bool

View file

@ -203,10 +203,10 @@ func (w *platformInterfaceWrapper) ReadWIFIState() adapter.WIFIState {
return (adapter.WIFIState)(*wifiState)
}
func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*process.Info, error) {
func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort) (*process.Info, error) {
var uid int32
if w.useProcFS {
uid = procfs.ResolveSocketByProcSearch(network, source, destination)
uid = procfs.ResolveSocketByProcSearch(network, source)
if uid == -1 {
return nil, E.New("procfs: not found")
}
@ -221,7 +221,7 @@ func (w *platformInterfaceWrapper) FindProcessInfo(ctx context.Context, network
return nil, E.New("unknown network: ", network)
}
var err error
uid, err = w.iif.FindConnectionOwner(ipProtocol, source.Addr().String(), int32(source.Port()), destination.Addr().String(), int32(destination.Port()))
uid, err = w.iif.FindConnectionOwner(ipProtocol, source.Addr().String(), int32(source.Port()))
if err != nil {
return nil, err
}

View file

@ -115,6 +115,12 @@ func (a *myInboundAdapter) Start() error {
func (a *myInboundAdapter) Close() error {
a.inShutdown.Store(true)
if a.tcpListener != nil {
a.logger.Info("tcp server closed at ", a.tcpListener.Addr())
}
if a.udpConn != nil {
a.logger.Info("udp server closed at ", a.udpConn.LocalAddr())
}
var err error
if a.systemProxy != nil && a.systemProxy.IsEnabled() {
err = a.systemProxy.Disable()

View file

@ -10,6 +10,7 @@ import (
"os/user"
"runtime"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
@ -50,11 +51,15 @@ import (
var _ adapter.Router = (*Router)(nil)
type Router struct {
ctx context.Context
logger log.ContextLogger
dnsLogger log.ContextLogger
ctx context.Context
logger log.ContextLogger
dnsLogger log.ContextLogger
// Currently this is responsible for protecting inbound and outbound dynamic
// control. I'm not sure if it can be separated because I haven't delved
// into the logic yet to make sure they don't interfere with each other.
// To research, may improve performance on some high-load setups.
boundary sync.RWMutex
inboundByTag map[string]adapter.Inbound
outbounds []adapter.Outbound
outboundByTag map[string]adapter.Outbound
rules []adapter.Rule
defaultDetour string
@ -113,6 +118,7 @@ func NewRouter(
ctx: ctx,
logger: logFactory.NewLogger("router"),
dnsLogger: logFactory.NewLogger("dns"),
inboundByTag: make(map[string]adapter.Inbound),
outboundByTag: make(map[string]adapter.Outbound),
rules: make([]adapter.Rule, 0, len(options.Rules)),
dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
@ -373,76 +379,6 @@ func NewRouter(
return router, nil
}
func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error {
inboundByTag := make(map[string]adapter.Inbound)
for _, inbound := range inbounds {
inboundByTag[inbound.Tag()] = inbound
}
outboundByTag := make(map[string]adapter.Outbound)
for _, detour := range outbounds {
outboundByTag[detour.Tag()] = detour
}
var defaultOutboundForConnection adapter.Outbound
var defaultOutboundForPacketConnection adapter.Outbound
if r.defaultDetour != "" {
detour, loaded := outboundByTag[r.defaultDetour]
if !loaded {
return E.New("default detour not found: ", r.defaultDetour)
}
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
}
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
}
}
if defaultOutboundForConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
break
}
}
}
if defaultOutboundForPacketConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
break
}
}
}
if defaultOutboundForConnection == nil || defaultOutboundForPacketConnection == nil {
detour := defaultOutbound()
if defaultOutboundForConnection == nil {
defaultOutboundForConnection = detour
}
if defaultOutboundForPacketConnection == nil {
defaultOutboundForPacketConnection = detour
}
outbounds = append(outbounds, detour)
outboundByTag[detour.Tag()] = detour
}
r.inboundByTag = inboundByTag
r.outbounds = outbounds
r.defaultOutboundForConnection = defaultOutboundForConnection
r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection
r.outboundByTag = outboundByTag
for i, rule := range r.rules {
if _, loaded := outboundByTag[rule.Outbound()]; !loaded {
return E.New("outbound not found for rule[", i, "]: ", rule.Outbound())
}
}
return nil
}
func (r *Router) Outbounds() []adapter.Outbound {
if !r.started {
return nil
}
return r.outbounds
}
func (r *Router) PreStart() error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
if r.interfaceMonitor != nil {
@ -581,9 +517,191 @@ func (r *Router) Start() error {
return nil
}
func (r *Router) Cleanup() error {
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
return nil
}
func (r *Router) AddOutbound(out adapter.Outbound) error {
r.boundary.Lock()
defer r.boundary.Unlock()
if _, ok := r.outboundByTag[out.Tag()]; ok {
return E.New("duplication of tag")
}
if r.defaultDetour == "" || r.defaultDetour == out.Tag() {
if r.defaultOutboundForConnection == nil {
if common.Contains(out.Network(), N.NetworkTCP) {
r.defaultOutboundForConnection = out
}
}
if r.defaultOutboundForPacketConnection == nil {
if common.Contains(out.Network(), N.NetworkUDP) {
r.defaultOutboundForPacketConnection = out
}
}
}
if r.started {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
monitor.Start("initialize outbound/", out.Type(), "[", out.Tag(), "]")
defer monitor.Finish()
if startable, isStartable := out.(interface{ Start() error }); isStartable {
if err := startable.Start(); err != nil {
return E.Cause(err, "start")
}
}
if err := postStartOutbound(out); err != nil {
return E.Cause(err, "post start")
}
}
r.outboundByTag[out.Tag()] = out
return nil
}
func (r *Router) AddInbound(in adapter.Inbound) error {
r.boundary.Lock()
defer r.boundary.Unlock()
if _, ok := r.inboundByTag[in.Tag()]; ok {
return E.New("duplication of tag")
}
if r.started {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
monitor.Start("initialize inbound/", in.Type(), "[", in.Tag(), "]")
defer monitor.Finish()
if err := in.Start(); err != nil {
return E.Cause(err, "start")
}
if err := postStartInbound(in); err != nil {
return E.Cause(err, "post-start")
}
}
r.inboundByTag[in.Tag()] = in
return nil
}
func (r *Router) RemoveOutbound(tag string) error {
r.boundary.Lock()
defer r.boundary.Unlock()
out, ok := r.outboundByTag[tag]
if !ok {
return E.New("unknown tag")
}
delete(r.outboundByTag, tag)
if out == r.defaultOutboundForConnection {
r.defaultOutboundForConnection = nil
}
if out == r.defaultOutboundForPacketConnection {
r.defaultOutboundForPacketConnection = nil
}
if r.defaultDetour == "" {
for _, out := range r.outboundByTag {
if r.defaultOutboundForConnection == nil {
if common.Contains(out.Network(), N.NetworkTCP) {
r.defaultOutboundForConnection = out
}
if common.Contains(out.Network(), N.NetworkUDP) {
r.defaultOutboundForPacketConnection = out
}
if r.defaultOutboundForConnection != nil && r.defaultOutboundForPacketConnection != nil {
break
}
}
}
}
if r.started {
if err := common.Close(out); err != nil {
return E.Cause(err, "close")
}
}
return nil
}
func (r *Router) RemoveInbound(tag string) error {
r.boundary.Lock()
defer r.boundary.Unlock()
in, ok := r.inboundByTag[tag]
if !ok {
return E.New("unknown tag")
}
delete(r.inboundByTag, tag)
if r.started {
if err := in.Close(); err != nil {
return E.Cause(err, "close")
}
}
return nil
}
func (r *Router) StartOutbounds() error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
startedTags := make(map[string]struct{})
for tag, out := range r.outboundByTag {
if err := (&OutboundStarter{
outboundByTag: r.outboundByTag,
startedTags: startedTags,
monitor: monitor,
}).Start(tag, make(map[string]struct{})); err != nil {
return E.Cause(err, "start outbound/", out.Type(), "[", tag, "]")
}
}
return nil
}
func (r *Router) StartInbounds() error {
for tag, in := range r.inboundByTag {
if err := in.Start(); err != nil {
return E.Cause(err, "start inbound/", in.Type(), "[", tag, "]")
}
}
return nil
}
func (r *Router) closeBounds(monitor *taskmonitor.Monitor) error {
r.boundary.Lock()
defer r.boundary.Unlock()
var err error
for tag, in := range r.inboundByTag {
monitor.Start("close inbound/", in.Type(), "[", tag, "]")
err = E.Append(err, in.Close(), func(err error) error {
return E.Cause(err, "close inbound/", in.Type(), "[", tag, "]")
})
monitor.Finish()
}
for tag, out := range r.outboundByTag {
monitor.Start("close outbound/", out.Type(), "[", tag, "]")
err = E.Append(err, common.Close(out), func(err error) error {
return E.Cause(err, "close outbound/", out.Type(), "[", tag, "]")
})
monitor.Finish()
}
return err
}
func (r *Router) Close() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
var err error
err := r.closeBounds(monitor)
for i, rule := range r.rules {
monitor.Start("close rule[", i, "]")
err = E.Append(err, rule.Close(), func(err error) error {
@ -654,10 +772,35 @@ func (r *Router) Close() error {
})
monitor.Finish()
}
r.started = false
return err
}
func postStartOutbound(out adapter.Outbound) error {
if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound {
if err := lateOutbound.PostStart(); err != nil {
return E.Cause(err, "outbound/", out.Type(), "[", out.Tag(), "]")
}
}
return nil
}
func postStartInbound(in adapter.Inbound) error {
if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound {
if err := lateInbound.PostStart(); err != nil {
return E.Cause(err, "inbound/", in.Type(), "[", in.Tag(), "]")
}
}
return nil
}
func (r *Router) PostStart() error {
// TODO: reorganize ALL start order
for _, out := range r.outboundByTag {
if err := postStartOutbound(out); err != nil {
return err
}
}
monitor := taskmonitor.New(r.logger, C.StopTimeout)
if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set")
@ -749,35 +892,58 @@ func (r *Router) PostStart() error {
return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]")
}
}
for _, in := range r.inboundByTag {
if err := postStartInbound(in); err != nil {
return err
}
}
r.started = true
return nil
}
func (r *Router) Cleanup() error {
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
return nil
}
func (r *Router) Outbound(tag string) (adapter.Outbound, bool) {
outbound, loaded := r.outboundByTag[tag]
return outbound, loaded
}
func (r *Router) DefaultOutbound(network string) (adapter.Outbound, error) {
if network == N.NetworkTCP {
r.boundary.RLock()
defer r.boundary.RUnlock()
switch network {
case N.NetworkTCP:
if r.defaultOutboundForConnection == nil {
return nil, E.New("missing default outbound for TCP connections")
}
return r.defaultOutboundForConnection, nil
} else {
case N.NetworkUDP:
if r.defaultOutboundForPacketConnection == nil {
return nil, E.New("missing default outbound for UDP connections")
}
return r.defaultOutboundForPacketConnection, nil
}
return nil, E.New("wrong network type provided")
}
func (r *Router) Outbounds() []adapter.Outbound {
if !r.started {
return nil
}
r.boundary.RLock()
defer r.boundary.RUnlock()
res := make([]adapter.Outbound, 0, len(r.outboundByTag))
for _, out := range r.outboundByTag {
res = append(res, out)
}
return res
}
func (r *Router) Outbound(tag string) (adapter.Outbound, bool) {
r.boundary.RLock()
defer r.boundary.RUnlock()
outbound, loaded := r.outboundByTag[tag]
return outbound, loaded
}
func (r *Router) Inbound(tag string) (adapter.Inbound, bool) {
r.boundary.RLock()
defer r.boundary.RUnlock()
inbound, loaded := r.inboundByTag[tag]
return inbound, loaded
}
func (r *Router) FakeIPStore() adapter.FakeIPStore {
@ -802,8 +968,8 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
if metadata.LastInbound == metadata.InboundDetour {
return E.New("routing loop on detour: ", metadata.InboundDetour)
}
detour := r.inboundByTag[metadata.InboundDetour]
if detour == nil {
detour, ok := r.Inbound(metadata.InboundDetour)
if !ok {
return E.New("inbound detour not found: ", metadata.InboundDetour)
}
injectable, isInjectable := detour.(adapter.InjectableInbound)
@ -908,15 +1074,27 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
} else if metadata.Destination.IsIPv6() {
metadata.IPVersion = 6
}
ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForConnection)
if err != nil {
return err
rule, detour := r.ruleByMetadata(ctx, &metadata)
if rule == nil {
var err error
detour, err = r.DefaultOutbound(N.NetworkTCP)
if err != nil {
return E.New("missing supported outbound, closing packet connection")
}
}
if tag, loaded := outbound.TagFromContext(ctx); loaded {
if tag == detour.Tag() {
return E.New("connection loopback in outbound/", detour.Type(), "[", detour.Tag(), "]")
}
}
if !common.Contains(detour.Network(), N.NetworkTCP) {
return E.New("missing supported outbound, closing connection")
return E.New("missing support of network type by outbound, closing packet connection")
}
ctx = outbound.ContextWithTag(ctx, detour.Tag())
if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, matchedRule)
trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, rule)
defer tracker.Leave()
conn = trackerConn
}
@ -936,8 +1114,8 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
if metadata.LastInbound == metadata.InboundDetour {
return E.New("routing loop on detour: ", metadata.InboundDetour)
}
detour := r.inboundByTag[metadata.InboundDetour]
if detour == nil {
detour, ok := r.Inbound(metadata.InboundDetour)
if !ok {
return E.New("inbound detour not found: ", metadata.InboundDetour)
}
injectable, isInjectable := detour.(adapter.InjectableInbound)
@ -1082,15 +1260,27 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
} else if metadata.Destination.IsIPv6() {
metadata.IPVersion = 6
}
ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection)
if err != nil {
return err
rule, detour := r.ruleByMetadata(ctx, &metadata)
if rule == nil {
var err error
detour, err = r.DefaultOutbound(N.NetworkUDP)
if err != nil {
return E.New("missing supported outbound, closing packet connection")
}
}
if tag, loaded := outbound.TagFromContext(ctx); loaded {
if tag == detour.Tag() {
return E.New("connection loopback in outbound/", detour.Type(), "[", detour.Tag(), "]")
}
}
if !common.Contains(detour.Network(), N.NetworkUDP) {
return E.New("missing supported outbound, closing packet connection")
return E.New("missing support of network type by outbound, closing packet connection")
}
ctx = outbound.ContextWithTag(ctx, detour.Tag())
if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, matchedRule)
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, rule)
defer tracker.Leave()
conn = trackerConn
}
@ -1105,26 +1295,9 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
return detour.NewPacketConnection(ctx, conn, metadata)
}
func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (context.Context, adapter.Rule, adapter.Outbound, error) {
matchRule, matchOutbound := r.match0(ctx, metadata, defaultOutbound)
if contextOutbound, loaded := outbound.TagFromContext(ctx); loaded {
if contextOutbound == matchOutbound.Tag() {
return nil, nil, nil, E.New("connection loopback in outbound/", matchOutbound.Type(), "[", matchOutbound.Tag(), "]")
}
}
ctx = outbound.ContextWithTag(ctx, matchOutbound.Tag())
return ctx, matchRule, matchOutbound, nil
}
func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
func (r *Router) processInfoByMetadata(ctx context.Context, metadata *adapter.InboundContext) *process.Info {
if r.processSearcher != nil {
var originDestination netip.AddrPort
if metadata.OriginDestination.IsValid() {
originDestination = metadata.OriginDestination.AddrPort()
} else if metadata.Destination.IsIP() {
originDestination = metadata.Destination.AddrPort()
}
processInfo, err := process.FindProcessInfo(r.processSearcher, ctx, metadata.Network, metadata.Source.AddrPort(), originDestination)
processInfo, err := process.FindProcessInfo(r.processSearcher, ctx, metadata.Network, metadata.Source.AddrPort())
if err != nil {
r.logger.InfoContext(ctx, "failed to search process: ", err)
} else {
@ -1145,21 +1318,26 @@ func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, d
r.logger.InfoContext(ctx, "found user id: ", processInfo.UserId)
}
}
metadata.ProcessInfo = processInfo
return processInfo
}
}
return nil
}
func (r *Router) ruleByMetadata(ctx context.Context, metadata *adapter.InboundContext) (adapter.Rule, adapter.Outbound) {
metadata.ProcessInfo = r.processInfoByMetadata(ctx, metadata)
for i, rule := range r.rules {
metadata.ResetRuleCache()
if rule.Match(metadata) {
detour := rule.Outbound()
r.logger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
r.logger.DebugContext(ctx, "rule[", i, "] ", rule.String(), " => ", detour)
if outbound, loaded := r.Outbound(detour); loaded {
return rule, outbound
}
r.logger.ErrorContext(ctx, "outbound not found: ", detour)
r.logger.ErrorContext(ctx, "not found outbound[", detour, "]")
}
}
return nil, defaultOutbound
return nil, nil
}
func (r *Router) InterfaceFinder() control.InterfaceFinder {
@ -1306,8 +1484,8 @@ func (r *Router) notifyNetworkUpdate(event int) {
func (r *Router) ResetNetwork() error {
conntrack.Close()
for _, outbound := range r.outbounds {
listener, isListener := outbound.(adapter.InterfaceUpdateListener)
for _, out := range r.Outbounds() {
listener, isListener := out.(adapter.InterfaceUpdateListener)
if isListener {
listener.InterfaceUpdated()
}
@ -1316,6 +1494,7 @@ func (r *Router) ResetNetwork() error {
for _, transport := range r.transports {
transport.Reset()
}
return nil
}

View file

@ -0,0 +1,60 @@
package route
import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
E "github.com/sagernet/sing/common/exceptions"
)
type OutboundStarter struct {
outboundByTag map[string]adapter.Outbound
startedTags map[string]struct{}
monitor *taskmonitor.Monitor
}
func (s *OutboundStarter) Start(tag string, pathIncludesTags map[string]struct{}) error {
adapter := s.outboundByTag[tag]
if adapter == nil {
return E.New("dependency[", tag, "] is not found")
}
// The outbound may have been started by another subtree in the previous,
// we don't need to start it again.
if _, ok := s.startedTags[tag]; ok {
return nil
}
// If we detected the repetition of the tags in scope of tree evaluation,
// the circular dependency is found, as it grows from bottom to top.
if _, ok := pathIncludesTags[tag]; ok {
return E.New("circular dependency related with outbound/", adapter.Type(), "[", tag, "]")
}
// This required to be done only if that outbound isn't already started,
// because some dependencies may come to the same root,
// but they aren't circular.
pathIncludesTags[tag] = struct{}{}
// Next, we are recursively starting all dependencies of the current
// outbound and repeating the cycle.
for _, dependencyTag := range adapter.Dependencies() {
if err := s.Start(dependencyTag, pathIncludesTags); err != nil {
return err
}
}
// Anyway, it will be finished soon, nothing will happen if I'll include
// Startable interface typecasting too.
s.monitor.Start("initialize outbound/", adapter.Type(), "[", tag, "]")
defer s.monitor.Finish()
// After the evaluation of entire tree let's begin to start all
// the outbounds!
if startable, isStartable := adapter.(interface{ Start() error }); isStartable {
if err := startable.Start(); err != nil {
return E.Cause(err, "initialize outbound/", adapter.Type(), "[", tag, "]")
}
}
return nil
}