mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-01-30 12:46:53 +00:00
289 lines
8.4 KiB
Go
289 lines
8.4 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
"github.com/sagernet/sing-box/common/taskmonitor"
|
|
C "github.com/sagernet/sing-box/constant"
|
|
"github.com/sagernet/sing-box/log"
|
|
"github.com/sagernet/sing/common"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/logger"
|
|
)
|
|
|
|
var _ adapter.DNSTransportManager = (*TransportManager)(nil)
|
|
|
|
type TransportManager struct {
|
|
logger log.ContextLogger
|
|
registry adapter.DNSTransportRegistry
|
|
outbound adapter.OutboundManager
|
|
defaultTag string
|
|
access sync.RWMutex
|
|
started bool
|
|
stage adapter.StartStage
|
|
transports []adapter.DNSTransport
|
|
transportByTag map[string]adapter.DNSTransport
|
|
dependByTag map[string][]string
|
|
defaultTransport adapter.DNSTransport
|
|
defaultTransportFallback adapter.DNSTransport
|
|
fakeIPTransport adapter.FakeIPTransport
|
|
}
|
|
|
|
func NewTransportManager(logger logger.ContextLogger, registry adapter.DNSTransportRegistry, outbound adapter.OutboundManager, defaultTag string) *TransportManager {
|
|
return &TransportManager{
|
|
logger: logger,
|
|
registry: registry,
|
|
outbound: outbound,
|
|
defaultTag: defaultTag,
|
|
transportByTag: make(map[string]adapter.DNSTransport),
|
|
dependByTag: make(map[string][]string),
|
|
}
|
|
}
|
|
|
|
func (m *TransportManager) Initialize(defaultTransportFallback adapter.DNSTransport) {
|
|
m.defaultTransportFallback = defaultTransportFallback
|
|
}
|
|
|
|
func (m *TransportManager) Start(stage adapter.StartStage) error {
|
|
m.access.Lock()
|
|
if m.started && m.stage >= stage {
|
|
panic("already started")
|
|
}
|
|
m.started = true
|
|
m.stage = stage
|
|
outbounds := m.transports
|
|
m.access.Unlock()
|
|
if stage == adapter.StartStateStart {
|
|
return m.startTransports(m.transports)
|
|
} else {
|
|
for _, outbound := range outbounds {
|
|
err := adapter.LegacyStart(outbound, stage)
|
|
if err != nil {
|
|
return E.Cause(err, stage, " dns/", outbound.Type(), "[", outbound.Tag(), "]")
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *TransportManager) startTransports(transports []adapter.DNSTransport) error {
|
|
monitor := taskmonitor.New(m.logger, C.StartTimeout)
|
|
started := make(map[string]bool)
|
|
for {
|
|
canContinue := false
|
|
startOne:
|
|
for _, transportToStart := range transports {
|
|
transportTag := transportToStart.Tag()
|
|
if started[transportTag] {
|
|
continue
|
|
}
|
|
dependencies := transportToStart.Dependencies()
|
|
for _, dependency := range dependencies {
|
|
if !started[dependency] {
|
|
continue startOne
|
|
}
|
|
}
|
|
started[transportTag] = true
|
|
canContinue = true
|
|
if starter, isStarter := transportToStart.(adapter.Lifecycle); isStarter {
|
|
monitor.Start("start dns/", transportToStart.Type(), "[", transportTag, "]")
|
|
err := starter.Start(adapter.StartStateStart)
|
|
monitor.Finish()
|
|
if err != nil {
|
|
return E.Cause(err, "start dns/", transportToStart.Type(), "[", transportTag, "]")
|
|
}
|
|
}
|
|
}
|
|
if len(started) == len(transports) {
|
|
break
|
|
}
|
|
if canContinue {
|
|
continue
|
|
}
|
|
currentTransport := common.Find(transports, func(it adapter.DNSTransport) bool {
|
|
return !started[it.Tag()]
|
|
})
|
|
var lintTransport func(oTree []string, oCurrent adapter.DNSTransport) error
|
|
lintTransport = func(oTree []string, oCurrent adapter.DNSTransport) error {
|
|
problemTransportTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
|
|
return !started[it]
|
|
})
|
|
if common.Contains(oTree, problemTransportTag) {
|
|
return E.New("circular server dependency: ", strings.Join(oTree, " -> "), " -> ", problemTransportTag)
|
|
}
|
|
m.access.Lock()
|
|
problemTransport := m.transportByTag[problemTransportTag]
|
|
m.access.Unlock()
|
|
if problemTransport == nil {
|
|
return E.New("dependency[", problemTransportTag, "] not found for server[", oCurrent.Tag(), "]")
|
|
}
|
|
return lintTransport(append(oTree, problemTransportTag), problemTransport)
|
|
}
|
|
return lintTransport([]string{currentTransport.Tag()}, currentTransport)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *TransportManager) Close() error {
|
|
monitor := taskmonitor.New(m.logger, C.StopTimeout)
|
|
m.access.Lock()
|
|
if !m.started {
|
|
m.access.Unlock()
|
|
return nil
|
|
}
|
|
m.started = false
|
|
transports := m.transports
|
|
m.transports = nil
|
|
m.access.Unlock()
|
|
var err error
|
|
for _, transport := range transports {
|
|
if closer, isCloser := transport.(io.Closer); isCloser {
|
|
monitor.Start("close server/", transport.Type(), "[", transport.Tag(), "]")
|
|
err = E.Append(err, closer.Close(), func(err error) error {
|
|
return E.Cause(err, "close server/", transport.Type(), "[", transport.Tag(), "]")
|
|
})
|
|
monitor.Finish()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *TransportManager) Transports() []adapter.DNSTransport {
|
|
m.access.RLock()
|
|
defer m.access.RUnlock()
|
|
return m.transports
|
|
}
|
|
|
|
func (m *TransportManager) Transport(tag string) (adapter.DNSTransport, bool) {
|
|
m.access.RLock()
|
|
outbound, found := m.transportByTag[tag]
|
|
m.access.RUnlock()
|
|
return outbound, found
|
|
}
|
|
|
|
func (m *TransportManager) Default() adapter.DNSTransport {
|
|
m.access.RLock()
|
|
defer m.access.RUnlock()
|
|
if m.defaultTransport != nil {
|
|
return m.defaultTransport
|
|
} else {
|
|
return m.defaultTransportFallback
|
|
}
|
|
}
|
|
|
|
func (m *TransportManager) FakeIP() adapter.FakeIPTransport {
|
|
m.access.RLock()
|
|
defer m.access.RUnlock()
|
|
return m.fakeIPTransport
|
|
}
|
|
|
|
func (m *TransportManager) Remove(tag string) error {
|
|
m.access.Lock()
|
|
defer m.access.Unlock()
|
|
transport, found := m.transportByTag[tag]
|
|
if !found {
|
|
return os.ErrInvalid
|
|
}
|
|
delete(m.transportByTag, tag)
|
|
index := common.Index(m.transports, func(it adapter.DNSTransport) bool {
|
|
return it == transport
|
|
})
|
|
if index == -1 {
|
|
panic("invalid inbound index")
|
|
}
|
|
m.transports = append(m.transports[:index], m.transports[index+1:]...)
|
|
started := m.started
|
|
if m.defaultTransport == transport {
|
|
if len(m.transports) > 0 {
|
|
nextTransport := m.transports[0]
|
|
if nextTransport.Type() != C.DNSTypeFakeIP {
|
|
return E.New("default server cannot be fakeip")
|
|
}
|
|
m.defaultTransport = nextTransport
|
|
m.logger.Info("updated default server to ", m.defaultTransport.Tag())
|
|
} else {
|
|
m.defaultTransport = nil
|
|
}
|
|
}
|
|
dependBy := m.dependByTag[tag]
|
|
if len(dependBy) > 0 {
|
|
return E.New("server[", tag, "] is depended by ", strings.Join(dependBy, ", "))
|
|
}
|
|
dependencies := transport.Dependencies()
|
|
for _, dependency := range dependencies {
|
|
if len(m.dependByTag[dependency]) == 1 {
|
|
delete(m.dependByTag, dependency)
|
|
} else {
|
|
m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
|
|
return it != tag
|
|
})
|
|
}
|
|
}
|
|
if started {
|
|
transport.Reset()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *TransportManager) Create(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) error {
|
|
if tag == "" {
|
|
return os.ErrInvalid
|
|
}
|
|
transport, err := m.registry.CreateDNSTransport(ctx, logger, tag, transportType, options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.access.Lock()
|
|
defer m.access.Unlock()
|
|
if m.started {
|
|
for _, stage := range adapter.ListStartStages {
|
|
err = adapter.LegacyStart(transport, stage)
|
|
if err != nil {
|
|
return E.Cause(err, stage, " dns/", transport.Type(), "[", transport.Tag(), "]")
|
|
}
|
|
}
|
|
}
|
|
if existsTransport, loaded := m.transportByTag[tag]; loaded {
|
|
if m.started {
|
|
err = common.Close(existsTransport)
|
|
if err != nil {
|
|
return E.Cause(err, "close dns/", existsTransport.Type(), "[", existsTransport.Tag(), "]")
|
|
}
|
|
}
|
|
existsIndex := common.Index(m.transports, func(it adapter.DNSTransport) bool {
|
|
return it == existsTransport
|
|
})
|
|
if existsIndex == -1 {
|
|
panic("invalid inbound index")
|
|
}
|
|
m.transports = append(m.transports[:existsIndex], m.transports[existsIndex+1:]...)
|
|
}
|
|
m.transports = append(m.transports, transport)
|
|
m.transportByTag[tag] = transport
|
|
dependencies := transport.Dependencies()
|
|
for _, dependency := range dependencies {
|
|
m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
|
|
}
|
|
if tag == m.defaultTag || (m.defaultTag == "" && m.defaultTransport == nil) {
|
|
if transport.Type() == C.DNSTypeFakeIP {
|
|
return E.New("default server cannot be fakeip")
|
|
}
|
|
m.defaultTransport = transport
|
|
if m.started {
|
|
m.logger.Info("updated default server to ", transport.Tag())
|
|
}
|
|
}
|
|
if transport.Type() == C.DNSTypeFakeIP {
|
|
if m.fakeIPTransport != nil {
|
|
return E.New("multiple fakeip server are not supported")
|
|
}
|
|
m.fakeIPTransport = transport.(adapter.FakeIPTransport)
|
|
}
|
|
return nil
|
|
}
|