mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-25 10:01:30 +00:00
Fix v2ray websocket transport
This commit is contained in:
parent
120dae4eed
commit
d74abbd20e
|
@ -21,7 +21,8 @@ var _ adapter.V2RayClientTransport = (*Client)(nil)
|
|||
|
||||
type Client struct {
|
||||
dialer *websocket.Dialer
|
||||
uri string
|
||||
requestURL url.URL
|
||||
requestURLString string
|
||||
headers http.Header
|
||||
maxEarlyData uint32
|
||||
earlyDataHeaderName string
|
||||
|
@ -57,15 +58,15 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
|||
return &deadConn{conn}, nil
|
||||
}
|
||||
}
|
||||
var uri url.URL
|
||||
var requestURL url.URL
|
||||
if tlsConfig == nil {
|
||||
uri.Scheme = "ws"
|
||||
requestURL.Scheme = "ws"
|
||||
} else {
|
||||
uri.Scheme = "wss"
|
||||
requestURL.Scheme = "wss"
|
||||
}
|
||||
uri.Host = serverAddr.String()
|
||||
uri.Path = options.Path
|
||||
err := sHTTP.URLSetPath(&uri, options.Path)
|
||||
requestURL.Host = serverAddr.String()
|
||||
requestURL.Path = options.Path
|
||||
err := sHTTP.URLSetPath(&requestURL, options.Path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -75,7 +76,8 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
|||
}
|
||||
return &Client{
|
||||
wsDialer,
|
||||
uri.String(),
|
||||
requestURL,
|
||||
requestURL.String(),
|
||||
headers,
|
||||
options.MaxEarlyData,
|
||||
options.EarlyDataHeaderName,
|
||||
|
@ -84,7 +86,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
|||
|
||||
func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
if c.maxEarlyData <= 0 {
|
||||
conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers)
|
||||
conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers)
|
||||
if err == nil {
|
||||
return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
|
||||
}
|
||||
|
|
|
@ -94,51 +94,64 @@ type EarlyWebsocketConn struct {
|
|||
ctx context.Context
|
||||
conn *WebsocketConn
|
||||
create chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
|
||||
if c.conn == nil {
|
||||
<-c.create
|
||||
if c.err != nil {
|
||||
return 0, c.err
|
||||
}
|
||||
}
|
||||
return c.conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
|
||||
var (
|
||||
earlyData []byte
|
||||
lateData []byte
|
||||
conn *websocket.Conn
|
||||
response *http.Response
|
||||
err error
|
||||
)
|
||||
if len(content) > int(c.maxEarlyData) {
|
||||
earlyData = content[:c.maxEarlyData]
|
||||
lateData = content[c.maxEarlyData:]
|
||||
} else {
|
||||
earlyData = content
|
||||
}
|
||||
if len(earlyData) > 0 {
|
||||
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
|
||||
if c.earlyDataHeaderName == "" {
|
||||
requestURL := c.requestURL
|
||||
requestURL.Path += earlyDataString
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
|
||||
} else {
|
||||
headers := c.headers.Clone()
|
||||
headers.Set(c.earlyDataHeaderName, earlyDataString)
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
|
||||
}
|
||||
} else {
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
|
||||
}
|
||||
if err != nil {
|
||||
return wrapDialError(response, err)
|
||||
}
|
||||
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
|
||||
if len(lateData) > 0 {
|
||||
_, err = c.conn.Write(lateData)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
|
||||
if c.conn != nil {
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
var (
|
||||
earlyData []byte
|
||||
lateData []byte
|
||||
conn *websocket.Conn
|
||||
response *http.Response
|
||||
)
|
||||
if len(b) > int(c.maxEarlyData) {
|
||||
earlyData = b[:c.maxEarlyData]
|
||||
lateData = b[c.maxEarlyData:]
|
||||
} else {
|
||||
earlyData = b
|
||||
}
|
||||
if len(earlyData) > 0 {
|
||||
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
|
||||
if c.earlyDataHeaderName == "" {
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
|
||||
} else {
|
||||
headers := c.headers.Clone()
|
||||
headers.Set(c.earlyDataHeaderName, earlyDataString)
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
|
||||
}
|
||||
} else {
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, wrapDialError(response, err)
|
||||
}
|
||||
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
|
||||
err = c.writeRequest(b)
|
||||
c.err = err
|
||||
close(c.create)
|
||||
if len(lateData) > 0 {
|
||||
_, err = c.conn.Write(lateData)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -149,39 +162,9 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
|
|||
if c.conn != nil {
|
||||
return c.conn.WriteBuffer(buffer)
|
||||
}
|
||||
var (
|
||||
earlyData []byte
|
||||
lateData []byte
|
||||
conn *websocket.Conn
|
||||
response *http.Response
|
||||
err error
|
||||
)
|
||||
if buffer.Len() > int(c.maxEarlyData) {
|
||||
earlyData = buffer.Bytes()[:c.maxEarlyData]
|
||||
lateData = buffer.Bytes()[c.maxEarlyData:]
|
||||
} else {
|
||||
earlyData = buffer.Bytes()
|
||||
}
|
||||
if len(earlyData) > 0 {
|
||||
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
|
||||
if c.earlyDataHeaderName == "" {
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
|
||||
} else {
|
||||
headers := c.headers.Clone()
|
||||
headers.Set(c.earlyDataHeaderName, earlyDataString)
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
|
||||
}
|
||||
} else {
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
|
||||
}
|
||||
if err != nil {
|
||||
return wrapDialError(response, err)
|
||||
}
|
||||
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
|
||||
err := c.writeRequest(buffer.Bytes())
|
||||
c.err = err
|
||||
close(c.create)
|
||||
if len(lateData) > 0 {
|
||||
_, err = c.conn.Write(lateData)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue