Fix WS reading X-Forwarded-For & Add tests (#3546)

Fixes https://github.com/XTLS/Xray-core/issues/3545

---------

Co-authored-by: mmmray <142015632+mmmray@users.noreply.github.com>
This commit is contained in:
风扇滑翔翼 2024-07-17 18:40:25 +08:00 committed by GitHub
parent 9e6d7a3cb0
commit a7e198e1e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 15 additions and 11 deletions

View file

@ -151,7 +151,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
return return
} }
_, err = c.Write([]byte("Response")) _, err = c.Write([]byte(c.RemoteAddr().String()))
common.Must(err) common.Must(err)
}(conn) }(conn)
}) })
@ -169,7 +169,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
var b [1024]byte var b [1024]byte
n, err := conn.Read(b[:]) n, err := conn.Read(b[:])
common.Must(err) common.Must(err)
if string(b[:n]) != "Response" { if string(b[:n]) != "1.1.1.1:0" {
t.Error("response: ", string(b[:n])) t.Error("response: ", string(b[:n]))
} }

View file

@ -96,7 +96,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
return return
} }
_, err = c.Write([]byte("Response")) _, err = c.Write([]byte(c.RemoteAddr().String()))
common.Must(err) common.Must(err)
}(conn) }(conn)
}) })
@ -113,7 +113,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
var b [1024]byte var b [1024]byte
n, _ := conn.Read(b[:]) n, _ := conn.Read(b[:])
if string(b[:n]) != "Response" { if string(b[:n]) != "1.1.1.1:0" {
t.Error("response: ", string(b[:n])) t.Error("response: ", string(b[:n]))
} }

View file

@ -14,14 +14,18 @@ import (
var _ buf.Writer = (*connection)(nil) var _ buf.Writer = (*connection)(nil)
// connection is a wrapper for net.Conn over WebSocket connection. // connection is a wrapper for net.Conn over WebSocket connection.
// remoteAddr is used to pass "virtual" remote IP addresses in X-Forwarded-For.
// so we shouldn't directly read it form conn.
type connection struct { type connection struct {
conn *websocket.Conn conn *websocket.Conn
reader io.Reader reader io.Reader
remoteAddr net.Addr
} }
func NewConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection { func NewConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
return &connection{ return &connection{
conn: conn, conn: conn,
remoteAddr: remoteAddr,
reader: extraReader, reader: extraReader,
} }
} }
@ -90,7 +94,7 @@ func (c *connection) LocalAddr() net.Addr {
} }
func (c *connection) RemoteAddr() net.Addr { func (c *connection) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.remoteAddr
} }
func (c *connection) SetDeadline(t time.Time) error { func (c *connection) SetDeadline(t time.Time) error {

View file

@ -91,7 +91,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
return return
} }
_, err = c.Write([]byte("Response")) _, err = c.Write([]byte(c.RemoteAddr().String()))
common.Must(err) common.Must(err)
}(conn) }(conn)
}) })
@ -109,7 +109,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
var b [1024]byte var b [1024]byte
n, err := conn.Read(b[:]) n, err := conn.Read(b[:])
common.Must(err) common.Must(err)
if string(b[:n]) != "Response" { if string(b[:n]) != "1.1.1.1:0" {
t.Error("response: ", string(b[:n])) t.Error("response: ", string(b[:n]))
} }