diff --git a/adapter/inbound.go b/adapter/inbound.go index f32b804d..bcf3ea5f 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -96,3 +96,12 @@ func ExtendContext(ctx context.Context) (context.Context, *InboundContext) { } return WithContext(ctx, &newMetadata), &newMetadata } + +func OverrideContext(ctx context.Context) context.Context { + if metadata := ContextFrom(ctx); metadata != nil { + var newMetadata InboundContext + newMetadata = *metadata + return WithContext(ctx, &newMetadata) + } + return ctx +} diff --git a/common/mux/client.go b/common/mux/client.go index 6f201dea..bf103355 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -1,11 +1,16 @@ package mux import ( + "context" + "net" + + "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-mux" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -30,7 +35,7 @@ func NewClientWithOptions(dialer N.Dialer, logger logger.Logger, options option. } } return mux.NewClient(mux.Options{ - Dialer: dialer, + Dialer: &clientDialer{dialer}, Logger: logger, Protocol: options.Protocol, MaxConnections: options.MaxConnections, @@ -40,3 +45,15 @@ func NewClientWithOptions(dialer N.Dialer, logger logger.Logger, options option. Brutal: brutalOptions, }) } + +type clientDialer struct { + N.Dialer +} + +func (d *clientDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + return d.Dialer.DialContext(adapter.OverrideContext(ctx), network, destination) +} + +func (d *clientDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return d.Dialer.ListenPacket(adapter.OverrideContext(ctx), destination) +}