Improve rule actions

This commit is contained in:
世界 2024-11-06 17:30:40 +08:00
parent 2864dd01f2
commit c9b24396d3
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 414 additions and 138 deletions

View file

@ -57,7 +57,9 @@ type InboundContext struct {
// Deprecated // Deprecated
InboundOptions option.InboundOptions InboundOptions option.InboundOptions
UDPDisableDomainUnmapping bool UDPDisableDomainUnmapping bool
DNSServer string UDPConnect bool
DNSServer string
DestinationAddresses []netip.Addr DestinationAddresses []netip.Addr
SourceGeoIPCode string SourceGeoIPCode string

View file

@ -25,12 +25,13 @@ const (
) )
const ( const (
RuleActionTypeRoute = "route" RuleActionTypeRoute = "route"
RuleActionTypeReturn = "return" RuleActionTypeRouteOptions = "route-options"
RuleActionTypeReject = "reject" RuleActionTypeDirect = "direct"
RuleActionTypeHijackDNS = "hijack-dns" RuleActionTypeReject = "reject"
RuleActionTypeSniff = "sniff" RuleActionTypeHijackDNS = "hijack-dns"
RuleActionTypeResolve = "resolve" RuleActionTypeSniff = "sniff"
RuleActionTypeResolve = "resolve"
) )
const ( const (

View file

@ -109,7 +109,7 @@ type DefaultRule struct {
RuleAction RuleAction
} }
func (r *DefaultRule) MarshalJSON() ([]byte, error) { func (r DefaultRule) MarshalJSON() ([]byte, error) {
return badjson.MarshallObjects(r.RawDefaultRule, r.RuleAction) return badjson.MarshallObjects(r.RawDefaultRule, r.RuleAction)
} }
@ -128,27 +128,27 @@ func (r *DefaultRule) IsValid() bool {
return !reflect.DeepEqual(r, defaultValue) return !reflect.DeepEqual(r, defaultValue)
} }
type _LogicalRule struct { type RawLogicalRule struct {
Mode string `json:"mode"` Mode string `json:"mode"`
Rules []Rule `json:"rules,omitempty"` Rules []Rule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"` Invert bool `json:"invert,omitempty"`
} }
type LogicalRule struct { type LogicalRule struct {
_LogicalRule RawLogicalRule
RuleAction RuleAction
} }
func (r *LogicalRule) MarshalJSON() ([]byte, error) { func (r LogicalRule) MarshalJSON() ([]byte, error) {
return badjson.MarshallObjects(r._LogicalRule, r.RuleAction) return badjson.MarshallObjects(r.RawLogicalRule, r.RuleAction)
} }
func (r *LogicalRule) UnmarshalJSON(data []byte) error { func (r *LogicalRule) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, &r._LogicalRule) err := json.Unmarshal(data, &r.RawLogicalRule)
if err != nil { if err != nil {
return err return err
} }
return badjson.UnmarshallExcluded(data, &r._LogicalRule, &r.RuleAction) return badjson.UnmarshallExcluded(data, &r.RawLogicalRule, &r.RuleAction)
} }
func (r *LogicalRule) IsValid() bool { func (r *LogicalRule) IsValid() bool {

View file

@ -1,30 +1,41 @@
package option package option
import ( import (
"fmt"
"time"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
dns "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson" "github.com/sagernet/sing/common/json/badjson"
) )
type _RuleAction struct { type _RuleAction struct {
Action string `json:"action,omitempty"` Action string `json:"action,omitempty"`
RouteOptions RouteActionOptions `json:"-"` RouteOptions RouteActionOptions `json:"-"`
RejectOptions RejectActionOptions `json:"-"` RouteOptionsOptions RouteOptionsActionOptions `json:"-"`
SniffOptions RouteActionSniff `json:"-"` DirectOptions DirectActionOptions `json:"-"`
ResolveOptions RouteActionResolve `json:"-"` RejectOptions RejectActionOptions `json:"-"`
SniffOptions RouteActionSniff `json:"-"`
ResolveOptions RouteActionResolve `json:"-"`
} }
type RuleAction _RuleAction type RuleAction _RuleAction
func (r RuleAction) MarshalJSON() ([]byte, error) { func (r RuleAction) MarshalJSON() ([]byte, error) {
if r.Action == "" {
return json.Marshal(struct{}{})
}
var v any var v any
switch r.Action { switch r.Action {
case C.RuleActionTypeRoute: case C.RuleActionTypeRoute:
r.Action = "" r.Action = ""
v = r.RouteOptions v = r.RouteOptions
case C.RuleActionTypeReturn: case C.RuleActionTypeRouteOptions:
v = nil v = r.RouteOptionsOptions
case C.RuleActionTypeDirect:
v = r.DirectOptions
case C.RuleActionTypeReject: case C.RuleActionTypeReject:
v = r.RejectOptions v = r.RejectOptions
case C.RuleActionTypeHijackDNS: case C.RuleActionTypeHijackDNS:
@ -52,8 +63,10 @@ func (r *RuleAction) UnmarshalJSON(data []byte) error {
case "", C.RuleActionTypeRoute: case "", C.RuleActionTypeRoute:
r.Action = C.RuleActionTypeRoute r.Action = C.RuleActionTypeRoute
v = &r.RouteOptions v = &r.RouteOptions
case C.RuleActionTypeReturn: case C.RuleActionTypeRouteOptions:
v = nil v = &r.RouteOptionsOptions
case C.RuleActionTypeDirect:
v = &r.DirectOptions
case C.RuleActionTypeReject: case C.RuleActionTypeReject:
v = &r.RejectOptions v = &r.RejectOptions
case C.RuleActionTypeHijackDNS: case C.RuleActionTypeHijackDNS:
@ -73,29 +86,30 @@ func (r *RuleAction) UnmarshalJSON(data []byte) error {
} }
type _DNSRuleAction struct { type _DNSRuleAction struct {
Action string `json:"action,omitempty"` Action string `json:"action,omitempty"`
RouteOptions DNSRouteActionOptions `json:"-"` RouteOptions DNSRouteActionOptions `json:"-"`
RejectOptions RejectActionOptions `json:"-"` RouteOptionsOptions DNSRouteOptionsActionOptions `json:"-"`
RejectOptions RejectActionOptions `json:"-"`
} }
type DNSRuleAction _DNSRuleAction type DNSRuleAction _DNSRuleAction
func (r DNSRuleAction) MarshalJSON() ([]byte, error) { func (r DNSRuleAction) MarshalJSON() ([]byte, error) {
if r.Action == "" {
return json.Marshal(struct{}{})
}
var v any var v any
switch r.Action { switch r.Action {
case C.RuleActionTypeRoute: case C.RuleActionTypeRoute:
r.Action = "" r.Action = ""
v = r.RouteOptions v = r.RouteOptions
case C.RuleActionTypeReturn: case C.RuleActionTypeRouteOptions:
v = nil v = r.RouteOptionsOptions
case C.RuleActionTypeReject: case C.RuleActionTypeReject:
v = r.RejectOptions v = r.RejectOptions
default: default:
return nil, E.New("unknown DNS rule action: " + r.Action) return nil, E.New("unknown DNS rule action: " + r.Action)
} }
if v == nil {
return badjson.MarshallObjects((_DNSRuleAction)(r))
}
return badjson.MarshallObjects((_DNSRuleAction)(r), v) return badjson.MarshallObjects((_DNSRuleAction)(r), v)
} }
@ -109,8 +123,8 @@ func (r *DNSRuleAction) UnmarshalJSON(data []byte) error {
case "", C.RuleActionTypeRoute: case "", C.RuleActionTypeRoute:
r.Action = C.RuleActionTypeRoute r.Action = C.RuleActionTypeRoute
v = &r.RouteOptions v = &r.RouteOptions
case C.RuleActionTypeReturn: case C.RuleActionTypeRouteOptions:
v = nil v = &r.RouteOptionsOptions
case C.RuleActionTypeReject: case C.RuleActionTypeReject:
v = &r.RejectOptions v = &r.RejectOptions
default: default:
@ -123,18 +137,136 @@ func (r *DNSRuleAction) UnmarshalJSON(data []byte) error {
return badjson.UnmarshallExcluded(data, (*_DNSRuleAction)(r), v) return badjson.UnmarshallExcluded(data, (*_DNSRuleAction)(r), v)
} }
type RouteActionOptions struct { type _RouteActionOptions struct {
Outbound string `json:"outbound"` Outbound string `json:"outbound,omitempty"`
UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"`
} }
type DNSRouteActionOptions struct { type RouteActionOptions _RouteActionOptions
Server string `json:"server"`
func (r *RouteActionOptions) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, (*_RouteActionOptions)(r))
if err != nil {
return err
}
if r.Outbound == "" {
return E.New("missing outbound")
}
return nil
}
type _RouteOptionsActionOptions struct {
UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"`
UDPConnect bool `json:"udp_connect,omitempty"`
}
type RouteOptionsActionOptions _RouteOptionsActionOptions
func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, (*_RouteOptionsActionOptions)(r))
if err != nil {
return err
}
if *r == (RouteOptionsActionOptions{}) {
return E.New("empty route option action")
}
return nil
}
type _DNSRouteActionOptions struct {
Server string `json:"server,omitempty"`
// Deprecated: Use DNSRouteOptionsActionOptions instead.
DisableCache bool `json:"disable_cache,omitempty"`
// Deprecated: Use DNSRouteOptionsActionOptions instead.
RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"`
// Deprecated: Use DNSRouteOptionsActionOptions instead.
ClientSubnet *AddrPrefix `json:"client_subnet,omitempty"`
}
type DNSRouteActionOptions _DNSRouteActionOptions
func (r *DNSRouteActionOptions) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, (*_DNSRouteActionOptions)(r))
if err != nil {
return err
}
if r.Server == "" {
return E.New("missing server")
}
return nil
}
type _DNSRouteOptionsActionOptions struct {
DisableCache bool `json:"disable_cache,omitempty"` DisableCache bool `json:"disable_cache,omitempty"`
RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"`
ClientSubnet *AddrPrefix `json:"client_subnet,omitempty"` ClientSubnet *AddrPrefix `json:"client_subnet,omitempty"`
} }
type DNSRouteOptionsActionOptions _DNSRouteOptionsActionOptions
func (r *DNSRouteOptionsActionOptions) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, (*_DNSRouteOptionsActionOptions)(r))
if err != nil {
return err
}
if *r == (DNSRouteOptionsActionOptions{}) {
return E.New("empty DNS route option action")
}
return nil
}
type _DirectActionOptions DialerOptions
type DirectActionOptions _DirectActionOptions
func (d DirectActionOptions) Descriptions() []string {
var descriptions []string
if d.BindInterface != "" {
descriptions = append(descriptions, "bind_interface="+d.BindInterface)
}
if d.Inet4BindAddress != nil {
descriptions = append(descriptions, "inet4_bind_address="+d.Inet4BindAddress.Build().String())
}
if d.Inet6BindAddress != nil {
descriptions = append(descriptions, "inet6_bind_address="+d.Inet6BindAddress.Build().String())
}
if d.RoutingMark != 0 {
descriptions = append(descriptions, "routing_mark="+fmt.Sprintf("0x%x", d.RoutingMark))
}
if d.ReuseAddr {
descriptions = append(descriptions, "reuse_addr")
}
if d.ConnectTimeout != 0 {
descriptions = append(descriptions, "connect_timeout="+time.Duration(d.ConnectTimeout).String())
}
if d.TCPFastOpen {
descriptions = append(descriptions, "tcp_fast_open")
}
if d.TCPMultiPath {
descriptions = append(descriptions, "tcp_multi_path")
}
if d.UDPFragment != nil {
descriptions = append(descriptions, "udp_fragment="+fmt.Sprint(*d.UDPFragment))
}
if d.DomainStrategy != DomainStrategy(dns.DomainStrategyAsIS) {
descriptions = append(descriptions, "domain_strategy="+d.DomainStrategy.String())
}
if d.FallbackDelay != 0 {
descriptions = append(descriptions, "fallback_delay="+time.Duration(d.FallbackDelay).String())
}
return descriptions
}
func (d *DirectActionOptions) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, (*_DirectActionOptions)(d))
if err != nil {
return err
}
if d.Detour != "" {
return E.New("detour is not available in the current context")
}
return nil
}
type _RejectActionOptions struct { type _RejectActionOptions struct {
Method string `json:"method,omitempty"` Method string `json:"method,omitempty"`
NoDrop bool `json:"no_drop,omitempty"` NoDrop bool `json:"no_drop,omitempty"`
@ -155,7 +287,7 @@ func (r *RejectActionOptions) UnmarshalJSON(bytes []byte) error {
return E.New("unknown reject method: " + r.Method) return E.New("unknown reject method: " + r.Method)
} }
if r.Method == C.RuleActionRejectMethodDrop && r.NoDrop { if r.Method == C.RuleActionRejectMethodDrop && r.NoDrop {
return E.New("no_drop is not allowed when method is drop") return E.New("no_drop is not available in current context")
} }
return nil return nil
} }

View file

@ -111,7 +111,7 @@ type DefaultDNSRule struct {
DNSRuleAction DNSRuleAction
} }
func (r *DefaultDNSRule) MarshalJSON() ([]byte, error) { func (r DefaultDNSRule) MarshalJSON() ([]byte, error) {
return badjson.MarshallObjects(r.RawDefaultDNSRule, r.DNSRuleAction) return badjson.MarshallObjects(r.RawDefaultDNSRule, r.DNSRuleAction)
} }
@ -123,34 +123,34 @@ func (r *DefaultDNSRule) UnmarshalJSON(data []byte) error {
return badjson.UnmarshallExcluded(data, &r.RawDefaultDNSRule, &r.DNSRuleAction) return badjson.UnmarshallExcluded(data, &r.RawDefaultDNSRule, &r.DNSRuleAction)
} }
func (r *DefaultDNSRule) IsValid() bool { func (r DefaultDNSRule) IsValid() bool {
var defaultValue DefaultDNSRule var defaultValue DefaultDNSRule
defaultValue.Invert = r.Invert defaultValue.Invert = r.Invert
defaultValue.DNSRuleAction = r.DNSRuleAction defaultValue.DNSRuleAction = r.DNSRuleAction
return !reflect.DeepEqual(r, defaultValue) return !reflect.DeepEqual(r, defaultValue)
} }
type _LogicalDNSRule struct { type RawLogicalDNSRule struct {
Mode string `json:"mode"` Mode string `json:"mode"`
Rules []DNSRule `json:"rules,omitempty"` Rules []DNSRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"` Invert bool `json:"invert,omitempty"`
} }
type LogicalDNSRule struct { type LogicalDNSRule struct {
_LogicalDNSRule RawLogicalDNSRule
DNSRuleAction DNSRuleAction
} }
func (r *LogicalDNSRule) MarshalJSON() ([]byte, error) { func (r LogicalDNSRule) MarshalJSON() ([]byte, error) {
return badjson.MarshallObjects(r._LogicalDNSRule, r.DNSRuleAction) return badjson.MarshallObjects(r.RawLogicalDNSRule, r.DNSRuleAction)
} }
func (r *LogicalDNSRule) UnmarshalJSON(data []byte) error { func (r *LogicalDNSRule) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, &r._LogicalDNSRule) err := json.Unmarshal(data, &r.RawLogicalDNSRule)
if err != nil { if err != nil {
return err return err
} }
return badjson.UnmarshallExcluded(data, &r._LogicalDNSRule, &r.DNSRuleAction) return badjson.UnmarshallExcluded(data, &r.RawLogicalDNSRule, &r.DNSRuleAction)
} }
func (r *LogicalDNSRule) IsValid() bool { func (r *LogicalDNSRule) IsValid() bool {

View file

@ -43,6 +43,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net
go func() error { go func() error {
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
if err != nil { if err != nil {
conn.Close()
return err return err
} }
responseBuffer := buf.NewPacket() responseBuffer := buf.NewPacket()

View file

@ -87,23 +87,34 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if deadline.NeedAdditionalReadDeadline(conn) { if deadline.NeedAdditionalReadDeadline(conn) {
conn = deadline.NewConn(conn) conn = deadline.NewConn(conn)
} }
selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil, -1) selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil)
if err != nil { if err != nil {
return err return err
} }
var selectedOutbound adapter.Outbound var (
var selectReturn bool // selectedOutbound adapter.Outbound
selectedDialer N.Dialer
selectedTag string
selectedDescription string
)
if selectedRule != nil { if selectedRule != nil {
switch action := selectedRule.Action().(type) { switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute: case *rule.RuleActionRoute:
var loaded bool selectedOutbound, loaded := r.Outbound(action.Outbound)
selectedOutbound, loaded = r.Outbound(action.Outbound)
if !loaded { if !loaded {
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
return E.New("outbound not found: ", action.Outbound) return E.New("outbound not found: ", action.Outbound)
} }
case *rule.RuleActionReturn: if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) {
selectReturn = true buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag())
}
selectedDialer = selectedOutbound
selectedTag = selectedOutbound.Tag()
selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
case *rule.RuleActionDirect:
selectedDialer = action.Dialer
selectedDescription = action.String()
case *rule.RuleActionReject: case *rule.RuleActionReject:
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@ -116,17 +127,16 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
return nil return nil
} }
} }
if selectedRule == nil || selectReturn { if selectedRule == nil {
if r.defaultOutboundForConnection == nil { if r.defaultOutboundForConnection == nil {
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
return E.New("missing default outbound with TCP support") return E.New("missing default outbound with TCP support")
} }
selectedOutbound = r.defaultOutboundForConnection selectedDialer = r.defaultOutboundForConnection
} selectedTag = r.defaultOutboundForConnection.Tag()
if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) { selectedDescription = F.ToString("outbound/", r.defaultOutboundForConnection.Type(), "[", r.defaultOutboundForConnection.Tag(), "]")
buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag())
} }
for _, buffer := range buffers { for _, buffer := range buffers {
conn = bufio.NewCachedConn(conn, buffer) conn = bufio.NewCachedConn(conn, buffer)
} }
@ -137,10 +147,10 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
} }
if r.v2rayServer != nil { if r.v2rayServer != nil {
if statsService := r.v2rayServer.StatsService(); statsService != nil { if statsService := r.v2rayServer.StatsService(); statsService != nil {
conn = statsService.RoutedConnection(metadata.Inbound, selectedOutbound.Tag(), metadata.User, conn) conn = statsService.RoutedConnection(metadata.Inbound, selectedTag, metadata.User, conn)
} }
} }
legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler) legacyOutbound, isLegacy := selectedDialer.(adapter.ConnectionHandler)
if isLegacy { if isLegacy {
err = legacyOutbound.NewConnection(ctx, conn, metadata) err = legacyOutbound.NewConnection(ctx, conn, metadata)
if err != nil { if err != nil {
@ -148,7 +158,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if onClose != nil { if onClose != nil {
onClose(err) onClose(err)
} }
return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") return E.Cause(err, selectedDescription)
} else { } else {
if onClose != nil { if onClose != nil {
onClose(nil) onClose(nil)
@ -157,13 +167,13 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
return nil return nil
} }
// TODO // TODO
err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata) err = outbound.NewConnection(ctx, selectedDialer, conn, metadata)
if err != nil { if err != nil {
conn.Close() conn.Close()
if onClose != nil { if onClose != nil {
onClose(err) onClose(err)
} }
return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") return E.Cause(err, selectedDescription)
} else { } else {
if onClose != nil { if onClose != nil {
onClose(nil) onClose(nil)
@ -231,24 +241,34 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn)) conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn))
}*/ }*/
selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1) selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn)
if err != nil { if err != nil {
return err return err
} }
var selectedOutbound adapter.Outbound var (
selectedDialer N.Dialer
selectedTag string
selectedDescription string
)
var selectReturn bool var selectReturn bool
if selectedRule != nil { if selectedRule != nil {
switch action := selectedRule.Action().(type) { switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute: case *rule.RuleActionRoute:
var loaded bool selectedOutbound, loaded := r.Outbound(action.Outbound)
selectedOutbound, loaded = r.Outbound(action.Outbound)
if !loaded { if !loaded {
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("outbound not found: ", action.Outbound) return E.New("outbound not found: ", action.Outbound)
} }
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
case *rule.RuleActionReturn: N.ReleaseMultiPacketBuffer(packetBuffers)
selectReturn = true return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
}
selectedDialer = selectedOutbound
selectedTag = selectedOutbound.Tag()
selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
case *rule.RuleActionDirect:
selectedDialer = action.Dialer
selectedDescription = action.String()
case *rule.RuleActionReject: case *rule.RuleActionReject:
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@ -263,11 +283,9 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
N.ReleaseMultiPacketBuffer(packetBuffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("missing default outbound with UDP support") return E.New("missing default outbound with UDP support")
} }
selectedOutbound = r.defaultOutboundForPacketConnection selectedDialer = r.defaultOutboundForPacketConnection
} selectedTag = r.defaultOutboundForPacketConnection.Tag()
if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) { selectedDescription = F.ToString("outbound/", r.defaultOutboundForPacketConnection.Type(), "[", r.defaultOutboundForPacketConnection.Tag(), "]")
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
} }
for _, buffer := range packetBuffers { for _, buffer := range packetBuffers {
conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination) conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination)
@ -280,32 +298,32 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
} }
if r.v2rayServer != nil { if r.v2rayServer != nil {
if statsService := r.v2rayServer.StatsService(); statsService != nil { if statsService := r.v2rayServer.StatsService(); statsService != nil {
conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedOutbound.Tag(), metadata.User, conn) conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedTag, metadata.User, conn)
} }
} }
if metadata.FakeIP { if metadata.FakeIP {
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
} }
legacyOutbound, isLegacy := selectedOutbound.(adapter.PacketConnectionHandler) legacyOutbound, isLegacy := selectedDialer.(adapter.PacketConnectionHandler)
if isLegacy { if isLegacy {
err = legacyOutbound.NewPacketConnection(ctx, conn, metadata) err = legacyOutbound.NewPacketConnection(ctx, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") return E.Cause(err, selectedDescription)
} }
return nil return nil
} }
// TODO // TODO
err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata) err = outbound.NewPacketConnection(ctx, selectedDialer, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil { if err != nil {
return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") return E.Cause(err, selectedDescription)
} }
return nil return nil
} }
func (r *Router) PreMatch(metadata adapter.InboundContext) error { func (r *Router) PreMatch(metadata adapter.InboundContext) error {
selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1) selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil)
if err != nil { if err != nil {
return err return err
} }
@ -321,7 +339,7 @@ func (r *Router) PreMatch(metadata adapter.InboundContext) error {
func (r *Router) matchRule( func (r *Router) matchRule(
ctx context.Context, metadata *adapter.InboundContext, preMatch bool, ctx context.Context, metadata *adapter.InboundContext, preMatch bool,
inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int, inputConn net.Conn, inputPacketConn N.PacketConn,
) ( ) (
selectedRule adapter.Rule, selectedRuleIndex int, selectedRule adapter.Rule, selectedRuleIndex int,
buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error, buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error,
@ -416,24 +434,10 @@ func (r *Router) matchRule(
} }
match: match:
for ruleIndex < len(r.rules) { for currentRuleIndex, currentRule := range r.rules {
rules := r.rules metadata.ResetRuleCache()
if ruleIndex != -1 { if !currentRule.Match(metadata) {
rules = rules[ruleIndex+1:] continue
}
var (
currentRule adapter.Rule
currentRuleIndex int
matched bool
)
for currentRuleIndex, currentRule = range rules {
if currentRule.Match(metadata) {
matched = true
break
}
}
if !matched {
break
} }
if !preMatch { if !preMatch {
ruleDescription := currentRule.String() ruleDescription := currentRule.String()
@ -444,7 +448,7 @@ match:
} }
} else { } else {
switch currentRule.Action().Type() { switch currentRule.Action().Type() {
case C.RuleActionTypeReject, C.RuleActionTypeResolve: case C.RuleActionTypeReject:
ruleDescription := currentRule.String() ruleDescription := currentRule.String()
if ruleDescription != "" { if ruleDescription != "" {
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action()) r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
@ -454,6 +458,12 @@ match:
} }
} }
switch action := currentRule.Action().(type) { switch action := currentRule.Action().(type) {
case *rule.RuleActionRoute:
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
metadata.UDPConnect = action.UDPConnect
case *rule.RuleActionRouteOptions:
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
metadata.UDPConnect = action.UDPConnect
case *rule.RuleActionSniff: case *rule.RuleActionSniff:
if !preMatch { if !preMatch {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn) newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
@ -476,12 +486,16 @@ match:
if fatalErr != nil { if fatalErr != nil {
return return
} }
default: }
actionType := currentRule.Action().Type()
if actionType == C.RuleActionTypeRoute ||
actionType == C.RuleActionTypeReject ||
actionType == C.RuleActionTypeHijackDNS ||
(actionType == C.RuleActionTypeSniff && preMatch) {
selectedRule = currentRule selectedRule = currentRule
selectedRuleIndex = currentRuleIndex selectedRuleIndex = currentRuleIndex
break match break match
} }
ruleIndex = currentRuleIndex
} }
if !preMatch && metadata.Destination.Addr.IsUnspecified() { if !preMatch && metadata.Destination.Addr.IsUnspecified() {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn) newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)

View file

@ -8,8 +8,10 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
R "github.com/sagernet/sing-box/route/rule" R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
tun "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/cache" "github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
@ -48,38 +50,63 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int,
if ruleIndex != -1 { if ruleIndex != -1 {
dnsRules = dnsRules[ruleIndex+1:] dnsRules = dnsRules[ruleIndex+1:]
} }
for currentRuleIndex, rule := range dnsRules { for currentRuleIndex, currentRule := range dnsRules {
if rule.WithAddressLimit() && !isAddressQuery { if currentRule.WithAddressLimit() && !isAddressQuery {
continue continue
} }
metadata.ResetRuleCache() metadata.ResetRuleCache()
if rule.Match(metadata) { if currentRule.Match(metadata) {
displayRuleIndex := currentRuleIndex displayRuleIndex := currentRuleIndex
if displayRuleIndex != -1 { if displayRuleIndex != -1 {
displayRuleIndex += displayRuleIndex + 1 displayRuleIndex += displayRuleIndex + 1
} }
if routeAction, isRoute := rule.Action().(*R.RuleActionDNSRoute); isRoute { ruleDescription := currentRule.String()
transport, loaded := r.transportMap[routeAction.Server] if ruleDescription != "" {
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action())
} else {
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
}
switch action := currentRule.Action().(type) {
case *R.RuleActionDNSRoute:
transport, loaded := r.transportMap[action.Server]
if !loaded { if !loaded {
r.dnsLogger.ErrorContext(ctx, "transport not found: ", routeAction.Server) r.dnsLogger.ErrorContext(ctx, "transport not found: ", action.Server)
continue continue
} }
_, isFakeIP := transport.(adapter.FakeIPTransport) _, isFakeIP := transport.(adapter.FakeIPTransport)
if isFakeIP && !allowFakeIP { if isFakeIP && !allowFakeIP {
continue continue
} }
options.DisableCache = isFakeIP || routeAction.DisableCache if isFakeIP || action.DisableCache {
options.RewriteTTL = routeAction.RewriteTTL options.DisableCache = true
options.ClientSubnet = routeAction.ClientSubnet }
if action.RewriteTTL != nil {
options.RewriteTTL = action.RewriteTTL
}
if action.ClientSubnet.IsValid() {
options.ClientSubnet = action.ClientSubnet
}
if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded { if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
options.Strategy = domainStrategy options.Strategy = domainStrategy
} else { } else {
options.Strategy = r.defaultDomainStrategy options.Strategy = r.defaultDomainStrategy
} }
r.dnsLogger.DebugContext(ctx, "match[", displayRuleIndex, "] ", rule.String(), " => ", rule.Action()) r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
return transport, options, rule, currentRuleIndex return transport, options, currentRule, currentRuleIndex
} else { case *R.RuleActionDNSRouteOptions:
return nil, options, rule, currentRuleIndex if action.DisableCache {
options.DisableCache = true
}
if action.RewriteTTL != nil {
options.RewriteTTL = action.RewriteTTL
}
if action.ClientSubnet.IsValid() {
options.ClientSubnet = action.ClientSubnet
}
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
case *R.RuleActionReject:
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
return nil, options, currentRule, currentRuleIndex
} }
} }
} }
@ -127,6 +154,17 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
dnsCtx := adapter.OverrideContext(ctx) dnsCtx := adapter.OverrideContext(ctx)
var addressLimit bool var addressLimit bool
transport, options, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message)) transport, options, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message))
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
return dns.FixedResponse(message.Id, message.Question[0], nil, 0), nil
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
}
}
if rule != nil && rule.WithAddressLimit() { if rule != nil && rule.WithAddressLimit() {
addressLimit = true addressLimit = true
response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, options, func(response *mDNS.Msg) bool { response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, options, func(response *mDNS.Msg) bool {
@ -238,6 +276,17 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
if strategy != dns.DomainStrategyAsIS { if strategy != dns.DomainStrategyAsIS {
options.Strategy = strategy options.Strategy = strategy
} }
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
return nil, nil
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
}
}
if rule != nil && rule.WithAddressLimit() { if rule != nil && rule.WithAddressLimit() {
addressLimit = true addressLimit = true
responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool { responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {

View file

@ -5,9 +5,11 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
"syscall"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/sniff" "github.com/sagernet/sing-box/common/sniff"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
@ -17,19 +19,42 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
) )
func NewRuleAction(logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) { func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) {
switch action.Action { switch action.Action {
case "":
return nil, nil
case C.RuleActionTypeRoute: case C.RuleActionTypeRoute:
return &RuleActionRoute{ return &RuleActionRoute{
Outbound: action.RouteOptions.Outbound, Outbound: action.RouteOptions.Outbound,
UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, }, nil
case C.RuleActionTypeRouteOptions:
return &RuleActionRouteOptions{
UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping,
UDPConnect: action.RouteOptionsOptions.UDPConnect,
}, nil
case C.RuleActionTypeDirect:
directDialer, err := dialer.New(router, option.DialerOptions(action.DirectOptions))
if err != nil {
return nil, err
}
var description string
descriptions := action.DirectOptions.Descriptions()
switch len(descriptions) {
case 0:
case 1:
description = F.ToString("(", descriptions[0], ")")
case 2:
description = F.ToString("(", descriptions[0], ",", descriptions[1], ")")
default:
description = F.ToString("(", descriptions[0], ",", descriptions[1], ",...)")
}
return &RuleActionDirect{
Dialer: directDialer,
description: description,
}, nil }, nil
case C.RuleActionTypeReturn:
return &RuleActionReturn{}, nil
case C.RuleActionTypeReject: case C.RuleActionTypeReject:
return &RuleActionReject{ return &RuleActionReject{
Method: action.RejectOptions.Method, Method: action.RejectOptions.Method,
@ -56,6 +81,8 @@ func NewRuleAction(logger logger.ContextLogger, action option.RuleAction) (adapt
func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) adapter.RuleAction { func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) adapter.RuleAction {
switch action.Action { switch action.Action {
case "":
return nil
case C.RuleActionTypeRoute: case C.RuleActionTypeRoute:
return &RuleActionDNSRoute{ return &RuleActionDNSRoute{
Server: action.RouteOptions.Server, Server: action.RouteOptions.Server,
@ -63,8 +90,12 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction)
RewriteTTL: action.RouteOptions.RewriteTTL, RewriteTTL: action.RouteOptions.RewriteTTL,
ClientSubnet: action.RouteOptions.ClientSubnet.Build(), ClientSubnet: action.RouteOptions.ClientSubnet.Build(),
} }
case C.RuleActionTypeReturn: case C.RuleActionTypeRouteOptions:
return &RuleActionReturn{} return &RuleActionDNSRouteOptions{
DisableCache: action.RouteOptionsOptions.DisableCache,
RewriteTTL: action.RouteOptionsOptions.RewriteTTL,
ClientSubnet: action.RouteOptionsOptions.ClientSubnet.Build(),
}
case C.RuleActionTypeReject: case C.RuleActionTypeReject:
return &RuleActionReject{ return &RuleActionReject{
Method: action.RejectOptions.Method, Method: action.RejectOptions.Method,
@ -77,8 +108,7 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction)
} }
type RuleActionRoute struct { type RuleActionRoute struct {
Outbound string Outbound string
UDPDisableDomainUnmapping bool
} }
func (r *RuleActionRoute) Type() string { func (r *RuleActionRoute) Type() string {
@ -89,6 +119,26 @@ func (r *RuleActionRoute) String() string {
return F.ToString("route(", r.Outbound, ")") return F.ToString("route(", r.Outbound, ")")
} }
type RuleActionRouteOptions struct {
UDPDisableDomainUnmapping bool
UDPConnect bool
}
func (r *RuleActionRouteOptions) Type() string {
return C.RuleActionTypeRouteOptions
}
func (r *RuleActionRouteOptions) String() string {
var descriptions []string
if r.UDPDisableDomainUnmapping {
descriptions = append(descriptions, "udp-disable-domain-unmapping")
}
if r.UDPConnect {
descriptions = append(descriptions, "udp-connect")
}
return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
}
type RuleActionDNSRoute struct { type RuleActionDNSRoute struct {
Server string Server string
DisableCache bool DisableCache bool
@ -104,14 +154,41 @@ func (r *RuleActionDNSRoute) String() string {
return F.ToString("route(", r.Server, ")") return F.ToString("route(", r.Server, ")")
} }
type RuleActionReturn struct{} type RuleActionDNSRouteOptions struct {
DisableCache bool
func (r *RuleActionReturn) Type() string { RewriteTTL *uint32
return C.RuleActionTypeReturn ClientSubnet netip.Prefix
} }
func (r *RuleActionReturn) String() string { func (r *RuleActionDNSRouteOptions) Type() string {
return "return" return C.RuleActionTypeRouteOptions
}
func (r *RuleActionDNSRouteOptions) String() string {
var descriptions []string
if r.DisableCache {
descriptions = append(descriptions, "disable-cache")
}
if r.RewriteTTL != nil {
descriptions = append(descriptions, F.ToString("rewrite-ttl(", *r.RewriteTTL, ")"))
}
if r.ClientSubnet.IsValid() {
descriptions = append(descriptions, F.ToString("client-subnet(", r.ClientSubnet, ")"))
}
return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
}
type RuleActionDirect struct {
Dialer N.Dialer
description string
}
func (r *RuleActionDirect) Type() string {
return C.RuleActionTypeDirect
}
func (r *RuleActionDirect) String() string {
return "direct" + r.description
} }
type RuleActionReject struct { type RuleActionReject struct {
@ -137,7 +214,7 @@ func (r *RuleActionReject) Error(ctx context.Context) error {
var returnErr error var returnErr error
switch r.Method { switch r.Method {
case C.RuleActionRejectMethodDefault: case C.RuleActionRejectMethodDefault:
returnErr = unix.ECONNREFUSED returnErr = syscall.ECONNREFUSED
case C.RuleActionRejectMethodDrop: case C.RuleActionRejectMethodDrop:
return tun.ErrDrop return tun.ErrDrop
default: default:

View file

@ -52,7 +52,7 @@ type RuleItem interface {
} }
func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) {
action, err := NewRuleAction(logger, options.RuleAction) action, err := NewRuleAction(router, logger, options.RuleAction)
if err != nil { if err != nil {
return nil, E.Cause(err, "action") return nil, E.Cause(err, "action")
} }
@ -254,7 +254,7 @@ type LogicalRule struct {
} }
func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
action, err := NewRuleAction(logger, options.RuleAction) action, err := NewRuleAction(router, logger, options.RuleAction)
if err != nil { if err != nil {
return nil, E.Cause(err, "action") return nil, E.Cause(err, "action")
} }