From ff209471d832b81c44b92f78a78bfef0b82d05a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 7 Sep 2023 19:26:45 +0800 Subject: [PATCH] Fix QUIC defragger --- common/baderror/baderror.go | 62 --------------------------------- transport/hysteria/wrap.go | 2 +- transport/tuic/client.go | 2 +- transport/tuic/client_packet.go | 4 +-- transport/tuic/packet.go | 22 ++++++++---- transport/tuic/server.go | 4 +-- transport/tuic/server_packet.go | 2 +- transport/v2raygrpc/conn.go | 2 +- transport/v2raygrpclite/conn.go | 2 +- transport/v2rayhttp/conn.go | 2 +- 10 files changed, 26 insertions(+), 78 deletions(-) delete mode 100644 common/baderror/baderror.go diff --git a/common/baderror/baderror.go b/common/baderror/baderror.go deleted file mode 100644 index 952dac8f..00000000 --- a/common/baderror/baderror.go +++ /dev/null @@ -1,62 +0,0 @@ -package baderror - -import ( - "context" - "io" - "net" - "strings" - - E "github.com/sagernet/sing/common/exceptions" -) - -func Contains(err error, msgList ...string) bool { - for _, msg := range msgList { - if strings.Contains(err.Error(), msg) { - return true - } - } - return false -} - -func WrapH2(err error) error { - if err == nil { - return nil - } - err = E.Unwrap(err) - if err == io.ErrUnexpectedEOF { - return io.EOF - } - if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") { - return net.ErrClosed - } - return err -} - -func WrapGRPC(err error) error { - // grpc uses stupid internal error types - if err == nil { - return nil - } - if Contains(err, "EOF") { - return io.EOF - } - if Contains(err, "Canceled") { - return context.Canceled - } - if Contains(err, - "the client connection is closing", - "server closed the stream without sending trailers") { - return net.ErrClosed - } - return err -} - -func WrapQUIC(err error) error { - if err == nil { - return nil - } - if Contains(err, "canceled by local with error code 0") { - return net.ErrClosed - } - return err -} diff --git a/transport/hysteria/wrap.go b/transport/hysteria/wrap.go index 8d2a63c8..e89ac95e 100644 --- a/transport/hysteria/wrap.go +++ b/transport/hysteria/wrap.go @@ -6,8 +6,8 @@ import ( "syscall" "github.com/sagernet/quic-go" - "github.com/sagernet/sing-box/common/baderror" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/baderror" ) type PacketConnWrapper struct { diff --git a/transport/tuic/client.go b/transport/tuic/client.go index f4a00019..841e61c9 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -10,8 +10,8 @@ import ( "time" "github.com/sagernet/quic-go" - "github.com/sagernet/sing-box/common/baderror" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" diff --git a/transport/tuic/client_packet.go b/transport/tuic/client_packet.go index b4292e94..48da97a5 100644 --- a/transport/tuic/client_packet.go +++ b/transport/tuic/client_packet.go @@ -34,7 +34,7 @@ func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { } switch data[1] { case CommandPacket: - message := udpMessagePool.Get().(*udpMessage) + message := allocMessage() err := decodeUDPMessage(message, data[2:]) if err != nil { message.release() @@ -82,7 +82,7 @@ func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.Receive return E.New("unknown command ", command) } reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream) - message := udpMessagePool.Get().(*udpMessage) + message := allocMessage() err = readUDPMessage(message, reader) if err != nil { message.release() diff --git a/transport/tuic/packet.go b/transport/tuic/packet.go index 0c7b10db..abc46206 100644 --- a/transport/tuic/packet.go +++ b/transport/tuic/packet.go @@ -27,11 +27,16 @@ var udpMessagePool = sync.Pool{ }, } +func allocMessage() *udpMessage { + message := udpMessagePool.Get().(*udpMessage) + message.referenced = true + return message +} + func releaseMessages(messages []*udpMessage) { for _, message := range messages { if message != nil { - *message = udpMessage{} - udpMessagePool.Put(message) + message.release() } } } @@ -43,9 +48,13 @@ type udpMessage struct { fragmentID uint8 destination M.Socksaddr data *buf.Buffer + referenced bool } func (m *udpMessage) release() { + if !m.referenced { + return + } *m = udpMessage{} udpMessagePool.Put(m) } @@ -83,7 +92,7 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { originPacket := message.data.Bytes() udpMTU := maxPacketSize - message.headerSize() for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { - fragment := udpMessagePool.Get().(*udpMessage) + fragment := allocMessage() *fragment = *message if remaining > udpMTU { fragment.data = buf.As(originPacket[:udpMTU]) @@ -214,7 +223,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) c.packetId.Store(0) packetId = 0 } - message := udpMessagePool.Get().(*udpMessage) + message := allocMessage() *message = udpMessage{ sessionID: c.sessionID, packetID: uint16(packetId), @@ -259,7 +268,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { c.packetId.Store(0) packetId = 0 } - message := udpMessagePool.Get().(*udpMessage) + message := allocMessage() *message = udpMessage{ sessionID: c.sessionID, packetID: uint16(packetId), @@ -431,7 +440,7 @@ func (d *udpDefragger) feed(m *udpMessage) *udpMessage { if int(item.count) != len(item.messages) { return nil } - newMessage := udpMessagePool.Get().(*udpMessage) + newMessage := allocMessage() *newMessage = *item.messages[0] var dataLength uint16 for _, message := range item.messages { @@ -446,6 +455,7 @@ func (d *udpDefragger) feed(m *udpMessage) *udpMessage { item.messages = nil return newMessage } + item.messages = nil return nil } diff --git a/transport/tuic/server.go b/transport/tuic/server.go index 4a40b44f..01a2644a 100644 --- a/transport/tuic/server.go +++ b/transport/tuic/server.go @@ -13,9 +13,9 @@ import ( "time" "github.com/sagernet/quic-go" - "github.com/sagernet/sing-box/common/baderror" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" @@ -264,7 +264,7 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error { return s.connErr case <-s.authDone: } - message := udpMessagePool.Get().(*udpMessage) + message := allocMessage() err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream)) if err != nil { message.release() diff --git a/transport/tuic/server_packet.go b/transport/tuic/server_packet.go index fba6118a..d05c7bf1 100644 --- a/transport/tuic/server_packet.go +++ b/transport/tuic/server_packet.go @@ -35,7 +35,7 @@ func (s *serverSession) handleMessage(data []byte) error { } switch data[1] { case CommandPacket: - message := udpMessagePool.Get().(*udpMessage) + message := allocMessage() err := decodeUDPMessage(message, data[2:]) if err != nil { message.release() diff --git a/transport/v2raygrpc/conn.go b/transport/v2raygrpc/conn.go index 821eac70..0fecbf33 100644 --- a/transport/v2raygrpc/conn.go +++ b/transport/v2raygrpc/conn.go @@ -5,8 +5,8 @@ import ( "os" "time" - "github.com/sagernet/sing-box/common/baderror" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/baderror" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/rw" ) diff --git a/transport/v2raygrpclite/conn.go b/transport/v2raygrpclite/conn.go index d2f36f42..c1f8fd2e 100644 --- a/transport/v2raygrpclite/conn.go +++ b/transport/v2raygrpclite/conn.go @@ -11,8 +11,8 @@ import ( "sync" "time" - "github.com/sagernet/sing-box/common/baderror" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" diff --git a/transport/v2rayhttp/conn.go b/transport/v2rayhttp/conn.go index 8fa4ee84..184bc8ae 100644 --- a/transport/v2rayhttp/conn.go +++ b/transport/v2rayhttp/conn.go @@ -10,8 +10,8 @@ import ( "sync" "time" - "github.com/sagernet/sing-box/common/baderror" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions"