refactor: Modular inbound/outbound manager

This commit is contained in:
世界 2024-11-09 21:16:11 +08:00
parent 0b0f74bc0f
commit 08f9ec63fb
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
69 changed files with 1186 additions and 754 deletions

View file

@ -15,7 +15,7 @@ import (
type ClashServer interface { type ClashServer interface {
Service Service
PreStarter LegacyPreStarter
Mode() string Mode() string
ModeList() []string ModeList() []string
HistoryStorage() *urltest.HistoryStorage HistoryStorage() *urltest.HistoryStorage
@ -25,7 +25,7 @@ type ClashServer interface {
type CacheFile interface { type CacheFile interface {
Service Service
PreStarter LegacyPreStarter
StoreFakeIP() bool StoreFakeIP() bool
FakeIPStorage FakeIPStorage

View file

@ -28,7 +28,15 @@ type UDPInjectableInbound interface {
type InboundRegistry interface { type InboundRegistry interface {
option.InboundOptionsRegistry option.InboundOptionsRegistry
CreateInbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Inbound, error) Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, inboundType string, options any) (Inbound, error)
}
type InboundManager interface {
NewService
Inbounds() []Inbound
Get(tag string) (Inbound, bool)
Remove(tag string) error
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, inboundType string, options any) error
} }
type InboundContext struct { type InboundContext struct {

143
adapter/inbound/manager.go Normal file
View file

@ -0,0 +1,143 @@
package inbound
import (
"context"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
var _ adapter.InboundManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.InboundRegistry
access sync.Mutex
started bool
stage adapter.StartStage
inbounds []adapter.Inbound
inboundByTag map[string]adapter.Inbound
}
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
inboundByTag: make(map[string]adapter.Inbound),
}
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
for _, inbound := range m.inbounds {
err := adapter.LegacyStart(inbound, stage)
if err != nil {
return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]")
}
}
return nil
}
func (m *Manager) Close() error {
m.access.Lock()
if !m.started {
panic("not started")
}
m.started = false
inbounds := m.inbounds
m.inbounds = nil
m.access.Unlock()
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, inbound := range inbounds {
monitor.Start("close inbound/", inbound.Type(), "[", inbound.Tag(), "]")
err = E.Append(err, inbound.Close(), func(err error) error {
return E.Cause(err, "close inbound/", inbound.Type(), "[", inbound.Tag(), "]")
})
monitor.Finish()
}
return nil
}
func (m *Manager) Inbounds() []adapter.Inbound {
m.access.Lock()
defer m.access.Unlock()
return m.inbounds
}
func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
m.access.Lock()
defer m.access.Unlock()
inbound, found := m.inboundByTag[tag]
return inbound, found
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
inbound, found := m.inboundByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.inboundByTag, tag)
index := common.Index(m.inbounds, func(it adapter.Inbound) bool {
return it == inbound
})
if index == -1 {
panic("invalid inbound index")
}
m.inbounds = append(m.inbounds[:index], m.inbounds[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return inbound.Close()
}
return nil
}
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
inbound, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(inbound, stage)
if err != nil {
return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]")
}
}
}
if existsInbound, loaded := m.inboundByTag[tag]; loaded {
if m.started {
err = existsInbound.Close()
if err != nil {
return E.Cause(err, "close inbound/", existsInbound.Type(), "[", existsInbound.Tag(), "]")
}
}
existsIndex := common.Index(m.inbounds, func(it adapter.Inbound) bool {
return it == existsInbound
})
if existsIndex == -1 {
panic("invalid inbound index")
}
m.inbounds = append(m.inbounds[:existsIndex], m.inbounds[existsIndex+1:]...)
}
m.inbounds = append(m.inbounds, inbound)
m.inboundByTag[tag] = inbound
return nil
}

View file

@ -15,8 +15,12 @@ type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, log
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) { func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
registry.register(outboundType, func() any { registry.register(outboundType, func() any {
return new(Options) return new(Options)
}, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Inbound, error) { }, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Inbound, error) {
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options.(*Options))) var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options))
}) })
} }
@ -30,39 +34,39 @@ type (
type Registry struct { type Registry struct {
access sync.Mutex access sync.Mutex
optionsType map[string]optionsConstructorFunc optionsType map[string]optionsConstructorFunc
constructors map[string]constructorFunc constructor map[string]constructorFunc
} }
func NewRegistry() *Registry { func NewRegistry() *Registry {
return &Registry{ return &Registry{
optionsType: make(map[string]optionsConstructorFunc), optionsType: make(map[string]optionsConstructorFunc),
constructors: make(map[string]constructorFunc), constructor: make(map[string]constructorFunc),
} }
} }
func (r *Registry) CreateOptions(outboundType string) (any, bool) { func (m *Registry) CreateOptions(outboundType string) (any, bool) {
r.access.Lock() m.access.Lock()
defer r.access.Unlock() defer m.access.Unlock()
optionsConstructor, loaded := r.optionsType[outboundType] optionsConstructor, loaded := m.optionsType[outboundType]
if !loaded { if !loaded {
return nil, false return nil, false
} }
return optionsConstructor(), true return optionsConstructor(), true
} }
func (r *Registry) CreateInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Inbound, error) { func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Inbound, error) {
r.access.Lock() m.access.Lock()
defer r.access.Unlock() defer m.access.Unlock()
constructor, loaded := r.constructors[outboundType] constructor, loaded := m.constructor[outboundType]
if !loaded { if !loaded {
return nil, E.New("outbound type not found: " + outboundType) return nil, E.New("outbound type not found: " + outboundType)
} }
return constructor(ctx, router, logger, tag, options) return constructor(ctx, router, logger, tag, options)
} }
func (r *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) { func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
r.access.Lock() m.access.Lock()
defer r.access.Unlock() defer m.access.Unlock()
r.optionsType[outboundType] = optionsConstructor m.optionsType[outboundType] = optionsConstructor
r.constructors[outboundType] = constructor m.constructor[outboundType] = constructor
} }

41
adapter/lifecycle.go Normal file
View file

@ -0,0 +1,41 @@
package adapter
type StartStage uint8
const (
StartStateInitialize StartStage = iota
StartStateStart
StartStatePostStart
StartStateStarted
)
var ListStartStages = []StartStage{
StartStateInitialize,
StartStateStart,
StartStatePostStart,
StartStateStarted,
}
func (s StartStage) Action() string {
switch s {
case StartStateInitialize:
return "initialize"
case StartStateStart:
return "start"
case StartStatePostStart:
return "post-start"
case StartStateStarted:
return "start-after-started"
default:
panic("unknown stage")
}
}
type NewService interface {
NewStarter
Close() error
}
type NewStarter interface {
Start(stage StartStage) error
}

View file

@ -0,0 +1,33 @@
package adapter
type LegacyPreStarter interface {
PreStart() error
}
type LegacyPostStarter interface {
PostStart() error
}
func LegacyStart(starter any, stage StartStage) error {
switch stage {
case StartStateInitialize:
if preStarter, isPreStarter := starter.(interface {
PreStart() error
}); isPreStarter {
return preStarter.PreStart()
}
case StartStateStart:
if starter, isStarter := starter.(interface {
Start() error
}); isStarter {
return starter.Start()
}
case StartStatePostStart:
if postStarter, isPostStarter := starter.(interface {
PostStart() error
}); isPostStarter {
return postStarter.PostStart()
}
}
return nil
}

View file

@ -22,3 +22,12 @@ type OutboundRegistry interface {
option.OutboundOptionsRegistry option.OutboundOptionsRegistry
CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error) CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error)
} }
type OutboundManager interface {
NewService
Outbounds() []Outbound
Outbound(tag string) (Outbound, bool)
Default() Outbound
Remove(tag string) error
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) error
}

265
adapter/outbound/manager.go Normal file
View file

@ -0,0 +1,265 @@
package outbound
import (
"context"
"io"
"os"
"strings"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var _ adapter.OutboundManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.OutboundRegistry
defaultTag string
access sync.Mutex
started bool
stage adapter.StartStage
outbounds []adapter.Outbound
outboundByTag map[string]adapter.Outbound
dependByTag map[string][]string
defaultOutbound adapter.Outbound
defaultOutboundFallback adapter.Outbound
}
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager {
return &Manager{
logger: logger,
registry: registry,
defaultTag: defaultTag,
outboundByTag: make(map[string]adapter.Outbound),
dependByTag: make(map[string][]string),
}
}
func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) {
m.defaultOutboundFallback = defaultOutboundFallback
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
if stage == adapter.StartStateStart {
m.startOutbounds()
} else {
for _, outbound := range m.outbounds {
err := adapter.LegacyStart(outbound, stage)
if err != nil {
return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
}
}
}
return nil
}
func (m *Manager) startOutbounds() error {
monitor := taskmonitor.New(m.logger, C.StartTimeout)
started := make(map[string]bool)
for {
canContinue := false
startOne:
for _, outboundToStart := range m.outbounds {
outboundTag := outboundToStart.Tag()
if started[outboundTag] {
continue
}
dependencies := outboundToStart.Dependencies()
for _, dependency := range dependencies {
if !started[dependency] {
continue startOne
}
}
started[outboundTag] = true
canContinue = true
if starter, isStarter := outboundToStart.(interface {
Start() error
}); isStarter {
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
}
}
if len(started) == len(m.outbounds) {
break
}
if canContinue {
continue
}
currentOutbound := common.Find(m.outbounds, func(it adapter.Outbound) bool {
return !started[it.Tag()]
})
var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
return !started[it]
})
if common.Contains(oTree, problemOutboundTag) {
return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
}
problemOutbound := m.outboundByTag[problemOutboundTag]
if problemOutbound == nil {
return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]")
}
return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
}
return lintOutbound([]string{currentOutbound.Tag()}, currentOutbound)
}
return nil
}
func (m *Manager) Close() error {
monitor := taskmonitor.New(m.logger, C.StopTimeout)
m.access.Lock()
if !m.started {
panic("not started")
}
m.started = false
outbounds := m.outbounds
m.outbounds = nil
m.access.Unlock()
var err error
for _, outbound := range outbounds {
if closer, isCloser := outbound.(io.Closer); isCloser {
monitor.Start("close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
err = E.Append(err, closer.Close(), func(err error) error {
return E.Cause(err, "close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
})
monitor.Finish()
}
}
return nil
}
func (m *Manager) Outbounds() []adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
return m.outbounds
}
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag]
return outbound, found
}
func (m *Manager) Default() adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
if m.defaultOutbound != nil {
return m.defaultOutbound
} else {
return m.defaultOutboundFallback
}
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
outbound, found := m.outboundByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.outboundByTag, tag)
index := common.Index(m.outbounds, func(it adapter.Outbound) bool {
return it == outbound
})
if index == -1 {
panic("invalid inbound index")
}
m.outbounds = append(m.outbounds[:index], m.outbounds[index+1:]...)
started := m.started
if m.defaultOutbound == outbound {
if len(m.outbounds) > 0 {
m.defaultOutbound = m.outbounds[0]
m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
} else {
m.defaultOutbound = nil
}
}
dependBy := m.dependByTag[tag]
if len(dependBy) > 0 {
return E.New("outbound[", tag, "] is depended by ", strings.Join(dependBy, ", "))
}
dependencies := outbound.Dependencies()
for _, dependency := range dependencies {
if len(m.dependByTag[dependency]) == 1 {
delete(m.dependByTag, dependency)
} else {
m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
return it != tag
})
}
}
m.access.Unlock()
if started {
return common.Close(outbound)
}
return nil
}
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, inboundType string, options any) error {
if tag == "" {
return os.ErrInvalid
}
outbound, err := m.registry.CreateOutbound(ctx, router, logger, tag, inboundType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(outbound, stage)
if err != nil {
return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
}
}
}
if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
if m.started {
err = common.Close(existsOutbound)
if err != nil {
return E.Cause(err, "close outbound/", existsOutbound.Type(), "[", existsOutbound.Tag(), "]")
}
}
existsIndex := common.Index(m.outbounds, func(it adapter.Outbound) bool {
return it == existsOutbound
})
if existsIndex == -1 {
panic("invalid inbound index")
}
m.outbounds = append(m.outbounds[:existsIndex], m.outbounds[existsIndex+1:]...)
}
m.outbounds = append(m.outbounds, outbound)
m.outboundByTag[tag] = outbound
dependencies := outbound.Dependencies()
for _, dependency := range dependencies {
m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
}
if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) {
m.defaultOutbound = outbound
if m.started {
m.logger.Info("updated default outbound to ", outbound.Tag())
}
}
return nil
}

View file

@ -15,8 +15,12 @@ type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, log
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) { func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
registry.register(outboundType, func() any { registry.register(outboundType, func() any {
return new(Options) return new(Options)
}, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Outbound, error) { }, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Outbound, error) {
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options.(*Options))) var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options))
}) })
} }

View file

@ -1,9 +1 @@
package adapter package adapter
type PreStarter interface {
PreStart() error
}
type PostStarter interface {
PostStart() error
}

View file

@ -15,21 +15,13 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
mdns "github.com/miekg/dns" mdns "github.com/miekg/dns"
"go4.org/netipx" "go4.org/netipx"
) )
type Router interface { type Router interface {
Service NewService
PreStarter
PostStarter
Cleanup() error
Outbounds() []Outbound
Outbound(tag string) (Outbound, bool)
DefaultOutbound(network string) (Outbound, error)
FakeIPStore() FakeIPStore FakeIPStore() FakeIPStore
@ -84,14 +76,6 @@ type ConnectionRouterEx interface {
RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc)
} }
func ContextWithRouter(ctx context.Context, router Router) context.Context {
return service.ContextWith(ctx, router)
}
func RouterFromContext(ctx context.Context) Router {
return service.FromContext[Router](ctx)
}
type RuleSet interface { type RuleSet interface {
Name() string Name() string
StartContext(ctx context.Context, startContext *HTTPStartContext) error StartContext(ctx context.Context, startContext *HTTPStartContext) error

134
box.go
View file

@ -9,6 +9,8 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/taskmonitor" "github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental" "github.com/sagernet/sing-box/experimental"
@ -30,8 +32,8 @@ var _ adapter.Service = (*Box)(nil)
type Box struct { type Box struct {
createdAt time.Time createdAt time.Time
router adapter.Router router adapter.Router
inbounds []adapter.Inbound inbound *inbound.Manager
outbounds []adapter.Outbound outbound *outbound.Manager
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
preServices1 map[string]adapter.Service preServices1 map[string]adapter.Service
@ -66,6 +68,7 @@ func New(options Options) (*Box, error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
ctx = service.ContextWithDefaultRegistry(ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
if inboundRegistry == nil { if inboundRegistry == nil {
return nil, E.New("missing inbound registry in context") return nil, E.New("missing inbound registry in context")
@ -74,7 +77,6 @@ func New(options Options) (*Box, error) {
if outboundRegistry == nil { if outboundRegistry == nil {
return nil, E.New("missing outbound registry in context") return nil, E.New("missing outbound registry in context")
} }
ctx = service.ContextWithDefaultRegistry(ctx)
ctx = pause.WithDefaultManager(ctx) ctx = pause.WithDefaultManager(ctx)
experimentalOptions := common.PtrValueOrDefault(options.Experimental) experimentalOptions := common.PtrValueOrDefault(options.Experimental)
applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug)) applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug))
@ -106,10 +108,15 @@ func New(options Options) (*Box, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "create log factory") return nil, E.Cause(err, "create log factory")
} }
routeOptions := common.PtrValueOrDefault(options.Route)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound-manager"), inboundRegistry)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound-manager"), outboundRegistry, routeOptions.Final)
ctx = service.ContextWith[adapter.InboundManager](ctx, inboundManager)
ctx = service.ContextWith[adapter.OutboundManager](ctx, outboundManager)
router, err := route.NewRouter( router, err := route.NewRouter(
ctx, ctx,
logFactory, logFactory,
common.PtrValueOrDefault(options.Route), routeOptions,
common.PtrValueOrDefault(options.DNS), common.PtrValueOrDefault(options.DNS),
common.PtrValueOrDefault(options.NTP), common.PtrValueOrDefault(options.NTP),
options.Inbounds, options.Inbounds,
@ -118,15 +125,13 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "parse route options") return nil, E.Cause(err, "parse route options")
} }
for i, inboundOptions := range options.Inbounds { for i, inboundOptions := range options.Inbounds {
var currentInbound adapter.Inbound
var tag string var tag string
if inboundOptions.Tag != "" { if inboundOptions.Tag != "" {
tag = inboundOptions.Tag tag = inboundOptions.Tag
} else { } else {
tag = F.ToString(i) tag = F.ToString(i)
} }
currentInbound, err = inboundRegistry.CreateInbound( err = inboundManager.Create(ctx,
ctx,
router, router,
logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")), logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")),
tag, tag,
@ -134,12 +139,10 @@ func New(options Options) (*Box, error) {
inboundOptions.Options, inboundOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "parse inbound[", i, "]") return nil, E.Cause(err, "initialize inbound[", i, "]")
} }
inbounds = append(inbounds, currentInbound)
} }
for i, outboundOptions := range options.Outbounds { for i, outboundOptions := range options.Outbounds {
var currentOutbound adapter.Outbound
var tag string var tag string
if outboundOptions.Tag != "" { if outboundOptions.Tag != "" {
tag = outboundOptions.Tag tag = outboundOptions.Tag
@ -153,7 +156,7 @@ func New(options Options) (*Box, error) {
Outbound: tag, Outbound: tag,
}) })
} }
currentOutbound, err = outboundRegistry.CreateOutbound( err = outboundManager.Create(
outboundCtx, outboundCtx,
router, router,
logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")), logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")),
@ -162,16 +165,18 @@ func New(options Options) (*Box, error) {
outboundOptions.Options, outboundOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "parse outbound[", i, "]") return nil, E.Cause(err, "initialize outbound[", i, "]")
} }
outbounds = append(outbounds, currentOutbound)
} }
err = router.Initialize(inbounds, outbounds, func() adapter.Outbound { outboundManager.Initialize(common.Must1(
defaultOutbound, cErr := direct.NewOutbound(ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.DirectOutboundOptions{}) direct.NewOutbound(
common.Must(cErr) ctx,
outbounds = append(outbounds, defaultOutbound) router,
return defaultOutbound logFactory.NewLogger("outbound/direct"),
}) "direct",
option.DirectOutboundOptions{},
),
))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -195,7 +200,7 @@ func New(options Options) (*Box, error) {
if needClashAPI { if needClashAPI {
clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI) clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI)
clashAPIOptions.ModeList = experimental.CalculateClashModeList(options.Options) clashAPIOptions.ModeList = experimental.CalculateClashModeList(options.Options)
clashServer, err := experimental.NewClashServer(ctx, router, logFactory.(log.ObservableFactory), clashAPIOptions) clashServer, err := experimental.NewClashServer(ctx, logFactory.(log.ObservableFactory), clashAPIOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "create clash api server") return nil, E.Cause(err, "create clash api server")
} }
@ -212,8 +217,8 @@ func New(options Options) (*Box, error) {
} }
return &Box{ return &Box{
router: router, router: router,
inbounds: inbounds, inbound: inboundManager,
outbounds: outbounds, outbound: outboundManager,
createdAt: createdAt, createdAt: createdAt,
logFactory: logFactory, logFactory: logFactory,
logger: logFactory.Logger(), logger: logFactory.Logger(),
@ -271,7 +276,7 @@ func (s *Box) preStart() error {
return E.Cause(err, "start logger") return E.Cause(err, "start logger")
} }
for serviceName, service := range s.preServices1 { for serviceName, service := range s.preServices1 {
if preService, isPreService := service.(adapter.PreStarter); isPreService { if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService {
monitor.Start("pre-start ", serviceName) monitor.Start("pre-start ", serviceName)
err := preService.PreStart() err := preService.PreStart()
monitor.Finish() monitor.Finish()
@ -281,7 +286,7 @@ func (s *Box) preStart() error {
} }
} }
for serviceName, service := range s.preServices2 { for serviceName, service := range s.preServices2 {
if preService, isPreService := service.(adapter.PreStarter); isPreService { if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService {
monitor.Start("pre-start ", serviceName) monitor.Start("pre-start ", serviceName)
err := preService.PreStart() err := preService.PreStart()
monitor.Finish() monitor.Finish()
@ -290,15 +295,15 @@ func (s *Box) preStart() error {
} }
} }
} }
err = s.router.PreStart() err = s.router.Start(adapter.StartStateInitialize)
if err != nil { if err != nil {
return E.Cause(err, "pre-start router") return E.Cause(err, "initialize router")
} }
err = s.startOutbounds() err = s.outbound.Start(adapter.StartStateStart)
if err != nil { if err != nil {
return err return err
} }
return s.router.Start() return s.router.Start(adapter.StartStateStart)
} }
func (s *Box) start() error { func (s *Box) start() error {
@ -318,52 +323,39 @@ func (s *Box) start() error {
return E.Cause(err, "start ", serviceName) return E.Cause(err, "start ", serviceName)
} }
} }
for i, in := range s.inbounds { err = s.inbound.Start(adapter.StartStateStart)
var tag string
if in.Tag() == "" {
tag = F.ToString(i)
} else {
tag = in.Tag()
}
err = in.Start()
if err != nil {
return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]")
}
}
err = s.postStart()
if err != nil { if err != nil {
return err return err
} }
return s.router.Cleanup()
}
func (s *Box) postStart() error {
for serviceName, service := range s.postServices { for serviceName, service := range s.postServices {
err := service.Start() err := service.Start()
if err != nil { if err != nil {
return E.Cause(err, "start ", serviceName) return E.Cause(err, "start ", serviceName)
} }
} }
// TODO: reorganize ALL start order err = s.outbound.Start(adapter.StartStatePostStart)
for _, out := range s.outbounds {
if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound {
err := lateOutbound.PostStart()
if err != nil {
return E.Cause(err, "post-start outbound/", out.Tag())
}
}
}
err := s.router.PostStart()
if err != nil { if err != nil {
return err return err
} }
for _, in := range s.inbounds { err = s.router.Start(adapter.StartStatePostStart)
if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound {
err = lateInbound.PostStart()
if err != nil { if err != nil {
return E.Cause(err, "post-start inbound/", in.Tag()) return err
} }
err = s.inbound.Start(adapter.StartStatePostStart)
if err != nil {
return err
} }
err = s.router.Start(adapter.StartStateStarted)
if err != nil {
return err
}
err = s.outbound.Start(adapter.StartStateStarted)
if err != nil {
return err
}
err = s.inbound.Start(adapter.StartStateStarted)
if err != nil {
return err
} }
return nil return nil
} }
@ -384,20 +376,8 @@ func (s *Box) Close() error {
}) })
monitor.Finish() monitor.Finish()
} }
for i, in := range s.inbounds { errors = E.Errors(errors, s.inbound.Close())
monitor.Start("close inbound/", in.Type(), "[", i, "]") errors = E.Errors(errors, s.outbound.Close())
errors = E.Append(errors, in.Close(), func(err error) error {
return E.Cause(err, "close inbound/", in.Type(), "[", i, "]")
})
monitor.Finish()
}
for i, out := range s.outbounds {
monitor.Start("close outbound/", out.Type(), "[", i, "]")
errors = E.Append(errors, common.Close(out), func(err error) error {
return E.Cause(err, "close outbound/", out.Type(), "[", i, "]")
})
monitor.Finish()
}
monitor.Start("close router") monitor.Start("close router")
if err := common.Close(s.router); err != nil { if err := common.Close(s.router); err != nil {
errors = E.Append(errors, err, func(err error) error { errors = E.Append(errors, err, func(err error) error {
@ -427,6 +407,14 @@ func (s *Box) Close() error {
return errors return errors
} }
func (s *Box) Inbound() adapter.InboundManager {
return s.inbound
}
func (s *Box) Outbound() adapter.OutboundManager {
return s.outbound
}
func (s *Box) Router() adapter.Router { func (s *Box) Router() adapter.Router {
return s.router return s.router
} }

View file

@ -1,85 +0,0 @@
package box
import (
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
func (s *Box) startOutbounds() error {
monitor := taskmonitor.New(s.logger, C.StartTimeout)
outboundTags := make(map[adapter.Outbound]string)
outbounds := make(map[string]adapter.Outbound)
for i, outboundToStart := range s.outbounds {
var outboundTag string
if outboundToStart.Tag() == "" {
outboundTag = F.ToString(i)
} else {
outboundTag = outboundToStart.Tag()
}
if _, exists := outbounds[outboundTag]; exists {
return E.New("outbound tag ", outboundTag, " duplicated")
}
outboundTags[outboundToStart] = outboundTag
outbounds[outboundTag] = outboundToStart
}
started := make(map[string]bool)
for {
canContinue := false
startOne:
for _, outboundToStart := range s.outbounds {
outboundTag := outboundTags[outboundToStart]
if started[outboundTag] {
continue
}
dependencies := outboundToStart.Dependencies()
for _, dependency := range dependencies {
if !started[dependency] {
continue startOne
}
}
started[outboundTag] = true
canContinue = true
if starter, isStarter := outboundToStart.(interface {
Start() error
}); isStarter {
monitor.Start("initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
}
}
if len(started) == len(s.outbounds) {
break
}
if canContinue {
continue
}
currentOutbound := common.Find(s.outbounds, func(it adapter.Outbound) bool {
return !started[outboundTags[it]]
})
var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
return !started[it]
})
if common.Contains(oTree, problemOutboundTag) {
return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
}
problemOutbound := outbounds[problemOutboundTag]
if problemOutbound == nil {
return E.New("dependency[", problemOutboundTag, "] not found for outbound[", outboundTags[oCurrent], "]")
}
return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
}
return lintOutbound([]string{outboundTags[currentOutbound]}, currentOutbound)
}
return nil
}

View file

@ -41,11 +41,11 @@ func createPreStartedClient() (*box.Box, error) {
return instance, nil return instance, nil
} }
func createDialer(instance *box.Box, network string, outboundTag string) (N.Dialer, error) { func createDialer(instance *box.Box, outboundTag string) (N.Dialer, error) {
if outboundTag == "" { if outboundTag == "" {
return instance.Router().DefaultOutbound(N.NetworkName(network)) return instance.Outbound().Default(), nil
} else { } else {
outbound, loaded := instance.Router().Outbound(outboundTag) outbound, loaded := instance.Outbound().Outbound(outboundTag)
if !loaded { if !loaded {
return nil, E.New("outbound not found: ", outboundTag) return nil, E.New("outbound not found: ", outboundTag)
} }

View file

@ -45,7 +45,7 @@ func connect(address string) error {
return err return err
} }
defer instance.Close() defer instance.Close()
dialer, err := createDialer(instance, commandConnectFlagNetwork, commandToolsFlagOutbound) dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil { if err != nil {
return err return err
} }

View file

@ -48,7 +48,7 @@ func fetch(args []string) error {
httpClient = &http.Client{ httpClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer, err := createDialer(instance, network, commandToolsFlagOutbound) dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -16,7 +16,7 @@ import (
) )
func initializeHTTP3Client(instance *box.Box) error { func initializeHTTP3Client(instance *box.Box) error {
dialer, err := createDialer(instance, N.NetworkUDP, commandToolsFlagOutbound) dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil { if err != nil {
return err return err
} }

View file

@ -9,7 +9,6 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -45,7 +44,7 @@ func syncTime() error {
if err != nil { if err != nil {
return err return err
} }
dialer, err := createDialer(instance, N.NetworkUDP, commandToolsFlagOutbound) dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil { if err != nil {
return err return err
} }

View file

@ -12,15 +12,15 @@ import (
) )
type DetourDialer struct { type DetourDialer struct {
router adapter.Router outboundManager adapter.OutboundManager
detour string detour string
dialer N.Dialer dialer N.Dialer
initOnce sync.Once initOnce sync.Once
initErr error initErr error
} }
func NewDetour(router adapter.Router, detour string) N.Dialer { func NewDetour(outboundManager adapter.OutboundManager, detour string) N.Dialer {
return &DetourDialer{router: router, detour: detour} return &DetourDialer{outboundManager: outboundManager, detour: detour}
} }
func (d *DetourDialer) Start() error { func (d *DetourDialer) Start() error {
@ -31,7 +31,7 @@ func (d *DetourDialer) Start() error {
func (d *DetourDialer) Dialer() (N.Dialer, error) { func (d *DetourDialer) Dialer() (N.Dialer, error) {
d.initOnce.Do(func() { d.initOnce.Do(func() {
var loaded bool var loaded bool
d.dialer, loaded = d.router.Outbound(d.detour) d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded { if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour) d.initErr = E.New("outbound detour not found: ", d.detour)
} }

View file

@ -1,21 +1,22 @@
package dialer package dialer
import ( import (
"context"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
) )
func New(router adapter.Router, options option.DialerOptions) (N.Dialer, error) { func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
router := service.FromContext[adapter.Router](ctx)
if options.IsWireGuardListener { if options.IsWireGuardListener {
return NewDefault(router, options) return NewDefault(router, options)
} }
if router == nil {
return NewDefault(nil, options)
}
var ( var (
dialer N.Dialer dialer N.Dialer
err error err error
@ -26,7 +27,14 @@ func New(router adapter.Router, options option.DialerOptions) (N.Dialer, error)
return nil, err return nil, err
} }
} else { } else {
dialer = NewDetour(router, options.Detour) outboundManager := service.FromContext[adapter.OutboundManager](ctx)
if outboundManager == nil {
return nil, E.New("missing outbound manager")
}
dialer = NewDetour(outboundManager, options.Detour)
}
if router == nil {
return NewDefault(router, options)
} }
if options.Detour == "" { if options.Detour == "" {
dialer = NewResolveDialer( dialer = NewResolveDialer(

View file

@ -9,30 +9,22 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
type RouterDialer struct { type DefaultOutboundDialer struct {
router adapter.Router outboundManager adapter.OutboundManager
} }
func NewRouter(router adapter.Router) N.Dialer { func NewDefaultOutbound(outboundManager adapter.OutboundManager) N.Dialer {
return &RouterDialer{router: router} return &DefaultOutboundDialer{outboundManager: outboundManager}
} }
func (d *RouterDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
dialer, err := d.router.DefaultOutbound(network) return d.outboundManager.Default().DialContext(ctx, network, destination)
if err != nil {
return nil, err
}
return dialer.DialContext(ctx, network, destination)
} }
func (d *RouterDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
dialer, err := d.router.DefaultOutbound(N.NetworkUDP) return d.outboundManager.Default().ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return dialer.ListenPacket(ctx, destination)
} }
func (d *RouterDialer) Upstream() any { func (d *DefaultOutboundDialer) Upstream() any {
return d.router return d.outboundManager.Default()
} }

View file

@ -12,6 +12,7 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/shell"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
) )
type DarwinSystemProxy struct { type DarwinSystemProxy struct {
@ -24,7 +25,7 @@ type DarwinSystemProxy struct {
} }
func NewSystemProxy(ctx context.Context, serverAddr M.Socksaddr, supportSOCKS bool) (*DarwinSystemProxy, error) { func NewSystemProxy(ctx context.Context, serverAddr M.Socksaddr, supportSOCKS bool) (*DarwinSystemProxy, error) {
interfaceMonitor := adapter.RouterFromContext(ctx).InterfaceMonitor() interfaceMonitor := service.FromContext[adapter.Router](ctx).InterfaceMonitor()
if interfaceMonitor == nil { if interfaceMonitor == nil {
return nil, E.New("missing interface monitor") return nil, E.New("missing interface monitor")
} }

View file

@ -19,6 +19,7 @@ import (
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
) )
@ -213,7 +214,7 @@ func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverNam
}, },
}, },
} }
response, err := adapter.RouterFromContext(ctx).Exchange(ctx, message) response, err := service.FromContext[adapter.Router](ctx).Exchange(ctx, message)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/sagernet/reality" "github.com/sagernet/reality"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
@ -102,7 +101,7 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
tlsConfig.ShortIds[shortID] = true tlsConfig.ShortIds[shortID] = true
} }
handshakeDialer, err := dialer.New(adapter.RouterFromContext(ctx), options.Reality.Handshake.DialerOptions) handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -12,7 +12,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
) )
type ClashServerConstructor = func(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) type ClashServerConstructor = func(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error)
var clashServerConstructor ClashServerConstructor var clashServerConstructor ClashServerConstructor
@ -20,11 +20,11 @@ func RegisterClashServerConstructor(constructor ClashServerConstructor) {
clashServerConstructor = constructor clashServerConstructor = constructor
} }
func NewClashServer(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { func NewClashServer(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
if clashServerConstructor == nil { if clashServerConstructor == nil {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
return clashServerConstructor(ctx, router, logFactory, options) return clashServerConstructor(ctx, logFactory, options)
} }
func CalculateClashModeList(options option.Options) []string { func CalculateClashModeList(options option.Options) []string {

View file

@ -23,7 +23,7 @@ func groupRouter(server *Server) http.Handler {
r := chi.NewRouter() r := chi.NewRouter()
r.Get("/", getGroups(server)) r.Get("/", getGroups(server))
r.Route("/{name}", func(r chi.Router) { r.Route("/{name}", func(r chi.Router) {
r.Use(parseProxyName, findProxyByName(server.router)) r.Use(parseProxyName, findProxyByName(server))
r.Get("/", getGroup(server)) r.Get("/", getGroup(server))
r.Get("/delay", getGroupDelay(server)) r.Get("/delay", getGroupDelay(server))
}) })
@ -32,7 +32,7 @@ func groupRouter(server *Server) http.Handler {
func getGroups(server *Server) func(w http.ResponseWriter, r *http.Request) { func getGroups(server *Server) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
groups := common.Map(common.Filter(server.router.Outbounds(), func(it adapter.Outbound) bool { groups := common.Map(common.Filter(server.outboundManager.Outbounds(), func(it adapter.Outbound) bool {
_, isGroup := it.(adapter.OutboundGroup) _, isGroup := it.(adapter.OutboundGroup)
return isGroup return isGroup
}), func(it adapter.Outbound) *badjson.JSONObject { }), func(it adapter.Outbound) *badjson.JSONObject {
@ -86,7 +86,7 @@ func getGroupDelay(server *Server) func(w http.ResponseWriter, r *http.Request)
result, err = urlTestGroup.URLTest(ctx) result, err = urlTestGroup.URLTest(ctx)
} else { } else {
outbounds := common.FilterNotNil(common.Map(outboundGroup.All(), func(it string) adapter.Outbound { outbounds := common.FilterNotNil(common.Map(outboundGroup.All(), func(it string) adapter.Outbound {
itOutbound, _ := server.router.Outbound(it) itOutbound, _ := server.outboundManager.Outbound(it)
return itOutbound return itOutbound
})) }))
b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10)) b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10))
@ -100,7 +100,7 @@ func getGroupDelay(server *Server) func(w http.ResponseWriter, r *http.Request)
continue continue
} }
checked[realTag] = true checked[realTag] = true
p, loaded := server.router.Outbound(realTag) p, loaded := server.outboundManager.Outbound(realTag)
if !loaded { if !loaded {
continue continue
} }

View file

@ -23,10 +23,10 @@ import (
func proxyRouter(server *Server, router adapter.Router) http.Handler { func proxyRouter(server *Server, router adapter.Router) http.Handler {
r := chi.NewRouter() r := chi.NewRouter()
r.Get("/", getProxies(server, router)) r.Get("/", getProxies(server))
r.Route("/{name}", func(r chi.Router) { r.Route("/{name}", func(r chi.Router) {
r.Use(parseProxyName, findProxyByName(router)) r.Use(parseProxyName, findProxyByName(server))
r.Get("/", getProxy(server)) r.Get("/", getProxy(server))
r.Get("/delay", getProxyDelay(server)) r.Get("/delay", getProxyDelay(server))
r.Put("/", updateProxy) r.Put("/", updateProxy)
@ -42,11 +42,11 @@ func parseProxyName(next http.Handler) http.Handler {
}) })
} }
func findProxyByName(router adapter.Router) func(next http.Handler) http.Handler { func findProxyByName(server *Server) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
name := r.Context().Value(CtxKeyProxyName).(string) name := r.Context().Value(CtxKeyProxyName).(string)
proxy, exist := router.Outbound(name) proxy, exist := server.outboundManager.Outbound(name)
if !exist { if !exist {
render.Status(r, http.StatusNotFound) render.Status(r, http.StatusNotFound)
render.JSON(w, r, ErrNotFound) render.JSON(w, r, ErrNotFound)
@ -83,10 +83,10 @@ func proxyInfo(server *Server, detour adapter.Outbound) *badjson.JSONObject {
return &info return &info
} }
func getProxies(server *Server, router adapter.Router) func(w http.ResponseWriter, r *http.Request) { func getProxies(server *Server) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var proxyMap badjson.JSONObject var proxyMap badjson.JSONObject
outbounds := common.Filter(router.Outbounds(), func(detour adapter.Outbound) bool { outbounds := common.Filter(server.outboundManager.Outbounds(), func(detour adapter.Outbound) bool {
return detour.Tag() != "" return detour.Tag() != ""
}) })
@ -100,12 +100,7 @@ func getProxies(server *Server, router adapter.Router) func(w http.ResponseWrite
allProxies = append(allProxies, detour.Tag()) allProxies = append(allProxies, detour.Tag())
} }
var defaultTag string defaultTag := server.outboundManager.Default().Tag()
if defaultOutbound, err := router.DefaultOutbound(N.NetworkTCP); err == nil {
defaultTag = defaultOutbound.Tag()
} else {
defaultTag = allProxies[0]
}
sort.SliceStable(allProxies, func(i, j int) bool { sort.SliceStable(allProxies, func(i, j int) bool {
return allProxies[i] == defaultTag return allProxies[i] == defaultTag

View file

@ -42,6 +42,7 @@ var _ adapter.ClashServer = (*Server)(nil)
type Server struct { type Server struct {
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
outboundManager adapter.OutboundManager
logger log.Logger logger log.Logger
httpServer *http.Server httpServer *http.Server
trafficManager *trafficontrol.Manager trafficManager *trafficontrol.Manager
@ -56,12 +57,13 @@ type Server struct {
externalUIDownloadDetour string externalUIDownloadDetour string
} }
func NewServer(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { func NewServer(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
trafficManager := trafficontrol.NewManager() trafficManager := trafficontrol.NewManager()
chiRouter := chi.NewRouter() chiRouter := chi.NewRouter()
server := &Server{ s := &Server{
ctx: ctx, ctx: ctx,
router: router, router: service.FromContext[adapter.Router](ctx),
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logFactory.NewLogger("clash-api"), logger: logFactory.NewLogger("clash-api"),
httpServer: &http.Server{ httpServer: &http.Server{
Addr: options.ExternalController, Addr: options.ExternalController,
@ -73,18 +75,18 @@ func NewServer(ctx context.Context, router adapter.Router, logFactory log.Observ
externalUIDownloadURL: options.ExternalUIDownloadURL, externalUIDownloadURL: options.ExternalUIDownloadURL,
externalUIDownloadDetour: options.ExternalUIDownloadDetour, externalUIDownloadDetour: options.ExternalUIDownloadDetour,
} }
server.urlTestHistory = service.PtrFromContext[urltest.HistoryStorage](ctx) s.urlTestHistory = service.PtrFromContext[urltest.HistoryStorage](ctx)
if server.urlTestHistory == nil { if s.urlTestHistory == nil {
server.urlTestHistory = urltest.NewHistoryStorage() s.urlTestHistory = urltest.NewHistoryStorage()
} }
defaultMode := "Rule" defaultMode := "Rule"
if options.DefaultMode != "" { if options.DefaultMode != "" {
defaultMode = options.DefaultMode defaultMode = options.DefaultMode
} }
if !common.Contains(server.modeList, defaultMode) { if !common.Contains(s.modeList, defaultMode) {
server.modeList = append([]string{defaultMode}, server.modeList...) s.modeList = append([]string{defaultMode}, s.modeList...)
} }
server.mode = defaultMode s.mode = defaultMode
//goland:noinspection GoDeprecation //goland:noinspection GoDeprecation
//nolint:staticcheck //nolint:staticcheck
if options.StoreMode || options.StoreSelected || options.StoreFakeIP || options.CacheFile != "" || options.CacheID != "" { if options.StoreMode || options.StoreSelected || options.StoreFakeIP || options.CacheFile != "" || options.CacheID != "" {
@ -108,30 +110,30 @@ func NewServer(ctx context.Context, router adapter.Router, logFactory log.Observ
r.Get("/logs", getLogs(logFactory)) r.Get("/logs", getLogs(logFactory))
r.Get("/traffic", traffic(trafficManager)) r.Get("/traffic", traffic(trafficManager))
r.Get("/version", version) r.Get("/version", version)
r.Mount("/configs", configRouter(server, logFactory)) r.Mount("/configs", configRouter(s, logFactory))
r.Mount("/proxies", proxyRouter(server, router)) r.Mount("/proxies", proxyRouter(s, s.router))
r.Mount("/rules", ruleRouter(router)) r.Mount("/rules", ruleRouter(s.router))
r.Mount("/connections", connectionRouter(router, trafficManager)) r.Mount("/connections", connectionRouter(s.router, trafficManager))
r.Mount("/providers/proxies", proxyProviderRouter()) r.Mount("/providers/proxies", proxyProviderRouter())
r.Mount("/providers/rules", ruleProviderRouter()) r.Mount("/providers/rules", ruleProviderRouter())
r.Mount("/script", scriptRouter()) r.Mount("/script", scriptRouter())
r.Mount("/profile", profileRouter()) r.Mount("/profile", profileRouter())
r.Mount("/cache", cacheRouter(ctx)) r.Mount("/cache", cacheRouter(ctx))
r.Mount("/dns", dnsRouter(router)) r.Mount("/dns", dnsRouter(s.router))
server.setupMetaAPI(r) s.setupMetaAPI(r)
}) })
if options.ExternalUI != "" { if options.ExternalUI != "" {
server.externalUI = filemanager.BasePath(ctx, os.ExpandEnv(options.ExternalUI)) s.externalUI = filemanager.BasePath(ctx, os.ExpandEnv(options.ExternalUI))
chiRouter.Group(func(r chi.Router) { chiRouter.Group(func(r chi.Router) {
fs := http.StripPrefix("/ui", http.FileServer(http.Dir(server.externalUI))) fs := http.StripPrefix("/ui", http.FileServer(http.Dir(s.externalUI)))
r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP) r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP)
r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) { r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) {
fs.ServeHTTP(w, r) fs.ServeHTTP(w, r)
}) })
}) })
} }
return server, nil return s, nil
} }
func (s *Server) PreStart() error { func (s *Server) PreStart() error {
@ -235,12 +237,12 @@ func (s *Server) TrafficManager() *trafficontrol.Manager {
} }
func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) { func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) {
tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.router, matchedRule) tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule)
return tracker, tracker return tracker, tracker
} }
func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) { func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) {
tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.router, matchedRule) tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule)
return tracker, tracker return tracker, tracker
} }

View file

@ -15,7 +15,6 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service/filemanager" "github.com/sagernet/sing/service/filemanager"
) )
@ -45,16 +44,13 @@ func (s *Server) downloadExternalUI() error {
s.logger.Info("downloading external ui") s.logger.Info("downloading external ui")
var detour adapter.Outbound var detour adapter.Outbound
if s.externalUIDownloadDetour != "" { if s.externalUIDownloadDetour != "" {
outbound, loaded := s.router.Outbound(s.externalUIDownloadDetour) outbound, loaded := s.outboundManager.Outbound(s.externalUIDownloadDetour)
if !loaded { if !loaded {
return E.New("detour outbound not found: ", s.externalUIDownloadDetour) return E.New("detour outbound not found: ", s.externalUIDownloadDetour)
} }
detour = outbound detour = outbound
} else { } else {
outbound, err := s.router.DefaultOutbound(N.NetworkTCP) outbound := s.outboundManager.Default()
if err != nil {
return err
}
detour = outbound detour = outbound
} }
httpClient := &http.Client{ httpClient := &http.Client{

View file

@ -124,7 +124,7 @@ func (tt *TCPConn) WriterReplaceable() bool {
return true return true
} }
func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, router adapter.Router, rule adapter.Rule) *TCPConn { func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *TCPConn {
id, _ := uuid.NewV4() id, _ := uuid.NewV4()
var ( var (
chain []string chain []string
@ -138,11 +138,11 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont
} }
if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction { if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction {
next = routeAction.Outbound next = routeAction.Outbound
} else if defaultOutbound, err := router.DefaultOutbound(N.NetworkTCP); err == nil { } else {
next = defaultOutbound.Tag() next = outboundManager.Default().Tag()
} }
for { for {
detour, loaded := router.Outbound(next) detour, loaded := outboundManager.Outbound(next)
if !loaded { if !loaded {
break break
} }
@ -213,7 +213,7 @@ func (ut *UDPConn) WriterReplaceable() bool {
return true return true
} }
func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, router adapter.Router, rule adapter.Rule) *UDPConn { func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *UDPConn {
id, _ := uuid.NewV4() id, _ := uuid.NewV4()
var ( var (
chain []string chain []string
@ -227,11 +227,11 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound
} }
if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction { if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction {
next = routeAction.Outbound next = routeAction.Outbound
} else if defaultOutbound, err := router.DefaultOutbound(N.NetworkUDP); err == nil { } else {
next = defaultOutbound.Tag() next = outboundManager.Default().Tag()
} }
for { for {
detour, loaded := router.Outbound(next) detour, loaded := outboundManager.Outbound(next)
if !loaded { if !loaded {
break break
} }

View file

@ -76,21 +76,6 @@ func (i *Inbound) Close() error {
return i.listener.Close() return i.listener.Close()
} }
func (i *Inbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
switch i.overrideOption {
case 1:
metadata.Destination = i.overrideDestination
case 2:
destination := i.overrideDestination
destination.Port = metadata.Destination.Port
metadata.Destination = destination
case 3:
metadata.Destination.Port = i.overrideDestination.Port
}
i.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
return i.router.RouteConnection(ctx, conn, metadata)
}
func (i *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { func (i *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
var destination M.Socksaddr var destination M.Socksaddr
switch i.overrideOption { switch i.overrideOption {
@ -107,9 +92,21 @@ func (i *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
} }
func (i *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (i *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
switch i.overrideOption {
case 1:
metadata.Destination = i.overrideDestination
case 2:
destination := i.overrideDestination
destination.Port = metadata.Destination.Port
metadata.Destination = destination
case 3:
metadata.Destination.Port = i.overrideDestination.Port
}
i.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) i.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
metadata.Inbound = i.Tag() metadata.Inbound = i.Tag()
metadata.InboundType = i.Type() metadata.InboundType = i.Type()
metadata.InboundDetour = i.listener.ListenOptions().Detour
metadata.InboundOptions = i.listener.ListenOptions().InboundOptions
i.router.RouteConnectionEx(ctx, conn, metadata, onClose) i.router.RouteConnectionEx(ctx, conn, metadata, onClose)
} }
@ -119,6 +116,8 @@ func (i *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
var metadata adapter.InboundContext var metadata adapter.InboundContext
metadata.Inbound = i.Tag() metadata.Inbound = i.Tag()
metadata.InboundType = i.Type() metadata.InboundType = i.Type()
metadata.InboundDetour = i.listener.ListenOptions().Detour
metadata.InboundOptions = i.listener.ListenOptions().InboundOptions
metadata.Source = source metadata.Source = source
metadata.Destination = destination metadata.Destination = destination
metadata.OriginDestination = i.listener.UDPAddr() metadata.OriginDestination = i.listener.UDPAddr()

View file

@ -12,7 +12,7 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns" dns "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
@ -39,7 +39,7 @@ type Outbound struct {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) {
options.UDPFragmentDefault = true options.UDPFragmentDefault = true
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -26,7 +26,7 @@ var _ adapter.OutboundGroup = (*Selector)(nil)
type Selector struct { type Selector struct {
outbound.Adapter outbound.Adapter
ctx context.Context ctx context.Context
router adapter.Router outboundManager adapter.OutboundManager
logger logger.ContextLogger logger logger.ContextLogger
tags []string tags []string
defaultTag string defaultTag string
@ -40,7 +40,7 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL
outbound := &Selector{ outbound := &Selector{
Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds),
ctx: ctx, ctx: ctx,
router: router, outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger, logger: logger,
tags: options.Outbounds, tags: options.Outbounds,
defaultTag: options.Default, defaultTag: options.Default,
@ -63,7 +63,7 @@ func (s *Selector) Network() []string {
func (s *Selector) Start() error { func (s *Selector) Start() error {
for i, tag := range s.tags { for i, tag := range s.tags {
detour, loaded := s.router.Outbound(tag) detour, loaded := s.outboundManager.Outbound(tag)
if !loaded { if !loaded {
return E.New("outbound ", i, " not found: ", tag) return E.New("outbound ", i, " not found: ", tag)
} }

View file

@ -36,6 +36,7 @@ type URLTest struct {
outbound.Adapter outbound.Adapter
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
outboundManager adapter.OutboundManager
logger log.ContextLogger logger log.ContextLogger
tags []string tags []string
link string link string
@ -51,6 +52,7 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds),
ctx: ctx, ctx: ctx,
router: router, router: router,
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger, logger: logger,
tags: options.Outbounds, tags: options.Outbounds,
link: options.URL, link: options.URL,
@ -68,7 +70,7 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
func (s *URLTest) Start() error { func (s *URLTest) Start() error {
outbounds := make([]adapter.Outbound, 0, len(s.tags)) outbounds := make([]adapter.Outbound, 0, len(s.tags))
for i, tag := range s.tags { for i, tag := range s.tags {
detour, loaded := s.router.Outbound(tag) detour, loaded := s.outboundManager.Outbound(tag)
if !loaded { if !loaded {
return E.New("outbound ", i, " not found: ", tag) return E.New("outbound ", i, " not found: ", tag)
} }
@ -77,6 +79,7 @@ func (s *URLTest) Start() error {
group, err := NewURLTestGroup( group, err := NewURLTestGroup(
s.ctx, s.ctx,
s.router, s.router,
s.outboundManager,
s.logger, s.logger,
outbounds, outbounds,
s.link, s.link,
@ -190,6 +193,7 @@ func (s *URLTest) InterfaceUpdated() {
type URLTestGroup struct { type URLTestGroup struct {
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
outboundManager adapter.OutboundManager
logger log.Logger logger log.Logger
outbounds []adapter.Outbound outbounds []adapter.Outbound
link string link string
@ -214,6 +218,7 @@ type URLTestGroup struct {
func NewURLTestGroup( func NewURLTestGroup(
ctx context.Context, ctx context.Context,
router adapter.Router, router adapter.Router,
outboundManager adapter.OutboundManager,
logger log.Logger, logger log.Logger,
outbounds []adapter.Outbound, outbounds []adapter.Outbound,
link string, link string,
@ -244,6 +249,7 @@ func NewURLTestGroup(
return &URLTestGroup{ return &URLTestGroup{
ctx: ctx, ctx: ctx,
router: router, router: router,
outboundManager: outboundManager,
logger: logger, logger: logger,
outbounds: outbounds, outbounds: outbounds,
link: link, link: link,
@ -385,7 +391,7 @@ func (g *URLTestGroup) urlTest(ctx context.Context, force bool) (map[string]uint
continue continue
} }
checked[realTag] = true checked[realTag] = true
p, loaded := g.router.Outbound(realTag) p, loaded := g.outboundManager.Outbound(realTag)
if !loaded { if !loaded {
continue continue
} }

View file

@ -73,7 +73,7 @@ func (h *Inbound) Start() error {
func (h *Inbound) Close() error { func (h *Inbound) Close() error {
return common.Close( return common.Close(
&h.listener, h.listener,
h.tlsConfig, h.tlsConfig,
) )
} }
@ -82,8 +82,12 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
err := h.newConnection(ctx, conn, metadata, onClose) err := h.newConnection(ctx, conn, metadata, onClose)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
@ -98,6 +102,10 @@ func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata ada
} }
func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
user, loaded := auth.UserFromContext[string](ctx) user, loaded := auth.UserFromContext[string](ctx)
if !loaded { if !loaded {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
@ -110,6 +118,10 @@ func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata
} }
func (h *Inbound) streamUserPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (h *Inbound) streamUserPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
user, loaded := auth.UserFromContext[string](ctx) user, loaded := auth.UserFromContext[string](ctx)
if !loaded { if !loaded {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)

View file

@ -30,7 +30,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -17,6 +17,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
@ -85,7 +86,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
XPlusPassword: options.Obfs, XPlusPassword: options.Obfs,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
UDPTimeout: udpTimeout, UDPTimeout: udpTimeout,
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil), Handler: inbound,
// Legacy options // Legacy options
@ -117,12 +118,16 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx) ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx) userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" { if userName := h.userNameList[userID]; userName != "" {
@ -131,16 +136,19 @@ func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata ada
} else { } else {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
} }
return h.router.RouteConnection(ctx, conn, metadata) h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx) ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr() metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx) userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" { if userName := h.userNameList[userID]; userName != "" {
@ -149,7 +157,7 @@ func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, me
} else { } else {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
} }
return h.router.RoutePacketConnection(ctx, conn, metadata) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start() error {
@ -168,7 +176,7 @@ func (h *Inbound) Start() error {
func (h *Inbound) Close() error { func (h *Inbound) Close() error {
return common.Close( return common.Close(
&h.listener, h.listener,
h.tlsConfig, h.tlsConfig,
common.PtrOrNil(h.service), common.PtrOrNil(h.service),
) )

View file

@ -47,7 +47,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
if err != nil { if err != nil {
return nil, err return nil, err
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
@ -108,7 +109,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
IgnoreClientBandwidth: options.IgnoreClientBandwidth, IgnoreClientBandwidth: options.IgnoreClientBandwidth,
UDPTimeout: udpTimeout, UDPTimeout: udpTimeout,
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil), Handler: inbound,
MasqueradeHandler: masqueradeHandler, MasqueradeHandler: masqueradeHandler,
}) })
if err != nil { if err != nil {
@ -128,12 +129,16 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx) ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx) userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" { if userName := h.userNameList[userID]; userName != "" {
@ -142,16 +147,19 @@ func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata ada
} else { } else {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
} }
return h.router.RouteConnection(ctx, conn, metadata) h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx) ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr() metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx) userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" { if userName := h.userNameList[userID]; userName != "" {
@ -160,7 +168,7 @@ func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, me
} else { } else {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
} }
return h.router.RoutePacketConnection(ctx, conn, metadata) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start() error {
@ -179,7 +187,7 @@ func (h *Inbound) Start() error {
func (h *Inbound) Close() error { func (h *Inbound) Close() error {
return common.Close( return common.Close(
&h.listener, h.listener,
h.tlsConfig, h.tlsConfig,
common.PtrOrNil(h.service), common.PtrOrNil(h.service),
) )

View file

@ -59,7 +59,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, E.New("unknown obfs type: ", options.Obfs.Type) return nil, E.New("unknown obfs type: ", options.Obfs.Type)
} }
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -66,8 +66,12 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
err := h.newConnection(ctx, conn, metadata, onClose) err := h.newConnection(ctx, conn, metadata, onClose)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
@ -85,6 +89,10 @@ func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata ada
} }
func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
user, loaded := auth.UserFromContext[string](ctx) user, loaded := auth.UserFromContext[string](ctx)
if !loaded { if !loaded {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
@ -97,6 +105,10 @@ func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata
} }
func (h *Inbound) streamUserPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (h *Inbound) streamUserPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
user, loaded := auth.UserFromContext[string](ctx) user, loaded := auth.UserFromContext[string](ctx)
if !loaded { if !loaded {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)

View file

@ -59,6 +59,8 @@ func (h *Redirect) NewConnectionEx(ctx context.Context, conn net.Conn, metadata
} }
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.Destination = M.SocksaddrFromNetIP(destination) metadata.Destination = M.SocksaddrFromNetIP(destination)
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
h.router.RouteConnectionEx(ctx, conn, metadata, onClose) h.router.RouteConnectionEx(ctx, conn, metadata, onClose)

View file

@ -90,6 +90,8 @@ func (t *TProxy) Close() error {
} }
func (t *TProxy) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (t *TProxy) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
metadata.Inbound = t.Tag()
metadata.InboundType = t.Type()
metadata.Destination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap() metadata.Destination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
t.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) t.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
t.router.RouteConnectionEx(ctx, conn, metadata, onClose) t.router.RouteConnectionEx(ctx, conn, metadata, onClose)
@ -101,6 +103,8 @@ func (t *TProxy) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, s
var metadata adapter.InboundContext var metadata adapter.InboundContext
metadata.Inbound = t.Tag() metadata.Inbound = t.Tag()
metadata.InboundType = t.Type() metadata.InboundType = t.Type()
metadata.InboundDetour = t.listener.ListenOptions().Detour
metadata.InboundOptions = t.listener.ListenOptions().InboundOptions
metadata.Source = source metadata.Source = source
metadata.Destination = destination metadata.Destination = destination
metadata.OriginDestination = t.listener.UDPAddr() metadata.OriginDestination = t.listener.UDPAddr()

View file

@ -105,8 +105,12 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata)) err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata))
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { func (h *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
@ -120,6 +124,8 @@ func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata ada
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
return h.router.RouteConnection(ctx, conn, metadata) return h.router.RouteConnection(ctx, conn, metadata)
} }

View file

@ -113,8 +113,12 @@ func (h *MultiInbound) NewConnectionEx(ctx context.Context, conn net.Conn, metad
err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata)) err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata))
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *MultiInbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { func (h *MultiInbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
@ -138,6 +142,8 @@ func (h *MultiInbound) newConnection(ctx context.Context, conn net.Conn, metadat
h.logger.InfoContext(ctx, "[", user, "] inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "[", user, "] inbound connection to ", metadata.Destination)
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
return h.router.RouteConnection(ctx, conn, metadata) return h.router.RouteConnection(ctx, conn, metadata)
} }

View file

@ -98,8 +98,12 @@ func (h *RelayInbound) NewConnectionEx(ctx context.Context, conn net.Conn, metad
err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata)) err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata))
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *RelayInbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { func (h *RelayInbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
@ -123,6 +127,8 @@ func (h *RelayInbound) newConnection(ctx context.Context, conn net.Conn, metadat
h.logger.InfoContext(ctx, "[", destination, "] inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "[", destination, "] inbound connection to ", metadata.Destination)
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
return h.router.RouteConnection(ctx, conn, metadata) return h.router.RouteConnection(ctx, conn, metadata)
} }

View file

@ -44,7 +44,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
if err != nil { if err != nil {
return nil, err return nil, err
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -16,6 +16,7 @@ import (
"github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
@ -46,7 +47,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
if options.Version > 1 { if options.Version > 1 {
handshakeForServerName = make(map[string]shadowtls.HandshakeConfig) handshakeForServerName = make(map[string]shadowtls.HandshakeConfig)
for serverName, serverOptions := range options.HandshakeForServerName { for serverName, serverOptions := range options.HandshakeForServerName {
handshakeDialer, err := dialer.New(router, serverOptions.DialerOptions) handshakeDialer, err := dialer.New(ctx, serverOptions.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -56,7 +57,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
} }
} }
} }
handshakeDialer, err := dialer.New(router, options.Handshake.DialerOptions) handshakeDialer, err := dialer.New(ctx, options.Handshake.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,7 +73,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
}, },
HandshakeForServerName: handshakeForServerName, HandshakeForServerName: handshakeForServerName,
StrictMode: options.StrictMode, StrictMode: options.StrictMode,
Handler: adapter.NewUpstreamContextHandler(inbound.newConnection, nil, nil), Handler: (*inboundHandler)(inbound),
Logger: logger, Logger: logger,
}) })
if err != nil { if err != nil {
@ -97,24 +98,33 @@ func (h *Inbound) Close() error {
return h.listener.Close() return h.listener.Close()
} }
func (h *Inbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata)) err := h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, metadata.Source, metadata.Destination, onClose)
N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
}
}
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { type inboundHandler Inbound
func (h *inboundHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.Source = source
metadata.Destination = destination
if userName, _ := auth.UserFromContext[string](ctx); userName != "" { if userName, _ := auth.UserFromContext[string](ctx); userName != "" {
metadata.User = userName metadata.User = userName
h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination)
} else { } else {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
} }
return h.router.RouteConnection(ctx, conn, metadata) h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
err := h.NewConnection(ctx, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
}
} }

View file

@ -68,7 +68,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
tlsHandshakeFunc = shadowtls.DefaultTLSHandshakeFunc(options.Password, stdTLSConfig) tlsHandshakeFunc = shadowtls.DefaultTLSHandshakeFunc(options.Password, stdTLSConfig)
} }
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -62,11 +62,19 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
err := socks.HandleConnectionEx(ctx, conn, std_bufio.NewReader(conn), h.authenticator, nil, adapter.NewUpstreamHandlerEx(metadata, h.newUserConnection, h.streamUserPacketConnection), metadata.Source, metadata.Destination, onClose) err := socks.HandleConnectionEx(ctx, conn, std_bufio.NewReader(conn), h.authenticator, nil, adapter.NewUpstreamHandlerEx(metadata, h.newUserConnection, h.streamUserPacketConnection), metadata.Source, metadata.Destination, onClose)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
user, loaded := auth.UserFromContext[string](ctx) user, loaded := auth.UserFromContext[string](ctx)
if !loaded { if !loaded {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
@ -79,6 +87,10 @@ func (h *Inbound) newUserConnection(ctx context.Context, conn net.Conn, metadata
} }
func (h *Inbound) streamUserPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (h *Inbound) streamUserPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
user, loaded := auth.UserFromContext[string](ctx) user, loaded := auth.UserFromContext[string](ctx)
if !loaded { if !loaded {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)

View file

@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
if err != nil { if err != nil {
return nil, err return nil, err
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -49,7 +49,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -75,7 +75,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
} }
startConf.TorrcFile = torrcFile startConf.TorrcFile = torrcFile
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -149,7 +149,7 @@ func (h *Inbound) Start() error {
func (h *Inbound) Close() error { func (h *Inbound) Close() error {
return common.Close( return common.Close(
&h.listener, h.listener,
h.tlsConfig, h.tlsConfig,
h.transport, h.transport,
) )
@ -170,11 +170,19 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
err := h.NewConnection(ctx, conn, metadata) err := h.NewConnection(ctx, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
userIndex, loaded := auth.UserFromContext[int](ctx) userIndex, loaded := auth.UserFromContext[int](ctx)
if !loaded { if !loaded {
return os.ErrInvalid return os.ErrInvalid
@ -207,12 +215,20 @@ func (h *Inbound) fallbackConnection(ctx context.Context, conn net.Conn, metadat
} }
fallbackAddr = h.fallbackAddr fallbackAddr = h.fallbackAddr
} }
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
h.logger.InfoContext(ctx, "fallback connection to ", fallbackAddr) h.logger.InfoContext(ctx, "fallback connection to ", fallbackAddr)
metadata.Destination = fallbackAddr metadata.Destination = fallbackAddr
return h.router.RouteConnection(ctx, conn, metadata) return h.router.RouteConnection(ctx, conn, metadata)
} }
func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
userIndex, loaded := auth.UserFromContext[int](ctx) userIndex, loaded := auth.UserFromContext[int](ctx)
if !loaded { if !loaded {
return os.ErrInvalid return os.ErrInvalid
@ -233,10 +249,6 @@ type inboundTransportHandler Inbound
func (h *inboundTransportHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { func (h *inboundTransportHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.Source = source metadata.Source = source
metadata.Destination = destination metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)

View file

@ -38,7 +38,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -17,6 +17,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
@ -71,7 +72,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
ZeroRTTHandshake: options.ZeroRTTHandshake, ZeroRTTHandshake: options.ZeroRTTHandshake,
Heartbeat: time.Duration(options.Heartbeat), Heartbeat: time.Duration(options.Heartbeat),
UDPTimeout: udpTimeout, UDPTimeout: udpTimeout,
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil), Handler: inbound,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -99,12 +100,16 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx) ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx) userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" { if userName := h.userNameList[userID]; userName != "" {
@ -113,16 +118,19 @@ func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata ada
} else { } else {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
} }
return h.router.RouteConnection(ctx, conn, metadata) h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx) ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag() metadata.Inbound = h.Tag()
metadata.InboundType = h.Type() metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr() metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx) userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" { if userName := h.userNameList[userID]; userName != "" {
@ -131,7 +139,7 @@ func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, me
} else { } else {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
} }
return h.router.RoutePacketConnection(ctx, conn, metadata) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start() error {
@ -150,7 +158,7 @@ func (h *Inbound) Start() error {
func (h *Inbound) Close() error { func (h *Inbound) Close() error {
return common.Close( return common.Close(
&h.listener, h.listener,
h.tlsConfig, h.tlsConfig,
common.PtrOrNil(h.server), common.PtrOrNil(h.server),
) )

View file

@ -60,7 +60,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
case "quic": case "quic":
tuicUDPStream = true tuicUDPStream = true
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -129,7 +129,7 @@ func (h *Inbound) Start() error {
func (h *Inbound) Close() error { func (h *Inbound) Close() error {
return common.Close( return common.Close(
h.service, h.service,
&h.listener, h.listener,
h.tlsConfig, h.tlsConfig,
h.transport, h.transport,
) )
@ -150,11 +150,19 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
err := h.NewConnection(ctx, conn, metadata) err := h.NewConnection(ctx, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
userIndex, loaded := auth.UserFromContext[int](ctx) userIndex, loaded := auth.UserFromContext[int](ctx)
if !loaded { if !loaded {
return os.ErrInvalid return os.ErrInvalid
@ -170,6 +178,10 @@ func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata ada
} }
func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
userIndex, loaded := auth.UserFromContext[int](ctx) userIndex, loaded := auth.UserFromContext[int](ctx)
if !loaded { if !loaded {
return os.ErrInvalid return os.ErrInvalid
@ -196,10 +208,6 @@ type inboundTransportHandler Inbound
func (h *inboundTransportHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { func (h *inboundTransportHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.Source = source metadata.Source = source
metadata.Destination = destination metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)

View file

@ -41,7 +41,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VLESSOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VLESSOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -143,7 +143,7 @@ func (h *Inbound) Start() error {
func (h *Inbound) Close() error { func (h *Inbound) Close() error {
return common.Close( return common.Close(
h.service, h.service,
&h.listener, h.listener,
h.tlsConfig, h.tlsConfig,
h.transport, h.transport,
) )
@ -164,11 +164,19 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
err := h.NewConnection(ctx, conn, metadata) err := h.NewConnection(ctx, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
if E.IsClosedOrCanceled(err) {
h.logger.DebugContext(ctx, "connection closed: ", err)
} else {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
} }
}
} }
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
userIndex, loaded := auth.UserFromContext[int](ctx) userIndex, loaded := auth.UserFromContext[int](ctx)
if !loaded { if !loaded {
return os.ErrInvalid return os.ErrInvalid
@ -184,6 +192,10 @@ func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata ada
} }
func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (h *Inbound) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
userIndex, loaded := auth.UserFromContext[int](ctx) userIndex, loaded := auth.UserFromContext[int](ctx)
if !loaded { if !loaded {
return os.ErrInvalid return os.ErrInvalid
@ -210,10 +222,6 @@ type inboundTransportHandler Inbound
func (h *inboundTransportHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { func (h *inboundTransportHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.Source = source metadata.Source = source
metadata.Destination = destination metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)

View file

@ -41,7 +41,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -78,7 +78,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
options.IsWireGuardListener = true options.IsWireGuardListener = true
outbound.useStdNetBind = true outbound.useStdNetBind = true
} }
listener, err := dialer.New(router, options.DialerOptions) listener, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -77,7 +77,6 @@ func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, b
return err return err
} }
err = conn.WritePacket(responseBuffer, destination) err = conn.WritePacket(responseBuffer, destination)
responseBuffer.Release()
return err return err
} }

View file

@ -87,7 +87,7 @@ type Router struct {
v2rayServer adapter.V2RayServer v2rayServer adapter.V2RayServer
platformInterface platform.Interface platformInterface platform.Interface
needWIFIState bool needWIFIState bool
needPackageManager bool enforcePackageManager bool
wifiState adapter.WIFIState wifiState adapter.WIFIState
started bool started bool
} }
@ -123,7 +123,7 @@ func NewRouter(
pauseManager: service.FromContext[pause.Manager](ctx), pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[platform.Interface](ctx), platformInterface: service.FromContext[platform.Interface](ctx),
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
needPackageManager: common.Any(inbounds, func(inbound option.Inbound) bool { enforcePackageManager: common.Any(inbounds, func(inbound option.Inbound) bool {
if tunOptions, isTUN := inbound.Options.(*option.TunInboundOptions); isTUN && tunOptions.AutoRoute { if tunOptions, isTUN := inbound.Options.(*option.TunInboundOptions); isTUN && tunOptions.AutoRoute {
return true return true
} }
@ -191,7 +191,8 @@ func NewRouter(
transportTags[i] = tag transportTags[i] = tag
transportTagMap[tag] = true transportTagMap[tag] = true
} }
ctx = adapter.ContextWithRouter(ctx, router) ctx = service.ContextWith[adapter.Router](ctx, router)
outboundManager := service.FromContext[adapter.OutboundManager](ctx)
for { for {
lastLen := len(dummyTransportMap) lastLen := len(dummyTransportMap)
for i, server := range dnsOptions.Servers { for i, server := range dnsOptions.Servers {
@ -201,9 +202,9 @@ func NewRouter(
} }
var detour N.Dialer var detour N.Dialer
if server.Detour == "" { if server.Detour == "" {
detour = dialer.NewRouter(router) detour = dialer.NewDefaultOutbound(outboundManager)
} else { } else {
detour = dialer.NewDetour(router, server.Detour) detour = dialer.NewDetour(outboundManager, server.Detour)
} }
var serverProtocol string var serverProtocol string
switch server.Address { switch server.Address {
@ -327,7 +328,7 @@ func NewRouter(
} }
usePlatformDefaultInterfaceMonitor := router.platformInterface != nil && router.platformInterface.UsePlatformDefaultInterfaceMonitor() usePlatformDefaultInterfaceMonitor := router.platformInterface != nil && router.platformInterface.UsePlatformDefaultInterfaceMonitor()
needInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool { enforceInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool {
if httpMixedOptions, isHTTPMixed := inbound.Options.(*option.HTTPMixedInboundOptions); isHTTPMixed && httpMixedOptions.SetSystemProxy { if httpMixedOptions, isHTTPMixed := inbound.Options.(*option.HTTPMixedInboundOptions); isHTTPMixed && httpMixedOptions.SetSystemProxy {
return true return true
} }
@ -339,7 +340,7 @@ func NewRouter(
if !usePlatformDefaultInterfaceMonitor { if !usePlatformDefaultInterfaceMonitor {
networkMonitor, err := tun.NewNetworkUpdateMonitor(router.logger) networkMonitor, err := tun.NewNetworkUpdateMonitor(router.logger)
if !((err != nil && !needInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) { if !((err != nil && !enforceInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -365,7 +366,7 @@ func NewRouter(
} }
if ntpOptions.Enabled { if ntpOptions.Enabled {
ntpDialer, err := dialer.New(router, ntpOptions.DialerOptions) ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "create NTP service") return nil, E.Cause(err, "create NTP service")
} }
@ -383,73 +384,6 @@ func NewRouter(
return router, nil return router, nil
} }
func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error {
inboundByTag := make(map[string]adapter.Inbound)
for _, inbound := range inbounds {
inboundByTag[inbound.Tag()] = inbound
}
outboundByTag := make(map[string]adapter.Outbound)
for _, detour := range outbounds {
outboundByTag[detour.Tag()] = detour
}
var defaultOutboundForConnection adapter.Outbound
var defaultOutboundForPacketConnection adapter.Outbound
if r.defaultDetour != "" {
detour, loaded := outboundByTag[r.defaultDetour]
if !loaded {
return E.New("default detour not found: ", r.defaultDetour)
}
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
}
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
}
}
if defaultOutboundForConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
break
}
}
}
if defaultOutboundForPacketConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
break
}
}
}
if defaultOutboundForConnection == nil || defaultOutboundForPacketConnection == nil {
detour := defaultOutbound()
if defaultOutboundForConnection == nil {
defaultOutboundForConnection = detour
}
if defaultOutboundForPacketConnection == nil {
defaultOutboundForPacketConnection = detour
}
outbounds = append(outbounds, detour)
outboundByTag[detour.Tag()] = detour
}
r.inboundByTag = inboundByTag
r.outbounds = outbounds
r.defaultOutboundForConnection = defaultOutboundForConnection
r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection
r.outboundByTag = outboundByTag
for i, rule := range r.rules {
routeAction, isRoute := rule.Action().(*R.RuleActionRoute)
if !isRoute {
continue
}
if _, loaded := outboundByTag[routeAction.Outbound]; !loaded {
return E.New("outbound not found for rule[", i, "]: ", routeAction.Outbound)
}
}
return nil
}
func (r *Router) Outbounds() []adapter.Outbound { func (r *Router) Outbounds() []adapter.Outbound {
if !r.started { if !r.started {
return nil return nil
@ -457,8 +391,10 @@ func (r *Router) Outbounds() []adapter.Outbound {
return r.outbounds return r.outbounds
} }
func (r *Router) PreStart() error { func (r *Router) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(r.logger, C.StartTimeout) monitor := taskmonitor.New(r.logger, C.StartTimeout)
switch stage {
case adapter.StartStateInitialize:
if r.interfaceMonitor != nil { if r.interfaceMonitor != nil {
monitor.Start("initialize interface monitor") monitor.Start("initialize interface monitor")
err := r.interfaceMonitor.Start() err := r.interfaceMonitor.Start()
@ -483,11 +419,7 @@ func (r *Router) PreStart() error {
return err return err
} }
} }
return nil case adapter.StartStateStart:
}
func (r *Router) Start() error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
if r.needGeoIPDatabase { if r.needGeoIPDatabase {
monitor.Start("initialize geoip database") monitor.Start("initialize geoip database")
err := r.prepareGeoIPDatabase() err := r.prepareGeoIPDatabase()
@ -557,7 +489,7 @@ func (r *Router) Start() error {
if err != nil { if err != nil {
return E.Cause(err, "create package manager") return E.Cause(err, "create package manager")
} }
if r.needPackageManager { if r.enforcePackageManager {
monitor.Start("start package manager") monitor.Start("start package manager")
err = packageManager.Start() err = packageManager.Start()
monitor.Finish() monitor.Finish()
@ -592,6 +524,108 @@ func (r *Router) Start() error {
return E.Cause(err, "initialize time service") return E.Cause(err, "initialize time service")
} }
} }
case adapter.StartStatePostStart:
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set")
cacheContext = adapter.NewHTTPStartContext()
var ruleSetStartGroup task.Group
for i, ruleSet := range r.ruleSets {
ruleSetInPlace := ruleSet
ruleSetStartGroup.Append0(func(ctx context.Context) error {
err := ruleSetInPlace.StartContext(ctx, cacheContext)
if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]")
}
return nil
})
}
ruleSetStartGroup.Concurrency(5)
ruleSetStartGroup.FastFail()
err := ruleSetStartGroup.Run(r.ctx)
monitor.Finish()
if err != nil {
return err
}
}
if cacheContext != nil {
cacheContext.Close()
}
needFindProcess := r.needFindProcess
needWIFIState := r.needWIFIState
for _, ruleSet := range r.ruleSets {
metadata := ruleSet.Metadata()
if metadata.ContainsProcessRule {
needFindProcess = true
}
if metadata.ContainsWIFIRule {
needWIFIState = true
}
}
if C.IsAndroid && r.platformInterface == nil && !r.enforcePackageManager {
if needFindProcess {
monitor.Start("start package manager")
err := r.packageManager.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
}
} else {
r.packageManager = nil
}
}
if needFindProcess {
if r.platformInterface != nil {
r.processSearcher = r.platformInterface
} else {
monitor.Start("initialize process searcher")
searcher, err := process.NewSearcher(process.Config{
Logger: r.logger,
PackageManager: r.packageManager,
})
monitor.Finish()
if err != nil {
if err != os.ErrInvalid {
r.logger.Warn(E.Cause(err, "create process searcher"))
}
} else {
r.processSearcher = searcher
}
}
}
if needWIFIState && r.platformInterface != nil {
monitor.Start("initialize WIFI state")
r.needWIFIState = true
r.interfaceMonitor.RegisterCallback(func(_ int) {
r.updateWIFIState()
})
r.updateWIFIState()
monitor.Finish()
}
for i, rule := range r.rules {
monitor.Start("initialize rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize rule[", i, "]")
}
}
for _, ruleSet := range r.ruleSets {
monitor.Start("post start rule_set[", ruleSet.Name(), "]")
err := ruleSet.PostStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]")
}
}
r.started = true
return nil
case adapter.StartStateStarted:
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
}
return nil return nil
} }
@ -671,113 +705,6 @@ func (r *Router) Close() error {
return err return err
} }
func (r *Router) PostStart() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set")
cacheContext = adapter.NewHTTPStartContext()
var ruleSetStartGroup task.Group
for i, ruleSet := range r.ruleSets {
ruleSetInPlace := ruleSet
ruleSetStartGroup.Append0(func(ctx context.Context) error {
err := ruleSetInPlace.StartContext(ctx, cacheContext)
if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]")
}
return nil
})
}
ruleSetStartGroup.Concurrency(5)
ruleSetStartGroup.FastFail()
err := ruleSetStartGroup.Run(r.ctx)
monitor.Finish()
if err != nil {
return err
}
}
if cacheContext != nil {
cacheContext.Close()
}
needFindProcess := r.needFindProcess
needWIFIState := r.needWIFIState
for _, ruleSet := range r.ruleSets {
metadata := ruleSet.Metadata()
if metadata.ContainsProcessRule {
needFindProcess = true
}
if metadata.ContainsWIFIRule {
needWIFIState = true
}
}
if C.IsAndroid && r.platformInterface == nil && !r.needPackageManager {
if needFindProcess {
monitor.Start("start package manager")
err := r.packageManager.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
}
} else {
r.packageManager = nil
}
}
if needFindProcess {
if r.platformInterface != nil {
r.processSearcher = r.platformInterface
} else {
monitor.Start("initialize process searcher")
searcher, err := process.NewSearcher(process.Config{
Logger: r.logger,
PackageManager: r.packageManager,
})
monitor.Finish()
if err != nil {
if err != os.ErrInvalid {
r.logger.Warn(E.Cause(err, "create process searcher"))
}
} else {
r.processSearcher = searcher
}
}
}
if needWIFIState && r.platformInterface != nil {
monitor.Start("initialize WIFI state")
r.needWIFIState = true
r.interfaceMonitor.RegisterCallback(func(_ int) {
r.updateWIFIState()
})
r.updateWIFIState()
monitor.Finish()
}
for i, rule := range r.rules {
monitor.Start("initialize rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize rule[", i, "]")
}
}
for _, ruleSet := range r.ruleSets {
monitor.Start("post start rule_set[", ruleSet.Name(), "]")
err := ruleSet.PostStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]")
}
}
r.started = true
return nil
}
func (r *Router) Cleanup() error {
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
return nil
}
func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { func (r *Router) Outbound(tag string) (adapter.Outbound, bool) {
outbound, loaded := r.outboundByTag[tag] outbound, loaded := r.outboundByTag[tag]
return outbound, loaded return outbound, loaded

View file

@ -22,7 +22,7 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) { func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) {
switch action.Action { switch action.Action {
case "": case "":
return nil, nil return nil, nil
@ -36,7 +36,7 @@ func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action op
UDPConnect: action.RouteOptionsOptions.UDPConnect, UDPConnect: action.RouteOptionsOptions.UDPConnect,
}, nil }, nil
case C.RuleActionTypeDirect: case C.RuleActionTypeDirect:
directDialer, err := dialer.New(router, option.DialerOptions(action.DirectOptions)) directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -52,7 +52,7 @@ type RuleItem interface {
} }
func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) {
action, err := NewRuleAction(router, logger, options.RuleAction) action, err := NewRuleAction(ctx, logger, options.RuleAction)
if err != nil { if err != nil {
return nil, E.Cause(err, "action") return nil, E.Cause(err, "action")
} }
@ -254,7 +254,7 @@ type LogicalRule struct {
} }
func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
action, err := NewRuleAction(router, logger, options.RuleAction) action, err := NewRuleAction(ctx, logger, options.RuleAction)
if err != nil { if err != nil {
return nil, E.Cause(err, "action") return nil, E.Cause(err, "action")
} }

View file

@ -36,6 +36,7 @@ type RemoteRuleSet struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
router adapter.Router router adapter.Router
outboundManager adapter.OutboundManager
logger logger.ContextLogger logger logger.ContextLogger
options option.RuleSet options option.RuleSet
metadata adapter.RuleSetMetadata metadata adapter.RuleSetMetadata
@ -64,6 +65,7 @@ func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
router: router, router: router,
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger, logger: logger,
options: options, options: options,
updateInterval: updateInterval, updateInterval: updateInterval,
@ -83,17 +85,13 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.
s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx) s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
var dialer N.Dialer var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" { if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour) outbound, loaded := s.outboundManager.Outbound(s.options.RemoteOptions.DownloadDetour)
if !loaded { if !loaded {
return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour) return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour)
} }
dialer = outbound dialer = outbound
} else { } else {
outbound, err := s.router.DefaultOutbound(N.NetworkTCP) dialer = s.outboundManager.Default()
if err != nil {
return err
}
dialer = outbound
} }
s.dialer = dialer s.dialer = dialer
if s.cacheFile != nil { if s.cacheFile != nil {

View file

@ -23,6 +23,7 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
@ -53,7 +54,7 @@ func NewTransport(options dns.TransportOptions) (*Transport, error) {
if linkURL.Host == "" { if linkURL.Host == "" {
return nil, E.New("missing interface name for DHCP") return nil, E.New("missing interface name for DHCP")
} }
router := adapter.RouterFromContext(options.Context) router := service.FromContext[adapter.Router](options.Context)
if router == nil { if router == nil {
return nil, E.New("missing router in context") return nil, E.New("missing router in context")
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
) )
@ -32,7 +33,7 @@ type Transport struct {
} }
func NewTransport(options dns.TransportOptions) (*Transport, error) { func NewTransport(options dns.TransportOptions) (*Transport, error) {
router := adapter.RouterFromContext(options.Context) router := service.FromContext[adapter.Router](options.Context)
if router == nil { if router == nil {
return nil, E.New("missing router in context") return nil, E.New("missing router in context")
} }