From f87baf08d3c9f91108334d988ce2df3c6895fe2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 24 Aug 2022 10:21:56 +0800 Subject: [PATCH] Fix naive padding --- inbound/naive.go | 15 +++++++++------ test/box_test.go | 1 + 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/inbound/naive.go b/inbound/naive.go index a2f4a063..cd4591cc 100644 --- a/inbound/naive.go +++ b/inbound/naive.go @@ -162,6 +162,7 @@ func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if hijacker, isHijacker := writer.(http.Hijacker); isHijacker { conn, _, err := hijacker.Hijack() if err != nil { + n.badRequest(ctx, request, E.New("hijack failed")) return } n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination) @@ -245,7 +246,7 @@ func (c *naiveH1Conn) read(p []byte) (n int, err error) { if err != nil { return } - c.readRemaining = 0 + c.paddingRemaining = 0 } if c.readPadding < kFirstPaddings { paddingHdr := p[:3] @@ -352,14 +353,15 @@ func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error { return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes()))) } -func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) { +// FIXME +/*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) { if c.readPadding < kFirstPaddings { n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) } else { n, err = bufio.Copy(w, c.Conn) } return n, wrapHttpError(err) -} +}*/ func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) { if c.writePadding < kFirstPaddings { @@ -415,7 +417,7 @@ func (c *naiveH2Conn) read(p []byte) (n int, err error) { if err != nil { return } - c.readRemaining = 0 + c.paddingRemaining = 0 } if c.readPadding < kFirstPaddings { paddingHdr := p[:3] @@ -529,14 +531,15 @@ func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error { return wrapHttpError(err) } -func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) { +// FIXME +/*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) { if c.readPadding < kFirstPaddings { n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) } else { n, err = bufio.Copy(w, c.reader) } return n, wrapHttpError(err) -} +}*/ func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) { if c.writePadding < kFirstPaddings { diff --git a/test/box_test.go b/test/box_test.go index a84023e7..aef5ea43 100644 --- a/test/box_test.go +++ b/test/box_test.go @@ -57,6 +57,7 @@ func testTCP(t *testing.T, clientPort uint16, testPort uint16) { return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort)) } require.NoError(t, testPingPongWithConn(t, testPort, dialTCP)) + require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP)) } func testSuitQUIC(t *testing.T, clientPort uint16, testPort uint16) {