diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index bcb33937..805a674e 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -146,6 +146,7 @@ func (c *TCPConfig) Build() (proto.Message, error) { } type WebSocketConfig struct { + Host string `json:"host"` Path string `json:"path"` Headers map[string]string `json:"headers"` AcceptProxyProtocol bool `json:"acceptProxyProtocol"` @@ -154,10 +155,6 @@ type WebSocketConfig struct { // Build implements Buildable. func (c *WebSocketConfig) Build() (proto.Message, error) { path := c.Path - header := make(map[string]string); - for key, value := range c.Headers { - header[key] = value; - } var ed uint32 if u, err := url.Parse(path); err == nil { if q := u.Query(); q.Get("ed") != "" { @@ -168,9 +165,18 @@ func (c *WebSocketConfig) Build() (proto.Message, error) { path = u.String() } } + // If http host is not set in the Host field, but in headers field, we add it to Host Field here. + // If we don't do that, http host will be overwritten as address. + // Host priority: Host field > headers field > address. + if c.Host == "" && c.Headers["host"] != "" { + c.Host = c.Headers["host"] + } else if c.Host == "" && c.Headers["Host"] != "" { + c.Host = c.Headers["Host"] + } config := &websocket.Config{ Path: path, - Header: header, + Host: c.Host, + Header: c.Headers, AcceptProxyProtocol: c.AcceptProxyProtocol, Ed: ed, } @@ -178,8 +184,8 @@ func (c *WebSocketConfig) Build() (proto.Message, error) { } type HttpUpgradeConfig struct { - Path string `json:"path"` Host string `json:"host"` + Path string `json:"path"` Headers map[string]string `json:"headers"` AcceptProxyProtocol bool `json:"acceptProxyProtocol"` } diff --git a/transport/internet/websocket/config.go b/transport/internet/websocket/config.go index 2e3cfea5..8cb4a855 100644 --- a/transport/internet/websocket/config.go +++ b/transport/internet/websocket/config.go @@ -25,6 +25,7 @@ func (c *Config) GetRequestHeader() http.Header { for k, v := range c.Header { header.Add(k, v) } + header.Set("Host", c.Host) return header } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index b3fe9ff4..8a860c30 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -21,6 +21,7 @@ import ( ) type requestHandler struct { + host string path string ln *Listener } @@ -37,6 +38,10 @@ var upgrader = &websocket.Upgrader{ } func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if len(h.host) > 0 && request.Host != h.host { + writer.WriteHeader(http.StatusNotFound) + return + } if request.URL.Path != h.path { writer.WriteHeader(http.StatusNotFound) return @@ -125,6 +130,7 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet l.server = http.Server{ Handler: &requestHandler{ + host: wsSettings.Host, path: wsSettings.GetNormalizedPath(), ln: l, },