Fix read DNS message

This commit is contained in:
世界 2022-08-26 13:13:44 +08:00
parent 9ac31d0233
commit c5e38203eb
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 53 additions and 46 deletions

View file

@ -22,7 +22,7 @@ func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.
if err != nil {
return nil, err
}
if length > 512 {
if length == 0 {
return nil, os.ErrInvalid
}
_buffer := buf.StackNewSize(int(length))

View file

@ -3,13 +3,13 @@ package outbound
import (
"context"
"encoding/binary"
"io"
"net"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/canceler"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
@ -47,53 +47,60 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa
func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata)
_buffer := buf.StackNewSize(1024)
for {
err := d.handleConnection(ctx, conn, metadata)
if err != nil {
return err
}
}
}
func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil {
return err
}
if queryLength == 0 {
return dns.RCodeFormatError
}
_buffer := buf.StackNewSize(int(queryLength))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
for {
var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil {
return err
}
if queryLength > 1024 {
return io.ErrShortBuffer
}
buffer.FullReset()
_, err = buffer.ReadFullFrom(conn, int(queryLength))
if err != nil {
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
}
go func() error {
response, err := d.router.Exchange(ctx, &message)
if err != nil {
return err
}
_responseBuffer := buf.StackNewPacket()
defer common.KeepAlive(_responseBuffer)
responseBuffer := common.Dup(_responseBuffer)
defer responseBuffer.Release()
responseBuffer.Resize(2, 0)
n, err := response.AppendPack(responseBuffer.Index(0))
if err != nil {
return err
}
responseBuffer.Truncate(len(n))
binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
_, err = conn.Write(responseBuffer.Bytes())
return err
}()
_, err = buffer.ReadFullFrom(conn, int(queryLength))
if err != nil {
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
}
go func() error {
response, err := d.router.Exchange(ctx, &message)
if err != nil {
return err
}
_responseBuffer := buf.StackNewPacket()
defer common.KeepAlive(_responseBuffer)
responseBuffer := common.Dup(_responseBuffer)
defer responseBuffer.Release()
responseBuffer.Resize(2, 0)
n, err := response.AppendPack(responseBuffer.Index(0))
if err != nil {
return err
}
responseBuffer.Truncate(len(n))
binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
_, err = conn.Write(responseBuffer.Bytes())
return err
}()
return nil
}
func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
@ -103,7 +110,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
var group task.Group
group.Append0(func(ctx context.Context) error {
defer cancel()
_buffer := buf.StackNewSize(1024)
_buffer := buf.StackNewSize(dns.FixedPacketSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()