XTLS Vision: Use separate uplink/downlink flag for direct copy (#4329)

Fixes https://github.com/XTLS/Xray-core/issues/4033
This commit is contained in:
yuhan6665 2025-01-27 15:44:33 -05:00 committed by GitHub
parent 7b59379d73
commit 03131c72db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 44 additions and 24 deletions

View file

@ -110,7 +110,8 @@ type TrafficState struct {
// reader link state // reader link state
WithinPaddingBuffers bool WithinPaddingBuffers bool
ReaderSwitchToDirectCopy bool DownlinkReaderDirectCopy bool
UplinkReaderDirectCopy bool
RemainingCommand int32 RemainingCommand int32
RemainingContent int32 RemainingContent int32
RemainingPadding int32 RemainingPadding int32
@ -118,7 +119,8 @@ type TrafficState struct {
// write link state // write link state
IsPadding bool IsPadding bool
WriterSwitchToDirectCopy bool DownlinkWriterDirectCopy bool
UplinkWriterDirectCopy bool
} }
func NewTrafficState(userUUID []byte) *TrafficState { func NewTrafficState(userUUID []byte) *TrafficState {
@ -131,13 +133,15 @@ func NewTrafficState(userUUID []byte) *TrafficState {
Cipher: 0, Cipher: 0,
RemainingServerHello: -1, RemainingServerHello: -1,
WithinPaddingBuffers: true, WithinPaddingBuffers: true,
ReaderSwitchToDirectCopy: false, DownlinkReaderDirectCopy: false,
UplinkReaderDirectCopy: false,
RemainingCommand: -1, RemainingCommand: -1,
RemainingContent: -1, RemainingContent: -1,
RemainingPadding: -1, RemainingPadding: -1,
CurrentCommand: 0, CurrentCommand: 0,
IsPadding: true, IsPadding: true,
WriterSwitchToDirectCopy: false, DownlinkWriterDirectCopy: false,
UplinkWriterDirectCopy: false,
} }
} }
@ -147,13 +151,15 @@ type VisionReader struct {
buf.Reader buf.Reader
trafficState *TrafficState trafficState *TrafficState
ctx context.Context ctx context.Context
isUplink bool
} }
func NewVisionReader(reader buf.Reader, state *TrafficState, context context.Context) *VisionReader { func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, context context.Context) *VisionReader {
return &VisionReader{ return &VisionReader{
Reader: reader, Reader: reader,
trafficState: state, trafficState: state,
ctx: context, ctx: context,
isUplink: isUplink,
} }
} }
@ -175,7 +181,11 @@ func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
w.trafficState.WithinPaddingBuffers = false w.trafficState.WithinPaddingBuffers = false
} else if w.trafficState.CurrentCommand == 2 { } else if w.trafficState.CurrentCommand == 2 {
w.trafficState.WithinPaddingBuffers = false w.trafficState.WithinPaddingBuffers = false
w.trafficState.ReaderSwitchToDirectCopy = true if w.isUplink {
w.trafficState.UplinkReaderDirectCopy = true
} else {
w.trafficState.DownlinkReaderDirectCopy = true
}
} else { } else {
errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len()) errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len())
} }
@ -194,9 +204,10 @@ type VisionWriter struct {
trafficState *TrafficState trafficState *TrafficState
ctx context.Context ctx context.Context
writeOnceUserUUID []byte writeOnceUserUUID []byte
isUplink bool
} }
func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Context) *VisionWriter { func NewVisionWriter(writer buf.Writer, state *TrafficState, isUplink bool, context context.Context) *VisionWriter {
w := make([]byte, len(state.UserUUID)) w := make([]byte, len(state.UserUUID))
copy(w, state.UserUUID) copy(w, state.UserUUID)
return &VisionWriter{ return &VisionWriter{
@ -204,6 +215,7 @@ func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Con
trafficState: state, trafficState: state,
ctx: context, ctx: context,
writeOnceUserUUID: w, writeOnceUserUUID: w,
isUplink: isUplink,
} }
} }
@ -221,7 +233,11 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for i, b := range mb { for i, b := range mb {
if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) { if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) {
if w.trafficState.EnableXtls { if w.trafficState.EnableXtls {
w.trafficState.WriterSwitchToDirectCopy = true if w.isUplink {
w.trafficState.UplinkWriterDirectCopy = true
} else {
w.trafficState.DownlinkWriterDirectCopy = true
}
} }
var command byte = CommandPaddingContinue var command byte = CommandPaddingContinue
if i == len(mb)-1 { if i == len(mb)-1 {

View file

@ -61,13 +61,13 @@ func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) {
} }
// EncodeBodyAddons returns a Writer that auto-encrypt content written by caller. // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller.
func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, context context.Context) buf.Writer { func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, isUplink bool, context context.Context) buf.Writer {
if request.Command == protocol.RequestCommandUDP { if request.Command == protocol.RequestCommandUDP {
return NewMultiLengthPacketWriter(writer.(buf.Writer)) return NewMultiLengthPacketWriter(writer.(buf.Writer))
} }
w := buf.NewWriter(writer) w := buf.NewWriter(writer)
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
w = proxy.NewVisionWriter(w, state, context) w = proxy.NewVisionWriter(w, state, isUplink, context)
} }
return w return w
} }

View file

@ -172,10 +172,10 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
} }
// XtlsRead filter and read xtls protocol // XtlsRead filter and read xtls protocol
func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
err := func() error { err := func() error {
for { for {
if trafficState.ReaderSwitchToDirectCopy { if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
var writerConn net.Conn var writerConn net.Conn
var inTimer *signal.ActivityTimer var inTimer *signal.ActivityTimer
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
buffer, err := reader.ReadMultiBuffer() buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() { if !buffer.IsEmpty() {
timer.Update() timer.Update()
if trafficState.ReaderSwitchToDirectCopy { if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
// XTLS Vision processes struct TLS Conn's input and rawInput // XTLS Vision processes struct TLS Conn's input and rawInput
if inputBuffer, err := buf.ReadFrom(input); err == nil { if inputBuffer, err := buf.ReadFrom(input); err == nil {
if !inputBuffer.IsEmpty() { if !inputBuffer.IsEmpty() {
@ -222,12 +222,12 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
} }
// XtlsWrite filter and write xtls protocol // XtlsWrite filter and write xtls protocol
func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
err := func() error { err := func() error {
var ct stats.Counter var ct stats.Counter
for { for {
buffer, err := reader.ReadMultiBuffer() buffer, err := reader.ReadMultiBuffer()
if trafficState.WriterSwitchToDirectCopy { if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy {
if inbound := session.InboundFromContext(ctx); inbound != nil { if inbound := session.InboundFromContext(ctx); inbound != nil {
if inbound.CanSpliceCopy == 2 { if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1 inbound.CanSpliceCopy = 1
@ -239,7 +239,11 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) rawConn, _, writerCounter := proxy.UnwrapRawConn(conn)
writer = buf.NewWriter(rawConn) writer = buf.NewWriter(rawConn)
ct = writerCounter ct = writerCounter
trafficState.WriterSwitchToDirectCopy = false if isUplink {
trafficState.UplinkWriterDirectCopy = false
} else {
trafficState.DownlinkWriterDirectCopy = false
}
} }
if !buffer.IsEmpty() { if !buffer.IsEmpty() {
if ct != nil { if ct != nil {

View file

@ -538,8 +538,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1) clientReader = proxy.NewVisionReader(clientReader, trafficState, true, ctx1)
err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1) err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, true, ctx1)
} else { } else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -561,7 +561,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
} }
// default: clientWriter := bufferWriter // default: clientWriter := bufferWriter
clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx) clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, false, ctx)
multiBuffer, err1 := serverReader.ReadMultiBuffer() multiBuffer, err1 := serverReader.ReadMultiBuffer()
if err1 != nil { if err1 != nil {
return err1 // ... return err1 // ...
@ -576,7 +576,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
var err error var err error
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, ctx) err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, false, ctx)
} else { } else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))

View file

@ -194,7 +194,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
} }
// default: serverWriter := bufferWriter // default: serverWriter := bufferWriter
serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx) serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, true, ctx)
if request.Command == protocol.RequestCommandMux && request.Port == 666 { if request.Command == protocol.RequestCommandMux && request.Port == 666 {
serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx)) serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx))
} }
@ -234,7 +234,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
} }
} }
ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1) err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, true, ctx1)
} else { } else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -261,7 +261,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
// default: serverReader := buf.NewReader(conn) // default: serverReader := buf.NewReader(conn)
serverReader := encoding.DecodeBodyAddons(conn, request, responseAddons) serverReader := encoding.DecodeBodyAddons(conn, request, responseAddons)
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
serverReader = proxy.NewVisionReader(serverReader, trafficState, ctx) serverReader = proxy.NewVisionReader(serverReader, trafficState, false, ctx)
} }
if request.Command == protocol.RequestCommandMux && request.Port == 666 { if request.Command == protocol.RequestCommandMux && request.Port == 666 {
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
@ -272,7 +272,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
} }
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx) err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, false, ctx)
} else { } else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))