diff --git a/box.go b/box.go index 5a6e3d55..434f0655 100644 --- a/box.go +++ b/box.go @@ -117,6 +117,7 @@ func New(options Options) (*Box, error) { ctx, router, logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")), + tag, outboundOptions) if err != nil { return nil, E.Cause(err, "parse outbound[", i, "]") @@ -124,7 +125,7 @@ func New(options Options) (*Box, error) { outbounds = append(outbounds, out) } err = router.Initialize(inbounds, outbounds, func() adapter.Outbound { - out, oErr := outbound.New(ctx, router, logFactory.NewLogger("outbound/direct"), option.Outbound{Type: "direct", Tag: "default"}) + out, oErr := outbound.New(ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.Outbound{Type: "direct", Tag: "default"}) common.Must(oErr) outbounds = append(outbounds, out) return out diff --git a/go.mod b/go.mod index f1b028fb..7c3175c1 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 github.com/sagernet/sing v0.2.2-0.20230407053809-308e421e33c2 - github.com/sagernet/sing-dns v0.1.5-0.20230407055526-2a27418e7855 + github.com/sagernet/sing-dns v0.1.5-0.20230408004833-5adaf486d440 github.com/sagernet/sing-shadowsocks v0.2.0 github.com/sagernet/sing-shadowtls v0.1.0 github.com/sagernet/sing-tun v0.1.4-0.20230326080954-8848c0e4cbab diff --git a/go.sum b/go.sum index 7774e2c5..2fba38f2 100644 --- a/go.sum +++ b/go.sum @@ -113,8 +113,8 @@ github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2 github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= github.com/sagernet/sing v0.2.2-0.20230407053809-308e421e33c2 h1:VjeHDxEgpB2fqK5G16yBvtLacibvg3h2MsIjal0UXH0= github.com/sagernet/sing v0.2.2-0.20230407053809-308e421e33c2/go.mod h1:9uHswk2hITw8leDbiLS/xn0t9nzBcbePxzm9PJhwdlw= -github.com/sagernet/sing-dns v0.1.5-0.20230407055526-2a27418e7855 h1:a3W2X1n5C/oYGp/Dd26eoymME3iXN8TJq7LZtO2MSUY= -github.com/sagernet/sing-dns v0.1.5-0.20230407055526-2a27418e7855/go.mod h1:69PNSHyEmXdjf6C+bXBOdr2GQnPeEyWjIzo/MV8gmz8= +github.com/sagernet/sing-dns v0.1.5-0.20230408004833-5adaf486d440 h1:VH8/BcOVuApHtS+vKP+khxlGRcXH7KKhgkTDtNynqSQ= +github.com/sagernet/sing-dns v0.1.5-0.20230408004833-5adaf486d440/go.mod h1:69PNSHyEmXdjf6C+bXBOdr2GQnPeEyWjIzo/MV8gmz8= github.com/sagernet/sing-shadowsocks v0.2.0 h1:ILDWL7pwWfkPLEbviE/MyCgfjaBmJY/JVVY+5jhSb58= github.com/sagernet/sing-shadowsocks v0.2.0/go.mod h1:ysYzszRLpNzJSorvlWRMuzU6Vchsp7sd52q+JNY4axw= github.com/sagernet/sing-shadowtls v0.1.0 h1:05MYce8aR5xfKIn+y7xRFsdKhKt44QZTSEQW+lG5IWQ= diff --git a/include/dhcp_stub.go b/include/dhcp_stub.go index fe175d07..c57aa430 100644 --- a/include/dhcp_stub.go +++ b/include/dhcp_stub.go @@ -12,7 +12,7 @@ import ( ) func init() { - dns.RegisterTransport([]string{"dhcp"}, func(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { + dns.RegisterTransport([]string{"dhcp"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { return nil, E.New(`DHCP is not included in this build, rebuild with -tags with_dhcp`) }) } diff --git a/include/quic_stub.go b/include/quic_stub.go index 18c49b48..682eb536 100644 --- a/include/quic_stub.go +++ b/include/quic_stub.go @@ -19,7 +19,7 @@ import ( const WithQUIC = false func init() { - dns.RegisterTransport([]string{"quic", "h3"}, func(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { + dns.RegisterTransport([]string{"quic", "h3"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { return nil, C.ErrQUICNotIncluded }) v2ray.RegisterQUICConstructor( diff --git a/outbound/builder.go b/outbound/builder.go index 6795e127..f032d83b 100644 --- a/outbound/builder.go +++ b/outbound/builder.go @@ -10,50 +10,51 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.Outbound) (adapter.Outbound, error) { +func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Outbound) (adapter.Outbound, error) { var metadata *adapter.InboundContext - if options.Tag != "" { + if tag != "" { ctx, metadata = adapter.AppendContext(ctx) - metadata.Outbound = options.Tag + metadata.Outbound = tag } if options.Type == "" { return nil, E.New("missing outbound type") } + ctx = ContextWithTag(ctx, tag) switch options.Type { case C.TypeDirect: - return NewDirect(router, logger, options.Tag, options.DirectOptions) + return NewDirect(router, logger, tag, options.DirectOptions) case C.TypeBlock: - return NewBlock(logger, options.Tag), nil + return NewBlock(logger, tag), nil case C.TypeDNS: - return NewDNS(router, options.Tag), nil + return NewDNS(router, tag), nil case C.TypeSocks: - return NewSocks(router, logger, options.Tag, options.SocksOptions) + return NewSocks(router, logger, tag, options.SocksOptions) case C.TypeHTTP: - return NewHTTP(router, logger, options.Tag, options.HTTPOptions) + return NewHTTP(router, logger, tag, options.HTTPOptions) case C.TypeShadowsocks: - return NewShadowsocks(ctx, router, logger, options.Tag, options.ShadowsocksOptions) + return NewShadowsocks(ctx, router, logger, tag, options.ShadowsocksOptions) case C.TypeVMess: - return NewVMess(ctx, router, logger, options.Tag, options.VMessOptions) + return NewVMess(ctx, router, logger, tag, options.VMessOptions) case C.TypeTrojan: - return NewTrojan(ctx, router, logger, options.Tag, options.TrojanOptions) + return NewTrojan(ctx, router, logger, tag, options.TrojanOptions) case C.TypeWireGuard: - return NewWireGuard(ctx, router, logger, options.Tag, options.WireGuardOptions) + return NewWireGuard(ctx, router, logger, tag, options.WireGuardOptions) case C.TypeHysteria: - return NewHysteria(ctx, router, logger, options.Tag, options.HysteriaOptions) + return NewHysteria(ctx, router, logger, tag, options.HysteriaOptions) case C.TypeTor: - return NewTor(ctx, router, logger, options.Tag, options.TorOptions) + return NewTor(ctx, router, logger, tag, options.TorOptions) case C.TypeSSH: - return NewSSH(ctx, router, logger, options.Tag, options.SSHOptions) + return NewSSH(ctx, router, logger, tag, options.SSHOptions) case C.TypeShadowTLS: - return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions) + return NewShadowTLS(ctx, router, logger, tag, options.ShadowTLSOptions) case C.TypeShadowsocksR: - return NewShadowsocksR(ctx, router, logger, options.Tag, options.ShadowsocksROptions) + return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions) case C.TypeVLESS: - return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions) + return NewVLESS(ctx, router, logger, tag, options.VLESSOptions) case C.TypeSelector: - return NewSelector(router, logger, options.Tag, options.SelectorOptions) + return NewSelector(router, logger, tag, options.SelectorOptions) case C.TypeURLTest: - return NewURLTest(router, logger, options.Tag, options.URLTestOptions) + return NewURLTest(router, logger, tag, options.URLTestOptions) default: return nil, E.New("unknown outbound type: ", options.Type) } diff --git a/outbound/lookback.go b/outbound/lookback.go new file mode 100644 index 00000000..aeb7451d --- /dev/null +++ b/outbound/lookback.go @@ -0,0 +1,14 @@ +package outbound + +import "context" + +type outboundTagKey struct{} + +func ContextWithTag(ctx context.Context, outboundTag string) context.Context { + return context.WithValue(ctx, outboundTagKey{}, outboundTag) +} + +func TagFromContext(ctx context.Context) (string, bool) { + value, loaded := ctx.Value(outboundTagKey{}).(string) + return value, loaded +} diff --git a/route/router.go b/route/router.go index 380ec5e5..dc778f66 100644 --- a/route/router.go +++ b/route/router.go @@ -26,6 +26,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/ntp" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/outbound" "github.com/sagernet/sing-dns" "github.com/sagernet/sing-tun" "github.com/sagernet/sing-vmess" @@ -218,7 +219,7 @@ func NewRouter( } } } - transport, err := dns.CreateTransport(ctx, logFactory.NewLogger(F.ToString("dns/transport[", tag, "]")), detour, server.Address) + transport, err := dns.CreateTransport(tag, ctx, logFactory.NewLogger(F.ToString("dns/transport[", tag, "]")), detour, server.Address) if err != nil { return nil, E.Cause(err, "parse dns server[", tag, "]") } @@ -258,7 +259,7 @@ func NewRouter( } if defaultTransport == nil { if len(transports) == 0 { - transports = append(transports, dns.NewLocalTransport(N.SystemDialer)) + transports = append(transports, dns.NewLocalTransport("local", N.SystemDialer)) } defaultTransport = transports[0] } @@ -660,9 +661,11 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad metadata.DestinationAddresses = addresses r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") } - matchedRule, detour := r.match(ctx, &metadata, r.defaultOutboundForConnection) + ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForConnection) + if err != nil { + return err + } if !common.Contains(detour.Network(), N.NetworkTCP) { - conn.Close() return E.New("missing supported outbound, closing connection") } if r.clashServer != nil { @@ -738,9 +741,11 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m metadata.DestinationAddresses = addresses r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") } - matchedRule, detour := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection) + ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection) + if err != nil { + return err + } if !common.Contains(detour.Network(), N.NetworkUDP) { - conn.Close() return E.New("missing supported outbound, closing packet connection") } if r.clashServer != nil { @@ -756,7 +761,18 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m return detour.NewPacketConnection(ctx, conn, metadata) } -func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) { +func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (context.Context, adapter.Rule, adapter.Outbound, error) { + matchRule, matchOutbound := r.match0(ctx, metadata, defaultOutbound) + if contextOutbound, loaded := outbound.TagFromContext(ctx); loaded { + if contextOutbound == matchOutbound.Tag() { + return nil, nil, nil, E.New("connection loopback in outbound/", matchOutbound.Type(), "[", matchOutbound.Tag(), "]") + } + } + ctx = outbound.ContextWithTag(ctx, matchOutbound.Tag()) + return ctx, matchRule, matchOutbound, nil +} + +func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) { if r.processSearcher != nil { var originDestination netip.AddrPort if metadata.OriginDestination.IsValid() { diff --git a/transport/dhcp/server.go b/transport/dhcp/server.go index e341b9c6..427017a9 100644 --- a/transport/dhcp/server.go +++ b/transport/dhcp/server.go @@ -35,6 +35,7 @@ func init() { } type Transport struct { + name string ctx context.Context router adapter.Router logger logger.Logger @@ -46,7 +47,7 @@ type Transport struct { updatedAt time.Time } -func NewTransport(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { +func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { linkURL, err := url.Parse(link) if err != nil { return nil, err @@ -59,6 +60,7 @@ func NewTransport(ctx context.Context, logger logger.ContextLogger, dialer N.Dia return nil, E.New("missing router in context") } transport := &Transport{ + name: name, ctx: ctx, router: router, logger: logger, @@ -68,6 +70,10 @@ func NewTransport(ctx context.Context, logger logger.ContextLogger, dialer N.Dia return transport, nil } +func (t *Transport) Name() string { + return t.name +} + func (t *Transport) Start() error { err := t.fetchServers() if err != nil { @@ -247,7 +253,7 @@ func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Ad }) var transports []dns.Transport for _, serverAddr := range serverAddrs { - serverTransport, err := dns.NewUDPTransport(t.ctx, serverDialer, M.Socksaddr{Addr: serverAddr, Port: 53}) + serverTransport, err := dns.NewUDPTransport(t.name, t.ctx, serverDialer, M.Socksaddr{Addr: serverAddr, Port: 53}) if err != nil { return err }