Refactor Vision reader writer

- Vision now use traffic states to capture two-way info about a connection
- XTLS is de-couple with Vision, it only read traffic states to switch to direct copy mode
- fix a edge case error when Vision unpadding read 5 command bytes
This commit is contained in:
yuhan6665 2023-09-02 11:37:50 -04:00
parent efd32b0fb2
commit d6d225c698
5 changed files with 440 additions and 374 deletions

View file

@ -6,10 +6,14 @@
package proxy package proxy
import ( import (
"bytes"
"context" "context"
"crypto/rand"
gotls "crypto/tls" gotls "crypto/tls"
"io" "io"
"math/big"
"runtime" "runtime"
"strconv"
"github.com/pires/go-proxyproto" "github.com/pires/go-proxyproto"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
@ -27,6 +31,30 @@ import (
"github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/tls"
) )
var (
Tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}
TlsClientHandShakeStart = []byte{0x16, 0x03}
TlsServerHandShakeStart = []byte{0x16, 0x03, 0x03}
TlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
Tls13CipherSuiteDic = map[uint16]string{
0x1301: "TLS_AES_128_GCM_SHA256",
0x1302: "TLS_AES_256_GCM_SHA384",
0x1303: "TLS_CHACHA20_POLY1305_SHA256",
0x1304: "TLS_AES_128_CCM_SHA256",
0x1305: "TLS_AES_128_CCM_8_SHA256",
}
)
const (
TlsHandshakeTypeClientHello byte = 0x01
TlsHandshakeTypeServerHello byte = 0x02
CommandPaddingContinue byte = 0x00
CommandPaddingEnd byte = 0x01
CommandPaddingDirect byte = 0x02
)
// An Inbound processes inbound connections. // An Inbound processes inbound connections.
type Inbound interface { type Inbound interface {
// Network returns a list of networks that this inbound supports. Connections with not-supported networks will not be passed into Process(). // Network returns a list of networks that this inbound supports. Connections with not-supported networks will not be passed into Process().
@ -59,6 +87,364 @@ type GetOutbound interface {
GetOutbound() Outbound GetOutbound() Outbound
} }
// TrafficState is used to track uplink and downlink of one connection
// It is used by XTLS to determine if switch to raw copy mode, It is used by Vision to calculate padding
type TrafficState struct {
UserUUID []byte
NumberOfPacketToFilter int
EnableXtls bool
IsTLS12orAbove bool
IsTLS bool
Cipher uint16
RemainingServerHello int32
// reader link state
WithinPaddingBuffers bool
ReaderSwitchToDirectCopy bool
RemainingCommand int32
RemainingContent int32
RemainingPadding int32
CurrentCommand int
// write link state
IsPadding bool
WriterSwitchToDirectCopy bool
}
func NewTrafficState(userUUID []byte) *TrafficState {
return &TrafficState{
UserUUID: userUUID,
NumberOfPacketToFilter: 8,
EnableXtls: false,
IsTLS12orAbove: false,
IsTLS: false,
Cipher: 0,
RemainingServerHello: -1,
WithinPaddingBuffers: true,
ReaderSwitchToDirectCopy: false,
RemainingCommand: -1,
RemainingContent: -1,
RemainingPadding: -1,
CurrentCommand: 0,
IsPadding: true,
WriterSwitchToDirectCopy: false,
}
}
// VisionReader is used to read xtls vision protocol
// Note Vision probably only make sense as the inner most layer of reader, since it need assess traffic state from origin proxy traffic
type VisionReader struct {
buf.Reader
trafficState *TrafficState
ctx context.Context
}
func NewVisionReader(reader buf.Reader, state *TrafficState, context context.Context) *VisionReader {
return &VisionReader{
Reader: reader,
trafficState: state,
ctx: context,
}
}
func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer, err := w.Reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
if w.trafficState.WithinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 {
mb2 := make(buf.MultiBuffer, 0, len(buffer))
for _, b := range buffer {
newbuffer := XtlsUnpadding(b, w.trafficState, w.ctx)
if newbuffer.Len() > 0 {
mb2 = append(mb2, newbuffer)
}
}
buffer = mb2
if w.trafficState.RemainingContent == 0 && w.trafficState.RemainingPadding == 0 {
if w.trafficState.CurrentCommand == 1 {
w.trafficState.WithinPaddingBuffers = false
} else if w.trafficState.CurrentCommand == 2 {
w.trafficState.WithinPaddingBuffers = false
w.trafficState.ReaderSwitchToDirectCopy = true
} else if w.trafficState.CurrentCommand == 0 {
w.trafficState.WithinPaddingBuffers = true
} else {
newError("XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len()).WriteToLog(session.ExportIDToError(w.ctx))
}
} else if w.trafficState.RemainingContent > 0 || w.trafficState.RemainingPadding > 0 {
w.trafficState.WithinPaddingBuffers = true
} else {
w.trafficState.WithinPaddingBuffers = false
}
}
if w.trafficState.NumberOfPacketToFilter > 0 {
XtlsFilterTls(buffer, w.trafficState, w.ctx)
}
}
return buffer, err
}
// VisionWriter is used to write xtls vision protocol
// Note Vision probably only make sense as the inner most layer of writer, since it need assess traffic state from origin proxy traffic
type VisionWriter struct {
buf.Writer
trafficState *TrafficState
ctx context.Context
writeOnceUserUUID []byte
}
func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Context) *VisionWriter {
w := make([]byte, len(state.UserUUID))
copy(w, state.UserUUID)
return &VisionWriter{
Writer: writer,
trafficState: state,
ctx: context,
writeOnceUserUUID: w,
}
}
func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if w.trafficState.NumberOfPacketToFilter > 0 {
XtlsFilterTls(mb, w.trafficState, w.ctx)
}
if w.trafficState.IsPadding {
if len(mb) == 1 && mb[0] == nil {
mb[0] = XtlsPadding(nil, CommandPaddingContinue, &w.writeOnceUserUUID, true, w.ctx) // we do a long padding to hide vless header
return w.Writer.WriteMultiBuffer(mb)
}
mb = ReshapeMultiBuffer(w.ctx, mb)
longPadding := w.trafficState.IsTLS
for i, b := range mb {
if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) {
if w.trafficState.EnableXtls {
w.trafficState.WriterSwitchToDirectCopy = true
}
var command byte = CommandPaddingContinue
if i == len(mb) - 1 {
command = CommandPaddingEnd
if w.trafficState.EnableXtls {
command = CommandPaddingDirect
}
}
mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, true, w.ctx)
w.trafficState.IsPadding = false // padding going to end
longPadding = false
continue
} else if !w.trafficState.IsTLS12orAbove && w.trafficState.NumberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early
w.trafficState.IsPadding = false
mb[i] = XtlsPadding(b, CommandPaddingEnd, &w.writeOnceUserUUID, longPadding, w.ctx)
break
}
var command byte = CommandPaddingContinue
if i == len(mb) - 1 && !w.trafficState.IsPadding {
command = CommandPaddingEnd
if w.trafficState.EnableXtls {
command = CommandPaddingDirect
}
}
mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, longPadding, w.ctx)
}
}
return w.Writer.WriteMultiBuffer(mb)
}
// ReshapeMultiBuffer prepare multi buffer for padding stucture (max 21 bytes)
func ReshapeMultiBuffer(ctx context.Context, buffer buf.MultiBuffer) buf.MultiBuffer {
needReshape := 0
for _, b := range buffer {
if b.Len() >= buf.Size-21 {
needReshape += 1
}
}
if needReshape == 0 {
return buffer
}
mb2 := make(buf.MultiBuffer, 0, len(buffer)+needReshape)
toPrint := ""
for i, buffer1 := range buffer {
if buffer1.Len() >= buf.Size-21 {
index := int32(bytes.LastIndex(buffer1.Bytes(), TlsApplicationDataStart))
if index <= 0 || index > buf.Size-21 {
index = buf.Size / 2
}
buffer2 := buf.New()
buffer2.Write(buffer1.BytesFrom(index))
buffer1.Resize(0, index)
mb2 = append(mb2, buffer1, buffer2)
toPrint += " " + strconv.Itoa(int(buffer1.Len())) + " " + strconv.Itoa(int(buffer2.Len()))
} else {
mb2 = append(mb2, buffer1)
toPrint += " " + strconv.Itoa(int(buffer1.Len()))
}
buffer[i] = nil
}
buffer = buffer[:0]
newError("ReshapeMultiBuffer ", toPrint).WriteToLog(session.ExportIDToError(ctx))
return mb2
}
// XtlsPadding add padding to eliminate length siganature during tls handshake
func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool, ctx context.Context) *buf.Buffer {
var contentLen int32 = 0
var paddingLen int32 = 0
if b != nil {
contentLen = b.Len()
}
if contentLen < 900 && longPadding {
l, err := rand.Int(rand.Reader, big.NewInt(500))
if err != nil {
newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
paddingLen = int32(l.Int64()) + 900 - contentLen
} else {
l, err := rand.Int(rand.Reader, big.NewInt(256))
if err != nil {
newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
paddingLen = int32(l.Int64())
}
if paddingLen > buf.Size-21-contentLen {
paddingLen = buf.Size - 21 - contentLen
}
newbuffer := buf.New()
if userUUID != nil {
newbuffer.Write(*userUUID)
*userUUID = nil
}
newbuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)})
if b != nil {
newbuffer.Write(b.Bytes())
b.Release()
b = nil
}
newbuffer.Extend(paddingLen)
newError("XtlsPadding ", contentLen, " ", paddingLen, " ", command).WriteToLog(session.ExportIDToError(ctx))
return newbuffer
}
// XtlsUnpadding remove padding and parse command
func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buffer {
if s.RemainingCommand == -1 && s.RemainingContent == -1 && s.RemainingPadding == -1 { // inital state
if b.Len() >= 21 && bytes.Equal(s.UserUUID, b.BytesTo(16)) {
b.Advance(16)
s.RemainingCommand = 5
} else {
return b
}
}
newbuffer := buf.New()
for b.Len() > 0 {
if s.RemainingCommand > 0 {
data, err := b.ReadByte()
if err != nil {
return newbuffer
}
switch s.RemainingCommand {
case 5:
s.CurrentCommand = int(data)
case 4:
s.RemainingContent = int32(data)<<8
case 3:
s.RemainingContent = s.RemainingContent | int32(data)
case 2:
s.RemainingPadding = int32(data)<<8
case 1:
s.RemainingPadding = s.RemainingPadding | int32(data)
newError("Xtls Unpadding new block, content ", s.RemainingContent, " padding ", s.RemainingPadding, " command ", s.CurrentCommand).WriteToLog(session.ExportIDToError(ctx))
}
s.RemainingCommand--
} else if s.RemainingContent > 0 {
len := s.RemainingContent
if b.Len() < len {
len = b.Len()
}
data, err := b.ReadBytes(len)
if err != nil {
return newbuffer
}
newbuffer.Write(data)
s.RemainingContent -= len
} else { // remainingPadding > 0
len := s.RemainingPadding
if b.Len() < len {
len = b.Len()
}
b.Advance(len)
s.RemainingPadding -= len
}
if s.RemainingCommand <= 0 && s.RemainingContent <= 0 && s.RemainingPadding <= 0 { // this block done
if s.CurrentCommand == 0 {
s.RemainingCommand = 5
} else {
s.RemainingCommand = -1 // set to initial state
s.RemainingContent = -1
s.RemainingPadding = -1
if b.Len() > 0 { // shouldn't happen
newbuffer.Write(b.Bytes())
}
break
}
}
}
b.Release()
b = nil
return newbuffer
}
// XtlsFilterTls filter and recognize tls 1.3 and other info
func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx context.Context) {
for _, b := range buffer {
if b == nil {
continue
}
trafficState.NumberOfPacketToFilter--
if b.Len() >= 6 {
startsBytes := b.BytesTo(6)
if bytes.Equal(TlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == TlsHandshakeTypeServerHello {
trafficState.RemainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5
trafficState.IsTLS12orAbove = true
trafficState.IsTLS = true
if b.Len() >= 79 && trafficState.RemainingServerHello >= 79 {
sessionIdLen := int32(b.Byte(43))
cipherSuite := b.BytesRange(43+sessionIdLen+1, 43+sessionIdLen+3)
trafficState.Cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
} else {
newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", trafficState.RemainingServerHello).WriteToLog(session.ExportIDToError(ctx))
}
} else if bytes.Equal(TlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == TlsHandshakeTypeClientHello {
trafficState.IsTLS = true
newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
}
}
if trafficState.RemainingServerHello > 0 {
end := trafficState.RemainingServerHello
if end > b.Len() {
end = b.Len()
}
trafficState.RemainingServerHello -= b.Len()
if bytes.Contains(b.BytesTo(end), Tls13SupportedVersions) {
v, ok := Tls13CipherSuiteDic[trafficState.Cipher]
if !ok {
v = "Old cipher: " + strconv.FormatUint(uint64(trafficState.Cipher), 16)
} else if v != "TLS_AES_128_CCM_8_SHA256" {
trafficState.EnableXtls = true
}
newError("XtlsFilterTls found tls 1.3! ", b.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx))
trafficState.NumberOfPacketToFilter = 0
return
} else if trafficState.RemainingServerHello <= 0 {
newError("XtlsFilterTls found tls 1.2! ", b.Len()).WriteToLog(session.ExportIDToError(ctx))
trafficState.NumberOfPacketToFilter = 0
return
}
newError("XtlsFilterTls inconclusive server hello ", b.Len(), " ", trafficState.RemainingServerHello).WriteToLog(session.ExportIDToError(ctx))
}
if trafficState.NumberOfPacketToFilter <= 0 {
newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
}
}
}
// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it // UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it
func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
var readCounter, writerCounter stats.Counter var readCounter, writerCounter stats.Counter

View file

@ -1,10 +1,12 @@
package encoding package encoding
import ( import (
"context"
"io" "io"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@ -58,14 +60,19 @@ func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) {
} }
// EncodeBodyAddons returns a Writer that auto-encrypt content written by caller. // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller.
func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, addons *Addons) buf.Writer { func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, context context.Context) buf.Writer {
switch addons.Flow {
default:
if request.Command == protocol.RequestCommandUDP { if request.Command == protocol.RequestCommandUDP {
return NewMultiLengthPacketWriter(writer.(buf.Writer)) w := writer.(buf.Writer)
if requestAddons.Flow == vless.XRV {
w = proxy.NewVisionWriter(w, state, context)
} }
return NewMultiLengthPacketWriter(w)
} }
return buf.NewWriter(writer) w := buf.NewWriter(writer)
if requestAddons.Flow == vless.XRV {
w = proxy.NewVisionWriter(w, state, context)
}
return w
} }
// DecodeBodyAddons returns a Reader from which caller can fetch decrypted body. // DecodeBodyAddons returns a Reader from which caller can fetch decrypted body.

View file

@ -5,11 +5,7 @@ package encoding
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/rand"
"io" "io"
"math/big"
"strconv"
"time"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
@ -26,30 +22,6 @@ const (
Version = byte(0) Version = byte(0)
) )
var (
tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}
tlsClientHandShakeStart = []byte{0x16, 0x03}
tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03}
tlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
Tls13CipherSuiteDic = map[uint16]string{
0x1301: "TLS_AES_128_GCM_SHA256",
0x1302: "TLS_AES_256_GCM_SHA384",
0x1303: "TLS_CHACHA20_POLY1305_SHA256",
0x1304: "TLS_AES_128_CCM_SHA256",
0x1305: "TLS_AES_128_CCM_8_SHA256",
}
)
const (
tlsHandshakeTypeClientHello byte = 0x01
tlsHandshakeTypeServerHello byte = 0x02
CommandPaddingContinue byte = 0x00
CommandPaddingEnd byte = 0x01
CommandPaddingDirect byte = 0x02
)
var addrParser = protocol.NewAddressParser( var addrParser = protocol.NewAddressParser(
protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4), protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain), protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
@ -202,18 +174,11 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
} }
// XtlsRead filter and read xtls protocol // XtlsRead filter and read xtls protocol
func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ctx context.Context) error {
ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool,
isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32,
) error {
err := func() error { err := func() error {
withinPaddingBuffers := true visionReader := proxy.NewVisionReader(reader, trafficState, ctx)
shouldSwitchToDirectCopy := false
var remainingContent int32 = -1
var remainingPadding int32 = -1
currentCommand := 0
for { for {
if shouldSwitchToDirectCopy { if trafficState.ReaderSwitchToDirectCopy {
var writerConn net.Conn var writerConn net.Conn
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
writerConn = inbound.Conn writerConn = inbound.Conn
@ -223,18 +188,10 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
} }
return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer) return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer)
} }
buffer, err := reader.ReadMultiBuffer() buffer, err := visionReader.ReadMultiBuffer()
if !buffer.IsEmpty() { if !buffer.IsEmpty() {
if withinPaddingBuffers || *numberOfPacketToFilter > 0 { timer.Update()
buffer = XtlsUnpadding(ctx, buffer, userUUID, &remainingContent, &remainingPadding, &currentCommand) if trafficState.ReaderSwitchToDirectCopy {
if remainingContent == 0 && remainingPadding == 0 {
if currentCommand == 1 {
withinPaddingBuffers = false
remainingContent = -1
remainingPadding = -1 // set to initial state to parse the next padding
} else if currentCommand == 2 {
withinPaddingBuffers = false
shouldSwitchToDirectCopy = true
// XTLS Vision processes struct TLS Conn's input and rawInput // XTLS Vision processes struct TLS Conn's input and rawInput
if inputBuffer, err := buf.ReadFrom(input); err == nil { if inputBuffer, err := buf.ReadFrom(input); err == nil {
if !inputBuffer.IsEmpty() { if !inputBuffer.IsEmpty() {
@ -246,21 +203,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
buffer, _ = buf.MergeMulti(buffer, rawInputBuffer) buffer, _ = buf.MergeMulti(buffer, rawInputBuffer)
} }
} }
} else if currentCommand == 0 {
withinPaddingBuffers = true
} else {
newError("XtlsRead unknown command ", currentCommand, buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
} }
} else if remainingContent > 0 || remainingPadding > 0 {
withinPaddingBuffers = true
} else {
withinPaddingBuffers = false
}
}
if *numberOfPacketToFilter > 0 {
XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx)
}
timer.Update()
if werr := writer.WriteMultiBuffer(buffer); werr != nil { if werr := writer.WriteMultiBuffer(buffer); werr != nil {
return werr return werr
} }
@ -277,59 +220,19 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
} }
// XtlsWrite filter and write xtls protocol // XtlsWrite filter and write xtls protocol
func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ctx context.Context) error {
ctx context.Context, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool,
cipher *uint16, remainingServerHello *int32,
) error {
err := func() error { err := func() error {
var ct stats.Counter var ct stats.Counter
isPadding := true
shouldSwitchToDirectCopy := false
for { for {
buffer, err := reader.ReadMultiBuffer() buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() { if trafficState.WriterSwitchToDirectCopy {
if *numberOfPacketToFilter > 0 {
XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx)
}
if isPadding {
buffer = ReshapeMultiBuffer(ctx, buffer)
var xtlsSpecIndex int
for i, b := range buffer {
if *isTLS && b.Len() >= 6 && bytes.Equal(tlsApplicationDataStart, b.BytesTo(3)) {
var command byte = CommandPaddingEnd
if *enableXtls {
shouldSwitchToDirectCopy = true
xtlsSpecIndex = i
command = CommandPaddingDirect
}
isPadding = false
buffer[i] = XtlsPadding(b, command, nil, *isTLS, ctx)
break
} else if !*isTLS12orAbove && *numberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early
isPadding = false
buffer[i] = XtlsPadding(b, CommandPaddingEnd, nil, *isTLS, ctx)
break
}
buffer[i] = XtlsPadding(b, CommandPaddingContinue, nil, *isTLS, ctx)
}
if shouldSwitchToDirectCopy {
encryptBuffer, directBuffer := buf.SplitMulti(buffer, xtlsSpecIndex+1)
if !encryptBuffer.IsEmpty() {
timer.Update()
if werr := writer.WriteMultiBuffer(encryptBuffer); werr != nil {
return werr
}
}
time.Sleep(5 * time.Millisecond) // for some device, the first xtls direct packet fails without this delay
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 { if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter
} }
buffer = directBuffer
rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) rawConn, _, writerCounter := proxy.UnwrapRawConn(conn)
writer = buf.NewWriter(rawConn) writer = buf.NewWriter(rawConn)
ct = writerCounter ct = writerCounter
} trafficState.WriterSwitchToDirectCopy = false
} }
if !buffer.IsEmpty() { if !buffer.IsEmpty() {
if ct != nil { if ct != nil {
@ -340,7 +243,6 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
return werr return werr
} }
} }
}
if err != nil { if err != nil {
return err return err
} }
@ -351,201 +253,3 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
} }
return nil return nil
} }
// XtlsFilterTls filter and recognize tls 1.3 and other info
func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool,
cipher *uint16, remainingServerHello *int32, ctx context.Context,
) {
for _, b := range buffer {
*numberOfPacketToFilter--
if b.Len() >= 6 {
startsBytes := b.BytesTo(6)
if bytes.Equal(tlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == tlsHandshakeTypeServerHello {
*remainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5
*isTLS12orAbove = true
*isTLS = true
if b.Len() >= 79 && *remainingServerHello >= 79 {
sessionIdLen := int32(b.Byte(43))
cipherSuite := b.BytesRange(43+sessionIdLen+1, 43+sessionIdLen+3)
*cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
} else {
newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx))
}
} else if bytes.Equal(tlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == tlsHandshakeTypeClientHello {
*isTLS = true
newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
}
}
if *remainingServerHello > 0 {
end := *remainingServerHello
if end > b.Len() {
end = b.Len()
}
*remainingServerHello -= b.Len()
if bytes.Contains(b.BytesTo(end), tls13SupportedVersions) {
v, ok := Tls13CipherSuiteDic[*cipher]
if !ok {
v = "Old cipher: " + strconv.FormatUint(uint64(*cipher), 16)
} else if v != "TLS_AES_128_CCM_8_SHA256" {
*enableXtls = true
}
newError("XtlsFilterTls found tls 1.3! ", b.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx))
*numberOfPacketToFilter = 0
return
} else if *remainingServerHello <= 0 {
newError("XtlsFilterTls found tls 1.2! ", b.Len()).WriteToLog(session.ExportIDToError(ctx))
*numberOfPacketToFilter = 0
return
}
newError("XtlsFilterTls inconclusive server hello ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx))
}
if *numberOfPacketToFilter <= 0 {
newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
}
}
}
// ReshapeMultiBuffer prepare multi buffer for padding stucture (max 21 bytes)
func ReshapeMultiBuffer(ctx context.Context, buffer buf.MultiBuffer) buf.MultiBuffer {
needReshape := 0
for _, b := range buffer {
if b.Len() >= buf.Size-21 {
needReshape += 1
}
}
if needReshape == 0 {
return buffer
}
mb2 := make(buf.MultiBuffer, 0, len(buffer)+needReshape)
toPrint := ""
for i, buffer1 := range buffer {
if buffer1.Len() >= buf.Size-21 {
index := int32(bytes.LastIndex(buffer1.Bytes(), tlsApplicationDataStart))
if index <= 0 || index > buf.Size-21 {
index = buf.Size / 2
}
buffer2 := buf.New()
buffer2.Write(buffer1.BytesFrom(index))
buffer1.Resize(0, index)
mb2 = append(mb2, buffer1, buffer2)
toPrint += " " + strconv.Itoa(int(buffer1.Len())) + " " + strconv.Itoa(int(buffer2.Len()))
} else {
mb2 = append(mb2, buffer1)
toPrint += " " + strconv.Itoa(int(buffer1.Len()))
}
buffer[i] = nil
}
buffer = buffer[:0]
newError("ReshapeMultiBuffer ", toPrint).WriteToLog(session.ExportIDToError(ctx))
return mb2
}
// XtlsPadding add padding to eliminate length siganature during tls handshake
func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool, ctx context.Context) *buf.Buffer {
var contentLen int32 = 0
var paddingLen int32 = 0
if b != nil {
contentLen = b.Len()
}
if contentLen < 900 && longPadding {
l, err := rand.Int(rand.Reader, big.NewInt(500))
if err != nil {
newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
paddingLen = int32(l.Int64()) + 900 - contentLen
} else {
l, err := rand.Int(rand.Reader, big.NewInt(256))
if err != nil {
newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
paddingLen = int32(l.Int64())
}
if paddingLen > buf.Size-21-contentLen {
paddingLen = buf.Size - 21 - contentLen
}
newbuffer := buf.New()
if userUUID != nil {
newbuffer.Write(*userUUID)
*userUUID = nil
}
newbuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)})
if b != nil {
newbuffer.Write(b.Bytes())
b.Release()
b = nil
}
newbuffer.Extend(paddingLen)
newError("XtlsPadding ", contentLen, " ", paddingLen, " ", command).WriteToLog(session.ExportIDToError(ctx))
return newbuffer
}
// XtlsUnpadding remove padding and parse command
func XtlsUnpadding(ctx context.Context, buffer buf.MultiBuffer, userUUID []byte, remainingContent *int32, remainingPadding *int32, currentCommand *int) buf.MultiBuffer {
posindex := 0
var posByte int32 = 0
if *remainingContent == -1 && *remainingPadding == -1 {
for i, b := range buffer {
if b.Len() >= 21 && bytes.Equal(userUUID, b.BytesTo(16)) {
posindex = i
posByte = 16
*remainingContent = 0
*remainingPadding = 0
*currentCommand = 0
break
}
}
}
if *remainingContent == -1 && *remainingPadding == -1 {
return buffer
}
mb2 := make(buf.MultiBuffer, 0, len(buffer))
for i := 0; i < posindex; i++ {
newbuffer := buf.New()
newbuffer.Write(buffer[i].Bytes())
mb2 = append(mb2, newbuffer)
}
for i := posindex; i < len(buffer); i++ {
b := buffer[i]
for posByte < b.Len() {
if *remainingContent <= 0 && *remainingPadding <= 0 {
if *currentCommand == 1 { // possible buffer after padding, no need to worry about xtls (command 2)
len := b.Len() - posByte
newbuffer := buf.New()
newbuffer.Write(b.BytesRange(posByte, posByte+len))
mb2 = append(mb2, newbuffer)
posByte += len
} else {
paddingInfo := b.BytesRange(posByte, posByte+5)
*currentCommand = int(paddingInfo[0])
*remainingContent = int32(paddingInfo[1])<<8 | int32(paddingInfo[2])
*remainingPadding = int32(paddingInfo[3])<<8 | int32(paddingInfo[4])
newError("Xtls Unpadding new block", i, " ", posByte, " content ", *remainingContent, " padding ", *remainingPadding, " ", paddingInfo[0]).WriteToLog(session.ExportIDToError(ctx))
posByte += 5
}
} else if *remainingContent > 0 {
len := *remainingContent
if b.Len() < posByte+*remainingContent {
len = b.Len() - posByte
}
newbuffer := buf.New()
newbuffer.Write(b.BytesRange(posByte, posByte+len))
mb2 = append(mb2, newbuffer)
*remainingContent -= len
posByte += len
} else { // remainingPadding > 0
len := *remainingPadding
if b.Len() < posByte+*remainingPadding {
len = b.Len() - posByte
}
*remainingPadding -= len
posByte += len
}
if posByte == b.Len() {
posByte = 0
break
}
}
}
buf.ReleaseMulti(buffer)
return mb2
}

View file

@ -28,6 +28,7 @@ import (
feature_inbound "github.com/xtls/xray-core/features/inbound" feature_inbound "github.com/xtls/xray-core/features/inbound"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless"
"github.com/xtls/xray-core/proxy/vless/encoding" "github.com/xtls/xray-core/proxy/vless/encoding"
"github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/reality"
@ -510,13 +511,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
serverReader := link.Reader // .(*pipe.Reader) serverReader := link.Reader // .(*pipe.Reader)
serverWriter := link.Writer // .(*pipe.Writer) serverWriter := link.Writer // .(*pipe.Writer)
enableXtls := false trafficState := proxy.NewTrafficState(account.ID.Bytes())
isTLS12orAbove := false
isTLS := false
var cipher uint16 = 0
var remainingServerHello int32 = -1
numberOfPacketToFilter := 8
postRequest := func() error { postRequest := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
@ -527,8 +522,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, ctx1, account.ID.Bytes(), err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, ctx1)
&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -550,19 +544,11 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
} }
// default: clientWriter := bufferWriter // default: clientWriter := bufferWriter
clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, responseAddons) clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx)
userUUID := account.ID.Bytes()
multiBuffer, err1 := serverReader.ReadMultiBuffer() multiBuffer, err1 := serverReader.ReadMultiBuffer()
if err1 != nil { if err1 != nil {
return err1 // ... return err1 // ...
} }
if requestAddons.Flow == vless.XRV {
encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer)
for i, b := range multiBuffer {
multiBuffer[i] = encoding.XtlsPadding(b, encoding.CommandPaddingContinue, &userUUID, isTLS, ctx)
}
}
if err := clientWriter.WriteMultiBuffer(multiBuffer); err != nil { if err := clientWriter.WriteMultiBuffer(multiBuffer); err != nil {
return err // ... return err // ...
} }
@ -573,8 +559,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
var err error var err error
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, ctx, &numberOfPacketToFilter, err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ctx)
&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))

View file

@ -22,6 +22,7 @@ import (
"github.com/xtls/xray-core/common/xudp" "github.com/xtls/xray-core/common/xudp"
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless"
"github.com/xtls/xray-core/proxy/vless/encoding" "github.com/xtls/xray-core/proxy/vless/encoding"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
@ -183,13 +184,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
clientReader := link.Reader // .(*pipe.Reader) clientReader := link.Reader // .(*pipe.Reader)
clientWriter := link.Writer // .(*pipe.Writer) clientWriter := link.Writer // .(*pipe.Writer)
enableXtls := false trafficState := proxy.NewTrafficState(account.ID.Bytes())
isTLS12orAbove := false
isTLS := false
var cipher uint16 = 0
var remainingServerHello int32 = -1
numberOfPacketToFilter := 8
if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 { if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 {
request.Command = protocol.RequestCommandMux request.Command = protocol.RequestCommandMux
request.Address = net.DomainAddress("v1.mux.cool") request.Address = net.DomainAddress("v1.mux.cool")
@ -205,22 +200,14 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
} }
// default: serverWriter := bufferWriter // default: serverWriter := bufferWriter
serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons) serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx)
if request.Command == protocol.RequestCommandMux && request.Port == 666 { if request.Command == protocol.RequestCommandMux && request.Port == 666 {
serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx)) serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx))
} }
userUUID := account.ID.Bytes()
timeoutReader, ok := clientReader.(buf.TimeoutReader) timeoutReader, ok := clientReader.(buf.TimeoutReader)
if ok { if ok {
multiBuffer, err1 := timeoutReader.ReadMultiBufferTimeout(time.Millisecond * 500) multiBuffer, err1 := timeoutReader.ReadMultiBufferTimeout(time.Millisecond * 500)
if err1 == nil { if err1 == nil {
if requestAddons.Flow == vless.XRV {
encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer)
for i, b := range multiBuffer {
multiBuffer[i] = encoding.XtlsPadding(b, encoding.CommandPaddingContinue, &userUUID, isTLS, ctx)
}
}
if err := serverWriter.WriteMultiBuffer(multiBuffer); err != nil { if err := serverWriter.WriteMultiBuffer(multiBuffer); err != nil {
return err // ... return err // ...
} }
@ -228,10 +215,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return err1 return err1
} else if requestAddons.Flow == vless.XRV { } else if requestAddons.Flow == vless.XRV {
mb := make(buf.MultiBuffer, 1) mb := make(buf.MultiBuffer, 1)
mb[0] = encoding.XtlsPadding(nil, encoding.CommandPaddingContinue, &userUUID, true, ctx) // we do a long padding to hide vless header
newError("Insert padding with empty content to camouflage VLESS header ", mb.Len()).WriteToLog(session.ExportIDToError(ctx)) newError("Insert padding with empty content to camouflage VLESS header ", mb.Len()).WriteToLog(session.ExportIDToError(ctx))
if err := serverWriter.WriteMultiBuffer(mb); err != nil { if err := serverWriter.WriteMultiBuffer(mb); err != nil {
return err return err // ...
} }
} }
} else { } else {
@ -254,8 +240,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
} }
} }
ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice
err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, ctx1, &numberOfPacketToFilter, err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ctx1)
&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -286,8 +271,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
} }
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, ctx, account.ID.Bytes(), err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ctx)
&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))