mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-01-31 13:16:53 +00:00
116 lines
3 KiB
Go
116 lines
3 KiB
Go
|
package transport
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/sagernet/sing-box/adapter"
|
||
|
"github.com/sagernet/sing-box/common/tls"
|
||
|
C "github.com/sagernet/sing-box/constant"
|
||
|
"github.com/sagernet/sing-box/dns"
|
||
|
"github.com/sagernet/sing-box/log"
|
||
|
"github.com/sagernet/sing-box/option"
|
||
|
"github.com/sagernet/sing/common"
|
||
|
E "github.com/sagernet/sing/common/exceptions"
|
||
|
"github.com/sagernet/sing/common/logger"
|
||
|
M "github.com/sagernet/sing/common/metadata"
|
||
|
N "github.com/sagernet/sing/common/network"
|
||
|
"github.com/sagernet/sing/common/x/list"
|
||
|
|
||
|
mDNS "github.com/miekg/dns"
|
||
|
)
|
||
|
|
||
|
var _ adapter.DNSTransport = (*TLSTransport)(nil)
|
||
|
|
||
|
func RegisterTLS(registry *dns.TransportRegistry) {
|
||
|
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS)
|
||
|
}
|
||
|
|
||
|
type TLSTransport struct {
|
||
|
dns.TransportAdapter
|
||
|
logger logger.ContextLogger
|
||
|
dialer N.Dialer
|
||
|
serverAddr M.Socksaddr
|
||
|
tlsConfig tls.Config
|
||
|
access sync.Mutex
|
||
|
connections list.List[*tlsDNSConn]
|
||
|
}
|
||
|
|
||
|
type tlsDNSConn struct {
|
||
|
tls.Conn
|
||
|
queryId uint16
|
||
|
}
|
||
|
|
||
|
func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
|
||
|
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
tlsOptions := common.PtrValueOrDefault(options.TLS)
|
||
|
tlsOptions.Enabled = true
|
||
|
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
serverAddr := options.ServerOptions.Build()
|
||
|
if serverAddr.Port == 0 {
|
||
|
serverAddr.Port = 853
|
||
|
}
|
||
|
return &TLSTransport{
|
||
|
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTLS, tag, options.RemoteDNSServerOptions),
|
||
|
logger: logger,
|
||
|
dialer: transportDialer,
|
||
|
serverAddr: serverAddr,
|
||
|
tlsConfig: tlsConfig,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (t *TLSTransport) Reset() {
|
||
|
t.access.Lock()
|
||
|
defer t.access.Unlock()
|
||
|
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
|
||
|
connection.Value.Close()
|
||
|
}
|
||
|
t.connections.Init()
|
||
|
}
|
||
|
|
||
|
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||
|
t.access.Lock()
|
||
|
conn := t.connections.PopFront()
|
||
|
t.access.Unlock()
|
||
|
if conn != nil {
|
||
|
response, err := t.exchange(message, conn)
|
||
|
if err == nil {
|
||
|
return response, nil
|
||
|
}
|
||
|
}
|
||
|
tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
tlsConn, err := tls.ClientHandshake(ctx, tcpConn, t.tlsConfig)
|
||
|
if err != nil {
|
||
|
tcpConn.Close()
|
||
|
return nil, err
|
||
|
}
|
||
|
return t.exchange(message, &tlsDNSConn{Conn: tlsConn})
|
||
|
}
|
||
|
|
||
|
func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
|
||
|
conn.queryId++
|
||
|
err := WriteMessage(conn, conn.queryId, message)
|
||
|
if err != nil {
|
||
|
conn.Close()
|
||
|
return nil, E.Cause(err, "write request")
|
||
|
}
|
||
|
response, err := ReadMessage(conn)
|
||
|
if err != nil {
|
||
|
conn.Close()
|
||
|
return nil, E.Cause(err, "read response")
|
||
|
}
|
||
|
t.access.Lock()
|
||
|
t.connections.PushBack(conn)
|
||
|
t.access.Unlock()
|
||
|
return response, nil
|
||
|
}
|