2022-07-06 10:48:11 +00:00
|
|
|
package dns
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"os"
|
|
|
|
|
2022-07-08 15:03:57 +00:00
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
|
|
"github.com/sagernet/sing-box/log"
|
2022-07-06 10:48:11 +00:00
|
|
|
"github.com/sagernet/sing/common"
|
|
|
|
"github.com/sagernet/sing/common/buf"
|
|
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
|
|
N "github.com/sagernet/sing/common/network"
|
|
|
|
"github.com/sagernet/sing/common/task"
|
|
|
|
|
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
|
|
)
|
|
|
|
|
2022-07-06 15:11:48 +00:00
|
|
|
var _ adapter.DNSTransport = (*UDPTransport)(nil)
|
2022-07-06 10:48:11 +00:00
|
|
|
|
|
|
|
type UDPTransport struct {
|
2022-07-06 15:11:48 +00:00
|
|
|
myTransportAdapter
|
2022-07-06 10:48:11 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func NewUDPTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *UDPTransport {
|
|
|
|
return &UDPTransport{
|
2022-07-06 15:11:48 +00:00
|
|
|
myTransportAdapter{
|
|
|
|
ctx: ctx,
|
|
|
|
dialer: dialer,
|
|
|
|
logger: logger,
|
|
|
|
destination: destination,
|
|
|
|
done: make(chan struct{}),
|
|
|
|
},
|
2022-07-06 10:48:11 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *UDPTransport) offer() (*dnsConnection, error) {
|
|
|
|
t.access.RLock()
|
|
|
|
connection := t.connection
|
|
|
|
t.access.RUnlock()
|
|
|
|
if connection != nil {
|
|
|
|
select {
|
|
|
|
case <-connection.done:
|
|
|
|
default:
|
|
|
|
return connection, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
t.access.Lock()
|
|
|
|
connection = t.connection
|
|
|
|
if connection != nil {
|
|
|
|
select {
|
|
|
|
case <-connection.done:
|
|
|
|
default:
|
|
|
|
t.access.Unlock()
|
|
|
|
return connection, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
tcpConn, err := t.dialer.DialContext(t.ctx, "udp", t.destination)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
connection = &dnsConnection{
|
|
|
|
Conn: tcpConn,
|
|
|
|
done: make(chan struct{}),
|
|
|
|
callbacks: make(map[uint16]chan *dnsmessage.Message),
|
|
|
|
}
|
|
|
|
t.connection = connection
|
|
|
|
t.access.Unlock()
|
|
|
|
go t.newConnection(connection)
|
|
|
|
return connection, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *UDPTransport) newConnection(conn *dnsConnection) {
|
|
|
|
defer close(conn.done)
|
|
|
|
defer conn.Close()
|
2022-07-08 04:58:43 +00:00
|
|
|
err := task.Any(t.ctx, func(ctx context.Context) error {
|
2022-07-06 10:48:11 +00:00
|
|
|
return t.loopIn(conn)
|
2022-07-08 04:58:43 +00:00
|
|
|
}, func(ctx context.Context) error {
|
2022-07-06 10:48:11 +00:00
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
return nil
|
|
|
|
case <-t.done:
|
|
|
|
return os.ErrClosed
|
|
|
|
}
|
|
|
|
})
|
|
|
|
conn.err = err
|
|
|
|
if err != nil {
|
2022-07-06 15:39:17 +00:00
|
|
|
t.logger.Debug("connection closed: ", err)
|
2022-07-06 10:48:11 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *UDPTransport) loopIn(conn *dnsConnection) error {
|
|
|
|
_buffer := buf.StackNewSize(1024)
|
|
|
|
defer common.KeepAlive(_buffer)
|
|
|
|
buffer := common.Dup(_buffer)
|
|
|
|
defer buffer.Release()
|
|
|
|
for {
|
|
|
|
buffer.FullReset()
|
|
|
|
_, err := buffer.ReadFrom(conn)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
var message dnsmessage.Message
|
|
|
|
err = message.Unpack(buffer.Bytes())
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
conn.access.Lock()
|
|
|
|
callback, loaded := conn.callbacks[message.ID]
|
|
|
|
if loaded {
|
|
|
|
delete(conn.callbacks, message.ID)
|
|
|
|
}
|
|
|
|
conn.access.Unlock()
|
|
|
|
if !loaded {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
callback <- &message
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *UDPTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
|
|
|
|
var connection *dnsConnection
|
|
|
|
err := task.Run(ctx, func() error {
|
|
|
|
var innerErr error
|
|
|
|
connection, innerErr = t.offer()
|
|
|
|
return innerErr
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
connection.access.Lock()
|
|
|
|
connection.queryId++
|
|
|
|
message.ID = connection.queryId
|
|
|
|
callback := make(chan *dnsmessage.Message)
|
|
|
|
connection.callbacks[message.ID] = callback
|
|
|
|
connection.access.Unlock()
|
|
|
|
_buffer := buf.StackNewSize(1024)
|
|
|
|
defer common.KeepAlive(_buffer)
|
|
|
|
buffer := common.Dup(_buffer)
|
|
|
|
defer buffer.Release()
|
|
|
|
rawMessage, err := message.AppendPack(buffer.Index(0))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
buffer.Truncate(len(rawMessage))
|
|
|
|
err = task.Run(ctx, func() error {
|
|
|
|
return common.Error(connection.Write(buffer.Bytes()))
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
select {
|
|
|
|
case response := <-callback:
|
|
|
|
return response, nil
|
|
|
|
case <-connection.done:
|
|
|
|
return nil, connection.err
|
|
|
|
case <-ctx.Done():
|
|
|
|
return nil, ctx.Err()
|
|
|
|
}
|
|
|
|
}
|