From 980236f2b63af4ccbdb82a11baa601cdea568444 Mon Sep 17 00:00:00 2001 From: mmmray <142015632+mmmray@users.noreply.github.com> Date: Thu, 6 Jun 2024 02:43:44 +0200 Subject: [PATCH] preserve exact header casing when using httpupgrade (#3427) * preserve exact header casing when using httpupgrade * fix capitalization of websocket * oops, we dont need net/url either * restore old codepath when there are no headers --- transport/internet/httpupgrade/dialer.go | 78 +++++++++++++++---- .../internet/httpupgrade/httpupgrade_test.go | 61 ++++++++++++++- 2 files changed, 122 insertions(+), 17 deletions(-) diff --git a/transport/internet/httpupgrade/dialer.go b/transport/internet/httpupgrade/dialer.go index 9f909503..b70af331 100644 --- a/transport/internet/httpupgrade/dialer.go +++ b/transport/internet/httpupgrade/dialer.go @@ -65,23 +65,69 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings * requestURL.Scheme = "http" } - requestURL.Host = dest.NetAddr() - requestURL.Path = transportConfiguration.GetNormalizedPath() - req := &http.Request{ - Method: http.MethodGet, - URL: &requestURL, - Host: transportConfiguration.Host, - Header: make(http.Header), - } - for key, value := range transportConfiguration.Header { - req.Header.Add(key, value) - } - req.Header.Set("Connection", "upgrade") - req.Header.Set("Upgrade", "websocket") + var req *http.Request = nil - err = req.Write(conn) - if err != nil { - return nil, err + if len(transportConfiguration.Header) == 0 { + requestURL.Host = dest.NetAddr() + requestURL.Path = transportConfiguration.GetNormalizedPath() + req = &http.Request{ + Method: http.MethodGet, + URL: &requestURL, + Host: transportConfiguration.Host, + Header: make(http.Header), + } + + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + + err = req.Write(conn) + if err != nil { + return nil, err + } + } else { + var headersBuilder strings.Builder + + headersBuilder.WriteString("GET ") + headersBuilder.WriteString(transportConfiguration.GetNormalizedPath()) + headersBuilder.WriteString(" HTTP/1.1\r\n") + hasConnectionHeader := false + hasUpgradeHeader := false + hasHostHeader := false + for key, value := range transportConfiguration.Header { + if strings.ToLower(key) == "connection" { + hasConnectionHeader = true + } + if strings.ToLower(key) == "upgrade" { + hasUpgradeHeader = true + } + if strings.ToLower(key) == "host" { + hasHostHeader = true + } + headersBuilder.WriteString(key) + headersBuilder.WriteString(": ") + headersBuilder.WriteString(value) + headersBuilder.WriteString("\r\n") + } + + if !hasConnectionHeader { + headersBuilder.WriteString("Connection: upgrade\r\n") + } + + if !hasUpgradeHeader { + headersBuilder.WriteString("Upgrade: websocket\r\n") + } + + if !hasHostHeader { + headersBuilder.WriteString("Host: ") + headersBuilder.WriteString(transportConfiguration.Host) + headersBuilder.WriteString("\r\n") + } + + headersBuilder.WriteString("\r\n") + _, err = conn.Write([]byte(headersBuilder.String())) + if err != nil { + return nil, err + } } connRF := &ConnRF{ diff --git a/transport/internet/httpupgrade/httpupgrade_test.go b/transport/internet/httpupgrade/httpupgrade_test.go index d8d3ad84..991c37cb 100644 --- a/transport/internet/httpupgrade/httpupgrade_test.go +++ b/transport/internet/httpupgrade/httpupgrade_test.go @@ -72,6 +72,65 @@ func Test_listenHTTPUpgradeAndDial(t *testing.T) { common.Must(listen.Close()) } +func Test_listenHTTPUpgradeAndDialWithHeaders(t *testing.T) { + listenPort := tcp.PickPort() + listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ + ProtocolName: "httpupgrade", + ProtocolSettings: &Config{ + Path: "httpupgrade", + Header: map[string]string{ + "User-Agent": "Mozilla", + }, + }, + }, func(conn stat.Connection) { + go func(c stat.Connection) { + defer c.Close() + + var b [1024]byte + _, err := c.Read(b[:]) + if err != nil { + return + } + + common.Must2(c.Write([]byte("Response"))) + }(conn) + }) + common.Must(err) + + ctx := context.Background() + streamSettings := &internet.MemoryStreamConfig{ + ProtocolName: "httpupgrade", + ProtocolSettings: &Config{Path: "httpupgrade"}, + } + conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) + + common.Must(err) + _, err = conn.Write([]byte("Test connection 1")) + common.Must(err) + + var b [1024]byte + n, err := conn.Read(b[:]) + common.Must(err) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } + + common.Must(conn.Close()) + <-time.After(time.Second * 5) + conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) + common.Must(err) + _, err = conn.Write([]byte("Test connection 2")) + common.Must(err) + n, err = conn.Read(b[:]) + common.Must(err) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } + common.Must(conn.Close()) + + common.Must(listen.Close()) +} + func TestDialWithRemoteAddr(t *testing.T) { listenPort := tcp.PickPort() listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ @@ -150,4 +209,4 @@ func Test_listenHTTPUpgradeAndDial_TLS(t *testing.T) { if !end.Before(start.Add(time.Second * 5)) { t.Error("end: ", end, " start: ", start) } -} \ No newline at end of file +}