diff --git a/inbound/vless.go b/inbound/vless.go index 24c7d6fe..2de8ce6d 100644 --- a/inbound/vless.go +++ b/inbound/vless.go @@ -50,7 +50,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg ctx: ctx, users: options.Users, } - service := vless.NewService[int](adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound)) + service := vless.NewService[int](logger, adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound)) service.UpdateUsers(common.MapIndexed(inbound.users, func(index int, _ option.VLESSUser) int { return index }), common.Map(inbound.users, func(it option.VLESSUser) string { diff --git a/outbound/vless.go b/outbound/vless.go index 31d4979f..996c8759 100644 --- a/outbound/vless.go +++ b/outbound/vless.go @@ -67,7 +67,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg default: return nil, E.New("unknown packet encoding: ", options.PacketEncoding) } - outbound.client, err = vless.NewClient(options.UUID, options.Flow) + outbound.client, err = vless.NewClient(options.UUID, options.Flow, logger) if err != nil { return nil, err } diff --git a/transport/vless/client.go b/transport/vless/client.go index dd70a2df..0c3bec34 100644 --- a/transport/vless/client.go +++ b/transport/vless/client.go @@ -9,6 +9,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -16,11 +17,12 @@ import ( ) type Client struct { - key [16]byte - flow string + key [16]byte + flow string + logger logger.Logger } -func NewClient(userId string, flow string) (*Client, error) { +func NewClient(userId string, flow string, logger logger.Logger) (*Client, error) { user := uuid.FromStringOrNil(userId) if user == uuid.Nil { user = uuid.NewV5(user, userId) @@ -30,12 +32,12 @@ func NewClient(userId string, flow string) (*Client, error) { default: return nil, E.New("unsupported flow: " + flow) } - return &Client{user, flow}, nil + return &Client{user, flow, logger}, nil } func (c *Client) prepareConn(conn net.Conn) (net.Conn, error) { if c.flow == FlowVision { - vConn, err := NewVisionConn(conn, c.key) + vConn, err := NewVisionConn(conn, c.key, c.logger) if err != nil { return nil, E.Cause(err, "initialize vision") } diff --git a/transport/vless/constant.go b/transport/vless/constant.go index 20085019..5602eef4 100644 --- a/transport/vless/constant.go +++ b/transport/vless/constant.go @@ -11,6 +11,10 @@ var ( tlsClientHandShakeStart = []byte{0x16, 0x03} tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03} tlsApplicationDataStart = []byte{0x17, 0x03, 0x03} + + commandPaddingContinue byte = 0 + commandPaddingEnd byte = 1 + commandPaddingDirect byte = 2 ) var tls13CipherSuiteDic = map[uint16]string{ diff --git a/transport/vless/service.go b/transport/vless/service.go index 84fd629e..b77f9702 100644 --- a/transport/vless/service.go +++ b/transport/vless/service.go @@ -11,6 +11,7 @@ import ( "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -19,6 +20,7 @@ import ( type Service[T any] struct { userMap map[[16]byte]T + logger logger.Logger handler Handler } @@ -28,8 +30,9 @@ type Handler interface { E.Handler } -func NewService[T any](handler Handler) *Service[T] { +func NewService[T any](logger logger.Logger, handler Handler) *Service[T] { return &Service[T]{ + logger: logger, handler: handler, } } @@ -64,7 +67,7 @@ func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata switch request.Flow { case "": case FlowVision: - protocolConn, err = NewVisionConn(conn, request.UUID) + protocolConn, err = NewVisionConn(conn, request.UUID, s.logger) if err != nil { return E.Cause(err, "initialize vision") } diff --git a/transport/vless/vision.go b/transport/vless/vision.go index 742e8355..39593f3a 100644 --- a/transport/vless/vision.go +++ b/transport/vless/vision.go @@ -16,6 +16,7 @@ import ( "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" N "github.com/sagernet/sing/common/network" ) @@ -37,26 +38,27 @@ type VisionConn struct { input *bytes.Reader rawInput *bytes.Buffer netConn net.Conn + logger logger.Logger - userUUID [16]byte - isTLS bool - numberOfPacketToFilter int - isTLS12orAbove bool - remainingServerHello int32 - cipher uint16 - enableXTLS bool - filterTlsApplicationData bool - directWrite bool - writeUUID bool - filterUUID bool - remainingContent int - remainingPadding int - currentCommand int - directRead bool - remainingReader io.Reader + userUUID [16]byte + isTLS bool + numberOfPacketToFilter int + isTLS12orAbove bool + remainingServerHello int32 + cipher uint16 + enableXTLS bool + isPadding bool + directWrite bool + writeUUID bool + withinPaddingBuffers bool + remainingContent int + remainingPadding int + currentCommand int + directRead bool + remainingReader io.Reader } -func NewVisionConn(conn net.Conn, userUUID [16]byte) (*VisionConn, error) { +func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) { var ( loaded bool reflectType reflect.Type @@ -75,19 +77,21 @@ func NewVisionConn(conn net.Conn, userUUID [16]byte) (*VisionConn, error) { input, _ := reflectType.FieldByName("input") rawInput, _ := reflectType.FieldByName("rawInput") return &VisionConn{ - Conn: conn, - writer: bufio.NewVectorisedWriter(conn), - input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)), - rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)), - netConn: netConn, - userUUID: userUUID, - numberOfPacketToFilter: 8, - remainingServerHello: -1, - filterTlsApplicationData: true, - writeUUID: true, - filterUUID: true, - remainingContent: -1, - remainingPadding: -1, + Conn: conn, + writer: bufio.NewVectorisedWriter(conn), + input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)), + rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)), + netConn: netConn, + logger: logger, + + userUUID: userUUID, + numberOfPacketToFilter: 8, + remainingServerHello: -1, + isPadding: true, + writeUUID: true, + withinPaddingBuffers: true, + remainingContent: -1, + remainingPadding: -1, }, nil } @@ -97,6 +101,7 @@ func (c *VisionConn) Read(p []byte) (n int, err error) { if err == io.EOF { c.remainingReader = nil if n > 0 { + err = nil return } } @@ -109,13 +114,15 @@ func (c *VisionConn) Read(p []byte) (n int, err error) { return } buffer := p[:n] - if c.filterUUID && (c.isTLS || c.numberOfPacketToFilter > 0) { + if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 { buffers := c.unPadding(buffer) if c.remainingContent == 0 && c.remainingPadding == 0 { if c.currentCommand == 1 { - c.filterUUID = false + c.withinPaddingBuffers = false + c.remainingContent = -1 + c.remainingPadding = -1 } else if c.currentCommand == 2 { - c.filterUUID = false + c.withinPaddingBuffers = false c.directRead = true inputBuffer, err := io.ReadAll(c.input) @@ -130,9 +137,17 @@ func (c *VisionConn) Read(p []byte) (n int, err error) { } buffers = append(buffers, rawInputBuffer) - } else if c.currentCommand != 0 { + + c.logger.Trace("XtlsRead readV") + } else if c.currentCommand == 0 { + c.withinPaddingBuffers = true + } else { return 0, E.New("unknown command ", c.currentCommand) } + } else if c.remainingContent > 0 || c.remainingPadding > 0 { + c.withinPaddingBuffers = true + } else { + c.withinPaddingBuffers = false } if c.numberOfPacketToFilter > 0 { c.filterTLS(buffers) @@ -151,27 +166,27 @@ func (c *VisionConn) Write(p []byte) (n int, err error) { if c.numberOfPacketToFilter > 0 { c.filterTLS([][]byte{p}) } - if c.isTLS && c.filterTlsApplicationData { + if c.isPadding { inputLen := len(p) buffers := reshapeBuffer(p) var specIndex int for i, buffer := range buffers { - if buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) { - var command byte = 1 + if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) { + var command byte = commandPaddingEnd if c.enableXTLS { c.directWrite = true specIndex = i - command = 2 + command = commandPaddingDirect } - c.filterTlsApplicationData = false + c.isPadding = false buffers[i] = c.padding(buffer, command) break - } else if !c.isTLS12orAbove && c.numberOfPacketToFilter == 0 { - c.filterTlsApplicationData = false - buffers[i] = c.padding(buffer, 0x01) + } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 { + c.isPadding = false + buffers[i] = c.padding(buffer, commandPaddingEnd) break } - buffers[i] = c.padding(buffer, 0x00) + buffers[i] = c.padding(buffer, commandPaddingContinue) } if c.directWrite { encryptedBuffer := buffers[:specIndex+1] @@ -181,6 +196,7 @@ func (c *VisionConn) Write(p []byte) (n int, err error) { } buffers = buffers[specIndex+1:] c.writer = bufio.NewVectorisedWriter(c.netConn) + c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers)) time.Sleep(5 * time.Millisecond) // wtf } err = c.writer.WriteVectorised(buffers) @@ -209,10 +225,13 @@ func (c *VisionConn) filterTLS(buffers [][]byte) { sessionIdLen := int32(buffer[43]) cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3] c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1]) + } else { + c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello) } } } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 { c.isTLS = true + c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer)) } } if c.remainingServerHello > 0 { @@ -226,13 +245,18 @@ func (c *VisionConn) filterTLS(buffers [][]byte) { if ok && cipher != "TLS_AES_128_CCM_8_SHA256" { c.enableXTLS = true } + c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS) c.numberOfPacketToFilter = 0 return } else if c.remainingServerHello == 0 { + c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer)) c.numberOfPacketToFilter = 0 return } } + if c.numberOfPacketToFilter == 0 { + c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer)) + } } } @@ -242,9 +266,12 @@ func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer { if buffer != nil { contentLen = buffer.Len() } - if contentLen < 900 { + if contentLen < 900 && c.isTLS { l, _ := rand.Int(rand.Reader, big.NewInt(500)) paddingLen = int(l.Int64()) + 900 - contentLen + } else { + l, _ := rand.Int(rand.Reader, big.NewInt(256)) + paddingLen = int(l.Int64()) } newBuffer := buf.New() if c.writeUUID { @@ -257,6 +284,7 @@ func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer { buffer.Release() } newBuffer.Extend(paddingLen) + c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command) return newBuffer } @@ -267,6 +295,7 @@ func (c *VisionConn) unPadding(buffer []byte) [][]byte { bufferIndex = 16 c.remainingContent = 0 c.remainingPadding = 0 + c.currentCommand = 0 } } if c.remainingContent == -1 && c.remainingPadding == -1 { @@ -284,6 +313,7 @@ func (c *VisionConn) unPadding(buffer []byte) [][]byte { c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2]) c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4]) bufferIndex += 5 + c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand) } } else if c.remainingContent > 0 { end := c.remainingContent