Fix v2ray websocket transport

This commit is contained in:
世界 2023-07-11 15:12:26 +08:00
parent 120dae4eed
commit d74abbd20e
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 57 additions and 72 deletions

View file

@ -21,7 +21,8 @@ var _ adapter.V2RayClientTransport = (*Client)(nil)
type Client struct { type Client struct {
dialer *websocket.Dialer dialer *websocket.Dialer
uri string requestURL url.URL
requestURLString string
headers http.Header headers http.Header
maxEarlyData uint32 maxEarlyData uint32
earlyDataHeaderName string earlyDataHeaderName string
@ -57,15 +58,15 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
return &deadConn{conn}, nil return &deadConn{conn}, nil
} }
} }
var uri url.URL var requestURL url.URL
if tlsConfig == nil { if tlsConfig == nil {
uri.Scheme = "ws" requestURL.Scheme = "ws"
} else { } else {
uri.Scheme = "wss" requestURL.Scheme = "wss"
} }
uri.Host = serverAddr.String() requestURL.Host = serverAddr.String()
uri.Path = options.Path requestURL.Path = options.Path
err := sHTTP.URLSetPath(&uri, options.Path) err := sHTTP.URLSetPath(&requestURL, options.Path)
if err != nil { if err != nil {
return nil return nil
} }
@ -75,7 +76,8 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
} }
return &Client{ return &Client{
wsDialer, wsDialer,
uri.String(), requestURL,
requestURL.String(),
headers, headers,
options.MaxEarlyData, options.MaxEarlyData,
options.EarlyDataHeaderName, options.EarlyDataHeaderName,
@ -84,7 +86,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
if c.maxEarlyData <= 0 { if c.maxEarlyData <= 0 {
conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers) conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers)
if err == nil { if err == nil {
return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
} }

View file

@ -94,51 +94,64 @@ type EarlyWebsocketConn struct {
ctx context.Context ctx context.Context
conn *WebsocketConn conn *WebsocketConn
create chan struct{} create chan struct{}
err error
} }
func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
if c.conn == nil { if c.conn == nil {
<-c.create <-c.create
if c.err != nil {
return 0, c.err
}
} }
return c.conn.Read(b) return c.conn.Read(b)
} }
func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
var (
earlyData []byte
lateData []byte
conn *websocket.Conn
response *http.Response
err error
)
if len(content) > int(c.maxEarlyData) {
earlyData = content[:c.maxEarlyData]
lateData = content[c.maxEarlyData:]
} else {
earlyData = content
}
if len(earlyData) > 0 {
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
if c.earlyDataHeaderName == "" {
requestURL := c.requestURL
requestURL.Path += earlyDataString
conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
} else {
headers := c.headers.Clone()
headers.Set(c.earlyDataHeaderName, earlyDataString)
conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
}
} else {
conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
}
if err != nil {
return wrapDialError(response, err)
}
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
if len(lateData) > 0 {
_, err = c.conn.Write(lateData)
}
return err
}
func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
if c.conn != nil { if c.conn != nil {
return c.conn.Write(b) return c.conn.Write(b)
} }
var ( err = c.writeRequest(b)
earlyData []byte c.err = err
lateData []byte
conn *websocket.Conn
response *http.Response
)
if len(b) > int(c.maxEarlyData) {
earlyData = b[:c.maxEarlyData]
lateData = b[c.maxEarlyData:]
} else {
earlyData = b
}
if len(earlyData) > 0 {
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
if c.earlyDataHeaderName == "" {
conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
} else {
headers := c.headers.Clone()
headers.Set(c.earlyDataHeaderName, earlyDataString)
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
}
} else {
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
}
if err != nil {
return 0, wrapDialError(response, err)
}
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
close(c.create) close(c.create)
if len(lateData) > 0 {
_, err = c.conn.Write(lateData)
}
if err != nil { if err != nil {
return return
} }
@ -149,39 +162,9 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
if c.conn != nil { if c.conn != nil {
return c.conn.WriteBuffer(buffer) return c.conn.WriteBuffer(buffer)
} }
var ( err := c.writeRequest(buffer.Bytes())
earlyData []byte c.err = err
lateData []byte
conn *websocket.Conn
response *http.Response
err error
)
if buffer.Len() > int(c.maxEarlyData) {
earlyData = buffer.Bytes()[:c.maxEarlyData]
lateData = buffer.Bytes()[c.maxEarlyData:]
} else {
earlyData = buffer.Bytes()
}
if len(earlyData) > 0 {
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
if c.earlyDataHeaderName == "" {
conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
} else {
headers := c.headers.Clone()
headers.Set(c.earlyDataHeaderName, earlyDataString)
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
}
} else {
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
}
if err != nil {
return wrapDialError(response, err)
}
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
close(c.create) close(c.create)
if len(lateData) > 0 {
_, err = c.conn.Write(lateData)
}
return err return err
} }