From 4d23773a25dfa084c076fbfe6e23a78a43dfdb1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 9 Nov 2023 16:59:44 +0800 Subject: [PATCH] Migrate to gobwas/ws --- common/dialer/detour.go | 10 ++- experimental/clashapi/api_meta.go | 17 ++-- experimental/clashapi/connections.go | 9 +- experimental/clashapi/server.go | 38 ++++----- go.mod | 4 +- go.sum | 9 +- transport/v2ray/transport.go | 2 +- transport/v2rayhttp/client.go | 2 +- transport/v2raywebsocket/client.go | 94 ++++++++++++--------- transport/v2raywebsocket/conn.go | 122 +++++++++++++++++---------- transport/v2raywebsocket/mask.go | 6 -- transport/v2raywebsocket/server.go | 13 +-- transport/v2raywebsocket/writer.go | 27 ++---- 13 files changed, 192 insertions(+), 161 deletions(-) delete mode 100644 transport/v2raywebsocket/mask.go diff --git a/common/dialer/detour.go b/common/dialer/detour.go index 81600913..ff484da2 100644 --- a/common/dialer/detour.go +++ b/common/dialer/detour.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/bufio/deadline" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -44,7 +45,14 @@ func (d *DetourDialer) DialContext(ctx context.Context, network string, destinat if err != nil { return nil, err } - return dialer.DialContext(ctx, network, destination) + conn, err := dialer.DialContext(ctx, network, destination) + if err != nil { + return nil, err + } + if deadline.NeedAdditionalReadDeadline(conn) { + conn = deadline.NewConn(conn) + } + return conn, nil } func (d *DetourDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { diff --git a/experimental/clashapi/api_meta.go b/experimental/clashapi/api_meta.go index bfdee1b8..876f9869 100644 --- a/experimental/clashapi/api_meta.go +++ b/experimental/clashapi/api_meta.go @@ -2,12 +2,14 @@ package clashapi import ( "bytes" + "net" "net/http" "time" "github.com/sagernet/sing-box/common/json" "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol" - "github.com/sagernet/websocket" + "github.com/sagernet/ws" + "github.com/sagernet/ws/wsutil" "github.com/go-chi/chi/v5" "github.com/go-chi/render" @@ -27,16 +29,16 @@ type Memory struct { func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - var wsConn *websocket.Conn - if websocket.IsWebSocketUpgrade(r) { + var conn net.Conn + if r.Header.Get("Upgrade") == "websocket" { var err error - wsConn, err = upgrader.Upgrade(w, r, nil) + conn, _, _, err = ws.UpgradeHTTP(r, w) if err != nil { return } } - if wsConn == nil { + if conn == nil { w.Header().Set("Content-Type", "application/json") render.Status(r, http.StatusOK) } @@ -63,13 +65,12 @@ func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r }); err != nil { break } - if wsConn == nil { + if conn == nil { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) + err = wsutil.WriteServerText(conn, buf.Bytes()) } - if err != nil { break } diff --git a/experimental/clashapi/connections.go b/experimental/clashapi/connections.go index 94cfb9a3..042bdd36 100644 --- a/experimental/clashapi/connections.go +++ b/experimental/clashapi/connections.go @@ -9,7 +9,8 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/json" "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol" - "github.com/sagernet/websocket" + "github.com/sagernet/ws" + "github.com/sagernet/ws/wsutil" "github.com/go-chi/chi/v5" "github.com/go-chi/render" @@ -25,13 +26,13 @@ func connectionRouter(router adapter.Router, trafficManager *trafficontrol.Manag func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if !websocket.IsWebSocketUpgrade(r) { + if r.Header.Get("Upgrade") != "websocket" { snapshot := trafficManager.Snapshot() render.JSON(w, r, snapshot) return } - conn, err := upgrader.Upgrade(w, r, nil) + conn, _, _, err := ws.UpgradeHTTP(r, w) if err != nil { return } @@ -56,7 +57,7 @@ func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseW if err := json.NewEncoder(buf).Encode(snapshot); err != nil { return err } - return conn.WriteMessage(websocket.TextMessage, buf.Bytes()) + return wsutil.WriteServerText(conn, buf.Bytes()) } if err = sendSnapshot(); err != nil { diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index 9f4b0f7c..6a3d6f66 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -25,7 +25,8 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/filemanager" - "github.com/sagernet/websocket" + "github.com/sagernet/ws" + "github.com/sagernet/ws/wsutil" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" @@ -314,7 +315,7 @@ func authentication(serverSecret string) func(next http.Handler) http.Handler { } // Browser websocket not support custom header - if websocket.IsWebSocketUpgrade(r) && r.URL.Query().Get("token") != "" { + if r.Header.Get("Upgrade") == "websocket" && r.URL.Query().Get("token") != "" { token := r.URL.Query().Get("token") if token != serverSecret { render.Status(r, http.StatusUnauthorized) @@ -351,12 +352,6 @@ func hello(redirect bool) func(w http.ResponseWriter, r *http.Request) { } } -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, -} - type Traffic struct { Up int64 `json:"up"` Down int64 `json:"down"` @@ -364,16 +359,17 @@ type Traffic struct { func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - var wsConn *websocket.Conn - if websocket.IsWebSocketUpgrade(r) { + var conn net.Conn + if r.Header.Get("Upgrade") == "websocket" { var err error - wsConn, err = upgrader.Upgrade(w, r, nil) + conn, _, _, err = ws.UpgradeHTTP(r, w) if err != nil { return } + defer conn.Close() } - if wsConn == nil { + if conn == nil { w.Header().Set("Content-Type", "application/json") render.Status(r, http.StatusOK) } @@ -392,11 +388,11 @@ func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, break } - if wsConn == nil { + if conn == nil { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) + err = wsutil.WriteServerText(conn, buf.Bytes()) } if err != nil { @@ -432,16 +428,16 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht } defer logFactory.UnSubscribe(subscription) - var wsConn *websocket.Conn - if websocket.IsWebSocketUpgrade(r) { - var err error - wsConn, err = upgrader.Upgrade(w, r, nil) + var conn net.Conn + if r.Header.Get("Upgrade") == "websocket" { + conn, _, _, err = ws.UpgradeHTTP(r, w) if err != nil { return } + defer conn.Close() } - if wsConn == nil { + if conn == nil { w.Header().Set("Content-Type", "application/json") render.Status(r, http.StatusOK) } @@ -465,11 +461,11 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht if err != nil { break } - if wsConn == nil { + if conn == nil { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) + err = wsutil.WriteServerText(conn, buf.Bytes()) } if err != nil { diff --git a/go.mod b/go.mod index 0106737a..f26202f9 100644 --- a/go.mod +++ b/go.mod @@ -38,8 +38,8 @@ require ( github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 - github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e github.com/sagernet/wireguard-go v0.0.0-20230807125731-5d4a7ef2dc5f + github.com/sagernet/ws v0.0.0-20231030053741-7d481eb31bed github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 go.uber.org/zap v1.26.0 @@ -61,6 +61,8 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect diff --git a/go.sum b/go.sum index 10b96b00..d3c77974 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,10 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M= github.com/gofrs/uuid/v5 v5.0.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -134,10 +138,10 @@ github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 h1:Px+hN4Vzgx+iCGV github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6/go.mod h1:zovq6vTvEM6ECiqE3Eeb9rpIylPpamPcmrJ9tv0Bt0M= github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 h1:kDUqhc9Vsk5HJuhfIATJ8oQwBmpOZJuozQG7Vk88lL4= github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2/go.mod h1:JKQMZq/O2qnZjdrt+B57olmfgEmLtY9iiSIEYtWvoSM= -github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e h1:7uw2njHFGE+VpWamge6o56j2RWk4omF6uLKKxMmcWvs= -github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e/go.mod h1:45TUl8+gH4SIKr4ykREbxKWTxkDlSzFENzctB1dVRRY= github.com/sagernet/wireguard-go v0.0.0-20230807125731-5d4a7ef2dc5f h1:Kvo8w8Y9lzFGB/7z09MJ3TR99TFtfI/IuY87Ygcycho= github.com/sagernet/wireguard-go v0.0.0-20230807125731-5d4a7ef2dc5f/go.mod h1:mySs0abhpc/gLlvhoq7HP1RzOaRmIXVeZGCh++zoApk= +github.com/sagernet/ws v0.0.0-20231030053741-7d481eb31bed h1:90a510OeE9siSJoYsI8nSjPmA+u5ROMDts/ZkdNsuXY= +github.com/sagernet/ws v0.0.0-20231030053741-7d481eb31bed/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA= github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg= github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= @@ -189,6 +193,7 @@ golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/transport/v2ray/transport.go b/transport/v2ray/transport.go index 3481c852..9dfee281 100644 --- a/transport/v2ray/transport.go +++ b/transport/v2ray/transport.go @@ -50,7 +50,7 @@ func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socks case C.V2RayTransportTypeGRPC: return NewGRPCClient(ctx, dialer, serverAddr, options.GRPCOptions, tlsConfig) case C.V2RayTransportTypeWebsocket: - return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig), nil + return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig) case C.V2RayTransportTypeQUIC: if tlsConfig == nil { return nil, C.ErrTLSRequired diff --git a/transport/v2rayhttp/client.go b/transport/v2rayhttp/client.go index d5e8a1f6..f280eeef 100644 --- a/transport/v2rayhttp/client.go +++ b/transport/v2rayhttp/client.go @@ -81,7 +81,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt uri.Path = options.Path err := sHTTP.URLSetPath(&uri, options.Path) if err != nil { - return nil, E.New("failed to set path: " + err.Error()) + return nil, E.Cause(err, "parse path") } client.url = &uri return client, nil diff --git a/transport/v2raywebsocket/client.go b/transport/v2raywebsocket/client.go index 9f14f676..54c4df0b 100644 --- a/transport/v2raywebsocket/client.go +++ b/transport/v2raywebsocket/client.go @@ -5,58 +5,37 @@ import ( "net" "net/http" "net/url" + "strings" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" sHTTP "github.com/sagernet/sing/protocol/http" - "github.com/sagernet/websocket" + "github.com/sagernet/ws" ) var _ adapter.V2RayClientTransport = (*Client)(nil) type Client struct { - dialer *websocket.Dialer + dialer N.Dialer + tlsConfig tls.Config + serverAddr M.Socksaddr requestURL url.URL - requestURLString string headers http.Header maxEarlyData uint32 earlyDataHeaderName string } -func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) adapter.V2RayClientTransport { - wsDialer := &websocket.Dialer{ - ReadBufferSize: 4 * 1024, - WriteBufferSize: 4 * 1024, - HandshakeTimeout: time.Second * 8, - } +func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) { if tlsConfig != nil { if len(tlsConfig.NextProtos()) == 0 { tlsConfig.SetNextProtos([]string{"http/1.1"}) } - wsDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - if err != nil { - return nil, err - } - tlsConn, err := tls.ClientHandshake(ctx, conn, tlsConfig) - if err != nil { - return nil, err - } - return &deadConn{tlsConn}, nil - } - } else { - wsDialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - if err != nil { - return nil, err - } - return &deadConn{conn}, nil - } } var requestURL url.URL if tlsConfig == nil { @@ -68,37 +47,68 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt requestURL.Path = options.Path err := sHTTP.URLSetPath(&requestURL, options.Path) if err != nil { - return nil + return nil, E.Cause(err, "parse path") + } + if !strings.HasPrefix(requestURL.Path, "/") { + requestURL.Path = "/" + requestURL.Path } headers := make(http.Header) for key, value := range options.Headers { headers[key] = value + if key == "Host" { + if len(value) > 1 { + return nil, E.New("multiple Host headers") + } + requestURL.Host = value[0] + } + } + if headers.Get("User-Agent") == "" { + headers.Set("User-Agent", "Go-http-client/1.1") } return &Client{ - wsDialer, + dialer, + tlsConfig, + serverAddr, requestURL, - requestURL.String(), headers, options.MaxEarlyData, options.EarlyDataHeaderName, + }, nil +} + +func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers http.Header) (*WebsocketConn, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr) + if err != nil { + return nil, err } + if c.tlsConfig != nil { + conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig) + if err != nil { + return nil, err + } + } + conn.SetDeadline(time.Now().Add(C.TCPTimeout)) + var protocols []string + if protocolHeader := headers.Get("Sec-WebSocket-Protocol"); protocolHeader != "" { + protocols = []string{protocolHeader} + headers.Del("Sec-WebSocket-Protocol") + } + reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(conn, requestURL) + conn.SetDeadline(time.Time{}) + if err != nil { + return nil, err + } + return NewConn(conn, reader, nil, ws.StateClientSide), nil } func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { if c.maxEarlyData <= 0 { - conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers) - if err == nil { - return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil + conn, err := c.dialContext(ctx, &c.requestURL, c.headers) + if err != nil { + return nil, err } - return nil, wrapDialError(response, err) + return conn, nil } else { return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil } } - -func wrapDialError(response *http.Response, err error) error { - if response == nil { - return err - } - return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status) -} diff --git a/transport/v2raywebsocket/conn.go b/transport/v2raywebsocket/conn.go index 6400b118..5b51e5c5 100644 --- a/transport/v2raywebsocket/conn.go +++ b/transport/v2raywebsocket/conn.go @@ -1,11 +1,11 @@ package v2raywebsocket import ( + "bufio" "context" "encoding/base64" "io" "net" - "net/http" "os" "sync" "time" @@ -13,50 +13,96 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/websocket" + "github.com/sagernet/ws" + "github.com/sagernet/ws/wsutil" ) type WebsocketConn struct { - *websocket.Conn + net.Conn *Writer - remoteAddr net.Addr - reader io.Reader + state ws.State + reader *wsutil.Reader + controlHandler wsutil.FrameHandlerFunc + remoteAddr net.Addr } -func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn { +func NewConn(conn net.Conn, br *bufio.Reader, remoteAddr net.Addr, state ws.State) *WebsocketConn { + controlHandler := wsutil.ControlFrameHandler(conn, state) + var reader io.Reader + if br != nil && br.Buffered() > 0 { + reader = br + } else { + reader = conn + } return &WebsocketConn{ - Conn: wsConn, - remoteAddr: remoteAddr, - Writer: NewWriter(wsConn, true), + Conn: conn, + state: state, + reader: &wsutil.Reader{ + Source: reader, + State: state, + SkipHeaderCheck: !debug.Enabled, + OnIntermediate: controlHandler, + }, + controlHandler: controlHandler, + remoteAddr: remoteAddr, + Writer: NewWriter(conn, state), } } func (c *WebsocketConn) Close() error { - err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout)) - if err != nil { - return c.Conn.Close() + c.Conn.SetWriteDeadline(time.Now().Add(C.TCPTimeout)) + frame := ws.NewCloseFrame(ws.NewCloseFrameBody( + ws.StatusNormalClosure, "", + )) + if c.state == ws.StateClientSide { + frame = ws.MaskFrameInPlace(frame) } + ws.WriteFrame(c.Conn, frame) + c.Conn.Close() return nil } func (c *WebsocketConn) Read(b []byte) (n int, err error) { + var header ws.Header for { - if c.reader == nil { - _, c.reader, err = c.NextReader() + n, err = c.reader.Read(b) + if n > 0 { + err = nil + return + } + if !E.IsMulti(err, io.EOF, wsutil.ErrNoFrameAdvance) { + return + } + header, err = c.reader.NextFrame() + if err != nil { + return + } + if header.OpCode.IsControl() { + err = c.controlHandler(header, c.reader) if err != nil { - err = wrapError(err) return } - } - n, err = c.reader.Read(b) - if E.IsMulti(err, io.EOF) { - c.reader = nil continue } - err = wrapError(err) + if header.OpCode&ws.OpBinary == 0 { + err = c.reader.Discard() + if err != nil { + return + } + continue + } + } +} + +func (c *WebsocketConn) Write(p []byte) (n int, err error) { + err = wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, p) + if err != nil { return } + n = len(p) + return } func (c *WebsocketConn) RemoteAddr() net.Addr { @@ -83,11 +129,7 @@ func (c *WebsocketConn) NeedAdditionalReadDeadline() bool { } func (c *WebsocketConn) Upstream() any { - return c.Conn.NetConn() -} - -func (c *WebsocketConn) UpstreamWriter() any { - return c.Writer + return c.Conn } type EarlyWebsocketConn struct { @@ -113,8 +155,7 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error { var ( earlyData []byte lateData []byte - conn *websocket.Conn - response *http.Response + conn *WebsocketConn err error ) if len(content) > int(c.maxEarlyData) { @@ -128,23 +169,26 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error { if c.earlyDataHeaderName == "" { requestURL := c.requestURL requestURL.Path += earlyDataString - conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers) + conn, err = c.dialContext(c.ctx, &requestURL, c.headers) } else { headers := c.headers.Clone() headers.Set(c.earlyDataHeaderName, earlyDataString) - conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers) + conn, err = c.dialContext(c.ctx, &c.requestURL, headers) } } else { - conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers) + conn, err = c.dialContext(c.ctx, &c.requestURL, c.headers) } if err != nil { - return wrapDialError(response, err) + return err } - c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)} if len(lateData) > 0 { - _, err = c.conn.Write(lateData) + _, err = conn.Write(lateData) + if err != nil { + return err + } } - return err + c.conn = conn + return nil } func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { @@ -230,13 +274,3 @@ func (c *EarlyWebsocketConn) Upstream() any { func (c *EarlyWebsocketConn) LazyHeadroom() bool { return c.conn == nil } - -func wrapError(err error) error { - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - return io.EOF - } - if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { - return net.ErrClosed - } - return err -} diff --git a/transport/v2raywebsocket/mask.go b/transport/v2raywebsocket/mask.go deleted file mode 100644 index 01ea8437..00000000 --- a/transport/v2raywebsocket/mask.go +++ /dev/null @@ -1,6 +0,0 @@ -package v2raywebsocket - -import _ "unsafe" - -//go:linkname maskBytes github.com/sagernet/websocket.maskBytes -func maskBytes(key [4]byte, pos int, b []byte) int diff --git a/transport/v2raywebsocket/server.go b/transport/v2raywebsocket/server.go index 9d8bc69a..ae6e15f3 100644 --- a/transport/v2raywebsocket/server.go +++ b/transport/v2raywebsocket/server.go @@ -20,7 +20,7 @@ import ( N "github.com/sagernet/sing/common/network" aTLS "github.com/sagernet/sing/common/tls" sHttp "github.com/sagernet/sing/protocol/http" - "github.com/sagernet/websocket" + "github.com/sagernet/ws" ) var _ adapter.V2RayServerTransport = (*Server)(nil) @@ -58,13 +58,6 @@ func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsCon return server, nil } -var upgrader = websocket.Upgrader{ - HandshakeTimeout: C.TCPTimeout, - CheckOrigin: func(r *http.Request) bool { - return true - }, -} - func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" { if request.URL.Path != s.path { @@ -95,14 +88,14 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { s.invalidRequest(writer, request, http.StatusBadRequest, E.Cause(err, "decode early data")) return } - wsConn, err := upgrader.Upgrade(writer, request, nil) + wsConn, reader, _, err := ws.UpgradeHTTP(request, writer) if err != nil { s.invalidRequest(writer, request, 0, E.Cause(err, "upgrade websocket connection")) return } var metadata M.Metadata metadata.Source = sHttp.SourceAddress(request) - conn = NewServerConn(wsConn, metadata.Source.TCPAddr()) + conn = NewConn(wsConn, reader.Reader, metadata.Source.TCPAddr(), ws.StateServerSide) if len(earlyData) > 0 { conn = bufio.NewCachedConn(conn, buf.As(earlyData)) } diff --git a/transport/v2raywebsocket/writer.go b/transport/v2raywebsocket/writer.go index fbb61f0f..5bd0d0a1 100644 --- a/transport/v2raywebsocket/writer.go +++ b/transport/v2raywebsocket/writer.go @@ -2,36 +2,27 @@ package v2raywebsocket import ( "encoding/binary" + "io" "math/rand" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/websocket" + "github.com/sagernet/ws" ) type Writer struct { - *websocket.Conn writer N.ExtendedWriter isServer bool } -func NewWriter(conn *websocket.Conn, isServer bool) *Writer { +func NewWriter(writer io.Writer, state ws.State) *Writer { return &Writer{ - conn, - bufio.NewExtendedWriter(conn.NetConn()), - isServer, + bufio.NewExtendedWriter(writer), + state == ws.StateServerSide, } } -func (w *Writer) Write(p []byte) (n int, err error) { - err = w.Conn.WriteMessage(websocket.BinaryMessage, p) - if err != nil { - return - } - return len(p), nil -} - func (w *Writer) WriteBuffer(buffer *buf.Buffer) error { var payloadBitLength int dataLen := buffer.Len() @@ -52,7 +43,7 @@ func (w *Writer) WriteBuffer(buffer *buf.Buffer) error { } header := buffer.ExtendHeader(headerLen) - header[0] = websocket.BinaryMessage | 1<<7 + header[0] = byte(ws.OpBinary) | 0x80 if w.isServer { header[1] = 0 } else { @@ -72,16 +63,12 @@ func (w *Writer) WriteBuffer(buffer *buf.Buffer) error { if !w.isServer { maskKey := rand.Uint32() binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey) - maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data) + ws.Cipher(data, *(*[4]byte)(header[1+payloadBitLength:]), 0) } return w.writer.WriteBuffer(buffer) } -func (w *Writer) Upstream() any { - return w.Conn.NetConn() -} - func (w *Writer) FrontHeadroom() int { return 14 }