mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-21 16:11:31 +00:00
preserve exact header casing when using httpupgrade (#3427)
* preserve exact header casing when using httpupgrade * fix capitalization of websocket * oops, we dont need net/url either * restore old codepath when there are no headers
This commit is contained in:
parent
be29cc39d7
commit
980236f2b6
|
@ -65,23 +65,69 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
|
|||
requestURL.Scheme = "http"
|
||||
}
|
||||
|
||||
requestURL.Host = dest.NetAddr()
|
||||
requestURL.Path = transportConfiguration.GetNormalizedPath()
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &requestURL,
|
||||
Host: transportConfiguration.Host,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
for key, value := range transportConfiguration.Header {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
req.Header.Set("Connection", "upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
var req *http.Request = nil
|
||||
|
||||
err = req.Write(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(transportConfiguration.Header) == 0 {
|
||||
requestURL.Host = dest.NetAddr()
|
||||
requestURL.Path = transportConfiguration.GetNormalizedPath()
|
||||
req = &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &requestURL,
|
||||
Host: transportConfiguration.Host,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
req.Header.Set("Connection", "upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
|
||||
err = req.Write(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
var headersBuilder strings.Builder
|
||||
|
||||
headersBuilder.WriteString("GET ")
|
||||
headersBuilder.WriteString(transportConfiguration.GetNormalizedPath())
|
||||
headersBuilder.WriteString(" HTTP/1.1\r\n")
|
||||
hasConnectionHeader := false
|
||||
hasUpgradeHeader := false
|
||||
hasHostHeader := false
|
||||
for key, value := range transportConfiguration.Header {
|
||||
if strings.ToLower(key) == "connection" {
|
||||
hasConnectionHeader = true
|
||||
}
|
||||
if strings.ToLower(key) == "upgrade" {
|
||||
hasUpgradeHeader = true
|
||||
}
|
||||
if strings.ToLower(key) == "host" {
|
||||
hasHostHeader = true
|
||||
}
|
||||
headersBuilder.WriteString(key)
|
||||
headersBuilder.WriteString(": ")
|
||||
headersBuilder.WriteString(value)
|
||||
headersBuilder.WriteString("\r\n")
|
||||
}
|
||||
|
||||
if !hasConnectionHeader {
|
||||
headersBuilder.WriteString("Connection: upgrade\r\n")
|
||||
}
|
||||
|
||||
if !hasUpgradeHeader {
|
||||
headersBuilder.WriteString("Upgrade: websocket\r\n")
|
||||
}
|
||||
|
||||
if !hasHostHeader {
|
||||
headersBuilder.WriteString("Host: ")
|
||||
headersBuilder.WriteString(transportConfiguration.Host)
|
||||
headersBuilder.WriteString("\r\n")
|
||||
}
|
||||
|
||||
headersBuilder.WriteString("\r\n")
|
||||
_, err = conn.Write([]byte(headersBuilder.String()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
connRF := &ConnRF{
|
||||
|
|
|
@ -72,6 +72,65 @@ func Test_listenHTTPUpgradeAndDial(t *testing.T) {
|
|||
common.Must(listen.Close())
|
||||
}
|
||||
|
||||
func Test_listenHTTPUpgradeAndDialWithHeaders(t *testing.T) {
|
||||
listenPort := tcp.PickPort()
|
||||
listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||
ProtocolName: "httpupgrade",
|
||||
ProtocolSettings: &Config{
|
||||
Path: "httpupgrade",
|
||||
Header: map[string]string{
|
||||
"User-Agent": "Mozilla",
|
||||
},
|
||||
},
|
||||
}, func(conn stat.Connection) {
|
||||
go func(c stat.Connection) {
|
||||
defer c.Close()
|
||||
|
||||
var b [1024]byte
|
||||
_, err := c.Read(b[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
common.Must2(c.Write([]byte("Response")))
|
||||
}(conn)
|
||||
})
|
||||
common.Must(err)
|
||||
|
||||
ctx := context.Background()
|
||||
streamSettings := &internet.MemoryStreamConfig{
|
||||
ProtocolName: "httpupgrade",
|
||||
ProtocolSettings: &Config{Path: "httpupgrade"},
|
||||
}
|
||||
conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
|
||||
|
||||
common.Must(err)
|
||||
_, err = conn.Write([]byte("Test connection 1"))
|
||||
common.Must(err)
|
||||
|
||||
var b [1024]byte
|
||||
n, err := conn.Read(b[:])
|
||||
common.Must(err)
|
||||
if string(b[:n]) != "Response" {
|
||||
t.Error("response: ", string(b[:n]))
|
||||
}
|
||||
|
||||
common.Must(conn.Close())
|
||||
<-time.After(time.Second * 5)
|
||||
conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
|
||||
common.Must(err)
|
||||
_, err = conn.Write([]byte("Test connection 2"))
|
||||
common.Must(err)
|
||||
n, err = conn.Read(b[:])
|
||||
common.Must(err)
|
||||
if string(b[:n]) != "Response" {
|
||||
t.Error("response: ", string(b[:n]))
|
||||
}
|
||||
common.Must(conn.Close())
|
||||
|
||||
common.Must(listen.Close())
|
||||
}
|
||||
|
||||
func TestDialWithRemoteAddr(t *testing.T) {
|
||||
listenPort := tcp.PickPort()
|
||||
listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||
|
@ -150,4 +209,4 @@ func Test_listenHTTPUpgradeAndDial_TLS(t *testing.T) {
|
|||
if !end.Before(start.Add(time.Second * 5)) {
|
||||
t.Error("end: ", end, " start: ", start)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue