diff --git a/transport/v2raygrpclite/server.go b/transport/v2raygrpclite/server.go index 1a298b7d..b45c690b 100644 --- a/transport/v2raygrpclite/server.go +++ b/transport/v2raygrpclite/server.go @@ -12,6 +12,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/v2rayhttp" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -87,8 +88,9 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(http.StatusOK) var metadata M.Metadata metadata.Source = sHttp.SourceAddress(request) - conn := newGunConn(request.Body, writer, writer.(http.Flusher)) + conn := v2rayhttp.NewHTTP2Wrapper(newGunConn(request.Body, writer, writer.(http.Flusher))) s.handler.NewConnection(request.Context(), conn, metadata) + conn.CloseWrapper() } func (s *Server) badRequest(request *http.Request, err error) { diff --git a/transport/v2rayhttp/conn.go b/transport/v2rayhttp/conn.go index 49cc1080..89d8baa8 100644 --- a/transport/v2rayhttp/conn.go +++ b/transport/v2rayhttp/conn.go @@ -5,10 +5,14 @@ import ( "net" "net/http" "os" + "sync" "time" "github.com/sagernet/sing-box/common/baderror" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + N "github.com/sagernet/sing/common/network" ) type HTTPConn struct { @@ -105,3 +109,43 @@ func (c *ServerHTTPConn) Write(b []byte) (n int, err error) { } return } + +type HTTP2ConnWrapper struct { + N.ExtendedConn + access sync.Mutex + closed bool +} + +func NewHTTP2Wrapper(conn net.Conn) *HTTP2ConnWrapper { + return &HTTP2ConnWrapper{ + ExtendedConn: bufio.NewExtendedConn(conn), + } +} + +func (w *HTTP2ConnWrapper) Write(p []byte) (n int, err error) { + w.access.Lock() + defer w.access.Unlock() + if w.closed { + return 0, net.ErrClosed + } + return w.ExtendedConn.Write(p) +} + +func (w *HTTP2ConnWrapper) WriteBuffer(buffer *buf.Buffer) error { + w.access.Lock() + defer w.access.Unlock() + if w.closed { + return net.ErrClosed + } + return w.ExtendedConn.WriteBuffer(buffer) +} + +func (w *HTTP2ConnWrapper) CloseWrapper() { + w.access.Lock() + defer w.access.Unlock() + w.closed = true +} + +func (w *HTTP2ConnWrapper) Upstream() any { + return w.ExtendedConn +} diff --git a/transport/v2rayhttp/server.go b/transport/v2rayhttp/server.go index 369eb5c3..1777ca80 100644 --- a/transport/v2rayhttp/server.go +++ b/transport/v2rayhttp/server.go @@ -120,11 +120,12 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { } s.handler.NewConnection(request.Context(), conn, metadata) } else { - conn := &ServerHTTPConn{ + conn := NewHTTP2Wrapper(&ServerHTTPConn{ newHTTPConn(request.Body, writer), writer.(http.Flusher), - } + }) s.handler.NewConnection(request.Context(), conn, metadata) + conn.CloseWrapper() } }