diff --git a/common/sniff/stun.go b/common/sniff/stun.go new file mode 100644 index 00000000..66a72d7e --- /dev/null +++ b/common/sniff/stun.go @@ -0,0 +1,24 @@ +package sniff + +import ( + "context" + "encoding/binary" + "os" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" +) + +func STUNMessage(ctx context.Context, packet []byte) (*adapter.InboundContext, error) { + pLen := len(packet) + if pLen < 20 { + return nil, os.ErrInvalid + } + if binary.BigEndian.Uint32(packet[4:8]) != 0x2112A442 { + return nil, os.ErrInvalid + } + if len(packet) < 20+int(binary.BigEndian.Uint16(packet[2:4])) { + return nil, os.ErrInvalid + } + return &adapter.InboundContext{Protocol: C.ProtocolSTUN}, nil +} diff --git a/common/sniff/stun_test.go b/common/sniff/stun_test.go new file mode 100644 index 00000000..5fd9a18d --- /dev/null +++ b/common/sniff/stun_test.go @@ -0,0 +1,28 @@ +package sniff_test + +import ( + "context" + "encoding/hex" + "testing" + + "github.com/sagernet/sing-box/common/sniff" + C "github.com/sagernet/sing-box/constant" + + "github.com/stretchr/testify/require" +) + +func TestSniffSTUN(t *testing.T) { + packet, err := hex.DecodeString("000100002112a44224b1a025d0c180c484341306") + require.NoError(t, err) + metadata, err := sniff.STUNMessage(context.Background(), packet) + require.NoError(t, err) + require.Equal(t, metadata.Protocol, C.ProtocolSTUN) +} + +func FuzzSniffSTUN(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + if _, err := sniff.STUNMessage(context.Background(), data); err == nil { + t.Fail() + } + }) +} diff --git a/constant/protocol.go b/constant/protocol.go index 62a33f49..810c79ec 100644 --- a/constant/protocol.go +++ b/constant/protocol.go @@ -5,4 +5,5 @@ const ( ProtocolHTTP = "http" ProtocolQUIC = "quic" ProtocolDNS = "dns" + ProtocolSTUN = "stun" ) diff --git a/route/router.go b/route/router.go index 64e04af1..31d2c971 100644 --- a/route/router.go +++ b/route/router.go @@ -457,7 +457,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m if err != nil { return err } - sniffMetadata, err := sniff.PeekPacket(ctx, buffer.Bytes(), sniff.QUICClientHello) + sniffMetadata, err := sniff.PeekPacket(ctx, buffer.Bytes(), sniff.QUICClientHello, sniff.STUNMessage) originDestination := metadata.Destination if err == nil { metadata.Protocol = sniffMetadata.Protocol