package transport import ( "context" "encoding/binary" "io" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" mDNS "github.com/miekg/dns" ) var _ adapter.DNSTransport = (*TCPTransport)(nil) func RegisterTCP(registry *dns.TransportRegistry) { dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeTCP, NewTCP) } type TCPTransport struct { dns.TransportAdapter dialer N.Dialer serverAddr M.Socksaddr } func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) { transportDialer, err := dns.NewRemoteDialer(ctx, options) if err != nil { return nil, err } serverAddr := options.ServerOptions.Build() if serverAddr.Port == 0 { serverAddr.Port = 53 } return &TCPTransport{ TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTCP, tag, options), dialer: transportDialer, serverAddr: serverAddr, }, nil } func (t *TCPTransport) Reset() { } func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) if err != nil { return nil, err } defer conn.Close() err = WriteMessage(conn, 0, message) if err != nil { return nil, err } return ReadMessage(conn) } func ReadMessage(reader io.Reader) (*mDNS.Msg, error) { var responseLen uint16 err := binary.Read(reader, binary.BigEndian, &responseLen) if err != nil { return nil, err } if responseLen < 10 { return nil, mDNS.ErrShortRead } buffer := buf.NewSize(int(responseLen)) defer buffer.Release() _, err = buffer.ReadFullFrom(reader, int(responseLen)) if err != nil { return nil, err } var message mDNS.Msg err = message.Unpack(buffer.Bytes()) return &message, err } func WriteMessage(writer io.Writer, messageId uint16, message *mDNS.Msg) error { requestLen := message.Len() buffer := buf.NewSize(3 + requestLen) defer buffer.Release() common.Must(binary.Write(buffer, binary.BigEndian, uint16(requestLen))) exMessage := *message exMessage.Id = messageId exMessage.Compress = true rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes()) if err != nil { return err } buffer.Truncate(2 + len(rawMessage)) return common.Error(writer.Write(buffer.Bytes())) }