SplitHTTP: Fix connection leaks and crashes (#3710)

This commit is contained in:
mmmray 2024-08-22 17:07:57 +02:00 committed by GitHub
parent 2be03c56cb
commit 83eef6bc1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 109 additions and 65 deletions

View file

@ -49,6 +49,8 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
var downResponse io.ReadCloser var downResponse io.ReadCloser
gotDownResponse := done.New() gotDownResponse := done.New()
ctx, ctxCancel := context.WithCancel(ctx)
go func() { go func() {
trace := &httptrace.ClientTrace{ trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) { GotConn: func(connInfo httptrace.GotConnInfo) {
@ -61,8 +63,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
// in case we hit an error, we want to unblock this part // in case we hit an error, we want to unblock this part
defer gotConn.Close() defer gotConn.Close()
ctx = httptrace.WithClientTrace(ctx, trace)
req, err := http.NewRequestWithContext( req, err := http.NewRequestWithContext(
httptrace.WithClientTrace(ctx, trace), ctx,
"GET", "GET",
baseURL, baseURL,
nil, nil,
@ -94,16 +98,17 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
gotDownResponse.Close() gotDownResponse.Close()
}() }()
if c.isH3 { if !c.isH3 {
gotConn.Close() // in quic-go, sometimes gotConn is never closed for the lifetime of
// the entire connection, and the download locks up
// https://github.com/quic-go/quic-go/issues/3342
// for other HTTP versions, we want to block Dial until we know the
// remote address of the server, for logging purposes
<-gotConn.Wait()
} }
// we want to block Dial until we know the remote address of the server,
// for logging purposes
<-gotConn.Wait()
lazyDownload := &LazyReader{ lazyDownload := &LazyReader{
CreateReader: func() (io.ReadCloser, error) { CreateReader: func() (io.Reader, error) {
<-gotDownResponse.Wait() <-gotDownResponse.Wait()
if downResponse == nil { if downResponse == nil {
return nil, errors.New("downResponse failed") return nil, errors.New("downResponse failed")
@ -112,7 +117,15 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
}, },
} }
return lazyDownload, remoteAddr, localAddr, nil // workaround for https://github.com/quic-go/quic-go/issues/2143 --
// always cancel request context so that Close cancels any Read.
// Should then match the behavior of http2 and http1.
reader := downloadBody{
lazyDownload,
ctxCancel,
}
return reader, remoteAddr, localAddr, nil
} }
func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error { func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
@ -172,3 +185,13 @@ func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string,
return nil return nil
} }
type downloadBody struct {
io.Reader
cancel context.CancelFunc
}
func (c downloadBody) Close() error {
c.cancel()
return nil
}

View file

@ -1,10 +1,8 @@
package splithttp package splithttp
import ( import (
"bytes"
"context" "context"
gotls "crypto/tls" gotls "crypto/tls"
"io"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -292,35 +290,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
return nil, err return nil, err
} }
lazyDownload := &LazyReader{ reader := &stripOkReader{ReadCloser: lazyRawDownload}
CreateReader: func() (io.ReadCloser, error) {
// skip "ok" response
trashHeader := []byte{0, 0}
_, err := io.ReadFull(lazyRawDownload, trashHeader)
if err != nil {
return nil, errors.New("failed to read initial response").Base(err)
}
if bytes.Equal(trashHeader, []byte("ok")) {
return lazyRawDownload, nil
}
// we read some garbage byte that may not have been "ok" at
// all. return a reader that replays what we have read so far
reader := io.MultiReader(
bytes.NewReader(trashHeader),
lazyRawDownload,
)
readCloser := struct {
io.Reader
io.Closer
}{
Reader: reader,
Closer: lazyRawDownload,
}
return readCloser, nil
},
}
writer := uploadWriter{ writer := uploadWriter{
uploadPipeWriter, uploadPipeWriter,
@ -329,7 +299,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
conn := splitConn{ conn := splitConn{
writer: writer, writer: writer,
reader: lazyDownload, reader: reader,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
localAddr: localAddr, localAddr: localAddr,
} }

View file

@ -222,8 +222,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
h.ln.addConn(stat.Connection(&conn)) h.ln.addConn(stat.Connection(&conn))
// "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned." // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
<-downloadDone.Wait() select {
case <-request.Context().Done():
case <-downloadDone.Wait():
}
conn.Close()
} else { } else {
writer.WriteHeader(http.StatusMethodNotAllowed) writer.WriteHeader(http.StatusMethodNotAllowed)
} }

View file

@ -3,18 +3,20 @@ package splithttp
import ( import (
"io" "io"
"sync" "sync"
"github.com/xtls/xray-core/common/errors"
) )
// Close is intentionally not supported by LazyReader because it's not clear
// how CreateReader should be aborted in case of Close. It's best to wrap
// LazyReader in another struct that handles Close correctly, or better, stop
// using LazyReader entirely.
type LazyReader struct { type LazyReader struct {
readerSync sync.Mutex readerSync sync.Mutex
CreateReader func() (io.ReadCloser, error) CreateReader func() (io.Reader, error)
reader io.ReadCloser reader io.Reader
readerError error readerError error
} }
func (r *LazyReader) getReader() (io.ReadCloser, error) { func (r *LazyReader) getReader() (io.Reader, error) {
r.readerSync.Lock() r.readerSync.Lock()
defer r.readerSync.Unlock() defer r.readerSync.Unlock()
if r.reader != nil { if r.reader != nil {
@ -43,17 +45,3 @@ func (r *LazyReader) Read(b []byte) (int, error) {
n, err := reader.Read(b) n, err := reader.Read(b)
return n, err return n, err
} }
func (r *LazyReader) Close() error {
r.readerSync.Lock()
defer r.readerSync.Unlock()
var err error
if r.reader != nil {
err = r.reader.Close()
r.reader = nil
r.readerError = errors.New("closed reader")
}
return err
}

View file

@ -248,6 +248,8 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
NextProtocol: []string{"h3"}, NextProtocol: []string{"h3"},
}, },
} }
serverClosed := false
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() { go func() {
defer conn.Close() defer conn.Close()
@ -258,10 +260,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
for { for {
b.Clear() b.Clear()
if _, err := b.ReadFrom(conn); err != nil { if _, err := b.ReadFrom(conn); err != nil {
return break
} }
common.Must2(conn.Write(b.Bytes())) common.Must2(conn.Write(b.Bytes()))
} }
serverClosed = true
}() }()
}) })
common.Must(err) common.Must(err)
@ -271,7 +275,6 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
common.Must(err) common.Must(err)
defer conn.Close()
const N = 1024 const N = 1024
b1 := make([]byte, N) b1 := make([]byte, N)
@ -294,6 +297,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
t.Error(r) t.Error(r)
} }
conn.Close()
time.Sleep(100 * time.Millisecond)
if !serverClosed {
t.Error("server did not get closed")
}
end := time.Now() end := time.Now()
if !end.Before(start.Add(time.Second * 5)) { if !end.Before(start.Add(time.Second * 5)) {
t.Error("end: ", end, " start: ", start) t.Error("end: ", end, " start: ", start)

View file

@ -0,0 +1,48 @@
package splithttp
import (
"bytes"
"io"
"github.com/xtls/xray-core/common/errors"
)
// in older versions of splithttp, the server would respond with `ok` to flush
// out HTTP response headers early. Response headers and a 200 OK were required
// to initiate the connection. Later versions of splithttp dropped this
// requirement, and in xray 1.8.24 the server stopped sending "ok" if it sees
// x_padding. For compatibility, we need to remove "ok" from the underlying
// reader if it exists, and otherwise forward the stream as-is.
type stripOkReader struct {
io.ReadCloser
firstDone bool
prefixRead []byte
}
func (r *stripOkReader) Read(b []byte) (int, error) {
if !r.firstDone {
r.firstDone = true
// skip "ok" response
prefixRead := []byte{0, 0}
_, err := io.ReadFull(r.ReadCloser, prefixRead)
if err != nil {
return 0, errors.New("failed to read initial response").Base(err)
}
if !bytes.Equal(prefixRead, []byte("ok")) {
// we read some garbage byte that may not have been "ok" at
// all. return a reader that replays what we have read so far
r.prefixRead = prefixRead
}
}
if len(r.prefixRead) > 0 {
n := copy(b, r.prefixRead)
r.prefixRead = r.prefixRead[n:]
return n, nil
}
n, err := r.ReadCloser.Read(b)
return n, err
}

View file

@ -51,8 +51,10 @@ func (h *uploadQueue) Close() error {
h.writeCloseMutex.Lock() h.writeCloseMutex.Lock()
defer h.writeCloseMutex.Unlock() defer h.writeCloseMutex.Unlock()
h.closed = true if !h.closed {
close(h.pushedPackets) h.closed = true
close(h.pushedPackets)
}
return nil return nil
} }