Add UDP timeout route option

This commit is contained in:
世界 2024-11-24 14:45:40 +08:00
parent c18c7d807f
commit 7c75d2b9d7
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
13 changed files with 143 additions and 223 deletions

View file

@ -70,6 +70,8 @@ type InboundContext struct {
InboundOptions option.InboundOptions InboundOptions option.InboundOptions
UDPDisableDomainUnmapping bool UDPDisableDomainUnmapping bool
UDPConnect bool UDPConnect bool
UDPTimeout time.Duration
NetworkStrategy C.NetworkStrategy NetworkStrategy C.NetworkStrategy
NetworkType []C.InterfaceType NetworkType []C.InterfaceType
FallbackNetworkType []C.InterfaceType FallbackNetworkType []C.InterfaceType

View file

@ -1,157 +0,0 @@
package outbound
import (
"context"
"net"
"net/netip"
"os"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn
var err error
if len(metadata.DestinationAddresses) > 0 {
outConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
err = N.ReportConnHandshakeSuccess(conn, outConn)
if err != nil {
outConn.Close()
return err
}
return CopyEarlyConn(ctx, conn, outConn)
}
func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata)
var (
outPacketConn net.PacketConn
outConn net.Conn
destinationAddress netip.Addr
err error
)
if metadata.UDPConnect {
if len(metadata.DestinationAddresses) > 0 {
if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer {
outConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
}
} else {
outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
}
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
outPacketConn = bufio.NewUnbindPacketConn(outConn)
connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr())
if connRemoteAddr != metadata.Destination.Addr {
destinationAddress = connRemoteAddr
}
} else {
if len(metadata.DestinationAddresses) > 0 {
outPacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
}
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
}
err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn)
if err != nil {
outPacketConn.Close()
return err
}
if destinationAddress.IsValid() {
var originDestination M.Socksaddr
if metadata.RouteOriginalDestination.IsValid() {
originDestination = metadata.RouteOriginalDestination
} else {
originDestination = metadata.Destination
}
if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
if metadata.UDPDisableDomainUnmapping {
outPacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
} else {
outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
}
}
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
natConn.UpdateDestination(destinationAddress)
}
}
switch metadata.Protocol {
case C.ProtocolSTUN:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout)
case C.ProtocolQUIC:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout)
case C.ProtocolDNS:
ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
}
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
}
func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {
if cachedReader, isCached := conn.(N.CachedReader); isCached {
payload := cachedReader.ReadCached()
if payload != nil && !payload.IsEmpty() {
_, err := serverConn.Write(payload.Bytes())
payload.Release()
if err != nil {
serverConn.Close()
return err
}
return bufio.CopyConn(ctx, conn, serverConn)
}
}
if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](serverConn); isEarlyConn && earlyConn.NeedHandshake() {
payload := buf.NewPacket()
err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
if err != os.ErrInvalid {
if err != nil {
payload.Release()
serverConn.Close()
return err
}
_, err = payload.ReadOnceFrom(conn)
if err != nil && !E.IsTimeout(err) {
payload.Release()
serverConn.Close()
return E.Cause(err, "read payload")
}
err = conn.SetReadDeadline(time.Time{})
if err != nil {
payload.Release()
serverConn.Close()
return err
}
}
_, err = serverConn.Write(payload.Bytes())
payload.Release()
if err != nil {
serverConn.Close()
return N.ReportHandshakeFailure(conn, err)
}
}
return bufio.CopyConn(ctx, conn, serverConn)
}

View file

@ -10,6 +10,7 @@ const (
ProtocolDTLS = "dtls" ProtocolDTLS = "dtls"
ProtocolSSH = "ssh" ProtocolSSH = "ssh"
ProtocolRDP = "rdp" ProtocolRDP = "rdp"
ProtocolNTP = "ntp"
) )
const ( const (

View file

@ -9,8 +9,6 @@ const (
TCPTimeout = 15 * time.Second TCPTimeout = 15 * time.Second
ReadPayloadTimeout = 300 * time.Millisecond ReadPayloadTimeout = 300 * time.Millisecond
DNSTimeout = 10 * time.Second DNSTimeout = 10 * time.Second
QUICTimeout = 30 * time.Second
STUNTimeout = 15 * time.Second
UDPTimeout = 5 * time.Minute UDPTimeout = 5 * time.Minute
DefaultURLTestInterval = 3 * time.Minute DefaultURLTestInterval = 3 * time.Minute
DefaultURLTestIdleTimeout = 30 * time.Minute DefaultURLTestIdleTimeout = 30 * time.Minute
@ -19,3 +17,18 @@ const (
FatalStopTimeout = 10 * time.Second FatalStopTimeout = 10 * time.Second
FakeIPMetadataSaveInterval = 10 * time.Second FakeIPMetadataSaveInterval = 10 * time.Second
) )
var PortProtocols = map[uint16]string{
53: ProtocolDNS,
123: ProtocolNTP,
3478: ProtocolSTUN,
443: ProtocolQUIC,
}
var ProtocolTimeouts = map[string]time.Duration{
ProtocolDNS: 10 * time.Second,
ProtocolNTP: 10 * time.Second,
ProtocolSTUN: 10 * time.Second,
ProtocolQUIC: 30 * time.Second,
ProtocolDTLS: 30 * time.Second,
}

View file

@ -41,7 +41,8 @@ See `route-options` fields below.
"network_strategy": "", "network_strategy": "",
"fallback_delay": "", "fallback_delay": "",
"udp_disable_domain_unmapping": false, "udp_disable_domain_unmapping": false,
"udp_connect": false "udp_connect": false,
"udp_timeout": ""
} }
``` ```
@ -86,6 +87,28 @@ do not support receiving UDP packets with domain addresses, such as Surge.
If enabled, attempts to connect UDP connection to the destination instead of listen. If enabled, attempts to connect UDP connection to the destination instead of listen.
#### udp_timeout
Timeout for UDP connections.
Setting a larger value than the UDP timeout in inbounds will have no effect.
Default value for protocol sniffed connections:
| Timeout | Protocol |
|---------|----------------------|
| `10s` | `dns`, `ntp`, `stun` |
| `30s` | `quic`, `dtls` |
If no protocol is sniffed, the following ports will be recognized as protocols by default:
| Port | Protocol |
|------|----------|
| 53 | `dns` |
| 123 | `ntp` |
| 443 | `quic` |
| 3478 | `stun` |
### reject ### reject
```json ```json

View file

@ -37,7 +37,8 @@ icon: material/new-box
"network_strategy": "", "network_strategy": "",
"fallback_delay": "", "fallback_delay": "",
"udp_disable_domain_unmapping": false, "udp_disable_domain_unmapping": false,
"udp_connect": false "udp_connect": false,
"udp_timeout": ""
} }
``` ```
@ -84,6 +85,28 @@ icon: material/new-box
如果启用,将尝试将 UDP 连接 connect 到目标而不是 listen。 如果启用,将尝试将 UDP 连接 connect 到目标而不是 listen。
#### udp_timeout
UDP 连接超时时间。
设置比入站 UDP 超时更大的值将无效。
已探测协议连接的默认值:
| 超时 | 协议 |
|-------|----------------------|
| `10s` | `dns`, `ntp`, `stun` |
| `30s` | `quic`, `dtls` |
如果没有探测到协议,以下端口将默认识别为协议:
| 端口 | 协议 |
|------|--------|
| 53 | `dns` |
| 123 | `ntp` |
| 443 | `quic` |
| 3478 | `stun` |
### reject ### reject
```json ```json

View file

@ -150,6 +150,7 @@ type RawRouteOptionsActionOptions struct {
UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"`
UDPConnect bool `json:"udp_connect,omitempty"` UDPConnect bool `json:"udp_connect,omitempty"`
UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"`
} }
type RouteOptionsActionOptions RawRouteOptionsActionOptions type RouteOptionsActionOptions RawRouteOptionsActionOptions

View file

@ -14,7 +14,7 @@ type WireGuardEndpointOptions struct {
PrivateKey string `json:"private_key"` PrivateKey string `json:"private_key"`
ListenPort uint16 `json:"listen_port,omitempty"` ListenPort uint16 `json:"listen_port,omitempty"`
Peers []WireGuardPeer `json:"peers,omitempty"` Peers []WireGuardPeer `json:"peers,omitempty"`
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"` UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"`
Workers int `json:"workers,omitempty"` Workers int `json:"workers,omitempty"`
DialerOptions DialerOptions
} }

View file

@ -10,6 +10,7 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/atomic"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -21,17 +22,22 @@ func RegisterSelector(registry *outbound.Registry) {
outbound.Register[option.SelectorOutboundOptions](registry, C.TypeSelector, NewSelector) outbound.Register[option.SelectorOutboundOptions](registry, C.TypeSelector, NewSelector)
} }
var _ adapter.OutboundGroup = (*Selector)(nil) var (
_ adapter.OutboundGroup = (*Selector)(nil)
_ adapter.ConnectionHandlerEx = (*Selector)(nil)
_ adapter.PacketConnectionHandlerEx = (*Selector)(nil)
)
type Selector struct { type Selector struct {
outbound.Adapter outbound.Adapter
ctx context.Context ctx context.Context
outboundManager adapter.OutboundManager outbound adapter.OutboundManager
connection adapter.ConnectionManager
logger logger.ContextLogger logger logger.ContextLogger
tags []string tags []string
defaultTag string defaultTag string
outbounds map[string]adapter.Outbound outbounds map[string]adapter.Outbound
selected adapter.Outbound selected atomic.TypedValue[adapter.Outbound]
interruptGroup *interrupt.Group interruptGroup *interrupt.Group
interruptExternalConnections bool interruptExternalConnections bool
} }
@ -40,7 +46,8 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL
outbound := &Selector{ outbound := &Selector{
Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds),
ctx: ctx, ctx: ctx,
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx),
connection: service.FromContext[adapter.ConnectionManager](ctx),
logger: logger, logger: logger,
tags: options.Outbounds, tags: options.Outbounds,
defaultTag: options.Default, defaultTag: options.Default,
@ -55,15 +62,16 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL
} }
func (s *Selector) Network() []string { func (s *Selector) Network() []string {
if s.selected == nil { selected := s.selected.Load()
if selected == nil {
return []string{N.NetworkTCP, N.NetworkUDP} return []string{N.NetworkTCP, N.NetworkUDP}
} }
return s.selected.Network() return selected.Network()
} }
func (s *Selector) Start() error { func (s *Selector) Start() error {
for i, tag := range s.tags { for i, tag := range s.tags {
detour, loaded := s.outboundManager.Outbound(tag) detour, loaded := s.outbound.Outbound(tag)
if !loaded { if !loaded {
return E.New("outbound ", i, " not found: ", tag) return E.New("outbound ", i, " not found: ", tag)
} }
@ -77,7 +85,7 @@ func (s *Selector) Start() error {
if selected != "" { if selected != "" {
detour, loaded := s.outbounds[selected] detour, loaded := s.outbounds[selected]
if loaded { if loaded {
s.selected = detour s.selected.Store(detour)
return nil return nil
} }
} }
@ -89,16 +97,16 @@ func (s *Selector) Start() error {
if !loaded { if !loaded {
return E.New("default outbound not found: ", s.defaultTag) return E.New("default outbound not found: ", s.defaultTag)
} }
s.selected = detour s.selected.Store(detour)
return nil return nil
} }
s.selected = s.outbounds[s.tags[0]] s.selected.Store(s.outbounds[s.tags[0]])
return nil return nil
} }
func (s *Selector) Now() string { func (s *Selector) Now() string {
selected := s.selected selected := s.selected.Load()
if selected == nil { if selected == nil {
return s.tags[0] return s.tags[0]
} }
@ -114,10 +122,9 @@ func (s *Selector) SelectOutbound(tag string) bool {
if !loaded { if !loaded {
return false return false
} }
if s.selected == detour { if s.selected.Swap(detour) == detour {
return true return true
} }
s.selected = detour
if s.Tag() != "" { if s.Tag() != "" {
cacheFile := service.FromContext[adapter.CacheFile](s.ctx) cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
if cacheFile != nil { if cacheFile != nil {
@ -132,7 +139,7 @@ func (s *Selector) SelectOutbound(tag string) bool {
} }
func (s *Selector) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (s *Selector) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
conn, err := s.selected.DialContext(ctx, network, destination) conn, err := s.selected.Load().DialContext(ctx, network, destination)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -140,32 +147,30 @@ func (s *Selector) DialContext(ctx context.Context, network string, destination
} }
func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
conn, err := s.selected.ListenPacket(ctx, destination) conn, err := s.selected.Load().ListenPacket(ctx, destination)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
} }
// TODO func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
// Deprecated
func (s *Selector) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
ctx = interrupt.ContextWithIsExternalConnection(ctx) ctx = interrupt.ContextWithIsExternalConnection(ctx)
if legacyHandler, ok := s.selected.(adapter.ConnectionHandler); ok { selected := s.selected.Load()
return legacyHandler.NewConnection(ctx, conn, metadata) if outboundHandler, isHandler := selected.(adapter.ConnectionHandlerEx); isHandler {
outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose)
} else { } else {
return outbound.NewConnection(ctx, s.selected, conn, metadata) s.connection.NewConnection(ctx, selected, conn, metadata, onClose)
} }
} }
// TODO func (s *Selector) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
// Deprecated
func (s *Selector) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
ctx = interrupt.ContextWithIsExternalConnection(ctx) ctx = interrupt.ContextWithIsExternalConnection(ctx)
if legacyHandler, ok := s.selected.(adapter.PacketConnectionHandler); ok { selected := s.selected.Load()
return legacyHandler.NewPacketConnection(ctx, conn, metadata) if outboundHandler, isHandler := selected.(adapter.PacketConnectionHandlerEx); isHandler {
outboundHandler.NewPacketConnectionEx(ctx, conn, metadata, onClose)
} else { } else {
return outbound.NewPacketConnection(ctx, s.selected, conn, metadata) s.connection.NewPacketConnection(ctx, selected, conn, metadata, onClose)
} }
} }

View file

@ -36,7 +36,8 @@ type URLTest struct {
outbound.Adapter outbound.Adapter
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
outboundManager adapter.OutboundManager outbound adapter.OutboundManager
connection adapter.ConnectionManager
logger log.ContextLogger logger log.ContextLogger
tags []string tags []string
link string link string
@ -52,7 +53,8 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
ctx: ctx, ctx: ctx,
router: router, router: router,
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx),
connection: service.FromContext[adapter.ConnectionManager](ctx),
logger: logger, logger: logger,
tags: options.Outbounds, tags: options.Outbounds,
link: options.URL, link: options.URL,
@ -70,13 +72,13 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
func (s *URLTest) Start() error { func (s *URLTest) Start() error {
outbounds := make([]adapter.Outbound, 0, len(s.tags)) outbounds := make([]adapter.Outbound, 0, len(s.tags))
for i, tag := range s.tags { for i, tag := range s.tags {
detour, loaded := s.outboundManager.Outbound(tag) detour, loaded := s.outbound.Outbound(tag)
if !loaded { if !loaded {
return E.New("outbound ", i, " not found: ", tag) return E.New("outbound ", i, " not found: ", tag)
} }
outbounds = append(outbounds, detour) outbounds = append(outbounds, detour)
} }
group, err := NewURLTestGroup(s.ctx, s.outboundManager, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections) group, err := NewURLTestGroup(s.ctx, s.outbound, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections)
if err != nil { if err != nil {
return err return err
} }
@ -160,18 +162,14 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
return nil, err return nil, err
} }
// TODO func (s *URLTest) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
// Deprecated
func (s *URLTest) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
ctx = interrupt.ContextWithIsExternalConnection(ctx) ctx = interrupt.ContextWithIsExternalConnection(ctx)
return outbound.NewConnection(ctx, s, conn, metadata) s.connection.NewConnection(ctx, s, conn, metadata, onClose)
} }
// TODO func (s *URLTest) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
// Deprecated
func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
ctx = interrupt.ContextWithIsExternalConnection(ctx) ctx = interrupt.ContextWithIsExternalConnection(ctx)
return outbound.NewPacketConnection(ctx, s, conn, metadata) s.connection.NewPacketConnection(ctx, s, conn, metadata, onClose)
} }
func (s *URLTest) InterfaceUpdated() { func (s *URLTest) InterfaceUpdated() {

View file

@ -6,11 +6,14 @@ import (
"net" "net"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
"time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -199,6 +202,21 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
natConn.UpdateDestination(destinationAddress) natConn.UpdateDestination(destinationAddress)
} }
} }
var udpTimeout time.Duration
if metadata.UDPTimeout > 0 {
udpTimeout = metadata.UDPTimeout
} else {
protocol := metadata.Protocol
if protocol == "" {
protocol = C.PortProtocols[metadata.Destination.Port]
}
if protocol != "" {
udpTimeout = C.ProtocolTimeouts[protocol]
}
}
if udpTimeout > 0 {
ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
}
destination := bufio.NewPacketConn(remotePacketConn) destination := bufio.NewPacketConn(remotePacketConn)
if ctx.Done() != nil { if ctx.Done() != nil {
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))

View file

@ -132,23 +132,11 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if r.tracker != nil { if r.tracker != nil {
conn = r.tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound) conn = r.tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound)
} }
legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler) if outboundHandler, isHandler := selectedOutbound.(adapter.ConnectionHandlerEx); isHandler {
if isLegacy { outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose)
err = legacyOutbound.NewConnection(ctx, conn, metadata)
if err != nil {
conn.Close()
if onClose != nil {
onClose(err)
}
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
} else { } else {
if onClose != nil {
onClose(nil)
}
}
return nil
}
r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose) r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose)
}
return nil return nil
} }
@ -440,6 +428,9 @@ match:
if routeOptions.UDPConnect { if routeOptions.UDPConnect {
metadata.UDPConnect = true metadata.UDPConnect = true
} }
if routeOptions.UDPTimeout > 0 {
metadata.UDPTimeout = routeOptions.UDPTimeout
}
} }
switch action := currentRule.Action().(type) { switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff: case *rule.RuleActionSniff:

View file

@ -47,6 +47,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
FallbackDelay: time.Duration(action.RouteOptionsOptions.FallbackDelay), FallbackDelay: time.Duration(action.RouteOptionsOptions.FallbackDelay),
UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping, UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping,
UDPConnect: action.RouteOptionsOptions.UDPConnect, UDPConnect: action.RouteOptionsOptions.UDPConnect,
UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout),
}, nil }, nil
case C.RuleActionTypeDirect: case C.RuleActionTypeDirect:
directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions)) directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions))
@ -152,6 +153,7 @@ type RuleActionRouteOptions struct {
FallbackDelay time.Duration FallbackDelay time.Duration
UDPDisableDomainUnmapping bool UDPDisableDomainUnmapping bool
UDPConnect bool UDPConnect bool
UDPTimeout time.Duration
} }
func (r *RuleActionRouteOptions) Type() string { func (r *RuleActionRouteOptions) Type() string {