From 22ea878fe99f0cbf1bc4991e4e1c529eadd36a20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 23 Sep 2022 17:21:24 +0800 Subject: [PATCH] Improve websocket writer --- docs/features.md | 2 +- transport/v2raywebsocket/client.go | 2 +- transport/v2raywebsocket/conn.go | 74 +++++++++++++++++++++++++----- transport/v2raywebsocket/mask.go | 6 +++ transport/v2raywebsocket/server.go | 5 +- transport/v2raywebsocket/writer.go | 73 +++++++++++++++++++++++++++++ 6 files changed, 144 insertions(+), 18 deletions(-) create mode 100644 transport/v2raywebsocket/mask.go create mode 100644 transport/v2raywebsocket/writer.go diff --git a/docs/features.md b/docs/features.md index 71feed89..84185bca 100644 --- a/docs/features.md +++ b/docs/features.md @@ -110,7 +110,7 @@ | / | TCP | HTTP | H2 TLS | WebSocket TLS | gRPC TLS | |--------------------|:---------:|:---------:|:---------:|:-------------:|:---------:| | v2ray-core (5.1.0) | 7.86 GBps | 2.86 Gbps | 1.83 Gbps | 2.36 Gbps | 2.43 Gbps | -| sing-box | 7.96 Gbps | 8.09 Gbps | 6.11 Gbps | 2.69 Gbps | 6.35 Gbps | +| sing-box | 7.96 Gbps | 8.09 Gbps | 6.11 Gbps | 8.02 Gbps | 6.35 Gbps | #### License diff --git a/transport/v2raywebsocket/client.go b/transport/v2raywebsocket/client.go index db194a32..f449ff53 100644 --- a/transport/v2raywebsocket/client.go +++ b/transport/v2raywebsocket/client.go @@ -74,7 +74,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { if c.maxEarlyData <= 0 { conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers) if err == nil { - return &WebsocketConn{Conn: conn}, nil + return &WebsocketConn{Conn: conn, Writer: &Writer{conn, false}}, nil } return nil, wrapDialError(response, err) } else { diff --git a/transport/v2raywebsocket/conn.go b/transport/v2raywebsocket/conn.go index 9985a65d..455e1312 100644 --- a/transport/v2raywebsocket/conn.go +++ b/transport/v2raywebsocket/conn.go @@ -10,16 +10,26 @@ import ( "time" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/websocket" ) type WebsocketConn struct { *websocket.Conn + *Writer remoteAddr net.Addr reader io.Reader } +func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn { + return &WebsocketConn{ + Conn: wsConn, + remoteAddr: remoteAddr, + Writer: &Writer{wsConn, true}, + } +} + func (c *WebsocketConn) Close() error { err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout)) if err != nil { @@ -47,14 +57,6 @@ func (c *WebsocketConn) Read(b []byte) (n int, err error) { } } -func (c *WebsocketConn) Write(b []byte) (n int, err error) { - err = wrapError(c.WriteMessage(websocket.BinaryMessage, b)) - if err != nil { - return - } - return len(b), nil -} - func (c *WebsocketConn) RemoteAddr() net.Addr { if c.remoteAddr != nil { return c.remoteAddr @@ -66,6 +68,10 @@ func (c *WebsocketConn) SetDeadline(t time.Time) error { return os.ErrInvalid } +func (c *WebsocketConn) FrontHeadroom() int { + return frontHeadroom +} + type EarlyWebsocketConn struct { *Client ctx context.Context @@ -90,9 +96,9 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { conn *websocket.Conn response *http.Response ) - if len(earlyData) > int(c.maxEarlyData) { - earlyData = earlyData[:c.maxEarlyData] - lateData = lateData[c.maxEarlyData:] + if len(b) > int(c.maxEarlyData) { + earlyData = b[:c.maxEarlyData] + lateData = b[c.maxEarlyData:] } else { earlyData = b } @@ -111,7 +117,7 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { if err != nil { return 0, wrapDialError(response, err) } - c.conn = &WebsocketConn{Conn: conn} + c.conn = &WebsocketConn{Conn: conn, Writer: &Writer{conn, false}} close(c.create) if len(lateData) > 0 { _, err = c.conn.Write(lateData) @@ -122,6 +128,46 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { + if c.conn != nil { + return c.conn.WriteBuffer(buffer) + } + var ( + earlyData []byte + lateData []byte + conn *websocket.Conn + response *http.Response + err error + ) + if buffer.Len() > int(c.maxEarlyData) { + earlyData = buffer.Bytes()[:c.maxEarlyData] + lateData = buffer.Bytes()[c.maxEarlyData:] + } else { + earlyData = buffer.Bytes() + } + if len(earlyData) > 0 { + earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) + if c.earlyDataHeaderName == "" { + conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers) + } else { + headers := c.headers.Clone() + headers.Set(c.earlyDataHeaderName, earlyDataString) + conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers) + } + } else { + conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers) + } + if err != nil { + return wrapDialError(response, err) + } + c.conn = &WebsocketConn{Conn: conn, Writer: &Writer{conn, false}} + close(c.create) + if len(lateData) > 0 { + _, err = c.conn.Write(lateData) + } + return err +} + func (c *EarlyWebsocketConn) Close() error { if c.conn == nil { return nil @@ -164,6 +210,10 @@ func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } +func (c *EarlyWebsocketConn) FrontHeadroom() int { + return frontHeadroom +} + func wrapError(err error) error { if websocket.IsCloseError(err, websocket.CloseNormalClosure) { return io.EOF diff --git a/transport/v2raywebsocket/mask.go b/transport/v2raywebsocket/mask.go new file mode 100644 index 00000000..01ea8437 --- /dev/null +++ b/transport/v2raywebsocket/mask.go @@ -0,0 +1,6 @@ +package v2raywebsocket + +import _ "unsafe" + +//go:linkname maskBytes github.com/sagernet/websocket.maskBytes +func maskBytes(key [4]byte, pos int, b []byte) int diff --git a/transport/v2raywebsocket/server.go b/transport/v2raywebsocket/server.go index 453a47c7..32a01bf8 100644 --- a/transport/v2raywebsocket/server.go +++ b/transport/v2raywebsocket/server.go @@ -108,10 +108,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { } var metadata M.Metadata metadata.Source = sHttp.SourceAddress(request) - conn = &WebsocketConn{ - Conn: wsConn, - remoteAddr: metadata.Source.TCPAddr(), - } + conn = NewServerConn(wsConn, metadata.Source.TCPAddr()) if len(earlyData) > 0 { conn = bufio.NewCachedConn(conn, buf.As(earlyData)) } diff --git a/transport/v2raywebsocket/writer.go b/transport/v2raywebsocket/writer.go new file mode 100644 index 00000000..ba0b145e --- /dev/null +++ b/transport/v2raywebsocket/writer.go @@ -0,0 +1,73 @@ +package v2raywebsocket + +import ( + "encoding/binary" + "math/rand" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/websocket" +) + +const frontHeadroom = 14 + +type Writer struct { + *websocket.Conn + isServer bool +} + +func (w *Writer) Write(p []byte) (n int, err error) { + err = w.Conn.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + return + } + return len(p), nil +} + +func (w *Writer) WriteBuffer(buffer *buf.Buffer) error { + defer buffer.Release() + + var payloadBitLength int + dataLen := buffer.Len() + data := buffer.Bytes() + if dataLen < 126 { + payloadBitLength = 1 + } else if dataLen < 65536 { + payloadBitLength = 3 + } else { + payloadBitLength = 9 + } + + var headerLen int + headerLen += 1 // FIN / RSV / OPCODE + headerLen += payloadBitLength + if !w.isServer { + headerLen += 4 // MASK KEY + } + + header := buffer.ExtendHeader(headerLen) + header[0] = websocket.BinaryMessage | 1<<7 + if w.isServer { + header[1] = 0 + } else { + header[1] = 1 << 7 + } + + if dataLen < 126 { + header[1] |= byte(dataLen) + } else if dataLen < 65536 { + header[1] |= 126 + binary.BigEndian.PutUint16(header[2:], uint16(dataLen)) + } else { + header[1] |= 127 + binary.BigEndian.PutUint64(header[2:], uint64(dataLen)) + } + + if !w.isServer { + maskKey := rand.Uint32() + binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey) + maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data) + } + + return common.Error(w.Conn.NetConn().Write(buffer.Bytes())) +}