From 58c4fd745ae83b56a693d4d0218d11aa0279cb97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 21 Mar 2023 21:36:17 +0800 Subject: [PATCH] Add L3 routing support --- adapter/inbound.go | 2 +- adapter/outbound.go | 6 + adapter/router.go | 11 + common/badjsonmerge/merge_test.go | 4 +- common/dialer/tfo.go | 7 +- docs/configuration/outbound/index.md | 8 +- docs/configuration/outbound/index.zh.md | 8 +- docs/configuration/route/index.md | 12 +- docs/configuration/route/index.zh.md | 14 +- docs/configuration/route/ip-rule.md | 205 +++++++++++++ docs/configuration/route/ip-rule.zh.md | 204 +++++++++++++ docs/configuration/route/rule.md | 16 +- docs/configuration/route/rule.zh.md | 16 +- docs/examples/index.md | 1 + docs/examples/index.zh.md | 1 + docs/examples/wireguard-direct.md | 90 ++++++ inbound/tun.go | 38 ++- mkdocs.yml | 2 + option/dns.go | 103 ------- option/route.go | 101 +------ option/rule.go | 101 +++++++ option/rule_dns.go | 107 +++++++ option/rule_ip.go | 120 ++++++++ option/wireguard.go | 1 + outbound/wireguard.go | 45 ++- route/router.go | 279 ++--------------- route/router_dns.go | 3 + route/router_geo_resources.go | 283 ++++++++++++++++++ route/router_ip.go | 66 ++++ route/rule_abstract.go | 203 +++++++++++++ route/{rule.go => rule_default.go} | 221 +------------- route/rule_dns.go | 223 ++------------ route/rule_ip.go | 189 ++++++++++++ ...le_auth_user.go => rule_item_auth_user.go} | 0 route/{rule_cidr.go => rule_item_cidr.go} | 0 ..._clash_mode.go => rule_item_clash_mode.go} | 0 route/{rule_domain.go => rule_item_domain.go} | 0 ...keyword.go => rule_item_domain_keyword.go} | 0 ...ain_regex.go => rule_item_domain_regex.go} | 0 route/{rule_geoip.go => rule_item_geoip.go} | 0 .../{rule_geosite.go => rule_item_geosite.go} | 0 .../{rule_inbound.go => rule_item_inbound.go} | 0 ...le_ipversion.go => rule_item_ipversion.go} | 0 route/rule_item_network.go | 42 +++ ...rule_outbound.go => rule_item_outbound.go} | 0 ...kage_name.go => rule_item_package_name.go} | 0 route/{rule_port.go => rule_item_port.go} | 0 ..._port_range.go => rule_item_port_range.go} | 0 ...cess_name.go => rule_item_process_name.go} | 0 ...cess_path.go => rule_item_process_path.go} | 0 ...rule_protocol.go => rule_item_protocol.go} | 0 ..._query_type.go => rule_item_query_type.go} | 0 route/{rule_user.go => rule_item_user.go} | 0 .../{rule_user_id.go => rule_item_user_id.go} | 0 route/rule_network.go | 23 -- transport/wireguard/device.go | 14 +- transport/wireguard/device_nat.go | 75 +++++ transport/wireguard/device_nat_gvisor.go | 27 ++ transport/wireguard/device_stack.go | 100 +++++-- transport/wireguard/device_stack_stub.go | 2 +- transport/wireguard/device_system.go | 35 ++- 61 files changed, 2043 insertions(+), 965 deletions(-) create mode 100644 docs/configuration/route/ip-rule.md create mode 100644 docs/configuration/route/ip-rule.zh.md create mode 100644 docs/examples/wireguard-direct.md create mode 100644 option/rule.go create mode 100644 option/rule_dns.go create mode 100644 option/rule_ip.go create mode 100644 route/router_geo_resources.go create mode 100644 route/router_ip.go create mode 100644 route/rule_abstract.go rename route/{rule.go => rule_default.go} (61%) create mode 100644 route/rule_ip.go rename route/{rule_auth_user.go => rule_item_auth_user.go} (100%) rename route/{rule_cidr.go => rule_item_cidr.go} (100%) rename route/{rule_clash_mode.go => rule_item_clash_mode.go} (100%) rename route/{rule_domain.go => rule_item_domain.go} (100%) rename route/{rule_domain_keyword.go => rule_item_domain_keyword.go} (100%) rename route/{rule_domain_regex.go => rule_item_domain_regex.go} (100%) rename route/{rule_geoip.go => rule_item_geoip.go} (100%) rename route/{rule_geosite.go => rule_item_geosite.go} (100%) rename route/{rule_inbound.go => rule_item_inbound.go} (100%) rename route/{rule_ipversion.go => rule_item_ipversion.go} (100%) create mode 100644 route/rule_item_network.go rename route/{rule_outbound.go => rule_item_outbound.go} (100%) rename route/{rule_package_name.go => rule_item_package_name.go} (100%) rename route/{rule_port.go => rule_item_port.go} (100%) rename route/{rule_port_range.go => rule_item_port_range.go} (100%) rename route/{rule_process_name.go => rule_item_process_name.go} (100%) rename route/{rule_process_path.go => rule_item_process_path.go} (100%) rename route/{rule_protocol.go => rule_item_protocol.go} (100%) rename route/{rule_query_type.go => rule_item_query_type.go} (100%) rename route/{rule_user.go => rule_item_user.go} (100%) rename route/{rule_user_id.go => rule_item_user_id.go} (100%) delete mode 100644 route/rule_network.go create mode 100644 transport/wireguard/device_nat.go create mode 100644 transport/wireguard/device_nat_gvisor.go diff --git a/adapter/inbound.go b/adapter/inbound.go index 6e478ba3..356a3200 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -27,7 +27,7 @@ type InjectableInbound interface { type InboundContext struct { Inbound string InboundType string - IPVersion int + IPVersion uint8 Network string Source M.Socksaddr Destination M.Socksaddr diff --git a/adapter/outbound.go b/adapter/outbound.go index a45c27fd..03c99d51 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -4,6 +4,7 @@ import ( "context" "net" + "github.com/sagernet/sing-tun" N "github.com/sagernet/sing/common/network" ) @@ -17,3 +18,8 @@ type Outbound interface { NewConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error } + +type IPOutbound interface { + Outbound + NewIPConnection(ctx context.Context, conn tun.RouteContext, metadata InboundContext) (tun.DirectDestination, error) +} diff --git a/adapter/router.go b/adapter/router.go index e1807747..3fb5a109 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -23,6 +23,9 @@ type Router interface { RouteConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error + RouteIPConnection(ctx context.Context, conn tun.RouteContext, metadata InboundContext) tun.RouteAction + + NatRequired(outbound string) bool GeoIPReader() *geoip.Reader LoadGeosite(code string) (Rule, error) @@ -39,7 +42,9 @@ type Router interface { NetworkMonitor() tun.NetworkUpdateMonitor InterfaceMonitor() tun.DefaultInterfaceMonitor PackageManager() tun.PackageManager + Rules() []Rule + IPRules() []IPRule TimeService @@ -76,6 +81,12 @@ type Rule interface { type DNSRule interface { Rule DisableCache() bool + RewriteTTL() *uint32 +} + +type IPRule interface { + Rule + Action() tun.ActionType } type InterfaceUpdateListener interface { diff --git a/common/badjsonmerge/merge_test.go b/common/badjsonmerge/merge_test.go index d1714cd6..be4481b5 100644 --- a/common/badjsonmerge/merge_test.go +++ b/common/badjsonmerge/merge_test.go @@ -21,7 +21,7 @@ func TestMergeJSON(t *testing.T) { { Type: C.RuleTypeDefault, DefaultOptions: option.DefaultRule{ - Network: N.NetworkTCP, + Network: []string{N.NetworkTCP}, Outbound: "direct", }, }, @@ -42,7 +42,7 @@ func TestMergeJSON(t *testing.T) { { Type: C.RuleTypeDefault, DefaultOptions: option.DefaultRule{ - Network: N.NetworkUDP, + Network: []string{N.NetworkUDP}, Outbound: "direct", }, }, diff --git a/common/dialer/tfo.go b/common/dialer/tfo.go index 0205daaf..d577560c 100644 --- a/common/dialer/tfo.go +++ b/common/dialer/tfo.go @@ -27,7 +27,12 @@ type slowOpenConn struct { func DialSlowContext(dialer *tfo.Dialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP { - return dialer.DialContext(ctx, network, destination.String(), nil) + switch N.NetworkName(network) { + case N.NetworkTCP, N.NetworkUDP: + return dialer.Dialer.DialContext(ctx, network, destination.String()) + default: + return dialer.Dialer.DialContext(ctx, network, destination.AddrString()) + } } return &slowOpenConn{ dialer: dialer, diff --git a/docs/configuration/outbound/index.md b/docs/configuration/outbound/index.md index a8a1874f..83320971 100644 --- a/docs/configuration/outbound/index.md +++ b/docs/configuration/outbound/index.md @@ -37,4 +37,10 @@ #### tag -The tag of the outbound. \ No newline at end of file +The tag of the outbound. + +### Features + +#### Outbounds that support IP connection + +* `WireGuard` diff --git a/docs/configuration/outbound/index.zh.md b/docs/configuration/outbound/index.zh.md index f9053356..e54a1d95 100644 --- a/docs/configuration/outbound/index.zh.md +++ b/docs/configuration/outbound/index.zh.md @@ -36,4 +36,10 @@ #### tag -出站的标签。 \ No newline at end of file +出站的标签。 + +### 特性 + +#### 支持 IP 连接的出站 + +* `WireGuard` diff --git a/docs/configuration/route/index.md b/docs/configuration/route/index.md index 7440f2bb..beabafbb 100644 --- a/docs/configuration/route/index.md +++ b/docs/configuration/route/index.md @@ -7,6 +7,7 @@ "route": { "geoip": {}, "geosite": {}, + "ip_rules": [], "rules": [], "final": "", "auto_detect_interface": false, @@ -19,11 +20,12 @@ ### Fields -| Key | Format | -|-----------|------------------------------| -| `geoip` | [GeoIP](./geoip) | -| `geosite` | [Geosite](./geosite) | -| `rules` | List of [Route Rule](./rule) | +| Key | Format | +|------------|------------------------------------| +| `geoip` | [GeoIP](./geoip) | +| `geosite` | [Geosite](./geosite) | +| `ip_rules` | List of [IP Route Rule](./ip-rule) | +| `rules` | List of [Route Rule](./rule) | #### final diff --git a/docs/configuration/route/index.zh.md b/docs/configuration/route/index.zh.md index e0bbe917..8525f7b0 100644 --- a/docs/configuration/route/index.zh.md +++ b/docs/configuration/route/index.zh.md @@ -7,6 +7,7 @@ "route": { "geoip": {}, "geosite": {}, + "ip_rules": [], "rules": [], "final": "", "auto_detect_interface": false, @@ -19,11 +20,12 @@ ### 字段 -| 键 | 格式 | -|-----------|----------------------| -| `geoip` | [GeoIP](./geoip) | -| `geosite` | [GeoSite](./geosite) | -| `rules` | 一组 [路由规则](./rule) | +| 键 | 格式 | +|------------|-------------------------| +| `geoip` | [GeoIP](./geoip) | +| `geosite` | [GeoSite](./geosite) | +| `ip_rules` | 一组 [IP 路由规则](./ip-rule) | +| `rules` | 一组 [路由规则](./rule) | #### final @@ -65,4 +67,4 @@ 默认为出站连接设置路由标记。 -如果设置了 `outbound.routing_mark` 设置,则不生效。 +如果设置了 `outbound.routing_mark` 设置,则不生效。 \ No newline at end of file diff --git a/docs/configuration/route/ip-rule.md b/docs/configuration/route/ip-rule.md new file mode 100644 index 00000000..352c39f8 --- /dev/null +++ b/docs/configuration/route/ip-rule.md @@ -0,0 +1,205 @@ +### Structure + +```json +{ + "route": { + "ip_rules": [ + { + "inbound": [ + "mixed-in" + ], + "ip_version": 6, + "network": [ + "tcp" + ], + "domain": [ + "test.com" + ], + "domain_suffix": [ + ".cn" + ], + "domain_keyword": [ + "test" + ], + "domain_regex": [ + "^stun\\..+" + ], + "geosite": [ + "cn" + ], + "source_geoip": [ + "private" + ], + "geoip": [ + "cn" + ], + "source_ip_cidr": [ + "10.0.0.0/24", + "192.168.0.1" + ], + "ip_cidr": [ + "10.0.0.0/24", + "192.168.0.1" + ], + "source_port": [ + 12345 + ], + "source_port_range": [ + "1000:2000", + ":3000", + "4000:" + ], + "port": [ + 80, + 443 + ], + "port_range": [ + "1000:2000", + ":3000", + "4000:" + ], + "invert": false, + "action": "direct", + "outbound": "wireguard" + }, + { + "type": "logical", + "mode": "and", + "rules": [], + "invert": false, + "action": "direct", + "outbound": "wireguard" + } + ] + } +} + +``` + +!!! note "" + + You can ignore the JSON Array [] tag when the content is only one item + +### Default Fields + +!!! note "" + + The default rule uses the following matching logic: + (`domain` || `domain_suffix` || `domain_keyword` || `domain_regex` || `geosite` || `geoip` || `ip_cidr`) && + (`port` || `port_range`) && + (`source_geoip` || `source_ip_cidr`) && + (`source_port` || `source_port_range`) && + `other fields` + +#### inbound + +Tags of [Inbound](/configuration/inbound). + +#### ip_version + +4 or 6. + +Not limited if empty. + +#### network + +Match network protocol. + +Available values: + +* `tcp` +* `udp` +* `icmpv4` +* `icmpv6` + +#### domain + +Match full domain. + +#### domain_suffix + +Match domain suffix. + +#### domain_keyword + +Match domain using keyword. + +#### domain_regex + +Match domain using regular expression. + +#### geosite + +Match geosite. + +#### source_geoip + +Match source geoip. + +#### geoip + +Match geoip. + +#### source_ip_cidr + +Match source ip cidr. + +#### ip_cidr + +Match ip cidr. + +#### source_port + +Match source port. + +#### source_port_range + +Match source port range. + +#### port + +Match port. + +#### port_range + +Match port range. + +#### invert + +Invert match result. + +#### action + +==Required== + +| Action | Description | +|--------|--------------------------------------------------------------------| +| return | Stop IP routing and assemble the connection to the transport layer | +| block | Block the connection | +| direct | Directly forward the connection | + +#### outbound + +==Required if action is direct== + +Tag of the target outbound. + +Only outbound which supports IP connection can be used, see [Outbounds that support IP connection](/configuration/outbound/#outbounds-that-support-ip-connection). + +### Logical Fields + +#### type + +`logical` + +#### mode + +==Required== + +`and` or `or` + +#### rules + +==Required== + +Included default rules. \ No newline at end of file diff --git a/docs/configuration/route/ip-rule.zh.md b/docs/configuration/route/ip-rule.zh.md new file mode 100644 index 00000000..d580086c --- /dev/null +++ b/docs/configuration/route/ip-rule.zh.md @@ -0,0 +1,204 @@ +### 结构 + +```json +{ + "route": { + "ip_rules": [ + { + "inbound": [ + "mixed-in" + ], + "ip_version": 6, + "network": [ + "tcp" + ], + "domain": [ + "test.com" + ], + "domain_suffix": [ + ".cn" + ], + "domain_keyword": [ + "test" + ], + "domain_regex": [ + "^stun\\..+" + ], + "geosite": [ + "cn" + ], + "source_geoip": [ + "private" + ], + "geoip": [ + "cn" + ], + "source_ip_cidr": [ + "10.0.0.0/24", + "192.168.0.1" + ], + "ip_cidr": [ + "10.0.0.0/24", + "192.168.0.1" + ], + "source_port": [ + 12345 + ], + "source_port_range": [ + "1000:2000", + ":3000", + "4000:" + ], + "port": [ + 80, + 443 + ], + "port_range": [ + "1000:2000", + ":3000", + "4000:" + ], + "invert": false, + "action": "direct", + "outbound": "wireguard" + }, + { + "type": "logical", + "mode": "and", + "rules": [], + "invert": false, + "action": "direct", + "outbound": "wireguard" + } + ] + } +} + +``` + +!!! note "" + + 当内容只有一项时,可以忽略 JSON 数组 [] 标签。 + +### Default Fields + +!!! note "" + + 默认规则使用以下匹配逻辑: + (`domain` || `domain_suffix` || `domain_keyword` || `domain_regex` || `geosite` || `geoip` || `ip_cidr`) && + (`port` || `port_range`) && + (`source_geoip` || `source_ip_cidr`) && + (`source_port` || `source_port_range`) && + `other fields` + +#### inbound + +[入站](/zh/configuration/inbound) 标签。 + +#### ip_version + +4 或 6。 + +默认不限制。 + +#### network + +匹配网络协议。 + +可用值: + +* `tcp` +* `udp` +* `icmpv4` +* `icmpv6` + +#### domain + +匹配完整域名。 + +#### domain_suffix + +匹配域名后缀。 + +#### domain_keyword + +匹配域名关键字。 + +#### domain_regex + +匹配域名正则表达式。 + +#### geosite + +匹配 GeoSite。 + +#### source_geoip + +匹配源 GeoIP。 + +#### geoip + +匹配 GeoIP。 + +#### source_ip_cidr + +匹配源 IP CIDR。 + +#### ip_cidr + +匹配 IP CIDR。 + +#### source_port + +匹配源端口。 + +#### source_port_range + +匹配源端口范围。 + +#### port + +匹配端口。 + +#### port_range + +匹配端口范围。 + +#### invert + +反选匹配结果。 + +#### action + +==必填== + +| Action | 描述 | +|--------|---------------------| +| return | 停止 IP 路由并将该连接组装到传输层 | +| block | 屏蔽该连接 | +| direct | 直接转发该连接 | + + +#### outbound + +==action 为 direct 则必填== + +目标出站的标签。 + +### 逻辑字段 + +#### type + +`logical` + +#### mode + +==必填== + +`and` 或 `or` + +#### rules + +==必填== + +包括的默认规则。 \ No newline at end of file diff --git a/docs/configuration/route/rule.md b/docs/configuration/route/rule.md index a838105c..3cee478d 100644 --- a/docs/configuration/route/rule.md +++ b/docs/configuration/route/rule.md @@ -9,7 +9,9 @@ "mixed-in" ], "ip_version": 6, - "network": "tcp", + "network": [ + "tcp" + ], "auth_user": [ "usera", "userb" @@ -244,18 +246,12 @@ Tag of the target outbound. #### mode +==Required== + `and` or `or` #### rules -Included default rules. - -#### invert - -Invert match result. - -#### outbound - ==Required== -Tag of the target outbound. +Included default rules. diff --git a/docs/configuration/route/rule.zh.md b/docs/configuration/route/rule.zh.md index fc7d5990..4a09ed8e 100644 --- a/docs/configuration/route/rule.zh.md +++ b/docs/configuration/route/rule.zh.md @@ -9,7 +9,9 @@ "mixed-in" ], "ip_version": 6, - "network": "tcp", + "network": [ + "tcp" + ], "auth_user": [ "usera", "userb" @@ -242,18 +244,12 @@ #### mode +==必填== + `and` 或 `or` #### rules -包括的默认规则。 - -#### invert - -反选匹配结果。 - -#### outbound - ==必填== -目标出站的标签。 +包括的默认规则。 \ No newline at end of file diff --git a/docs/examples/index.md b/docs/examples/index.md index ca2fa8e9..dbc6c42b 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -8,3 +8,4 @@ Configuration examples for sing-box. * [Shadowsocks](./shadowsocks) * [ShadowTLS](./shadowtls) * [Clash API](./clash-api) +* [WireGuard Direct](./wireguard-direct) diff --git a/docs/examples/index.zh.md b/docs/examples/index.zh.md index e4d17c38..1338081f 100644 --- a/docs/examples/index.zh.md +++ b/docs/examples/index.zh.md @@ -8,3 +8,4 @@ sing-box 的配置示例。 * [Shadowsocks](./shadowsocks) * [ShadowTLS](./shadowtls) * [Clash API](./clash-api) +* [WireGuard Direct](./wireguard-direct) diff --git a/docs/examples/wireguard-direct.md b/docs/examples/wireguard-direct.md new file mode 100644 index 00000000..98e5d575 --- /dev/null +++ b/docs/examples/wireguard-direct.md @@ -0,0 +1,90 @@ +# WireGuard Direct + +```json +{ + "dns": { + "servers": [ + { + "tag": "google", + "address": "tls://8.8.8.8" + }, + { + "tag": "local", + "address": "223.5.5.5", + "detour": "direct" + } + ], + "rules": [ + { + "geoip": "cn", + "server": "direct" + } + ], + "reverse_mapping": true + }, + "inbounds": [ + { + "type": "tun", + "tag": "tun", + "inet4_address": "172.19.0.1/30", + "auto_route": true, + "sniff": true, + "stack": "system" + } + ], + "outbounds": [ + { + "type": "wireguard", + "tag": "wg", + "server": "127.0.0.1", + "server_port": 2345, + "local_address": [ + "172.19.0.1/128" + ], + "private_key": "KLTnpPY03pig/WC3zR8U7VWmpANHPFh2/4pwICGJ5Fk=", + "peer_public_key": "uvNabcamf6Rs0vzmcw99jsjTJbxo6eWGOykSY66zsUk=" + }, + { + "type": "dns", + "tag": "dns" + }, + { + "type": "direct", + "tag": "direct" + }, + { + "type": "block", + "tag": "block" + } + ], + "route": { + "ip_rules": [ + { + "port": 53, + "action": "return" + }, + { + "geoip": "cn", + "geosite": "cn", + "action": "return" + }, + { + "action": "direct", + "outbound": "wg" + } + ], + "rules": [ + { + "protocol": "dns", + "outbound": "dns" + }, + { + "geoip": "cn", + "geosite": "cn", + "outbound": "direct" + } + ], + "auto_detect_interface": true + } +} +``` \ No newline at end of file diff --git a/inbound/tun.go b/inbound/tun.go index 8ea7b384..f18c61d2 100644 --- a/inbound/tun.go +++ b/inbound/tun.go @@ -19,7 +19,10 @@ import ( "github.com/sagernet/sing/common/ranges" ) -var _ adapter.Inbound = (*Tun)(nil) +var ( + _ adapter.Inbound = (*Tun)(nil) + _ tun.Router = (*Tun)(nil) +) type Tun struct { tag string @@ -38,10 +41,6 @@ type Tun struct { } func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunInboundOptions, platformInterface platform.Interface) (*Tun, error) { - tunName := options.InterfaceName - if tunName == "" { - tunName = tun.CalculateInterfaceName("") - } tunMTU := options.MTU if tunMTU == 0 { tunMTU = 9000 @@ -75,7 +74,7 @@ func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger logger: logger, inboundOptions: options.InboundOptions, tunOptions: tun.Options{ - Name: tunName, + Name: options.InterfaceName, MTU: tunMTU, Inet4Address: common.Map(options.Inet4Address, option.ListenPrefix.Build), Inet6Address: common.Map(options.Inet6Address, option.ListenPrefix.Build), @@ -141,12 +140,17 @@ func (t *Tun) Tag() string { func (t *Tun) Start() error { if C.IsAndroid && t.platformInterface == nil { + t.logger.Trace("building android rules") t.tunOptions.BuildAndroidRules(t.router.PackageManager(), t) } + if t.tunOptions.Name == "" { + t.tunOptions.Name = tun.CalculateInterfaceName("") + } var ( tunInterface tun.Tun err error ) + t.logger.Trace("opening interface") if t.platformInterface != nil { tunInterface, err = t.platformInterface.OpenTun(t.tunOptions, t.platformOptions) } else { @@ -155,7 +159,12 @@ func (t *Tun) Start() error { if err != nil { return E.Cause(err, "configure tun interface") } + t.logger.Trace("creating stack") t.tunIf = tunInterface + var tunRouter tun.Router + if len(t.router.IPRules()) > 0 { + tunRouter = t + } t.tunStack, err = tun.NewStack(t.stack, tun.StackOptions{ Context: t.ctx, Tun: tunInterface, @@ -165,6 +174,7 @@ func (t *Tun) Start() error { Inet6Address: t.tunOptions.Inet6Address, EndpointIndependentNat: t.endpointIndependentNat, UDPTimeout: t.udpTimeout, + Router: tunRouter, Handler: t, Logger: t.logger, UnderPlatform: t.platformInterface != nil, @@ -172,6 +182,7 @@ func (t *Tun) Start() error { if err != nil { return err } + t.logger.Trace("starting stack") err = t.tunStack.Start() if err != nil { return err @@ -187,6 +198,21 @@ func (t *Tun) Close() error { ) } +func (t *Tun) RouteConnection(session tun.RouteSession, conn tun.RouteContext) tun.RouteAction { + ctx := log.ContextWithNewID(t.ctx) + var metadata adapter.InboundContext + metadata.Inbound = t.tag + metadata.InboundType = C.TypeTun + metadata.IPVersion = session.IPVersion + metadata.Network = tun.NetworkName(session.Network) + metadata.Source = M.SocksaddrFromNetIP(session.Source) + metadata.Destination = M.SocksaddrFromNetIP(session.Destination) + metadata.InboundOptions = t.inboundOptions + t.logger.DebugContext(ctx, "incoming connection from ", metadata.Source) + t.logger.DebugContext(ctx, "incoming connection to ", metadata.Destination) + return t.router.RouteIPConnection(ctx, conn, metadata) +} + func (t *Tun) NewConnection(ctx context.Context, conn net.Conn, upstreamMetadata M.Metadata) error { ctx = log.ContextWithNewID(ctx) var metadata adapter.InboundContext diff --git a/mkdocs.yml b/mkdocs.yml index c458c47b..0843722e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -54,6 +54,7 @@ nav: - configuration/route/index.md - GeoIP: configuration/route/geoip.md - Geosite: configuration/route/geosite.md + - IP Route Rule: configuration/route/ip-rule.md - Route Rule: configuration/route/rule.md - Protocol Sniff: configuration/route/sniff.md - Experimental: @@ -169,6 +170,7 @@ plugins: DNS Rule: DNS 规则 Route: 路由 + IP Route Rule: IP 路由规则 Route Rule: 路由规则 Protocol Sniff: 协议探测 diff --git a/option/dns.go b/option/dns.go index d376f433..df8e7e70 100644 --- a/option/dns.go +++ b/option/dns.go @@ -1,14 +1,5 @@ package option -import ( - "reflect" - - "github.com/sagernet/sing-box/common/json" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" -) - type DNSOptions struct { Servers []DNSServerOptions `json:"servers,omitempty"` Rules []DNSRule `json:"rules,omitempty"` @@ -32,97 +23,3 @@ type DNSServerOptions struct { Strategy DomainStrategy `json:"strategy,omitempty"` Detour string `json:"detour,omitempty"` } - -type _DNSRule struct { - Type string `json:"type,omitempty"` - DefaultOptions DefaultDNSRule `json:"-"` - LogicalOptions LogicalDNSRule `json:"-"` -} - -type DNSRule _DNSRule - -func (r DNSRule) MarshalJSON() ([]byte, error) { - var v any - switch r.Type { - case C.RuleTypeDefault: - r.Type = "" - v = r.DefaultOptions - case C.RuleTypeLogical: - v = r.LogicalOptions - default: - return nil, E.New("unknown rule type: " + r.Type) - } - return MarshallObjects((_DNSRule)(r), v) -} - -func (r *DNSRule) UnmarshalJSON(bytes []byte) error { - err := json.Unmarshal(bytes, (*_DNSRule)(r)) - if err != nil { - return err - } - var v any - switch r.Type { - case "", C.RuleTypeDefault: - r.Type = C.RuleTypeDefault - v = &r.DefaultOptions - case C.RuleTypeLogical: - v = &r.LogicalOptions - default: - return E.New("unknown rule type: " + r.Type) - } - err = UnmarshallExcluded(bytes, (*_DNSRule)(r), v) - if err != nil { - return E.Cause(err, "dns route rule") - } - return nil -} - -type DefaultDNSRule struct { - Inbound Listable[string] `json:"inbound,omitempty"` - IPVersion int `json:"ip_version,omitempty"` - QueryType Listable[DNSQueryType] `json:"query_type,omitempty"` - Network string `json:"network,omitempty"` - AuthUser Listable[string] `json:"auth_user,omitempty"` - Protocol Listable[string] `json:"protocol,omitempty"` - Domain Listable[string] `json:"domain,omitempty"` - DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` - DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` - DomainRegex Listable[string] `json:"domain_regex,omitempty"` - Geosite Listable[string] `json:"geosite,omitempty"` - SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` - SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` - SourcePort Listable[uint16] `json:"source_port,omitempty"` - SourcePortRange Listable[string] `json:"source_port_range,omitempty"` - Port Listable[uint16] `json:"port,omitempty"` - PortRange Listable[string] `json:"port_range,omitempty"` - ProcessName Listable[string] `json:"process_name,omitempty"` - ProcessPath Listable[string] `json:"process_path,omitempty"` - PackageName Listable[string] `json:"package_name,omitempty"` - User Listable[string] `json:"user,omitempty"` - UserID Listable[int32] `json:"user_id,omitempty"` - Outbound Listable[string] `json:"outbound,omitempty"` - ClashMode string `json:"clash_mode,omitempty"` - Invert bool `json:"invert,omitempty"` - Server string `json:"server,omitempty"` - DisableCache bool `json:"disable_cache,omitempty"` -} - -func (r DefaultDNSRule) IsValid() bool { - var defaultValue DefaultDNSRule - defaultValue.Invert = r.Invert - defaultValue.Server = r.Server - defaultValue.DisableCache = r.DisableCache - return !reflect.DeepEqual(r, defaultValue) -} - -type LogicalDNSRule struct { - Mode string `json:"mode"` - Rules []DefaultDNSRule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Server string `json:"server,omitempty"` - DisableCache bool `json:"disable_cache,omitempty"` -} - -func (r LogicalDNSRule) IsValid() bool { - return len(r.Rules) > 0 && common.All(r.Rules, DefaultDNSRule.IsValid) -} diff --git a/option/route.go b/option/route.go index 308c4802..b32d4b3f 100644 --- a/option/route.go +++ b/option/route.go @@ -1,17 +1,9 @@ package option -import ( - "reflect" - - "github.com/sagernet/sing-box/common/json" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" -) - type RouteOptions struct { GeoIP *GeoIPOptions `json:"geoip,omitempty"` Geosite *GeositeOptions `json:"geosite,omitempty"` + IPRules []IPRule `json:"ip_rules,omitempty"` Rules []Rule `json:"rules,omitempty"` Final string `json:"final,omitempty"` FindProcess bool `json:"find_process,omitempty"` @@ -32,94 +24,3 @@ type GeositeOptions struct { DownloadURL string `json:"download_url,omitempty"` DownloadDetour string `json:"download_detour,omitempty"` } - -type _Rule struct { - Type string `json:"type,omitempty"` - DefaultOptions DefaultRule `json:"-"` - LogicalOptions LogicalRule `json:"-"` -} - -type Rule _Rule - -func (r Rule) MarshalJSON() ([]byte, error) { - var v any - switch r.Type { - case C.RuleTypeDefault: - r.Type = "" - v = r.DefaultOptions - case C.RuleTypeLogical: - v = r.LogicalOptions - default: - return nil, E.New("unknown rule type: " + r.Type) - } - return MarshallObjects((_Rule)(r), v) -} - -func (r *Rule) UnmarshalJSON(bytes []byte) error { - err := json.Unmarshal(bytes, (*_Rule)(r)) - if err != nil { - return err - } - var v any - switch r.Type { - case "", C.RuleTypeDefault: - r.Type = C.RuleTypeDefault - v = &r.DefaultOptions - case C.RuleTypeLogical: - v = &r.LogicalOptions - default: - return E.New("unknown rule type: " + r.Type) - } - err = UnmarshallExcluded(bytes, (*_Rule)(r), v) - if err != nil { - return E.Cause(err, "route rule") - } - return nil -} - -type DefaultRule struct { - Inbound Listable[string] `json:"inbound,omitempty"` - IPVersion int `json:"ip_version,omitempty"` - Network string `json:"network,omitempty"` - AuthUser Listable[string] `json:"auth_user,omitempty"` - Protocol Listable[string] `json:"protocol,omitempty"` - Domain Listable[string] `json:"domain,omitempty"` - DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` - DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` - DomainRegex Listable[string] `json:"domain_regex,omitempty"` - Geosite Listable[string] `json:"geosite,omitempty"` - SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` - GeoIP Listable[string] `json:"geoip,omitempty"` - SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` - IPCIDR Listable[string] `json:"ip_cidr,omitempty"` - SourcePort Listable[uint16] `json:"source_port,omitempty"` - SourcePortRange Listable[string] `json:"source_port_range,omitempty"` - Port Listable[uint16] `json:"port,omitempty"` - PortRange Listable[string] `json:"port_range,omitempty"` - ProcessName Listable[string] `json:"process_name,omitempty"` - ProcessPath Listable[string] `json:"process_path,omitempty"` - PackageName Listable[string] `json:"package_name,omitempty"` - User Listable[string] `json:"user,omitempty"` - UserID Listable[int32] `json:"user_id,omitempty"` - ClashMode string `json:"clash_mode,omitempty"` - Invert bool `json:"invert,omitempty"` - Outbound string `json:"outbound,omitempty"` -} - -func (r DefaultRule) IsValid() bool { - var defaultValue DefaultRule - defaultValue.Invert = r.Invert - defaultValue.Outbound = r.Outbound - return !reflect.DeepEqual(r, defaultValue) -} - -type LogicalRule struct { - Mode string `json:"mode"` - Rules []DefaultRule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Outbound string `json:"outbound,omitempty"` -} - -func (r LogicalRule) IsValid() bool { - return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid) -} diff --git a/option/rule.go b/option/rule.go new file mode 100644 index 00000000..f78a752d --- /dev/null +++ b/option/rule.go @@ -0,0 +1,101 @@ +package option + +import ( + "reflect" + + "github.com/sagernet/sing-box/common/json" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type _Rule struct { + Type string `json:"type,omitempty"` + DefaultOptions DefaultRule `json:"-"` + LogicalOptions LogicalRule `json:"-"` +} + +type Rule _Rule + +func (r Rule) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleTypeDefault: + r.Type = "" + v = r.DefaultOptions + case C.RuleTypeLogical: + v = r.LogicalOptions + default: + return nil, E.New("unknown rule type: " + r.Type) + } + return MarshallObjects((_Rule)(r), v) +} + +func (r *Rule) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_Rule)(r)) + if err != nil { + return err + } + var v any + switch r.Type { + case "", C.RuleTypeDefault: + r.Type = C.RuleTypeDefault + v = &r.DefaultOptions + case C.RuleTypeLogical: + v = &r.LogicalOptions + default: + return E.New("unknown rule type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_Rule)(r), v) + if err != nil { + return E.Cause(err, "route rule") + } + return nil +} + +type DefaultRule struct { + Inbound Listable[string] `json:"inbound,omitempty"` + IPVersion int `json:"ip_version,omitempty"` + Network Listable[string] `json:"network,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` + Protocol Listable[string] `json:"protocol,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + GeoIP Listable[string] `json:"geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + IPCIDR Listable[string] `json:"ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + ProcessPath Listable[string] `json:"process_path,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + User Listable[string] `json:"user,omitempty"` + UserID Listable[int32] `json:"user_id,omitempty"` + ClashMode string `json:"clash_mode,omitempty"` + Invert bool `json:"invert,omitempty"` + Outbound string `json:"outbound,omitempty"` +} + +func (r DefaultRule) IsValid() bool { + var defaultValue DefaultRule + defaultValue.Invert = r.Invert + defaultValue.Outbound = r.Outbound + return !reflect.DeepEqual(r, defaultValue) +} + +type LogicalRule struct { + Mode string `json:"mode"` + Rules []DefaultRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Outbound string `json:"outbound,omitempty"` +} + +func (r LogicalRule) IsValid() bool { + return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid) +} diff --git a/option/rule_dns.go b/option/rule_dns.go new file mode 100644 index 00000000..563d3085 --- /dev/null +++ b/option/rule_dns.go @@ -0,0 +1,107 @@ +package option + +import ( + "reflect" + + "github.com/sagernet/sing-box/common/json" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type _DNSRule struct { + Type string `json:"type,omitempty"` + DefaultOptions DefaultDNSRule `json:"-"` + LogicalOptions LogicalDNSRule `json:"-"` +} + +type DNSRule _DNSRule + +func (r DNSRule) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleTypeDefault: + r.Type = "" + v = r.DefaultOptions + case C.RuleTypeLogical: + v = r.LogicalOptions + default: + return nil, E.New("unknown rule type: " + r.Type) + } + return MarshallObjects((_DNSRule)(r), v) +} + +func (r *DNSRule) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_DNSRule)(r)) + if err != nil { + return err + } + var v any + switch r.Type { + case "", C.RuleTypeDefault: + r.Type = C.RuleTypeDefault + v = &r.DefaultOptions + case C.RuleTypeLogical: + v = &r.LogicalOptions + default: + return E.New("unknown rule type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_DNSRule)(r), v) + if err != nil { + return E.Cause(err, "dns route rule") + } + return nil +} + +type DefaultDNSRule struct { + Inbound Listable[string] `json:"inbound,omitempty"` + IPVersion int `json:"ip_version,omitempty"` + QueryType Listable[DNSQueryType] `json:"query_type,omitempty"` + Network Listable[string] `json:"network,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` + Protocol Listable[string] `json:"protocol,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + ProcessPath Listable[string] `json:"process_path,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + User Listable[string] `json:"user,omitempty"` + UserID Listable[int32] `json:"user_id,omitempty"` + Outbound Listable[string] `json:"outbound,omitempty"` + ClashMode string `json:"clash_mode,omitempty"` + Invert bool `json:"invert,omitempty"` + Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` + RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` +} + +func (r DefaultDNSRule) IsValid() bool { + var defaultValue DefaultDNSRule + defaultValue.Invert = r.Invert + defaultValue.Server = r.Server + defaultValue.DisableCache = r.DisableCache + defaultValue.RewriteTTL = r.RewriteTTL + return !reflect.DeepEqual(r, defaultValue) +} + +type LogicalDNSRule struct { + Mode string `json:"mode"` + Rules []DefaultDNSRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` + RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` +} + +func (r LogicalDNSRule) IsValid() bool { + return len(r.Rules) > 0 && common.All(r.Rules, DefaultDNSRule.IsValid) +} diff --git a/option/rule_ip.go b/option/rule_ip.go new file mode 100644 index 00000000..2a2c9dca --- /dev/null +++ b/option/rule_ip.go @@ -0,0 +1,120 @@ +package option + +import ( + "reflect" + + "github.com/sagernet/sing-box/common/json" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type _IPRule struct { + Type string `json:"type,omitempty"` + DefaultOptions DefaultIPRule `json:"-"` + LogicalOptions LogicalIPRule `json:"-"` +} + +type IPRule _IPRule + +func (r IPRule) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleTypeDefault: + r.Type = "" + v = r.DefaultOptions + case C.RuleTypeLogical: + v = r.LogicalOptions + default: + return nil, E.New("unknown rule type: " + r.Type) + } + return MarshallObjects((_IPRule)(r), v) +} + +func (r *IPRule) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_IPRule)(r)) + if err != nil { + return err + } + var v any + switch r.Type { + case "", C.RuleTypeDefault: + r.Type = C.RuleTypeDefault + v = &r.DefaultOptions + case C.RuleTypeLogical: + v = &r.LogicalOptions + default: + return E.New("unknown rule type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_IPRule)(r), v) + if err != nil { + return E.Cause(err, "ip route rule") + } + return nil +} + +type DefaultIPRule struct { + Inbound Listable[string] `json:"inbound,omitempty"` + IPVersion int `json:"ip_version,omitempty"` + Network Listable[string] `json:"network,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + GeoIP Listable[string] `json:"geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + IPCIDR Listable[string] `json:"ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + Invert bool `json:"invert,omitempty"` + Action RouteAction `json:"action,omitempty"` + Outbound string `json:"outbound,omitempty"` +} + +type RouteAction tun.ActionType + +func (a RouteAction) MarshalJSON() ([]byte, error) { + typeName, err := tun.ActionTypeName(tun.ActionType(a)) + if err != nil { + return nil, err + } + return json.Marshal(typeName) +} + +func (a *RouteAction) UnmarshalJSON(bytes []byte) error { + var value string + err := json.Unmarshal(bytes, &value) + if err != nil { + return err + } + actionType, err := tun.ParseActionType(value) + if err != nil { + return err + } + *a = RouteAction(actionType) + return nil +} + +func (r DefaultIPRule) IsValid() bool { + var defaultValue DefaultIPRule + defaultValue.Invert = r.Invert + defaultValue.Outbound = r.Outbound + return !reflect.DeepEqual(r, defaultValue) +} + +type LogicalIPRule struct { + Mode string `json:"mode"` + Rules []DefaultIPRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Action RouteAction `json:"action,omitempty"` + Outbound string `json:"outbound,omitempty"` +} + +func (r LogicalIPRule) IsValid() bool { + return len(r.Rules) > 0 && common.All(r.Rules, DefaultIPRule.IsValid) +} diff --git a/option/wireguard.go b/option/wireguard.go index ee6e1053..15639474 100644 --- a/option/wireguard.go +++ b/option/wireguard.go @@ -13,4 +13,5 @@ type WireGuardOutboundOptions struct { Workers int `json:"workers,omitempty"` MTU uint32 `json:"mtu,omitempty"` Network NetworkList `json:"network,omitempty"` + IPRewrite bool `json:"ip_rewrite,omitempty"` } diff --git a/outbound/wireguard.go b/outbound/wireguard.go index cdba9812..1878f46f 100644 --- a/outbound/wireguard.go +++ b/outbound/wireguard.go @@ -8,7 +8,9 @@ import ( "encoding/hex" "fmt" "net" + "os" "strings" + "syscall" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -26,7 +28,7 @@ import ( ) var ( - _ adapter.Outbound = (*WireGuard)(nil) + _ adapter.IPOutbound = (*WireGuard)(nil) _ adapter.InterfaceUpdateListener = (*WireGuard)(nil) ) @@ -34,6 +36,7 @@ type WireGuard struct { myOutboundAdapter bind *wireguard.ClientBind device *device.Device + natDevice wireguard.NatDevice tunDevice wireguard.Device } @@ -106,17 +109,25 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context if mtu == 0 { mtu = 1408 } - var wireTunDevice wireguard.Device + var tunDevice wireguard.Device var err error if !options.SystemInterface && tun.WithGVisor { - wireTunDevice, err = wireguard.NewStackDevice(localPrefixes, mtu) + tunDevice, err = wireguard.NewStackDevice(localPrefixes, mtu, options.IPRewrite) } else { - wireTunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, localPrefixes, mtu) + tunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, localPrefixes, mtu) } if err != nil { return nil, E.Cause(err, "create WireGuard device") } - wgDevice := device.NewDevice(wireTunDevice, outbound.bind, &device.Logger{ + natDevice, isNatDevice := tunDevice.(wireguard.NatDevice) + if !isNatDevice && router.NatRequired(tag) { + natDevice = wireguard.NewNATDevice(tunDevice, options.IPRewrite) + } + deviceInput := tunDevice + if natDevice != nil { + deviceInput = natDevice + } + wgDevice := device.NewDevice(deviceInput, outbound.bind, &device.Logger{ Verbosef: func(format string, args ...interface{}) { logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) }, @@ -132,7 +143,8 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context return nil, E.Cause(err, "setup wireguard") } outbound.device = wgDevice - outbound.tunDevice = wireTunDevice + outbound.natDevice = natDevice + outbound.tunDevice = tunDevice return outbound, nil } @@ -171,6 +183,27 @@ func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, return NewPacketConnection(ctx, w, conn, metadata) } +func (w *WireGuard) NewIPConnection(ctx context.Context, conn tun.RouteContext, metadata adapter.InboundContext) (tun.DirectDestination, error) { + if w.natDevice == nil { + return nil, os.ErrInvalid + } + session := tun.RouteSession{ + IPVersion: metadata.IPVersion, + Network: tun.NetworkFromName(metadata.Network), + Source: metadata.Source.AddrPort(), + Destination: metadata.Destination.AddrPort(), + } + switch session.Network { + case syscall.IPPROTO_TCP: + w.logger.InfoContext(ctx, "linked connection to ", metadata.Destination) + case syscall.IPPROTO_UDP: + w.logger.InfoContext(ctx, "linked packet connection to ", metadata.Destination) + default: + w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection to ", metadata.Destination.AddrString()) + } + return w.natDevice.CreateDestination(session, conn), nil +} + func (w *WireGuard) Start() error { return w.tunDevice.Start() } diff --git a/route/router.go b/route/router.go index 91013d41..5d739cf8 100644 --- a/route/router.go +++ b/route/router.go @@ -2,14 +2,11 @@ package route import ( "context" - "io" "net" - "net/http" "net/netip" "net/url" "os" "os/user" - "path/filepath" "strings" "time" @@ -38,7 +35,6 @@ import ( F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/uot" ) @@ -73,6 +69,7 @@ type Router struct { outbounds []adapter.Outbound outboundByTag map[string]adapter.Outbound rules []adapter.Rule + ipRules []adapter.IPRule defaultDetour string defaultOutboundForConnection adapter.Outbound defaultOutboundForPacketConnection adapter.Outbound @@ -130,6 +127,7 @@ func NewRouter( dnsLogger: logFactory.NewLogger("dns"), outboundByTag: make(map[string]adapter.Outbound), rules: make([]adapter.Rule, 0, len(options.Rules)), + ipRules: make([]adapter.IPRule, 0, len(options.IPRules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule), needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule), @@ -151,6 +149,13 @@ func NewRouter( } router.rules = append(router.rules, routeRule) } + for i, ipRuleOptions := range options.IPRules { + ipRule, err := NewIPRule(router, router.logger, ipRuleOptions) + if err != nil { + return nil, E.Cause(err, "parse ip rule[", i, "]") + } + router.ipRules = append(router.ipRules, ipRule) + } for i, dnsRuleOptions := range dnsOptions.Rules { dnsRule, err := NewDNSRule(router, router.logger, dnsRuleOptions) if err != nil { @@ -158,6 +163,7 @@ func NewRouter( } router.dnsRules = append(router.dnsRules, dnsRule) } + transports := make([]dns.Transport, len(dnsOptions.Servers)) dummyTransportMap := make(map[string]dns.Transport) transportMap := make(map[string]dns.Transport) @@ -547,27 +553,6 @@ func (r *Router) Close() error { return err } -func (r *Router) GeoIPReader() *geoip.Reader { - return r.geoIPReader -} - -func (r *Router) LoadGeosite(code string) (adapter.Rule, error) { - rule, cached := r.geositeCache[code] - if cached { - return rule, nil - } - items, err := r.geositeReader.Read(code) - if err != nil { - return nil, err - } - rule, err = NewDefaultRule(r, nil, geosite.Compile(items)) - if err != nil { - return nil, err - } - r.geositeCache[code] = rule - return rule, nil -} - func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { outbound, loaded := r.outboundByTag[tag] return outbound, loaded @@ -747,6 +732,13 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m } conn = bufio.NewCachedPacketConn(conn, buffer, destination) } + if r.dnsReverseMapping != nil && metadata.Domain == "" { + domain, loaded := r.dnsReverseMapping.Query(metadata.Destination.Addr) + if loaded { + metadata.Domain = domain + r.logger.DebugContext(ctx, "found reserve mapped domain: ", metadata.Domain) + } + } if metadata.Destination.IsFqdn() && dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS { addresses, err := r.Lookup(adapter.WithContext(ctx, &metadata), metadata.Destination.Fqdn, dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)) if err != nil { @@ -866,6 +858,10 @@ func (r *Router) Rules() []adapter.Rule { return r.rules } +func (r *Router) IPRules() []adapter.IPRule { + return r.ipRules +} + func (r *Router) NetworkMonitor() tun.NetworkUpdateMonitor { return r.networkMonitor } @@ -901,239 +897,6 @@ func (r *Router) SetV2RayServer(server adapter.V2RayServer) { r.v2rayServer = server } -func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { - for _, rule := range rules { - switch rule.Type { - case C.RuleTypeDefault: - if cond(rule.DefaultOptions) { - return true - } - case C.RuleTypeLogical: - for _, subRule := range rule.LogicalOptions.Rules { - if cond(subRule) { - return true - } - } - } - } - return false -} - -func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool { - for _, rule := range rules { - switch rule.Type { - case C.RuleTypeDefault: - if cond(rule.DefaultOptions) { - return true - } - case C.RuleTypeLogical: - for _, subRule := range rule.LogicalOptions.Rules { - if cond(subRule) { - return true - } - } - } - } - return false -} - -func isGeoIPRule(rule option.DefaultRule) bool { - return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode) -} - -func isGeoIPDNSRule(rule option.DefaultDNSRule) bool { - return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) -} - -func isGeositeRule(rule option.DefaultRule) bool { - return len(rule.Geosite) > 0 -} - -func isGeositeDNSRule(rule option.DefaultDNSRule) bool { - return len(rule.Geosite) > 0 -} - -func isProcessRule(rule option.DefaultRule) bool { - return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 -} - -func isProcessDNSRule(rule option.DefaultDNSRule) bool { - return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 -} - -func notPrivateNode(code string) bool { - return code != "private" -} - -func (r *Router) prepareGeoIPDatabase() error { - var geoPath string - if r.geoIPOptions.Path != "" { - geoPath = r.geoIPOptions.Path - } else { - geoPath = "geoip.db" - if foundPath, loaded := C.FindPath(geoPath); loaded { - geoPath = foundPath - } - } - geoPath = C.BasePath(geoPath) - if !rw.FileExists(geoPath) { - r.logger.Warn("geoip database not exists: ", geoPath) - var err error - for attempts := 0; attempts < 3; attempts++ { - err = r.downloadGeoIPDatabase(geoPath) - if err == nil { - break - } - r.logger.Error("download geoip database: ", err) - os.Remove(geoPath) - // time.Sleep(10 * time.Second) - } - if err != nil { - return err - } - } - geoReader, codes, err := geoip.Open(geoPath) - if err != nil { - return E.Cause(err, "open geoip database") - } - r.logger.Info("loaded geoip database: ", len(codes), " codes") - r.geoIPReader = geoReader - return nil -} - -func (r *Router) prepareGeositeDatabase() error { - var geoPath string - if r.geositeOptions.Path != "" { - geoPath = r.geositeOptions.Path - } else { - geoPath = "geosite.db" - if foundPath, loaded := C.FindPath(geoPath); loaded { - geoPath = foundPath - } - } - geoPath = C.BasePath(geoPath) - if !rw.FileExists(geoPath) { - r.logger.Warn("geosite database not exists: ", geoPath) - var err error - for attempts := 0; attempts < 3; attempts++ { - err = r.downloadGeositeDatabase(geoPath) - if err == nil { - break - } - r.logger.Error("download geosite database: ", err) - os.Remove(geoPath) - // time.Sleep(10 * time.Second) - } - if err != nil { - return err - } - } - geoReader, codes, err := geosite.Open(geoPath) - if err == nil { - r.logger.Info("loaded geosite database: ", len(codes), " codes") - r.geositeReader = geoReader - } else { - return E.Cause(err, "open geosite database") - } - return nil -} - -func (r *Router) downloadGeoIPDatabase(savePath string) error { - var downloadURL string - if r.geoIPOptions.DownloadURL != "" { - downloadURL = r.geoIPOptions.DownloadURL - } else { - downloadURL = "https://github.com/SagerNet/sing-geoip/releases/latest/download/geoip.db" - } - r.logger.Info("downloading geoip database") - var detour adapter.Outbound - if r.geoIPOptions.DownloadDetour != "" { - outbound, loaded := r.Outbound(r.geoIPOptions.DownloadDetour) - if !loaded { - return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) - } - detour = outbound - } else { - detour = r.defaultOutboundForConnection - } - - if parentDir := filepath.Dir(savePath); parentDir != "" { - os.MkdirAll(parentDir, 0o755) - } - - saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return E.Cause(err, "open output file: ", downloadURL) - } - defer saveFile.Close() - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: 5 * time.Second, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - defer httpClient.CloseIdleConnections() - response, err := httpClient.Get(downloadURL) - if err != nil { - return err - } - defer response.Body.Close() - _, err = io.Copy(saveFile, response.Body) - return err -} - -func (r *Router) downloadGeositeDatabase(savePath string) error { - var downloadURL string - if r.geositeOptions.DownloadURL != "" { - downloadURL = r.geositeOptions.DownloadURL - } else { - downloadURL = "https://github.com/SagerNet/sing-geosite/releases/latest/download/geosite.db" - } - r.logger.Info("downloading geosite database") - var detour adapter.Outbound - if r.geositeOptions.DownloadDetour != "" { - outbound, loaded := r.Outbound(r.geositeOptions.DownloadDetour) - if !loaded { - return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour) - } - detour = outbound - } else { - detour = r.defaultOutboundForConnection - } - - if parentDir := filepath.Dir(savePath); parentDir != "" { - os.MkdirAll(parentDir, 0o755) - } - - saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return E.Cause(err, "open output file: ", downloadURL) - } - defer saveFile.Close() - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: 5 * time.Second, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - defer httpClient.CloseIdleConnections() - response, err := httpClient.Get(downloadURL) - if err != nil { - return err - } - defer response.Body.Close() - _, err = io.Copy(saveFile, response.Body) - return err -} - func (r *Router) OnPackagesUpdated(packages int, sharedUsers int) { r.logger.Info("updated packages list: ", packages, " packages, ", sharedUsers, " shared users") } diff --git a/route/router_dns.go b/route/router_dns.go index 11c02c87..d343fb8b 100644 --- a/route/router_dns.go +++ b/route/router_dns.go @@ -47,6 +47,9 @@ func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport, if rule.DisableCache() { ctx = dns.ContextWithDisableCache(ctx, true) } + if rewriteTTL := rule.RewriteTTL(); rewriteTTL != nil { + ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL) + } detour := rule.Outbound() r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) if transport, loaded := r.transportMap[detour]; loaded { diff --git a/route/router_geo_resources.go b/route/router_geo_resources.go new file mode 100644 index 00000000..a72b4bad --- /dev/null +++ b/route/router_geo_resources.go @@ -0,0 +1,283 @@ +package route + +import ( + "context" + "io" + "net" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/geoip" + "github.com/sagernet/sing-box/common/geosite" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/rw" +) + +func (r *Router) GeoIPReader() *geoip.Reader { + return r.geoIPReader +} + +func (r *Router) LoadGeosite(code string) (adapter.Rule, error) { + rule, cached := r.geositeCache[code] + if cached { + return rule, nil + } + items, err := r.geositeReader.Read(code) + if err != nil { + return nil, err + } + rule, err = NewDefaultRule(r, nil, geosite.Compile(items)) + if err != nil { + return nil, err + } + r.geositeCache[code] = rule + return rule, nil +} + +func (r *Router) prepareGeoIPDatabase() error { + var geoPath string + if r.geoIPOptions.Path != "" { + geoPath = r.geoIPOptions.Path + } else { + geoPath = "geoip.db" + if foundPath, loaded := C.FindPath(geoPath); loaded { + geoPath = foundPath + } + } + geoPath = C.BasePath(geoPath) + if rw.FileExists(geoPath) { + geoReader, codes, err := geoip.Open(geoPath) + if err == nil { + r.logger.Info("loaded geoip database: ", len(codes), " codes") + r.geoIPReader = geoReader + return nil + } + } + if !rw.FileExists(geoPath) { + r.logger.Warn("geoip database not exists: ", geoPath) + var err error + for attempts := 0; attempts < 3; attempts++ { + err = r.downloadGeoIPDatabase(geoPath) + if err == nil { + break + } + r.logger.Error("download geoip database: ", err) + os.Remove(geoPath) + // time.Sleep(10 * time.Second) + } + if err != nil { + return err + } + } + geoReader, codes, err := geoip.Open(geoPath) + if err != nil { + return E.Cause(err, "open geoip database") + } + r.logger.Info("loaded geoip database: ", len(codes), " codes") + r.geoIPReader = geoReader + return nil +} + +func (r *Router) prepareGeositeDatabase() error { + var geoPath string + if r.geositeOptions.Path != "" { + geoPath = r.geositeOptions.Path + } else { + geoPath = "geosite.db" + if foundPath, loaded := C.FindPath(geoPath); loaded { + geoPath = foundPath + } + } + geoPath = C.BasePath(geoPath) + if !rw.FileExists(geoPath) { + r.logger.Warn("geosite database not exists: ", geoPath) + var err error + for attempts := 0; attempts < 3; attempts++ { + err = r.downloadGeositeDatabase(geoPath) + if err == nil { + break + } + r.logger.Error("download geosite database: ", err) + os.Remove(geoPath) + // time.Sleep(10 * time.Second) + } + if err != nil { + return err + } + } + geoReader, codes, err := geosite.Open(geoPath) + if err == nil { + r.logger.Info("loaded geosite database: ", len(codes), " codes") + r.geositeReader = geoReader + } else { + return E.Cause(err, "open geosite database") + } + return nil +} + +func (r *Router) downloadGeoIPDatabase(savePath string) error { + var downloadURL string + if r.geoIPOptions.DownloadURL != "" { + downloadURL = r.geoIPOptions.DownloadURL + } else { + downloadURL = "https://github.com/SagerNet/sing-geoip/releases/latest/download/geoip.db" + } + r.logger.Info("downloading geoip database") + var detour adapter.Outbound + if r.geoIPOptions.DownloadDetour != "" { + outbound, loaded := r.Outbound(r.geoIPOptions.DownloadDetour) + if !loaded { + return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) + } + detour = outbound + } else { + detour = r.defaultOutboundForConnection + } + + if parentDir := filepath.Dir(savePath); parentDir != "" { + os.MkdirAll(parentDir, 0o755) + } + + saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return E.Cause(err, "open output file: ", downloadURL) + } + defer saveFile.Close() + + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: 5 * time.Second, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + defer httpClient.CloseIdleConnections() + response, err := httpClient.Get(downloadURL) + if err != nil { + return err + } + defer response.Body.Close() + _, err = io.Copy(saveFile, response.Body) + return err +} + +func (r *Router) downloadGeositeDatabase(savePath string) error { + var downloadURL string + if r.geositeOptions.DownloadURL != "" { + downloadURL = r.geositeOptions.DownloadURL + } else { + downloadURL = "https://github.com/SagerNet/sing-geosite/releases/latest/download/geosite.db" + } + r.logger.Info("downloading geosite database") + var detour adapter.Outbound + if r.geositeOptions.DownloadDetour != "" { + outbound, loaded := r.Outbound(r.geositeOptions.DownloadDetour) + if !loaded { + return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour) + } + detour = outbound + } else { + detour = r.defaultOutboundForConnection + } + + if parentDir := filepath.Dir(savePath); parentDir != "" { + os.MkdirAll(parentDir, 0o755) + } + + saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return E.Cause(err, "open output file: ", downloadURL) + } + defer saveFile.Close() + + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: 5 * time.Second, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + defer httpClient.CloseIdleConnections() + response, err := httpClient.Get(downloadURL) + if err != nil { + return err + } + defer response.Body.Close() + _, err = io.Copy(saveFile, response.Body) + return err +} + +func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { + for _, rule := range rules { + switch rule.Type { + case C.RuleTypeDefault: + if cond(rule.DefaultOptions) { + return true + } + case C.RuleTypeLogical: + for _, subRule := range rule.LogicalOptions.Rules { + if cond(subRule) { + return true + } + } + } + } + return false +} + +func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool { + for _, rule := range rules { + switch rule.Type { + case C.RuleTypeDefault: + if cond(rule.DefaultOptions) { + return true + } + case C.RuleTypeLogical: + for _, subRule := range rule.LogicalOptions.Rules { + if cond(subRule) { + return true + } + } + } + } + return false +} + +func isGeoIPRule(rule option.DefaultRule) bool { + return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode) +} + +func isGeoIPDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) +} + +func isGeositeRule(rule option.DefaultRule) bool { + return len(rule.Geosite) > 0 +} + +func isGeositeDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.Geosite) > 0 +} + +func isProcessRule(rule option.DefaultRule) bool { + return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 +} + +func isProcessDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 +} + +func notPrivateNode(code string) bool { + return code != "private" +} diff --git a/route/router_ip.go b/route/router_ip.go new file mode 100644 index 00000000..5234b4f5 --- /dev/null +++ b/route/router_ip.go @@ -0,0 +1,66 @@ +package route + +import ( + "context" + "strings" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-dns" + "github.com/sagernet/sing-tun" + F "github.com/sagernet/sing/common/format" +) + +func (r *Router) RouteIPConnection(ctx context.Context, conn tun.RouteContext, metadata adapter.InboundContext) tun.RouteAction { + if r.dnsReverseMapping != nil && metadata.Domain == "" { + domain, loaded := r.dnsReverseMapping.Query(metadata.Destination.Addr) + if loaded { + metadata.Domain = domain + r.logger.DebugContext(ctx, "found reserve mapped domain: ", metadata.Domain) + } + } + if metadata.Destination.IsFqdn() && dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS { + addresses, err := r.Lookup(adapter.WithContext(ctx, &metadata), metadata.Destination.Fqdn, dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)) + if err != nil { + r.logger.ErrorContext(ctx, err) + return (*tun.ActionReturn)(nil) + } + metadata.DestinationAddresses = addresses + r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") + } + for i, rule := range r.ipRules { + if rule.Match(&metadata) { + if rule.Action() == tun.ActionTypeBlock { + r.logger.InfoContext(ctx, "match[", i, "] ", rule.String(), " => block") + return (*tun.ActionBlock)(nil) + } + detour := rule.Outbound() + r.logger.InfoContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) + outbound, loaded := r.Outbound(detour) + if !loaded { + r.logger.ErrorContext(ctx, "outbound not found: ", detour) + break + } + ipOutbound, loaded := outbound.(adapter.IPOutbound) + if !loaded { + r.logger.ErrorContext(ctx, "outbound have no ip connection support: ", detour) + break + } + destination, err := ipOutbound.NewIPConnection(ctx, conn, metadata) + if err != nil { + r.logger.ErrorContext(ctx, err) + break + } + return &tun.ActionDirect{DirectDestination: destination} + } + } + return (*tun.ActionReturn)(nil) +} + +func (r *Router) NatRequired(outbound string) bool { + for _, ipRule := range r.ipRules { + if ipRule.Outbound() == outbound { + return true + } + } + return false +} diff --git a/route/rule_abstract.go b/route/rule_abstract.go new file mode 100644 index 00000000..be41401f --- /dev/null +++ b/route/rule_abstract.go @@ -0,0 +1,203 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" +) + +type abstractDefaultRule struct { + items []RuleItem + sourceAddressItems []RuleItem + sourcePortItems []RuleItem + destinationAddressItems []RuleItem + destinationPortItems []RuleItem + allItems []RuleItem + invert bool + outbound string +} + +func (r *abstractDefaultRule) Type() string { + return C.RuleTypeDefault +} + +func (r *abstractDefaultRule) Start() error { + for _, item := range r.allItems { + err := common.Start(item) + if err != nil { + return err + } + } + return nil +} + +func (r *abstractDefaultRule) Close() error { + for _, item := range r.allItems { + err := common.Close(item) + if err != nil { + return err + } + } + return nil +} + +func (r *abstractDefaultRule) UpdateGeosite() error { + for _, item := range r.allItems { + if geositeItem, isSite := item.(*GeositeItem); isSite { + err := geositeItem.Update() + if err != nil { + return err + } + } + } + return nil +} + +func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool { + for _, item := range r.items { + if !item.Match(metadata) { + return r.invert + } + } + + if len(r.sourceAddressItems) > 0 { + var sourceAddressMatch bool + for _, item := range r.sourceAddressItems { + if item.Match(metadata) { + sourceAddressMatch = true + break + } + } + if !sourceAddressMatch { + return r.invert + } + } + + if len(r.sourcePortItems) > 0 { + var sourcePortMatch bool + for _, item := range r.sourcePortItems { + if item.Match(metadata) { + sourcePortMatch = true + break + } + } + if !sourcePortMatch { + return r.invert + } + } + + if len(r.destinationAddressItems) > 0 { + var destinationAddressMatch bool + for _, item := range r.destinationAddressItems { + if item.Match(metadata) { + destinationAddressMatch = true + break + } + } + if !destinationAddressMatch { + return r.invert + } + } + + if len(r.destinationPortItems) > 0 { + var destinationPortMatch bool + for _, item := range r.destinationPortItems { + if item.Match(metadata) { + destinationPortMatch = true + break + } + } + if !destinationPortMatch { + return r.invert + } + } + + return !r.invert +} + +func (r *abstractDefaultRule) Outbound() string { + return r.outbound +} + +func (r *abstractDefaultRule) String() string { + if !r.invert { + return strings.Join(F.MapToString(r.allItems), " ") + } else { + return "!(" + strings.Join(F.MapToString(r.allItems), " ") + ")" + } +} + +type abstractLogicalRule struct { + rules []adapter.Rule + mode string + invert bool + outbound string +} + +func (r *abstractLogicalRule) Type() string { + return C.RuleTypeLogical +} + +func (r *abstractLogicalRule) UpdateGeosite() error { + for _, rule := range r.rules { + err := rule.UpdateGeosite() + if err != nil { + return err + } + } + return nil +} + +func (r *abstractLogicalRule) Start() error { + for _, rule := range r.rules { + err := rule.Start() + if err != nil { + return err + } + } + return nil +} + +func (r *abstractLogicalRule) Close() error { + for _, rule := range r.rules { + err := rule.Close() + if err != nil { + return err + } + } + return nil +} + +func (r *abstractLogicalRule) Match(metadata *adapter.InboundContext) bool { + if r.mode == C.LogicalTypeAnd { + return common.All(r.rules, func(it adapter.Rule) bool { + return it.Match(metadata) + }) != r.invert + } else { + return common.Any(r.rules, func(it adapter.Rule) bool { + return it.Match(metadata) + }) != r.invert + } +} + +func (r *abstractLogicalRule) Outbound() string { + return r.outbound +} + +func (r *abstractLogicalRule) String() string { + var op string + switch r.mode { + case C.LogicalTypeAnd: + op = "&&" + case C.LogicalTypeOr: + op = "||" + } + if !r.invert { + return strings.Join(F.MapToString(r.rules), " "+op+" ") + } else { + return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" + } +} diff --git a/route/rule.go b/route/rule_default.go similarity index 61% rename from route/rule.go rename to route/rule_default.go index 6f2f3baa..01322c13 100644 --- a/route/rule.go +++ b/route/rule_default.go @@ -1,16 +1,11 @@ package route import ( - "strings" - "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - N "github.com/sagernet/sing/common/network" ) func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.Rule, error) { @@ -39,14 +34,7 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul var _ adapter.Rule = (*DefaultRule)(nil) type DefaultRule struct { - items []RuleItem - sourceAddressItems []RuleItem - sourcePortItems []RuleItem - destinationAddressItems []RuleItem - destinationPortItems []RuleItem - allItems []RuleItem - invert bool - outbound string + abstractDefaultRule } type RuleItem interface { @@ -56,8 +44,10 @@ type RuleItem interface { func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { rule := &DefaultRule{ - invert: options.Invert, - outbound: options.Outbound, + abstractDefaultRule{ + invert: options.Invert, + outbound: options.Outbound, + }, } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -74,15 +64,10 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt return nil, E.New("invalid ip version: ", options.IPVersion) } } - if options.Network != "" { - switch options.Network { - case N.NetworkTCP, N.NetworkUDP: - item := NewNetworkItem(options.Network) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - default: - return nil, E.New("invalid network: ", options.Network) - } + if len(options.Network) > 0 { + item := NewNetworkItem(options.Network) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.AuthUser) > 0 { item := NewAuthUserItem(options.AuthUser) @@ -202,130 +187,19 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt return rule, nil } -func (r *DefaultRule) Type() string { - return C.RuleTypeDefault -} - -func (r *DefaultRule) Start() error { - for _, item := range r.allItems { - err := common.Start(item) - if err != nil { - return err - } - } - return nil -} - -func (r *DefaultRule) Close() error { - for _, item := range r.allItems { - err := common.Close(item) - if err != nil { - return err - } - } - return nil -} - -func (r *DefaultRule) UpdateGeosite() error { - for _, item := range r.allItems { - if geositeItem, isSite := item.(*GeositeItem); isSite { - err := geositeItem.Update() - if err != nil { - return err - } - } - } - return nil -} - -func (r *DefaultRule) Match(metadata *adapter.InboundContext) bool { - for _, item := range r.items { - if !item.Match(metadata) { - return r.invert - } - } - - if len(r.sourceAddressItems) > 0 { - var sourceAddressMatch bool - for _, item := range r.sourceAddressItems { - if item.Match(metadata) { - sourceAddressMatch = true - break - } - } - if !sourceAddressMatch { - return r.invert - } - } - - if len(r.sourcePortItems) > 0 { - var sourcePortMatch bool - for _, item := range r.sourcePortItems { - if item.Match(metadata) { - sourcePortMatch = true - break - } - } - if !sourcePortMatch { - return r.invert - } - } - - if len(r.destinationAddressItems) > 0 { - var destinationAddressMatch bool - for _, item := range r.destinationAddressItems { - if item.Match(metadata) { - destinationAddressMatch = true - break - } - } - if !destinationAddressMatch { - return r.invert - } - } - - if len(r.destinationPortItems) > 0 { - var destinationPortMatch bool - for _, item := range r.destinationPortItems { - if item.Match(metadata) { - destinationPortMatch = true - break - } - } - if !destinationPortMatch { - return r.invert - } - } - - return !r.invert -} - -func (r *DefaultRule) Outbound() string { - return r.outbound -} - -func (r *DefaultRule) String() string { - if !r.invert { - return strings.Join(F.MapToString(r.allItems), " ") - } else { - return "!(" + strings.Join(F.MapToString(r.allItems), " ") + ")" - } -} - var _ adapter.Rule = (*LogicalRule)(nil) type LogicalRule struct { - mode string - rules []*DefaultRule - invert bool - outbound string + abstractLogicalRule } func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { r := &LogicalRule{ - rules: make([]*DefaultRule, len(options.Rules)), - invert: options.Invert, - outbound: options.Outbound, + abstractLogicalRule{ + rules: make([]adapter.Rule, len(options.Rules)), + invert: options.Invert, + outbound: options.Outbound, + }, } switch options.Mode { case C.LogicalTypeAnd: @@ -344,68 +218,3 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt } return r, nil } - -func (r *LogicalRule) Type() string { - return C.RuleTypeLogical -} - -func (r *LogicalRule) UpdateGeosite() error { - for _, rule := range r.rules { - err := rule.UpdateGeosite() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalRule) Start() error { - for _, rule := range r.rules { - err := rule.Start() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalRule) Close() error { - for _, rule := range r.rules { - err := rule.Close() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool { - if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it *DefaultRule) bool { - return it.Match(metadata) - }) != r.invert - } else { - return common.Any(r.rules, func(it *DefaultRule) bool { - return it.Match(metadata) - }) != r.invert - } -} - -func (r *LogicalRule) Outbound() string { - return r.outbound -} - -func (r *LogicalRule) String() string { - var op string - switch r.mode { - case C.LogicalTypeAnd: - op = "&&" - case C.LogicalTypeOr: - op = "||" - } - if !r.invert { - return strings.Join(F.MapToString(r.rules), " "+op+" ") - } else { - return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" - } -} diff --git a/route/rule_dns.go b/route/rule_dns.go index 3bfdb729..15e4b16f 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -1,16 +1,11 @@ package route import ( - "strings" - "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - N "github.com/sagernet/sing/common/network" ) func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) { @@ -39,22 +34,19 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option. var _ adapter.DNSRule = (*DefaultDNSRule)(nil) type DefaultDNSRule struct { - items []RuleItem - sourceAddressItems []RuleItem - sourcePortItems []RuleItem - destinationAddressItems []RuleItem - destinationPortItems []RuleItem - allItems []RuleItem - invert bool - outbound string - disableCache bool + abstractDefaultRule + disableCache bool + rewriteTTL *uint32 } func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { rule := &DefaultDNSRule{ - invert: options.Invert, - outbound: options.Server, + abstractDefaultRule: abstractDefaultRule{ + invert: options.Invert, + outbound: options.Server, + }, disableCache: options.DisableCache, + rewriteTTL: options.RewriteTTL, } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -76,15 +68,10 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } - if options.Network != "" { - switch options.Network { - case N.NetworkTCP, N.NetworkUDP: - item := NewNetworkItem(options.Network) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - default: - return nil, E.New("invalid network: ", options.Network) - } + if len(options.Network) > 0 { + item := NewNetworkItem(options.Network) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.AuthUser) > 0 { item := NewAuthUserItem(options.AuthUser) @@ -196,132 +183,31 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options return rule, nil } -func (r *DefaultDNSRule) Type() string { - return C.RuleTypeDefault -} - -func (r *DefaultDNSRule) Start() error { - for _, item := range r.allItems { - err := common.Start(item) - if err != nil { - return err - } - } - return nil -} - -func (r *DefaultDNSRule) Close() error { - for _, item := range r.allItems { - err := common.Close(item) - if err != nil { - return err - } - } - return nil -} - -func (r *DefaultDNSRule) UpdateGeosite() error { - for _, item := range r.allItems { - if geositeItem, isSite := item.(*GeositeItem); isSite { - err := geositeItem.Update() - if err != nil { - return err - } - } - } - return nil -} - -func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { - for _, item := range r.items { - if !item.Match(metadata) { - return r.invert - } - } - - if len(r.sourceAddressItems) > 0 { - var sourceAddressMatch bool - for _, item := range r.sourceAddressItems { - if item.Match(metadata) { - sourceAddressMatch = true - break - } - } - if !sourceAddressMatch { - return r.invert - } - } - - if len(r.sourcePortItems) > 0 { - var sourcePortMatch bool - for _, item := range r.sourcePortItems { - if item.Match(metadata) { - sourcePortMatch = true - break - } - } - if !sourcePortMatch { - return r.invert - } - } - - if len(r.destinationAddressItems) > 0 { - var destinationAddressMatch bool - for _, item := range r.destinationAddressItems { - if item.Match(metadata) { - destinationAddressMatch = true - break - } - } - if !destinationAddressMatch { - return r.invert - } - } - - if len(r.destinationPortItems) > 0 { - var destinationPortMatch bool - for _, item := range r.destinationPortItems { - if item.Match(metadata) { - destinationPortMatch = true - break - } - } - if !destinationPortMatch { - return r.invert - } - } - - return !r.invert -} - -func (r *DefaultDNSRule) Outbound() string { - return r.outbound -} - func (r *DefaultDNSRule) DisableCache() bool { return r.disableCache } -func (r *DefaultDNSRule) String() string { - return strings.Join(F.MapToString(r.allItems), " ") +func (r *DefaultDNSRule) RewriteTTL() *uint32 { + return r.rewriteTTL } var _ adapter.DNSRule = (*LogicalDNSRule)(nil) type LogicalDNSRule struct { - mode string - rules []*DefaultDNSRule - invert bool - outbound string + abstractLogicalRule disableCache bool + rewriteTTL *uint32 } func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { r := &LogicalDNSRule{ - rules: make([]*DefaultDNSRule, len(options.Rules)), - invert: options.Invert, - outbound: options.Server, + abstractLogicalRule: abstractLogicalRule{ + rules: make([]adapter.Rule, len(options.Rules)), + invert: options.Invert, + outbound: options.Server, + }, disableCache: options.DisableCache, + rewriteTTL: options.RewriteTTL, } switch options.Mode { case C.LogicalTypeAnd: @@ -341,71 +227,10 @@ func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options return r, nil } -func (r *LogicalDNSRule) Type() string { - return C.RuleTypeLogical -} - -func (r *LogicalDNSRule) UpdateGeosite() error { - for _, rule := range r.rules { - err := rule.UpdateGeosite() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalDNSRule) Start() error { - for _, rule := range r.rules { - err := rule.Start() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalDNSRule) Close() error { - for _, rule := range r.rules { - err := rule.Close() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { - if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it *DefaultDNSRule) bool { - return it.Match(metadata) - }) != r.invert - } else { - return common.Any(r.rules, func(it *DefaultDNSRule) bool { - return it.Match(metadata) - }) != r.invert - } -} - -func (r *LogicalDNSRule) Outbound() string { - return r.outbound -} - func (r *LogicalDNSRule) DisableCache() bool { return r.disableCache } -func (r *LogicalDNSRule) String() string { - var op string - switch r.mode { - case C.LogicalTypeAnd: - op = "&&" - case C.LogicalTypeOr: - op = "||" - } - if !r.invert { - return strings.Join(F.MapToString(r.rules), " "+op+" ") - } else { - return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" - } +func (r *LogicalDNSRule) RewriteTTL() *uint32 { + return r.rewriteTTL } diff --git a/route/rule_ip.go b/route/rule_ip.go new file mode 100644 index 00000000..65e77cd0 --- /dev/null +++ b/route/rule_ip.go @@ -0,0 +1,189 @@ +package route + +import ( + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +func NewIPRule(router adapter.Router, logger log.ContextLogger, options option.IPRule) (adapter.IPRule, error) { + switch options.Type { + case "", C.RuleTypeDefault: + if !options.DefaultOptions.IsValid() { + return nil, E.New("missing conditions") + } + if common.IsEmpty(options.DefaultOptions.Action) { + return nil, E.New("missing action") + } + return NewDefaultIPRule(router, logger, options.DefaultOptions) + case C.RuleTypeLogical: + if !options.LogicalOptions.IsValid() { + return nil, E.New("missing conditions") + } + if common.IsEmpty(options.DefaultOptions.Action) { + return nil, E.New("missing action") + } + return NewLogicalIPRule(router, logger, options.LogicalOptions) + default: + return nil, E.New("unknown rule type: ", options.Type) + } +} + +var _ adapter.IPRule = (*DefaultIPRule)(nil) + +type DefaultIPRule struct { + abstractDefaultRule + action tun.ActionType +} + +func NewDefaultIPRule(router adapter.Router, logger log.ContextLogger, options option.DefaultIPRule) (*DefaultIPRule, error) { + rule := &DefaultIPRule{ + abstractDefaultRule: abstractDefaultRule{ + invert: options.Invert, + outbound: options.Outbound, + }, + action: tun.ActionType(options.Action), + } + if len(options.Inbound) > 0 { + item := NewInboundRule(options.Inbound) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if options.IPVersion > 0 { + switch options.IPVersion { + case 4, 6: + item := NewIPVersionItem(options.IPVersion == 6) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + default: + return nil, E.New("invalid ip version: ", options.IPVersion) + } + } + if len(options.Network) > 0 { + item := NewNetworkItem(options.Network) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 { + item := NewDomainItem(options.Domain, options.DomainSuffix) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.DomainKeyword) > 0 { + item := NewDomainKeywordItem(options.DomainKeyword) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.DomainRegex) > 0 { + item, err := NewDomainRegexItem(options.DomainRegex) + if err != nil { + return nil, E.Cause(err, "domain_regex") + } + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.Geosite) > 0 { + item := NewGeositeItem(router, logger, options.Geosite) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceGeoIP) > 0 { + item := NewGeoIPItem(router, logger, true, options.SourceGeoIP) + rule.sourceAddressItems = append(rule.sourceAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.GeoIP) > 0 { + item := NewGeoIPItem(router, logger, false, options.GeoIP) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceIPCIDR) > 0 { + item, err := NewIPCIDRItem(true, options.SourceIPCIDR) + if err != nil { + return nil, E.Cause(err, "source_ipcidr") + } + rule.sourceAddressItems = append(rule.sourceAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.IPCIDR) > 0 { + item, err := NewIPCIDRItem(false, options.IPCIDR) + if err != nil { + return nil, E.Cause(err, "ipcidr") + } + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourcePort) > 0 { + item := NewPortItem(true, options.SourcePort) + rule.sourcePortItems = append(rule.sourcePortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourcePortRange) > 0 { + item, err := NewPortRangeItem(true, options.SourcePortRange) + if err != nil { + return nil, E.Cause(err, "source_port_range") + } + rule.sourcePortItems = append(rule.sourcePortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.Port) > 0 { + item := NewPortItem(false, options.Port) + rule.destinationPortItems = append(rule.destinationPortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.PortRange) > 0 { + item, err := NewPortRangeItem(false, options.PortRange) + if err != nil { + return nil, E.Cause(err, "port_range") + } + rule.destinationPortItems = append(rule.destinationPortItems, item) + rule.allItems = append(rule.allItems, item) + } + return rule, nil +} + +func (r *DefaultIPRule) Action() tun.ActionType { + return r.action +} + +var _ adapter.IPRule = (*LogicalIPRule)(nil) + +type LogicalIPRule struct { + abstractLogicalRule + action tun.ActionType +} + +func NewLogicalIPRule(router adapter.Router, logger log.ContextLogger, options option.LogicalIPRule) (*LogicalIPRule, error) { + r := &LogicalIPRule{ + abstractLogicalRule: abstractLogicalRule{ + rules: make([]adapter.Rule, len(options.Rules)), + invert: options.Invert, + outbound: options.Outbound, + }, + action: tun.ActionType(options.Action), + } + switch options.Mode { + case C.LogicalTypeAnd: + r.mode = C.LogicalTypeAnd + case C.LogicalTypeOr: + r.mode = C.LogicalTypeOr + default: + return nil, E.New("unknown logical mode: ", options.Mode) + } + for i, subRule := range options.Rules { + rule, err := NewDefaultIPRule(router, logger, subRule) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } + r.rules[i] = rule + } + return r, nil +} + +func (r *LogicalIPRule) Action() tun.ActionType { + return r.action +} diff --git a/route/rule_auth_user.go b/route/rule_item_auth_user.go similarity index 100% rename from route/rule_auth_user.go rename to route/rule_item_auth_user.go diff --git a/route/rule_cidr.go b/route/rule_item_cidr.go similarity index 100% rename from route/rule_cidr.go rename to route/rule_item_cidr.go diff --git a/route/rule_clash_mode.go b/route/rule_item_clash_mode.go similarity index 100% rename from route/rule_clash_mode.go rename to route/rule_item_clash_mode.go diff --git a/route/rule_domain.go b/route/rule_item_domain.go similarity index 100% rename from route/rule_domain.go rename to route/rule_item_domain.go diff --git a/route/rule_domain_keyword.go b/route/rule_item_domain_keyword.go similarity index 100% rename from route/rule_domain_keyword.go rename to route/rule_item_domain_keyword.go diff --git a/route/rule_domain_regex.go b/route/rule_item_domain_regex.go similarity index 100% rename from route/rule_domain_regex.go rename to route/rule_item_domain_regex.go diff --git a/route/rule_geoip.go b/route/rule_item_geoip.go similarity index 100% rename from route/rule_geoip.go rename to route/rule_item_geoip.go diff --git a/route/rule_geosite.go b/route/rule_item_geosite.go similarity index 100% rename from route/rule_geosite.go rename to route/rule_item_geosite.go diff --git a/route/rule_inbound.go b/route/rule_item_inbound.go similarity index 100% rename from route/rule_inbound.go rename to route/rule_item_inbound.go diff --git a/route/rule_ipversion.go b/route/rule_item_ipversion.go similarity index 100% rename from route/rule_ipversion.go rename to route/rule_item_ipversion.go diff --git a/route/rule_item_network.go b/route/rule_item_network.go new file mode 100644 index 00000000..fc54f425 --- /dev/null +++ b/route/rule_item_network.go @@ -0,0 +1,42 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" + F "github.com/sagernet/sing/common/format" +) + +var _ RuleItem = (*NetworkItem)(nil) + +type NetworkItem struct { + networks []string + networkMap map[string]bool +} + +func NewNetworkItem(networks []string) *NetworkItem { + networkMap := make(map[string]bool) + for _, network := range networks { + networkMap[network] = true + } + return &NetworkItem{ + networks: networks, + networkMap: networkMap, + } +} + +func (r *NetworkItem) Match(metadata *adapter.InboundContext) bool { + return r.networkMap[metadata.Network] +} + +func (r *NetworkItem) String() string { + description := "network=" + + pLen := len(r.networks) + if pLen == 1 { + description += F.ToString(r.networks[0]) + } else { + description += "[" + strings.Join(F.MapToString(r.networks), " ") + "]" + } + return description +} diff --git a/route/rule_outbound.go b/route/rule_item_outbound.go similarity index 100% rename from route/rule_outbound.go rename to route/rule_item_outbound.go diff --git a/route/rule_package_name.go b/route/rule_item_package_name.go similarity index 100% rename from route/rule_package_name.go rename to route/rule_item_package_name.go diff --git a/route/rule_port.go b/route/rule_item_port.go similarity index 100% rename from route/rule_port.go rename to route/rule_item_port.go diff --git a/route/rule_port_range.go b/route/rule_item_port_range.go similarity index 100% rename from route/rule_port_range.go rename to route/rule_item_port_range.go diff --git a/route/rule_process_name.go b/route/rule_item_process_name.go similarity index 100% rename from route/rule_process_name.go rename to route/rule_item_process_name.go diff --git a/route/rule_process_path.go b/route/rule_item_process_path.go similarity index 100% rename from route/rule_process_path.go rename to route/rule_item_process_path.go diff --git a/route/rule_protocol.go b/route/rule_item_protocol.go similarity index 100% rename from route/rule_protocol.go rename to route/rule_item_protocol.go diff --git a/route/rule_query_type.go b/route/rule_item_query_type.go similarity index 100% rename from route/rule_query_type.go rename to route/rule_item_query_type.go diff --git a/route/rule_user.go b/route/rule_item_user.go similarity index 100% rename from route/rule_user.go rename to route/rule_item_user.go diff --git a/route/rule_user_id.go b/route/rule_item_user_id.go similarity index 100% rename from route/rule_user_id.go rename to route/rule_item_user_id.go diff --git a/route/rule_network.go b/route/rule_network.go deleted file mode 100644 index 0346cb13..00000000 --- a/route/rule_network.go +++ /dev/null @@ -1,23 +0,0 @@ -package route - -import ( - "github.com/sagernet/sing-box/adapter" -) - -var _ RuleItem = (*NetworkItem)(nil) - -type NetworkItem struct { - network string -} - -func NewNetworkItem(network string) *NetworkItem { - return &NetworkItem{network} -} - -func (r *NetworkItem) Match(metadata *adapter.InboundContext) bool { - return r.network == metadata.Network -} - -func (r *NetworkItem) String() string { - return "network=" + r.network -} diff --git a/transport/wireguard/device.go b/transport/wireguard/device.go index 14e04bf5..9fb750b0 100644 --- a/transport/wireguard/device.go +++ b/transport/wireguard/device.go @@ -1,13 +1,23 @@ package wireguard import ( + "net/netip" + + "github.com/sagernet/sing-tun" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/wireguard-go/tun" + wgTun "github.com/sagernet/wireguard-go/tun" ) type Device interface { - tun.Device + wgTun.Device N.Dialer Start() error + Inet4Address() netip.Addr + Inet6Address() netip.Addr // NewEndpoint() (stack.LinkEndpoint, error) } + +type NatDevice interface { + Device + CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination +} diff --git a/transport/wireguard/device_nat.go b/transport/wireguard/device_nat.go new file mode 100644 index 00000000..72201bb5 --- /dev/null +++ b/transport/wireguard/device_nat.go @@ -0,0 +1,75 @@ +package wireguard + +import ( + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" +) + +var _ Device = (*natDeviceWrapper)(nil) + +type natDeviceWrapper struct { + Device + outbound chan *buf.Buffer + mapping *tun.NatMapping + writer *tun.NatWriter +} + +func NewNATDevice(upstream Device, ipRewrite bool) NatDevice { + wrapper := &natDeviceWrapper{ + Device: upstream, + outbound: make(chan *buf.Buffer, 256), + mapping: tun.NewNatMapping(ipRewrite), + } + if ipRewrite { + wrapper.writer = tun.NewNatWriter(upstream.Inet4Address(), upstream.Inet6Address()) + } + return wrapper +} + +func (d *natDeviceWrapper) Read(p []byte, offset int) (int, error) { + select { + case packet := <-d.outbound: + defer packet.Release() + return copy(p[offset:], packet.Bytes()), nil + default: + } + return d.Device.Read(p, offset) +} + +func (d *natDeviceWrapper) Write(p []byte, offset int) (int, error) { + packet := p[offset:] + handled, err := d.mapping.WritePacket(packet) + if handled { + return len(packet), err + } + return d.Device.Write(p, offset) +} + +func (d *natDeviceWrapper) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination { + d.mapping.CreateSession(session, conn) + return &natDestinationWrapper{d, session} +} + +var _ tun.DirectDestination = (*natDestinationWrapper)(nil) + +type natDestinationWrapper struct { + device *natDeviceWrapper + session tun.RouteSession +} + +func (d *natDestinationWrapper) WritePacket(buffer *buf.Buffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacket(buffer.Bytes()) + } + d.device.outbound <- buffer + return nil +} + +func (d *natDestinationWrapper) Close() error { + d.device.mapping.DeleteSession(d.session) + return nil +} + +func (d *natDestinationWrapper) Timeout() bool { + return false +} diff --git a/transport/wireguard/device_nat_gvisor.go b/transport/wireguard/device_nat_gvisor.go new file mode 100644 index 00000000..6c55ec96 --- /dev/null +++ b/transport/wireguard/device_nat_gvisor.go @@ -0,0 +1,27 @@ +//go:build with_gvisor + +package wireguard + +import ( + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +func (d *natDestinationWrapper) WritePacketBuffer(buffer *stack.PacketBuffer) error { + defer buffer.DecRef() + if d.device.writer != nil { + d.device.writer.RewritePacketBuffer(buffer) + } + var packetLen int + for _, slice := range buffer.AsSlices() { + packetLen += len(slice) + } + packet := buf.NewSize(packetLen) + for _, slice := range buffer.AsSlices() { + common.Must1(packet.Write(slice)) + } + d.device.outbound <- packet + return nil +} diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index b2981e36..56f8f4a5 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -8,10 +8,12 @@ import ( "net/netip" "os" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/wireguard-go/tun" + wgTun "github.com/sagernet/wireguard-go/tun" "gvisor.dev/gvisor/pkg/bufferv2" "gvisor.dev/gvisor/pkg/tcpip" @@ -25,33 +27,38 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -var _ Device = (*StackDevice)(nil) +var _ NatDevice = (*StackDevice)(nil) const defaultNIC tcpip.NICID = 1 type StackDevice struct { - stack *stack.Stack - mtu uint32 - events chan tun.Event - outbound chan *stack.PacketBuffer - done chan struct{} - dispatcher stack.NetworkDispatcher - addr4 tcpip.Address - addr6 tcpip.Address + stack *stack.Stack + mtu uint32 + events chan wgTun.Event + outbound chan *stack.PacketBuffer + packetOutbound chan *buf.Buffer + done chan struct{} + dispatcher stack.NetworkDispatcher + addr4 tcpip.Address + addr6 tcpip.Address + mapping *tun.NatMapping + writer *tun.NatWriter } -func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) { +func NewStackDevice(localAddresses []netip.Prefix, mtu uint32, ipRewrite bool) (*StackDevice, error) { ipStack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, HandleLocal: true, }) tunDevice := &StackDevice{ - stack: ipStack, - mtu: mtu, - events: make(chan tun.Event, 1), - outbound: make(chan *stack.PacketBuffer, 256), - done: make(chan struct{}), + stack: ipStack, + mtu: mtu, + events: make(chan wgTun.Event, 1), + outbound: make(chan *stack.PacketBuffer, 256), + packetOutbound: make(chan *buf.Buffer, 256), + done: make(chan struct{}), + mapping: tun.NewNatMapping(ipRewrite), } err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) if err != nil { @@ -77,6 +84,9 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String()) } } + if ipRewrite { + tunDevice.writer = tun.NewNatWriter(tunDevice.Inet4Address(), tunDevice.Inet6Address()) + } sOpt := tcpip.TCPSACKEnabled(true) ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) cOpt := tcpip.CongestionControlOption("cubic") @@ -144,8 +154,16 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) return udpConn, nil } +func (w *StackDevice) Inet4Address() netip.Addr { + return M.AddrFromIP(net.IP(w.addr4)) +} + +func (w *StackDevice) Inet6Address() netip.Addr { + return M.AddrFromIP(net.IP(w.addr6)) +} + func (w *StackDevice) Start() error { - w.events <- tun.EventUp + w.events <- wgTun.EventUp return nil } @@ -165,6 +183,10 @@ func (w *StackDevice) Read(p []byte, offset int) (n int, err error) { n += copy(p[n:], slice) } return + case packet := <-w.packetOutbound: + defer packet.Release() + n = copy(p[offset:], packet.Bytes()) + return case <-w.done: return 0, os.ErrClosed } @@ -175,6 +197,10 @@ func (w *StackDevice) Write(p []byte, offset int) (n int, err error) { if len(p) == 0 { return } + handled, err := w.mapping.WritePacket(p) + if handled { + return len(p), err + } var networkProtocol tcpip.NetworkProtocolNumber switch header.IPVersion(p) { case header.IPv4Version: @@ -203,7 +229,7 @@ func (w *StackDevice) Name() (string, error) { return "sing-box", nil } -func (w *StackDevice) Events() chan tun.Event { +func (w *StackDevice) Events() chan wgTun.Event { return w.events } @@ -222,6 +248,44 @@ func (w *StackDevice) Close() error { return nil } +func (w *StackDevice) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination { + w.mapping.CreateSession(session, conn) + return &stackNatDestination{ + device: w, + session: session, + } +} + +type stackNatDestination struct { + device *StackDevice + session tun.RouteSession +} + +func (d *stackNatDestination) WritePacket(buffer *buf.Buffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacket(buffer.Bytes()) + } + d.device.packetOutbound <- buffer + return nil +} + +func (d *stackNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacketBuffer(buffer) + } + d.device.outbound <- buffer + return nil +} + +func (d *stackNatDestination) Close() error { + d.device.mapping.DeleteSession(d.session) + return nil +} + +func (d *stackNatDestination) Timeout() bool { + return false +} + var _ stack.LinkEndpoint = (*wireEndpoint)(nil) type wireEndpoint StackDevice diff --git a/transport/wireguard/device_stack_stub.go b/transport/wireguard/device_stack_stub.go index b383ab38..5d2fc1dc 100644 --- a/transport/wireguard/device_stack_stub.go +++ b/transport/wireguard/device_stack_stub.go @@ -8,6 +8,6 @@ import ( "github.com/sagernet/sing-tun" ) -func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) { +func NewStackDevice(localAddresses []netip.Prefix, mtu uint32, ipRewrite bool) (Device, error) { return nil, tun.ErrGVisorNotIncluded } diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index d4316422..faca3023 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -23,6 +23,8 @@ type SystemDevice struct { name string mtu int events chan wgTun.Event + addr4 netip.Addr + addr6 netip.Addr } /*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) { @@ -55,11 +57,24 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes if err != nil { return nil, err } + var inet4Address netip.Addr + var inet6Address netip.Addr + if len(inet4Addresses) > 0 { + inet4Address = inet4Addresses[0].Addr() + } + if len(inet6Addresses) > 0 { + inet6Address = inet6Addresses[0].Addr() + } return &SystemDevice{ - dialer.NewDefault(router, option.DialerOptions{ + dialer: dialer.NewDefault(router, option.DialerOptions{ BindInterface: interfaceName, }), - tunInterface, interfaceName, int(mtu), make(chan wgTun.Event), + device: tunInterface, + name: interfaceName, + mtu: int(mtu), + events: make(chan wgTun.Event), + addr4: inet4Address, + addr6: inet6Address, }, nil } @@ -71,6 +86,14 @@ func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr return w.dialer.ListenPacket(ctx, destination) } +func (w *SystemDevice) Inet4Address() netip.Addr { + return w.addr4 +} + +func (w *SystemDevice) Inet6Address() netip.Addr { + return w.addr6 +} + func (w *SystemDevice) Start() error { w.events <- wgTun.EventUp return nil @@ -80,12 +103,12 @@ func (w *SystemDevice) File() *os.File { return nil } -func (w *SystemDevice) Read(bytes []byte, index int) (int, error) { - return w.device.Read(bytes[index-tun.PacketOffset:]) +func (w *SystemDevice) Read(p []byte, offset int) (int, error) { + return w.device.Read(p[offset-tun.PacketOffset:]) } -func (w *SystemDevice) Write(bytes []byte, index int) (int, error) { - return w.device.Write(bytes[index:]) +func (w *SystemDevice) Write(p []byte, offset int) (int, error) { + return w.device.Write(p[offset:]) } func (w *SystemDevice) Flush() error {