mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-12-01 19:06:43 +00:00
254 lines
7 KiB
Go
254 lines
7 KiB
Go
|
package tailscale
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"net"
|
||
|
"net/netip"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/sagernet/sing-box/adapter"
|
||
|
"github.com/sagernet/sing-box/common/dialer"
|
||
|
"github.com/sagernet/sing-box/option"
|
||
|
"github.com/sagernet/sing-dns"
|
||
|
"github.com/sagernet/sing/common"
|
||
|
E "github.com/sagernet/sing/common/exceptions"
|
||
|
M "github.com/sagernet/sing/common/metadata"
|
||
|
N "github.com/sagernet/sing/common/network"
|
||
|
"github.com/sagernet/sing/service"
|
||
|
"github.com/sagernet/tailscale/ipn"
|
||
|
nDNS "github.com/sagernet/tailscale/net/dns"
|
||
|
"github.com/sagernet/tailscale/wgengine/router"
|
||
|
"github.com/sagernet/tailscale/wgengine/wgcfg"
|
||
|
|
||
|
mDNS "github.com/miekg/dns"
|
||
|
"go4.org/netipx"
|
||
|
)
|
||
|
|
||
|
func init() {
|
||
|
dns.RegisterTransport([]string{"tailscale"}, func(options dns.TransportOptions) (dns.Transport, error) {
|
||
|
return NewDNSTransport(options)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
type DNSTransport struct {
|
||
|
endpointTag string
|
||
|
options dns.TransportOptions
|
||
|
network adapter.NetworkManager
|
||
|
endpointManager adapter.EndpointManager
|
||
|
endpoint *Endpoint
|
||
|
rawConfig *wgcfg.Config
|
||
|
rawDNSConfig **nDNS.Config
|
||
|
routePrefixes []netip.Prefix
|
||
|
dnsClient *dns.Client
|
||
|
routes map[string][]dns.Transport
|
||
|
hosts map[string][]netip.Addr
|
||
|
}
|
||
|
|
||
|
func NewDNSTransport(options dns.TransportOptions) (dns.Transport, error) {
|
||
|
linkURL, err := url.Parse(options.Address)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if linkURL.Host == "" {
|
||
|
return nil, E.New("missing tailscale outbound tag")
|
||
|
}
|
||
|
return &DNSTransport{
|
||
|
endpointTag: linkURL.Host,
|
||
|
options: options,
|
||
|
network: service.FromContext[adapter.NetworkManager](options.Context),
|
||
|
endpointManager: service.FromContext[adapter.EndpointManager](options.Context),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (t *DNSTransport) Name() string {
|
||
|
return t.options.Name
|
||
|
}
|
||
|
|
||
|
func (t *DNSTransport) Start() error {
|
||
|
rawOutbound, loaded := t.endpointManager.Get(t.endpointTag)
|
||
|
if !loaded {
|
||
|
return E.New("endpoint not found: ", t.endpointTag)
|
||
|
}
|
||
|
tsOutbound, isTailscale := rawOutbound.(*Endpoint)
|
||
|
if !isTailscale {
|
||
|
return E.New("endpoint is not tailscale: ", t.endpointTag)
|
||
|
}
|
||
|
t.endpoint = tsOutbound
|
||
|
go tsOutbound.server.ExportLocalBackend().WatchNotifications(t.options.Context, ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) {
|
||
|
if roNotify.State != nil {
|
||
|
if *roNotify.State == ipn.Running {
|
||
|
err := t.updateDNSServers()
|
||
|
if err == nil {
|
||
|
t.options.Logger.Info("initialized")
|
||
|
}
|
||
|
return err != nil
|
||
|
}
|
||
|
}
|
||
|
if roNotify.LoginFinished != nil {
|
||
|
err := t.updateDNSServers()
|
||
|
if err == nil {
|
||
|
t.options.Logger.Info("initialized")
|
||
|
}
|
||
|
return err != nil
|
||
|
}
|
||
|
return true
|
||
|
})
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *DNSTransport) Reset() {
|
||
|
}
|
||
|
|
||
|
func (t *DNSTransport) updateDNSServers() error {
|
||
|
config, dnsConfig, routeConfig := t.endpoint.server.ExportLocalBackend().ExportConfig()
|
||
|
if config == nil || dnsConfig == nil {
|
||
|
return os.ErrInvalid
|
||
|
}
|
||
|
t.routePrefixes = buildRoutePrefixes(routeConfig)
|
||
|
directDialerOnce := sync.OnceValue(func() N.Dialer {
|
||
|
directDialer := common.Must1(dialer.NewDefault(t.network, option.DialerOptions{}))
|
||
|
return &DNSDialer{transport: t, fallbackDialer: directDialer}
|
||
|
})
|
||
|
routes := make(map[string][]dns.Transport)
|
||
|
for domain, resolvers := range dnsConfig.Routes {
|
||
|
var myResolvers []dns.Transport
|
||
|
for _, resolver := range resolvers {
|
||
|
myDialer := directDialerOnce()
|
||
|
if len(resolver.BootstrapResolution) > 0 {
|
||
|
bootstrapTransport := common.Must1(dns.CreateTransport(dns.TransportOptions{
|
||
|
Context: t.options.Context,
|
||
|
Logger: t.options.Logger,
|
||
|
Dialer: directDialerOnce(),
|
||
|
Address: resolver.BootstrapResolution[0].String(),
|
||
|
}))
|
||
|
myDialer = dns.NewDialerWrapper(myDialer, t.dnsClient, bootstrapTransport, dns.DomainStrategyPreferIPv4, 0)
|
||
|
}
|
||
|
transport, err := dns.CreateTransport(dns.TransportOptions{
|
||
|
Context: t.options.Context,
|
||
|
Logger: t.options.Logger,
|
||
|
Dialer: myDialer,
|
||
|
Address: resolver.Addr,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return E.Cause(err, "parse resolver: ", resolver.Addr)
|
||
|
}
|
||
|
myResolvers = append(myResolvers, transport)
|
||
|
}
|
||
|
routes[domain.WithTrailingDot()] = myResolvers
|
||
|
}
|
||
|
hosts := make(map[string][]netip.Addr)
|
||
|
for domain, addresses := range dnsConfig.Hosts {
|
||
|
hosts[domain.WithTrailingDot()] = addresses
|
||
|
}
|
||
|
t.routes = routes
|
||
|
t.hosts = hosts
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix {
|
||
|
var builder netipx.IPSetBuilder
|
||
|
for _, localAddr := range routeConfig.LocalAddrs {
|
||
|
builder.AddPrefix(localAddr)
|
||
|
}
|
||
|
for _, route := range routeConfig.Routes {
|
||
|
builder.AddPrefix(route)
|
||
|
}
|
||
|
for _, route := range routeConfig.LocalRoutes {
|
||
|
builder.AddPrefix(route)
|
||
|
}
|
||
|
for _, route := range routeConfig.SubnetRoutes {
|
||
|
builder.AddPrefix(route)
|
||
|
}
|
||
|
ipSet, err := builder.IPSet()
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
return ipSet.Prefixes()
|
||
|
}
|
||
|
|
||
|
func (t *DNSTransport) Close() error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *DNSTransport) Raw() bool {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||
|
if len(message.Question) != 1 {
|
||
|
return nil, os.ErrInvalid
|
||
|
}
|
||
|
question := message.Question[0]
|
||
|
addresses, hostsLoaded := t.hosts[question.Name]
|
||
|
if hostsLoaded {
|
||
|
switch question.Qtype {
|
||
|
case mDNS.TypeA:
|
||
|
addresses4 := common.Filter(addresses, func(addr netip.Addr) bool {
|
||
|
return addr.Is4()
|
||
|
})
|
||
|
if len(addresses4) > 0 {
|
||
|
return dns.FixedResponse(message.Id, question, addresses4, dns.DefaultTTL), nil
|
||
|
}
|
||
|
case mDNS.TypeAAAA:
|
||
|
addresses6 := common.Filter(addresses, func(addr netip.Addr) bool {
|
||
|
return addr.Is6()
|
||
|
})
|
||
|
if len(addresses6) > 0 {
|
||
|
return dns.FixedResponse(message.Id, question, addresses6, dns.DefaultTTL), nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
for domainSuffix, transports := range t.routes {
|
||
|
if strings.HasSuffix(question.Name, domainSuffix) {
|
||
|
if len(transports) == 0 {
|
||
|
return &mDNS.Msg{
|
||
|
MsgHdr: mDNS.MsgHdr{
|
||
|
Id: message.Id,
|
||
|
Rcode: mDNS.RcodeNameError,
|
||
|
Response: true,
|
||
|
},
|
||
|
Question: []mDNS.Question{question},
|
||
|
}, nil
|
||
|
}
|
||
|
return transports[0].Exchange(ctx, message)
|
||
|
}
|
||
|
}
|
||
|
return nil, dns.RCodeNameError
|
||
|
}
|
||
|
|
||
|
func (t *DNSTransport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
|
||
|
return nil, os.ErrInvalid
|
||
|
}
|
||
|
|
||
|
type DNSDialer struct {
|
||
|
transport *DNSTransport
|
||
|
fallbackDialer N.Dialer
|
||
|
}
|
||
|
|
||
|
func (d *DNSDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||
|
if destination.IsFqdn() {
|
||
|
panic("invalid request here")
|
||
|
}
|
||
|
for _, prefix := range d.transport.routePrefixes {
|
||
|
if prefix.Contains(destination.Addr) {
|
||
|
return d.transport.endpoint.DialContext(ctx, network, destination)
|
||
|
}
|
||
|
}
|
||
|
return d.fallbackDialer.DialContext(ctx, network, destination)
|
||
|
}
|
||
|
|
||
|
func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||
|
if destination.IsFqdn() {
|
||
|
panic("invalid request here")
|
||
|
}
|
||
|
for _, prefix := range d.transport.routePrefixes {
|
||
|
if prefix.Contains(destination.Addr) {
|
||
|
return d.transport.endpoint.ListenPacket(ctx, destination)
|
||
|
}
|
||
|
}
|
||
|
return d.fallbackDialer.ListenPacket(ctx, destination)
|
||
|
}
|