Fix read DNS message

This commit is contained in:
世界 2022-08-26 13:13:44 +08:00
parent 9ac31d0233
commit c5e38203eb
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 53 additions and 46 deletions

View file

@ -22,7 +22,7 @@ func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.
if err != nil { if err != nil {
return nil, err return nil, err
} }
if length > 512 { if length == 0 {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
_buffer := buf.StackNewSize(int(length)) _buffer := buf.StackNewSize(int(length))

View file

@ -3,13 +3,13 @@ package outbound
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"io"
"net" "net"
"os" "os"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/canceler" "github.com/sagernet/sing-box/common/canceler"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -47,53 +47,60 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa
func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
defer conn.Close() defer conn.Close()
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
_buffer := buf.StackNewSize(1024) for {
err := d.handleConnection(ctx, conn, metadata)
if err != nil {
return err
}
}
}
func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil {
return err
}
if queryLength == 0 {
return dns.RCodeFormatError
}
_buffer := buf.StackNewSize(int(queryLength))
defer common.KeepAlive(_buffer) defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
for { _, err = buffer.ReadFullFrom(conn, int(queryLength))
var queryLength uint16 if err != nil {
err := binary.Read(conn, binary.BigEndian, &queryLength) return err
if err != nil {
return err
}
if queryLength > 1024 {
return io.ErrShortBuffer
}
buffer.FullReset()
_, err = buffer.ReadFullFrom(conn, int(queryLength))
if err != nil {
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
}
go func() error {
response, err := d.router.Exchange(ctx, &message)
if err != nil {
return err
}
_responseBuffer := buf.StackNewPacket()
defer common.KeepAlive(_responseBuffer)
responseBuffer := common.Dup(_responseBuffer)
defer responseBuffer.Release()
responseBuffer.Resize(2, 0)
n, err := response.AppendPack(responseBuffer.Index(0))
if err != nil {
return err
}
responseBuffer.Truncate(len(n))
binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
_, err = conn.Write(responseBuffer.Bytes())
return err
}()
} }
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
}
go func() error {
response, err := d.router.Exchange(ctx, &message)
if err != nil {
return err
}
_responseBuffer := buf.StackNewPacket()
defer common.KeepAlive(_responseBuffer)
responseBuffer := common.Dup(_responseBuffer)
defer responseBuffer.Release()
responseBuffer.Resize(2, 0)
n, err := response.AppendPack(responseBuffer.Index(0))
if err != nil {
return err
}
responseBuffer.Truncate(len(n))
binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
_, err = conn.Write(responseBuffer.Bytes())
return err
}()
return nil
} }
func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
@ -103,7 +110,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
var group task.Group var group task.Group
group.Append0(func(ctx context.Context) error { group.Append0(func(ctx context.Context) error {
defer cancel() defer cancel()
_buffer := buf.StackNewSize(1024) _buffer := buf.StackNewSize(dns.FixedPacketSize)
defer common.KeepAlive(_buffer) defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()