mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-27 02:51:36 +00:00
refactor: connection manager
This commit is contained in:
parent
8610018f3b
commit
7dbc105f89
15
adapter/connections.go
Normal file
15
adapter/connections.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectionManager interface {
|
||||||
|
Start() error
|
||||||
|
Close() error
|
||||||
|
NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc)
|
||||||
|
NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc)
|
||||||
|
}
|
8
box.go
8
box.go
|
@ -36,9 +36,10 @@ type Box struct {
|
||||||
logFactory log.Factory
|
logFactory log.Factory
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
network *route.NetworkManager
|
network *route.NetworkManager
|
||||||
router *route.Router
|
|
||||||
inbound *inbound.Manager
|
inbound *inbound.Manager
|
||||||
outbound *outbound.Manager
|
outbound *outbound.Manager
|
||||||
|
connection *route.ConnectionManager
|
||||||
|
router *route.Router
|
||||||
services []adapter.LifecycleService
|
services []adapter.LifecycleService
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
@ -128,6 +129,8 @@ func New(options Options) (*Box, error) {
|
||||||
return nil, E.Cause(err, "initialize network manager")
|
return nil, E.Cause(err, "initialize network manager")
|
||||||
}
|
}
|
||||||
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
|
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
|
||||||
|
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
|
||||||
|
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
|
||||||
router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
|
router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "initialize router")
|
return nil, E.Cause(err, "initialize router")
|
||||||
|
@ -238,9 +241,10 @@ func New(options Options) (*Box, error) {
|
||||||
}
|
}
|
||||||
return &Box{
|
return &Box{
|
||||||
network: networkManager,
|
network: networkManager,
|
||||||
router: router,
|
|
||||||
inbound: inboundManager,
|
inbound: inboundManager,
|
||||||
outbound: outboundManager,
|
outbound: outboundManager,
|
||||||
|
connection: connectionManager,
|
||||||
|
router: router,
|
||||||
createdAt: createdAt,
|
createdAt: createdAt,
|
||||||
logFactory: logFactory,
|
logFactory: logFactory,
|
||||||
logger: logFactory.Logger(),
|
logger: logFactory.Logger(),
|
||||||
|
|
|
@ -83,7 +83,7 @@ func (i *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
|
||||||
destination = i.overrideDestination
|
destination = i.overrideDestination
|
||||||
case 2:
|
case 2:
|
||||||
destination = i.overrideDestination
|
destination = i.overrideDestination
|
||||||
destination.Port = source.Port
|
destination.Port = i.listener.UDPAddr().Port
|
||||||
case 3:
|
case 3:
|
||||||
destination = source
|
destination = source
|
||||||
destination.Port = i.overrideDestination.Port
|
destination.Port = i.overrideDestination.Port
|
||||||
|
|
330
route/conn.go
Normal file
330
route/conn.go
Normal file
|
@ -0,0 +1,330 @@
|
||||||
|
package route
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
|
||||||
|
|
||||||
|
type ConnectionManager struct {
|
||||||
|
logger logger.ContextLogger
|
||||||
|
monitor *ConnectionMonitor
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
|
||||||
|
return &ConnectionManager{
|
||||||
|
logger: logger,
|
||||||
|
monitor: NewConnectionMonitor(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionManager) Start() error {
|
||||||
|
return m.monitor.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionManager) Close() error {
|
||||||
|
return m.monitor.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||||
|
ctx = adapter.WithContext(ctx, &metadata)
|
||||||
|
var (
|
||||||
|
remoteConn net.Conn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if len(metadata.DestinationAddresses) > 0 {
|
||||||
|
remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
|
||||||
|
} else {
|
||||||
|
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||||
|
m.logger.ErrorContext(ctx, "open outbound connection: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = N.ReportConnHandshakeSuccess(conn, remoteConn)
|
||||||
|
if err != nil {
|
||||||
|
remoteConn.Close()
|
||||||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||||
|
m.logger.ErrorContext(ctx, "report handshake success: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var done atomic.Bool
|
||||||
|
if ctx.Done() != nil {
|
||||||
|
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
|
||||||
|
}
|
||||||
|
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
|
||||||
|
go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
||||||
|
originSource := source
|
||||||
|
var readCounters, writeCounters []N.CountFunc
|
||||||
|
for {
|
||||||
|
source, readCounters = N.UnwrapCountReader(source, readCounters)
|
||||||
|
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
||||||
|
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
||||||
|
cachedBuffer := cachedSrc.ReadCached()
|
||||||
|
if cachedBuffer != nil {
|
||||||
|
if !cachedBuffer.IsEmpty() {
|
||||||
|
dataLen := cachedBuffer.Len()
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
_, err := destination.Write(cachedBuffer.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
|
||||||
|
cachedBuffer.Release()
|
||||||
|
if done.Swap(true) {
|
||||||
|
if onClose != nil {
|
||||||
|
onClose(err)
|
||||||
|
}
|
||||||
|
common.Close(source, destination)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cachedBuffer.Release()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
dstDuplex bool
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
_, err = bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
||||||
|
if _, dstDuplex = common.Cast[N.WriteCloser](destination); dstDuplex && err == nil {
|
||||||
|
N.CloseWrite(destination)
|
||||||
|
} else {
|
||||||
|
common.Close(destination)
|
||||||
|
}
|
||||||
|
if done.Swap(true) {
|
||||||
|
if onClose != nil {
|
||||||
|
onClose(err)
|
||||||
|
}
|
||||||
|
common.Close(source, destination)
|
||||||
|
}
|
||||||
|
if !direction {
|
||||||
|
if err == nil {
|
||||||
|
m.logger.DebugContext(ctx, "connection upload finished")
|
||||||
|
} else if !E.IsClosedOrCanceled(err) {
|
||||||
|
m.logger.ErrorContext(ctx, "connection upload closed: ", err)
|
||||||
|
} else {
|
||||||
|
m.logger.TraceContext(ctx, "connection upload closed")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil {
|
||||||
|
m.logger.DebugContext(ctx, "connection download finished")
|
||||||
|
} else if !E.IsClosedOrCanceled(err) {
|
||||||
|
m.logger.ErrorContext(ctx, "connection download closed: ", err)
|
||||||
|
} else {
|
||||||
|
m.logger.TraceContext(ctx, "connection download closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||||
|
ctx = adapter.WithContext(ctx, &metadata)
|
||||||
|
var (
|
||||||
|
remotePacketConn net.PacketConn
|
||||||
|
remoteConn net.Conn
|
||||||
|
destinationAddress netip.Addr
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if metadata.UDPConnect {
|
||||||
|
if len(metadata.DestinationAddresses) > 0 {
|
||||||
|
if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer {
|
||||||
|
remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
|
||||||
|
} else {
|
||||||
|
remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||||
|
m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
|
||||||
|
connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
|
||||||
|
if connRemoteAddr != metadata.Destination.Addr {
|
||||||
|
destinationAddress = connRemoteAddr
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(metadata.DestinationAddresses) > 0 {
|
||||||
|
remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
|
||||||
|
} else {
|
||||||
|
remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||||
|
m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
remotePacketConn.Close()
|
||||||
|
m.logger.ErrorContext(ctx, "report handshake success: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if destinationAddress.IsValid() {
|
||||||
|
var originDestination M.Socksaddr
|
||||||
|
if metadata.RouteOriginalDestination.IsValid() {
|
||||||
|
originDestination = metadata.RouteOriginalDestination
|
||||||
|
} else {
|
||||||
|
originDestination = metadata.Destination
|
||||||
|
}
|
||||||
|
if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
|
||||||
|
if metadata.UDPDisableDomainUnmapping {
|
||||||
|
remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
|
||||||
|
} else {
|
||||||
|
remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
|
||||||
|
natConn.UpdateDestination(destinationAddress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
destination := bufio.NewPacketConn(remotePacketConn)
|
||||||
|
if ctx.Done() != nil {
|
||||||
|
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
|
||||||
|
}
|
||||||
|
var done atomic.Bool
|
||||||
|
go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
|
||||||
|
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
||||||
|
_, err := bufio.CopyPacket(destination, source)
|
||||||
|
/*var readCounters, writeCounters []N.CountFunc
|
||||||
|
var cachedPackets []*N.PacketBuffer
|
||||||
|
originSource := source
|
||||||
|
for {
|
||||||
|
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
|
||||||
|
destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters)
|
||||||
|
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
|
||||||
|
packet := cachedReader.ReadCachedPacket()
|
||||||
|
if packet != nil {
|
||||||
|
cachedPackets = append(cachedPackets, packet)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
var handled bool
|
||||||
|
if natConn, isNatConn := source.(udpnat.Conn); isNatConn {
|
||||||
|
natConn.SetHandler(&udpHijacker{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: m.logger,
|
||||||
|
source: natConn,
|
||||||
|
destination: destination,
|
||||||
|
direction: direction,
|
||||||
|
readCounters: readCounters,
|
||||||
|
writeCounters: writeCounters,
|
||||||
|
done: done,
|
||||||
|
onClose: onClose,
|
||||||
|
})
|
||||||
|
handled = true
|
||||||
|
}
|
||||||
|
if cachedPackets != nil {
|
||||||
|
_, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters)
|
||||||
|
if err != nil {
|
||||||
|
common.Close(source, destination)
|
||||||
|
m.logger.ErrorContext(ctx, "packet upload payload: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if handled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/
|
||||||
|
if !direction {
|
||||||
|
if E.IsClosedOrCanceled(err) {
|
||||||
|
m.logger.TraceContext(ctx, "packet upload closed")
|
||||||
|
} else {
|
||||||
|
m.logger.DebugContext(ctx, "packet upload closed: ", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if E.IsClosedOrCanceled(err) {
|
||||||
|
m.logger.TraceContext(ctx, "packet download closed")
|
||||||
|
} else {
|
||||||
|
m.logger.DebugContext(ctx, "packet download closed: ", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !done.Swap(true) {
|
||||||
|
common.Close(source, destination)
|
||||||
|
if onClose != nil {
|
||||||
|
onClose(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*type udpHijacker struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger logger.ContextLogger
|
||||||
|
source io.Closer
|
||||||
|
destination N.PacketWriter
|
||||||
|
direction bool
|
||||||
|
readCounters []N.CountFunc
|
||||||
|
writeCounters []N.CountFunc
|
||||||
|
done *atomic.Bool
|
||||||
|
onClose N.CloseHandlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
|
||||||
|
dataLen := buffer.Len()
|
||||||
|
for _, counter := range u.readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
err := u.destination.WritePacket(buffer, source)
|
||||||
|
if err != nil {
|
||||||
|
common.Close(u.source, u.destination)
|
||||||
|
u.logger.DebugContext(u.ctx, "packet upload closed: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, counter := range u.writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpHijacker) Close() error {
|
||||||
|
var err error
|
||||||
|
if !u.done.Swap(true) {
|
||||||
|
err = common.Close(u.source, u.destination)
|
||||||
|
if u.onClose != nil {
|
||||||
|
u.onClose(net.ErrClosed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if u.direction {
|
||||||
|
u.logger.TraceContext(u.ctx, "packet download closed")
|
||||||
|
} else {
|
||||||
|
u.logger.TraceContext(u.ctx, "packet upload closed")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpHijacker) Upstream() any {
|
||||||
|
return u.destination
|
||||||
|
}
|
||||||
|
*/
|
124
route/conn_monitor.go
Normal file
124
route/conn_monitor.go
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
package route
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/x/list"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectionMonitor struct {
|
||||||
|
access sync.RWMutex
|
||||||
|
reloadChan chan struct{}
|
||||||
|
connections list.List[*monitorEntry]
|
||||||
|
}
|
||||||
|
|
||||||
|
type monitorEntry struct {
|
||||||
|
ctx context.Context
|
||||||
|
closer io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionMonitor() *ConnectionMonitor {
|
||||||
|
return &ConnectionMonitor{
|
||||||
|
reloadChan: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionMonitor) Add(ctx context.Context, closer io.Closer) N.CloseHandlerFunc {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
element := m.connections.PushBack(&monitorEntry{
|
||||||
|
ctx: ctx,
|
||||||
|
closer: closer,
|
||||||
|
})
|
||||||
|
select {
|
||||||
|
case <-m.reloadChan:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case m.reloadChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(it error) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
m.connections.Remove(element)
|
||||||
|
select {
|
||||||
|
case <-m.reloadChan:
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case m.reloadChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionMonitor) Start() error {
|
||||||
|
go m.monitor()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionMonitor) Close() error {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
close(m.reloadChan)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionMonitor) monitor() {
|
||||||
|
var (
|
||||||
|
selectCases []reflect.SelectCase
|
||||||
|
elements []*list.Element[*monitorEntry]
|
||||||
|
)
|
||||||
|
rootCase := reflect.SelectCase{
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(m.reloadChan),
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
m.access.RLock()
|
||||||
|
if m.connections.Len() == 0 {
|
||||||
|
m.access.RUnlock()
|
||||||
|
if _, loaded := <-m.reloadChan; !loaded {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(elements) < m.connections.Len() {
|
||||||
|
elements = make([]*list.Element[*monitorEntry], 0, m.connections.Len())
|
||||||
|
}
|
||||||
|
if len(selectCases) < m.connections.Len()+1 {
|
||||||
|
selectCases = make([]reflect.SelectCase, 0, m.connections.Len()+1)
|
||||||
|
}
|
||||||
|
selectCases = selectCases[:1]
|
||||||
|
selectCases[0] = rootCase
|
||||||
|
for element := m.connections.Front(); element != nil; element = element.Next() {
|
||||||
|
elements = append(elements, element)
|
||||||
|
selectCases = append(selectCases, reflect.SelectCase{
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(element.Value.ctx.Done()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
m.access.RUnlock()
|
||||||
|
selected, _, loaded := reflect.Select(selectCases)
|
||||||
|
if selected == 0 {
|
||||||
|
if !loaded {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
element := elements[selected-1]
|
||||||
|
m.access.Lock()
|
||||||
|
m.connections.Remove(element)
|
||||||
|
m.access.Unlock()
|
||||||
|
element.Value.closer.Close() // maybe go close
|
||||||
|
}
|
||||||
|
}
|
43
route/conn_monitor_test.go
Normal file
43
route/conn_monitor_test.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package route_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/route"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMonitor(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var closer myCloser
|
||||||
|
closer.Add(1)
|
||||||
|
monitor := route.NewConnectionMonitor()
|
||||||
|
require.NoError(t, monitor.Start())
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
monitor.Add(ctx, &closer)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
closer.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second + 100*time.Millisecond):
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
require.NoError(t, monitor.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
type myCloser struct {
|
||||||
|
sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *myCloser) Close() error {
|
||||||
|
c.Done()
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -30,15 +30,15 @@ func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata ad
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) {
|
func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) {
|
||||||
if uConn, isUDPNAT2 := conn.(*udpnat.Conn); isUDPNAT2 {
|
if natConn, isNatConn := conn.(udpnat.Conn); isNatConn {
|
||||||
metadata.Destination = M.Socksaddr{}
|
metadata.Destination = M.Socksaddr{}
|
||||||
for _, packet := range packetBuffers {
|
for _, packet := range packetBuffers {
|
||||||
buffer := packet.Buffer
|
buffer := packet.Buffer
|
||||||
destination := packet.Destination
|
destination := packet.Destination
|
||||||
N.PutPacketBuffer(packet)
|
N.PutPacketBuffer(packet)
|
||||||
go ExchangeDNSPacket(ctx, r, uConn, buffer, metadata, destination)
|
go ExchangeDNSPacket(ctx, r, natConn, buffer, metadata, destination)
|
||||||
}
|
}
|
||||||
uConn.SetHandler(&dnsHijacker{
|
natConn.SetHandler(&dnsHijacker{
|
||||||
router: r,
|
router: r,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|
|
@ -145,13 +145,13 @@ func (r *Router) downloadGeoIPDatabase(savePath string) error {
|
||||||
r.logger.Info("downloading geoip database")
|
r.logger.Info("downloading geoip database")
|
||||||
var detour adapter.Outbound
|
var detour adapter.Outbound
|
||||||
if r.geoIPOptions.DownloadDetour != "" {
|
if r.geoIPOptions.DownloadDetour != "" {
|
||||||
outbound, loaded := r.outboundManager.Outbound(r.geoIPOptions.DownloadDetour)
|
outbound, loaded := r.outbound.Outbound(r.geoIPOptions.DownloadDetour)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour)
|
return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour)
|
||||||
}
|
}
|
||||||
detour = outbound
|
detour = outbound
|
||||||
} else {
|
} else {
|
||||||
detour = r.outboundManager.Default()
|
detour = r.outbound.Default()
|
||||||
}
|
}
|
||||||
|
|
||||||
if parentDir := filepath.Dir(savePath); parentDir != "" {
|
if parentDir := filepath.Dir(savePath); parentDir != "" {
|
||||||
|
@ -200,13 +200,13 @@ func (r *Router) downloadGeositeDatabase(savePath string) error {
|
||||||
r.logger.Info("downloading geosite database")
|
r.logger.Info("downloading geosite database")
|
||||||
var detour adapter.Outbound
|
var detour adapter.Outbound
|
||||||
if r.geositeOptions.DownloadDetour != "" {
|
if r.geositeOptions.DownloadDetour != "" {
|
||||||
outbound, loaded := r.outboundManager.Outbound(r.geositeOptions.DownloadDetour)
|
outbound, loaded := r.outbound.Outbound(r.geositeOptions.DownloadDetour)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour)
|
return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour)
|
||||||
}
|
}
|
||||||
detour = outbound
|
detour = outbound
|
||||||
} else {
|
} else {
|
||||||
detour = r.outboundManager.Default()
|
detour = r.outbound.Default()
|
||||||
}
|
}
|
||||||
|
|
||||||
if parentDir := filepath.Dir(savePath); parentDir != "" {
|
if parentDir := filepath.Dir(savePath); parentDir != "" {
|
||||||
|
|
|
@ -48,6 +48,7 @@ type NetworkManager struct {
|
||||||
powerListener winpowrprof.EventListener
|
powerListener winpowrprof.EventListener
|
||||||
pauseManager pause.Manager
|
pauseManager pause.Manager
|
||||||
platformInterface platform.Interface
|
platformInterface platform.Interface
|
||||||
|
inboundManager adapter.InboundManager
|
||||||
outboundManager adapter.OutboundManager
|
outboundManager adapter.OutboundManager
|
||||||
wifiState adapter.WIFIState
|
wifiState adapter.WIFIState
|
||||||
started bool
|
started bool
|
||||||
|
@ -357,6 +358,13 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
|
||||||
func (r *NetworkManager) ResetNetwork() {
|
func (r *NetworkManager) ResetNetwork() {
|
||||||
conntrack.Close()
|
conntrack.Close()
|
||||||
|
|
||||||
|
for _, inbound := range r.inboundManager.Inbounds() {
|
||||||
|
listener, isListener := inbound.(adapter.InterfaceUpdateListener)
|
||||||
|
if isListener {
|
||||||
|
listener.InterfaceUpdated()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, outbound := range r.outboundManager.Outbounds() {
|
for _, outbound := range r.outboundManager.Outbounds() {
|
||||||
listener, isListener := outbound.(adapter.InterfaceUpdateListener)
|
listener, isListener := outbound.(adapter.InterfaceUpdateListener)
|
||||||
if isListener {
|
if isListener {
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/adapter/outbound"
|
|
||||||
"github.com/sagernet/sing-box/common/conntrack"
|
"github.com/sagernet/sing-box/common/conntrack"
|
||||||
"github.com/sagernet/sing-box/common/process"
|
"github.com/sagernet/sing-box/common/process"
|
||||||
"github.com/sagernet/sing-box/common/sniff"
|
"github.com/sagernet/sing-box/common/sniff"
|
||||||
|
@ -58,7 +57,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
|
||||||
if metadata.LastInbound == metadata.InboundDetour {
|
if metadata.LastInbound == metadata.InboundDetour {
|
||||||
return E.New("routing loop on detour: ", metadata.InboundDetour)
|
return E.New("routing loop on detour: ", metadata.InboundDetour)
|
||||||
}
|
}
|
||||||
detour, loaded := r.inboundManager.Get(metadata.InboundDetour)
|
detour, loaded := r.inbound.Get(metadata.InboundDetour)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
return E.New("inbound detour not found: ", metadata.InboundDetour)
|
return E.New("inbound detour not found: ", metadata.InboundDetour)
|
||||||
}
|
}
|
||||||
|
@ -96,7 +95,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
|
||||||
switch action := selectedRule.Action().(type) {
|
switch action := selectedRule.Action().(type) {
|
||||||
case *rule.RuleActionRoute:
|
case *rule.RuleActionRoute:
|
||||||
var loaded bool
|
var loaded bool
|
||||||
selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound)
|
selectedOutbound, loaded = r.outbound.Outbound(action.Outbound)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
buf.ReleaseMulti(buffers)
|
buf.ReleaseMulti(buffers)
|
||||||
return E.New("outbound not found: ", action.Outbound)
|
return E.New("outbound not found: ", action.Outbound)
|
||||||
|
@ -118,7 +117,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if selectedRule == nil {
|
if selectedRule == nil {
|
||||||
defaultOutbound := r.outboundManager.Default()
|
defaultOutbound := r.outbound.Default()
|
||||||
if !common.Contains(defaultOutbound.Network(), N.NetworkTCP) {
|
if !common.Contains(defaultOutbound.Network(), N.NetworkTCP) {
|
||||||
buf.ReleaseMulti(buffers)
|
buf.ReleaseMulti(buffers)
|
||||||
return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag())
|
return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag())
|
||||||
|
@ -148,19 +147,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// TODO
|
r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose)
|
||||||
err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata)
|
|
||||||
if err != nil {
|
|
||||||
conn.Close()
|
|
||||||
if onClose != nil {
|
|
||||||
onClose(err)
|
|
||||||
}
|
|
||||||
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
|
|
||||||
} else {
|
|
||||||
if onClose != nil {
|
|
||||||
onClose(nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,7 +186,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
||||||
if metadata.LastInbound == metadata.InboundDetour {
|
if metadata.LastInbound == metadata.InboundDetour {
|
||||||
return E.New("routing loop on detour: ", metadata.InboundDetour)
|
return E.New("routing loop on detour: ", metadata.InboundDetour)
|
||||||
}
|
}
|
||||||
detour, loaded := r.inboundManager.Get(metadata.InboundDetour)
|
detour, loaded := r.inbound.Get(metadata.InboundDetour)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
return E.New("inbound detour not found: ", metadata.InboundDetour)
|
return E.New("inbound detour not found: ", metadata.InboundDetour)
|
||||||
}
|
}
|
||||||
|
@ -233,7 +220,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
||||||
switch action := selectedRule.Action().(type) {
|
switch action := selectedRule.Action().(type) {
|
||||||
case *rule.RuleActionRoute:
|
case *rule.RuleActionRoute:
|
||||||
var loaded bool
|
var loaded bool
|
||||||
selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound)
|
selectedOutbound, loaded = r.outbound.Outbound(action.Outbound)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
N.ReleaseMultiPacketBuffer(packetBuffers)
|
N.ReleaseMultiPacketBuffer(packetBuffers)
|
||||||
return E.New("outbound not found: ", action.Outbound)
|
return E.New("outbound not found: ", action.Outbound)
|
||||||
|
@ -252,7 +239,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if selectedRule == nil || selectReturn {
|
if selectedRule == nil || selectReturn {
|
||||||
defaultOutbound := r.outboundManager.Default()
|
defaultOutbound := r.outbound.Default()
|
||||||
if !common.Contains(defaultOutbound.Network(), N.NetworkUDP) {
|
if !common.Contains(defaultOutbound.Network(), N.NetworkUDP) {
|
||||||
N.ReleaseMultiPacketBuffer(packetBuffers)
|
N.ReleaseMultiPacketBuffer(packetBuffers)
|
||||||
return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag())
|
return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag())
|
||||||
|
@ -278,12 +265,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// TODO
|
r.connection.NewPacketConnection(ctx, selectedOutbound, conn, metadata, onClose)
|
||||||
err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata)
|
|
||||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -450,8 +432,12 @@ match:
|
||||||
}
|
}
|
||||||
metadata.NetworkStrategy = routeOptions.NetworkStrategy
|
metadata.NetworkStrategy = routeOptions.NetworkStrategy
|
||||||
metadata.FallbackDelay = routeOptions.FallbackDelay
|
metadata.FallbackDelay = routeOptions.FallbackDelay
|
||||||
metadata.UDPDisableDomainUnmapping = routeOptions.UDPDisableDomainUnmapping
|
if routeOptions.UDPDisableDomainUnmapping {
|
||||||
metadata.UDPConnect = routeOptions.UDPConnect
|
metadata.UDPDisableDomainUnmapping = true
|
||||||
|
}
|
||||||
|
if routeOptions.UDPConnect {
|
||||||
|
metadata.UDPConnect = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
switch action := currentRule.Action().(type) {
|
switch action := currentRule.Action().(type) {
|
||||||
case *rule.RuleActionSniff:
|
case *rule.RuleActionSniff:
|
||||||
|
|
|
@ -38,9 +38,10 @@ type Router struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
dnsLogger log.ContextLogger
|
dnsLogger log.ContextLogger
|
||||||
inboundManager adapter.InboundManager
|
inbound adapter.InboundManager
|
||||||
outboundManager adapter.OutboundManager
|
outbound adapter.OutboundManager
|
||||||
networkManager adapter.NetworkManager
|
connection adapter.ConnectionManager
|
||||||
|
network adapter.NetworkManager
|
||||||
rules []adapter.Rule
|
rules []adapter.Rule
|
||||||
needGeoIPDatabase bool
|
needGeoIPDatabase bool
|
||||||
needGeositeDatabase bool
|
needGeositeDatabase bool
|
||||||
|
@ -74,9 +75,10 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
logger: logFactory.NewLogger("router"),
|
logger: logFactory.NewLogger("router"),
|
||||||
dnsLogger: logFactory.NewLogger("dns"),
|
dnsLogger: logFactory.NewLogger("dns"),
|
||||||
inboundManager: service.FromContext[adapter.InboundManager](ctx),
|
inbound: service.FromContext[adapter.InboundManager](ctx),
|
||||||
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||||
networkManager: service.FromContext[adapter.NetworkManager](ctx),
|
connection: service.FromContext[adapter.ConnectionManager](ctx),
|
||||||
|
network: service.FromContext[adapter.NetworkManager](ctx),
|
||||||
rules: make([]adapter.Rule, 0, len(options.Rules)),
|
rules: make([]adapter.Rule, 0, len(options.Rules)),
|
||||||
dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
|
dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
|
||||||
ruleSetMap: make(map[string]adapter.RuleSet),
|
ruleSetMap: make(map[string]adapter.RuleSet),
|
||||||
|
@ -260,7 +262,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
Name: "local",
|
Name: "local",
|
||||||
Address: "local",
|
Address: "local",
|
||||||
Dialer: common.Must1(dialer.NewDefault(router.networkManager, option.DialerOptions{})),
|
Dialer: common.Must1(dialer.NewDefault(router.network, option.DialerOptions{})),
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
defaultTransport = transports[0]
|
defaultTransport = transports[0]
|
||||||
|
@ -405,7 +407,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||||
monitor.Start("initialize process searcher")
|
monitor.Start("initialize process searcher")
|
||||||
searcher, err := process.NewSearcher(process.Config{
|
searcher, err := process.NewSearcher(process.Config{
|
||||||
Logger: r.logger,
|
Logger: r.logger,
|
||||||
PackageManager: r.networkManager.PackageManager(),
|
PackageManager: r.network.PackageManager(),
|
||||||
})
|
})
|
||||||
monitor.Finish()
|
monitor.Finish()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -507,7 +509,7 @@ func (r *Router) SetTracker(tracker adapter.ConnectionTracker) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) ResetNetwork() {
|
func (r *Router) ResetNetwork() {
|
||||||
r.networkManager.ResetNetwork()
|
r.network.ResetNetwork()
|
||||||
for _, transport := range r.transports {
|
for _, transport := range r.transports {
|
||||||
transport.Reset()
|
transport.Reset()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue