diff --git a/transport/v2raywebsocket/client.go b/transport/v2raywebsocket/client.go index 7a827881..9f14f676 100644 --- a/transport/v2raywebsocket/client.go +++ b/transport/v2raywebsocket/client.go @@ -21,7 +21,8 @@ var _ adapter.V2RayClientTransport = (*Client)(nil) type Client struct { dialer *websocket.Dialer - uri string + requestURL url.URL + requestURLString string headers http.Header maxEarlyData uint32 earlyDataHeaderName string @@ -57,15 +58,15 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt return &deadConn{conn}, nil } } - var uri url.URL + var requestURL url.URL if tlsConfig == nil { - uri.Scheme = "ws" + requestURL.Scheme = "ws" } else { - uri.Scheme = "wss" + requestURL.Scheme = "wss" } - uri.Host = serverAddr.String() - uri.Path = options.Path - err := sHTTP.URLSetPath(&uri, options.Path) + requestURL.Host = serverAddr.String() + requestURL.Path = options.Path + err := sHTTP.URLSetPath(&requestURL, options.Path) if err != nil { return nil } @@ -75,7 +76,8 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt } return &Client{ wsDialer, - uri.String(), + requestURL, + requestURL.String(), headers, options.MaxEarlyData, options.EarlyDataHeaderName, @@ -84,7 +86,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { if c.maxEarlyData <= 0 { - conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers) + conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers) if err == nil { return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil } diff --git a/transport/v2raywebsocket/conn.go b/transport/v2raywebsocket/conn.go index dc3f4b52..7bf9df1d 100644 --- a/transport/v2raywebsocket/conn.go +++ b/transport/v2raywebsocket/conn.go @@ -94,51 +94,64 @@ type EarlyWebsocketConn struct { ctx context.Context conn *WebsocketConn create chan struct{} + err error } func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { if c.conn == nil { <-c.create + if c.err != nil { + return 0, c.err + } } return c.conn.Read(b) } +func (c *EarlyWebsocketConn) writeRequest(content []byte) error { + var ( + earlyData []byte + lateData []byte + conn *websocket.Conn + response *http.Response + err error + ) + if len(content) > int(c.maxEarlyData) { + earlyData = content[:c.maxEarlyData] + lateData = content[c.maxEarlyData:] + } else { + earlyData = content + } + if len(earlyData) > 0 { + earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) + if c.earlyDataHeaderName == "" { + requestURL := c.requestURL + requestURL.Path += earlyDataString + conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers) + } else { + headers := c.headers.Clone() + headers.Set(c.earlyDataHeaderName, earlyDataString) + conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers) + } + } else { + conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers) + } + if err != nil { + return wrapDialError(response, err) + } + c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)} + if len(lateData) > 0 { + _, err = c.conn.Write(lateData) + } + return err +} + func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { if c.conn != nil { return c.conn.Write(b) } - var ( - earlyData []byte - lateData []byte - conn *websocket.Conn - response *http.Response - ) - if len(b) > int(c.maxEarlyData) { - earlyData = b[:c.maxEarlyData] - lateData = b[c.maxEarlyData:] - } else { - earlyData = b - } - if len(earlyData) > 0 { - earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) - if c.earlyDataHeaderName == "" { - conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers) - } else { - headers := c.headers.Clone() - headers.Set(c.earlyDataHeaderName, earlyDataString) - conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers) - } - } else { - conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers) - } - if err != nil { - return 0, wrapDialError(response, err) - } - c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)} + err = c.writeRequest(b) + c.err = err close(c.create) - if len(lateData) > 0 { - _, err = c.conn.Write(lateData) - } if err != nil { return } @@ -149,39 +162,9 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { if c.conn != nil { return c.conn.WriteBuffer(buffer) } - var ( - earlyData []byte - lateData []byte - conn *websocket.Conn - response *http.Response - err error - ) - if buffer.Len() > int(c.maxEarlyData) { - earlyData = buffer.Bytes()[:c.maxEarlyData] - lateData = buffer.Bytes()[c.maxEarlyData:] - } else { - earlyData = buffer.Bytes() - } - if len(earlyData) > 0 { - earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) - if c.earlyDataHeaderName == "" { - conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers) - } else { - headers := c.headers.Clone() - headers.Set(c.earlyDataHeaderName, earlyDataString) - conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers) - } - } else { - conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers) - } - if err != nil { - return wrapDialError(response, err) - } - c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)} + err := c.writeRequest(buffer.Bytes()) + c.err = err close(c.create) - if len(lateData) > 0 { - _, err = c.conn.Write(lateData) - } return err }