mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-01-25 10:16:55 +00:00
333 lines
10 KiB
Go
333 lines
10 KiB
Go
|
package route
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"io"
|
||
|
"net"
|
||
|
"net/netip"
|
||
|
"sync/atomic"
|
||
|
|
||
|
"github.com/sagernet/sing-box/adapter"
|
||
|
"github.com/sagernet/sing-box/common/dialer"
|
||
|
"github.com/sagernet/sing/common"
|
||
|
"github.com/sagernet/sing/common/bufio"
|
||
|
E "github.com/sagernet/sing/common/exceptions"
|
||
|
"github.com/sagernet/sing/common/logger"
|
||
|
M "github.com/sagernet/sing/common/metadata"
|
||
|
N "github.com/sagernet/sing/common/network"
|
||
|
)
|
||
|
|
||
|
var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
|
||
|
|
||
|
type ConnectionManager struct {
|
||
|
logger logger.ContextLogger
|
||
|
monitor *ConnectionMonitor
|
||
|
}
|
||
|
|
||
|
func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
|
||
|
return &ConnectionManager{
|
||
|
logger: logger,
|
||
|
monitor: NewConnectionMonitor(),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *ConnectionManager) Start(stage adapter.StartStage) error {
|
||
|
if stage != adapter.StartStateInitialize {
|
||
|
return nil
|
||
|
}
|
||
|
return m.monitor.Start()
|
||
|
}
|
||
|
|
||
|
func (m *ConnectionManager) Close() error {
|
||
|
return m.monitor.Close()
|
||
|
}
|
||
|
|
||
|
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||
|
ctx = adapter.WithContext(ctx, &metadata)
|
||
|
var (
|
||
|
remoteConn net.Conn
|
||
|
err error
|
||
|
)
|
||
|
if len(metadata.DestinationAddresses) > 0 {
|
||
|
remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
|
||
|
} else {
|
||
|
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
|
||
|
}
|
||
|
if err != nil {
|
||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||
|
m.logger.ErrorContext(ctx, "open outbound connection: ", err)
|
||
|
return
|
||
|
}
|
||
|
err = N.ReportConnHandshakeSuccess(conn, remoteConn)
|
||
|
if err != nil {
|
||
|
remoteConn.Close()
|
||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||
|
m.logger.ErrorContext(ctx, "report handshake success: ", err)
|
||
|
return
|
||
|
}
|
||
|
var done atomic.Bool
|
||
|
if ctx.Done() != nil {
|
||
|
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
|
||
|
}
|
||
|
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
|
||
|
go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
|
||
|
}
|
||
|
|
||
|
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
||
|
originSource := source
|
||
|
originDestination := destination
|
||
|
var readCounters, writeCounters []N.CountFunc
|
||
|
for {
|
||
|
source, readCounters = N.UnwrapCountReader(source, readCounters)
|
||
|
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
||
|
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
||
|
cachedBuffer := cachedSrc.ReadCached()
|
||
|
if cachedBuffer != nil {
|
||
|
dataLen := cachedBuffer.Len()
|
||
|
_, err := destination.Write(cachedBuffer.Bytes())
|
||
|
cachedBuffer.Release()
|
||
|
if err != nil {
|
||
|
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
|
||
|
if done.Swap(true) {
|
||
|
if onClose != nil {
|
||
|
onClose(err)
|
||
|
}
|
||
|
}
|
||
|
common.Close(originSource, originDestination)
|
||
|
return
|
||
|
}
|
||
|
for _, counter := range readCounters {
|
||
|
counter(int64(dataLen))
|
||
|
}
|
||
|
for _, counter := range writeCounters {
|
||
|
counter(int64(dataLen))
|
||
|
}
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
||
|
if err != nil {
|
||
|
common.Close(originSource, originDestination)
|
||
|
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
|
||
|
err = duplexDst.CloseWrite()
|
||
|
if err != nil {
|
||
|
common.Close(originSource, originDestination)
|
||
|
}
|
||
|
} else {
|
||
|
common.Close(originDestination)
|
||
|
}
|
||
|
if done.Swap(true) {
|
||
|
if onClose != nil {
|
||
|
onClose(err)
|
||
|
}
|
||
|
common.Close(originSource, originDestination)
|
||
|
}
|
||
|
if !direction {
|
||
|
if err == nil {
|
||
|
m.logger.DebugContext(ctx, "connection upload finished")
|
||
|
} else if !E.IsClosedOrCanceled(err) {
|
||
|
m.logger.ErrorContext(ctx, "connection upload closed: ", err)
|
||
|
} else {
|
||
|
m.logger.TraceContext(ctx, "connection upload closed")
|
||
|
}
|
||
|
} else {
|
||
|
if err == nil {
|
||
|
m.logger.DebugContext(ctx, "connection download finished")
|
||
|
} else if !E.IsClosedOrCanceled(err) {
|
||
|
m.logger.ErrorContext(ctx, "connection download closed: ", err)
|
||
|
} else {
|
||
|
m.logger.TraceContext(ctx, "connection download closed")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||
|
ctx = adapter.WithContext(ctx, &metadata)
|
||
|
var (
|
||
|
remotePacketConn net.PacketConn
|
||
|
remoteConn net.Conn
|
||
|
destinationAddress netip.Addr
|
||
|
err error
|
||
|
)
|
||
|
if metadata.UDPConnect {
|
||
|
if len(metadata.DestinationAddresses) > 0 {
|
||
|
if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer {
|
||
|
remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
|
||
|
} else {
|
||
|
remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
|
||
|
}
|
||
|
} else {
|
||
|
remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
|
||
|
}
|
||
|
if err != nil {
|
||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||
|
m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
|
||
|
return
|
||
|
}
|
||
|
remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
|
||
|
connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
|
||
|
if connRemoteAddr != metadata.Destination.Addr {
|
||
|
destinationAddress = connRemoteAddr
|
||
|
}
|
||
|
} else {
|
||
|
if len(metadata.DestinationAddresses) > 0 {
|
||
|
remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
|
||
|
} else {
|
||
|
remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
|
||
|
}
|
||
|
if err != nil {
|
||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||
|
m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
|
||
|
if err != nil {
|
||
|
conn.Close()
|
||
|
remotePacketConn.Close()
|
||
|
m.logger.ErrorContext(ctx, "report handshake success: ", err)
|
||
|
return
|
||
|
}
|
||
|
if destinationAddress.IsValid() {
|
||
|
var originDestination M.Socksaddr
|
||
|
if metadata.RouteOriginalDestination.IsValid() {
|
||
|
originDestination = metadata.RouteOriginalDestination
|
||
|
} else {
|
||
|
originDestination = metadata.Destination
|
||
|
}
|
||
|
if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
|
||
|
if metadata.UDPDisableDomainUnmapping {
|
||
|
remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
|
||
|
} else {
|
||
|
remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
|
||
|
}
|
||
|
}
|
||
|
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
|
||
|
natConn.UpdateDestination(destinationAddress)
|
||
|
}
|
||
|
}
|
||
|
destination := bufio.NewPacketConn(remotePacketConn)
|
||
|
var done atomic.Bool
|
||
|
if ctx.Done() != nil {
|
||
|
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
|
||
|
}
|
||
|
go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
|
||
|
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
|
||
|
}
|
||
|
|
||
|
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
||
|
_, err := bufio.CopyPacket(destination, source)
|
||
|
/*var readCounters, writeCounters []N.CountFunc
|
||
|
var cachedPackets []*N.PacketBuffer
|
||
|
originSource := source
|
||
|
for {
|
||
|
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
|
||
|
destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters)
|
||
|
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
|
||
|
packet := cachedReader.ReadCachedPacket()
|
||
|
if packet != nil {
|
||
|
cachedPackets = append(cachedPackets, packet)
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
var handled bool
|
||
|
if natConn, isNatConn := source.(udpnat.Conn); isNatConn {
|
||
|
natConn.SetHandler(&udpHijacker{
|
||
|
ctx: ctx,
|
||
|
logger: m.logger,
|
||
|
source: natConn,
|
||
|
destination: destination,
|
||
|
direction: direction,
|
||
|
readCounters: readCounters,
|
||
|
writeCounters: writeCounters,
|
||
|
done: done,
|
||
|
onClose: onClose,
|
||
|
})
|
||
|
handled = true
|
||
|
}
|
||
|
if cachedPackets != nil {
|
||
|
_, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters)
|
||
|
if err != nil {
|
||
|
common.Close(source, destination)
|
||
|
m.logger.ErrorContext(ctx, "packet upload payload: ", err)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
if handled {
|
||
|
return
|
||
|
}
|
||
|
_, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/
|
||
|
if !direction {
|
||
|
if E.IsClosedOrCanceled(err) {
|
||
|
m.logger.TraceContext(ctx, "packet upload closed")
|
||
|
} else {
|
||
|
m.logger.DebugContext(ctx, "packet upload closed: ", err)
|
||
|
}
|
||
|
} else {
|
||
|
if E.IsClosedOrCanceled(err) {
|
||
|
m.logger.TraceContext(ctx, "packet download closed")
|
||
|
} else {
|
||
|
m.logger.DebugContext(ctx, "packet download closed: ", err)
|
||
|
}
|
||
|
}
|
||
|
if !done.Swap(true) {
|
||
|
if onClose != nil {
|
||
|
onClose(err)
|
||
|
}
|
||
|
}
|
||
|
common.Close(source, destination)
|
||
|
}
|
||
|
|
||
|
/*type udpHijacker struct {
|
||
|
ctx context.Context
|
||
|
logger logger.ContextLogger
|
||
|
source io.Closer
|
||
|
destination N.PacketWriter
|
||
|
direction bool
|
||
|
readCounters []N.CountFunc
|
||
|
writeCounters []N.CountFunc
|
||
|
done *atomic.Bool
|
||
|
onClose N.CloseHandlerFunc
|
||
|
}
|
||
|
|
||
|
func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
|
||
|
dataLen := buffer.Len()
|
||
|
for _, counter := range u.readCounters {
|
||
|
counter(int64(dataLen))
|
||
|
}
|
||
|
err := u.destination.WritePacket(buffer, source)
|
||
|
if err != nil {
|
||
|
common.Close(u.source, u.destination)
|
||
|
u.logger.DebugContext(u.ctx, "packet upload closed: ", err)
|
||
|
return
|
||
|
}
|
||
|
for _, counter := range u.writeCounters {
|
||
|
counter(int64(dataLen))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (u *udpHijacker) Close() error {
|
||
|
var err error
|
||
|
if !u.done.Swap(true) {
|
||
|
err = common.Close(u.source, u.destination)
|
||
|
if u.onClose != nil {
|
||
|
u.onClose(net.ErrClosed)
|
||
|
}
|
||
|
}
|
||
|
if u.direction {
|
||
|
u.logger.TraceContext(u.ctx, "packet download closed")
|
||
|
} else {
|
||
|
u.logger.TraceContext(u.ctx, "packet upload closed")
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (u *udpHijacker) Upstream() any {
|
||
|
return u.destination
|
||
|
}
|
||
|
*/
|