Add hijack_dns for tun

This commit is contained in:
世界 2022-07-10 09:15:01 +08:00
parent 638f8a52d1
commit 29f78248dc
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
15 changed files with 220 additions and 99 deletions

View file

@ -51,27 +51,6 @@ func (w *myUpstreamHandlerWrapper) NewError(ctx context.Context, err error) {
w.errorHandler.NewError(ctx, err)
}
var myContextType = (*MetadataContext)(nil)
type MetadataContext struct {
context.Context
Metadata InboundContext
}
func (c *MetadataContext) Value(key any) any {
if key == myContextType {
return c
}
return c.Context.Value(key)
}
func ContextWithMetadata(ctx context.Context, metadata InboundContext) context.Context {
return &MetadataContext{
Context: ctx,
Metadata: metadata,
}
}
func UpstreamMetadata(metadata InboundContext) M.Metadata {
return M.Metadata{
Source: metadata.Source,
@ -98,15 +77,15 @@ func NewUpstreamContextHandler(
}
func (w *myUpstreamContextHandlerWrapper) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
myCtx := ctx.Value(myContextType).(*MetadataContext)
myCtx.Metadata.Destination = metadata.Destination
return w.connectionHandler(ctx, conn, myCtx.Metadata)
myMetadata := ContextFrom(ctx)
myMetadata.Destination = metadata.Destination
return w.connectionHandler(ctx, conn, *myMetadata)
}
func (w *myUpstreamContextHandlerWrapper) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
myCtx := ctx.Value(myContextType).(*MetadataContext)
myCtx.Metadata.Destination = metadata.Destination
return w.packetHandler(ctx, conn, myCtx.Metadata)
myMetadata := ContextFrom(ctx)
myMetadata.Destination = metadata.Destination
return w.packetHandler(ctx, conn, *myMetadata)
}
func (w *myUpstreamContextHandlerWrapper) NewError(ctx context.Context, err error) {

View file

@ -21,7 +21,7 @@ func New(router adapter.Router, options option.DialerOptions) N.Dialer {
func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.Dialer {
dialer := New(router, options.DialerOptions)
domainStrategy := C.DomainStrategy(options.DomainStrategy)
if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" && !C.CGO_ENABLED {
if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" {
fallbackDelay := time.Duration(options.FallbackDelay)
if fallbackDelay == 0 {
fallbackDelay = time.Millisecond * 300

View file

@ -32,6 +32,9 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.AppendContext(ctx)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == C.DomainStrategyAsIS {
@ -49,6 +52,9 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.AppendContext(ctx)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == C.DomainStrategyAsIS {

View file

@ -20,6 +20,9 @@ func NewTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, addre
return nil, err
}
host := serverURL.Hostname()
if host == "" {
host = address
}
port := serverURL.Port()
switch serverURL.Scheme {
case "tls":

View file

@ -87,7 +87,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
}
})
conn.err = err
if err != nil {
if err != nil && !E.IsClosed(err) {
t.logger.Debug("connection closed: ", err)
}
}

View file

@ -95,7 +95,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
}
})
conn.err = err
if err != nil {
if err != nil && !E.IsClosed(err) {
t.logger.Debug("connection closed: ", err)
}
}

View file

@ -8,6 +8,7 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
@ -83,7 +84,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
}
})
conn.err = err
if err != nil {
if err != nil && !E.IsClosed(err) {
t.logger.Debug("connection closed: ", err)
}
}

View file

@ -79,6 +79,6 @@ func (d *Direct) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.B
case 3:
metadata.Destination.Port = d.overrideDestination.Port
}
d.udpNat.NewPacketDirect(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), metadata.Source.AddrPort(), conn, buffer, adapter.UpstreamMetadata(metadata))
d.udpNat.NewPacketDirect(adapter.WithContext(log.ContextWithID(ctx), &metadata), metadata.Source.AddrPort(), conn, buffer, adapter.UpstreamMetadata(metadata))
return nil
}

107
inbound/dns.go Normal file
View file

@ -0,0 +1,107 @@
package inbound
import (
"context"
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/dns/dnsmessage"
)
func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn net.Conn, metadata adapter.InboundContext) error {
_buffer := buf.StackNewSize(1024)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
for {
var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil {
return err
}
if queryLength > 1024 {
return io.ErrShortBuffer
}
buffer.FullReset()
_, err = buffer.ReadFullFrom(conn, int(queryLength))
if err != nil {
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
}
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
if err != nil {
return err
}
buffer.FullReset()
responseBuffer, err := response.AppendPack(buffer.Index(0))
if err != nil {
return err
}
err = binary.Write(conn, binary.BigEndian, uint16(len(responseBuffer)))
if err != nil {
return err
}
_, err = conn.Write(responseBuffer)
if err != nil {
return err
}
}
}
func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn N.PacketConn, metadata adapter.InboundContext) error {
for {
buffer := buf.StackNewSize(1024)
destination, err := conn.ReadPacket(buffer)
if err != nil {
buffer.Release()
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
}
go func() error {
defer buffer.Release()
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
if err != nil {
return err
}
buffer.FullReset()
responseBuffer, err := response.AppendPack(buffer.Index(0))
if err != nil {
return err
}
buffer.Truncate(len(responseBuffer))
err = conn.WritePacket(buffer, destination)
return err
}()
}
}
func formatDNSQuestion(question dnsmessage.Question) string {
domain := question.Name.String()
domain = domain[:len(domain)-1]
return string(question.Name.Data[:question.Name.Length-1]) + " " + question.Type.String()[4:] + " " + question.Class.String()[5:]
}

View file

@ -73,9 +73,9 @@ func newShadowsocks(ctx context.Context, router adapter.Router, logger log.Logge
}
func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *Shadowsocks) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}

View file

@ -68,11 +68,11 @@ func newShadowsocksMulti(ctx context.Context, router adapter.Router, logger log.
}
func (h *ShadowsocksMulti) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksMulti) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksMulti) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

View file

@ -68,11 +68,11 @@ func newShadowsocksRelay(ctx context.Context, router adapter.Router, logger log.
}
func (h *ShadowsocksRelay) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksRelay) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksRelay) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

View file

@ -20,6 +20,7 @@ import (
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
)
var _ adapter.Inbound = (*Tun)(nil)
@ -27,23 +28,42 @@ var _ adapter.Inbound = (*Tun)(nil)
type Tun struct {
tag string
ctx context.Context
router adapter.Router
logger log.Logger
options option.TunInboundOptions
ctx context.Context
router adapter.Router
logger log.Logger
inboundOptions option.InboundOptions
tunName string
tunMTU uint32
inet4Address netip.Prefix
inet6Address netip.Prefix
autoRoute bool
hijackDNS bool
tunName string
tunFd uintptr
tun *tun.GVisorTun
tunFd uintptr
tun *tun.GVisorTun
}
func NewTun(ctx context.Context, router adapter.Router, logger log.Logger, tag string, options option.TunInboundOptions) (*Tun, error) {
tunName := options.InterfaceName
if tunName == "" {
tunName = mkInterfaceName()
}
tunMTU := options.MTU
if tunMTU == 0 {
tunMTU = 1500
}
return &Tun{
tag: tag,
ctx: ctx,
router: router,
logger: logger,
options: options,
tag: tag,
ctx: ctx,
router: router,
logger: logger,
inboundOptions: options.InboundOptions,
tunName: tunName,
tunMTU: tunMTU,
inet4Address: netip.Prefix(options.Inet4Address),
inet6Address: netip.Prefix(options.Inet6Address),
autoRoute: options.AutoRoute,
hijackDNS: options.HijackDNS,
}, nil
}
@ -56,38 +76,26 @@ func (t *Tun) Tag() string {
}
func (t *Tun) Start() error {
tunName := t.options.InterfaceName
if tunName == "" {
tunName = mkInterfaceName()
}
var mtu uint32
if t.options.MTU != 0 {
mtu = t.options.MTU
} else {
mtu = 1500
}
tunFd, err := tun.Open(tunName)
tunFd, err := tun.Open(t.tunName)
if err != nil {
return E.Cause(err, "create tun interface")
}
err = tun.Configure(tunName, netip.Prefix(t.options.Inet4Address), netip.Prefix(t.options.Inet6Address), mtu, t.options.AutoRoute)
err = tun.Configure(t.tunName, t.inet4Address, t.inet6Address, t.tunMTU, t.autoRoute)
if err != nil {
return E.Cause(err, "configure tun interface")
}
t.tunName = tunName
t.tunFd = tunFd
t.tun = tun.NewGVisor(t.ctx, tunFd, mtu, t)
t.tun = tun.NewGVisor(t.ctx, tunFd, t.tunMTU, t)
err = t.tun.Start()
if err != nil {
return err
}
t.logger.Info("started at ", tunName)
t.logger.Info("started at ", t.tunName)
return nil
}
func (t *Tun) Close() error {
err := tun.UnConfigure(t.tunName, netip.Prefix(t.options.Inet4Address), netip.Prefix(t.options.Inet6Address), t.options.AutoRoute)
err := tun.UnConfigure(t.tunName, t.inet4Address, t.inet6Address, t.autoRoute)
if err != nil {
return err
}
@ -98,30 +106,40 @@ func (t *Tun) Close() error {
}
func (t *Tun) NewConnection(ctx context.Context, conn net.Conn, upstreamMetadata M.Metadata) error {
t.logger.WithContext(ctx).Info("inbound connection from ", upstreamMetadata.Source)
t.logger.WithContext(ctx).Info("inbound connection to ", upstreamMetadata.Destination)
var metadata adapter.InboundContext
metadata.Inbound = t.tag
metadata.Network = C.NetworkTCP
metadata.Source = upstreamMetadata.Source
metadata.Destination = upstreamMetadata.Destination
metadata.SniffEnabled = t.options.SniffEnabled
metadata.SniffOverrideDestination = t.options.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.options.DomainStrategy)
metadata.SniffEnabled = t.inboundOptions.SniffEnabled
metadata.SniffOverrideDestination = t.inboundOptions.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.inboundOptions.DomainStrategy)
if t.hijackDNS && upstreamMetadata.Destination.Port == 53 {
return task.Run(ctx, func() error {
return NewDNSConnection(ctx, t.router, t.logger, conn, metadata)
})
}
t.logger.WithContext(ctx).Info("inbound connection from ", metadata.Source)
t.logger.WithContext(ctx).Info("inbound connection to ", metadata.Destination)
return t.router.RouteConnection(ctx, conn, metadata)
}
func (t *Tun) NewPacketConnection(ctx context.Context, conn N.PacketConn, upstreamMetadata M.Metadata) error {
t.logger.WithContext(ctx).Info("inbound packet connection from ", upstreamMetadata.Source)
t.logger.WithContext(ctx).Info("inbound packet connection to ", upstreamMetadata.Destination)
var metadata adapter.InboundContext
metadata.Inbound = t.tag
metadata.Network = C.NetworkUDP
metadata.Source = upstreamMetadata.Source
metadata.Destination = upstreamMetadata.Destination
metadata.SniffEnabled = t.options.SniffEnabled
metadata.SniffOverrideDestination = t.options.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.options.DomainStrategy)
metadata.SniffEnabled = t.inboundOptions.SniffEnabled
metadata.SniffOverrideDestination = t.inboundOptions.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.inboundOptions.DomainStrategy)
if t.hijackDNS && upstreamMetadata.Destination.Port == 53 {
return task.Run(ctx, func() error {
return NewDNSPacketConnection(ctx, t.router, t.logger, conn, metadata)
})
}
t.logger.WithContext(ctx).Info("inbound packet connection from ", metadata.Source)
t.logger.WithContext(ctx).Info("inbound packet connection to ", metadata.Destination)
return t.router.RoutePacketConnection(ctx, conn, metadata)
}

View file

@ -144,10 +144,11 @@ type ShadowsocksDestination struct {
}
type TunInboundOptions struct {
InterfaceName string `json:"interface_name"`
MTU uint32 `json:"mtu,omitempty"`
Inet4Address ListenPrefix `json:"inet4_address"`
Inet6Address ListenPrefix `json:"inet6_address"`
AutoRoute bool `json:"auto_route"`
InterfaceName string `json:"interface_name,omitempty"`
MTU uint32 `json:"mtu,omitempty,omitempty"`
Inet4Address ListenPrefix `json:"inet4_address,omitempty"`
Inet6Address ListenPrefix `json:"inet6_address,omitempty"`
AutoRoute bool `json:"auto_route,omitempty"`
HijackDNS bool `json:"hijack_dns,omitempty"`
InboundOptions
}

View file

@ -9,6 +9,7 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"time"
@ -128,23 +129,28 @@ func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptio
} else {
detour = dialer.NewDetour(router, server.Detour)
}
serverURL, err := url.Parse(server.Address)
if err != nil {
return nil, err
}
serverAddress := serverURL.Hostname()
_, notIpAddress := netip.ParseAddr(serverAddress)
if server.AddressResolver != "" {
if !transportTagMap[server.AddressResolver] {
return nil, E.New("parse dns server[", tag, "]: address resolver not found: ", server.AddressResolver)
if server.Address != "local" {
serverURL, err := url.Parse(server.Address)
if err != nil {
return nil, err
}
if upstream, exists := dummyTransportMap[server.AddressResolver]; exists {
detour = dns.NewDialerWrapper(detour, C.DomainStrategy(server.AddressStrategy), router.dnsClient, upstream)
} else {
continue
serverAddress := serverURL.Hostname()
if serverAddress == "" {
serverAddress = server.Address
}
_, notIpAddress := netip.ParseAddr(serverAddress)
if server.AddressResolver != "" {
if !transportTagMap[server.AddressResolver] {
return nil, E.New("parse dns server[", tag, "]: address resolver not found: ", server.AddressResolver)
}
if upstream, exists := dummyTransportMap[server.AddressResolver]; exists {
detour = dns.NewDialerWrapper(detour, C.DomainStrategy(server.AddressStrategy), router.dnsClient, upstream)
} else {
continue
}
} else if notIpAddress != nil {
return nil, E.New("parse dns server[", tag, "]: missing address_resolver")
}
} else if notIpAddress != nil {
return nil, E.New("parse dns server[", tag, "]: missing address_resolver")
}
transport, err := dns.NewTransport(ctx, detour, logger, server.Address)
if err != nil {
@ -419,7 +425,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
}
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
if metadata.SniffEnabled && metadata.Destination.Port == 443 {
if metadata.SniffEnabled {
_buffer := buf.StackNewPacket()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
@ -489,7 +495,7 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def
func (r *Router) matchDNS(ctx context.Context) adapter.DNSTransport {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
r.dnsLogger.WithContext(ctx).Warn("no context")
r.dnsLogger.WithContext(ctx).Warn("no context: ", reflect.TypeOf(ctx))
return r.defaultTransport
}
for i, rule := range r.dnsRules {