From 009037b1e2a4dd6c302b04ee0e5e92be3a911310 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Sun, 26 Jan 2025 12:17:17 -0500 Subject: [PATCH] Add separate uplink/downlink flag for direct copy For each connection, xtls need 4 flags for uplink/downlink reader/writer to decide when it switch to direct copy. In the past, there were only one for read and one for write. If service has xtls inbound and xtls outbound, the two flags may be corrupted by signal from different directions. --- proxy/proxy.go | 32 ++++++++++++++++++++++++-------- proxy/vless/encoding/addons.go | 4 ++-- proxy/vless/encoding/encoding.go | 16 ++++++++++------ proxy/vless/inbound/inbound.go | 8 ++++---- proxy/vless/outbound/outbound.go | 8 ++++---- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 1e4c69f5..a3d3fccb 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -110,7 +110,8 @@ type TrafficState struct { // reader link state WithinPaddingBuffers bool - ReaderSwitchToDirectCopy bool + DownlinkReaderDirectCopy bool + UplinkReaderDirectCopy bool RemainingCommand int32 RemainingContent int32 RemainingPadding int32 @@ -118,7 +119,8 @@ type TrafficState struct { // write link state IsPadding bool - WriterSwitchToDirectCopy bool + DownlinkWriterDirectCopy bool + UplinkWriterDirectCopy bool } func NewTrafficState(userUUID []byte) *TrafficState { @@ -131,13 +133,15 @@ func NewTrafficState(userUUID []byte) *TrafficState { Cipher: 0, RemainingServerHello: -1, WithinPaddingBuffers: true, - ReaderSwitchToDirectCopy: false, + DownlinkReaderDirectCopy: false, + UplinkReaderDirectCopy: false, RemainingCommand: -1, RemainingContent: -1, RemainingPadding: -1, CurrentCommand: 0, IsPadding: true, - WriterSwitchToDirectCopy: false, + DownlinkWriterDirectCopy: false, + UplinkWriterDirectCopy: false, } } @@ -147,13 +151,15 @@ type VisionReader struct { buf.Reader trafficState *TrafficState ctx context.Context + isUplink bool } -func NewVisionReader(reader buf.Reader, state *TrafficState, context context.Context) *VisionReader { +func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, context context.Context) *VisionReader { return &VisionReader{ Reader: reader, trafficState: state, ctx: context, + isUplink: isUplink, } } @@ -175,7 +181,11 @@ func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) { w.trafficState.WithinPaddingBuffers = false } else if w.trafficState.CurrentCommand == 2 { w.trafficState.WithinPaddingBuffers = false - w.trafficState.ReaderSwitchToDirectCopy = true + if w.isUplink { + w.trafficState.UplinkReaderDirectCopy = true + } else { + w.trafficState.DownlinkReaderDirectCopy = true + } } else { errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len()) } @@ -194,9 +204,10 @@ type VisionWriter struct { trafficState *TrafficState ctx context.Context writeOnceUserUUID []byte + isUplink bool } -func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Context) *VisionWriter { +func NewVisionWriter(writer buf.Writer, state *TrafficState, isUplink bool, context context.Context) *VisionWriter { w := make([]byte, len(state.UserUUID)) copy(w, state.UserUUID) return &VisionWriter{ @@ -204,6 +215,7 @@ func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Con trafficState: state, ctx: context, writeOnceUserUUID: w, + isUplink: isUplink, } } @@ -221,7 +233,11 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { for i, b := range mb { if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) { if w.trafficState.EnableXtls { - w.trafficState.WriterSwitchToDirectCopy = true + if w.isUplink { + w.trafficState.UplinkWriterDirectCopy = true + } else { + w.trafficState.DownlinkWriterDirectCopy = true + } } var command byte = CommandPaddingContinue if i == len(mb)-1 { diff --git a/proxy/vless/encoding/addons.go b/proxy/vless/encoding/addons.go index 1bf1817d..4474e3c9 100644 --- a/proxy/vless/encoding/addons.go +++ b/proxy/vless/encoding/addons.go @@ -61,13 +61,13 @@ func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) { } // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller. -func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, context context.Context) buf.Writer { +func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, isUplink bool, context context.Context) buf.Writer { if request.Command == protocol.RequestCommandUDP { return NewMultiLengthPacketWriter(writer.(buf.Writer)) } w := buf.NewWriter(writer) if requestAddons.Flow == vless.XRV { - w = proxy.NewVisionWriter(w, state, context) + w = proxy.NewVisionWriter(w, state, isUplink, context) } return w } diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 8b067a96..3fce3290 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -172,10 +172,10 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A } // XtlsRead filter and read xtls protocol -func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { +func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error { err := func() error { for { - if trafficState.ReaderSwitchToDirectCopy { + if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy { var writerConn net.Conn var inTimer *signal.ActivityTimer if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { @@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, buffer, err := reader.ReadMultiBuffer() if !buffer.IsEmpty() { timer.Update() - if trafficState.ReaderSwitchToDirectCopy { + if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy { // XTLS Vision processes struct TLS Conn's input and rawInput if inputBuffer, err := buf.ReadFrom(input); err == nil { if !inputBuffer.IsEmpty() { @@ -222,12 +222,12 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, } // XtlsWrite filter and write xtls protocol -func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { +func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error { err := func() error { var ct stats.Counter for { buffer, err := reader.ReadMultiBuffer() - if trafficState.WriterSwitchToDirectCopy { + if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy { if inbound := session.InboundFromContext(ctx); inbound != nil { if inbound.CanSpliceCopy == 2 { inbound.CanSpliceCopy = 1 @@ -239,7 +239,11 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) writer = buf.NewWriter(rawConn) ct = writerCounter - trafficState.WriterSwitchToDirectCopy = false + if isUplink { + trafficState.UplinkWriterDirectCopy = false + } else { + trafficState.DownlinkWriterDirectCopy = false + } } if !buffer.IsEmpty() { if ct != nil { diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index a2415a44..1da2e091 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -538,8 +538,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if requestAddons.Flow == vless.XRV { ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice - clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1) - err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1) + clientReader = proxy.NewVisionReader(clientReader, trafficState, true, ctx1) + err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, true, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -561,7 +561,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s } // default: clientWriter := bufferWriter - clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx) + clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, false, ctx) multiBuffer, err1 := serverReader.ReadMultiBuffer() if err1 != nil { return err1 // ... @@ -576,7 +576,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s var err error if requestAddons.Flow == vless.XRV { - err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, ctx) + err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, false, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index ed9e07dc..e1a727eb 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -194,7 +194,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } // default: serverWriter := bufferWriter - serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx) + serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, true, ctx) if request.Command == protocol.RequestCommandMux && request.Port == 666 { serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx)) } @@ -234,7 +234,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } } ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice - err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1) + err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, true, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -261,7 +261,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte // default: serverReader := buf.NewReader(conn) serverReader := encoding.DecodeBodyAddons(conn, request, responseAddons) if requestAddons.Flow == vless.XRV { - serverReader = proxy.NewVisionReader(serverReader, trafficState, ctx) + serverReader = proxy.NewVisionReader(serverReader, trafficState, false, ctx) } if request.Command == protocol.RequestCommandMux && request.Port == 666 { if requestAddons.Flow == vless.XRV { @@ -272,7 +272,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } if requestAddons.Flow == vless.XRV { - err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx) + err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, false, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))