Fix v2ray websocket transport

This commit is contained in:
世界 2023-07-11 15:12:26 +08:00
parent 120dae4eed
commit d74abbd20e
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 57 additions and 72 deletions

View file

@ -21,7 +21,8 @@ var _ adapter.V2RayClientTransport = (*Client)(nil)
type Client struct {
dialer *websocket.Dialer
uri string
requestURL url.URL
requestURLString string
headers http.Header
maxEarlyData uint32
earlyDataHeaderName string
@ -57,15 +58,15 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
return &deadConn{conn}, nil
}
}
var uri url.URL
var requestURL url.URL
if tlsConfig == nil {
uri.Scheme = "ws"
requestURL.Scheme = "ws"
} else {
uri.Scheme = "wss"
requestURL.Scheme = "wss"
}
uri.Host = serverAddr.String()
uri.Path = options.Path
err := sHTTP.URLSetPath(&uri, options.Path)
requestURL.Host = serverAddr.String()
requestURL.Path = options.Path
err := sHTTP.URLSetPath(&requestURL, options.Path)
if err != nil {
return nil
}
@ -75,7 +76,8 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
}
return &Client{
wsDialer,
uri.String(),
requestURL,
requestURL.String(),
headers,
options.MaxEarlyData,
options.EarlyDataHeaderName,
@ -84,7 +86,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
if c.maxEarlyData <= 0 {
conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers)
conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers)
if err == nil {
return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
}

View file

@ -94,51 +94,64 @@ type EarlyWebsocketConn struct {
ctx context.Context
conn *WebsocketConn
create chan struct{}
err error
}
func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
if c.conn == nil {
<-c.create
if c.err != nil {
return 0, c.err
}
}
return c.conn.Read(b)
}
func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
var (
earlyData []byte
lateData []byte
conn *websocket.Conn
response *http.Response
err error
)
if len(content) > int(c.maxEarlyData) {
earlyData = content[:c.maxEarlyData]
lateData = content[c.maxEarlyData:]
} else {
earlyData = content
}
if len(earlyData) > 0 {
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
if c.earlyDataHeaderName == "" {
requestURL := c.requestURL
requestURL.Path += earlyDataString
conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
} else {
headers := c.headers.Clone()
headers.Set(c.earlyDataHeaderName, earlyDataString)
conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
}
} else {
conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
}
if err != nil {
return wrapDialError(response, err)
}
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
if len(lateData) > 0 {
_, err = c.conn.Write(lateData)
}
return err
}
func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
if c.conn != nil {
return c.conn.Write(b)
}
var (
earlyData []byte
lateData []byte
conn *websocket.Conn
response *http.Response
)
if len(b) > int(c.maxEarlyData) {
earlyData = b[:c.maxEarlyData]
lateData = b[c.maxEarlyData:]
} else {
earlyData = b
}
if len(earlyData) > 0 {
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
if c.earlyDataHeaderName == "" {
conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
} else {
headers := c.headers.Clone()
headers.Set(c.earlyDataHeaderName, earlyDataString)
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
}
} else {
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
}
if err != nil {
return 0, wrapDialError(response, err)
}
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
err = c.writeRequest(b)
c.err = err
close(c.create)
if len(lateData) > 0 {
_, err = c.conn.Write(lateData)
}
if err != nil {
return
}
@ -149,39 +162,9 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
if c.conn != nil {
return c.conn.WriteBuffer(buffer)
}
var (
earlyData []byte
lateData []byte
conn *websocket.Conn
response *http.Response
err error
)
if buffer.Len() > int(c.maxEarlyData) {
earlyData = buffer.Bytes()[:c.maxEarlyData]
lateData = buffer.Bytes()[c.maxEarlyData:]
} else {
earlyData = buffer.Bytes()
}
if len(earlyData) > 0 {
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
if c.earlyDataHeaderName == "" {
conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
} else {
headers := c.headers.Clone()
headers.Set(c.earlyDataHeaderName, earlyDataString)
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
}
} else {
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
}
if err != nil {
return wrapDialError(response, err)
}
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
err := c.writeRequest(buffer.Bytes())
c.err = err
close(c.create)
if len(lateData) > 0 {
_, err = c.conn.Write(lateData)
}
return err
}