diff --git a/outbound/dns.go b/outbound/dns.go index df32a019..b18b901e 100644 --- a/outbound/dns.go +++ b/outbound/dns.go @@ -46,8 +46,8 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa } func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + metadata.Destination = M.Socksaddr{} defer conn.Close() - ctx = adapter.WithContext(ctx, &metadata) for { err := d.handleConnection(ctx, conn, metadata) if err != nil { @@ -98,6 +98,7 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap } func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + metadata.Destination = M.Socksaddr{} var reader N.PacketReader = conn var counters []N.CountFunc var cachedPackets []*N.PacketBuffer @@ -111,14 +112,11 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada } } if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created { - readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ - MTU: dns.FixedPacketSize, - }) + readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata) } break } - ctx = adapter.WithContext(ctx, &metadata) fastClose, cancel := common.ContextWithCancelCause(ctx) timeout := canceler.New(fastClose, cancel, C.DNSTimeout) var group task.Group @@ -167,15 +165,11 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada return err } timeout.Update() - responseBuffer := buf.NewPacket() - responseBuffer.Resize(1024, 0) - n, err := response.PackBuffer(responseBuffer.FreeBytes()) + responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024) if err != nil { cancel(err) - responseBuffer.Release() return err } - responseBuffer.Truncate(len(n)) err = conn.WritePacket(responseBuffer, destination) if err != nil { cancel(err) @@ -241,16 +235,11 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa return err } timeout.Update() - response = truncateDNSMessage(response, 512) // TODO: add an option to custom UDP buffer size - responseBuffer := buf.NewSize(dns.FixedPacketSize) - responseBuffer.Resize(1024, 0) - n, err := response.PackBuffer(responseBuffer.FreeBytes()) + responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024) if err != nil { cancel(err) - responseBuffer.Release() return err } - responseBuffer.Truncate(len(n)) err = conn.WritePacket(responseBuffer, destination) if err != nil { cancel(err) @@ -264,22 +253,3 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa }) return group.Run(fastClose) } - -func truncateDNSMessage(response *mDNS.Msg, maxLen int) *mDNS.Msg { - responseLen := response.Len() - if responseLen <= maxLen { - return response - } - newResponse := *response - response = &newResponse - for len(response.Answer) > 0 && responseLen > maxLen { - response.Answer = response.Answer[:len(response.Answer)-1] - response.Truncated = true - responseLen = response.Len() - } - if responseLen > maxLen { - response.Ns = nil - response.Extra = nil - } - return response -}