From 01f6e70bc55ee59cec896663af02e4ff0d2cd59e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 1 Dec 2023 12:24:29 +0800 Subject: [PATCH] Fix deadline usage --- common/dialer/detour.go | 10 +--------- go.mod | 2 +- go.sum | 4 ++-- transport/v2raywebsocket/client.go | 28 ++++++++++++++++++++++++---- transport/v2raywebsocket/conn.go | 11 ++--------- transport/v2raywebsocket/server.go | 4 ++-- 6 files changed, 32 insertions(+), 27 deletions(-) diff --git a/common/dialer/detour.go b/common/dialer/detour.go index ff484da2..81600913 100644 --- a/common/dialer/detour.go +++ b/common/dialer/detour.go @@ -6,7 +6,6 @@ import ( "sync" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing/common/bufio/deadline" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -45,14 +44,7 @@ func (d *DetourDialer) DialContext(ctx context.Context, network string, destinat if err != nil { return nil, err } - conn, err := dialer.DialContext(ctx, network, destination) - if err != nil { - return nil, err - } - if deadline.NeedAdditionalReadDeadline(conn) { - conn = deadline.NewConn(conn) - } - return conn, nil + return dialer.DialContext(ctx, network, destination) } func (d *DetourDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { diff --git a/go.mod b/go.mod index 7efe241e..0d238773 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930 github.com/sagernet/quic-go v0.40.0 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc + github.com/sagernet/sing v0.2.18-0.20231201054122-bca74039ead5 github.com/sagernet/sing-dns v0.1.11 github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07 github.com/sagernet/sing-quic v0.1.5-0.20231123150216-00957d136203 diff --git a/go.sum b/go.sum index 4db37c0b..fb2fe317 100644 --- a/go.sum +++ b/go.sum @@ -110,8 +110,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc h1:vESVuxHgbd2EzHxd+TYTpNACIEGBOhp5n3KG7bgbcws= -github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= +github.com/sagernet/sing v0.2.18-0.20231201054122-bca74039ead5 h1:luykfsWNqFh9sdLXlkCQtkuzLUPRd3BMsdQJt0REB1g= +github.com/sagernet/sing v0.2.18-0.20231201054122-bca74039ead5/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= github.com/sagernet/sing-dns v0.1.11 h1:PPrMCVVrAeR3f5X23I+cmvacXJ+kzuyAsBiWyUKhGSE= github.com/sagernet/sing-dns v0.1.11/go.mod h1:zJ/YjnYB61SYE+ubMcMqVdpaSvsyQ2iShQGO3vuLvvE= github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07 h1:ncKb5tVOsCQgCsv6UpsA0jinbNb5OQ5GMPJlyQP3EHM= diff --git a/transport/v2raywebsocket/client.go b/transport/v2raywebsocket/client.go index 54c4df0b..7fda40cc 100644 --- a/transport/v2raywebsocket/client.go +++ b/transport/v2raywebsocket/client.go @@ -12,6 +12,9 @@ import ( "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/bufio/deadline" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -87,18 +90,35 @@ func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers h return nil, err } } - conn.SetDeadline(time.Now().Add(C.TCPTimeout)) + var deadlineConn net.Conn + if deadline.NeedAdditionalReadDeadline(conn) { + deadlineConn = deadline.NewConn(conn) + } else { + deadlineConn = conn + } + err = deadlineConn.SetDeadline(time.Now().Add(C.TCPTimeout)) + if err != nil { + return nil, E.Cause(err, "set read deadline") + } var protocols []string if protocolHeader := headers.Get("Sec-WebSocket-Protocol"); protocolHeader != "" { protocols = []string{protocolHeader} headers.Del("Sec-WebSocket-Protocol") } - reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(conn, requestURL) - conn.SetDeadline(time.Time{}) + reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(deadlineConn, requestURL) + deadlineConn.SetDeadline(time.Time{}) if err != nil { return nil, err } - return NewConn(conn, reader, nil, ws.StateClientSide), nil + if reader != nil { + buffer := buf.NewSize(reader.Buffered()) + _, err = buffer.ReadFullFrom(reader, buffer.Len()) + if err != nil { + return nil, err + } + conn = bufio.NewCachedConn(conn, buffer) + } + return NewConn(conn, nil, ws.StateClientSide), nil } func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { diff --git a/transport/v2raywebsocket/conn.go b/transport/v2raywebsocket/conn.go index 5b51e5c5..8f06b118 100644 --- a/transport/v2raywebsocket/conn.go +++ b/transport/v2raywebsocket/conn.go @@ -1,7 +1,6 @@ package v2raywebsocket import ( - "bufio" "context" "encoding/base64" "io" @@ -28,19 +27,13 @@ type WebsocketConn struct { remoteAddr net.Addr } -func NewConn(conn net.Conn, br *bufio.Reader, remoteAddr net.Addr, state ws.State) *WebsocketConn { +func NewConn(conn net.Conn, remoteAddr net.Addr, state ws.State) *WebsocketConn { controlHandler := wsutil.ControlFrameHandler(conn, state) - var reader io.Reader - if br != nil && br.Buffered() > 0 { - reader = br - } else { - reader = conn - } return &WebsocketConn{ Conn: conn, state: state, reader: &wsutil.Reader{ - Source: reader, + Source: conn, State: state, SkipHeaderCheck: !debug.Enabled, OnIntermediate: controlHandler, diff --git a/transport/v2raywebsocket/server.go b/transport/v2raywebsocket/server.go index ae6e15f3..db078675 100644 --- a/transport/v2raywebsocket/server.go +++ b/transport/v2raywebsocket/server.go @@ -88,14 +88,14 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { s.invalidRequest(writer, request, http.StatusBadRequest, E.Cause(err, "decode early data")) return } - wsConn, reader, _, err := ws.UpgradeHTTP(request, writer) + wsConn, _, _, err := ws.UpgradeHTTP(request, writer) if err != nil { s.invalidRequest(writer, request, 0, E.Cause(err, "upgrade websocket connection")) return } var metadata M.Metadata metadata.Source = sHttp.SourceAddress(request) - conn = NewConn(wsConn, reader.Reader, metadata.Source.TCPAddr(), ws.StateServerSide) + conn = NewConn(wsConn, metadata.Source.TCPAddr(), ws.StateServerSide) if len(earlyData) > 0 { conn = bufio.NewCachedConn(conn, buf.As(earlyData)) }