diff --git a/common/sniff/dns.go b/common/sniff/dns.go index ebb9d928..7ff541d1 100644 --- a/common/sniff/dns.go +++ b/common/sniff/dns.go @@ -22,7 +22,7 @@ func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter. if err != nil { return nil, err } - if length > 512 { + if length == 0 { return nil, os.ErrInvalid } _buffer := buf.StackNewSize(int(length)) diff --git a/outbound/dns.go b/outbound/dns.go index 07806f20..83d96a82 100644 --- a/outbound/dns.go +++ b/outbound/dns.go @@ -3,13 +3,13 @@ package outbound import ( "context" "encoding/binary" - "io" "net" "os" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/canceler" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" @@ -47,53 +47,60 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { defer conn.Close() ctx = adapter.WithContext(ctx, &metadata) - _buffer := buf.StackNewSize(1024) + for { + err := d.handleConnection(ctx, conn, metadata) + if err != nil { + return err + } + } +} + +func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + var queryLength uint16 + err := binary.Read(conn, binary.BigEndian, &queryLength) + if err != nil { + return err + } + if queryLength == 0 { + return dns.RCodeFormatError + } + _buffer := buf.StackNewSize(int(queryLength)) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() - for { - var queryLength uint16 - err := binary.Read(conn, binary.BigEndian, &queryLength) - if err != nil { - return err - } - if queryLength > 1024 { - return io.ErrShortBuffer - } - buffer.FullReset() - _, err = buffer.ReadFullFrom(conn, int(queryLength)) - if err != nil { - return err - } - var message dnsmessage.Message - err = message.Unpack(buffer.Bytes()) - if err != nil { - return err - } - if len(message.Questions) > 0 { - question := message.Questions[0] - metadata.Domain = string(question.Name.Data[:question.Name.Length-1]) - } - go func() error { - response, err := d.router.Exchange(ctx, &message) - if err != nil { - return err - } - _responseBuffer := buf.StackNewPacket() - defer common.KeepAlive(_responseBuffer) - responseBuffer := common.Dup(_responseBuffer) - defer responseBuffer.Release() - responseBuffer.Resize(2, 0) - n, err := response.AppendPack(responseBuffer.Index(0)) - if err != nil { - return err - } - responseBuffer.Truncate(len(n)) - binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n))) - _, err = conn.Write(responseBuffer.Bytes()) - return err - }() + _, err = buffer.ReadFullFrom(conn, int(queryLength)) + if err != nil { + return err } + var message dnsmessage.Message + err = message.Unpack(buffer.Bytes()) + if err != nil { + return err + } + if len(message.Questions) > 0 { + question := message.Questions[0] + metadata.Domain = string(question.Name.Data[:question.Name.Length-1]) + } + go func() error { + response, err := d.router.Exchange(ctx, &message) + if err != nil { + return err + } + _responseBuffer := buf.StackNewPacket() + defer common.KeepAlive(_responseBuffer) + responseBuffer := common.Dup(_responseBuffer) + defer responseBuffer.Release() + responseBuffer.Resize(2, 0) + n, err := response.AppendPack(responseBuffer.Index(0)) + if err != nil { + return err + } + responseBuffer.Truncate(len(n)) + binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n))) + _, err = conn.Write(responseBuffer.Bytes()) + return err + }() + return nil } func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { @@ -103,7 +110,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada var group task.Group group.Append0(func(ctx context.Context) error { defer cancel() - _buffer := buf.StackNewSize(1024) + _buffer := buf.StackNewSize(dns.FixedPacketSize) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release()