mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-12-04 20:36:37 +00:00
Implement dns-hijack
This commit is contained in:
parent
642b71c237
commit
19d5bc4921
|
@ -1,6 +1,9 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box"
|
"github.com/sagernet/sing-box"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
@ -23,7 +26,9 @@ func init() {
|
||||||
func createPreStartedClient() (*box.Box, error) {
|
func createPreStartedClient() (*box.Box, error) {
|
||||||
options, err := readConfigAndMerge()
|
options, err := readConfigAndMerge()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
if !(errors.Is(err, os.ErrNotExist) && len(configDirectories) == 0 && len(configPaths) == 1) || configPaths[0] != "config.json" {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
instance, err := box.New(box.Options{Options: options})
|
instance, err := box.New(box.Options{Options: options})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -36,8 +36,9 @@ type TUN struct {
|
||||||
router adapter.Router
|
router adapter.Router
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
// Deprecated
|
// Deprecated
|
||||||
inboundOptions option.InboundOptions
|
inboundOptions option.InboundOptions
|
||||||
tunOptions tun.Options
|
tunOptions tun.Options
|
||||||
|
// Deprecated
|
||||||
endpointIndependentNat bool
|
endpointIndependentNat bool
|
||||||
udpTimeout time.Duration
|
udpTimeout time.Duration
|
||||||
stack string
|
stack string
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
@ -50,14 +51,15 @@ func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter
|
||||||
metadata.Destination = M.Socksaddr{}
|
metadata.Destination = M.Socksaddr{}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
for {
|
for {
|
||||||
err := d.handleConnection(ctx, conn, metadata)
|
conn.SetReadDeadline(time.Now().Add(C.DNSTimeout))
|
||||||
|
err := HandleStreamDNSRequest(ctx, d.router, conn, metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net.Conn, metadata adapter.InboundContext) error {
|
||||||
var queryLength uint16
|
var queryLength uint16
|
||||||
err := binary.Read(conn, binary.BigEndian, &queryLength)
|
err := binary.Read(conn, binary.BigEndian, &queryLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -79,7 +81,7 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap
|
||||||
}
|
}
|
||||||
metadataInQuery := metadata
|
metadataInQuery := metadata
|
||||||
go func() error {
|
go func() error {
|
||||||
response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
|
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -100,10 +102,14 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap
|
||||||
|
|
||||||
// Deprecated
|
// Deprecated
|
||||||
func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||||
|
return NewDNSPacketConnection(ctx, d.router, conn, nil, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error {
|
||||||
metadata.Destination = M.Socksaddr{}
|
metadata.Destination = M.Socksaddr{}
|
||||||
var reader N.PacketReader = conn
|
var reader N.PacketReader = conn
|
||||||
var counters []N.CountFunc
|
var counters []N.CountFunc
|
||||||
var cachedPackets []*N.PacketBuffer
|
cachedPackets = common.Reverse(cachedPackets)
|
||||||
for {
|
for {
|
||||||
reader, counters = N.UnwrapCountPacketReader(reader, counters)
|
reader, counters = N.UnwrapCountPacketReader(reader, counters)
|
||||||
if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
|
if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
|
||||||
|
@ -115,7 +121,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
|
||||||
}
|
}
|
||||||
if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
|
if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
|
||||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||||
return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata)
|
return newDNSPacketConnection(ctx, router, conn, readWaiter, counters, cachedPackets, metadata)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -161,7 +167,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
|
||||||
}
|
}
|
||||||
metadataInQuery := metadata
|
metadataInQuery := metadata
|
||||||
go func() error {
|
go func() error {
|
||||||
response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
|
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cancel(err)
|
cancel(err)
|
||||||
return err
|
return err
|
||||||
|
@ -186,7 +192,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
|
||||||
return group.Run(fastClose)
|
return group.Run(fastClose)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
|
func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
|
||||||
fastClose, cancel := common.ContextWithCancelCause(ctx)
|
fastClose, cancel := common.ContextWithCancelCause(ctx)
|
||||||
timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
|
timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
|
||||||
var group task.Group
|
var group task.Group
|
||||||
|
@ -206,11 +212,12 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa
|
||||||
}
|
}
|
||||||
err = message.Unpack(packet.Buffer.Bytes())
|
err = message.Unpack(packet.Buffer.Bytes())
|
||||||
packet.Buffer.Release()
|
packet.Buffer.Release()
|
||||||
|
destination = packet.Destination
|
||||||
|
N.PutPacketBuffer(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cancel(err)
|
cancel(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
destination = packet.Destination
|
|
||||||
} else {
|
} else {
|
||||||
buffer, destination, err = readWaiter.WaitReadPacket()
|
buffer, destination, err = readWaiter.WaitReadPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -230,7 +237,7 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa
|
||||||
}
|
}
|
||||||
metadataInQuery := metadata
|
metadataInQuery := metadata
|
||||||
go func() error {
|
go func() error {
|
||||||
response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
|
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cancel(err)
|
cancel(err)
|
||||||
return err
|
return err
|
||||||
|
|
93
route/dns.go
Normal file
93
route/dns.go
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
package route
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/outbound"
|
||||||
|
"github.com/sagernet/sing-dns"
|
||||||
|
"github.com/sagernet/sing-tun"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/udpnat2"
|
||||||
|
|
||||||
|
mDNS "github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||||
|
metadata.Destination = M.Socksaddr{}
|
||||||
|
for {
|
||||||
|
conn.SetReadDeadline(time.Now().Add(C.DNSTimeout))
|
||||||
|
err := outbound.HandleStreamDNSRequest(ctx, r, conn, metadata)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) {
|
||||||
|
if uConn, isUDPNAT2 := conn.(*udpnat.Conn); isUDPNAT2 {
|
||||||
|
metadata.Destination = M.Socksaddr{}
|
||||||
|
for _, packet := range packetBuffers {
|
||||||
|
buffer := packet.Buffer
|
||||||
|
destination := packet.Destination
|
||||||
|
N.PutPacketBuffer(packet)
|
||||||
|
go ExchangeDNSPacket(ctx, r, uConn, buffer, metadata, destination)
|
||||||
|
}
|
||||||
|
uConn.SetHandler(&dnsHijacker{
|
||||||
|
router: r,
|
||||||
|
conn: conn,
|
||||||
|
ctx: ctx,
|
||||||
|
metadata: metadata,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := outbound.NewDNSPacketConnection(ctx, r, conn, packetBuffers, metadata)
|
||||||
|
if err != nil && !E.IsClosedOrCanceled(err) {
|
||||||
|
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) {
|
||||||
|
err := exchangeDNSPacket(ctx, router, conn, buffer, metadata, destination)
|
||||||
|
if err != nil && !errors.Is(err, tun.ErrDrop) && !E.IsClosedOrCanceled(err) {
|
||||||
|
router.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) error {
|
||||||
|
var message mDNS.Msg
|
||||||
|
err := message.Unpack(buffer.Bytes())
|
||||||
|
buffer.Release()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "unpack request")
|
||||||
|
}
|
||||||
|
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = conn.WritePacket(responseBuffer, destination)
|
||||||
|
responseBuffer.Release()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type dnsHijacker struct {
|
||||||
|
router *Router
|
||||||
|
conn N.PacketConn
|
||||||
|
ctx context.Context
|
||||||
|
metadata adapter.InboundContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHijacker) NewPacketEx(buffer *buf.Buffer, destination M.Socksaddr) {
|
||||||
|
go ExchangeDNSPacket(h.ctx, h.router, h.conn, buffer, h.metadata, destination)
|
||||||
|
}
|
139
route/route.go
139
route/route.go
|
@ -88,7 +88,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
|
||||||
if deadline.NeedAdditionalReadDeadline(conn) {
|
if deadline.NeedAdditionalReadDeadline(conn) {
|
||||||
conn = deadline.NewConn(conn)
|
conn = deadline.NewConn(conn)
|
||||||
}
|
}
|
||||||
selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, conn, nil, -1)
|
selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -109,6 +109,12 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
|
||||||
buf.ReleaseMulti(buffers)
|
buf.ReleaseMulti(buffers)
|
||||||
N.CloseOnHandshakeFailure(conn, onClose, action.Error())
|
N.CloseOnHandshakeFailure(conn, onClose, action.Error())
|
||||||
return nil
|
return nil
|
||||||
|
case *rule.RuleActionHijackDNS:
|
||||||
|
for _, buffer := range buffers {
|
||||||
|
conn = bufio.NewCachedConn(conn, buffer)
|
||||||
|
}
|
||||||
|
r.hijackDNSStream(ctx, conn, metadata)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if selectedRule == nil || selectReturn {
|
if selectedRule == nil || selectReturn {
|
||||||
|
@ -226,7 +232,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
||||||
conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn))
|
conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn))
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1)
|
selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -238,32 +244,35 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
||||||
var loaded bool
|
var loaded bool
|
||||||
selectedOutbound, loaded = r.Outbound(action.Outbound)
|
selectedOutbound, loaded = r.Outbound(action.Outbound)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
buf.ReleaseMulti(buffers)
|
N.ReleaseMultiPacketBuffer(packetBuffers)
|
||||||
return E.New("outbound not found: ", action.Outbound)
|
return E.New("outbound not found: ", action.Outbound)
|
||||||
}
|
}
|
||||||
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
|
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
|
||||||
case *rule.RuleActionReturn:
|
case *rule.RuleActionReturn:
|
||||||
selectReturn = true
|
selectReturn = true
|
||||||
case *rule.RuleActionReject:
|
case *rule.RuleActionReject:
|
||||||
buf.ReleaseMulti(buffers)
|
N.ReleaseMultiPacketBuffer(packetBuffers)
|
||||||
N.CloseOnHandshakeFailure(conn, onClose, syscall.ECONNREFUSED)
|
N.CloseOnHandshakeFailure(conn, onClose, syscall.ECONNREFUSED)
|
||||||
return nil
|
return nil
|
||||||
|
case *rule.RuleActionHijackDNS:
|
||||||
|
r.hijackDNSPacket(ctx, conn, packetBuffers, metadata)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if selectedRule == nil || selectReturn {
|
if selectedRule == nil || selectReturn {
|
||||||
if r.defaultOutboundForPacketConnection == nil {
|
if r.defaultOutboundForPacketConnection == nil {
|
||||||
buf.ReleaseMulti(buffers)
|
N.ReleaseMultiPacketBuffer(packetBuffers)
|
||||||
return E.New("missing default outbound with UDP support")
|
return E.New("missing default outbound with UDP support")
|
||||||
}
|
}
|
||||||
selectedOutbound = r.defaultOutboundForPacketConnection
|
selectedOutbound = r.defaultOutboundForPacketConnection
|
||||||
}
|
}
|
||||||
if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
|
if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
|
||||||
buf.ReleaseMulti(buffers)
|
N.ReleaseMultiPacketBuffer(packetBuffers)
|
||||||
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
|
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
|
||||||
}
|
}
|
||||||
for _, buffer := range buffers {
|
for _, buffer := range packetBuffers {
|
||||||
// TODO: check if metadata.Destination == packet destination
|
conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination)
|
||||||
conn = bufio.NewCachedPacketConn(conn, buffer, metadata.Destination)
|
N.PutPacketBuffer(buffer)
|
||||||
}
|
}
|
||||||
if r.clashServer != nil {
|
if r.clashServer != nil {
|
||||||
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule)
|
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule)
|
||||||
|
@ -297,7 +306,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) PreMatch(metadata adapter.InboundContext) error {
|
func (r *Router) PreMatch(metadata adapter.InboundContext) error {
|
||||||
selectedRule, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1)
|
selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -314,7 +323,10 @@ func (r *Router) PreMatch(metadata adapter.InboundContext) error {
|
||||||
func (r *Router) matchRule(
|
func (r *Router) matchRule(
|
||||||
ctx context.Context, metadata *adapter.InboundContext, preMatch bool,
|
ctx context.Context, metadata *adapter.InboundContext, preMatch bool,
|
||||||
inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int,
|
inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int,
|
||||||
) (selectedRule adapter.Rule, selectedRuleIndex int, buffers []*buf.Buffer, fatalErr error) {
|
) (
|
||||||
|
selectedRule adapter.Rule, selectedRuleIndex int,
|
||||||
|
buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error,
|
||||||
|
) {
|
||||||
if r.processSearcher != nil && metadata.ProcessInfo == nil {
|
if r.processSearcher != nil && metadata.ProcessInfo == nil {
|
||||||
var originDestination netip.AddrPort
|
var originDestination netip.AddrPort
|
||||||
if metadata.OriginDestination.IsValid() {
|
if metadata.OriginDestination.IsValid() {
|
||||||
|
@ -376,7 +388,7 @@ func (r *Router) matchRule(
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
if metadata.InboundOptions != common.DefaultValue[option.InboundOptions]() {
|
if metadata.InboundOptions != common.DefaultValue[option.InboundOptions]() {
|
||||||
if !preMatch && metadata.InboundOptions.SniffEnabled {
|
if !preMatch && metadata.InboundOptions.SniffEnabled {
|
||||||
newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
|
newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
|
||||||
OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
|
OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
|
||||||
Timeout: time.Duration(metadata.InboundOptions.SniffTimeout),
|
Timeout: time.Duration(metadata.InboundOptions.SniffTimeout),
|
||||||
}, inputConn, inputPacketConn)
|
}, inputConn, inputPacketConn)
|
||||||
|
@ -384,7 +396,11 @@ func (r *Router) matchRule(
|
||||||
fatalErr = newErr
|
fatalErr = newErr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buffers = append(buffers, newBuffers...)
|
if newBuffer != nil {
|
||||||
|
buffers = []*buf.Buffer{newBuffer}
|
||||||
|
} else if len(newPackerBuffers) > 0 {
|
||||||
|
packetBuffers = newPackerBuffers
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS {
|
if dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS {
|
||||||
fatalErr = r.actionResolve(ctx, metadata, &rule.RuleActionResolve{
|
fatalErr = r.actionResolve(ctx, metadata, &rule.RuleActionResolve{
|
||||||
|
@ -421,22 +437,36 @@ match:
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if !preMatch {
|
if !preMatch {
|
||||||
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
|
ruleDescription := currentRule.String()
|
||||||
|
if ruleDescription != "" {
|
||||||
|
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
|
||||||
|
} else {
|
||||||
|
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action())
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
switch currentRule.Action().Type() {
|
switch currentRule.Action().Type() {
|
||||||
case C.RuleActionTypeReject, C.RuleActionTypeResolve:
|
case C.RuleActionTypeReject, C.RuleActionTypeResolve:
|
||||||
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
|
ruleDescription := currentRule.String()
|
||||||
|
if ruleDescription != "" {
|
||||||
|
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
|
||||||
|
} else {
|
||||||
|
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] => ", currentRule.Action())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch action := currentRule.Action().(type) {
|
switch action := currentRule.Action().(type) {
|
||||||
case *rule.RuleActionSniff:
|
case *rule.RuleActionSniff:
|
||||||
if !preMatch {
|
if !preMatch {
|
||||||
newBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
|
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
|
||||||
if newErr != nil {
|
if newErr != nil {
|
||||||
fatalErr = newErr
|
fatalErr = newErr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buffers = append(buffers, newBuffers...)
|
if newBuffer != nil {
|
||||||
|
buffers = append(buffers, newBuffer)
|
||||||
|
} else if len(newPacketBuffers) > 0 {
|
||||||
|
packetBuffers = append(packetBuffers, newPacketBuffers...)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
selectedRule = currentRule
|
selectedRule = currentRule
|
||||||
selectedRuleIndex = currentRuleIndex
|
selectedRuleIndex = currentRuleIndex
|
||||||
|
@ -455,12 +485,16 @@ match:
|
||||||
ruleIndex = currentRuleIndex
|
ruleIndex = currentRuleIndex
|
||||||
}
|
}
|
||||||
if !preMatch && metadata.Destination.Addr.IsUnspecified() {
|
if !preMatch && metadata.Destination.Addr.IsUnspecified() {
|
||||||
newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)
|
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)
|
||||||
if newErr != nil {
|
if newErr != nil {
|
||||||
fatalErr = newErr
|
fatalErr = newErr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buffers = append(buffers, newBuffers...)
|
if newBuffer != nil {
|
||||||
|
buffers = append(buffers, newBuffer)
|
||||||
|
} else if len(newPacketBuffers) > 0 {
|
||||||
|
packetBuffers = append(packetBuffers, newPacketBuffers...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -468,18 +502,31 @@ match:
|
||||||
func (r *Router) actionSniff(
|
func (r *Router) actionSniff(
|
||||||
ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff,
|
ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff,
|
||||||
inputConn net.Conn, inputPacketConn N.PacketConn,
|
inputConn net.Conn, inputPacketConn N.PacketConn,
|
||||||
) (buffers []*buf.Buffer, fatalErr error) {
|
) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
|
||||||
if sniff.Skip(metadata) {
|
if sniff.Skip(metadata) {
|
||||||
return
|
return
|
||||||
} else if inputConn != nil && len(action.StreamSniffers) > 0 {
|
} else if inputConn != nil {
|
||||||
buffer := buf.NewPacket()
|
sniffBuffer := buf.NewPacket()
|
||||||
|
var streamSniffers []sniff.StreamSniffer
|
||||||
|
if len(action.StreamSniffers) > 0 {
|
||||||
|
streamSniffers = action.StreamSniffers
|
||||||
|
} else {
|
||||||
|
streamSniffers = []sniff.StreamSniffer{
|
||||||
|
sniff.TLSClientHello,
|
||||||
|
sniff.HTTPHost,
|
||||||
|
sniff.StreamDomainNameQuery,
|
||||||
|
sniff.BitTorrent,
|
||||||
|
sniff.SSH,
|
||||||
|
sniff.RDP,
|
||||||
|
}
|
||||||
|
}
|
||||||
err := sniff.PeekStream(
|
err := sniff.PeekStream(
|
||||||
ctx,
|
ctx,
|
||||||
metadata,
|
metadata,
|
||||||
inputConn,
|
inputConn,
|
||||||
buffer,
|
sniffBuffer,
|
||||||
action.Timeout,
|
action.Timeout,
|
||||||
action.StreamSniffers...,
|
streamSniffers...,
|
||||||
)
|
)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
//goland:noinspection GoDeprecation
|
//goland:noinspection GoDeprecation
|
||||||
|
@ -497,15 +544,15 @@ func (r *Router) actionSniff(
|
||||||
r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol)
|
r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !buffer.IsEmpty() {
|
if !sniffBuffer.IsEmpty() {
|
||||||
buffers = append(buffers, buffer)
|
buffer = sniffBuffer
|
||||||
} else {
|
} else {
|
||||||
buffer.Release()
|
sniffBuffer.Release()
|
||||||
}
|
}
|
||||||
} else if inputPacketConn != nil && len(action.PacketSniffers) > 0 {
|
} else if inputPacketConn != nil {
|
||||||
for {
|
for {
|
||||||
var (
|
var (
|
||||||
buffer = buf.NewPacket()
|
sniffBuffer = buf.NewPacket()
|
||||||
destination M.Socksaddr
|
destination M.Socksaddr
|
||||||
done = make(chan struct{})
|
done = make(chan struct{})
|
||||||
err error
|
err error
|
||||||
|
@ -516,7 +563,7 @@ func (r *Router) actionSniff(
|
||||||
sniffTimeout = action.Timeout
|
sniffTimeout = action.Timeout
|
||||||
}
|
}
|
||||||
inputPacketConn.SetReadDeadline(time.Now().Add(sniffTimeout))
|
inputPacketConn.SetReadDeadline(time.Now().Add(sniffTimeout))
|
||||||
destination, err = inputPacketConn.ReadPacket(buffer)
|
destination, err = inputPacketConn.ReadPacket(sniffBuffer)
|
||||||
inputPacketConn.SetReadDeadline(time.Time{})
|
inputPacketConn.SetReadDeadline(time.Time{})
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
@ -528,7 +575,7 @@ func (r *Router) actionSniff(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
sniffBuffer.Release()
|
||||||
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
fatalErr = err
|
fatalErr = err
|
||||||
return
|
return
|
||||||
|
@ -538,22 +585,40 @@ func (r *Router) actionSniff(
|
||||||
if metadata.Destination.Addr.IsUnspecified() {
|
if metadata.Destination.Addr.IsUnspecified() {
|
||||||
metadata.Destination = destination
|
metadata.Destination = destination
|
||||||
}
|
}
|
||||||
if len(buffers) > 0 {
|
if len(packetBuffers) > 0 {
|
||||||
err = sniff.PeekPacket(
|
err = sniff.PeekPacket(
|
||||||
ctx,
|
ctx,
|
||||||
metadata,
|
metadata,
|
||||||
buffer.Bytes(),
|
sniffBuffer.Bytes(),
|
||||||
sniff.QUICClientHello,
|
sniff.QUICClientHello,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
var packetSniffers []sniff.PacketSniffer
|
||||||
|
if len(action.PacketSniffers) > 0 {
|
||||||
|
packetSniffers = action.PacketSniffers
|
||||||
|
} else {
|
||||||
|
packetSniffers = []sniff.PacketSniffer{
|
||||||
|
sniff.DomainNameQuery,
|
||||||
|
sniff.QUICClientHello,
|
||||||
|
sniff.STUNMessage,
|
||||||
|
sniff.UTP,
|
||||||
|
sniff.UDPTracker,
|
||||||
|
sniff.DTLSRecord,
|
||||||
|
}
|
||||||
|
}
|
||||||
err = sniff.PeekPacket(
|
err = sniff.PeekPacket(
|
||||||
ctx, metadata,
|
ctx, metadata,
|
||||||
buffer.Bytes(),
|
sniffBuffer.Bytes(),
|
||||||
action.PacketSniffers...,
|
packetSniffers...,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
buffers = append(buffers, buffer)
|
packetBuffer := N.NewPacketBuffer()
|
||||||
if E.IsMulti(err, sniff.ErrClientHelloFragmented) && len(buffers) == 0 {
|
*packetBuffer = N.PacketBuffer{
|
||||||
|
Buffer: sniffBuffer,
|
||||||
|
Destination: destination,
|
||||||
|
}
|
||||||
|
packetBuffers = append(packetBuffers, packetBuffer)
|
||||||
|
if E.IsMulti(err, sniff.ErrClientHelloFragmented) && len(packetBuffers) == 0 {
|
||||||
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
|
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -162,53 +162,31 @@ func (r *RuleActionSniff) Type() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RuleActionSniff) build() error {
|
func (r *RuleActionSniff) build() error {
|
||||||
if len(r.StreamSniffers) > 0 || len(r.PacketSniffers) > 0 {
|
for _, name := range r.snifferNames {
|
||||||
return nil
|
switch name {
|
||||||
}
|
case C.ProtocolTLS:
|
||||||
if len(r.snifferNames) > 0 {
|
r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello)
|
||||||
for _, name := range r.snifferNames {
|
case C.ProtocolHTTP:
|
||||||
switch name {
|
r.StreamSniffers = append(r.StreamSniffers, sniff.HTTPHost)
|
||||||
case C.ProtocolTLS:
|
case C.ProtocolQUIC:
|
||||||
r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello)
|
r.PacketSniffers = append(r.PacketSniffers, sniff.QUICClientHello)
|
||||||
case C.ProtocolHTTP:
|
case C.ProtocolDNS:
|
||||||
r.StreamSniffers = append(r.StreamSniffers, sniff.HTTPHost)
|
r.StreamSniffers = append(r.StreamSniffers, sniff.StreamDomainNameQuery)
|
||||||
case C.ProtocolQUIC:
|
r.PacketSniffers = append(r.PacketSniffers, sniff.DomainNameQuery)
|
||||||
r.PacketSniffers = append(r.PacketSniffers, sniff.QUICClientHello)
|
case C.ProtocolSTUN:
|
||||||
case C.ProtocolDNS:
|
r.PacketSniffers = append(r.PacketSniffers, sniff.STUNMessage)
|
||||||
r.StreamSniffers = append(r.StreamSniffers, sniff.StreamDomainNameQuery)
|
case C.ProtocolBitTorrent:
|
||||||
r.PacketSniffers = append(r.PacketSniffers, sniff.DomainNameQuery)
|
r.StreamSniffers = append(r.StreamSniffers, sniff.BitTorrent)
|
||||||
case C.ProtocolSTUN:
|
r.PacketSniffers = append(r.PacketSniffers, sniff.UTP)
|
||||||
r.PacketSniffers = append(r.PacketSniffers, sniff.STUNMessage)
|
r.PacketSniffers = append(r.PacketSniffers, sniff.UDPTracker)
|
||||||
case C.ProtocolBitTorrent:
|
case C.ProtocolDTLS:
|
||||||
r.StreamSniffers = append(r.StreamSniffers, sniff.BitTorrent)
|
r.PacketSniffers = append(r.PacketSniffers, sniff.DTLSRecord)
|
||||||
r.PacketSniffers = append(r.PacketSniffers, sniff.UTP)
|
case C.ProtocolSSH:
|
||||||
r.PacketSniffers = append(r.PacketSniffers, sniff.UDPTracker)
|
r.StreamSniffers = append(r.StreamSniffers, sniff.SSH)
|
||||||
case C.ProtocolDTLS:
|
case C.ProtocolRDP:
|
||||||
r.PacketSniffers = append(r.PacketSniffers, sniff.DTLSRecord)
|
r.StreamSniffers = append(r.StreamSniffers, sniff.RDP)
|
||||||
case C.ProtocolSSH:
|
default:
|
||||||
r.StreamSniffers = append(r.StreamSniffers, sniff.SSH)
|
return E.New("unknown sniffer: ", name)
|
||||||
case C.ProtocolRDP:
|
|
||||||
r.StreamSniffers = append(r.StreamSniffers, sniff.RDP)
|
|
||||||
default:
|
|
||||||
return E.New("unknown sniffer: ", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
r.StreamSniffers = []sniff.StreamSniffer{
|
|
||||||
sniff.TLSClientHello,
|
|
||||||
sniff.HTTPHost,
|
|
||||||
sniff.StreamDomainNameQuery,
|
|
||||||
sniff.BitTorrent,
|
|
||||||
sniff.SSH,
|
|
||||||
sniff.RDP,
|
|
||||||
}
|
|
||||||
r.PacketSniffers = []sniff.PacketSniffer{
|
|
||||||
sniff.DomainNameQuery,
|
|
||||||
sniff.QUICClientHello,
|
|
||||||
sniff.STUNMessage,
|
|
||||||
sniff.UTP,
|
|
||||||
sniff.UDPTracker,
|
|
||||||
sniff.DTLSRecord,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
Loading…
Reference in a new issue