Add hijack_dns for tun

This commit is contained in:
世界 2022-07-10 09:15:01 +08:00
parent 638f8a52d1
commit 29f78248dc
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
15 changed files with 220 additions and 99 deletions

View file

@ -51,27 +51,6 @@ func (w *myUpstreamHandlerWrapper) NewError(ctx context.Context, err error) {
w.errorHandler.NewError(ctx, err)
}
var myContextType = (*MetadataContext)(nil)
type MetadataContext struct {
context.Context
Metadata InboundContext
}
func (c *MetadataContext) Value(key any) any {
if key == myContextType {
return c
}
return c.Context.Value(key)
}
func ContextWithMetadata(ctx context.Context, metadata InboundContext) context.Context {
return &MetadataContext{
Context: ctx,
Metadata: metadata,
}
}
func UpstreamMetadata(metadata InboundContext) M.Metadata {
return M.Metadata{
Source: metadata.Source,
@ -98,15 +77,15 @@ func NewUpstreamContextHandler(
}
func (w *myUpstreamContextHandlerWrapper) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
myCtx := ctx.Value(myContextType).(*MetadataContext)
myCtx.Metadata.Destination = metadata.Destination
return w.connectionHandler(ctx, conn, myCtx.Metadata)
myMetadata := ContextFrom(ctx)
myMetadata.Destination = metadata.Destination
return w.connectionHandler(ctx, conn, *myMetadata)
}
func (w *myUpstreamContextHandlerWrapper) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
myCtx := ctx.Value(myContextType).(*MetadataContext)
myCtx.Metadata.Destination = metadata.Destination
return w.packetHandler(ctx, conn, myCtx.Metadata)
myMetadata := ContextFrom(ctx)
myMetadata.Destination = metadata.Destination
return w.packetHandler(ctx, conn, *myMetadata)
}
func (w *myUpstreamContextHandlerWrapper) NewError(ctx context.Context, err error) {

View file

@ -21,7 +21,7 @@ func New(router adapter.Router, options option.DialerOptions) N.Dialer {
func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.Dialer {
dialer := New(router, options.DialerOptions)
domainStrategy := C.DomainStrategy(options.DomainStrategy)
if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" && !C.CGO_ENABLED {
if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" {
fallbackDelay := time.Duration(options.FallbackDelay)
if fallbackDelay == 0 {
fallbackDelay = time.Millisecond * 300

View file

@ -32,6 +32,9 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.AppendContext(ctx)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == C.DomainStrategyAsIS {
@ -49,6 +52,9 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.AppendContext(ctx)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == C.DomainStrategyAsIS {

View file

@ -20,6 +20,9 @@ func NewTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, addre
return nil, err
}
host := serverURL.Hostname()
if host == "" {
host = address
}
port := serverURL.Port()
switch serverURL.Scheme {
case "tls":

View file

@ -87,7 +87,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
}
})
conn.err = err
if err != nil {
if err != nil && !E.IsClosed(err) {
t.logger.Debug("connection closed: ", err)
}
}

View file

@ -95,7 +95,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
}
})
conn.err = err
if err != nil {
if err != nil && !E.IsClosed(err) {
t.logger.Debug("connection closed: ", err)
}
}

View file

@ -8,6 +8,7 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
@ -83,7 +84,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
}
})
conn.err = err
if err != nil {
if err != nil && !E.IsClosed(err) {
t.logger.Debug("connection closed: ", err)
}
}

View file

@ -79,6 +79,6 @@ func (d *Direct) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.B
case 3:
metadata.Destination.Port = d.overrideDestination.Port
}
d.udpNat.NewPacketDirect(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), metadata.Source.AddrPort(), conn, buffer, adapter.UpstreamMetadata(metadata))
d.udpNat.NewPacketDirect(adapter.WithContext(log.ContextWithID(ctx), &metadata), metadata.Source.AddrPort(), conn, buffer, adapter.UpstreamMetadata(metadata))
return nil
}

107
inbound/dns.go Normal file
View file

@ -0,0 +1,107 @@
package inbound
import (
"context"
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/dns/dnsmessage"
)
func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn net.Conn, metadata adapter.InboundContext) error {
_buffer := buf.StackNewSize(1024)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
for {
var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil {
return err
}
if queryLength > 1024 {
return io.ErrShortBuffer
}
buffer.FullReset()
_, err = buffer.ReadFullFrom(conn, int(queryLength))
if err != nil {
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
}
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
if err != nil {
return err
}
buffer.FullReset()
responseBuffer, err := response.AppendPack(buffer.Index(0))
if err != nil {
return err
}
err = binary.Write(conn, binary.BigEndian, uint16(len(responseBuffer)))
if err != nil {
return err
}
_, err = conn.Write(responseBuffer)
if err != nil {
return err
}
}
}
func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn N.PacketConn, metadata adapter.InboundContext) error {
for {
buffer := buf.StackNewSize(1024)
destination, err := conn.ReadPacket(buffer)
if err != nil {
buffer.Release()
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
}
go func() error {
defer buffer.Release()
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
if err != nil {
return err
}
buffer.FullReset()
responseBuffer, err := response.AppendPack(buffer.Index(0))
if err != nil {
return err
}
buffer.Truncate(len(responseBuffer))
err = conn.WritePacket(buffer, destination)
return err
}()
}
}
func formatDNSQuestion(question dnsmessage.Question) string {
domain := question.Name.String()
domain = domain[:len(domain)-1]
return string(question.Name.Data[:question.Name.Length-1]) + " " + question.Type.String()[4:] + " " + question.Class.String()[5:]
}

View file

@ -73,9 +73,9 @@ func newShadowsocks(ctx context.Context, router adapter.Router, logger log.Logge
}
func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *Shadowsocks) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}

View file

@ -68,11 +68,11 @@ func newShadowsocksMulti(ctx context.Context, router adapter.Router, logger log.
}
func (h *ShadowsocksMulti) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksMulti) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksMulti) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

View file

@ -68,11 +68,11 @@ func newShadowsocksRelay(ctx context.Context, router adapter.Router, logger log.
}
func (h *ShadowsocksRelay) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksRelay) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksRelay) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

View file

@ -20,6 +20,7 @@ import (
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
)
var _ adapter.Inbound = (*Tun)(nil)
@ -30,20 +31,39 @@ type Tun struct {
ctx context.Context
router adapter.Router
logger log.Logger
options option.TunInboundOptions
inboundOptions option.InboundOptions
tunName string
tunMTU uint32
inet4Address netip.Prefix
inet6Address netip.Prefix
autoRoute bool
hijackDNS bool
tunFd uintptr
tun *tun.GVisorTun
}
func NewTun(ctx context.Context, router adapter.Router, logger log.Logger, tag string, options option.TunInboundOptions) (*Tun, error) {
tunName := options.InterfaceName
if tunName == "" {
tunName = mkInterfaceName()
}
tunMTU := options.MTU
if tunMTU == 0 {
tunMTU = 1500
}
return &Tun{
tag: tag,
ctx: ctx,
router: router,
logger: logger,
options: options,
inboundOptions: options.InboundOptions,
tunName: tunName,
tunMTU: tunMTU,
inet4Address: netip.Prefix(options.Inet4Address),
inet6Address: netip.Prefix(options.Inet6Address),
autoRoute: options.AutoRoute,
hijackDNS: options.HijackDNS,
}, nil
}
@ -56,38 +76,26 @@ func (t *Tun) Tag() string {
}
func (t *Tun) Start() error {
tunName := t.options.InterfaceName
if tunName == "" {
tunName = mkInterfaceName()
}
var mtu uint32
if t.options.MTU != 0 {
mtu = t.options.MTU
} else {
mtu = 1500
}
tunFd, err := tun.Open(tunName)
tunFd, err := tun.Open(t.tunName)
if err != nil {
return E.Cause(err, "create tun interface")
}
err = tun.Configure(tunName, netip.Prefix(t.options.Inet4Address), netip.Prefix(t.options.Inet6Address), mtu, t.options.AutoRoute)
err = tun.Configure(t.tunName, t.inet4Address, t.inet6Address, t.tunMTU, t.autoRoute)
if err != nil {
return E.Cause(err, "configure tun interface")
}
t.tunName = tunName
t.tunFd = tunFd
t.tun = tun.NewGVisor(t.ctx, tunFd, mtu, t)
t.tun = tun.NewGVisor(t.ctx, tunFd, t.tunMTU, t)
err = t.tun.Start()
if err != nil {
return err
}
t.logger.Info("started at ", tunName)
t.logger.Info("started at ", t.tunName)
return nil
}
func (t *Tun) Close() error {
err := tun.UnConfigure(t.tunName, netip.Prefix(t.options.Inet4Address), netip.Prefix(t.options.Inet6Address), t.options.AutoRoute)
err := tun.UnConfigure(t.tunName, t.inet4Address, t.inet6Address, t.autoRoute)
if err != nil {
return err
}
@ -98,30 +106,40 @@ func (t *Tun) Close() error {
}
func (t *Tun) NewConnection(ctx context.Context, conn net.Conn, upstreamMetadata M.Metadata) error {
t.logger.WithContext(ctx).Info("inbound connection from ", upstreamMetadata.Source)
t.logger.WithContext(ctx).Info("inbound connection to ", upstreamMetadata.Destination)
var metadata adapter.InboundContext
metadata.Inbound = t.tag
metadata.Network = C.NetworkTCP
metadata.Source = upstreamMetadata.Source
metadata.Destination = upstreamMetadata.Destination
metadata.SniffEnabled = t.options.SniffEnabled
metadata.SniffOverrideDestination = t.options.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.options.DomainStrategy)
metadata.SniffEnabled = t.inboundOptions.SniffEnabled
metadata.SniffOverrideDestination = t.inboundOptions.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.inboundOptions.DomainStrategy)
if t.hijackDNS && upstreamMetadata.Destination.Port == 53 {
return task.Run(ctx, func() error {
return NewDNSConnection(ctx, t.router, t.logger, conn, metadata)
})
}
t.logger.WithContext(ctx).Info("inbound connection from ", metadata.Source)
t.logger.WithContext(ctx).Info("inbound connection to ", metadata.Destination)
return t.router.RouteConnection(ctx, conn, metadata)
}
func (t *Tun) NewPacketConnection(ctx context.Context, conn N.PacketConn, upstreamMetadata M.Metadata) error {
t.logger.WithContext(ctx).Info("inbound packet connection from ", upstreamMetadata.Source)
t.logger.WithContext(ctx).Info("inbound packet connection to ", upstreamMetadata.Destination)
var metadata adapter.InboundContext
metadata.Inbound = t.tag
metadata.Network = C.NetworkUDP
metadata.Source = upstreamMetadata.Source
metadata.Destination = upstreamMetadata.Destination
metadata.SniffEnabled = t.options.SniffEnabled
metadata.SniffOverrideDestination = t.options.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.options.DomainStrategy)
metadata.SniffEnabled = t.inboundOptions.SniffEnabled
metadata.SniffOverrideDestination = t.inboundOptions.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.inboundOptions.DomainStrategy)
if t.hijackDNS && upstreamMetadata.Destination.Port == 53 {
return task.Run(ctx, func() error {
return NewDNSPacketConnection(ctx, t.router, t.logger, conn, metadata)
})
}
t.logger.WithContext(ctx).Info("inbound packet connection from ", metadata.Source)
t.logger.WithContext(ctx).Info("inbound packet connection to ", metadata.Destination)
return t.router.RoutePacketConnection(ctx, conn, metadata)
}

View file

@ -144,10 +144,11 @@ type ShadowsocksDestination struct {
}
type TunInboundOptions struct {
InterfaceName string `json:"interface_name"`
MTU uint32 `json:"mtu,omitempty"`
Inet4Address ListenPrefix `json:"inet4_address"`
Inet6Address ListenPrefix `json:"inet6_address"`
AutoRoute bool `json:"auto_route"`
InterfaceName string `json:"interface_name,omitempty"`
MTU uint32 `json:"mtu,omitempty,omitempty"`
Inet4Address ListenPrefix `json:"inet4_address,omitempty"`
Inet6Address ListenPrefix `json:"inet6_address,omitempty"`
AutoRoute bool `json:"auto_route,omitempty"`
HijackDNS bool `json:"hijack_dns,omitempty"`
InboundOptions
}

View file

@ -9,6 +9,7 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"time"
@ -128,11 +129,15 @@ func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptio
} else {
detour = dialer.NewDetour(router, server.Detour)
}
if server.Address != "local" {
serverURL, err := url.Parse(server.Address)
if err != nil {
return nil, err
}
serverAddress := serverURL.Hostname()
if serverAddress == "" {
serverAddress = server.Address
}
_, notIpAddress := netip.ParseAddr(serverAddress)
if server.AddressResolver != "" {
if !transportTagMap[server.AddressResolver] {
@ -146,6 +151,7 @@ func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptio
} else if notIpAddress != nil {
return nil, E.New("parse dns server[", tag, "]: missing address_resolver")
}
}
transport, err := dns.NewTransport(ctx, detour, logger, server.Address)
if err != nil {
return nil, E.Cause(err, "parse dns server[", tag, "]")
@ -419,7 +425,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
}
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
if metadata.SniffEnabled && metadata.Destination.Port == 443 {
if metadata.SniffEnabled {
_buffer := buf.StackNewPacket()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
@ -489,7 +495,7 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def
func (r *Router) matchDNS(ctx context.Context) adapter.DNSTransport {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
r.dnsLogger.WithContext(ctx).Warn("no context")
r.dnsLogger.WithContext(ctx).Warn("no context: ", reflect.TypeOf(ctx))
return r.defaultTransport
}
for i, rule := range r.dnsRules {