preserve exact header casing when using httpupgrade (#3427)

* preserve exact header casing when using httpupgrade

* fix capitalization of websocket

* oops, we dont need net/url either

* restore old codepath when there are no headers
This commit is contained in:
mmmray 2024-06-06 02:43:44 +02:00 committed by GitHub
parent be29cc39d7
commit 980236f2b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 122 additions and 17 deletions

View file

@ -65,23 +65,69 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
requestURL.Scheme = "http"
}
requestURL.Host = dest.NetAddr()
requestURL.Path = transportConfiguration.GetNormalizedPath()
req := &http.Request{
Method: http.MethodGet,
URL: &requestURL,
Host: transportConfiguration.Host,
Header: make(http.Header),
}
for key, value := range transportConfiguration.Header {
req.Header.Add(key, value)
}
req.Header.Set("Connection", "upgrade")
req.Header.Set("Upgrade", "websocket")
var req *http.Request = nil
err = req.Write(conn)
if err != nil {
return nil, err
if len(transportConfiguration.Header) == 0 {
requestURL.Host = dest.NetAddr()
requestURL.Path = transportConfiguration.GetNormalizedPath()
req = &http.Request{
Method: http.MethodGet,
URL: &requestURL,
Host: transportConfiguration.Host,
Header: make(http.Header),
}
req.Header.Set("Connection", "upgrade")
req.Header.Set("Upgrade", "websocket")
err = req.Write(conn)
if err != nil {
return nil, err
}
} else {
var headersBuilder strings.Builder
headersBuilder.WriteString("GET ")
headersBuilder.WriteString(transportConfiguration.GetNormalizedPath())
headersBuilder.WriteString(" HTTP/1.1\r\n")
hasConnectionHeader := false
hasUpgradeHeader := false
hasHostHeader := false
for key, value := range transportConfiguration.Header {
if strings.ToLower(key) == "connection" {
hasConnectionHeader = true
}
if strings.ToLower(key) == "upgrade" {
hasUpgradeHeader = true
}
if strings.ToLower(key) == "host" {
hasHostHeader = true
}
headersBuilder.WriteString(key)
headersBuilder.WriteString(": ")
headersBuilder.WriteString(value)
headersBuilder.WriteString("\r\n")
}
if !hasConnectionHeader {
headersBuilder.WriteString("Connection: upgrade\r\n")
}
if !hasUpgradeHeader {
headersBuilder.WriteString("Upgrade: websocket\r\n")
}
if !hasHostHeader {
headersBuilder.WriteString("Host: ")
headersBuilder.WriteString(transportConfiguration.Host)
headersBuilder.WriteString("\r\n")
}
headersBuilder.WriteString("\r\n")
_, err = conn.Write([]byte(headersBuilder.String()))
if err != nil {
return nil, err
}
}
connRF := &ConnRF{

View file

@ -72,6 +72,65 @@ func Test_listenHTTPUpgradeAndDial(t *testing.T) {
common.Must(listen.Close())
}
func Test_listenHTTPUpgradeAndDialWithHeaders(t *testing.T) {
listenPort := tcp.PickPort()
listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
ProtocolName: "httpupgrade",
ProtocolSettings: &Config{
Path: "httpupgrade",
Header: map[string]string{
"User-Agent": "Mozilla",
},
},
}, func(conn stat.Connection) {
go func(c stat.Connection) {
defer c.Close()
var b [1024]byte
_, err := c.Read(b[:])
if err != nil {
return
}
common.Must2(c.Write([]byte("Response")))
}(conn)
})
common.Must(err)
ctx := context.Background()
streamSettings := &internet.MemoryStreamConfig{
ProtocolName: "httpupgrade",
ProtocolSettings: &Config{Path: "httpupgrade"},
}
conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
common.Must(err)
_, err = conn.Write([]byte("Test connection 1"))
common.Must(err)
var b [1024]byte
n, err := conn.Read(b[:])
common.Must(err)
if string(b[:n]) != "Response" {
t.Error("response: ", string(b[:n]))
}
common.Must(conn.Close())
<-time.After(time.Second * 5)
conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
common.Must(err)
_, err = conn.Write([]byte("Test connection 2"))
common.Must(err)
n, err = conn.Read(b[:])
common.Must(err)
if string(b[:n]) != "Response" {
t.Error("response: ", string(b[:n]))
}
common.Must(conn.Close())
common.Must(listen.Close())
}
func TestDialWithRemoteAddr(t *testing.T) {
listenPort := tcp.PickPort()
listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
@ -150,4 +209,4 @@ func Test_listenHTTPUpgradeAndDial_TLS(t *testing.T) {
if !end.Before(start.Add(time.Second * 5)) {
t.Error("end: ", end, " start: ", start)
}
}
}