Update vision protocol

This commit is contained in:
世界 2023-02-27 15:07:15 +08:00
parent 5ce3ddee9b
commit e4bff0460d
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 92 additions and 53 deletions

View file

@ -50,7 +50,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg
ctx: ctx, ctx: ctx,
users: options.Users, users: options.Users,
} }
service := vless.NewService[int](adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound)) service := vless.NewService[int](logger, adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound))
service.UpdateUsers(common.MapIndexed(inbound.users, func(index int, _ option.VLESSUser) int { service.UpdateUsers(common.MapIndexed(inbound.users, func(index int, _ option.VLESSUser) int {
return index return index
}), common.Map(inbound.users, func(it option.VLESSUser) string { }), common.Map(inbound.users, func(it option.VLESSUser) string {

View file

@ -67,7 +67,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg
default: default:
return nil, E.New("unknown packet encoding: ", options.PacketEncoding) return nil, E.New("unknown packet encoding: ", options.PacketEncoding)
} }
outbound.client, err = vless.NewClient(options.UUID, options.Flow) outbound.client, err = vless.NewClient(options.UUID, options.Flow, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -16,11 +17,12 @@ import (
) )
type Client struct { type Client struct {
key [16]byte key [16]byte
flow string flow string
logger logger.Logger
} }
func NewClient(userId string, flow string) (*Client, error) { func NewClient(userId string, flow string, logger logger.Logger) (*Client, error) {
user := uuid.FromStringOrNil(userId) user := uuid.FromStringOrNil(userId)
if user == uuid.Nil { if user == uuid.Nil {
user = uuid.NewV5(user, userId) user = uuid.NewV5(user, userId)
@ -30,12 +32,12 @@ func NewClient(userId string, flow string) (*Client, error) {
default: default:
return nil, E.New("unsupported flow: " + flow) return nil, E.New("unsupported flow: " + flow)
} }
return &Client{user, flow}, nil return &Client{user, flow, logger}, nil
} }
func (c *Client) prepareConn(conn net.Conn) (net.Conn, error) { func (c *Client) prepareConn(conn net.Conn) (net.Conn, error) {
if c.flow == FlowVision { if c.flow == FlowVision {
vConn, err := NewVisionConn(conn, c.key) vConn, err := NewVisionConn(conn, c.key, c.logger)
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize vision") return nil, E.Cause(err, "initialize vision")
} }

View file

@ -11,6 +11,10 @@ var (
tlsClientHandShakeStart = []byte{0x16, 0x03} tlsClientHandShakeStart = []byte{0x16, 0x03}
tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03} tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03}
tlsApplicationDataStart = []byte{0x17, 0x03, 0x03} tlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
commandPaddingContinue byte = 0
commandPaddingEnd byte = 1
commandPaddingDirect byte = 2
) )
var tls13CipherSuiteDic = map[uint16]string{ var tls13CipherSuiteDic = map[uint16]string{

View file

@ -11,6 +11,7 @@ import (
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -19,6 +20,7 @@ import (
type Service[T any] struct { type Service[T any] struct {
userMap map[[16]byte]T userMap map[[16]byte]T
logger logger.Logger
handler Handler handler Handler
} }
@ -28,8 +30,9 @@ type Handler interface {
E.Handler E.Handler
} }
func NewService[T any](handler Handler) *Service[T] { func NewService[T any](logger logger.Logger, handler Handler) *Service[T] {
return &Service[T]{ return &Service[T]{
logger: logger,
handler: handler, handler: handler,
} }
} }
@ -64,7 +67,7 @@ func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata
switch request.Flow { switch request.Flow {
case "": case "":
case FlowVision: case FlowVision:
protocolConn, err = NewVisionConn(conn, request.UUID) protocolConn, err = NewVisionConn(conn, request.UUID, s.logger)
if err != nil { if err != nil {
return E.Cause(err, "initialize vision") return E.Cause(err, "initialize vision")
} }

View file

@ -16,6 +16,7 @@ import (
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
@ -37,26 +38,27 @@ type VisionConn struct {
input *bytes.Reader input *bytes.Reader
rawInput *bytes.Buffer rawInput *bytes.Buffer
netConn net.Conn netConn net.Conn
logger logger.Logger
userUUID [16]byte userUUID [16]byte
isTLS bool isTLS bool
numberOfPacketToFilter int numberOfPacketToFilter int
isTLS12orAbove bool isTLS12orAbove bool
remainingServerHello int32 remainingServerHello int32
cipher uint16 cipher uint16
enableXTLS bool enableXTLS bool
filterTlsApplicationData bool isPadding bool
directWrite bool directWrite bool
writeUUID bool writeUUID bool
filterUUID bool withinPaddingBuffers bool
remainingContent int remainingContent int
remainingPadding int remainingPadding int
currentCommand int currentCommand int
directRead bool directRead bool
remainingReader io.Reader remainingReader io.Reader
} }
func NewVisionConn(conn net.Conn, userUUID [16]byte) (*VisionConn, error) { func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
var ( var (
loaded bool loaded bool
reflectType reflect.Type reflectType reflect.Type
@ -75,19 +77,21 @@ func NewVisionConn(conn net.Conn, userUUID [16]byte) (*VisionConn, error) {
input, _ := reflectType.FieldByName("input") input, _ := reflectType.FieldByName("input")
rawInput, _ := reflectType.FieldByName("rawInput") rawInput, _ := reflectType.FieldByName("rawInput")
return &VisionConn{ return &VisionConn{
Conn: conn, Conn: conn,
writer: bufio.NewVectorisedWriter(conn), writer: bufio.NewVectorisedWriter(conn),
input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)), input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)), rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
netConn: netConn, netConn: netConn,
userUUID: userUUID, logger: logger,
numberOfPacketToFilter: 8,
remainingServerHello: -1, userUUID: userUUID,
filterTlsApplicationData: true, numberOfPacketToFilter: 8,
writeUUID: true, remainingServerHello: -1,
filterUUID: true, isPadding: true,
remainingContent: -1, writeUUID: true,
remainingPadding: -1, withinPaddingBuffers: true,
remainingContent: -1,
remainingPadding: -1,
}, nil }, nil
} }
@ -97,6 +101,7 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
if err == io.EOF { if err == io.EOF {
c.remainingReader = nil c.remainingReader = nil
if n > 0 { if n > 0 {
err = nil
return return
} }
} }
@ -109,13 +114,15 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
return return
} }
buffer := p[:n] buffer := p[:n]
if c.filterUUID && (c.isTLS || c.numberOfPacketToFilter > 0) { if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
buffers := c.unPadding(buffer) buffers := c.unPadding(buffer)
if c.remainingContent == 0 && c.remainingPadding == 0 { if c.remainingContent == 0 && c.remainingPadding == 0 {
if c.currentCommand == 1 { if c.currentCommand == 1 {
c.filterUUID = false c.withinPaddingBuffers = false
c.remainingContent = -1
c.remainingPadding = -1
} else if c.currentCommand == 2 { } else if c.currentCommand == 2 {
c.filterUUID = false c.withinPaddingBuffers = false
c.directRead = true c.directRead = true
inputBuffer, err := io.ReadAll(c.input) inputBuffer, err := io.ReadAll(c.input)
@ -130,9 +137,17 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
} }
buffers = append(buffers, rawInputBuffer) buffers = append(buffers, rawInputBuffer)
} else if c.currentCommand != 0 {
c.logger.Trace("XtlsRead readV")
} else if c.currentCommand == 0 {
c.withinPaddingBuffers = true
} else {
return 0, E.New("unknown command ", c.currentCommand) return 0, E.New("unknown command ", c.currentCommand)
} }
} else if c.remainingContent > 0 || c.remainingPadding > 0 {
c.withinPaddingBuffers = true
} else {
c.withinPaddingBuffers = false
} }
if c.numberOfPacketToFilter > 0 { if c.numberOfPacketToFilter > 0 {
c.filterTLS(buffers) c.filterTLS(buffers)
@ -151,27 +166,27 @@ func (c *VisionConn) Write(p []byte) (n int, err error) {
if c.numberOfPacketToFilter > 0 { if c.numberOfPacketToFilter > 0 {
c.filterTLS([][]byte{p}) c.filterTLS([][]byte{p})
} }
if c.isTLS && c.filterTlsApplicationData { if c.isPadding {
inputLen := len(p) inputLen := len(p)
buffers := reshapeBuffer(p) buffers := reshapeBuffer(p)
var specIndex int var specIndex int
for i, buffer := range buffers { for i, buffer := range buffers {
if buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) { if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
var command byte = 1 var command byte = commandPaddingEnd
if c.enableXTLS { if c.enableXTLS {
c.directWrite = true c.directWrite = true
specIndex = i specIndex = i
command = 2 command = commandPaddingDirect
} }
c.filterTlsApplicationData = false c.isPadding = false
buffers[i] = c.padding(buffer, command) buffers[i] = c.padding(buffer, command)
break break
} else if !c.isTLS12orAbove && c.numberOfPacketToFilter == 0 { } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
c.filterTlsApplicationData = false c.isPadding = false
buffers[i] = c.padding(buffer, 0x01) buffers[i] = c.padding(buffer, commandPaddingEnd)
break break
} }
buffers[i] = c.padding(buffer, 0x00) buffers[i] = c.padding(buffer, commandPaddingContinue)
} }
if c.directWrite { if c.directWrite {
encryptedBuffer := buffers[:specIndex+1] encryptedBuffer := buffers[:specIndex+1]
@ -181,6 +196,7 @@ func (c *VisionConn) Write(p []byte) (n int, err error) {
} }
buffers = buffers[specIndex+1:] buffers = buffers[specIndex+1:]
c.writer = bufio.NewVectorisedWriter(c.netConn) c.writer = bufio.NewVectorisedWriter(c.netConn)
c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
time.Sleep(5 * time.Millisecond) // wtf time.Sleep(5 * time.Millisecond) // wtf
} }
err = c.writer.WriteVectorised(buffers) err = c.writer.WriteVectorised(buffers)
@ -209,10 +225,13 @@ func (c *VisionConn) filterTLS(buffers [][]byte) {
sessionIdLen := int32(buffer[43]) sessionIdLen := int32(buffer[43])
cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3] cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1]) c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
} else {
c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
} }
} }
} else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 { } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
c.isTLS = true c.isTLS = true
c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
} }
} }
if c.remainingServerHello > 0 { if c.remainingServerHello > 0 {
@ -226,13 +245,18 @@ func (c *VisionConn) filterTLS(buffers [][]byte) {
if ok && cipher != "TLS_AES_128_CCM_8_SHA256" { if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
c.enableXTLS = true c.enableXTLS = true
} }
c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
c.numberOfPacketToFilter = 0 c.numberOfPacketToFilter = 0
return return
} else if c.remainingServerHello == 0 { } else if c.remainingServerHello == 0 {
c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
c.numberOfPacketToFilter = 0 c.numberOfPacketToFilter = 0
return return
} }
} }
if c.numberOfPacketToFilter == 0 {
c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
}
} }
} }
@ -242,9 +266,12 @@ func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
if buffer != nil { if buffer != nil {
contentLen = buffer.Len() contentLen = buffer.Len()
} }
if contentLen < 900 { if contentLen < 900 && c.isTLS {
l, _ := rand.Int(rand.Reader, big.NewInt(500)) l, _ := rand.Int(rand.Reader, big.NewInt(500))
paddingLen = int(l.Int64()) + 900 - contentLen paddingLen = int(l.Int64()) + 900 - contentLen
} else {
l, _ := rand.Int(rand.Reader, big.NewInt(256))
paddingLen = int(l.Int64())
} }
newBuffer := buf.New() newBuffer := buf.New()
if c.writeUUID { if c.writeUUID {
@ -257,6 +284,7 @@ func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
buffer.Release() buffer.Release()
} }
newBuffer.Extend(paddingLen) newBuffer.Extend(paddingLen)
c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
return newBuffer return newBuffer
} }
@ -267,6 +295,7 @@ func (c *VisionConn) unPadding(buffer []byte) [][]byte {
bufferIndex = 16 bufferIndex = 16
c.remainingContent = 0 c.remainingContent = 0
c.remainingPadding = 0 c.remainingPadding = 0
c.currentCommand = 0
} }
} }
if c.remainingContent == -1 && c.remainingPadding == -1 { if c.remainingContent == -1 && c.remainingPadding == -1 {
@ -284,6 +313,7 @@ func (c *VisionConn) unPadding(buffer []byte) [][]byte {
c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2]) c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4]) c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
bufferIndex += 5 bufferIndex += 5
c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
} }
} else if c.remainingContent > 0 { } else if c.remainingContent > 0 {
end := c.remainingContent end := c.remainingContent