mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-09 18:43:12 +00:00
preserve exact header casing when using httpupgrade (#3427)
* preserve exact header casing when using httpupgrade * fix capitalization of websocket * oops, we dont need net/url either * restore old codepath when there are no headers
This commit is contained in:
parent
be29cc39d7
commit
980236f2b6
|
@ -65,23 +65,69 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
|
||||||
requestURL.Scheme = "http"
|
requestURL.Scheme = "http"
|
||||||
}
|
}
|
||||||
|
|
||||||
requestURL.Host = dest.NetAddr()
|
var req *http.Request = nil
|
||||||
requestURL.Path = transportConfiguration.GetNormalizedPath()
|
|
||||||
req := &http.Request{
|
|
||||||
Method: http.MethodGet,
|
|
||||||
URL: &requestURL,
|
|
||||||
Host: transportConfiguration.Host,
|
|
||||||
Header: make(http.Header),
|
|
||||||
}
|
|
||||||
for key, value := range transportConfiguration.Header {
|
|
||||||
req.Header.Add(key, value)
|
|
||||||
}
|
|
||||||
req.Header.Set("Connection", "upgrade")
|
|
||||||
req.Header.Set("Upgrade", "websocket")
|
|
||||||
|
|
||||||
err = req.Write(conn)
|
if len(transportConfiguration.Header) == 0 {
|
||||||
if err != nil {
|
requestURL.Host = dest.NetAddr()
|
||||||
return nil, err
|
requestURL.Path = transportConfiguration.GetNormalizedPath()
|
||||||
|
req = &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &requestURL,
|
||||||
|
Host: transportConfiguration.Host,
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Connection", "upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
|
||||||
|
err = req.Write(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var headersBuilder strings.Builder
|
||||||
|
|
||||||
|
headersBuilder.WriteString("GET ")
|
||||||
|
headersBuilder.WriteString(transportConfiguration.GetNormalizedPath())
|
||||||
|
headersBuilder.WriteString(" HTTP/1.1\r\n")
|
||||||
|
hasConnectionHeader := false
|
||||||
|
hasUpgradeHeader := false
|
||||||
|
hasHostHeader := false
|
||||||
|
for key, value := range transportConfiguration.Header {
|
||||||
|
if strings.ToLower(key) == "connection" {
|
||||||
|
hasConnectionHeader = true
|
||||||
|
}
|
||||||
|
if strings.ToLower(key) == "upgrade" {
|
||||||
|
hasUpgradeHeader = true
|
||||||
|
}
|
||||||
|
if strings.ToLower(key) == "host" {
|
||||||
|
hasHostHeader = true
|
||||||
|
}
|
||||||
|
headersBuilder.WriteString(key)
|
||||||
|
headersBuilder.WriteString(": ")
|
||||||
|
headersBuilder.WriteString(value)
|
||||||
|
headersBuilder.WriteString("\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasConnectionHeader {
|
||||||
|
headersBuilder.WriteString("Connection: upgrade\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasUpgradeHeader {
|
||||||
|
headersBuilder.WriteString("Upgrade: websocket\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasHostHeader {
|
||||||
|
headersBuilder.WriteString("Host: ")
|
||||||
|
headersBuilder.WriteString(transportConfiguration.Host)
|
||||||
|
headersBuilder.WriteString("\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
headersBuilder.WriteString("\r\n")
|
||||||
|
_, err = conn.Write([]byte(headersBuilder.String()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
connRF := &ConnRF{
|
connRF := &ConnRF{
|
||||||
|
|
|
@ -72,6 +72,65 @@ func Test_listenHTTPUpgradeAndDial(t *testing.T) {
|
||||||
common.Must(listen.Close())
|
common.Must(listen.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_listenHTTPUpgradeAndDialWithHeaders(t *testing.T) {
|
||||||
|
listenPort := tcp.PickPort()
|
||||||
|
listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||||
|
ProtocolName: "httpupgrade",
|
||||||
|
ProtocolSettings: &Config{
|
||||||
|
Path: "httpupgrade",
|
||||||
|
Header: map[string]string{
|
||||||
|
"User-Agent": "Mozilla",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, func(conn stat.Connection) {
|
||||||
|
go func(c stat.Connection) {
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
var b [1024]byte
|
||||||
|
_, err := c.Read(b[:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
common.Must2(c.Write([]byte("Response")))
|
||||||
|
}(conn)
|
||||||
|
})
|
||||||
|
common.Must(err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
streamSettings := &internet.MemoryStreamConfig{
|
||||||
|
ProtocolName: "httpupgrade",
|
||||||
|
ProtocolSettings: &Config{Path: "httpupgrade"},
|
||||||
|
}
|
||||||
|
conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
|
||||||
|
|
||||||
|
common.Must(err)
|
||||||
|
_, err = conn.Write([]byte("Test connection 1"))
|
||||||
|
common.Must(err)
|
||||||
|
|
||||||
|
var b [1024]byte
|
||||||
|
n, err := conn.Read(b[:])
|
||||||
|
common.Must(err)
|
||||||
|
if string(b[:n]) != "Response" {
|
||||||
|
t.Error("response: ", string(b[:n]))
|
||||||
|
}
|
||||||
|
|
||||||
|
common.Must(conn.Close())
|
||||||
|
<-time.After(time.Second * 5)
|
||||||
|
conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
|
||||||
|
common.Must(err)
|
||||||
|
_, err = conn.Write([]byte("Test connection 2"))
|
||||||
|
common.Must(err)
|
||||||
|
n, err = conn.Read(b[:])
|
||||||
|
common.Must(err)
|
||||||
|
if string(b[:n]) != "Response" {
|
||||||
|
t.Error("response: ", string(b[:n]))
|
||||||
|
}
|
||||||
|
common.Must(conn.Close())
|
||||||
|
|
||||||
|
common.Must(listen.Close())
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialWithRemoteAddr(t *testing.T) {
|
func TestDialWithRemoteAddr(t *testing.T) {
|
||||||
listenPort := tcp.PickPort()
|
listenPort := tcp.PickPort()
|
||||||
listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||||
|
@ -150,4 +209,4 @@ func Test_listenHTTPUpgradeAndDial_TLS(t *testing.T) {
|
||||||
if !end.Before(start.Add(time.Second * 5)) {
|
if !end.Before(start.Add(time.Second * 5)) {
|
||||||
t.Error("end: ", end, " start: ", start)
|
t.Error("end: ", end, " start: ", start)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue