package transport import ( "context" "net" "os" "sync" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" mDNS "github.com/miekg/dns" ) var _ adapter.DNSTransport = (*UDPTransport)(nil) func RegisterUDP(registry *dns.TransportRegistry) { dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeUDP, NewUDP) } type UDPTransport struct { dns.TransportAdapter logger logger.ContextLogger dialer N.Dialer serverAddr M.Socksaddr udpSize int tcpTransport *TCPTransport access sync.Mutex conn *dnsConnection done chan struct{} } func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) { transportDialer, err := dns.NewRemoteDialer(ctx, options) if err != nil { return nil, err } serverAddr := options.ServerOptions.Build() if serverAddr.Port == 0 { serverAddr.Port = 53 } return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil } func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport { return &UDPTransport{ TransportAdapter: adapter, logger: logger, dialer: dialer, serverAddr: serverAddr, udpSize: 512, tcpTransport: &TCPTransport{ dialer: dialer, serverAddr: serverAddr, }, done: make(chan struct{}), } } func (t *UDPTransport) Reset() { t.access.Lock() defer t.access.Unlock() close(t.done) t.done = make(chan struct{}) } func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { response, err := t.exchange(ctx, message) if err != nil { return nil, err } if response.Truncated { t.logger.InfoContext(ctx, "response truncated, retrying with TCP") return t.tcpTransport.Exchange(ctx, message) } return response, nil } func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { conn, err := t.open(ctx) if err != nil { return nil, err } if edns0Opt := message.IsEdns0(); edns0Opt != nil { if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize { t.udpSize = udpSize } } buffer := buf.NewSize(1 + message.Len()) defer buffer.Release() exMessage := *message exMessage.Compress = true messageId := message.Id callback := &dnsCallback{ done: make(chan struct{}), } conn.access.Lock() conn.queryId++ exMessage.Id = conn.queryId conn.callbacks[exMessage.Id] = callback conn.access.Unlock() defer func() { conn.access.Lock() delete(conn.callbacks, messageId) conn.access.Unlock() callback.access.Lock() select { case <-callback.done: default: close(callback.done) } callback.access.Unlock() }() rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes()) if err != nil { return nil, err } _, err = conn.Write(rawMessage) if err != nil { conn.Close(err) return nil, err } select { case <-callback.done: callback.message.Id = messageId return callback.message, nil case <-conn.done: return nil, conn.err case <-t.done: return nil, os.ErrClosed case <-ctx.Done(): conn.Close(ctx.Err()) return nil, ctx.Err() } } func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) { t.access.Lock() defer t.access.Unlock() conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) if err != nil { return nil, err } dnsConn := &dnsConnection{ Conn: conn, done: make(chan struct{}), callbacks: make(map[uint16]*dnsCallback), } go t.recvLoop(dnsConn) return dnsConn, nil } func (t *UDPTransport) recvLoop(conn *dnsConnection) { for { buffer := buf.NewSize(t.udpSize) _, err := buffer.ReadOnceFrom(conn) if err != nil { buffer.Release() conn.Close(err) return } var message mDNS.Msg err = message.Unpack(buffer.Bytes()) buffer.Release() if err != nil { conn.Close(err) return } conn.access.RLock() callback, loaded := conn.callbacks[message.Id] conn.access.RUnlock() if !loaded { continue } callback.access.Lock() select { case <-callback.done: default: callback.message = &message close(callback.done) } callback.access.Unlock() } } type dnsConnection struct { net.Conn access sync.RWMutex done chan struct{} closeOnce sync.Once err error queryId uint16 callbacks map[uint16]*dnsCallback } func (c *dnsConnection) Close(err error) { c.access.Lock() defer c.access.Unlock() c.closeOnce.Do(func() { close(c.done) c.err = err }) c.Conn.Close() } type dnsCallback struct { access sync.Mutex message *mDNS.Msg done chan struct{} }