mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-22 08:31:28 +00:00
SplitHTTP: Fix connection leaks and crashes (#3710)
This commit is contained in:
parent
2be03c56cb
commit
83eef6bc1f
|
@ -49,6 +49,8 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
|
||||||
var downResponse io.ReadCloser
|
var downResponse io.ReadCloser
|
||||||
gotDownResponse := done.New()
|
gotDownResponse := done.New()
|
||||||
|
|
||||||
|
ctx, ctxCancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
trace := &httptrace.ClientTrace{
|
trace := &httptrace.ClientTrace{
|
||||||
GotConn: func(connInfo httptrace.GotConnInfo) {
|
GotConn: func(connInfo httptrace.GotConnInfo) {
|
||||||
|
@ -61,8 +63,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
|
||||||
// in case we hit an error, we want to unblock this part
|
// in case we hit an error, we want to unblock this part
|
||||||
defer gotConn.Close()
|
defer gotConn.Close()
|
||||||
|
|
||||||
|
ctx = httptrace.WithClientTrace(ctx, trace)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(
|
req, err := http.NewRequestWithContext(
|
||||||
httptrace.WithClientTrace(ctx, trace),
|
ctx,
|
||||||
"GET",
|
"GET",
|
||||||
baseURL,
|
baseURL,
|
||||||
nil,
|
nil,
|
||||||
|
@ -94,16 +98,17 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
|
||||||
gotDownResponse.Close()
|
gotDownResponse.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if c.isH3 {
|
if !c.isH3 {
|
||||||
gotConn.Close()
|
// in quic-go, sometimes gotConn is never closed for the lifetime of
|
||||||
|
// the entire connection, and the download locks up
|
||||||
|
// https://github.com/quic-go/quic-go/issues/3342
|
||||||
|
// for other HTTP versions, we want to block Dial until we know the
|
||||||
|
// remote address of the server, for logging purposes
|
||||||
|
<-gotConn.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// we want to block Dial until we know the remote address of the server,
|
|
||||||
// for logging purposes
|
|
||||||
<-gotConn.Wait()
|
|
||||||
|
|
||||||
lazyDownload := &LazyReader{
|
lazyDownload := &LazyReader{
|
||||||
CreateReader: func() (io.ReadCloser, error) {
|
CreateReader: func() (io.Reader, error) {
|
||||||
<-gotDownResponse.Wait()
|
<-gotDownResponse.Wait()
|
||||||
if downResponse == nil {
|
if downResponse == nil {
|
||||||
return nil, errors.New("downResponse failed")
|
return nil, errors.New("downResponse failed")
|
||||||
|
@ -112,7 +117,15 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return lazyDownload, remoteAddr, localAddr, nil
|
// workaround for https://github.com/quic-go/quic-go/issues/2143 --
|
||||||
|
// always cancel request context so that Close cancels any Read.
|
||||||
|
// Should then match the behavior of http2 and http1.
|
||||||
|
reader := downloadBody{
|
||||||
|
lazyDownload,
|
||||||
|
ctxCancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
return reader, remoteAddr, localAddr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
|
func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
|
||||||
|
@ -172,3 +185,13 @@ func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string,
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type downloadBody struct {
|
||||||
|
io.Reader
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c downloadBody) Close() error {
|
||||||
|
c.cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
package splithttp
|
package splithttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
gotls "crypto/tls"
|
gotls "crypto/tls"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -292,35 +290,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
lazyDownload := &LazyReader{
|
reader := &stripOkReader{ReadCloser: lazyRawDownload}
|
||||||
CreateReader: func() (io.ReadCloser, error) {
|
|
||||||
// skip "ok" response
|
|
||||||
trashHeader := []byte{0, 0}
|
|
||||||
_, err := io.ReadFull(lazyRawDownload, trashHeader)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("failed to read initial response").Base(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Equal(trashHeader, []byte("ok")) {
|
|
||||||
return lazyRawDownload, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// we read some garbage byte that may not have been "ok" at
|
|
||||||
// all. return a reader that replays what we have read so far
|
|
||||||
reader := io.MultiReader(
|
|
||||||
bytes.NewReader(trashHeader),
|
|
||||||
lazyRawDownload,
|
|
||||||
)
|
|
||||||
readCloser := struct {
|
|
||||||
io.Reader
|
|
||||||
io.Closer
|
|
||||||
}{
|
|
||||||
Reader: reader,
|
|
||||||
Closer: lazyRawDownload,
|
|
||||||
}
|
|
||||||
return readCloser, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := uploadWriter{
|
writer := uploadWriter{
|
||||||
uploadPipeWriter,
|
uploadPipeWriter,
|
||||||
|
@ -329,7 +299,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
|
||||||
|
|
||||||
conn := splitConn{
|
conn := splitConn{
|
||||||
writer: writer,
|
writer: writer,
|
||||||
reader: lazyDownload,
|
reader: reader,
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
localAddr: localAddr,
|
localAddr: localAddr,
|
||||||
}
|
}
|
||||||
|
|
|
@ -222,8 +222,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
|
||||||
h.ln.addConn(stat.Connection(&conn))
|
h.ln.addConn(stat.Connection(&conn))
|
||||||
|
|
||||||
// "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
|
// "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
|
||||||
<-downloadDone.Wait()
|
select {
|
||||||
|
case <-request.Context().Done():
|
||||||
|
case <-downloadDone.Wait():
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
} else {
|
} else {
|
||||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,18 +3,20 @@ package splithttp
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/xtls/xray-core/common/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Close is intentionally not supported by LazyReader because it's not clear
|
||||||
|
// how CreateReader should be aborted in case of Close. It's best to wrap
|
||||||
|
// LazyReader in another struct that handles Close correctly, or better, stop
|
||||||
|
// using LazyReader entirely.
|
||||||
type LazyReader struct {
|
type LazyReader struct {
|
||||||
readerSync sync.Mutex
|
readerSync sync.Mutex
|
||||||
CreateReader func() (io.ReadCloser, error)
|
CreateReader func() (io.Reader, error)
|
||||||
reader io.ReadCloser
|
reader io.Reader
|
||||||
readerError error
|
readerError error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *LazyReader) getReader() (io.ReadCloser, error) {
|
func (r *LazyReader) getReader() (io.Reader, error) {
|
||||||
r.readerSync.Lock()
|
r.readerSync.Lock()
|
||||||
defer r.readerSync.Unlock()
|
defer r.readerSync.Unlock()
|
||||||
if r.reader != nil {
|
if r.reader != nil {
|
||||||
|
@ -43,17 +45,3 @@ func (r *LazyReader) Read(b []byte) (int, error) {
|
||||||
n, err := reader.Read(b)
|
n, err := reader.Read(b)
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *LazyReader) Close() error {
|
|
||||||
r.readerSync.Lock()
|
|
||||||
defer r.readerSync.Unlock()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if r.reader != nil {
|
|
||||||
err = r.reader.Close()
|
|
||||||
r.reader = nil
|
|
||||||
r.readerError = errors.New("closed reader")
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
|
@ -248,6 +248,8 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
|
||||||
NextProtocol: []string{"h3"},
|
NextProtocol: []string{"h3"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serverClosed := false
|
||||||
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||||
go func() {
|
go func() {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
@ -258,10 +260,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
|
||||||
for {
|
for {
|
||||||
b.Clear()
|
b.Clear()
|
||||||
if _, err := b.ReadFrom(conn); err != nil {
|
if _, err := b.ReadFrom(conn); err != nil {
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
common.Must2(conn.Write(b.Bytes()))
|
common.Must2(conn.Write(b.Bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serverClosed = true
|
||||||
}()
|
}()
|
||||||
})
|
})
|
||||||
common.Must(err)
|
common.Must(err)
|
||||||
|
@ -271,7 +275,6 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
|
||||||
|
|
||||||
conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
|
conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
|
||||||
common.Must(err)
|
common.Must(err)
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
const N = 1024
|
const N = 1024
|
||||||
b1 := make([]byte, N)
|
b1 := make([]byte, N)
|
||||||
|
@ -294,6 +297,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
|
||||||
t.Error(r)
|
t.Error(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if !serverClosed {
|
||||||
|
t.Error("server did not get closed")
|
||||||
|
}
|
||||||
|
|
||||||
end := time.Now()
|
end := time.Now()
|
||||||
if !end.Before(start.Add(time.Second * 5)) {
|
if !end.Before(start.Add(time.Second * 5)) {
|
||||||
t.Error("end: ", end, " start: ", start)
|
t.Error("end: ", end, " start: ", start)
|
||||||
|
|
48
transport/internet/splithttp/strip_ok_reader.go
Normal file
48
transport/internet/splithttp/strip_ok_reader.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package splithttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/xtls/xray-core/common/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// in older versions of splithttp, the server would respond with `ok` to flush
|
||||||
|
// out HTTP response headers early. Response headers and a 200 OK were required
|
||||||
|
// to initiate the connection. Later versions of splithttp dropped this
|
||||||
|
// requirement, and in xray 1.8.24 the server stopped sending "ok" if it sees
|
||||||
|
// x_padding. For compatibility, we need to remove "ok" from the underlying
|
||||||
|
// reader if it exists, and otherwise forward the stream as-is.
|
||||||
|
type stripOkReader struct {
|
||||||
|
io.ReadCloser
|
||||||
|
firstDone bool
|
||||||
|
prefixRead []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stripOkReader) Read(b []byte) (int, error) {
|
||||||
|
if !r.firstDone {
|
||||||
|
r.firstDone = true
|
||||||
|
|
||||||
|
// skip "ok" response
|
||||||
|
prefixRead := []byte{0, 0}
|
||||||
|
_, err := io.ReadFull(r.ReadCloser, prefixRead)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.New("failed to read initial response").Base(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(prefixRead, []byte("ok")) {
|
||||||
|
// we read some garbage byte that may not have been "ok" at
|
||||||
|
// all. return a reader that replays what we have read so far
|
||||||
|
r.prefixRead = prefixRead
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.prefixRead) > 0 {
|
||||||
|
n := copy(b, r.prefixRead)
|
||||||
|
r.prefixRead = r.prefixRead[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := r.ReadCloser.Read(b)
|
||||||
|
return n, err
|
||||||
|
}
|
|
@ -51,8 +51,10 @@ func (h *uploadQueue) Close() error {
|
||||||
h.writeCloseMutex.Lock()
|
h.writeCloseMutex.Lock()
|
||||||
defer h.writeCloseMutex.Unlock()
|
defer h.writeCloseMutex.Unlock()
|
||||||
|
|
||||||
h.closed = true
|
if !h.closed {
|
||||||
close(h.pushedPackets)
|
h.closed = true
|
||||||
|
close(h.pushedPackets)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue