diff --git a/inbound/dns.go b/inbound/dns.go index 899bfcd7..47889589 100644 --- a/inbound/dns.go +++ b/inbound/dns.go @@ -16,6 +16,7 @@ import ( ) func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn net.Conn, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) _buffer := buf.StackNewSize(1024) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) @@ -44,32 +45,38 @@ func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.Log metadata.Domain = string(question.Name.Data[:question.Name.Length-1]) logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source) } - response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message) - if err != nil { + go func() error { + response, err := router.Exchange(ctx, &message) + if err != nil { + return err + } + _responseBuffer := buf.StackNewSize(1024) + defer common.KeepAlive(_responseBuffer) + responseBuffer := common.Dup(_responseBuffer) + defer responseBuffer.Release() + responseBuffer.Resize(2, 0) + n, err := response.AppendPack(responseBuffer.Index(0)) + if err != nil { + return err + } + responseBuffer.Truncate(len(n)) + binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n))) + _, err = conn.Write(responseBuffer.Bytes()) return err - } - buffer.FullReset() - responseBuffer, err := response.AppendPack(buffer.Index(0)) - if err != nil { - return err - } - err = binary.Write(conn, binary.BigEndian, uint16(len(responseBuffer))) - if err != nil { - return err - } - _, err = conn.Write(responseBuffer) - if err != nil { - return err - } + }() } } func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) + _buffer := buf.StackNewSize(1024) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() for { - buffer := buf.StackNewSize(1024) + buffer.FullReset() destination, err := conn.ReadPacket(buffer) if err != nil { - buffer.Release() return err } var message dnsmessage.Message @@ -83,18 +90,20 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger l logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source) } go func() error { - defer buffer.Release() - response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message) + response, err := router.Exchange(ctx, &message) if err != nil { return err } - buffer.FullReset() - responseBuffer, err := response.AppendPack(buffer.Index(0)) + _responseBuffer := buf.StackNewSize(1024) + defer common.KeepAlive(_responseBuffer) + responseBuffer := common.Dup(_responseBuffer) + defer responseBuffer.Release() + n, err := response.AppendPack(responseBuffer.Index(0)) if err != nil { return err } - buffer.Truncate(len(responseBuffer)) - err = conn.WritePacket(buffer, destination) + responseBuffer.Truncate(len(n)) + err = conn.WritePacket(responseBuffer, destination) return err }() }