From 017f53b5fcf51d19a277514b8df1e167f39e972b Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Mon, 13 May 2024 21:52:24 -0400 Subject: [PATCH] Add session context outbounds as slice (#3356) * Add session context outbounds as slice slice is needed for dialer proxy where two outbounds work on top of each other There are two sets of target addr for example It also enable Xtls to correctly do splice copy by checking both outbounds are ready to do direct copy * Fill outbound tag info * Splice now checks capalibility from all outbounds * Fix unit tests --- app/dispatcher/default.go | 22 ++--- app/dispatcher/fakednssniffer.go | 11 +-- app/proxyman/inbound/worker.go | 11 +-- app/proxyman/outbound/handler.go | 38 +++++---- app/proxyman/outbound/handler_test.go | 3 + app/reverse/portal.go | 12 +-- app/router/router_test.go | 20 +++-- common/mux/client.go | 14 ++-- common/mux/client_test.go | 8 +- common/session/context.go | 10 +-- common/session/session.go | 14 ++-- common/singbridge/dialer.go | 11 ++- features/routing/session/context.go | 4 +- proxy/blackhole/blackhole.go | 7 +- proxy/blackhole/blackhole_test.go | 6 +- proxy/dns/dns.go | 11 +-- proxy/dokodemo/dokodemo.go | 15 ++-- proxy/freedom/freedom.go | 14 ++-- proxy/http/client.go | 21 +++-- proxy/http/server.go | 2 +- proxy/loopback/loopback.go | 9 ++- proxy/proxy.go | 102 +++++++++++++++--------- proxy/shadowsocks/client.go | 14 ++-- proxy/shadowsocks/server.go | 2 +- proxy/shadowsocks_2022/inbound.go | 2 +- proxy/shadowsocks_2022/inbound_multi.go | 2 +- proxy/shadowsocks_2022/inbound_relay.go | 2 +- proxy/shadowsocks_2022/outbound.go | 11 +-- proxy/socks/client.go | 14 ++-- proxy/socks/server.go | 2 +- proxy/trojan/client.go | 14 ++-- proxy/trojan/server.go | 2 +- proxy/vless/encoding/encoding.go | 20 +++-- proxy/vless/inbound/inbound.go | 10 ++- proxy/vless/outbound/outbound.go | 24 +++--- proxy/vmess/inbound/inbound.go | 2 +- proxy/vmess/outbound/outbound.go | 14 ++-- proxy/wireguard/client.go | 14 ++-- proxy/wireguard/server.go | 8 +- transport/internet/dialer.go | 13 ++- transport/internet/grpc/dial.go | 2 +- transport/internet/http/dialer.go | 2 +- 42 files changed, 303 insertions(+), 236 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index b8131b8f..26019bbe 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -218,11 +218,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin if !destination.IsValid() { panic("Dispatcher: Invalid destination.") } - ob := session.OutboundFromContext(ctx) - if ob == nil { - ob = &session.Outbound{} - ctx = session.ContextWithOutbound(ctx, ob) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) } + ob := outbounds[len(outbounds) - 1] ob.OriginalTarget = destination ob.Target = destination content := session.ContentFromContext(ctx) @@ -274,11 +275,12 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De if !destination.IsValid() { return newError("Dispatcher: Invalid destination.") } - ob := session.OutboundFromContext(ctx) - if ob == nil { - ob = &session.Outbound{} - ctx = session.ContextWithOutbound(ctx, ob) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) } + ob := outbounds[len(outbounds) - 1] ob.OriginalTarget = destination ob.Target = destination content := session.ContentFromContext(ctx) @@ -368,7 +370,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw return contentResult, contentErr } func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { - ob := session.OutboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { proxied := hosts.LookupHosts(ob.Target.String()) if proxied != nil { @@ -425,6 +428,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. return } + ob.Tag = handler.Tag() if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { if tag := handler.Tag(); tag != "" { if inTag == "" { diff --git a/app/dispatcher/fakednssniffer.go b/app/dispatcher/fakednssniffer.go index ad879daf..8d0804de 100644 --- a/app/dispatcher/fakednssniffer.go +++ b/app/dispatcher/fakednssniffer.go @@ -26,11 +26,12 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error) return protocolSnifferWithMetadata{}, errNotInit } return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) { - Target := session.OutboundFromContext(ctx).Target - if Target.Network == net.Network_TCP || Target.Network == net.Network_UDP { - domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(Target.Address) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob.Target.Network == net.Network_TCP || ob.Target.Network == net.Network_UDP { + domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(ob.Target.Address) if domainFromFakeDNS != "" { - newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", Target.Address.String()).WriteToLog(session.ExportIDToError(ctx)) + newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", ob.Target.Address.String()).WriteToLog(session.ExportIDToError(ctx)) return &fakeDNSSniffResult{domainName: domainFromFakeDNS}, nil } } @@ -38,7 +39,7 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error) if ipAddressInRangeValueI := ctx.Value(ipAddressInRange); ipAddressInRangeValueI != nil { ipAddressInRangeValue := ipAddressInRangeValueI.(*ipAddressInRangeOpt) if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok { - inPool := fkr0.IsIPInIPPool(Target.Address) + inPool := fkr0.IsIPInIPPool(ob.Target.Address) ipAddressInRangeValue.addressInRange = &inPool } } diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 1fe86655..9a6499f1 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -60,7 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) { sid := session.NewID() ctx = session.ContextWithID(ctx, sid) - var outbound = &session.Outbound{} + outbounds := []*session.Outbound{{}} if w.recvOrigDest { var dest net.Destination switch getTProxyType(w.stream) { @@ -75,10 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) { dest = net.DestinationFromAddr(conn.LocalAddr()) } if dest.IsValid() { - outbound.Target = dest + outbounds[0].Target = dest } } - ctx = session.ContextWithOutbound(ctx, outbound) + ctx = session.ContextWithOutbounds(ctx, outbounds) if w.uplinkCounter != nil || w.downlinkCounter != nil { conn = &stat.CounterConnection{ @@ -309,9 +309,10 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest ctx = session.ContextWithID(ctx, sid) if originalDest.IsValid() { - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + outbounds := []*session.Outbound{{ Target: originalDest, - }) + }} + ctx = session.ContextWithOutbounds(ctx, outbounds) } ctx = session.ContextWithInbound(ctx, &session.Inbound{ Source: source, diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 792ac249..4262c76a 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -169,10 +169,11 @@ func (h *Handler) Tag() string { // Dispatch implements proxy.Outbound.Dispatch. func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { - outbound := session.OutboundFromContext(ctx) - if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address { - link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address} - link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address} + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address { + link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} + link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} } if h.mux != nil { test := func(err error) { @@ -183,7 +184,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { common.Interrupt(link.Writer) } } - if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 { + if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 { switch h.udp443 { case "reject": test(newError("XUDP rejected UDP/443 traffic").AtInfo()) @@ -192,7 +193,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { goto out } } - if h.xudp != nil && outbound.Target.Network == net.Network_UDP { + if h.xudp != nil && ob.Target.Network == net.Network_UDP { if !h.xudp.Enabled { goto out } @@ -243,10 +244,11 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti handler := h.outboundManager.GetHandler(tag) if handler != nil { newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx)) - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + outbounds := session.OutboundsFromContext(ctx) + ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{ Target: dest, - }) - + Tag: tag, + })) // add another outbound in session ctx opts := pipe.OptionsFromContext(ctx) uplinkReader, uplinkWriter := pipe.New(opts...) downlinkReader, downlinkWriter := pipe.New(opts...) @@ -266,15 +268,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti } if h.senderSettings.Via != nil { - outbound := session.OutboundFromContext(ctx) - if outbound == nil { - outbound = new(session.Outbound) - ctx = session.ContextWithOutbound(ctx, outbound) - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] if h.senderSettings.ViaCidr == "" { - outbound.Gateway = h.senderSettings.Via.AsAddress() + ob.Gateway = h.senderSettings.Via.AsAddress() } else { //Get a random address. - outbound.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr) + ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr) } } } @@ -285,10 +284,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti conn, err := internet.Dial(ctx, dest, h.streamSettings) conn = h.getStatCouterConnection(conn) - outbound := session.OutboundFromContext(ctx) - if outbound != nil { - outbound.Conn = conn - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + ob.Conn = conn return conn, err } diff --git a/app/proxyman/outbound/handler_test.go b/app/proxyman/outbound/handler_test.go index e5b67308..3f7ef28e 100644 --- a/app/proxyman/outbound/handler_test.go +++ b/app/proxyman/outbound/handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/xtls/xray-core/app/stats" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/session" core "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/proxy/freedom" @@ -44,6 +45,7 @@ func TestOutboundWithoutStatCounter(t *testing.T) { v, _ := core.New(config) v.AddFeature((outbound.Manager)(new(Manager))) ctx := context.WithValue(context.Background(), xrayKey, v) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}}) h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{ Tag: "tag", ProxySettings: serial.ToTypedMessage(&freedom.Config{}), @@ -73,6 +75,7 @@ func TestOutboundWithStatCounter(t *testing.T) { v, _ := core.New(config) v.AddFeature((outbound.Manager)(new(Manager))) ctx := context.WithValue(context.Background(), xrayKey, v) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}}) h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{ Tag: "tag", ProxySettings: serial.ToTypedMessage(&freedom.Config{}), diff --git a/app/reverse/portal.go b/app/reverse/portal.go index fb0b6930..456de550 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -62,12 +62,13 @@ func (p *Portal) Close() error { } func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error { - outboundMeta := session.OutboundFromContext(ctx) - if outboundMeta == nil { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob == nil { return newError("outbound metadata not found").AtError() } - if isDomain(outboundMeta.Target, p.domain) { + if isDomain(ob.Target, p.domain) { muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{}) if err != nil { return newError("failed to create mux client worker").Base(err).AtWarning() @@ -206,9 +207,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { downlinkReader, downlinkWriter := pipe.New(opt...) ctx := context.Background() - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + outbounds := []*session.Outbound{{ Target: net.UDPDestination(net.DomainAddress(internalDomain), 0), - }) + }} + ctx = session.ContextWithOutbounds(ctx, outbounds) f := client.Dispatch(ctx, &transport.Link{ Reader: uplinkReader, Writer: downlinkWriter, diff --git a/app/router/router_test.go b/app/router/router_test.go index 4c6bfc63..2c33aae1 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -45,7 +45,9 @@ func TestSimpleRouter(t *testing.T) { HandlerSelector: mockHs, }, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -86,7 +88,9 @@ func TestSimpleBalancer(t *testing.T) { HandlerSelector: mockHs, }, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -174,7 +178,9 @@ func TestIPOnDemand(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -213,7 +219,9 @@ func TestIPIfNonMatchDomain(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -247,7 +255,9 @@ func TestIPIfNonMatchIP(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.LocalHostIP, 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { diff --git a/common/mux/client.go b/common/mux/client.go index 88621be0..2537f02b 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -148,9 +148,10 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { } go func(p proxy.Outbound, d internet.Dialer, c common.Closable) { - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ + outbounds := []*session.Outbound{{ Target: net.TCPDestination(muxCoolAddress, muxCoolPort), - }) + }} + ctx := session.ContextWithOutbounds(context.Background(), outbounds) ctx, cancel := context.WithCancel(ctx) if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { @@ -242,17 +243,18 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error { } func fetchInput(ctx context.Context, s *Session, output buf.Writer) { - dest := session.OutboundFromContext(ctx).Target + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] transferType := protocol.TransferTypeStream - if dest.Network == net.Network_UDP { + if ob.Target.Network == net.Network_UDP { transferType = protocol.TransferTypePacket } s.transferType = transferType - writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx)) + writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) defer s.Close(false) defer writer.Close() - newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) + newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx)) if err := writeFirstPayload(s.input, writer); err != nil { newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) writer.hasError = true diff --git a/common/mux/client_test.go b/common/mux/client_test.go index 7837a86e..9626e2a2 100644 --- a/common/mux/client_test.go +++ b/common/mux/client_test.go @@ -86,9 +86,9 @@ func TestClientWorkerClose(t *testing.T) { } tr1, tw1 := pipe.New(pipe.WithoutSizeLimit()) - ctx1 := session.ContextWithOutbound(context.Background(), &session.Outbound{ + ctx1 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80), - }) + }}) common.Must(manager.Dispatch(ctx1, &transport.Link{ Reader: tr1, Writer: tw1, @@ -103,9 +103,9 @@ func TestClientWorkerClose(t *testing.T) { } tr2, tw2 := pipe.New(pipe.WithoutSizeLimit()) - ctx2 := session.ContextWithOutbound(context.Background(), &session.Outbound{ + ctx2 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80), - }) + }}) common.Must(manager.Dispatch(ctx2, &transport.Link{ Reader: tr2, Writer: tw2, diff --git a/common/session/context.go b/common/session/context.go index 87586169..fc37bd72 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -51,13 +51,13 @@ func InboundFromContext(ctx context.Context) *Inbound { return nil } -func ContextWithOutbound(ctx context.Context, outbound *Outbound) context.Context { - return context.WithValue(ctx, outboundSessionKey, outbound) +func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Context { + return context.WithValue(ctx, outboundSessionKey, outbounds) } -func OutboundFromContext(ctx context.Context) *Outbound { - if outbound, ok := ctx.Value(outboundSessionKey).(*Outbound); ok { - return outbound +func OutboundsFromContext(ctx context.Context) []*Outbound { + if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok { + return outbounds } return nil } diff --git a/common/session/session.go b/common/session/session.go index 38ffa7bd..d8ab1ec4 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -50,18 +50,11 @@ type Inbound struct { Conn net.Conn // Timer of the inbound buf copier. May be nil. Timer *signal.ActivityTimer - // CanSpliceCopy is a property for this connection, set by both inbound and outbound + // CanSpliceCopy is a property for this connection // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot CanSpliceCopy int } -func(i *Inbound) SetCanSpliceCopy(canSpliceCopy int) int { - if canSpliceCopy > i.CanSpliceCopy { - i.CanSpliceCopy = canSpliceCopy - } - return i.CanSpliceCopy -} - // Outbound is the metadata of an outbound connection. type Outbound struct { // Target address of the outbound connection. @@ -70,10 +63,15 @@ type Outbound struct { RouteTarget net.Destination // Gateway address Gateway net.Address + // Tag of the outbound proxy that handles the connection. + Tag string // Name of the outbound proxy that handles the connection. Name string // Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings Conn net.Conn + // CanSpliceCopy is a property for this connection + // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot + CanSpliceCopy int } // SniffingRequest controls the behavior of content sniffing. diff --git a/common/singbridge/dialer.go b/common/singbridge/dialer.go index 896c97fe..6be83036 100644 --- a/common/singbridge/dialer.go +++ b/common/singbridge/dialer.go @@ -43,9 +43,14 @@ func NewOutboundDialer(outbound proxy.Outbound, dialer internet.Dialer) *XrayOut } func (d *XrayOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ - Target: ToDestination(destination, ToNetwork(network)), - }) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) + } + ob := outbounds[len(outbounds) - 1] + ob.Target = ToDestination(destination, ToNetwork(network)) + opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} uplinkReader, uplinkWriter := pipe.New(opts...) downlinkReader, downlinkWriter := pipe.New(opts...) diff --git a/features/routing/session/context.go b/features/routing/session/context.go index c900219d..3c9764b3 100644 --- a/features/routing/session/context.go +++ b/features/routing/session/context.go @@ -124,9 +124,11 @@ func (ctx *Context) GetSkipDNSResolve() bool { // AsRoutingContext creates a context from context.context with session info. func AsRoutingContext(ctx context.Context) routing.Context { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] return &Context{ Inbound: session.InboundFromContext(ctx), - Outbound: session.OutboundFromContext(ctx), + Outbound: ob, Content: session.ContentFromContext(ctx), } } diff --git a/proxy/blackhole/blackhole.go b/proxy/blackhole/blackhole.go index 4b819417..23c9c291 100644 --- a/proxy/blackhole/blackhole.go +++ b/proxy/blackhole/blackhole.go @@ -31,10 +31,9 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound != nil { - outbound.Name = "blackhole" - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + ob.Name = "blackhole" nBytes := h.response.WriteTo(link.Writer) if nBytes > 0 { diff --git a/proxy/blackhole/blackhole_test.go b/proxy/blackhole/blackhole_test.go index 8e487e0c..6a9cb8e8 100644 --- a/proxy/blackhole/blackhole_test.go +++ b/proxy/blackhole/blackhole_test.go @@ -7,13 +7,15 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/proxy/blackhole" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/pipe" ) func TestBlackholeHTTPResponse(t *testing.T) { - handler, err := blackhole.New(context.Background(), &blackhole.Config{ + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{}}) + handler, err := blackhole.New(ctx, &blackhole.Config{ Response: serial.ToTypedMessage(&blackhole.HTTPResponse{}), }) common.Must(err) @@ -32,7 +34,7 @@ func TestBlackholeHTTPResponse(t *testing.T) { Reader: reader, Writer: writer, } - common.Must(handler.Process(context.Background(), &link, nil)) + common.Must(handler.Process(ctx, &link, nil)) common.Must(rerr) if mb.IsEmpty() { t.Error("expect http response, but nothing") diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 2cf21a42..86790f76 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -96,15 +96,16 @@ func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage. // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("invalid outbound") } - outbound.Name = "dns" + ob.Name = "dns" - srcNetwork := outbound.Target.Network + srcNetwork := ob.Target.Network - dest := outbound.Target + dest := ob.Target if h.server.Network != net.Network_Unknown { dest.Network = h.server.Network } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 1c59fe62..5a07df5c 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -86,10 +86,15 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st destinationOverridden := false if d.config.FollowRedirect { - if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() { - dest = outbound.Target - destinationOverridden = true - } else if handshake, ok := conn.(hasHandshakeAddressContext); ok { + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) > 0 { + ob := outbounds[len(outbounds) - 1] + if ob.Target.IsValid() { + dest = ob.Target + destinationOverridden = true + } + } + if handshake, ok := conn.(hasHandshakeAddressContext); ok && !destinationOverridden { addr := handshake.HandshakeAddressContext(ctx) if addr != nil { dest.Address = addr @@ -103,7 +108,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st inbound := session.InboundFromContext(ctx) inbound.Name = "dokodemo-door" - inbound.SetCanSpliceCopy(1) + inbound.CanSpliceCopy = 1 inbound.User = &protocol.MemoryUser{ Level: d.config.UserLevel, } diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 0176929c..9e6afc9d 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -106,16 +106,16 @@ func isValidAddress(addr *net.IPOrDomain) bool { // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "freedom" + ob.Name = "freedom" + ob.CanSpliceCopy = 1 inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(1) - } - destination := outbound.Target + + destination := ob.Target UDPOverride := net.UDPDestination(nil, 0) if h.config.DestinationOverride != nil { server := h.config.DestinationOverride.Server diff --git a/proxy/http/client.go b/proxy/http/client.go index 72060c4d..80a0328a 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -69,16 +69,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements proxy.Outbound.Process. We first create a socket tunnel via HTTP CONNECT method, then redirect all inbound traffic to that tunnel. func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "http" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } - target := outbound.Target + ob.Name = "http" + ob.CanSpliceCopy = 2 + target := ob.Target targetAddr := target.NetAddr() if target.Network == net.Network_UDP { @@ -175,9 +173,10 @@ func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error) } inbound := session.InboundFromContext(ctx) - outbound := session.OutboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] - if inbound == nil || outbound == nil { + if inbound == nil || ob == nil { return nil, newError("missing inbound or outbound metadata from context") } @@ -186,7 +185,7 @@ func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error) Target net.Destination }{ Source: inbound.Source, - Target: outbound.Target, + Target: ob.Target, } filled := make([]*Header, len(header)) diff --git a/proxy/http/server.go b/proxy/http/server.go index 511d9b08..a7df317d 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -85,7 +85,7 @@ type readerOnly struct { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "http" - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } diff --git a/proxy/loopback/loopback.go b/proxy/loopback/loopback.go index 30c39bd9..f3be5a95 100644 --- a/proxy/loopback/loopback.go +++ b/proxy/loopback/loopback.go @@ -22,12 +22,13 @@ type Loopback struct { } func (l *Loopback) Process(ctx context.Context, link *transport.Link, _ internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "loopback" - destination := outbound.Target + ob.Name = "loopback" + destination := ob.Target newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/proxy.go b/proxy/proxy.go index 6a5a1798..2507d029 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -474,45 +474,73 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net readerConn, readCounter, _ := UnwrapRawConn(readerConn) writerConn, _, writeCounter := UnwrapRawConn(writerConn) reader := buf.NewReader(readerConn) - if inbound := session.InboundFromContext(ctx); inbound != nil { - if tc, ok := writerConn.(*net.TCPConn); ok && readerConn != nil && writerConn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { - for inbound.CanSpliceCopy != 3 { - if inbound.CanSpliceCopy == 1 { - newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx)) - statWriter, _ := writer.(*dispatcher.SizeStatWriter) - //runtime.Gosched() // necessary - time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice - w, err := tc.ReadFrom(readerConn) - if readCounter != nil { - readCounter.Add(w) // outbound stats - } - if writeCounter != nil { - writeCounter.Add(w) // inbound stats - } - if statWriter != nil { - statWriter.Counter.Add(w) // user stats - } - if err != nil && errors.Cause(err) != io.EOF { - return err - } - return nil - } - buffer, err := reader.ReadMultiBuffer() - if !buffer.IsEmpty() { - if readCounter != nil { - readCounter.Add(int64(buffer.Len())) - } - timer.Update() - if werr := writer.WriteMultiBuffer(buffer); werr != nil { - return werr - } - } - if err != nil { - return err - } - } + if runtime.GOOS != "linux" && runtime.GOOS != "android" { + return readV(ctx, reader, writer, timer, readCounter) + } + tc, ok := writerConn.(*net.TCPConn) + if !ok || readerConn == nil || writerConn == nil { + return readV(ctx, reader, writer, timer, readCounter) + } + inbound := session.InboundFromContext(ctx) + if inbound == nil || inbound.CanSpliceCopy == 3 { + return readV(ctx, reader, writer, timer, readCounter) + } + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + return readV(ctx, reader, writer, timer, readCounter) + } + for _, ob := range outbounds { + if ob.CanSpliceCopy == 3 { + return readV(ctx, reader, writer, timer, readCounter) } } + + for { + inbound := session.InboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + var splice = inbound.CanSpliceCopy == 1 + for _, ob := range outbounds { + if ob.CanSpliceCopy != 1 { + splice = false + } + } + if splice { + newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx)) + statWriter, _ := writer.(*dispatcher.SizeStatWriter) + //runtime.Gosched() // necessary + time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice + w, err := tc.ReadFrom(readerConn) + if readCounter != nil { + readCounter.Add(w) // outbound stats + } + if writeCounter != nil { + writeCounter.Add(w) // inbound stats + } + if statWriter != nil { + statWriter.Counter.Add(w) // user stats + } + if err != nil && errors.Cause(err) != io.EOF { + return err + } + return nil + } + buffer, err := reader.ReadMultiBuffer() + if !buffer.IsEmpty() { + if readCounter != nil { + readCounter.Add(int64(buffer.Len())) + } + timer.Update() + if werr := writer.WriteMultiBuffer(buffer); werr != nil { + return werr + } + } + if err != nil { + return err + } + } +} + +func readV(ctx context.Context, reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, readCounter stats.Counter) error { newError("CopyRawConn readv").WriteToLog(session.ExportIDToError(ctx)) if err := buf.Copy(reader, writer, buf.UpdateActivity(timer), buf.AddToStatCounter(readCounter)); err != nil { return newError("failed to process response").Base(err) diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 57d8f81c..8ebe7631 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -49,16 +49,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements OutboundHandler.Process(). func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "shadowsocks" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } - destination := outbound.Target + ob.Name = "shadowsocks" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network var server *protocol.ServerSpec diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 2975ba70..8253506a 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -73,7 +73,7 @@ func (s *Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 switch network { case net.Network_TCP: diff --git a/proxy/shadowsocks_2022/inbound.go b/proxy/shadowsocks_2022/inbound.go index 00314c90..f1eb76a5 100644 --- a/proxy/shadowsocks_2022/inbound.go +++ b/proxy/shadowsocks_2022/inbound.go @@ -66,7 +66,7 @@ func (i *Inbound) Network() []net.Network { func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_multi.go b/proxy/shadowsocks_2022/inbound_multi.go index df837894..f80ec6d1 100644 --- a/proxy/shadowsocks_2022/inbound_multi.go +++ b/proxy/shadowsocks_2022/inbound_multi.go @@ -155,7 +155,7 @@ func (i *MultiUserInbound) Network() []net.Network { func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022-multi" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_relay.go b/proxy/shadowsocks_2022/inbound_relay.go index 7317f8dd..1c4b8248 100644 --- a/proxy/shadowsocks_2022/inbound_relay.go +++ b/proxy/shadowsocks_2022/inbound_relay.go @@ -87,7 +87,7 @@ func (i *RelayInbound) Network() []net.Network { func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022-relay" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/outbound.go b/proxy/shadowsocks_2022/outbound.go index bc1eb556..cac9a91b 100644 --- a/proxy/shadowsocks_2022/outbound.go +++ b/proxy/shadowsocks_2022/outbound.go @@ -65,15 +65,16 @@ func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer int inbound := session.InboundFromContext(ctx) if inbound != nil { inboundConn = inbound.Conn - inbound.SetCanSpliceCopy(3) } - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "shadowsocks-2022" - destination := outbound.Target + ob.Name = "shadowsocks-2022" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network newError("tunneling request to ", destination, " via ", o.server.NetAddr()).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 82591be4..b283eb65 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -57,17 +57,15 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements proxy.Outbound.Process. func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "socks" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } + ob.Name = "socks" + ob.CanSpliceCopy = 2 // Destination of the inner request. - destination := outbound.Target + destination := ob.Target // Outbound server. var server *protocol.ServerSpec diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 2f789757..0109d5b4 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -65,7 +65,7 @@ func (s *Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "socks" - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index d6b95fc0..3a4d838a 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -50,16 +50,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements OutboundHandler.Process(). func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "trojan" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } - destination := outbound.Target + ob.Name = "trojan" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network var server *protocol.ServerSpec diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 5c3fcd91..bc52c2b1 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -215,7 +215,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con inbound := session.InboundFromContext(ctx) inbound.Name = "trojan" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 inbound.User = user sessionPolicy = s.policyManager.ForLevel(user.Level) diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 5956389a..2976be74 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -174,15 +174,18 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A } // XtlsRead filter and read xtls protocol -func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ctx context.Context) error { +func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { err := func() error { for { if trafficState.ReaderSwitchToDirectCopy { var writerConn net.Conn - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && ob != nil { writerConn = inbound.Conn if inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter + inbound.CanSpliceCopy = 1 + } + if ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change + ob.CanSpliceCopy = 1 } } return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer) @@ -219,14 +222,19 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater } // XtlsWrite filter and write xtls protocol -func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ctx context.Context) error { +func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { err := func() error { var ct stats.Counter for { buffer, err := reader.ReadMultiBuffer() if trafficState.WriterSwitchToDirectCopy { - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter + if inbound := session.InboundFromContext(ctx); inbound != nil && ob != nil { + if inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 + } + if ob.CanSpliceCopy == 2 { + ob.CanSpliceCopy = 1 + } } rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) writer = buf.NewWriter(rawConn) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 0ffa61d2..7d2dd507 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -449,7 +449,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s switch requestAddons.Flow { case vless.XRV: if account.Flow == requestAddons.Flow { - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 switch request.Command { case protocol.RequestCommandUDP: return newError(requestAddons.Flow + " doesn't support UDP").AtWarning() @@ -479,7 +479,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s return newError(account.ID.String() + " is not able to use " + requestAddons.Flow).AtWarning() } case "": - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 if account.Flow == vless.XRV && (request.Command == protocol.RequestCommandTCP || isMuxAndNotXUDP(request, first)) { return newError(account.ID.String() + " is not able to use \"\". Note that the pure TLS proxy has certain TLS in TLS characters.").AtWarning() } @@ -523,7 +523,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if requestAddons.Flow == vless.XRV { ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1) - err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, ctx1) + err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -560,7 +560,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s var err error if requestAddons.Flow == vless.XRV { - err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ob, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index a9368813..bf98253b 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -70,12 +70,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified").AtError() } - outbound.Name = "vless" - inbound := session.InboundFromContext(ctx) + ob.Name = "vless" var rec *protocol.ServerSpec var conn stat.Connection @@ -96,7 +96,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if statConn, ok := iConn.(*stat.CounterConnection); ok { iConn = statConn.Connection } - target := outbound.Target + target := ob.Target newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).AtInfo().WriteToLog(session.ExportIDToError(ctx)) command := protocol.RequestCommandTCP @@ -130,9 +130,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte requestAddons.Flow = requestAddons.Flow[:16] fallthrough case vless.XRV: - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } + ob.CanSpliceCopy = 2 switch request.Command { case protocol.RequestCommandUDP: if !allowUDP443 && request.Port == 443 { @@ -161,9 +159,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) } default: - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.CanSpliceCopy = 3 } var newCtx context.Context @@ -238,8 +234,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, utlsConn.ConnectionState().Version).AtWarning() } } - ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice - err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ctx1) + ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice + err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -277,7 +273,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } if requestAddons.Flow == vless.XRV { - err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ctx) + err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 679ea5da..f5340f20 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -257,7 +257,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s inbound := session.InboundFromContext(ctx) inbound.Name = "vmess" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 inbound.User = request.User sessionPolicy = h.policyManager.ForLevel(request.User.Level) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index c3c55d95..8f102dbb 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -60,15 +60,13 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified").AtError() } - outbound.Name = "vmess" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.Name = "vmess" + ob.CanSpliceCopy = 3 var rec *protocol.ServerSpec var conn stat.Connection @@ -87,7 +85,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } defer conn.Close() - target := outbound.Target + target := ob.Target newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx)) command := protocol.RequestCommandTCP diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go index 4136525e..00a6fa51 100644 --- a/proxy/wireguard/client.go +++ b/proxy/wireguard/client.go @@ -127,22 +127,20 @@ func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "wireguard" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.Name = "wireguard" + ob.CanSpliceCopy = 3 if err := h.processWireGuard(dialer); err != nil { return err } // Destination of the inner request. - destination := outbound.Target + destination := ob.Target command := protocol.RequestCommandTCP if destination.Network == net.Network_UDP { command = protocol.RequestCommandUDP diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index bdb4e801..3d3b584c 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -79,13 +79,15 @@ func (*Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "wireguard" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] s.info = routingInfo{ ctx: core.ToBackgroundDetachedContext(ctx), dispatcher: dispatcher, inboundTag: session.InboundFromContext(ctx), - outboundTag: session.OutboundFromContext(ctx), + outboundTag: ob, contentTag: session.ContentFromContext(ctx), } @@ -145,7 +147,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { ctx = session.ContextWithInbound(ctx, s.info.inboundTag) } if s.info.outboundTag != nil { - ctx = session.ContextWithOutbound(ctx, s.info.outboundTag) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag}) } if s.info.contentTag != nil { ctx = session.ContextWithContent(ctx, s.info.contentTag) diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 3d5d046f..ffa868a3 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -112,7 +112,12 @@ func canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { newError("redirecting request " + dst.String() + " to " + obt).WriteToLog(session.ExportIDToError(ctx)) h := obm.GetHandler(obt) - ctx = session.ContextWithOutbound(ctx, &session.Outbound{Target: dst, Gateway: nil}) + outbounds := session.OutboundsFromContext(ctx) + ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{ + Target: dst, + Gateway: nil, + Tag: obt, + })) // add another outbound in session ctx if h != nil { ur, uw := pipe.New(pipe.OptionsFromContext(ctx)...) dr, dw := pipe.New(pipe.OptionsFromContext(ctx)...) @@ -131,8 +136,10 @@ func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { // DialSystem calls system dialer to create a network connection. func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { var src net.Address - if outbound := session.OutboundFromContext(ctx); outbound != nil { - src = outbound.Gateway + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) > 0 { + ob := outbounds[len(outbounds) - 1] + src = ob.Gateway } if sockopt == nil { return effectiveSystemDialer.Dial(ctx, src, dest, sockopt) diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 5d5789b4..a4b03ced 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -118,7 +118,7 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in address := net.ParseAddress(rawHost) gctx = session.ContextWithID(gctx, session.IDFromContext(ctx)) - gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx)) + gctx = session.ContextWithOutbounds(gctx, session.OutboundsFromContext(ctx)) gctx = session.ContextWithTimeoutOnly(gctx, true) c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt) diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index acccd0b7..0148658c 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -68,7 +68,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in address := net.ParseAddress(rawHost) hctx = session.ContextWithID(hctx, session.IDFromContext(ctx)) - hctx = session.ContextWithOutbound(hctx, session.OutboundFromContext(ctx)) + hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx)) hctx = session.ContextWithTimeoutOnly(hctx, true) pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)