mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-22 00:21:30 +00:00
Update vision protocol
This commit is contained in:
parent
5ce3ddee9b
commit
e4bff0460d
|
@ -50,7 +50,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg
|
|||
ctx: ctx,
|
||||
users: options.Users,
|
||||
}
|
||||
service := vless.NewService[int](adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound))
|
||||
service := vless.NewService[int](logger, adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound))
|
||||
service.UpdateUsers(common.MapIndexed(inbound.users, func(index int, _ option.VLESSUser) int {
|
||||
return index
|
||||
}), common.Map(inbound.users, func(it option.VLESSUser) string {
|
||||
|
|
|
@ -67,7 +67,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg
|
|||
default:
|
||||
return nil, E.New("unknown packet encoding: ", options.PacketEncoding)
|
||||
}
|
||||
outbound.client, err = vless.NewClient(options.UUID, options.Flow)
|
||||
outbound.client, err = vless.NewClient(options.UUID, options.Flow, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
|
@ -16,11 +17,12 @@ import (
|
|||
)
|
||||
|
||||
type Client struct {
|
||||
key [16]byte
|
||||
flow string
|
||||
key [16]byte
|
||||
flow string
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewClient(userId string, flow string) (*Client, error) {
|
||||
func NewClient(userId string, flow string, logger logger.Logger) (*Client, error) {
|
||||
user := uuid.FromStringOrNil(userId)
|
||||
if user == uuid.Nil {
|
||||
user = uuid.NewV5(user, userId)
|
||||
|
@ -30,12 +32,12 @@ func NewClient(userId string, flow string) (*Client, error) {
|
|||
default:
|
||||
return nil, E.New("unsupported flow: " + flow)
|
||||
}
|
||||
return &Client{user, flow}, nil
|
||||
return &Client{user, flow, logger}, nil
|
||||
}
|
||||
|
||||
func (c *Client) prepareConn(conn net.Conn) (net.Conn, error) {
|
||||
if c.flow == FlowVision {
|
||||
vConn, err := NewVisionConn(conn, c.key)
|
||||
vConn, err := NewVisionConn(conn, c.key, c.logger)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "initialize vision")
|
||||
}
|
||||
|
|
|
@ -11,6 +11,10 @@ var (
|
|||
tlsClientHandShakeStart = []byte{0x16, 0x03}
|
||||
tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03}
|
||||
tlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
|
||||
|
||||
commandPaddingContinue byte = 0
|
||||
commandPaddingEnd byte = 1
|
||||
commandPaddingDirect byte = 2
|
||||
)
|
||||
|
||||
var tls13CipherSuiteDic = map[uint16]string{
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
|
@ -19,6 +20,7 @@ import (
|
|||
|
||||
type Service[T any] struct {
|
||||
userMap map[[16]byte]T
|
||||
logger logger.Logger
|
||||
handler Handler
|
||||
}
|
||||
|
||||
|
@ -28,8 +30,9 @@ type Handler interface {
|
|||
E.Handler
|
||||
}
|
||||
|
||||
func NewService[T any](handler Handler) *Service[T] {
|
||||
func NewService[T any](logger logger.Logger, handler Handler) *Service[T] {
|
||||
return &Service[T]{
|
||||
logger: logger,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
@ -64,7 +67,7 @@ func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata
|
|||
switch request.Flow {
|
||||
case "":
|
||||
case FlowVision:
|
||||
protocolConn, err = NewVisionConn(conn, request.UUID)
|
||||
protocolConn, err = NewVisionConn(conn, request.UUID, s.logger)
|
||||
if err != nil {
|
||||
return E.Cause(err, "initialize vision")
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
|
@ -37,26 +38,27 @@ type VisionConn struct {
|
|||
input *bytes.Reader
|
||||
rawInput *bytes.Buffer
|
||||
netConn net.Conn
|
||||
logger logger.Logger
|
||||
|
||||
userUUID [16]byte
|
||||
isTLS bool
|
||||
numberOfPacketToFilter int
|
||||
isTLS12orAbove bool
|
||||
remainingServerHello int32
|
||||
cipher uint16
|
||||
enableXTLS bool
|
||||
filterTlsApplicationData bool
|
||||
directWrite bool
|
||||
writeUUID bool
|
||||
filterUUID bool
|
||||
remainingContent int
|
||||
remainingPadding int
|
||||
currentCommand int
|
||||
directRead bool
|
||||
remainingReader io.Reader
|
||||
userUUID [16]byte
|
||||
isTLS bool
|
||||
numberOfPacketToFilter int
|
||||
isTLS12orAbove bool
|
||||
remainingServerHello int32
|
||||
cipher uint16
|
||||
enableXTLS bool
|
||||
isPadding bool
|
||||
directWrite bool
|
||||
writeUUID bool
|
||||
withinPaddingBuffers bool
|
||||
remainingContent int
|
||||
remainingPadding int
|
||||
currentCommand int
|
||||
directRead bool
|
||||
remainingReader io.Reader
|
||||
}
|
||||
|
||||
func NewVisionConn(conn net.Conn, userUUID [16]byte) (*VisionConn, error) {
|
||||
func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
|
||||
var (
|
||||
loaded bool
|
||||
reflectType reflect.Type
|
||||
|
@ -75,19 +77,21 @@ func NewVisionConn(conn net.Conn, userUUID [16]byte) (*VisionConn, error) {
|
|||
input, _ := reflectType.FieldByName("input")
|
||||
rawInput, _ := reflectType.FieldByName("rawInput")
|
||||
return &VisionConn{
|
||||
Conn: conn,
|
||||
writer: bufio.NewVectorisedWriter(conn),
|
||||
input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
|
||||
rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
|
||||
netConn: netConn,
|
||||
userUUID: userUUID,
|
||||
numberOfPacketToFilter: 8,
|
||||
remainingServerHello: -1,
|
||||
filterTlsApplicationData: true,
|
||||
writeUUID: true,
|
||||
filterUUID: true,
|
||||
remainingContent: -1,
|
||||
remainingPadding: -1,
|
||||
Conn: conn,
|
||||
writer: bufio.NewVectorisedWriter(conn),
|
||||
input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
|
||||
rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
|
||||
netConn: netConn,
|
||||
logger: logger,
|
||||
|
||||
userUUID: userUUID,
|
||||
numberOfPacketToFilter: 8,
|
||||
remainingServerHello: -1,
|
||||
isPadding: true,
|
||||
writeUUID: true,
|
||||
withinPaddingBuffers: true,
|
||||
remainingContent: -1,
|
||||
remainingPadding: -1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -97,6 +101,7 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
|
|||
if err == io.EOF {
|
||||
c.remainingReader = nil
|
||||
if n > 0 {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -109,13 +114,15 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
buffer := p[:n]
|
||||
if c.filterUUID && (c.isTLS || c.numberOfPacketToFilter > 0) {
|
||||
if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
|
||||
buffers := c.unPadding(buffer)
|
||||
if c.remainingContent == 0 && c.remainingPadding == 0 {
|
||||
if c.currentCommand == 1 {
|
||||
c.filterUUID = false
|
||||
c.withinPaddingBuffers = false
|
||||
c.remainingContent = -1
|
||||
c.remainingPadding = -1
|
||||
} else if c.currentCommand == 2 {
|
||||
c.filterUUID = false
|
||||
c.withinPaddingBuffers = false
|
||||
c.directRead = true
|
||||
|
||||
inputBuffer, err := io.ReadAll(c.input)
|
||||
|
@ -130,9 +137,17 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
buffers = append(buffers, rawInputBuffer)
|
||||
} else if c.currentCommand != 0 {
|
||||
|
||||
c.logger.Trace("XtlsRead readV")
|
||||
} else if c.currentCommand == 0 {
|
||||
c.withinPaddingBuffers = true
|
||||
} else {
|
||||
return 0, E.New("unknown command ", c.currentCommand)
|
||||
}
|
||||
} else if c.remainingContent > 0 || c.remainingPadding > 0 {
|
||||
c.withinPaddingBuffers = true
|
||||
} else {
|
||||
c.withinPaddingBuffers = false
|
||||
}
|
||||
if c.numberOfPacketToFilter > 0 {
|
||||
c.filterTLS(buffers)
|
||||
|
@ -151,27 +166,27 @@ func (c *VisionConn) Write(p []byte) (n int, err error) {
|
|||
if c.numberOfPacketToFilter > 0 {
|
||||
c.filterTLS([][]byte{p})
|
||||
}
|
||||
if c.isTLS && c.filterTlsApplicationData {
|
||||
if c.isPadding {
|
||||
inputLen := len(p)
|
||||
buffers := reshapeBuffer(p)
|
||||
var specIndex int
|
||||
for i, buffer := range buffers {
|
||||
if buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
|
||||
var command byte = 1
|
||||
if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
|
||||
var command byte = commandPaddingEnd
|
||||
if c.enableXTLS {
|
||||
c.directWrite = true
|
||||
specIndex = i
|
||||
command = 2
|
||||
command = commandPaddingDirect
|
||||
}
|
||||
c.filterTlsApplicationData = false
|
||||
c.isPadding = false
|
||||
buffers[i] = c.padding(buffer, command)
|
||||
break
|
||||
} else if !c.isTLS12orAbove && c.numberOfPacketToFilter == 0 {
|
||||
c.filterTlsApplicationData = false
|
||||
buffers[i] = c.padding(buffer, 0x01)
|
||||
} else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
|
||||
c.isPadding = false
|
||||
buffers[i] = c.padding(buffer, commandPaddingEnd)
|
||||
break
|
||||
}
|
||||
buffers[i] = c.padding(buffer, 0x00)
|
||||
buffers[i] = c.padding(buffer, commandPaddingContinue)
|
||||
}
|
||||
if c.directWrite {
|
||||
encryptedBuffer := buffers[:specIndex+1]
|
||||
|
@ -181,6 +196,7 @@ func (c *VisionConn) Write(p []byte) (n int, err error) {
|
|||
}
|
||||
buffers = buffers[specIndex+1:]
|
||||
c.writer = bufio.NewVectorisedWriter(c.netConn)
|
||||
c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
|
||||
time.Sleep(5 * time.Millisecond) // wtf
|
||||
}
|
||||
err = c.writer.WriteVectorised(buffers)
|
||||
|
@ -209,10 +225,13 @@ func (c *VisionConn) filterTLS(buffers [][]byte) {
|
|||
sessionIdLen := int32(buffer[43])
|
||||
cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
|
||||
c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
|
||||
} else {
|
||||
c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
|
||||
}
|
||||
}
|
||||
} else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
|
||||
c.isTLS = true
|
||||
c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
|
||||
}
|
||||
}
|
||||
if c.remainingServerHello > 0 {
|
||||
|
@ -226,13 +245,18 @@ func (c *VisionConn) filterTLS(buffers [][]byte) {
|
|||
if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
|
||||
c.enableXTLS = true
|
||||
}
|
||||
c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
|
||||
c.numberOfPacketToFilter = 0
|
||||
return
|
||||
} else if c.remainingServerHello == 0 {
|
||||
c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
|
||||
c.numberOfPacketToFilter = 0
|
||||
return
|
||||
}
|
||||
}
|
||||
if c.numberOfPacketToFilter == 0 {
|
||||
c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -242,9 +266,12 @@ func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
|
|||
if buffer != nil {
|
||||
contentLen = buffer.Len()
|
||||
}
|
||||
if contentLen < 900 {
|
||||
if contentLen < 900 && c.isTLS {
|
||||
l, _ := rand.Int(rand.Reader, big.NewInt(500))
|
||||
paddingLen = int(l.Int64()) + 900 - contentLen
|
||||
} else {
|
||||
l, _ := rand.Int(rand.Reader, big.NewInt(256))
|
||||
paddingLen = int(l.Int64())
|
||||
}
|
||||
newBuffer := buf.New()
|
||||
if c.writeUUID {
|
||||
|
@ -257,6 +284,7 @@ func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
|
|||
buffer.Release()
|
||||
}
|
||||
newBuffer.Extend(paddingLen)
|
||||
c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
|
||||
return newBuffer
|
||||
}
|
||||
|
||||
|
@ -267,6 +295,7 @@ func (c *VisionConn) unPadding(buffer []byte) [][]byte {
|
|||
bufferIndex = 16
|
||||
c.remainingContent = 0
|
||||
c.remainingPadding = 0
|
||||
c.currentCommand = 0
|
||||
}
|
||||
}
|
||||
if c.remainingContent == -1 && c.remainingPadding == -1 {
|
||||
|
@ -284,6 +313,7 @@ func (c *VisionConn) unPadding(buffer []byte) [][]byte {
|
|||
c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
|
||||
c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
|
||||
bufferIndex += 5
|
||||
c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
|
||||
}
|
||||
} else if c.remainingContent > 0 {
|
||||
end := c.remainingContent
|
||||
|
|
Loading…
Reference in a new issue