Fix QUIC defragger

This commit is contained in:
世界 2023-09-07 19:26:45 +08:00
parent 806f7d0a2b
commit ff209471d8
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 26 additions and 78 deletions

View file

@ -1,62 +0,0 @@
package baderror
import (
"context"
"io"
"net"
"strings"
E "github.com/sagernet/sing/common/exceptions"
)
func Contains(err error, msgList ...string) bool {
for _, msg := range msgList {
if strings.Contains(err.Error(), msg) {
return true
}
}
return false
}
func WrapH2(err error) error {
if err == nil {
return nil
}
err = E.Unwrap(err)
if err == io.ErrUnexpectedEOF {
return io.EOF
}
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
return net.ErrClosed
}
return err
}
func WrapGRPC(err error) error {
// grpc uses stupid internal error types
if err == nil {
return nil
}
if Contains(err, "EOF") {
return io.EOF
}
if Contains(err, "Canceled") {
return context.Canceled
}
if Contains(err,
"the client connection is closing",
"server closed the stream without sending trailers") {
return net.ErrClosed
}
return err
}
func WrapQUIC(err error) error {
if err == nil {
return nil
}
if Contains(err, "canceled by local with error code 0") {
return net.ErrClosed
}
return err
}

View file

@ -6,8 +6,8 @@ import (
"syscall"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/baderror"
)
type PacketConnWrapper struct {

View file

@ -10,8 +10,8 @@ import (
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"

View file

@ -34,7 +34,7 @@ func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
}
switch data[1] {
case CommandPacket:
message := udpMessagePool.Get().(*udpMessage)
message := allocMessage()
err := decodeUDPMessage(message, data[2:])
if err != nil {
message.release()
@ -82,7 +82,7 @@ func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.Receive
return E.New("unknown command ", command)
}
reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
message := udpMessagePool.Get().(*udpMessage)
message := allocMessage()
err = readUDPMessage(message, reader)
if err != nil {
message.release()

View file

@ -27,11 +27,16 @@ var udpMessagePool = sync.Pool{
},
}
func allocMessage() *udpMessage {
message := udpMessagePool.Get().(*udpMessage)
message.referenced = true
return message
}
func releaseMessages(messages []*udpMessage) {
for _, message := range messages {
if message != nil {
*message = udpMessage{}
udpMessagePool.Put(message)
message.release()
}
}
}
@ -43,9 +48,13 @@ type udpMessage struct {
fragmentID uint8
destination M.Socksaddr
data *buf.Buffer
referenced bool
}
func (m *udpMessage) release() {
if !m.referenced {
return
}
*m = udpMessage{}
udpMessagePool.Put(m)
}
@ -83,7 +92,7 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
originPacket := message.data.Bytes()
udpMTU := maxPacketSize - message.headerSize()
for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
fragment := udpMessagePool.Get().(*udpMessage)
fragment := allocMessage()
*fragment = *message
if remaining > udpMTU {
fragment.data = buf.As(originPacket[:udpMTU])
@ -214,7 +223,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
c.packetId.Store(0)
packetId = 0
}
message := udpMessagePool.Get().(*udpMessage)
message := allocMessage()
*message = udpMessage{
sessionID: c.sessionID,
packetID: uint16(packetId),
@ -259,7 +268,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.packetId.Store(0)
packetId = 0
}
message := udpMessagePool.Get().(*udpMessage)
message := allocMessage()
*message = udpMessage{
sessionID: c.sessionID,
packetID: uint16(packetId),
@ -431,7 +440,7 @@ func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
if int(item.count) != len(item.messages) {
return nil
}
newMessage := udpMessagePool.Get().(*udpMessage)
newMessage := allocMessage()
*newMessage = *item.messages[0]
var dataLength uint16
for _, message := range item.messages {
@ -446,6 +455,7 @@ func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
item.messages = nil
return newMessage
}
item.messages = nil
return nil
}

View file

@ -13,9 +13,9 @@ import (
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
@ -264,7 +264,7 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
return s.connErr
case <-s.authDone:
}
message := udpMessagePool.Get().(*udpMessage)
message := allocMessage()
err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
if err != nil {
message.release()

View file

@ -35,7 +35,7 @@ func (s *serverSession) handleMessage(data []byte) error {
}
switch data[1] {
case CommandPacket:
message := udpMessagePool.Get().(*udpMessage)
message := allocMessage()
err := decodeUDPMessage(message, data[2:])
if err != nil {
message.release()

View file

@ -5,8 +5,8 @@ import (
"os"
"time"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/baderror"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
)

View file

@ -11,8 +11,8 @@ import (
"sync"
"time"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"

View file

@ -10,8 +10,8 @@ import (
"sync"
"time"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"