Implement dns-hijack

This commit is contained in:
世界 2024-10-23 13:44:08 +08:00
parent 68a63e192f
commit abb1b14f3b
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 245 additions and 96 deletions

View file

@ -1,6 +1,9 @@
package main package main
import ( import (
"errors"
"os"
"github.com/sagernet/sing-box" "github.com/sagernet/sing-box"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -23,7 +26,9 @@ func init() {
func createPreStartedClient() (*box.Box, error) { func createPreStartedClient() (*box.Box, error) {
options, err := readConfigAndMerge() options, err := readConfigAndMerge()
if err != nil { if err != nil {
return nil, err if !(errors.Is(err, os.ErrNotExist) && len(configDirectories) == 0 && len(configPaths) == 1) || configPaths[0] != "config.json" {
return nil, err
}
} }
instance, err := box.New(box.Options{Options: options}) instance, err := box.New(box.Options{Options: options})
if err != nil { if err != nil {

View file

@ -36,8 +36,9 @@ type TUN struct {
router adapter.Router router adapter.Router
logger log.ContextLogger logger log.ContextLogger
// Deprecated // Deprecated
inboundOptions option.InboundOptions inboundOptions option.InboundOptions
tunOptions tun.Options tunOptions tun.Options
// Deprecated
endpointIndependentNat bool endpointIndependentNat bool
udpTimeout time.Duration udpTimeout time.Duration
stack string stack string

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"net" "net"
"os" "os"
"time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@ -50,14 +51,15 @@ func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter
metadata.Destination = M.Socksaddr{} metadata.Destination = M.Socksaddr{}
defer conn.Close() defer conn.Close()
for { for {
err := d.handleConnection(ctx, conn, metadata) conn.SetReadDeadline(time.Now().Add(C.DNSTimeout))
err := HandleStreamDNSRequest(ctx, d.router, conn, metadata)
if err != nil { if err != nil {
return err return err
} }
} }
} }
func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net.Conn, metadata adapter.InboundContext) error {
var queryLength uint16 var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength) err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil { if err != nil {
@ -79,7 +81,7 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap
} }
metadataInQuery := metadata metadataInQuery := metadata
go func() error { go func() error {
response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
if err != nil { if err != nil {
return err return err
} }
@ -100,10 +102,14 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap
// Deprecated // Deprecated
func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return NewDNSPacketConnection(ctx, d.router, conn, nil, metadata)
}
func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error {
metadata.Destination = M.Socksaddr{} metadata.Destination = M.Socksaddr{}
var reader N.PacketReader = conn var reader N.PacketReader = conn
var counters []N.CountFunc var counters []N.CountFunc
var cachedPackets []*N.PacketBuffer cachedPackets = common.Reverse(cachedPackets)
for { for {
reader, counters = N.UnwrapCountPacketReader(reader, counters) reader, counters = N.UnwrapCountPacketReader(reader, counters)
if cachedReader, isCached := reader.(N.CachedPacketReader); isCached { if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
@ -115,7 +121,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
} }
if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created { if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata) return newDNSPacketConnection(ctx, router, conn, readWaiter, counters, cachedPackets, metadata)
} }
break break
} }
@ -161,7 +167,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
} }
metadataInQuery := metadata metadataInQuery := metadata
go func() error { go func() error {
response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
if err != nil { if err != nil {
cancel(err) cancel(err)
return err return err
@ -186,7 +192,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
return group.Run(fastClose) return group.Run(fastClose)
} }
func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
fastClose, cancel := common.ContextWithCancelCause(ctx) fastClose, cancel := common.ContextWithCancelCause(ctx)
timeout := canceler.New(fastClose, cancel, C.DNSTimeout) timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
var group task.Group var group task.Group
@ -206,11 +212,12 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa
} }
err = message.Unpack(packet.Buffer.Bytes()) err = message.Unpack(packet.Buffer.Bytes())
packet.Buffer.Release() packet.Buffer.Release()
destination = packet.Destination
N.PutPacketBuffer(packet)
if err != nil { if err != nil {
cancel(err) cancel(err)
return err return err
} }
destination = packet.Destination
} else { } else {
buffer, destination, err = readWaiter.WaitReadPacket() buffer, destination, err = readWaiter.WaitReadPacket()
if err != nil { if err != nil {
@ -230,7 +237,7 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa
} }
metadataInQuery := metadata metadataInQuery := metadata
go func() error { go func() error {
response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
if err != nil { if err != nil {
cancel(err) cancel(err)
return err return err

93
route/dns.go Normal file
View file

@ -0,0 +1,93 @@
package route
import (
"context"
"errors"
"net"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/outbound"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat2"
mDNS "github.com/miekg/dns"
)
func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
metadata.Destination = M.Socksaddr{}
for {
conn.SetReadDeadline(time.Now().Add(C.DNSTimeout))
err := outbound.HandleStreamDNSRequest(ctx, r, conn, metadata)
if err != nil {
return err
}
}
}
func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) {
if uConn, isUDPNAT2 := conn.(*udpnat.Conn); isUDPNAT2 {
metadata.Destination = M.Socksaddr{}
for _, packet := range packetBuffers {
buffer := packet.Buffer
destination := packet.Destination
N.PutPacketBuffer(packet)
go ExchangeDNSPacket(ctx, r, uConn, buffer, metadata, destination)
}
uConn.SetHandler(&dnsHijacker{
router: r,
conn: conn,
ctx: ctx,
metadata: metadata,
})
return
}
err := outbound.NewDNSPacketConnection(ctx, r, conn, packetBuffers, metadata)
if err != nil && !E.IsClosedOrCanceled(err) {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection"))
}
}
func ExchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) {
err := exchangeDNSPacket(ctx, router, conn, buffer, metadata, destination)
if err != nil && !errors.Is(err, tun.ErrDrop) && !E.IsClosedOrCanceled(err) {
router.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection"))
}
}
func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) error {
var message mDNS.Msg
err := message.Unpack(buffer.Bytes())
buffer.Release()
if err != nil {
return E.Cause(err, "unpack request")
}
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
if err != nil {
return err
}
responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
if err != nil {
return err
}
err = conn.WritePacket(responseBuffer, destination)
responseBuffer.Release()
return err
}
type dnsHijacker struct {
router *Router
conn N.PacketConn
ctx context.Context
metadata adapter.InboundContext
}
func (h *dnsHijacker) NewPacketEx(buffer *buf.Buffer, destination M.Socksaddr) {
go ExchangeDNSPacket(h.ctx, h.router, h.conn, buffer, h.metadata, destination)
}

View file

@ -88,7 +88,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if deadline.NeedAdditionalReadDeadline(conn) { if deadline.NeedAdditionalReadDeadline(conn) {
conn = deadline.NewConn(conn) conn = deadline.NewConn(conn)
} }
selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, conn, nil, -1) selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil, -1)
if err != nil { if err != nil {
return err return err
} }
@ -109,6 +109,12 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
buf.ReleaseMulti(buffers) buf.ReleaseMulti(buffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error()) N.CloseOnHandshakeFailure(conn, onClose, action.Error())
return nil return nil
case *rule.RuleActionHijackDNS:
for _, buffer := range buffers {
conn = bufio.NewCachedConn(conn, buffer)
}
r.hijackDNSStream(ctx, conn, metadata)
return nil
} }
} }
if selectedRule == nil || selectReturn { if selectedRule == nil || selectReturn {
@ -226,7 +232,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn)) conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn))
}*/ }*/
selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1) selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1)
if err != nil { if err != nil {
return err return err
} }
@ -238,32 +244,35 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
var loaded bool var loaded bool
selectedOutbound, loaded = r.Outbound(action.Outbound) selectedOutbound, loaded = r.Outbound(action.Outbound)
if !loaded { if !loaded {
buf.ReleaseMulti(buffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("outbound not found: ", action.Outbound) return E.New("outbound not found: ", action.Outbound)
} }
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
case *rule.RuleActionReturn: case *rule.RuleActionReturn:
selectReturn = true selectReturn = true
case *rule.RuleActionReject: case *rule.RuleActionReject:
buf.ReleaseMulti(buffers) N.ReleaseMultiPacketBuffer(packetBuffers)
N.CloseOnHandshakeFailure(conn, onClose, syscall.ECONNREFUSED) N.CloseOnHandshakeFailure(conn, onClose, syscall.ECONNREFUSED)
return nil return nil
case *rule.RuleActionHijackDNS:
r.hijackDNSPacket(ctx, conn, packetBuffers, metadata)
return nil
} }
} }
if selectedRule == nil || selectReturn { if selectedRule == nil || selectReturn {
if r.defaultOutboundForPacketConnection == nil { if r.defaultOutboundForPacketConnection == nil {
buf.ReleaseMulti(buffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("missing default outbound with UDP support") return E.New("missing default outbound with UDP support")
} }
selectedOutbound = r.defaultOutboundForPacketConnection selectedOutbound = r.defaultOutboundForPacketConnection
} }
if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) { if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
buf.ReleaseMulti(buffers) N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag()) return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
} }
for _, buffer := range buffers { for _, buffer := range packetBuffers {
// TODO: check if metadata.Destination == packet destination conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination)
conn = bufio.NewCachedPacketConn(conn, buffer, metadata.Destination) N.PutPacketBuffer(buffer)
} }
if r.clashServer != nil { if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule) trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule)
@ -297,7 +306,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
} }
func (r *Router) PreMatch(metadata adapter.InboundContext) error { func (r *Router) PreMatch(metadata adapter.InboundContext) error {
selectedRule, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1) selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1)
if err != nil { if err != nil {
return err return err
} }
@ -314,7 +323,10 @@ func (r *Router) PreMatch(metadata adapter.InboundContext) error {
func (r *Router) matchRule( func (r *Router) matchRule(
ctx context.Context, metadata *adapter.InboundContext, preMatch bool, ctx context.Context, metadata *adapter.InboundContext, preMatch bool,
inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int, inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int,
) (selectedRule adapter.Rule, selectedRuleIndex int, buffers []*buf.Buffer, fatalErr error) { ) (
selectedRule adapter.Rule, selectedRuleIndex int,
buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error,
) {
if r.processSearcher != nil && metadata.ProcessInfo == nil { if r.processSearcher != nil && metadata.ProcessInfo == nil {
var originDestination netip.AddrPort var originDestination netip.AddrPort
if metadata.OriginDestination.IsValid() { if metadata.OriginDestination.IsValid() {
@ -376,7 +388,7 @@ func (r *Router) matchRule(
//nolint:staticcheck //nolint:staticcheck
if metadata.InboundOptions != common.DefaultValue[option.InboundOptions]() { if metadata.InboundOptions != common.DefaultValue[option.InboundOptions]() {
if !preMatch && metadata.InboundOptions.SniffEnabled { if !preMatch && metadata.InboundOptions.SniffEnabled {
newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{ newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
OverrideDestination: metadata.InboundOptions.SniffOverrideDestination, OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
Timeout: time.Duration(metadata.InboundOptions.SniffTimeout), Timeout: time.Duration(metadata.InboundOptions.SniffTimeout),
}, inputConn, inputPacketConn) }, inputConn, inputPacketConn)
@ -384,7 +396,11 @@ func (r *Router) matchRule(
fatalErr = newErr fatalErr = newErr
return return
} }
buffers = append(buffers, newBuffers...) if newBuffer != nil {
buffers = []*buf.Buffer{newBuffer}
} else if len(newPackerBuffers) > 0 {
packetBuffers = newPackerBuffers
}
} }
if dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS { if dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS {
fatalErr = r.actionResolve(ctx, metadata, &rule.RuleActionResolve{ fatalErr = r.actionResolve(ctx, metadata, &rule.RuleActionResolve{
@ -421,22 +437,36 @@ match:
break break
} }
if !preMatch { if !preMatch {
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action()) ruleDescription := currentRule.String()
if ruleDescription != "" {
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
} else {
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action())
}
} else { } else {
switch currentRule.Action().Type() { switch currentRule.Action().Type() {
case C.RuleActionTypeReject, C.RuleActionTypeResolve: case C.RuleActionTypeReject, C.RuleActionTypeResolve:
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action()) ruleDescription := currentRule.String()
if ruleDescription != "" {
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
} else {
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] => ", currentRule.Action())
}
} }
} }
switch action := currentRule.Action().(type) { switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff: case *rule.RuleActionSniff:
if !preMatch { if !preMatch {
newBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn) newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
if newErr != nil { if newErr != nil {
fatalErr = newErr fatalErr = newErr
return return
} }
buffers = append(buffers, newBuffers...) if newBuffer != nil {
buffers = append(buffers, newBuffer)
} else if len(newPacketBuffers) > 0 {
packetBuffers = append(packetBuffers, newPacketBuffers...)
}
} else { } else {
selectedRule = currentRule selectedRule = currentRule
selectedRuleIndex = currentRuleIndex selectedRuleIndex = currentRuleIndex
@ -455,12 +485,16 @@ match:
ruleIndex = currentRuleIndex ruleIndex = currentRuleIndex
} }
if !preMatch && metadata.Destination.Addr.IsUnspecified() { if !preMatch && metadata.Destination.Addr.IsUnspecified() {
newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn) newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)
if newErr != nil { if newErr != nil {
fatalErr = newErr fatalErr = newErr
return return
} }
buffers = append(buffers, newBuffers...) if newBuffer != nil {
buffers = append(buffers, newBuffer)
} else if len(newPacketBuffers) > 0 {
packetBuffers = append(packetBuffers, newPacketBuffers...)
}
} }
return return
} }
@ -468,18 +502,31 @@ match:
func (r *Router) actionSniff( func (r *Router) actionSniff(
ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff, ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff,
inputConn net.Conn, inputPacketConn N.PacketConn, inputConn net.Conn, inputPacketConn N.PacketConn,
) (buffers []*buf.Buffer, fatalErr error) { ) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
if sniff.Skip(metadata) { if sniff.Skip(metadata) {
return return
} else if inputConn != nil && len(action.StreamSniffers) > 0 { } else if inputConn != nil {
buffer := buf.NewPacket() sniffBuffer := buf.NewPacket()
var streamSniffers []sniff.StreamSniffer
if len(action.StreamSniffers) > 0 {
streamSniffers = action.StreamSniffers
} else {
streamSniffers = []sniff.StreamSniffer{
sniff.TLSClientHello,
sniff.HTTPHost,
sniff.StreamDomainNameQuery,
sniff.BitTorrent,
sniff.SSH,
sniff.RDP,
}
}
err := sniff.PeekStream( err := sniff.PeekStream(
ctx, ctx,
metadata, metadata,
inputConn, inputConn,
buffer, sniffBuffer,
action.Timeout, action.Timeout,
action.StreamSniffers..., streamSniffers...,
) )
if err == nil { if err == nil {
//goland:noinspection GoDeprecation //goland:noinspection GoDeprecation
@ -497,15 +544,15 @@ func (r *Router) actionSniff(
r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol) r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol)
} }
} }
if !buffer.IsEmpty() { if !sniffBuffer.IsEmpty() {
buffers = append(buffers, buffer) buffer = sniffBuffer
} else { } else {
buffer.Release() sniffBuffer.Release()
} }
} else if inputPacketConn != nil && len(action.PacketSniffers) > 0 { } else if inputPacketConn != nil {
for { for {
var ( var (
buffer = buf.NewPacket() sniffBuffer = buf.NewPacket()
destination M.Socksaddr destination M.Socksaddr
done = make(chan struct{}) done = make(chan struct{})
err error err error
@ -516,7 +563,7 @@ func (r *Router) actionSniff(
sniffTimeout = action.Timeout sniffTimeout = action.Timeout
} }
inputPacketConn.SetReadDeadline(time.Now().Add(sniffTimeout)) inputPacketConn.SetReadDeadline(time.Now().Add(sniffTimeout))
destination, err = inputPacketConn.ReadPacket(buffer) destination, err = inputPacketConn.ReadPacket(sniffBuffer)
inputPacketConn.SetReadDeadline(time.Time{}) inputPacketConn.SetReadDeadline(time.Time{})
close(done) close(done)
}() }()
@ -528,7 +575,7 @@ func (r *Router) actionSniff(
return return
} }
if err != nil { if err != nil {
buffer.Release() sniffBuffer.Release()
if !errors.Is(err, os.ErrDeadlineExceeded) { if !errors.Is(err, os.ErrDeadlineExceeded) {
fatalErr = err fatalErr = err
return return
@ -538,22 +585,40 @@ func (r *Router) actionSniff(
if metadata.Destination.Addr.IsUnspecified() { if metadata.Destination.Addr.IsUnspecified() {
metadata.Destination = destination metadata.Destination = destination
} }
if len(buffers) > 0 { if len(packetBuffers) > 0 {
err = sniff.PeekPacket( err = sniff.PeekPacket(
ctx, ctx,
metadata, metadata,
buffer.Bytes(), sniffBuffer.Bytes(),
sniff.QUICClientHello, sniff.QUICClientHello,
) )
} else { } else {
var packetSniffers []sniff.PacketSniffer
if len(action.PacketSniffers) > 0 {
packetSniffers = action.PacketSniffers
} else {
packetSniffers = []sniff.PacketSniffer{
sniff.DomainNameQuery,
sniff.QUICClientHello,
sniff.STUNMessage,
sniff.UTP,
sniff.UDPTracker,
sniff.DTLSRecord,
}
}
err = sniff.PeekPacket( err = sniff.PeekPacket(
ctx, metadata, ctx, metadata,
buffer.Bytes(), sniffBuffer.Bytes(),
action.PacketSniffers..., packetSniffers...,
) )
} }
buffers = append(buffers, buffer) packetBuffer := N.NewPacketBuffer()
if E.IsMulti(err, sniff.ErrClientHelloFragmented) && len(buffers) == 0 { *packetBuffer = N.PacketBuffer{
Buffer: sniffBuffer,
Destination: destination,
}
packetBuffers = append(packetBuffers, packetBuffer)
if E.IsMulti(err, sniff.ErrClientHelloFragmented) && len(packetBuffers) == 0 {
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
continue continue
} }

View file

@ -162,53 +162,31 @@ func (r *RuleActionSniff) Type() string {
} }
func (r *RuleActionSniff) build() error { func (r *RuleActionSniff) build() error {
if len(r.StreamSniffers) > 0 || len(r.PacketSniffers) > 0 { for _, name := range r.snifferNames {
return nil switch name {
} case C.ProtocolTLS:
if len(r.snifferNames) > 0 { r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello)
for _, name := range r.snifferNames { case C.ProtocolHTTP:
switch name { r.StreamSniffers = append(r.StreamSniffers, sniff.HTTPHost)
case C.ProtocolTLS: case C.ProtocolQUIC:
r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello) r.PacketSniffers = append(r.PacketSniffers, sniff.QUICClientHello)
case C.ProtocolHTTP: case C.ProtocolDNS:
r.StreamSniffers = append(r.StreamSniffers, sniff.HTTPHost) r.StreamSniffers = append(r.StreamSniffers, sniff.StreamDomainNameQuery)
case C.ProtocolQUIC: r.PacketSniffers = append(r.PacketSniffers, sniff.DomainNameQuery)
r.PacketSniffers = append(r.PacketSniffers, sniff.QUICClientHello) case C.ProtocolSTUN:
case C.ProtocolDNS: r.PacketSniffers = append(r.PacketSniffers, sniff.STUNMessage)
r.StreamSniffers = append(r.StreamSniffers, sniff.StreamDomainNameQuery) case C.ProtocolBitTorrent:
r.PacketSniffers = append(r.PacketSniffers, sniff.DomainNameQuery) r.StreamSniffers = append(r.StreamSniffers, sniff.BitTorrent)
case C.ProtocolSTUN: r.PacketSniffers = append(r.PacketSniffers, sniff.UTP)
r.PacketSniffers = append(r.PacketSniffers, sniff.STUNMessage) r.PacketSniffers = append(r.PacketSniffers, sniff.UDPTracker)
case C.ProtocolBitTorrent: case C.ProtocolDTLS:
r.StreamSniffers = append(r.StreamSniffers, sniff.BitTorrent) r.PacketSniffers = append(r.PacketSniffers, sniff.DTLSRecord)
r.PacketSniffers = append(r.PacketSniffers, sniff.UTP) case C.ProtocolSSH:
r.PacketSniffers = append(r.PacketSniffers, sniff.UDPTracker) r.StreamSniffers = append(r.StreamSniffers, sniff.SSH)
case C.ProtocolDTLS: case C.ProtocolRDP:
r.PacketSniffers = append(r.PacketSniffers, sniff.DTLSRecord) r.StreamSniffers = append(r.StreamSniffers, sniff.RDP)
case C.ProtocolSSH: default:
r.StreamSniffers = append(r.StreamSniffers, sniff.SSH) return E.New("unknown sniffer: ", name)
case C.ProtocolRDP:
r.StreamSniffers = append(r.StreamSniffers, sniff.RDP)
default:
return E.New("unknown sniffer: ", name)
}
}
} else {
r.StreamSniffers = []sniff.StreamSniffer{
sniff.TLSClientHello,
sniff.HTTPHost,
sniff.StreamDomainNameQuery,
sniff.BitTorrent,
sniff.SSH,
sniff.RDP,
}
r.PacketSniffers = []sniff.PacketSniffer{
sniff.DomainNameQuery,
sniff.QUICClientHello,
sniff.STUNMessage,
sniff.UTP,
sniff.UDPTracker,
sniff.DTLSRecord,
} }
} }
return nil return nil