sing-box/protocol/tailscale/dns_transport.go
2024-11-22 22:59:26 +08:00

254 lines
7 KiB
Go

package tailscale
import (
"context"
"net"
"net/netip"
"net/url"
"os"
"strings"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/sagernet/tailscale/ipn"
nDNS "github.com/sagernet/tailscale/net/dns"
"github.com/sagernet/tailscale/wgengine/router"
"github.com/sagernet/tailscale/wgengine/wgcfg"
mDNS "github.com/miekg/dns"
"go4.org/netipx"
)
func init() {
dns.RegisterTransport([]string{"tailscale"}, func(options dns.TransportOptions) (dns.Transport, error) {
return NewDNSTransport(options)
})
}
type DNSTransport struct {
endpointTag string
options dns.TransportOptions
network adapter.NetworkManager
endpointManager adapter.EndpointManager
endpoint *Endpoint
rawConfig *wgcfg.Config
rawDNSConfig **nDNS.Config
routePrefixes []netip.Prefix
dnsClient *dns.Client
routes map[string][]dns.Transport
hosts map[string][]netip.Addr
}
func NewDNSTransport(options dns.TransportOptions) (dns.Transport, error) {
linkURL, err := url.Parse(options.Address)
if err != nil {
return nil, err
}
if linkURL.Host == "" {
return nil, E.New("missing tailscale outbound tag")
}
return &DNSTransport{
endpointTag: linkURL.Host,
options: options,
network: service.FromContext[adapter.NetworkManager](options.Context),
endpointManager: service.FromContext[adapter.EndpointManager](options.Context),
}, nil
}
func (t *DNSTransport) Name() string {
return t.options.Name
}
func (t *DNSTransport) Start() error {
rawOutbound, loaded := t.endpointManager.Get(t.endpointTag)
if !loaded {
return E.New("endpoint not found: ", t.endpointTag)
}
tsOutbound, isTailscale := rawOutbound.(*Endpoint)
if !isTailscale {
return E.New("endpoint is not tailscale: ", t.endpointTag)
}
t.endpoint = tsOutbound
go tsOutbound.server.ExportLocalBackend().WatchNotifications(t.options.Context, ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) {
if roNotify.State != nil {
if *roNotify.State == ipn.Running {
err := t.updateDNSServers()
if err == nil {
t.options.Logger.Info("initialized")
}
return err != nil
}
}
if roNotify.LoginFinished != nil {
err := t.updateDNSServers()
if err == nil {
t.options.Logger.Info("initialized")
}
return err != nil
}
return true
})
return nil
}
func (t *DNSTransport) Reset() {
}
func (t *DNSTransport) updateDNSServers() error {
config, dnsConfig, routeConfig := t.endpoint.server.ExportLocalBackend().ExportConfig()
if config == nil || dnsConfig == nil {
return os.ErrInvalid
}
t.routePrefixes = buildRoutePrefixes(routeConfig)
directDialerOnce := sync.OnceValue(func() N.Dialer {
directDialer := common.Must1(dialer.NewDefault(t.network, option.DialerOptions{}))
return &DNSDialer{transport: t, fallbackDialer: directDialer}
})
routes := make(map[string][]dns.Transport)
for domain, resolvers := range dnsConfig.Routes {
var myResolvers []dns.Transport
for _, resolver := range resolvers {
myDialer := directDialerOnce()
if len(resolver.BootstrapResolution) > 0 {
bootstrapTransport := common.Must1(dns.CreateTransport(dns.TransportOptions{
Context: t.options.Context,
Logger: t.options.Logger,
Dialer: directDialerOnce(),
Address: resolver.BootstrapResolution[0].String(),
}))
myDialer = dns.NewDialerWrapper(myDialer, t.dnsClient, bootstrapTransport, dns.DomainStrategyPreferIPv4, 0)
}
transport, err := dns.CreateTransport(dns.TransportOptions{
Context: t.options.Context,
Logger: t.options.Logger,
Dialer: myDialer,
Address: resolver.Addr,
})
if err != nil {
return E.Cause(err, "parse resolver: ", resolver.Addr)
}
myResolvers = append(myResolvers, transport)
}
routes[domain.WithTrailingDot()] = myResolvers
}
hosts := make(map[string][]netip.Addr)
for domain, addresses := range dnsConfig.Hosts {
hosts[domain.WithTrailingDot()] = addresses
}
t.routes = routes
t.hosts = hosts
return nil
}
func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix {
var builder netipx.IPSetBuilder
for _, localAddr := range routeConfig.LocalAddrs {
builder.AddPrefix(localAddr)
}
for _, route := range routeConfig.Routes {
builder.AddPrefix(route)
}
for _, route := range routeConfig.LocalRoutes {
builder.AddPrefix(route)
}
for _, route := range routeConfig.SubnetRoutes {
builder.AddPrefix(route)
}
ipSet, err := builder.IPSet()
if err != nil {
return nil
}
return ipSet.Prefixes()
}
func (t *DNSTransport) Close() error {
return nil
}
func (t *DNSTransport) Raw() bool {
return true
}
func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
if len(message.Question) != 1 {
return nil, os.ErrInvalid
}
question := message.Question[0]
addresses, hostsLoaded := t.hosts[question.Name]
if hostsLoaded {
switch question.Qtype {
case mDNS.TypeA:
addresses4 := common.Filter(addresses, func(addr netip.Addr) bool {
return addr.Is4()
})
if len(addresses4) > 0 {
return dns.FixedResponse(message.Id, question, addresses4, dns.DefaultTTL), nil
}
case mDNS.TypeAAAA:
addresses6 := common.Filter(addresses, func(addr netip.Addr) bool {
return addr.Is6()
})
if len(addresses6) > 0 {
return dns.FixedResponse(message.Id, question, addresses6, dns.DefaultTTL), nil
}
}
}
for domainSuffix, transports := range t.routes {
if strings.HasSuffix(question.Name, domainSuffix) {
if len(transports) == 0 {
return &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: message.Id,
Rcode: mDNS.RcodeNameError,
Response: true,
},
Question: []mDNS.Question{question},
}, nil
}
return transports[0].Exchange(ctx, message)
}
}
return nil, dns.RCodeNameError
}
func (t *DNSTransport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
return nil, os.ErrInvalid
}
type DNSDialer struct {
transport *DNSTransport
fallbackDialer N.Dialer
}
func (d *DNSDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if destination.IsFqdn() {
panic("invalid request here")
}
for _, prefix := range d.transport.routePrefixes {
if prefix.Contains(destination.Addr) {
return d.transport.endpoint.DialContext(ctx, network, destination)
}
}
return d.fallbackDialer.DialContext(ctx, network, destination)
}
func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if destination.IsFqdn() {
panic("invalid request here")
}
for _, prefix := range d.transport.routePrefixes {
if prefix.Contains(destination.Addr) {
return d.transport.endpoint.ListenPacket(ctx, destination)
}
}
return d.fallbackDialer.ListenPacket(ctx, destination)
}