Fix naive padding

This commit is contained in:
世界 2022-08-24 10:21:56 +08:00
parent 22aa0c2f40
commit f87baf08d3
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 10 additions and 6 deletions

View file

@ -162,6 +162,7 @@ func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if hijacker, isHijacker := writer.(http.Hijacker); isHijacker { if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
conn, _, err := hijacker.Hijack() conn, _, err := hijacker.Hijack()
if err != nil { if err != nil {
n.badRequest(ctx, request, E.New("hijack failed"))
return return
} }
n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination) n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination)
@ -245,7 +246,7 @@ func (c *naiveH1Conn) read(p []byte) (n int, err error) {
if err != nil { if err != nil {
return return
} }
c.readRemaining = 0 c.paddingRemaining = 0
} }
if c.readPadding < kFirstPaddings { if c.readPadding < kFirstPaddings {
paddingHdr := p[:3] paddingHdr := p[:3]
@ -352,14 +353,15 @@ func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes()))) return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
} }
func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) { // FIXME
/*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
if c.readPadding < kFirstPaddings { if c.readPadding < kFirstPaddings {
n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
} else { } else {
n, err = bufio.Copy(w, c.Conn) n, err = bufio.Copy(w, c.Conn)
} }
return n, wrapHttpError(err) return n, wrapHttpError(err)
} }*/
func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) { func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
@ -415,7 +417,7 @@ func (c *naiveH2Conn) read(p []byte) (n int, err error) {
if err != nil { if err != nil {
return return
} }
c.readRemaining = 0 c.paddingRemaining = 0
} }
if c.readPadding < kFirstPaddings { if c.readPadding < kFirstPaddings {
paddingHdr := p[:3] paddingHdr := p[:3]
@ -529,14 +531,15 @@ func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
return wrapHttpError(err) return wrapHttpError(err)
} }
func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) { // FIXME
/*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
if c.readPadding < kFirstPaddings { if c.readPadding < kFirstPaddings {
n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
} else { } else {
n, err = bufio.Copy(w, c.reader) n, err = bufio.Copy(w, c.reader)
} }
return n, wrapHttpError(err) return n, wrapHttpError(err)
} }*/
func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) { func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {

View file

@ -57,6 +57,7 @@ func testTCP(t *testing.T, clientPort uint16, testPort uint16) {
return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort)) return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
} }
require.NoError(t, testPingPongWithConn(t, testPort, dialTCP)) require.NoError(t, testPingPongWithConn(t, testPort, dialTCP))
require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP))
} }
func testSuitQUIC(t *testing.T, clientPort uint16, testPort uint16) { func testSuitQUIC(t *testing.T, clientPort uint16, testPort uint16) {