From 4432cc2253eef3d8645c7d2ee8f9527a5237c94c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 10 Jul 2022 07:52:33 +0800 Subject: [PATCH] Fix quic sniff irl --- common/sniff/quic.go | 55 +++++++++++++++++++++++++++++++++---------- common/sniff/sniff.go | 1 + 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/common/sniff/quic.go b/common/sniff/quic.go index e7e7f71f..07bcbd6c 100644 --- a/common/sniff/quic.go +++ b/common/sniff/quic.go @@ -12,6 +12,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/sniff/internal/qtls" C "github.com/sagernet/sing-box/constant" + E "github.com/sagernet/sing/common/exceptions" "golang.org/x/crypto/hkdf" ) @@ -25,7 +26,7 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex } if typeByte&0x80 == 0 || typeByte&0x40 == 0 { - return nil, os.ErrInvalid + return nil, E.New("bad type byte") } var versionNumber uint32 err = binary.Read(reader, binary.BigEndian, &versionNumber) @@ -33,13 +34,22 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex return nil, err } if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 { - return nil, os.ErrInvalid + return nil, E.New("bad version") } - if (typeByte&0x30)>>4 == 0x0 { - } else if (typeByte&0x30)>>4 != 0x01 { - // 0-rtt + if versionNumber == qtls.Version2 { + if (typeByte&0x30)>>4 == 0b01 { + } else if (typeByte&0x30)>>4 != 0b10 { + // 0-rtt + } else { + return nil, E.New("bad packet type") + } } else { - return nil, os.ErrInvalid + if (typeByte&0x30)>>4 == 0x0 { + } else if (typeByte&0x30)>>4 != 0x01 { + // 0-rtt + } else { + return nil, E.New("bad packet type") + } } destConnIDLen, err := reader.ReadByte() @@ -47,6 +57,10 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex return nil, err } + if destConnIDLen == 0 || destConnIDLen > 20 { + return nil, E.New("bad destination connection id length") + } + destConnID := make([]byte, destConnIDLen) _, err = io.ReadFull(reader, destConnID) if err != nil { @@ -79,7 +93,7 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex } hdrLen := int(reader.Size()) - reader.Len() - if hdrLen != len(packet)-int(packetLen) { + if hdrLen+int(packetLen) > len(packet) { return nil, os.ErrInvalid } @@ -126,17 +140,25 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex newPacket[hdrLen+i] ^= mask[i+1] } packetNumberLength := newPacket[0]&0x3 + 1 - if packetNumberLength != 1 { + if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen { return nil, os.ErrInvalid } - packetNumber := newPacket[hdrLen] - if err != nil { - return nil, err + var packetNumber uint32 + switch packetNumberLength { + case 1: + packetNumber = uint32(newPacket[hdrLen]) + case 2: + packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:])) + case 3: + packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16 + case 4: + packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:]) + default: + return nil, E.New("bad packet number length") } if packetNumber != 0 { - return nil, os.ErrInvalid + return nil, E.New("bad packet number: ", packetNumber) } - extHdrLen := hdrLen + int(packetNumberLength) copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:]) data := newPacket[extHdrLen : int(packetLen)+hdrLen] @@ -166,6 +188,13 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex if err != nil { return nil, err } + for frameType == 0x0 { + // skip padding + frameType, err = decryptedReader.ReadByte() + if err != nil { + return nil, err + } + } if frameType != 0x6 { // not crypto frame return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, nil diff --git a/common/sniff/sniff.go b/common/sniff/sniff.go index 12055bed..c33cded9 100644 --- a/common/sniff/sniff.go +++ b/common/sniff/sniff.go @@ -28,6 +28,7 @@ func PeekPacket(ctx context.Context, packet []byte, sniffers ...PacketSniffer) ( for _, sniffer := range sniffers { sniffMetadata, err := sniffer(ctx, packet) if err != nil { + println(err.Error()) return nil, err } return sniffMetadata, nil