From 99bbdb734703455d87319260d73df622db99ff19 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Tue, 22 Oct 2024 21:28:22 +0800
Subject: [PATCH] Implement TCP and ICMP rejects

---
 adapter/router.go         |  1 +
 constant/rule.go          |  9 ++++--
 inbound/tun.go            | 12 ++++++--
 option/rule_action.go     | 24 +++++++++------
 route/route.go            | 62 +++++++++++++++++++++++++--------------
 route/rule/rule_action.go | 26 ++++++++++++++--
 6 files changed, 94 insertions(+), 40 deletions(-)

diff --git a/adapter/router.go b/adapter/router.go
index 134c9442..c9cd46e9 100644
--- a/adapter/router.go
+++ b/adapter/router.go
@@ -34,6 +34,7 @@ type Router interface {
 	FakeIPStore() FakeIPStore
 
 	ConnectionRouter
+	PreMatch(metadata InboundContext) error
 	ConnectionRouterEx
 
 	GeoIPReader() *geoip.Reader
diff --git a/constant/rule.go b/constant/rule.go
index ba74ec63..c7717376 100644
--- a/constant/rule.go
+++ b/constant/rule.go
@@ -34,7 +34,10 @@ const (
 )
 
 const (
-	RuleActionRejectMethodDefault         = "default"
-	RuleActionRejectMethodPortUnreachable = "port-unreachable"
-	RuleActionRejectMethodDrop            = "drop"
+	RuleActionRejectMethodDefault            = "default"
+	RuleActionRejectMethodReset              = "reset"
+	RuleActionRejectMethodNetworkUnreachable = "network-unreachable"
+	RuleActionRejectMethodHostUnreachable    = "host-unreachable"
+	RuleActionRejectMethodPortUnreachable    = "port-unreachable"
+	RuleActionRejectMethodDrop               = "drop"
 )
diff --git a/inbound/tun.go b/inbound/tun.go
index 0d856419..f04ae71b 100644
--- a/inbound/tun.go
+++ b/inbound/tun.go
@@ -404,9 +404,15 @@ func (t *TUN) Close() error {
 	)
 }
 
-func (t *TUN) PrepareConnection(source M.Socksaddr, destination M.Socksaddr) error {
-	// TODO: implement rejects
-	return nil
+func (t *TUN) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
+	return t.router.PreMatch(adapter.InboundContext{
+		Inbound:        t.tag,
+		InboundType:    C.TypeTun,
+		Network:        network,
+		Source:         source,
+		Destination:    destination,
+		InboundOptions: t.inboundOptions,
+	})
 }
 
 func (t *TUN) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
diff --git a/option/rule_action.go b/option/rule_action.go
index 4f0ec177..f446d81d 100644
--- a/option/rule_action.go
+++ b/option/rule_action.go
@@ -136,23 +136,29 @@ type DNSRouteActionOptions struct {
 	ClientSubnet *AddrPrefix `json:"client_subnet,omitempty"`
 }
 
-type RejectActionOptions struct {
-	Method RejectMethod `json:"method,omitempty"`
+type _RejectActionOptions struct {
+	Method string `json:"method,omitempty"`
 }
 
-type RejectMethod string
+type RejectActionOptions _RejectActionOptions
 
-func (m *RejectMethod) UnmarshalJSON(bytes []byte) error {
-	err := json.Unmarshal(bytes, (*string)(m))
+func (r *RejectActionOptions) UnmarshalJSON(bytes []byte) error {
+	err := json.Unmarshal(bytes, (*_RejectActionOptions)(r))
 	if err != nil {
 		return err
 	}
-	switch *m {
-	case C.RuleActionRejectMethodDefault, C.RuleActionRejectMethodPortUnreachable, C.RuleActionRejectMethodDrop:
-		return nil
+	switch r.Method {
+	case "", C.RuleActionRejectMethodDefault:
+		r.Method = C.RuleActionRejectMethodDefault
+	case C.RuleActionRejectMethodReset,
+		C.RuleActionRejectMethodNetworkUnreachable,
+		C.RuleActionRejectMethodHostUnreachable,
+		C.RuleActionRejectMethodPortUnreachable,
+		C.RuleActionRejectMethodDrop:
 	default:
-		return E.New("unknown reject method: " + *m)
+		return E.New("unknown reject method: " + r.Method)
 	}
+	return nil
 }
 
 type RouteActionSniff struct {
diff --git a/route/route.go b/route/route.go
index 86d4d95c..cecd0f2a 100644
--- a/route/route.go
+++ b/route/route.go
@@ -21,7 +21,6 @@ import (
 	"github.com/sagernet/sing-box/route/rule"
 	"github.com/sagernet/sing-dns"
 	"github.com/sagernet/sing-mux"
-	"github.com/sagernet/sing-tun"
 	"github.com/sagernet/sing-vmess"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
@@ -89,7 +88,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 	if deadline.NeedAdditionalReadDeadline(conn) {
 		conn = deadline.NewConn(conn)
 	}
-	selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, conn, nil, -1)
+	selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, conn, nil, -1)
 	if err != nil {
 		return err
 	}
@@ -108,16 +107,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 			selectReturn = true
 		case *rule.RuleActionReject:
 			buf.ReleaseMulti(buffers)
-			var rejectErr error
-			switch action.Method {
-			case C.RuleActionRejectMethodDefault:
-				rejectErr = os.ErrClosed
-			case C.RuleActionRejectMethodPortUnreachable:
-				rejectErr = syscall.ECONNREFUSED
-			case C.RuleActionRejectMethodDrop:
-				rejectErr = tun.ErrDrop
-			}
-			N.CloseOnHandshakeFailure(conn, onClose, rejectErr)
+			N.CloseOnHandshakeFailure(conn, onClose, action.Error())
 			return nil
 		}
 	}
@@ -236,7 +226,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 		conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn))
 	}*/
 
-	selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, nil, conn, -1)
+	selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1)
 	if err != nil {
 		return err
 	}
@@ -306,8 +296,23 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 	return nil
 }
 
+func (r *Router) PreMatch(metadata adapter.InboundContext) error {
+	selectedRule, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1)
+	if err != nil {
+		return err
+	}
+	if selectedRule == nil {
+		return nil
+	}
+	rejectAction, isReject := selectedRule.Action().(*rule.RuleActionReject)
+	if !isReject {
+		return nil
+	}
+	return rejectAction.Error()
+}
+
 func (r *Router) matchRule(
-	ctx context.Context, metadata *adapter.InboundContext,
+	ctx context.Context, metadata *adapter.InboundContext, preMatch bool,
 	inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int,
 ) (selectedRule adapter.Rule, selectedRuleIndex int, buffers []*buf.Buffer, fatalErr error) {
 	if r.processSearcher != nil && metadata.ProcessInfo == nil {
@@ -370,7 +375,7 @@ func (r *Router) matchRule(
 
 	//nolint:staticcheck
 	if metadata.InboundOptions != common.DefaultValue[option.InboundOptions]() {
-		if metadata.InboundOptions.SniffEnabled {
+		if !preMatch && metadata.InboundOptions.SniffEnabled {
 			newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
 				OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
 				Timeout:             time.Duration(metadata.InboundOptions.SniffTimeout),
@@ -415,15 +420,28 @@ match:
 		if !matched {
 			break
 		}
-		r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
+		if !preMatch {
+			r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
+		} else {
+			switch currentRule.Action().Type() {
+			case C.RuleActionTypeReject, C.RuleActionTypeResolve:
+				r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
+			}
+		}
 		switch action := currentRule.Action().(type) {
 		case *rule.RuleActionSniff:
-			newBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
-			if newErr != nil {
-				fatalErr = newErr
-				return
+			if !preMatch {
+				newBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
+				if newErr != nil {
+					fatalErr = newErr
+					return
+				}
+				buffers = append(buffers, newBuffers...)
+			} else {
+				selectedRule = currentRule
+				selectedRuleIndex = currentRuleIndex
+				break match
 			}
-			buffers = append(buffers, newBuffers...)
 		case *rule.RuleActionResolve:
 			fatalErr = r.actionResolve(ctx, metadata, action)
 			if fatalErr != nil {
@@ -436,7 +454,7 @@ match:
 		}
 		ruleIndex = currentRuleIndex
 	}
-	if metadata.Destination.Addr.IsUnspecified() {
+	if !preMatch && metadata.Destination.Addr.IsUnspecified() {
 		newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)
 		if newErr != nil {
 			fatalErr = newErr
diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go
index e85fc763..a157e94e 100644
--- a/route/rule/rule_action.go
+++ b/route/rule/rule_action.go
@@ -2,7 +2,9 @@ package rule
 
 import (
 	"net/netip"
+	"os"
 	"strings"
+	"syscall"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
@@ -10,6 +12,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-dns"
+	"github.com/sagernet/sing-tun"
 	E "github.com/sagernet/sing/common/exceptions"
 	F "github.com/sagernet/sing/common/format"
 )
@@ -22,10 +25,10 @@ func NewRuleAction(action option.RuleAction) (adapter.RuleAction, error) {
 			UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping,
 		}, nil
 	case C.RuleActionTypeReturn:
-		return &RuleActionReject{}, nil
+		return &RuleActionReturn{}, nil
 	case C.RuleActionTypeReject:
 		return &RuleActionReject{
-			Method: string(action.RejectOptions.Method),
+			Method: action.RejectOptions.Method,
 		}, nil
 	case C.RuleActionTypeHijackDNS:
 		return &RuleActionHijackDNS{}, nil
@@ -58,7 +61,7 @@ func NewDNSRuleAction(action option.DNSRuleAction) adapter.RuleAction {
 		return &RuleActionReturn{}
 	case C.RuleActionTypeReject:
 		return &RuleActionReject{
-			Method: string(action.RejectOptions.Method),
+			Method: action.RejectOptions.Method,
 		}
 	default:
 		panic(F.ToString("unknown rule action: ", action.Action))
@@ -118,6 +121,23 @@ func (r *RuleActionReject) String() string {
 	return F.ToString("reject(", r.Method, ")")
 }
 
+func (r *RuleActionReject) Error() error {
+	switch r.Method {
+	case C.RuleActionRejectMethodReset:
+		return os.ErrClosed
+	case C.RuleActionRejectMethodNetworkUnreachable:
+		return syscall.ENETUNREACH
+	case C.RuleActionRejectMethodHostUnreachable:
+		return syscall.EHOSTUNREACH
+	case C.RuleActionRejectMethodDefault, C.RuleActionRejectMethodPortUnreachable:
+		return syscall.ECONNREFUSED
+	case C.RuleActionRejectMethodDrop:
+		return tun.ErrDrop
+	default:
+		panic(F.ToString("unknown reject method: ", r.Method))
+	}
+}
+
 type RuleActionHijackDNS struct{}
 
 func (r *RuleActionHijackDNS) Type() string {