diff --git a/transport/internet/browser_dialer/dialer.go b/transport/internet/browser_dialer/dialer.go index 3f85f6a6..1991284d 100644 --- a/transport/internet/browser_dialer/dialer.go +++ b/transport/internet/browser_dialer/dialer.go @@ -5,6 +5,7 @@ import ( "context" _ "embed" "encoding/base64" + "encoding/json" "net/http" "time" @@ -17,6 +18,12 @@ import ( //go:embed dialer.html var webpage []byte +type task struct { + Method string `json:"method"` + URL string `json:"url"` + Extra any `json:"extra,omitempty"` +} + var conns chan *websocket.Conn var upgrader = &websocket.Upgrader{ @@ -55,23 +62,69 @@ func HasBrowserDialer() bool { return conns != nil } +type webSocketExtra struct { + Protocol string `json:"protocol,omitempty"` +} + func DialWS(uri string, ed []byte) (*websocket.Conn, error) { - data := []byte("WS " + uri) - if ed != nil { - data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...) + task := task{ + Method: "WS", + URL: uri, } - return dialRaw(data) + if ed != nil { + task.Extra = webSocketExtra{ + Protocol: base64.RawURLEncoding.EncodeToString(ed), + } + } + + return dialTask(task) } -func DialGet(uri string) (*websocket.Conn, error) { - data := []byte("GET " + uri) - return dialRaw(data) +type httpExtra struct { + Referrer string `json:"referrer,omitempty"` + Headers map[string]string `json:"headers,omitempty"` } -func DialPost(uri string, payload []byte) error { - data := []byte("POST " + uri) - conn, err := dialRaw(data) +func httpExtraFromHeaders(headers http.Header) *httpExtra { + if len(headers) == 0 { + return nil + } + + extra := httpExtra{} + if referrer := headers.Get("Referer"); referrer != "" { + extra.Referrer = referrer + headers.Del("Referer") + } + + if len(headers) > 0 { + extra.Headers = make(map[string]string) + for header := range headers { + extra.Headers[header] = headers.Get(header) + } + } + + return &extra +} + +func DialGet(uri string, headers http.Header) (*websocket.Conn, error) { + task := task{ + Method: "GET", + URL: uri, + Extra: httpExtraFromHeaders(headers), + } + + return dialTask(task) +} + +func DialPost(uri string, headers http.Header, payload []byte) error { + task := task{ + Method: "POST", + URL: uri, + Extra: httpExtraFromHeaders(headers), + } + + conn, err := dialTask(task) if err != nil { return err } @@ -90,7 +143,12 @@ func DialPost(uri string, payload []byte) error { return nil } -func dialRaw(data []byte) (*websocket.Conn, error) { +func dialTask(task task) (*websocket.Conn, error) { + data, err := json.Marshal(task) + if err != nil { + return nil, err + } + var conn *websocket.Conn for { conn = <-conns @@ -100,7 +158,7 @@ func dialRaw(data []byte) (*websocket.Conn, error) { break } } - err := CheckOK(conn) + err = CheckOK(conn) if err != nil { return nil, err } diff --git a/transport/internet/browser_dialer/dialer.html b/transport/internet/browser_dialer/dialer.html index 558db627..c62135ae 100644 --- a/transport/internet/browser_dialer/dialer.html +++ b/transport/internet/browser_dialer/dialer.html @@ -14,10 +14,28 @@ let upstreamGetCount = 0; let upstreamWsCount = 0; let upstreamPostCount = 0; + + function prepareRequestInit(extra) { + const requestInit = {}; + if (extra.referrer) { + // note: we have to strip the protocol and host part. + // Browsers disallow that, and will reset the value to current page if attempted. + const referrer = URL.parse(extra.referrer); + requestInit.referrer = referrer.pathname + referrer.search + referrer.hash; + requestInit.referrerPolicy = "unsafe-url"; + } + + if (extra.headers) { + requestInit.headers = extra.headers; + } + + return requestInit; + } + let check = function () { if (clientIdleCount > 0) { return; - }; + } clientIdleCount += 1; console.log("Prepare", url); let ws = new WebSocket(url); @@ -29,12 +47,12 @@ // double-checking that this continues to work ws.onmessage = function (event) { clientIdleCount -= 1; - let [method, url, protocol] = event.data.split(" "); - switch (method) { + let task = JSON.parse(event.data); + switch (task.method) { case "WS": { upstreamWsCount += 1; - console.log("Dial WS", url, protocol); - const wss = new WebSocket(url, protocol); + console.log("Dial WS", task.url, task.extra.protocol); + const wss = new WebSocket(task.url, task.extra.protocol); wss.binaryType = "arraybuffer"; let opened = false; ws.onmessage = function (event) { @@ -60,10 +78,12 @@ wss.close() }; break; - }; + } case "GET": { (async () => { - console.log("Dial GET", url); + const requestInit = prepareRequestInit(task.extra); + + console.log("Dial GET", task.url); ws.send("ok"); const controller = new AbortController(); @@ -83,58 +103,62 @@ ws.onclose = (event) => { try { reader && reader.cancel(); - } catch(e) {}; + } catch(e) {} try { controller.abort(); - } catch(e) {}; + } catch(e) {} }; try { upstreamGetCount += 1; - const response = await fetch(url, {signal: controller.signal}); + + requestInit.signal = controller.signal; + const response = await fetch(task.url, requestInit); const body = await response.body; reader = body.getReader(); while (true) { const { done, value } = await reader.read(); - ws.send(value); + if (value) ws.send(value); // don't send back "undefined" string when received nothing if (done) break; - }; + } } finally { upstreamGetCount -= 1; console.log("Dial GET DONE, remaining: ", upstreamGetCount); ws.close(); - }; + } })(); break; - }; + } case "POST": { upstreamPostCount += 1; - console.log("Dial POST", url); + + const requestInit = prepareRequestInit(task.extra); + requestInit.method = "POST"; + + console.log("Dial POST", task.url); ws.send("ok"); ws.onmessage = async (event) => { try { - const response = await fetch( - url, - {method: "POST", body: event.data} - ); + requestInit.body = event.data; + const response = await fetch(task.url, requestInit); if (response.ok) { ws.send("ok"); } else { console.error("bad status code"); ws.send("fail"); - }; + } } finally { upstreamPostCount -= 1; console.log("Dial POST DONE, remaining: ", upstreamPostCount); ws.close(); - }; + } }; break; - }; - }; + } + } check(); }; diff --git a/transport/internet/splithttp/browser_client.go b/transport/internet/splithttp/browser_client.go index d5d3d942..f4c9becd 100644 --- a/transport/internet/splithttp/browser_client.go +++ b/transport/internet/splithttp/browser_client.go @@ -5,13 +5,15 @@ import ( "io" gonet "net" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/transport/internet/browser_dialer" "github.com/xtls/xray-core/transport/internet/websocket" ) -// implements splithttp.DialerClient in terms of browser dialer -// has no fields because everything is global state :O) -type BrowserDialerClient struct{} +// BrowserDialerClient implements splithttp.DialerClient in terms of browser dialer +type BrowserDialerClient struct { + transportConfig *Config +} func (c *BrowserDialerClient) IsClosed() bool { panic("not implemented yet") @@ -19,10 +21,10 @@ func (c *BrowserDialerClient) IsClosed() bool { func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (io.ReadCloser, gonet.Addr, gonet.Addr, error) { if body != nil { - panic("not implemented yet") + return nil, nil, nil, errors.New("bidirectional streaming for browser dialer not implemented yet") } - conn, err := browser_dialer.DialGet(url) + conn, err := browser_dialer.DialGet(url, c.transportConfig.GetRequestHeader()) dummyAddr := &gonet.IPAddr{} if err != nil { return nil, dummyAddr, dummyAddr, err @@ -37,7 +39,7 @@ func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, body i return err } - err = browser_dialer.DialPost(url, bytes) + err = browser_dialer.DialPost(url, c.transportConfig.GetRequestHeader(), bytes) if err != nil { return err } diff --git a/transport/internet/splithttp/config.go b/transport/internet/splithttp/config.go index a76bf0e4..5bcd4865 100644 --- a/transport/internet/splithttp/config.go +++ b/transport/internet/splithttp/config.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "math/big" "net/http" + "net/url" "strings" "github.com/xtls/xray-core/common" @@ -11,6 +12,8 @@ import ( "github.com/xtls/xray-core/transport/internet" ) +const paddingQuery = "x_padding" + func (c *Config) GetNormalizedPath() string { pathAndQuery := strings.SplitN(c.Path, "?", 2) path := pathAndQuery[0] @@ -39,11 +42,6 @@ func (c *Config) GetNormalizedQuery() string { } query += "x_version=" + core.Version() - paddingLen := c.GetNormalizedXPaddingBytes().rand() - if paddingLen > 0 { - query += "&x_padding=" + strings.Repeat("0", int(paddingLen)) - } - return query } @@ -53,6 +51,28 @@ func (c *Config) GetRequestHeader() http.Header { header.Add(k, v) } + paddingLen := c.GetNormalizedXPaddingBytes().rand() + if paddingLen > 0 { + query, err := url.ParseQuery(c.GetNormalizedQuery()) + if err != nil { + query = url.Values{} + } + // https://www.rfc-editor.org/rfc/rfc7541.html#appendix-B + // h2's HPACK Header Compression feature employs a huffman encoding using a static table. + // 'X' is assigned an 8 bit code, so HPACK compression won't change actual padding length on the wire. + // https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.2-2 + // h3's similar QPACK feature uses the same huffman table. + query.Set(paddingQuery, strings.Repeat("X", int(paddingLen))) + + referrer := url.URL{ + Scheme: "https", // maybe http actually, but this part is not being checked + Host: c.Host, + Path: c.GetNormalizedPath(), + RawQuery: query.Encode(), + } + + header.Set("Referer", referrer.String()) + } return header } @@ -63,7 +83,7 @@ func (c *Config) WriteResponseHeader(writer http.ResponseWriter) { writer.Header().Set("X-Version", core.Version()) paddingLen := c.GetNormalizedXPaddingBytes().rand() if paddingLen > 0 { - writer.Header().Set("X-Padding", strings.Repeat("0", int(paddingLen))) + writer.Header().Set("X-Padding", strings.Repeat("X", int(paddingLen))) } } diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 22f854cd..6a276e8f 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -53,8 +53,8 @@ var ( func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *XmuxClient) { realityConfig := reality.ConfigFromStreamSettings(streamSettings) - if browser_dialer.HasBrowserDialer() && realityConfig != nil { - return &BrowserDialerClient{}, nil + if browser_dialer.HasBrowserDialer() && realityConfig == nil { + return &BrowserDialerClient{transportConfig: streamSettings.ProtocolSettings.(*Config)}, nil } globalDialerAccess.Lock() @@ -367,15 +367,18 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me }, } + var err error if mode == "stream-one" { requestURL.Path = transportConfiguration.GetNormalizedPath() if xmuxClient != nil { xmuxClient.LeftRequests.Add(-1) } - conn.reader, conn.remoteAddr, conn.localAddr, _ = httpClient.OpenStream(ctx, requestURL.String(), reader, false) + conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient.OpenStream(ctx, requestURL.String(), reader, false) + if err != nil { // browser dialer only + return nil, err + } return stat.Connection(&conn), nil } else { // stream-down - var err error if xmuxClient2 != nil { xmuxClient2.LeftRequests.Add(-1) } @@ -388,7 +391,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if xmuxClient != nil { xmuxClient.LeftRequests.Add(-1) } - httpClient.OpenStream(ctx, requestURL.String(), reader, true) + _, _, _, err = httpClient.OpenStream(ctx, requestURL.String(), reader, true) + if err != nil { // browser dialer only + return nil, err + } return stat.Connection(&conn), nil } @@ -428,8 +434,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me // can reassign Path (potentially concurrently) url := requestURL url.Path += "/" + strconv.FormatInt(seq, 10) - // reassign query to get different padding - url.RawQuery = transportConfiguration.GetNormalizedQuery() seq += 1 diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 0d8c20da..e5465822 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -7,6 +7,7 @@ import ( "io" gonet "net" "net/http" + "net/url" "strconv" "strings" "sync" @@ -110,9 +111,23 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } validRange := h.config.GetNormalizedXPaddingBytes() - x_padding := int32(len(request.URL.Query().Get("x_padding"))) - if validRange.To > 0 && (x_padding < validRange.From || x_padding > validRange.To) { - errors.LogInfo(context.Background(), "invalid x_padding length:", x_padding) + paddingLength := -1 + + if referrerPadding := request.Header.Get("Referer"); referrerPadding != "" { + // Browser dialer cannot control the host part of referrer header, so only check the query + if referrerURL, err := url.Parse(referrerPadding); err == nil { + if query := referrerURL.Query(); query.Has(paddingQuery) { + paddingLength = len(query.Get(paddingQuery)) + } + } + } + + if paddingLength == -1 { + paddingLength = len(request.URL.Query().Get(paddingQuery)) + } + + if validRange.To > 0 && (int32(paddingLength) < validRange.From || int32(paddingLength) > validRange.To) { + errors.LogInfo(context.Background(), "invalid x_padding length:", int32(paddingLength)) writer.WriteHeader(http.StatusBadRequest) return } @@ -185,10 +200,10 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req return } - payload, err := io.ReadAll(request.Body) + payload, err := io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1)) if len(payload) > scMaxEachPostBytes { - errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request had size ", len(payload), ". Adjust scMaxEachPostBytes on the server to be at least as large as client.") + errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.") writer.WriteHeader(http.StatusRequestEntityTooLarge) return }