diff --git a/transport/v2rayhttp/client.go b/transport/v2rayhttp/client.go index b0acde33..56616bd7 100644 --- a/transport/v2rayhttp/client.go +++ b/transport/v2rayhttp/client.go @@ -74,13 +74,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { if !c.http2 { - return c.dialHTTP() + return c.dialHTTP(ctx) } else { - return c.dialHTTP2() + return c.dialHTTP2(ctx) } } -func (c *Client) dialHTTP() (net.Conn, error) { +func (c *Client) dialHTTP(ctx context.Context) (net.Conn, error) { conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr) if err != nil { return nil, err @@ -92,6 +92,7 @@ func (c *Client) dialHTTP() (net.Conn, error) { Proto: "HTTP/1.1", Header: c.headers.Clone(), } + request = request.WithContext(ctx) switch hostLen := len(c.host); hostLen { case 0: case 1: @@ -114,7 +115,7 @@ func (c *Client) dialHTTP() (net.Conn, error) { return conn, nil } -func (c *Client) dialHTTP2() (net.Conn, error) { +func (c *Client) dialHTTP2(ctx context.Context) (net.Conn, error) { pipeInReader, pipeInWriter := io.Pipe() request := &http.Request{ Method: c.method, @@ -124,6 +125,7 @@ func (c *Client) dialHTTP2() (net.Conn, error) { Proto: "HTTP/2", Header: c.headers.Clone(), } + request = request.WithContext(ctx) switch hostLen := len(c.host); hostLen { case 0: case 1: diff --git a/transport/v2rayhttp/server.go b/transport/v2rayhttp/server.go index 92d2d3e6..f82f92b6 100644 --- a/transport/v2rayhttp/server.go +++ b/transport/v2rayhttp/server.go @@ -90,9 +90,8 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { } writer.WriteHeader(http.StatusOK) - if f, ok := writer.(http.Flusher); ok { - f.Flush() - } + writer.(http.Flusher).Flush() + var metadata M.Metadata metadata.Source = sHttp.SourceAddress(request) if h, ok := writer.(http.Hijacker); ok {