From 762418ac164d304ca4c4c13fc9b310b817012e14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 14 Nov 2024 18:31:37 +0800 Subject: [PATCH] Add override destination to route options --- adapter/inbound.go | 7 ++-- adapter/outbound/default.go | 24 ++++++------- common/dialer/default_parallel_network.go | 44 ++++++++++++++++------- experimental/deprecated/constants.go | 10 ++++++ option/direct.go | 29 +++++++++++++-- option/rule_action.go | 25 ++++++------- route/route.go | 37 ++++++++++++++----- route/rule/rule_action.go | 7 ++++ 8 files changed, 132 insertions(+), 51 deletions(-) diff --git a/adapter/inbound.go b/adapter/inbound.go index 1093bac4..d3cc95b7 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -61,9 +61,10 @@ type InboundContext struct { // cache // Deprecated: implement in rule action - InboundDetour string - LastInbound string - OriginDestination M.Socksaddr + InboundDetour string + LastInbound string + OriginDestination M.Socksaddr + RouteOriginalDestination M.Socksaddr // Deprecated InboundOptions option.InboundOptions UDPDisableDomainUnmapping bool diff --git a/adapter/outbound/default.go b/adapter/outbound/default.go index 27cafb9d..573673f2 100644 --- a/adapter/outbound/default.go +++ b/adapter/outbound/default.go @@ -25,11 +25,7 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a var outConn net.Conn var err error if len(metadata.DestinationAddresses) > 0 { - if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { - outConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) - } else { - outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) - } + outConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } @@ -73,11 +69,7 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, } } else { if len(metadata.DestinationAddresses) > 0 { - if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { - outPacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, parallelDialer, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) - } else { - outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) - } + outPacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { outPacketConn, err = this.ListenPacket(ctx, metadata.Destination) } @@ -91,11 +83,17 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, return err } if destinationAddress.IsValid() { - if metadata.Destination.IsFqdn() { + var originDestination M.Socksaddr + if metadata.RouteOriginalDestination.IsValid() { + originDestination = metadata.RouteOriginalDestination + } else { + originDestination = metadata.Destination + } + if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) { if metadata.UDPDisableDomainUnmapping { - outPacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) + outPacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) } else { - outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) + outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) } } if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { diff --git a/common/dialer/default_parallel_network.go b/common/dialer/default_parallel_network.go index ea043dfd..5145656b 100644 --- a/common/dialer/default_parallel_network.go +++ b/common/dialer/default_parallel_network.go @@ -13,17 +13,27 @@ import ( N "github.com/sagernet/sing/common/network" ) -func DialSerialNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { +func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel { return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) } var errors []error - for _, address := range destinationAddresses { - conn, err := dialer.DialParallelInterface(ctx, network, M.SocksaddrFrom(address, destination.Port), strategy, interfaceType, fallbackInterfaceType, fallbackDelay) - if err == nil { - return conn, nil + if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel { + for _, address := range destinationAddresses { + conn, err := parallelDialer.DialParallelInterface(ctx, network, M.SocksaddrFrom(address, destination.Port), strategy, interfaceType, fallbackInterfaceType, fallbackDelay) + if err == nil { + return conn, nil + } + errors = append(errors, err) + } + } else { + for _, address := range destinationAddresses { + conn, err := dialer.DialContext(ctx, network, M.SocksaddrFrom(address, destination.Port)) + if err == nil { + return conn, nil + } + errors = append(errors, err) } - errors = append(errors, err) } return nil, E.Errors(errors...) } @@ -106,17 +116,27 @@ func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, ne } } -func ListenSerialNetworkPacket(ctx context.Context, dialer ParallelInterfaceDialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { +func ListenSerialNetworkPacket(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel { return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) } var errors []error - for _, address := range destinationAddresses { - conn, err := dialer.ListenSerialInterfacePacket(ctx, M.SocksaddrFrom(address, destination.Port), strategy, interfaceType, fallbackInterfaceType, fallbackDelay) - if err == nil { - return conn, address, nil + if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel { + for _, address := range destinationAddresses { + conn, err := parallelDialer.ListenSerialInterfacePacket(ctx, M.SocksaddrFrom(address, destination.Port), strategy, interfaceType, fallbackInterfaceType, fallbackDelay) + if err == nil { + return conn, address, nil + } + errors = append(errors, err) + } + } else { + for _, address := range destinationAddresses { + conn, err := dialer.ListenPacket(ctx, M.SocksaddrFrom(address, destination.Port)) + if err == nil { + return conn, address, nil + } + errors = append(errors, err) } - errors = append(errors, err) } return nil, netip.Addr{}, E.Errors(errors...) } diff --git a/experimental/deprecated/constants.go b/experimental/deprecated/constants.go index bf848ffe..a62b89f8 100644 --- a/experimental/deprecated/constants.go +++ b/experimental/deprecated/constants.go @@ -100,6 +100,15 @@ var OptionInboundOptions = Note{ MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-legacy-special-outbounds-to-rule-actions", } +var OptionDestinationOverrideFields = Note{ + Name: "destination-override-fields", + Description: "destination override fields in direct outbound", + DeprecatedVersion: "1.11.0", + ScheduledVersion: "1.13.0", + EnvName: "DESTINATION_OVERRIDE_FIELDS", + MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options", +} + var Options = []Note{ OptionBadMatchSource, OptionGEOIP, @@ -107,4 +116,5 @@ var Options = []Note{ OptionTUNAddressX, OptionSpecialOutbounds, OptionInboundOptions, + OptionDestinationOverrideFields, } diff --git a/option/direct.go b/option/direct.go index 8624846d..180ff0aa 100644 --- a/option/direct.go +++ b/option/direct.go @@ -1,5 +1,12 @@ package option +import ( + "context" + + "github.com/sagernet/sing-box/experimental/deprecated" + "github.com/sagernet/sing/common/json" +) + type DirectInboundOptions struct { ListenOptions Network NetworkList `json:"network,omitempty"` @@ -7,9 +14,25 @@ type DirectInboundOptions struct { OverridePort uint16 `json:"override_port,omitempty"` } -type DirectOutboundOptions struct { +type _DirectOutboundOptions struct { DialerOptions + // Deprecated: Use Route Action instead OverrideAddress string `json:"override_address,omitempty"` - OverridePort uint16 `json:"override_port,omitempty"` - ProxyProtocol uint8 `json:"proxy_protocol,omitempty"` + // Deprecated: Use Route Action instead + OverridePort uint16 `json:"override_port,omitempty"` + // Deprecated: removed + ProxyProtocol uint8 `json:"proxy_protocol,omitempty"` +} + +type DirectOutboundOptions _DirectOutboundOptions + +func (d *DirectOutboundOptions) UnmarshalJSONContext(ctx context.Context, content []byte) error { + err := json.UnmarshalDisallowUnknownFields(content, (*_DirectOutboundOptions)(d)) + if err != nil { + return err + } + if d.OverrideAddress != "" || d.OverridePort != 0 { + deprecated.Report(ctx, deprecated.OptionDestinationOverrideFields) + } + return nil } diff --git a/option/rule_action.go b/option/rule_action.go index 7c31ea7a..ce3b92d9 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -137,24 +137,25 @@ func (r *DNSRuleAction) UnmarshalJSONContext(ctx context.Context, data []byte) e } type RouteActionOptions struct { - Outbound string `json:"outbound,omitempty"` - NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` - FallbackDelay uint32 `json:"fallback_delay,omitempty"` - UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` - UDPConnect bool `json:"udp_connect,omitempty"` + Outbound string `json:"outbound,omitempty"` + RawRouteOptionsActionOptions } -type _RouteOptionsActionOptions struct { - NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` - FallbackDelay uint32 `json:"fallback_delay,omitempty"` - UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` - UDPConnect bool `json:"udp_connect,omitempty"` +type RawRouteOptionsActionOptions struct { + OverrideAddress string `json:"override_address,omitempty"` + OverridePort uint16 `json:"override_port,omitempty"` + + NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` + FallbackDelay uint32 `json:"fallback_delay,omitempty"` + + UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` + UDPConnect bool `json:"udp_connect,omitempty"` } -type RouteOptionsActionOptions _RouteOptionsActionOptions +type RouteOptionsActionOptions RawRouteOptionsActionOptions func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error { - err := json.Unmarshal(data, (*_RouteOptionsActionOptions)(r)) + err := json.Unmarshal(data, (*RawRouteOptionsActionOptions)(r)) if err != nil { return err } diff --git a/route/route.go b/route/route.go index 051ee403..7d21d315 100644 --- a/route/route.go +++ b/route/route.go @@ -422,17 +422,38 @@ match: } } } + var routeOptions *rule.RuleActionRouteOptions switch action := currentRule.Action().(type) { case *rule.RuleActionRoute: - metadata.NetworkStrategy = action.NetworkStrategy - metadata.FallbackDelay = action.FallbackDelay - metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping - metadata.UDPConnect = action.UDPConnect + routeOptions = &action.RuleActionRouteOptions case *rule.RuleActionRouteOptions: - metadata.NetworkStrategy = action.NetworkStrategy - metadata.FallbackDelay = action.FallbackDelay - metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping - metadata.UDPConnect = action.UDPConnect + routeOptions = action + } + if routeOptions != nil { + // TODO: add nat + if (routeOptions.OverrideAddress.IsValid() || routeOptions.OverridePort > 0) && !metadata.RouteOriginalDestination.IsValid() { + metadata.RouteOriginalDestination = metadata.Destination + } + if routeOptions.OverrideAddress.IsValid() { + metadata.Destination = M.Socksaddr{ + Addr: routeOptions.OverrideAddress.Addr, + Port: metadata.Destination.Port, + Fqdn: routeOptions.OverrideAddress.Fqdn, + } + } + if routeOptions.OverridePort > 0 { + metadata.Destination = M.Socksaddr{ + Addr: metadata.Destination.Addr, + Port: routeOptions.OverridePort, + Fqdn: metadata.Destination.Fqdn, + } + } + metadata.NetworkStrategy = routeOptions.NetworkStrategy + metadata.FallbackDelay = routeOptions.FallbackDelay + metadata.UDPDisableDomainUnmapping = routeOptions.UDPDisableDomainUnmapping + metadata.UDPConnect = routeOptions.UDPConnect + } + switch action := currentRule.Action().(type) { case *rule.RuleActionSniff: if !preMatch { newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn) diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index 52e95207..1b4099c9 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -19,6 +19,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -30,6 +31,8 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti return &RuleActionRoute{ Outbound: action.RouteOptions.Outbound, RuleActionRouteOptions: RuleActionRouteOptions{ + OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptions.OverrideAddress, 0), + OverridePort: action.RouteOptions.OverridePort, NetworkStrategy: C.NetworkStrategy(action.RouteOptions.NetworkStrategy), FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, @@ -38,6 +41,8 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti }, nil case C.RuleActionTypeRouteOptions: return &RuleActionRouteOptions{ + OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptionsOptions.OverrideAddress, 0), + OverridePort: action.RouteOptionsOptions.OverridePort, NetworkStrategy: C.NetworkStrategy(action.RouteOptionsOptions.NetworkStrategy), FallbackDelay: time.Duration(action.RouteOptionsOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping, @@ -139,6 +144,8 @@ func (r *RuleActionRoute) String() string { } type RuleActionRouteOptions struct { + OverrideAddress M.Socksaddr + OverridePort uint16 NetworkStrategy C.NetworkStrategy NetworkType []C.InterfaceType FallbackNetworkType []C.InterfaceType