From a594edda137c8886ad8510b3101976057c7f70a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 24 Nov 2024 14:45:40 +0800 Subject: [PATCH] Add UDP timeout route option --- adapter/inbound.go | 10 +- adapter/outbound/default.go | 157 --------------------- constant/protocol.go | 1 + constant/timeout.go | 17 ++- docs/configuration/route/rule_action.md | 25 +++- docs/configuration/route/rule_action.zh.md | 25 +++- option/rule_action.go | 5 +- option/wireguard.go | 2 +- protocol/dns/outbound.go | 15 +- protocol/group/selector.go | 59 ++++---- protocol/group/urltest.go | 22 ++- route/conn.go | 18 +++ route/route.go | 36 ++--- route/rule/rule_action.go | 2 + 14 files changed, 155 insertions(+), 239 deletions(-) delete mode 100644 adapter/outbound/default.go diff --git a/adapter/inbound.go b/adapter/inbound.go index fd9c5405..f5d5c95b 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -70,10 +70,12 @@ type InboundContext struct { InboundOptions option.InboundOptions UDPDisableDomainUnmapping bool UDPConnect bool - NetworkStrategy C.NetworkStrategy - NetworkType []C.InterfaceType - FallbackNetworkType []C.InterfaceType - FallbackDelay time.Duration + UDPTimeout time.Duration + + NetworkStrategy C.NetworkStrategy + NetworkType []C.InterfaceType + FallbackNetworkType []C.InterfaceType + FallbackDelay time.Duration DNSServer string diff --git a/adapter/outbound/default.go b/adapter/outbound/default.go deleted file mode 100644 index 573673f2..00000000 --- a/adapter/outbound/default.go +++ /dev/null @@ -1,157 +0,0 @@ -package outbound - -import ( - "context" - "net" - "net/netip" - "os" - "time" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/canceler" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { - defer conn.Close() - ctx = adapter.WithContext(ctx, &metadata) - var outConn net.Conn - var err error - if len(metadata.DestinationAddresses) > 0 { - outConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) - } else { - outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) - } - if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - err = N.ReportConnHandshakeSuccess(conn, outConn) - if err != nil { - outConn.Close() - return err - } - return CopyEarlyConn(ctx, conn, outConn) -} - -func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { - defer conn.Close() - ctx = adapter.WithContext(ctx, &metadata) - var ( - outPacketConn net.PacketConn - outConn net.Conn - destinationAddress netip.Addr - err error - ) - if metadata.UDPConnect { - if len(metadata.DestinationAddresses) > 0 { - if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { - outConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) - } else { - outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) - } - } else { - outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) - } - if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - outPacketConn = bufio.NewUnbindPacketConn(outConn) - connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr()) - if connRemoteAddr != metadata.Destination.Addr { - destinationAddress = connRemoteAddr - } - } else { - if len(metadata.DestinationAddresses) > 0 { - outPacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) - } else { - outPacketConn, err = this.ListenPacket(ctx, metadata.Destination) - } - if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - } - err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn) - if err != nil { - outPacketConn.Close() - return err - } - if destinationAddress.IsValid() { - var originDestination M.Socksaddr - if metadata.RouteOriginalDestination.IsValid() { - originDestination = metadata.RouteOriginalDestination - } else { - originDestination = metadata.Destination - } - if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) { - if metadata.UDPDisableDomainUnmapping { - outPacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) - } else { - outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) - } - } - if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { - natConn.UpdateDestination(destinationAddress) - } - } - switch metadata.Protocol { - case C.ProtocolSTUN: - ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout) - case C.ProtocolQUIC: - ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout) - case C.ProtocolDNS: - ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout) - } - return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn)) -} - -func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error { - if cachedReader, isCached := conn.(N.CachedReader); isCached { - payload := cachedReader.ReadCached() - if payload != nil && !payload.IsEmpty() { - _, err := serverConn.Write(payload.Bytes()) - payload.Release() - if err != nil { - serverConn.Close() - return err - } - return bufio.CopyConn(ctx, conn, serverConn) - } - } - if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](serverConn); isEarlyConn && earlyConn.NeedHandshake() { - payload := buf.NewPacket() - err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout)) - if err != os.ErrInvalid { - if err != nil { - payload.Release() - serverConn.Close() - return err - } - _, err = payload.ReadOnceFrom(conn) - if err != nil && !E.IsTimeout(err) { - payload.Release() - serverConn.Close() - return E.Cause(err, "read payload") - } - err = conn.SetReadDeadline(time.Time{}) - if err != nil { - payload.Release() - serverConn.Close() - return err - } - } - _, err = serverConn.Write(payload.Bytes()) - payload.Release() - if err != nil { - serverConn.Close() - return N.ReportHandshakeFailure(conn, err) - } - } - return bufio.CopyConn(ctx, conn, serverConn) -} diff --git a/constant/protocol.go b/constant/protocol.go index 14854089..dbe16e51 100644 --- a/constant/protocol.go +++ b/constant/protocol.go @@ -10,6 +10,7 @@ const ( ProtocolDTLS = "dtls" ProtocolSSH = "ssh" ProtocolRDP = "rdp" + ProtocolNTP = "ntp" ) const ( diff --git a/constant/timeout.go b/constant/timeout.go index 67ae6f66..3b5a452b 100644 --- a/constant/timeout.go +++ b/constant/timeout.go @@ -9,8 +9,6 @@ const ( TCPTimeout = 15 * time.Second ReadPayloadTimeout = 300 * time.Millisecond DNSTimeout = 10 * time.Second - QUICTimeout = 30 * time.Second - STUNTimeout = 15 * time.Second UDPTimeout = 5 * time.Minute DefaultURLTestInterval = 3 * time.Minute DefaultURLTestIdleTimeout = 30 * time.Minute @@ -19,3 +17,18 @@ const ( FatalStopTimeout = 10 * time.Second FakeIPMetadataSaveInterval = 10 * time.Second ) + +var PortProtocols = map[uint16]string{ + 53: ProtocolDNS, + 123: ProtocolNTP, + 3478: ProtocolSTUN, + 443: ProtocolQUIC, +} + +var ProtocolTimeouts = map[string]time.Duration{ + ProtocolDNS: 10 * time.Second, + ProtocolNTP: 10 * time.Second, + ProtocolSTUN: 10 * time.Second, + ProtocolQUIC: 30 * time.Second, + ProtocolDTLS: 30 * time.Second, +} diff --git a/docs/configuration/route/rule_action.md b/docs/configuration/route/rule_action.md index 63e2b00b..fae52e85 100644 --- a/docs/configuration/route/rule_action.md +++ b/docs/configuration/route/rule_action.md @@ -41,7 +41,8 @@ See `route-options` fields below. "network_strategy": "", "fallback_delay": "", "udp_disable_domain_unmapping": false, - "udp_connect": false + "udp_connect": false, + "udp_timeout": "" } ``` @@ -86,6 +87,28 @@ do not support receiving UDP packets with domain addresses, such as Surge. If enabled, attempts to connect UDP connection to the destination instead of listen. +#### udp_timeout + +Timeout for UDP connections. + +Setting a larger value than the UDP timeout in inbounds will have no effect. + +Default value for protocol sniffed connections: + +| Timeout | Protocol | +|---------|----------------------| +| `10s` | `dns`, `ntp`, `stun` | +| `30s` | `quic`, `dtls` | + +If no protocol is sniffed, the following ports will be recognized as protocols by default: + +| Port | Protocol | +|------|----------| +| 53 | `dns` | +| 123 | `ntp` | +| 443 | `quic` | +| 3478 | `stun` | + ### reject ```json diff --git a/docs/configuration/route/rule_action.zh.md b/docs/configuration/route/rule_action.zh.md index 7959fced..2f558f4e 100644 --- a/docs/configuration/route/rule_action.zh.md +++ b/docs/configuration/route/rule_action.zh.md @@ -37,7 +37,8 @@ icon: material/new-box "network_strategy": "", "fallback_delay": "", "udp_disable_domain_unmapping": false, - "udp_connect": false + "udp_connect": false, + "udp_timeout": "" } ``` @@ -84,6 +85,28 @@ icon: material/new-box 如果启用,将尝试将 UDP 连接 connect 到目标而不是 listen。 +#### udp_timeout + +UDP 连接超时时间。 + +设置比入站 UDP 超时更大的值将无效。 + +已探测协议连接的默认值: + +| 超时 | 协议 | +|-------|----------------------| +| `10s` | `dns`, `ntp`, `stun` | +| `30s` | `quic`, `dtls` | + +如果没有探测到协议,以下端口将默认识别为协议: + +| 端口 | 协议 | +|------|--------| +| 53 | `dns` | +| 123 | `ntp` | +| 443 | `quic` | +| 3478 | `stun` | + ### reject ```json diff --git a/option/rule_action.go b/option/rule_action.go index ce3b92d9..29c5a0c3 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -148,8 +148,9 @@ type RawRouteOptionsActionOptions struct { NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` FallbackDelay uint32 `json:"fallback_delay,omitempty"` - UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` - UDPConnect bool `json:"udp_connect,omitempty"` + UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` + UDPConnect bool `json:"udp_connect,omitempty"` + UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` } type RouteOptionsActionOptions RawRouteOptionsActionOptions diff --git a/option/wireguard.go b/option/wireguard.go index b9860d11..43d3139c 100644 --- a/option/wireguard.go +++ b/option/wireguard.go @@ -14,7 +14,7 @@ type WireGuardEndpointOptions struct { PrivateKey string `json:"private_key"` ListenPort uint16 `json:"listen_port,omitempty"` Peers []WireGuardPeer `json:"peers,omitempty"` - UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"` + UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` Workers int `json:"workers,omitempty"` DialerOptions } diff --git a/protocol/dns/outbound.go b/protocol/dns/outbound.go index 3c493f80..5f06557b 100644 --- a/protocol/dns/outbound.go +++ b/protocol/dns/outbound.go @@ -42,20 +42,21 @@ func (d *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n return nil, os.ErrInvalid } -// Deprecated -func (d *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { +func (d *Outbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { metadata.Destination = M.Socksaddr{} - defer conn.Close() for { conn.SetReadDeadline(time.Now().Add(C.DNSTimeout)) err := HandleStreamDNSRequest(ctx, d.router, conn, metadata) if err != nil { - return err + conn.Close() + if onClose != nil { + onClose(err) + } + return } } } -// Deprecated -func (d *Outbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - return NewDNSPacketConnection(ctx, d.router, conn, nil, metadata) +func (d *Outbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + NewDNSPacketConnection(ctx, d.router, conn, nil, metadata) } diff --git a/protocol/group/selector.go b/protocol/group/selector.go index 0bb3cd66..9806e033 100644 --- a/protocol/group/selector.go +++ b/protocol/group/selector.go @@ -10,6 +10,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/atomic" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -21,17 +22,22 @@ func RegisterSelector(registry *outbound.Registry) { outbound.Register[option.SelectorOutboundOptions](registry, C.TypeSelector, NewSelector) } -var _ adapter.OutboundGroup = (*Selector)(nil) +var ( + _ adapter.OutboundGroup = (*Selector)(nil) + _ adapter.ConnectionHandlerEx = (*Selector)(nil) + _ adapter.PacketConnectionHandlerEx = (*Selector)(nil) +) type Selector struct { outbound.Adapter ctx context.Context - outboundManager adapter.OutboundManager + outbound adapter.OutboundManager + connection adapter.ConnectionManager logger logger.ContextLogger tags []string defaultTag string outbounds map[string]adapter.Outbound - selected adapter.Outbound + selected atomic.TypedValue[adapter.Outbound] interruptGroup *interrupt.Group interruptExternalConnections bool } @@ -40,7 +46,8 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL outbound := &Selector{ Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds), ctx: ctx, - outboundManager: service.FromContext[adapter.OutboundManager](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), + connection: service.FromContext[adapter.ConnectionManager](ctx), logger: logger, tags: options.Outbounds, defaultTag: options.Default, @@ -55,15 +62,16 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL } func (s *Selector) Network() []string { - if s.selected == nil { + selected := s.selected.Load() + if selected == nil { return []string{N.NetworkTCP, N.NetworkUDP} } - return s.selected.Network() + return selected.Network() } func (s *Selector) Start() error { for i, tag := range s.tags { - detour, loaded := s.outboundManager.Outbound(tag) + detour, loaded := s.outbound.Outbound(tag) if !loaded { return E.New("outbound ", i, " not found: ", tag) } @@ -77,7 +85,7 @@ func (s *Selector) Start() error { if selected != "" { detour, loaded := s.outbounds[selected] if loaded { - s.selected = detour + s.selected.Store(detour) return nil } } @@ -89,16 +97,16 @@ func (s *Selector) Start() error { if !loaded { return E.New("default outbound not found: ", s.defaultTag) } - s.selected = detour + s.selected.Store(detour) return nil } - s.selected = s.outbounds[s.tags[0]] + s.selected.Store(s.outbounds[s.tags[0]]) return nil } func (s *Selector) Now() string { - selected := s.selected + selected := s.selected.Load() if selected == nil { return s.tags[0] } @@ -114,10 +122,9 @@ func (s *Selector) SelectOutbound(tag string) bool { if !loaded { return false } - if s.selected == detour { + if s.selected.Swap(detour) == detour { return true } - s.selected = detour if s.Tag() != "" { cacheFile := service.FromContext[adapter.CacheFile](s.ctx) if cacheFile != nil { @@ -132,7 +139,7 @@ func (s *Selector) SelectOutbound(tag string) bool { } func (s *Selector) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - conn, err := s.selected.DialContext(ctx, network, destination) + conn, err := s.selected.Load().DialContext(ctx, network, destination) if err != nil { return nil, err } @@ -140,32 +147,30 @@ func (s *Selector) DialContext(ctx context.Context, network string, destination } func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - conn, err := s.selected.ListenPacket(ctx, destination) + conn, err := s.selected.Load().ListenPacket(ctx, destination) if err != nil { return nil, err } return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil } -// TODO -// Deprecated -func (s *Selector) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { +func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) - if legacyHandler, ok := s.selected.(adapter.ConnectionHandler); ok { - return legacyHandler.NewConnection(ctx, conn, metadata) + selected := s.selected.Load() + if outboundHandler, isHandler := selected.(adapter.ConnectionHandlerEx); isHandler { + outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose) } else { - return outbound.NewConnection(ctx, s.selected, conn, metadata) + s.connection.NewConnection(ctx, selected, conn, metadata, onClose) } } -// TODO -// Deprecated -func (s *Selector) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { +func (s *Selector) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) - if legacyHandler, ok := s.selected.(adapter.PacketConnectionHandler); ok { - return legacyHandler.NewPacketConnection(ctx, conn, metadata) + selected := s.selected.Load() + if outboundHandler, isHandler := selected.(adapter.PacketConnectionHandlerEx); isHandler { + outboundHandler.NewPacketConnectionEx(ctx, conn, metadata, onClose) } else { - return outbound.NewPacketConnection(ctx, s.selected, conn, metadata) + s.connection.NewPacketConnection(ctx, selected, conn, metadata, onClose) } } diff --git a/protocol/group/urltest.go b/protocol/group/urltest.go index fcada7dc..564c2373 100644 --- a/protocol/group/urltest.go +++ b/protocol/group/urltest.go @@ -36,7 +36,8 @@ type URLTest struct { outbound.Adapter ctx context.Context router adapter.Router - outboundManager adapter.OutboundManager + outbound adapter.OutboundManager + connection adapter.ConnectionManager logger log.ContextLogger tags []string link string @@ -52,7 +53,8 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), ctx: ctx, router: router, - outboundManager: service.FromContext[adapter.OutboundManager](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), + connection: service.FromContext[adapter.ConnectionManager](ctx), logger: logger, tags: options.Outbounds, link: options.URL, @@ -70,13 +72,13 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo func (s *URLTest) Start() error { outbounds := make([]adapter.Outbound, 0, len(s.tags)) for i, tag := range s.tags { - detour, loaded := s.outboundManager.Outbound(tag) + detour, loaded := s.outbound.Outbound(tag) if !loaded { return E.New("outbound ", i, " not found: ", tag) } outbounds = append(outbounds, detour) } - group, err := NewURLTestGroup(s.ctx, s.outboundManager, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections) + group, err := NewURLTestGroup(s.ctx, s.outbound, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections) if err != nil { return err } @@ -160,18 +162,14 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne return nil, err } -// TODO -// Deprecated -func (s *URLTest) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { +func (s *URLTest) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) - return outbound.NewConnection(ctx, s, conn, metadata) + s.connection.NewConnection(ctx, s, conn, metadata, onClose) } -// TODO -// Deprecated -func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { +func (s *URLTest) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) - return outbound.NewPacketConnection(ctx, s, conn, metadata) + s.connection.NewPacketConnection(ctx, s, conn, metadata, onClose) } func (s *URLTest) InterfaceUpdated() { diff --git a/route/conn.go b/route/conn.go index 594379cc..4a2192e0 100644 --- a/route/conn.go +++ b/route/conn.go @@ -6,11 +6,14 @@ import ( "net" "net/netip" "sync/atomic" + "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -208,6 +211,21 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial natConn.UpdateDestination(destinationAddress) } } + var udpTimeout time.Duration + if metadata.UDPTimeout > 0 { + udpTimeout = metadata.UDPTimeout + } else { + protocol := metadata.Protocol + if protocol == "" { + protocol = C.PortProtocols[metadata.Destination.Port] + } + if protocol != "" { + udpTimeout = C.ProtocolTimeouts[protocol] + } + } + if udpTimeout > 0 { + ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout) + } destination := bufio.NewPacketConn(remotePacketConn) var done atomic.Bool if ctx.Done() != nil { diff --git a/route/route.go b/route/route.go index 67eb2c69..05e22c25 100644 --- a/route/route.go +++ b/route/route.go @@ -132,23 +132,11 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if r.tracker != nil { conn = r.tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound) } - legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler) - if isLegacy { - err = legacyOutbound.NewConnection(ctx, conn, metadata) - if err != nil { - conn.Close() - if onClose != nil { - onClose(err) - } - return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")) - } else { - if onClose != nil { - onClose(nil) - } - } - return nil + if outboundHandler, isHandler := selectedOutbound.(adapter.ConnectionHandlerEx); isHandler { + outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose) + } else { + r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose) } - r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose) return nil } @@ -258,16 +246,11 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m if metadata.FakeIP { conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) } - legacyOutbound, isLegacy := selectedOutbound.(adapter.PacketConnectionHandler) - if isLegacy { - err = legacyOutbound.NewPacketConnection(ctx, conn, metadata) - N.CloseOnHandshakeFailure(conn, onClose, err) - if err != nil { - return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")) - } - return nil + if outboundHandler, isHandler := selectedOutbound.(adapter.PacketConnectionHandlerEx); isHandler { + outboundHandler.NewPacketConnectionEx(ctx, conn, metadata, onClose) + } else { + r.connection.NewPacketConnection(ctx, selectedOutbound, conn, metadata, onClose) } - r.connection.NewPacketConnection(ctx, selectedOutbound, conn, metadata, onClose) return nil } @@ -440,6 +423,9 @@ match: if routeOptions.UDPConnect { metadata.UDPConnect = true } + if routeOptions.UDPTimeout > 0 { + metadata.UDPTimeout = routeOptions.UDPTimeout + } } switch action := currentRule.Action().(type) { case *rule.RuleActionSniff: diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index 1b4099c9..34354cc0 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -47,6 +47,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti FallbackDelay: time.Duration(action.RouteOptionsOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping, UDPConnect: action.RouteOptionsOptions.UDPConnect, + UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout), }, nil case C.RuleActionTypeDirect: directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions)) @@ -152,6 +153,7 @@ type RuleActionRouteOptions struct { FallbackDelay time.Duration UDPDisableDomainUnmapping bool UDPConnect bool + UDPTimeout time.Duration } func (r *RuleActionRouteOptions) Type() string {