diff --git a/constant/proxy.go b/constant/proxy.go index 7fc12d36..2b9d8945 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -21,6 +21,7 @@ const ( TypeShadowTLS = "shadowtls" TypeShadowsocksR = "shadowsocksr" TypeVLESS = "vless" + TypeTUIC = "tuic" ) const ( @@ -62,6 +63,8 @@ func ProxyDisplayName(proxyType string) string { return "ShadowsocksR" case TypeVLESS: return "VLESS" + case TypeTUIC: + return "TUIC" case TypeSelector: return "Selector" case TypeURLTest: diff --git a/inbound/builder.go b/inbound/builder.go index 5243e6b8..4cd466af 100644 --- a/inbound/builder.go +++ b/inbound/builder.go @@ -44,6 +44,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions) case C.TypeVLESS: return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions) + case C.TypeTUIC: + return NewTUIC(ctx, router, logger, options.Tag, options.TUICOptions) default: return nil, E.New("unknown inbound type: ", options.Type) } diff --git a/inbound/default.go b/inbound/default.go index 28f8cee2..9ddfc915 100644 --- a/inbound/default.go +++ b/inbound/default.go @@ -153,6 +153,17 @@ func (a *myInboundAdapter) createMetadata(conn net.Conn, metadata adapter.Inboun return metadata } +func (a *myInboundAdapter) createPacketMetadata(conn N.PacketConn, metadata adapter.InboundContext) adapter.InboundContext { + metadata.Inbound = a.tag + metadata.InboundType = a.protocol + metadata.InboundDetour = a.listenOptions.Detour + metadata.InboundOptions = a.listenOptions.InboundOptions + if !metadata.Destination.IsValid() { + metadata.Destination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap() + } + return metadata +} + func (a *myInboundAdapter) newError(err error) { a.logger.Error(err) } diff --git a/inbound/tuic.go b/inbound/tuic.go new file mode 100644 index 00000000..f8ce24a0 --- /dev/null +++ b/inbound/tuic.go @@ -0,0 +1,114 @@ +//go:build with_quic + +package inbound + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/tuic" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + + "github.com/gofrs/uuid/v5" +) + +var _ adapter.Inbound = (*TUIC)(nil) + +type TUIC struct { + myInboundAdapter + server *tuic.Server + tlsConfig tls.ServerConfig +} + +func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) { + options.UDPFragmentDefault = true + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + rawConfig, err := tlsConfig.Config() + if err != nil { + return nil, err + } + var users []tuic.User + for index, user := range options.Users { + if user.UUID == "" { + return nil, E.New("missing uuid for user ", index) + } + userUUID, err := uuid.FromString(user.UUID) + if err != nil { + return nil, E.Cause(err, "invalid uuid for user ", index) + } + users = append(users, tuic.User{Name: user.Name, UUID: userUUID, Password: user.Password}) + } + inbound := &TUIC{ + myInboundAdapter: myInboundAdapter{ + protocol: C.TypeTUIC, + network: []string{N.NetworkUDP}, + ctx: ctx, + router: router, + logger: logger, + tag: tag, + listenOptions: options.ListenOptions, + }, + } + server, err := tuic.NewServer(tuic.ServerOptions{ + Context: ctx, + Logger: logger, + TLSConfig: rawConfig, + Users: users, + CongestionControl: options.CongestionControl, + AuthTimeout: time.Duration(options.AuthTimeout), + ZeroRTTHandshake: options.ZeroRTTHandshake, + Heartbeat: time.Duration(options.Heartbeat), + Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil), + }) + if err != nil { + return nil, err + } + inbound.server = server + return inbound, nil +} + +func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + ctx = log.ContextWithNewID(ctx) + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + metadata = h.createMetadata(conn, metadata) + metadata.User, _ = auth.UserFromContext[string](ctx) + return h.router.RouteConnection(ctx, conn, metadata) +} + +func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = log.ContextWithNewID(ctx) + metadata = h.createPacketMetadata(conn, metadata) + metadata.User, _ = auth.UserFromContext[string](ctx) + h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) + return h.router.RoutePacketConnection(ctx, conn, metadata) +} + +func (h *TUIC) Start() error { + packetConn, err := h.myInboundAdapter.ListenUDP() + if err != nil { + return err + } + return h.server.Start(packetConn) +} + +func (h *TUIC) Close() error { + return common.Close( + &h.myInboundAdapter, + common.PtrOrNil(h.server), + ) +} diff --git a/inbound/tuic_stub.go b/inbound/tuic_stub.go new file mode 100644 index 00000000..bfd402ab --- /dev/null +++ b/inbound/tuic_stub.go @@ -0,0 +1,16 @@ +//go:build !with_quic + +package inbound + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (adapter.Inbound, error) { + return nil, C.ErrQUICNotIncluded +} diff --git a/option/inbound.go b/option/inbound.go index ef56be90..b09b3a65 100644 --- a/option/inbound.go +++ b/option/inbound.go @@ -23,6 +23,7 @@ type _Inbound struct { HysteriaOptions HysteriaInboundOptions `json:"-"` ShadowTLSOptions ShadowTLSInboundOptions `json:"-"` VLESSOptions VLESSInboundOptions `json:"-"` + TUICOptions TUICInboundOptions `json:"-"` } type Inbound _Inbound @@ -58,6 +59,8 @@ func (h Inbound) MarshalJSON() ([]byte, error) { v = h.ShadowTLSOptions case C.TypeVLESS: v = h.VLESSOptions + case C.TypeTUIC: + v = h.TUICOptions default: return nil, E.New("unknown inbound type: ", h.Type) } @@ -99,6 +102,8 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error { v = &h.ShadowTLSOptions case C.TypeVLESS: v = &h.VLESSOptions + case C.TypeTUIC: + v = &h.TUICOptions default: return E.New("unknown inbound type: ", h.Type) } diff --git a/option/outbound.go b/option/outbound.go index 5b8eb936..ab7aa0eb 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -23,6 +23,7 @@ type _Outbound struct { ShadowTLSOptions ShadowTLSOutboundOptions `json:"-"` ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"` VLESSOptions VLESSOutboundOptions `json:"-"` + TUICOptions TUICOutboundOptions `json:"-"` SelectorOptions SelectorOutboundOptions `json:"-"` URLTestOptions URLTestOutboundOptions `json:"-"` } @@ -60,6 +61,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) { v = h.ShadowsocksROptions case C.TypeVLESS: v = h.VLESSOptions + case C.TypeTUIC: + v = h.TUICOptions case C.TypeSelector: v = h.SelectorOptions case C.TypeURLTest: @@ -105,6 +108,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error { v = &h.ShadowsocksROptions case C.TypeVLESS: v = &h.VLESSOptions + case C.TypeTUIC: + v = &h.TUICOptions case C.TypeSelector: v = &h.SelectorOptions case C.TypeURLTest: diff --git a/option/tuic.go b/option/tuic.go new file mode 100644 index 00000000..98d48be2 --- /dev/null +++ b/option/tuic.go @@ -0,0 +1,30 @@ +package option + +type TUICInboundOptions struct { + ListenOptions + Users []TUICUser `json:"users,omitempty"` + CongestionControl string `json:"congestion_control,omitempty"` + AuthTimeout Duration `json:"auth_timeout,omitempty"` + ZeroRTTHandshake bool `json:"zero_rtt_handshake,omitempty"` + Heartbeat Duration `json:"heartbeat,omitempty"` + TLS *InboundTLSOptions `json:"tls,omitempty"` +} + +type TUICUser struct { + Name string `json:"name,omitempty"` + UUID string `json:"uuid,omitempty"` + Password string `json:"password,omitempty"` +} + +type TUICOutboundOptions struct { + DialerOptions + ServerOptions + UUID string `json:"uuid,omitempty"` + Password string `json:"password,omitempty"` + CongestionControl string `json:"congestion_control,omitempty"` + UDPRelayMode string `json:"udp_relay_mode,omitempty"` + ZeroRTTHandshake bool `json:"zero_rtt_handshake,omitempty"` + Heartbeat Duration `json:"heartbeat,omitempty"` + Network NetworkList `json:"network,omitempty"` + TLS *OutboundTLSOptions `json:"tls,omitempty"` +} diff --git a/outbound/builder.go b/outbound/builder.go index f32c5de6..1324fdbf 100644 --- a/outbound/builder.go +++ b/outbound/builder.go @@ -51,6 +51,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, t return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions) case C.TypeVLESS: return NewVLESS(ctx, router, logger, tag, options.VLESSOptions) + case C.TypeTUIC: + return NewTUIC(ctx, router, logger, tag, options.TUICOptions) case C.TypeSelector: return NewSelector(router, logger, tag, options.SelectorOptions) case C.TypeURLTest: diff --git a/outbound/tuic.go b/outbound/tuic.go new file mode 100644 index 00000000..42585ab3 --- /dev/null +++ b/outbound/tuic.go @@ -0,0 +1,123 @@ +//go:build with_quic + +package outbound + +import ( + "context" + "net" + "os" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/tuic" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/gofrs/uuid/v5" +) + +var ( + _ adapter.Outbound = (*TUIC)(nil) + _ adapter.InterfaceUpdateListener = (*TUIC)(nil) +) + +type TUIC struct { + myOutboundAdapter + client *tuic.Client +} + +func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (*TUIC, error) { + options.UDPFragmentDefault = true + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + tlsConfig, err := abstractTLSConfig.Config() + if err != nil { + return nil, err + } + userUUID, err := uuid.FromString(options.UUID) + if err != nil { + return nil, E.Cause(err, "invalid uuid") + } + var udpStream bool + switch options.UDPRelayMode { + case "native": + case "quic": + udpStream = true + } + client, err := tuic.NewClient(tuic.ClientOptions{ + Context: ctx, + Dialer: dialer.New(router, options.DialerOptions), + ServerAddress: options.ServerOptions.Build(), + TLSConfig: tlsConfig, + UUID: userUUID, + Password: options.Password, + CongestionControl: options.CongestionControl, + UDPStream: udpStream, + ZeroRTTHandshake: options.ZeroRTTHandshake, + Heartbeat: time.Duration(options.Heartbeat), + }) + if err != nil { + return nil, err + } + return &TUIC{ + myOutboundAdapter: myOutboundAdapter{ + protocol: C.TypeTUIC, + network: options.Network.Build(), + router: router, + logger: logger, + tag: tag, + dependencies: withDialerDependency(options.DialerOptions), + }, + client: client, + }, nil +} + +func (h *TUIC) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + h.logger.InfoContext(ctx, "outbound connection to ", destination) + return h.client.DialConn(ctx, destination) + case N.NetworkUDP: + conn, err := h.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return bufio.NewBindPacketConn(conn, destination), nil + default: + return nil, E.New("unsupported network: ", network) + } +} + +func (h *TUIC) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + return h.client.ListenPacket(ctx) +} + +func (h *TUIC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + return NewConnection(ctx, h, conn, metadata) +} + +func (h *TUIC) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return NewPacketConnection(ctx, h, conn, metadata) +} + +func (h *TUIC) InterfaceUpdated() { + _ = h.client.CloseWithError(E.New("network changed")) +} + +func (h *TUIC) Close() error { + return h.client.CloseWithError(os.ErrClosed) +} diff --git a/outbound/tuic_stub.go b/outbound/tuic_stub.go new file mode 100644 index 00000000..a6372c9e --- /dev/null +++ b/outbound/tuic_stub.go @@ -0,0 +1,16 @@ +//go:build !with_quic + +package outbound + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (adapter.Outbound, error) { + return nil, C.ErrQUICNotIncluded +} diff --git a/test/clash_test.go b/test/clash_test.go index 466ba5e8..7ea03ebc 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -38,6 +38,8 @@ const ( ImageShadowsocksR = "teddysun/shadowsocks-r:latest" ImageXRayCore = "teddysun/xray:latest" ImageShadowsocksLegacy = "mritd/shadowsocks:latest" + ImageTUICServer = "" + ImageTUICClient = "" ) var allImages = []string{ @@ -53,6 +55,8 @@ var allImages = []string{ ImageShadowsocksR, ImageXRayCore, ImageShadowsocksLegacy, + // ImageTUICServer, + // ImageTUICClient, } var localIP = netip.MustParseAddr("127.0.0.1") diff --git a/test/config/tuic-client.json b/test/config/tuic-client.json new file mode 100644 index 00000000..c1042b53 --- /dev/null +++ b/test/config/tuic-client.json @@ -0,0 +1,14 @@ +{ + "relay": { + "server": "127.0.0.1:10000", + "uuid": "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D", + "password": "tuic", + "certificates": [ + "/etc/tuic/ca.pem" + ] + }, + "local": { + "server": "127.0.0.1:10001" + }, + "log_level": "debug" +} \ No newline at end of file diff --git a/test/config/tuic-server.json b/test/config/tuic-server.json new file mode 100644 index 00000000..74e83eba --- /dev/null +++ b/test/config/tuic-server.json @@ -0,0 +1,9 @@ +{ + "server": "[::]:10000", + "users": { + "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D": "tuic" + }, + "certificate": "/etc/tuic/cert.pem", + "private_key": "/etc/tuic/key.pem", + "log_level": "debug" +} \ No newline at end of file diff --git a/test/tuic_test.go b/test/tuic_test.go new file mode 100644 index 00000000..3851d132 --- /dev/null +++ b/test/tuic_test.go @@ -0,0 +1,178 @@ +package main + +import ( + "net/netip" + "testing" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + + "github.com/gofrs/uuid/v5" +) + +func TestTUICSelf(t *testing.T) { + t.Run("self", func(t *testing.T) { + testTUICSelf(t, false, false) + }) + t.Run("self-udp-stream", func(t *testing.T) { + testTUICSelf(t, true, false) + }) + t.Run("self-early", func(t *testing.T) { + testTUICSelf(t, false, true) + }) +} + +func testTUICSelf(t *testing.T, udpStream bool, zeroRTTHandshake bool) { + _, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + var udpRelayMode string + if udpStream { + udpRelayMode = "quic" + } + startInstance(t, option.Options{ + Inbounds: []option.Inbound{ + { + Type: C.TypeMixed, + Tag: "mixed-in", + MixedOptions: option.HTTPMixedInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: clientPort, + }, + }, + }, + { + Type: C.TypeTUIC, + TUICOptions: option.TUICInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: serverPort, + }, + Users: []option.TUICUser{{ + UUID: uuid.Nil.String(), + }}, + ZeroRTTHandshake: zeroRTTHandshake, + TLS: &option.InboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + KeyPath: keyPem, + }, + }, + }, + }, + Outbounds: []option.Outbound{ + { + Type: C.TypeDirect, + }, + { + Type: C.TypeTUIC, + Tag: "tuic-out", + TUICOptions: option.TUICOutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + UUID: uuid.Nil.String(), + UDPRelayMode: udpRelayMode, + ZeroRTTHandshake: zeroRTTHandshake, + TLS: &option.OutboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + }, + }, + }, + }, + Route: &option.RouteOptions{ + Rules: []option.Rule{ + { + DefaultOptions: option.DefaultRule{ + Inbound: []string{"mixed-in"}, + Outbound: "tuic-out", + }, + }, + }, + }, + }) + testSuit(t, clientPort, testPort) +} + +func TestTUICInbound(t *testing.T) { + caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + startInstance(t, option.Options{ + Inbounds: []option.Inbound{ + { + Type: C.TypeTUIC, + TUICOptions: option.TUICInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: serverPort, + }, + Users: []option.TUICUser{{ + UUID: "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D", + Password: "tuic", + }}, + TLS: &option.InboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + KeyPath: keyPem, + }, + }, + }, + }, + }) + startDockerContainer(t, DockerOptions{ + Image: ImageTUICClient, + Ports: []uint16{serverPort, clientPort}, + Bind: map[string]string{ + "tuic-client.json": "/etc/tuic/config.json", + caPem: "/etc/tuic/ca.pem", + }, + }) +} + +func TestTUICOutbound(t *testing.T) { + _, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + startDockerContainer(t, DockerOptions{ + Image: ImageTUICServer, + Ports: []uint16{testPort}, + Bind: map[string]string{ + "tuic-server.json": "/etc/tuic/config.json", + certPem: "/etc/tuic/cert.pem", + keyPem: "/etc/tuic/key.pem", + }, + }) + startInstance(t, option.Options{ + Inbounds: []option.Inbound{ + { + Type: C.TypeMixed, + MixedOptions: option.HTTPMixedInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.NewListenAddress(netip.IPv4Unspecified()), + ListenPort: clientPort, + }, + }, + }, + }, + Outbounds: []option.Outbound{ + { + Type: C.TypeTUIC, + TUICOptions: option.TUICOutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + UUID: "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D", + Password: "tuic", + TLS: &option.OutboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + }, + }, + }, + }, + }) + testSuit(t, clientPort, testPort) +} diff --git a/transport/tuic/address.go b/transport/tuic/address.go new file mode 100644 index 00000000..22b18fa9 --- /dev/null +++ b/transport/tuic/address.go @@ -0,0 +1,10 @@ +package tuic + +import M "github.com/sagernet/sing/common/metadata" + +var addressSerializer = M.NewSerializer( + M.AddressFamilyByte(0x00, M.AddressFamilyFqdn), + M.AddressFamilyByte(0x01, M.AddressFamilyIPv4), + M.AddressFamilyByte(0x02, M.AddressFamilyIPv6), + M.AddressFamilyByte(0xff, M.AddressFamilyEmpty), +) diff --git a/transport/tuic/client.go b/transport/tuic/client.go new file mode 100644 index 00000000..6f8203bf --- /dev/null +++ b/transport/tuic/client.go @@ -0,0 +1,322 @@ +package tuic + +import ( + "context" + "crypto/tls" + "io" + "net" + "os" + "runtime" + "sync" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-box/common/baderror" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/gofrs/uuid/v5" +) + +type ClientOptions struct { + Context context.Context + Dialer N.Dialer + ServerAddress M.Socksaddr + TLSConfig *tls.Config + UUID uuid.UUID + Password string + CongestionControl string + UDPStream bool + ZeroRTTHandshake bool + Heartbeat time.Duration +} + +type Client struct { + ctx context.Context + dialer N.Dialer + serverAddr M.Socksaddr + tlsConfig *tls.Config + quicConfig *quic.Config + uuid uuid.UUID + password string + congestionControl string + udpStream bool + zeroRTTHandshake bool + heartbeat time.Duration + + connAccess sync.RWMutex + conn *clientQUICConnection +} + +func NewClient(options ClientOptions) (*Client, error) { + if options.Heartbeat == 0 { + options.Heartbeat = 10 * time.Second + } + quicConfig := &quic.Config{ + DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), + MaxDatagramFrameSize: 1400, + EnableDatagrams: true, + MaxIncomingUniStreams: 1 << 60, + } + switch options.CongestionControl { + case "": + options.CongestionControl = "cubic" + case "cubic", "new_reno", "bbr": + default: + return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) + } + return &Client{ + ctx: options.Context, + dialer: options.Dialer, + serverAddr: options.ServerAddress, + tlsConfig: options.TLSConfig, + quicConfig: quicConfig, + uuid: options.UUID, + password: options.Password, + congestionControl: options.CongestionControl, + udpStream: options.UDPStream, + zeroRTTHandshake: options.ZeroRTTHandshake, + heartbeat: options.Heartbeat, + }, nil +} + +func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { + conn := c.conn + if conn != nil && conn.active() { + return conn, nil + } + c.connAccess.Lock() + defer c.connAccess.Unlock() + conn = c.conn + if conn != nil && conn.active() { + return conn, nil + } + conn, err := c.offerNew(ctx) + if err != nil { + return nil, err + } + return conn, nil +} + +func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { + udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr) + if err != nil { + return nil, err + } + var quicConn quic.Connection + if c.zeroRTTHandshake { + quicConn, err = quic.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) + } else { + quicConn, err = quic.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) + } + if err != nil { + udpConn.Close() + return nil, E.Cause(err, "open connection") + } + setCongestion(c.ctx, quicConn, c.congestionControl) + conn := &clientQUICConnection{ + quicConn: quicConn, + rawConn: udpConn, + connDone: make(chan struct{}), + udpConnMap: make(map[uint16]*udpPacketConn), + } + go func() { + hErr := c.clientHandshake(quicConn) + if hErr != nil { + conn.closeWithError(hErr) + } + }() + if c.udpStream { + go c.loopUniStreams(conn) + } + go c.loopMessages(conn) + go c.loopHeartbeats(conn) + c.conn = conn + return conn, nil +} + +func (c *Client) clientHandshake(conn quic.Connection) error { + authStream, err := conn.OpenUniStream() + if err != nil { + return err + } + defer authStream.Close() + handshakeState := conn.ConnectionState().TLS + tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32) + if err != nil { + return err + } + authRequest := buf.NewSize(AuthenticateLen) + authRequest.WriteByte(Version) + authRequest.WriteByte(CommandAuthenticate) + authRequest.Write(c.uuid[:]) + authRequest.Write(tuicAuthToken) + return common.Error(authStream.Write(authRequest.Bytes())) +} + +func (c *Client) loopHeartbeats(conn *clientQUICConnection) { + ticker := time.NewTicker(c.heartbeat) + defer ticker.Stop() + for { + select { + case <-conn.connDone: + return + case <-ticker.C: + err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat}) + if err != nil { + conn.closeWithError(E.Cause(err, "send heartbeat")) + } + } + } +} + +func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { + conn, err := c.offer(ctx) + if err != nil { + return nil, err + } + stream, err := conn.quicConn.OpenStream() + if err != nil { + return nil, err + } + return &clientConn{ + parent: conn, + stream: stream, + destination: destination, + }, nil +} + +func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { + conn, err := c.offer(ctx) + if err != nil { + return nil, err + } + var sessionID uint16 + clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() { + conn.udpAccess.Lock() + delete(conn.udpConnMap, sessionID) + conn.udpAccess.Unlock() + }) + conn.udpAccess.Lock() + sessionID = conn.udpSessionID + conn.udpSessionID++ + conn.udpConnMap[sessionID] = clientPacketConn + conn.udpAccess.Unlock() + clientPacketConn.sessionID = sessionID + return clientPacketConn, nil +} + +func (c *Client) CloseWithError(err error) error { + conn := c.conn + if conn != nil { + conn.closeWithError(err) + } + return nil +} + +type clientQUICConnection struct { + quicConn quic.Connection + rawConn io.Closer + closeOnce sync.Once + connDone chan struct{} + connErr error + udpAccess sync.RWMutex + udpConnMap map[uint16]*udpPacketConn + udpSessionID uint16 +} + +func (c *clientQUICConnection) active() bool { + select { + case <-c.quicConn.Context().Done(): + return false + default: + } + select { + case <-c.connDone: + return false + default: + } + return true +} + +func (c *clientQUICConnection) closeWithError(err error) { + c.closeOnce.Do(func() { + c.connErr = err + close(c.connDone) + _ = c.quicConn.CloseWithError(0, "") + _ = c.rawConn.Close() + }) +} + +type clientConn struct { + parent *clientQUICConnection + stream quic.Stream + destination M.Socksaddr + requestWritten bool +} + +func (c *clientConn) Read(b []byte) (n int, err error) { + n, err = c.stream.Read(b) + return n, baderror.WrapQUIC(err) +} + +func (c *clientConn) Write(b []byte) (n int, err error) { + if !c.requestWritten { + request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b)) + request.WriteByte(Version) + request.WriteByte(CommandConnect) + addressSerializer.WriteAddrPort(request, c.destination) + request.Write(b) + _, err = c.stream.Write(request.Bytes()) + if err != nil { + c.parent.closeWithError(E.Cause(err, "create new connection")) + return 0, baderror.WrapQUIC(err) + } + c.requestWritten = true + return len(b), nil + } + n, err = c.stream.Write(b) + return n, baderror.WrapQUIC(err) +} + +func (c *clientConn) Close() error { + stream := c.stream + if stream == nil { + return nil + } + stream.CancelRead(0) + return stream.Close() +} + +func (c *clientConn) LocalAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *clientConn) RemoteAddr() net.Addr { + return c.destination +} + +func (c *clientConn) SetDeadline(t time.Time) error { + if c.stream == nil { + return os.ErrInvalid + } + return c.stream.SetDeadline(t) +} + +func (c *clientConn) SetReadDeadline(t time.Time) error { + if c.stream == nil { + return os.ErrInvalid + } + return c.stream.SetReadDeadline(t) +} + +func (c *clientConn) SetWriteDeadline(t time.Time) error { + if c.stream == nil { + return os.ErrInvalid + } + return c.stream.SetWriteDeadline(t) +} diff --git a/transport/tuic/client_packet.go b/transport/tuic/client_packet.go new file mode 100644 index 00000000..b4292e94 --- /dev/null +++ b/transport/tuic/client_packet.go @@ -0,0 +1,110 @@ +package tuic + +import ( + "io" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" +) + +func (c *Client) loopMessages(conn *clientQUICConnection) { + for { + message, err := conn.quicConn.ReceiveMessage(c.ctx) + if err != nil { + conn.closeWithError(E.Cause(err, "receive message")) + return + } + go func() { + hErr := c.handleMessage(conn, message) + if hErr != nil { + conn.closeWithError(E.Cause(hErr, "handle message")) + } + }() + } +} + +func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { + if len(data) < 2 { + return E.New("invalid message") + } + if data[0] != Version { + return E.New("unknown version ", data[0]) + } + switch data[1] { + case CommandPacket: + message := udpMessagePool.Get().(*udpMessage) + err := decodeUDPMessage(message, data[2:]) + if err != nil { + message.release() + return E.Cause(err, "decode UDP message") + } + conn.handleUDPMessage(message) + return nil + case CommandHeartbeat: + return nil + default: + return E.New("unknown command ", data[0]) + } +} + +func (c *Client) loopUniStreams(conn *clientQUICConnection) { + for { + stream, err := conn.quicConn.AcceptUniStream(c.ctx) + if err != nil { + conn.closeWithError(E.Cause(err, "handle uni stream")) + return + } + go func() { + hErr := c.handleUniStream(conn, stream) + if hErr != nil { + conn.closeWithError(hErr) + } + }() + } +} + +func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error { + defer stream.CancelRead(0) + buffer := buf.NewPacket() + defer buffer.Release() + _, err := buffer.ReadAtLeastFrom(stream, 2) + if err != nil { + return err + } + version, _ := buffer.ReadByte() + if version != Version { + return E.New("unknown version ", version) + } + command, _ := buffer.ReadByte() + if command != CommandPacket { + return E.New("unknown command ", command) + } + reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream) + message := udpMessagePool.Get().(*udpMessage) + err = readUDPMessage(message, reader) + if err != nil { + message.release() + return err + } + conn.handleUDPMessage(message) + return nil +} + +func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) { + c.udpAccess.RLock() + udpConn, loaded := c.udpConnMap[message.sessionID] + c.udpAccess.RUnlock() + if !loaded { + message.releaseMessage() + return + } + select { + case <-udpConn.ctx.Done(): + message.releaseMessage() + return + default: + } + udpConn.inputPacket(message) +} diff --git a/transport/tuic/congestion.go b/transport/tuic/congestion.go new file mode 100644 index 00000000..71f74838 --- /dev/null +++ b/transport/tuic/congestion.go @@ -0,0 +1,46 @@ +package tuic + +import ( + "context" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-box/transport/tuic/congestion" + "github.com/sagernet/sing/common/ntp" +) + +func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) { + timeFunc := ntp.TimeFuncFromContext(ctx) + if timeFunc == nil { + timeFunc = time.Now + } + switch congestionName { + case "cubic": + connection.SetCongestionControl( + congestion.NewCubicSender( + congestion.DefaultClock{TimeFunc: timeFunc}, + congestion.GetInitialPacketSize(connection.RemoteAddr()), + false, + nil, + ), + ) + case "new_reno": + connection.SetCongestionControl( + congestion.NewCubicSender( + congestion.DefaultClock{TimeFunc: timeFunc}, + congestion.GetInitialPacketSize(connection.RemoteAddr()), + true, + nil, + ), + ) + case "bbr": + connection.SetCongestionControl( + congestion.NewBBRSender( + congestion.DefaultClock{}, + congestion.GetInitialPacketSize(connection.RemoteAddr()), + congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize, + congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize, + ), + ) + } +} diff --git a/transport/tuic/congestion/README.md b/transport/tuic/congestion/README.md new file mode 100644 index 00000000..6aa0309d --- /dev/null +++ b/transport/tuic/congestion/README.md @@ -0,0 +1,3 @@ +# congestion + +mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion \ No newline at end of file diff --git a/transport/tuic/congestion/bandwidth.go b/transport/tuic/congestion/bandwidth.go new file mode 100644 index 00000000..23393bad --- /dev/null +++ b/transport/tuic/congestion/bandwidth.go @@ -0,0 +1,25 @@ +package congestion + +import ( + "math" + "time" + + "github.com/sagernet/quic-go/congestion" +) + +// Bandwidth of a connection +type Bandwidth uint64 + +const infBandwidth Bandwidth = math.MaxUint64 + +const ( + // BitsPerSecond is 1 bit per second + BitsPerSecond Bandwidth = 1 + // BytesPerSecond is 1 byte per second + BytesPerSecond = 8 * BitsPerSecond +) + +// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta +func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { + return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond +} diff --git a/transport/tuic/congestion/bandwidth_sampler.go b/transport/tuic/congestion/bandwidth_sampler.go new file mode 100644 index 00000000..908f6e0d --- /dev/null +++ b/transport/tuic/congestion/bandwidth_sampler.go @@ -0,0 +1,374 @@ +package congestion + +import ( + "math" + "time" + + "github.com/sagernet/quic-go/congestion" +) + +var InfiniteBandwidth = Bandwidth(math.MaxUint64) + +// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned +// to the caller when the packet is acked or lost. +type SendTimeState struct { + // Whether other states in this object is valid. + isValid bool + // Whether the sender is app limited at the time the packet was sent. + // App limited bandwidth sample might be artificially low because the sender + // did not have enough data to send in order to saturate the link. + isAppLimited bool + // Total number of sent bytes at the time the packet was sent. + // Includes the packet itself. + totalBytesSent congestion.ByteCount + // Total number of acked bytes at the time the packet was sent. + totalBytesAcked congestion.ByteCount + // Total number of lost bytes at the time the packet was sent. + totalBytesLost congestion.ByteCount +} + +// ConnectionStateOnSentPacket represents the information about a sent packet +// and the state of the connection at the moment the packet was sent, +// specifically the information about the most recently acknowledged packet at +// that moment. +type ConnectionStateOnSentPacket struct { + packetNumber congestion.PacketNumber + // Time at which the packet is sent. + sendTime time.Time + // Size of the packet. + size congestion.ByteCount + // The value of |totalBytesSentAtLastAckedPacket| at the time the + // packet was sent. + totalBytesSentAtLastAckedPacket congestion.ByteCount + // The value of |lastAckedPacketSentTime| at the time the packet was + // sent. + lastAckedPacketSentTime time.Time + // The value of |lastAckedPacketAckTime| at the time the packet was + // sent. + lastAckedPacketAckTime time.Time + // Send time states that are returned to the congestion controller when the + // packet is acked or lost. + sendTimeState SendTimeState +} + +// BandwidthSample +type BandwidthSample struct { + // The bandwidth at that particular sample. Zero if no valid bandwidth sample + // is available. + bandwidth Bandwidth + // The RTT measurement at this particular sample. Zero if no RTT sample is + // available. Does not correct for delayed ack time. + rtt time.Duration + // States captured when the packet was sent. + stateAtSend SendTimeState +} + +func NewBandwidthSample() *BandwidthSample { + return &BandwidthSample{ + // FIXME: the default value of original code is zero. + rtt: InfiniteRTT, + } +} + +// BandwidthSampler keeps track of sent and acknowledged packets and outputs a +// bandwidth sample for every packet acknowledged. The samples are taken for +// individual packets, and are not filtered; the consumer has to filter the +// bandwidth samples itself. In certain cases, the sampler will locally severely +// underestimate the bandwidth, hence a maximum filter with a size of at least +// one RTT is recommended. +// +// This class bases its samples on the slope of two curves: the number of bytes +// sent over time, and the number of bytes acknowledged as received over time. +// It produces a sample of both slopes for every packet that gets acknowledged, +// based on a slope between two points on each of the corresponding curves. Note +// that due to the packet loss, the number of bytes on each curve might get +// further and further away from each other, meaning that it is not feasible to +// compare byte values coming from different curves with each other. +// +// The obvious points for measuring slope sample are the ones corresponding to +// the packet that was just acknowledged. Let us denote them as S_1 (point at +// which the current packet was sent) and A_1 (point at which the current packet +// was acknowledged). However, taking a slope requires two points on each line, +// so estimating bandwidth requires picking a packet in the past with respect to +// which the slope is measured. +// +// For that purpose, BandwidthSampler always keeps track of the most recently +// acknowledged packet, and records it together with every outgoing packet. +// When a packet gets acknowledged (A_1), it has not only information about when +// it itself was sent (S_1), but also the information about the latest +// acknowledged packet right before it was sent (S_0 and A_0). +// +// Based on that data, send and ack rate are estimated as: +// +// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0)) +// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0)) +// +// Here, the ack rate is intuitively the rate we want to treat as bandwidth. +// However, in certain cases (e.g. ack compression) the ack rate at a point may +// end up higher than the rate at which the data was originally sent, which is +// not indicative of the real bandwidth. Hence, we use the send rate as an upper +// bound, and the sample value is +// +// rate_sample = min(send_rate, ack_rate) +// +// An important edge case handled by the sampler is tracking the app-limited +// samples. There are multiple meaning of "app-limited" used interchangeably, +// hence it is important to understand and to be able to distinguish between +// them. +// +// Meaning 1: connection state. The connection is said to be app-limited when +// there is no outstanding data to send. This means that certain bandwidth +// samples in the future would not be an accurate indication of the link +// capacity, and it is important to inform consumer about that. Whenever +// connection becomes app-limited, the sampler is notified via OnAppLimited() +// method. +// +// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth +// sampler becomes notified about the connection being app-limited, it enters +// app-limited phase. In that phase, all *sent* packets are marked as +// app-limited. Note that the connection itself does not have to be +// app-limited during the app-limited phase, and in fact it will not be +// (otherwise how would it send packets?). The boolean flag below indicates +// whether the sampler is in that phase. +// +// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is +// sent during the app-limited phase, the resulting sample related to the +// packet will be marked as app-limited. +// +// With the terminology issue out of the way, let us consider the question of +// what kind of situation it addresses. +// +// Consider a scenario where we first send packets 1 to 20 at a regular +// bandwidth, and then immediately run out of data. After a few seconds, we send +// packets 21 to 60, and only receive ack for 21 between sending packets 40 and +// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0 +// we use to compute the slope is going to be packet 20, a few seconds apart +// from the current packet, hence the resulting estimate would be extremely low +// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21, +// meaning that the bandwidth sample would exclude the quiescence. +// +// Based on the analysis of that scenario, we implement the following rule: once +// OnAppLimited() is called, all sent packets will produce app-limited samples +// up until an ack for a packet that was sent after OnAppLimited() was called. +// Note that while the scenario above is not the only scenario when the +// connection is app-limited, the approach works in other cases too. +type BandwidthSampler struct { + // The total number of congestion controlled bytes sent during the connection. + totalBytesSent congestion.ByteCount + // The total number of congestion controlled bytes which were acknowledged. + totalBytesAcked congestion.ByteCount + // The total number of congestion controlled bytes which were lost. + totalBytesLost congestion.ByteCount + // The value of |totalBytesSent| at the time the last acknowledged packet + // was sent. Valid only when |lastAckedPacketSentTime| is valid. + totalBytesSentAtLastAckedPacket congestion.ByteCount + // The time at which the last acknowledged packet was sent. Set to + // QuicTime::Zero() if no valid timestamp is available. + lastAckedPacketSentTime time.Time + // The time at which the most recent packet was acknowledged. + lastAckedPacketAckTime time.Time + // The most recently sent packet. + lastSendPacket congestion.PacketNumber + // Indicates whether the bandwidth sampler is currently in an app-limited + // phase. + isAppLimited bool + // The packet that will be acknowledged after this one will cause the sampler + // to exit the app-limited phase. + endOfAppLimitedPhase congestion.PacketNumber + // Record of the connection state at the point where each packet in flight was + // sent, indexed by the packet number. + connectionStats *ConnectionStates +} + +func NewBandwidthSampler() *BandwidthSampler { + return &BandwidthSampler{ + connectionStats: &ConnectionStates{ + stats: make(map[congestion.PacketNumber]*ConnectionStateOnSentPacket), + }, + } +} + +// OnPacketSent Inputs the sent packet information into the sampler. Assumes that all +// packets are sent in order. The information about the packet will not be +// released from the sampler until it the packet is either acknowledged or +// declared lost. +func (s *BandwidthSampler) OnPacketSent(sentTime time.Time, lastSentPacket congestion.PacketNumber, sentBytes, bytesInFlight congestion.ByteCount, hasRetransmittableData bool) { + s.lastSendPacket = lastSentPacket + + if !hasRetransmittableData { + return + } + + s.totalBytesSent += sentBytes + + // If there are no packets in flight, the time at which the new transmission + // opens can be treated as the A_0 point for the purpose of bandwidth + // sampling. This underestimates bandwidth to some extent, and produces some + // artificially low samples for most packets in flight, but it provides with + // samples at important points where we would not have them otherwise, most + // importantly at the beginning of the connection. + if bytesInFlight == 0 { + s.lastAckedPacketAckTime = sentTime + s.totalBytesSentAtLastAckedPacket = s.totalBytesSent + + // In this situation ack compression is not a concern, set send rate to + // effectively infinite. + s.lastAckedPacketSentTime = sentTime + } + + s.connectionStats.Insert(lastSentPacket, sentTime, sentBytes, s) +} + +// OnPacketAcked Notifies the sampler that the |lastAckedPacket| is acknowledged. Returns a +// bandwidth sample. If no bandwidth sample is available, +// QuicBandwidth::Zero() is returned. +func (s *BandwidthSampler) OnPacketAcked(ackTime time.Time, lastAckedPacket congestion.PacketNumber) *BandwidthSample { + sentPacketState := s.connectionStats.Get(lastAckedPacket) + if sentPacketState == nil { + return NewBandwidthSample() + } + + sample := s.onPacketAckedInner(ackTime, lastAckedPacket, sentPacketState) + s.connectionStats.Remove(lastAckedPacket) + + return sample +} + +// onPacketAckedInner Handles the actual bandwidth calculations, whereas the outer method handles +// retrieving and removing |sentPacket|. +func (s *BandwidthSampler) onPacketAckedInner(ackTime time.Time, lastAckedPacket congestion.PacketNumber, sentPacket *ConnectionStateOnSentPacket) *BandwidthSample { + s.totalBytesAcked += sentPacket.size + + s.totalBytesSentAtLastAckedPacket = sentPacket.sendTimeState.totalBytesSent + s.lastAckedPacketSentTime = sentPacket.sendTime + s.lastAckedPacketAckTime = ackTime + + // Exit app-limited phase once a packet that was sent while the connection is + // not app-limited is acknowledged. + if s.isAppLimited && lastAckedPacket > s.endOfAppLimitedPhase { + s.isAppLimited = false + } + + // There might have been no packets acknowledged at the moment when the + // current packet was sent. In that case, there is no bandwidth sample to + // make. + if sentPacket.lastAckedPacketSentTime.IsZero() { + return NewBandwidthSample() + } + + // Infinite rate indicates that the sampler is supposed to discard the + // current send rate sample and use only the ack rate. + sendRate := InfiniteBandwidth + if sentPacket.sendTime.After(sentPacket.lastAckedPacketSentTime) { + sendRate = BandwidthFromDelta(sentPacket.sendTimeState.totalBytesSent-sentPacket.totalBytesSentAtLastAckedPacket, sentPacket.sendTime.Sub(sentPacket.lastAckedPacketSentTime)) + } + + // During the slope calculation, ensure that ack time of the current packet is + // always larger than the time of the previous packet, otherwise division by + // zero or integer underflow can occur. + if !ackTime.After(sentPacket.lastAckedPacketAckTime) { + // TODO(wub): Compare this code count before and after fixing clock jitter + // issue. + // if sentPacket.lastAckedPacketAckTime.Equal(sentPacket.sendTime) { + // This is the 1st packet after quiescense. + // QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2); + // } else { + // QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2); + // } + + return NewBandwidthSample() + } + + ackRate := BandwidthFromDelta(s.totalBytesAcked-sentPacket.sendTimeState.totalBytesAcked, + ackTime.Sub(sentPacket.lastAckedPacketAckTime)) + + // Note: this sample does not account for delayed acknowledgement time. This + // means that the RTT measurements here can be artificially high, especially + // on low bandwidth connections. + sample := &BandwidthSample{ + bandwidth: minBandwidth(sendRate, ackRate), + rtt: ackTime.Sub(sentPacket.sendTime), + } + + SentPacketToSendTimeState(sentPacket, &sample.stateAtSend) + return sample +} + +// OnPacketLost Informs the sampler that a packet is considered lost and it should no +// longer keep track of it. +func (s *BandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber) SendTimeState { + ok, sentPacket := s.connectionStats.Remove(packetNumber) + sendTimeState := SendTimeState{ + isValid: ok, + } + if sentPacket != nil { + s.totalBytesLost += sentPacket.size + SentPacketToSendTimeState(sentPacket, &sendTimeState) + } + + return sendTimeState +} + +// OnAppLimited Informs the sampler that the connection is currently app-limited, causing +// the sampler to enter the app-limited phase. The phase will expire by +// itself. +func (s *BandwidthSampler) OnAppLimited() { + s.isAppLimited = true + s.endOfAppLimitedPhase = s.lastSendPacket +} + +// SentPacketToSendTimeState Copy a subset of the (private) ConnectionStateOnSentPacket to the (public) +// SendTimeState. Always set send_time_state->is_valid to true. +func SentPacketToSendTimeState(sentPacket *ConnectionStateOnSentPacket, sendTimeState *SendTimeState) { + sendTimeState.isAppLimited = sentPacket.sendTimeState.isAppLimited + sendTimeState.totalBytesSent = sentPacket.sendTimeState.totalBytesSent + sendTimeState.totalBytesAcked = sentPacket.sendTimeState.totalBytesAcked + sendTimeState.totalBytesLost = sentPacket.sendTimeState.totalBytesLost + sendTimeState.isValid = true +} + +// ConnectionStates Record of the connection state at the point where each packet in flight was +// sent, indexed by the packet number. +// FIXME: using LinkedList replace map to fast remove all the packets lower than the specified packet number. +type ConnectionStates struct { + stats map[congestion.PacketNumber]*ConnectionStateOnSentPacket +} + +func (s *ConnectionStates) Insert(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) bool { + if _, ok := s.stats[packetNumber]; ok { + return false + } + + s.stats[packetNumber] = NewConnectionStateOnSentPacket(packetNumber, sentTime, bytes, sampler) + return true +} + +func (s *ConnectionStates) Get(packetNumber congestion.PacketNumber) *ConnectionStateOnSentPacket { + return s.stats[packetNumber] +} + +func (s *ConnectionStates) Remove(packetNumber congestion.PacketNumber) (bool, *ConnectionStateOnSentPacket) { + state, ok := s.stats[packetNumber] + if ok { + delete(s.stats, packetNumber) + } + return ok, state +} + +func NewConnectionStateOnSentPacket(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) *ConnectionStateOnSentPacket { + return &ConnectionStateOnSentPacket{ + packetNumber: packetNumber, + sendTime: sentTime, + size: bytes, + lastAckedPacketSentTime: sampler.lastAckedPacketSentTime, + lastAckedPacketAckTime: sampler.lastAckedPacketAckTime, + totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket, + sendTimeState: SendTimeState{ + isValid: true, + isAppLimited: sampler.isAppLimited, + totalBytesSent: sampler.totalBytesSent, + totalBytesAcked: sampler.totalBytesAcked, + totalBytesLost: sampler.totalBytesLost, + }, + } +} diff --git a/transport/tuic/congestion/bbr_sender.go b/transport/tuic/congestion/bbr_sender.go new file mode 100644 index 00000000..34acc676 --- /dev/null +++ b/transport/tuic/congestion/bbr_sender.go @@ -0,0 +1,1000 @@ +package congestion + +// src from https://quiche.googlesource.com/quiche.git/+/66dea072431f94095dfc3dd2743cb94ef365f7ef/quic/core/congestion_control/bbr_sender.cc + +import ( + "fmt" + "math" + "math/rand" + "net" + "time" + + "github.com/sagernet/quic-go/congestion" +) + +const ( + // InitialMaxDatagramSize is the default maximum packet size used in QUIC for congestion window computations in bytes. + InitialMaxDatagramSize = 1252 + InitialPacketSizeIPv4 = 1252 + InitialPacketSizeIPv6 = 1232 + InitialCongestionWindow = 32 + DefaultBBRMaxCongestionWindow = 10000 +) + +func GetInitialPacketSize(addr net.Addr) congestion.ByteCount { + maxSize := congestion.ByteCount(1200) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if udpAddr.IP.To4() != nil { + maxSize = InitialPacketSizeIPv4 + } else { + maxSize = InitialPacketSizeIPv6 + } + } + return congestion.ByteCount(maxSize) +} + +var ( + + // Default initial rtt used before any samples are received. + InitialRtt = 100 * time.Millisecond + + // The gain used for the STARTUP, equal to 4*ln(2). + DefaultHighGain = 2.77 + + // The gain used in STARTUP after loss has been detected. + // 1.5 is enough to allow for 25% exogenous loss and still observe a 25% growth + // in measured bandwidth. + StartupAfterLossGain = 1.5 + + // The cycle of gains used during the PROBE_BW stage. + PacingGain = []float64{1.25, 0.75, 1, 1, 1, 1, 1, 1} + + // The length of the gain cycle. + GainCycleLength = len(PacingGain) + + // The size of the bandwidth filter window, in round-trips. + BandwidthWindowSize = GainCycleLength + 2 + + // The time after which the current min_rtt value expires. + MinRttExpiry = 10 * time.Second + + // The minimum time the connection can spend in PROBE_RTT mode. + ProbeRttTime = time.Millisecond * 200 + + // If the bandwidth does not increase by the factor of |kStartupGrowthTarget| + // within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection + // will exit the STARTUP mode. + StartupGrowthTarget = 1.25 + RoundTripsWithoutGrowthBeforeExitingStartup = int64(3) + + // Coefficient of target congestion window to use when basing PROBE_RTT on BDP. + ModerateProbeRttMultiplier = 0.75 + + // Coefficient to determine if a new RTT is sufficiently similar to min_rtt that + // we don't need to enter PROBE_RTT. + SimilarMinRttThreshold = 1.125 + + // Congestion window gain for QUIC BBR during PROBE_BW phase. + DefaultCongestionWindowGainConst = 2.0 +) + +type bbrMode int + +const ( + // Startup phase of the connection. + STARTUP = iota + // After achieving the highest possible bandwidth during the startup, lower + // the pacing rate in order to drain the queue. + DRAIN + // Cruising mode. + PROBE_BW + // Temporarily slow down sending in order to empty the buffer and measure + // the real minimum RTT. + PROBE_RTT +) + +type bbrRecoveryState int + +const ( + // Do not limit. + NOT_IN_RECOVERY = iota + + // Allow an extra outstanding byte for each byte acknowledged. + CONSERVATION + + // Allow two extra outstanding bytes for each byte acknowledged (slow + // start). + GROWTH +) + +type bbrSender struct { + mode bbrMode + clock Clock + rttStats congestion.RTTStatsProvider + bytesInFlight congestion.ByteCount + // return total bytes of unacked packets. + // GetBytesInFlight func() congestion.ByteCount + // Bandwidth sampler provides BBR with the bandwidth measurements at + // individual points. + sampler *BandwidthSampler + // The number of the round trips that have occurred during the connection. + roundTripCount int64 + // The packet number of the most recently sent packet. + lastSendPacket congestion.PacketNumber + // Acknowledgement of any packet after |current_round_trip_end_| will cause + // the round trip counter to advance. + currentRoundTripEnd congestion.PacketNumber + // The filter that tracks the maximum bandwidth over the multiple recent + // round-trips. + maxBandwidth *WindowedFilter + // Tracks the maximum number of bytes acked faster than the sending rate. + maxAckHeight *WindowedFilter + // The time this aggregation started and the number of bytes acked during it. + aggregationEpochStartTime time.Time + aggregationEpochBytes congestion.ByteCount + // Minimum RTT estimate. Automatically expires within 10 seconds (and + // triggers PROBE_RTT mode) if no new value is sampled during that period. + minRtt time.Duration + // The time at which the current value of |min_rtt_| was assigned. + minRttTimestamp time.Time + // The maximum allowed number of bytes in flight. + congestionWindow congestion.ByteCount + // The initial value of the |congestion_window_|. + initialCongestionWindow congestion.ByteCount + // The largest value the |congestion_window_| can achieve. + initialMaxCongestionWindow congestion.ByteCount + // The smallest value the |congestion_window_| can achieve. + // minCongestionWindow congestion.ByteCount + // The pacing gain applied during the STARTUP phase. + highGain float64 + // The CWND gain applied during the STARTUP phase. + highCwndGain float64 + // The pacing gain applied during the DRAIN phase. + drainGain float64 + // The current pacing rate of the connection. + pacingRate Bandwidth + // The gain currently applied to the pacing rate. + pacingGain float64 + // The gain currently applied to the congestion window. + congestionWindowGain float64 + // The gain used for the congestion window during PROBE_BW. Latched from + // quic_bbr_cwnd_gain flag. + congestionWindowGainConst float64 + // The number of RTTs to stay in STARTUP mode. Defaults to 3. + numStartupRtts int64 + // If true, exit startup if 1RTT has passed with no bandwidth increase and + // the connection is in recovery. + exitStartupOnLoss bool + // Number of round-trips in PROBE_BW mode, used for determining the current + // pacing gain cycle. + cycleCurrentOffset int + // The time at which the last pacing gain cycle was started. + lastCycleStart time.Time + // Indicates whether the connection has reached the full bandwidth mode. + isAtFullBandwidth bool + // Number of rounds during which there was no significant bandwidth increase. + roundsWithoutBandwidthGain int64 + // The bandwidth compared to which the increase is measured. + bandwidthAtLastRound Bandwidth + // Set to true upon exiting quiescence. + exitingQuiescence bool + // Time at which PROBE_RTT has to be exited. Setting it to zero indicates + // that the time is yet unknown as the number of packets in flight has not + // reached the required value. + exitProbeRttAt time.Time + // Indicates whether a round-trip has passed since PROBE_RTT became active. + probeRttRoundPassed bool + // Indicates whether the most recent bandwidth sample was marked as + // app-limited. + lastSampleIsAppLimited bool + // Indicates whether any non app-limited samples have been recorded. + hasNoAppLimitedSample bool + // Indicates app-limited calls should be ignored as long as there's + // enough data inflight to see more bandwidth when necessary. + flexibleAppLimited bool + // Current state of recovery. + recoveryState bbrRecoveryState + // Receiving acknowledgement of a packet after |end_recovery_at_| will cause + // BBR to exit the recovery mode. A value above zero indicates at least one + // loss has been detected, so it must not be set back to zero. + endRecoveryAt congestion.PacketNumber + // A window used to limit the number of bytes in flight during loss recovery. + recoveryWindow congestion.ByteCount + // If true, consider all samples in recovery app-limited. + isAppLimitedRecovery bool + // When true, pace at 1.5x and disable packet conservation in STARTUP. + slowerStartup bool + // When true, disables packet conservation in STARTUP. + rateBasedStartup bool + // When non-zero, decreases the rate in STARTUP by the total number of bytes + // lost in STARTUP divided by CWND. + startupRateReductionMultiplier int64 + // Sum of bytes lost in STARTUP. + startupBytesLost congestion.ByteCount + // When true, add the most recent ack aggregation measurement during STARTUP. + enableAckAggregationDuringStartup bool + // When true, expire the windowed ack aggregation values in STARTUP when + // bandwidth increases more than 25%. + expireAckAggregationInStartup bool + // If true, will not exit low gain mode until bytes_in_flight drops below BDP + // or it's time for high gain mode. + drainToTarget bool + // If true, use a CWND of 0.75*BDP during probe_rtt instead of 4 packets. + probeRttBasedOnBdp bool + // If true, skip probe_rtt and update the timestamp of the existing min_rtt to + // now if min_rtt over the last cycle is within 12.5% of the current min_rtt. + // Even if the min_rtt is 12.5% too low, the 25% gain cycling and 2x CWND gain + // should overcome an overly small min_rtt. + probeRttSkippedIfSimilarRtt bool + // If true, disable PROBE_RTT entirely as long as the connection was recently + // app limited. + probeRttDisabledIfAppLimited bool + appLimitedSinceLastProbeRtt bool + minRttSinceLastProbeRtt time.Duration + // Latched value of --quic_always_get_bw_sample_when_acked. + alwaysGetBwSampleWhenAcked bool + + pacer *pacer + + maxDatagramSize congestion.ByteCount +} + +func NewBBRSender( + clock Clock, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow congestion.ByteCount, +) *bbrSender { + b := &bbrSender{ + mode: STARTUP, + clock: clock, + sampler: NewBandwidthSampler(), + maxBandwidth: NewWindowedFilter(int64(BandwidthWindowSize), MaxFilter), + maxAckHeight: NewWindowedFilter(int64(BandwidthWindowSize), MaxFilter), + congestionWindow: initialCongestionWindow, + initialCongestionWindow: initialCongestionWindow, + highGain: DefaultHighGain, + highCwndGain: DefaultHighGain, + drainGain: 1.0 / DefaultHighGain, + pacingGain: 1.0, + congestionWindowGain: 1.0, + congestionWindowGainConst: DefaultCongestionWindowGainConst, + numStartupRtts: RoundTripsWithoutGrowthBeforeExitingStartup, + recoveryState: NOT_IN_RECOVERY, + recoveryWindow: initialMaxCongestionWindow, + minRttSinceLastProbeRtt: InfiniteRTT, + maxDatagramSize: initialMaxDatagramSize, + } + b.pacer = newPacer(b.BandwidthEstimate) + return b +} + +func (b *bbrSender) maxCongestionWindow() congestion.ByteCount { + return b.maxDatagramSize * DefaultBBRMaxCongestionWindow +} + +func (b *bbrSender) minCongestionWindow() congestion.ByteCount { + return b.maxDatagramSize * b.initialCongestionWindow +} + +func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { + b.rttStats = provider +} + +func (b *bbrSender) GetBytesInFlight() congestion.ByteCount { + return b.bytesInFlight +} + +// TimeUntilSend returns when the next packet should be sent. +func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { + b.bytesInFlight = bytesInFlight + return b.pacer.TimeUntilSend() +} + +func (b *bbrSender) HasPacingBudget(now time.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) { + if s < b.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s)) + } + cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow() + b.maxDatagramSize = s + if cwndIsMinCwnd { + b.congestionWindow = b.minCongestionWindow() + } + b.pacer.SetMaxDatagramSize(s) +} + +func (b *bbrSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool) { + b.pacer.SentPacket(sentTime, bytes) + b.lastSendPacket = packetNumber + + b.bytesInFlight = bytesInFlight + if bytesInFlight == 0 && b.sampler.isAppLimited { + b.exitingQuiescence = true + } + + if b.aggregationEpochStartTime.IsZero() { + b.aggregationEpochStartTime = sentTime + } + + b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable) +} + +func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool { + b.bytesInFlight = bytesInFlight + return bytesInFlight < b.GetCongestionWindow() +} + +func (b *bbrSender) GetCongestionWindow() congestion.ByteCount { + if b.mode == PROBE_RTT { + return b.ProbeRttCongestionWindow() + } + + if b.InRecovery() && !(b.rateBasedStartup && b.mode == STARTUP) { + return minByteCount(b.congestionWindow, b.recoveryWindow) + } + + return b.congestionWindow +} + +func (b *bbrSender) MaybeExitSlowStart() { +} + +func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, priorInFlight congestion.ByteCount, eventTime time.Time) { + totalBytesAckedBefore := b.sampler.totalBytesAcked + isRoundStart, minRttExpired := false, false + lastAckedPacket := number + + isRoundStart = b.UpdateRoundTripCounter(lastAckedPacket) + minRttExpired = b.UpdateBandwidthAndMinRtt(eventTime, number, ackedBytes) + b.UpdateRecoveryState(false, isRoundStart) + bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore + excessAcked := b.UpdateAckAggregationBytes(eventTime, bytesAcked) + + // Handle logic specific to STARTUP and DRAIN modes. + if isRoundStart && !b.isAtFullBandwidth { + b.CheckIfFullBandwidthReached() + } + b.MaybeExitStartupOrDrain(eventTime) + + // Handle logic specific to PROBE_RTT. + b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) + + // After the model is updated, recalculate the pacing rate and congestion + // window. + b.CalculatePacingRate() + b.CalculateCongestionWindow(bytesAcked, excessAcked) + b.CalculateRecoveryWindow(bytesAcked, congestion.ByteCount(0)) +} + +func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount, priorInFlight congestion.ByteCount) { + eventTime := time.Now() + totalBytesAckedBefore := b.sampler.totalBytesAcked + isRoundStart, minRttExpired := false, false + + b.DiscardLostPackets(number, lostBytes) + + // Input the new data into the BBR model of the connection. + var excessAcked congestion.ByteCount + + // Handle logic specific to PROBE_BW mode. + if b.mode == PROBE_BW { + b.UpdateGainCyclePhase(time.Now(), priorInFlight, true) + } + + // Handle logic specific to STARTUP and DRAIN modes. + b.MaybeExitStartupOrDrain(eventTime) + + // Handle logic specific to PROBE_RTT. + b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) + + // Calculate number of packets acked and lost. + bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore + bytesLost := lostBytes + + // After the model is updated, recalculate the pacing rate and congestion + // window. + b.CalculatePacingRate() + b.CalculateCongestionWindow(bytesAcked, excessAcked) + b.CalculateRecoveryWindow(bytesAcked, bytesLost) +} + +//func (b *bbrSender) OnCongestionEvent(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets, lostPackets []*congestion.Packet) { +// totalBytesAckedBefore := b.sampler.totalBytesAcked +// isRoundStart, minRttExpired := false, false +// +// if lostPackets != nil { +// b.DiscardLostPackets(lostPackets) +// } +// +// // Input the new data into the BBR model of the connection. +// var excessAcked congestion.ByteCount +// if len(ackedPackets) > 0 { +// lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber +// isRoundStart = b.UpdateRoundTripCounter(lastAckedPacket) +// minRttExpired = b.UpdateBandwidthAndMinRtt(eventTime, ackedPackets) +// b.UpdateRecoveryState(lastAckedPacket, len(lostPackets) > 0, isRoundStart) +// bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore +// excessAcked = b.UpdateAckAggregationBytes(eventTime, bytesAcked) +// } +// +// // Handle logic specific to PROBE_BW mode. +// if b.mode == PROBE_BW { +// b.UpdateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) > 0) +// } +// +// // Handle logic specific to STARTUP and DRAIN modes. +// if isRoundStart && !b.isAtFullBandwidth { +// b.CheckIfFullBandwidthReached() +// } +// b.MaybeExitStartupOrDrain(eventTime) +// +// // Handle logic specific to PROBE_RTT. +// b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) +// +// // Calculate number of packets acked and lost. +// bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore +// bytesLost := congestion.ByteCount(0) +// for _, packet := range lostPackets { +// bytesLost += packet.Length +// } +// +// // After the model is updated, recalculate the pacing rate and congestion +// // window. +// b.CalculatePacingRate() +// b.CalculateCongestionWindow(bytesAcked, excessAcked) +// b.CalculateRecoveryWindow(bytesAcked, bytesLost) +//} + +//func (b *bbrSender) SetNumEmulatedConnections(n int) { +// +//} + +func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) { +} + +//func (b *bbrSender) OnConnectionMigration() { +// +//} + +//// Experiments +//func (b *bbrSender) SetSlowStartLargeReduction(enabled bool) { +// +//} + +//func (b *bbrSender) BandwidthEstimate() Bandwidth { +// return Bandwidth(b.maxBandwidth.GetBest()) +//} + +// BandwidthEstimate returns the current bandwidth estimate +func (b *bbrSender) BandwidthEstimate() Bandwidth { + if b.rttStats == nil { + return infBandwidth + } + srtt := b.rttStats.SmoothedRTT() + if srtt == 0 { + // If we haven't measured an rtt, the bandwidth estimate is unknown. + return infBandwidth + } + return BandwidthFromDelta(b.GetCongestionWindow(), srtt) +} + +//func (b *bbrSender) HybridSlowStart() *HybridSlowStart { +// return nil +//} + +//func (b *bbrSender) SlowstartThreshold() congestion.ByteCount { +// return 0 +//} + +//func (b *bbrSender) RenoBeta() float32 { +// return 0.0 +//} + +func (b *bbrSender) InRecovery() bool { + return b.recoveryState != NOT_IN_RECOVERY +} + +func (b *bbrSender) InSlowStart() bool { + return b.mode == STARTUP +} + +//func (b *bbrSender) ShouldSendProbingPacket() bool { +// if b.pacingGain <= 1 { +// return false +// } +// // TODO(b/77975811): If the pipe is highly under-utilized, consider not +// // sending a probing transmission, because the extra bandwidth is not needed. +// // If flexible_app_limited is enabled, check if the pipe is sufficiently full. +// if b.flexibleAppLimited { +// return !b.IsPipeSufficientlyFull() +// } else { +// return true +// } +//} + +//func (b *bbrSender) IsPipeSufficientlyFull() bool { +// // See if we need more bytes in flight to see more bandwidth. +// if b.mode == STARTUP { +// // STARTUP exits if it doesn't observe a 25% bandwidth increase, so the CWND +// // must be more than 25% above the target. +// return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(1.5) +// } +// if b.pacingGain > 1 { +// // Super-unity PROBE_BW doesn't exit until 1.25 * BDP is achieved. +// return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(b.pacingGain) +// } +// // If bytes_in_flight are above the target congestion window, it should be +// // possible to observe the same or more bandwidth if it's available. +// return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(1.1) +//} + +//func (b *bbrSender) SetFromConfig() { +// // TODO: not impl. +//} + +func (b *bbrSender) UpdateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool { + if b.currentRoundTripEnd == 0 || lastAckedPacket > b.currentRoundTripEnd { + b.currentRoundTripEnd = lastAckedPacket + b.roundTripCount++ + // if b.rttStats != nil && b.InSlowStart() { + // TODO: ++stats_->slowstart_num_rtts; + // } + return true + } + return false +} + +func (b *bbrSender) UpdateBandwidthAndMinRtt(now time.Time, number congestion.PacketNumber, ackedBytes congestion.ByteCount) bool { + sampleMinRtt := InfiniteRTT + + if !b.alwaysGetBwSampleWhenAcked && ackedBytes == 0 { + // Skip acked packets with 0 in flight bytes when updating bandwidth. + return false + } + bandwidthSample := b.sampler.OnPacketAcked(now, number) + if b.alwaysGetBwSampleWhenAcked && !bandwidthSample.stateAtSend.isValid { + // From the sampler's perspective, the packet has never been sent, or the + // packet has been acked or marked as lost previously. + return false + } + b.lastSampleIsAppLimited = bandwidthSample.stateAtSend.isAppLimited + // has_non_app_limited_sample_ |= + // !bandwidth_sample.state_at_send.is_app_limited; + if !bandwidthSample.stateAtSend.isAppLimited { + b.hasNoAppLimitedSample = true + } + if bandwidthSample.rtt > 0 { + sampleMinRtt = minRtt(sampleMinRtt, bandwidthSample.rtt) + } + if !bandwidthSample.stateAtSend.isAppLimited || bandwidthSample.bandwidth > b.BandwidthEstimate() { + b.maxBandwidth.Update(int64(bandwidthSample.bandwidth), b.roundTripCount) + } + + // If none of the RTT samples are valid, return immediately. + if sampleMinRtt == InfiniteRTT { + return false + } + + b.minRttSinceLastProbeRtt = minRtt(b.minRttSinceLastProbeRtt, sampleMinRtt) + // Do not expire min_rtt if none was ever available. + minRttExpired := b.minRtt > 0 && (now.After(b.minRttTimestamp.Add(MinRttExpiry))) + if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 { + if minRttExpired && b.ShouldExtendMinRttExpiry() { + minRttExpired = false + } else { + b.minRtt = sampleMinRtt + } + b.minRttTimestamp = now + // Reset since_last_probe_rtt fields. + b.minRttSinceLastProbeRtt = InfiniteRTT + b.appLimitedSinceLastProbeRtt = false + } + + return minRttExpired +} + +func (b *bbrSender) ShouldExtendMinRttExpiry() bool { + if b.probeRttDisabledIfAppLimited && b.appLimitedSinceLastProbeRtt { + // Extend the current min_rtt if we've been app limited recently. + return true + } + + minRttIncreasedSinceLastProbe := b.minRttSinceLastProbeRtt > time.Duration(float64(b.minRtt)*SimilarMinRttThreshold) + if b.probeRttSkippedIfSimilarRtt && b.appLimitedSinceLastProbeRtt && !minRttIncreasedSinceLastProbe { + // Extend the current min_rtt if we've been app limited recently and an rtt + // has been measured in that time that's less than 12.5% more than the + // current min_rtt. + return true + } + + return false +} + +func (b *bbrSender) DiscardLostPackets(number congestion.PacketNumber, lostBytes congestion.ByteCount) { + b.sampler.OnPacketLost(number) + if b.mode == STARTUP { + // if b.rttStats != nil { + // TODO: slow start. + // } + if b.startupRateReductionMultiplier != 0 { + b.startupBytesLost += lostBytes + } + } +} + +func (b *bbrSender) UpdateRecoveryState(hasLosses, isRoundStart bool) { + // Exit recovery when there are no losses for a round. + if !hasLosses { + b.endRecoveryAt = b.lastSendPacket + } + switch b.recoveryState { + case NOT_IN_RECOVERY: + // Enter conservation on the first loss. + if hasLosses { + b.recoveryState = CONSERVATION + // This will cause the |recovery_window_| to be set to the correct + // value in CalculateRecoveryWindow(). + b.recoveryWindow = 0 + // Since the conservation phase is meant to be lasting for a whole + // round, extend the current round as if it were started right now. + b.currentRoundTripEnd = b.lastSendPacket + if false && b.lastSampleIsAppLimited { + b.isAppLimitedRecovery = true + } + } + case CONSERVATION: + if isRoundStart { + b.recoveryState = GROWTH + } + fallthrough + case GROWTH: + // Exit recovery if appropriate. + if !hasLosses && b.lastSendPacket > b.endRecoveryAt { + b.recoveryState = NOT_IN_RECOVERY + b.isAppLimitedRecovery = false + } + } + + if b.recoveryState != NOT_IN_RECOVERY && b.isAppLimitedRecovery { + b.sampler.OnAppLimited() + } +} + +func (b *bbrSender) UpdateAckAggregationBytes(ackTime time.Time, ackedBytes congestion.ByteCount) congestion.ByteCount { + // Compute how many bytes are expected to be delivered, assuming max bandwidth + // is correct. + expectedAckedBytes := congestion.ByteCount(b.maxBandwidth.GetBest()) * + congestion.ByteCount((ackTime.Sub(b.aggregationEpochStartTime))) + // Reset the current aggregation epoch as soon as the ack arrival rate is less + // than or equal to the max bandwidth. + if b.aggregationEpochBytes <= expectedAckedBytes { + // Reset to start measuring a new aggregation epoch. + b.aggregationEpochBytes = ackedBytes + b.aggregationEpochStartTime = ackTime + return 0 + } + // Compute how many extra bytes were delivered vs max bandwidth. + // Include the bytes most recently acknowledged to account for stretch acks. + b.aggregationEpochBytes += ackedBytes + b.maxAckHeight.Update(int64(b.aggregationEpochBytes-expectedAckedBytes), b.roundTripCount) + return b.aggregationEpochBytes - expectedAckedBytes +} + +func (b *bbrSender) UpdateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) { + bytesInFlight := b.GetBytesInFlight() + // In most cases, the cycle is advanced after an RTT passes. + shouldAdvanceGainCycling := now.Sub(b.lastCycleStart) > b.GetMinRtt() + + // If the pacing gain is above 1.0, the connection is trying to probe the + // bandwidth by increasing the number of bytes in flight to at least + // pacing_gain * BDP. Make sure that it actually reaches the target, as long + // as there are no losses suggesting that the buffers are not able to hold + // that much. + if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.GetTargetCongestionWindow(b.pacingGain) { + shouldAdvanceGainCycling = false + } + // If pacing gain is below 1.0, the connection is trying to drain the extra + // queue which could have been incurred by probing prior to it. If the number + // of bytes in flight falls down to the estimated BDP value earlier, conclude + // that the queue has been successfully drained and exit this cycle early. + if b.pacingGain < 1.0 && bytesInFlight <= b.GetTargetCongestionWindow(1.0) { + shouldAdvanceGainCycling = true + } + + if shouldAdvanceGainCycling { + b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % GainCycleLength + b.lastCycleStart = now + // Stay in low gain mode until the target BDP is hit. + // Low gain mode will be exited immediately when the target BDP is achieved. + if b.drainToTarget && b.pacingGain < 1.0 && PacingGain[b.cycleCurrentOffset] == 1.0 && + bytesInFlight > b.GetTargetCongestionWindow(1.0) { + return + } + b.pacingGain = PacingGain[b.cycleCurrentOffset] + } +} + +func (b *bbrSender) GetTargetCongestionWindow(gain float64) congestion.ByteCount { + bdp := congestion.ByteCount(b.GetMinRtt()) * congestion.ByteCount(b.BandwidthEstimate()) + congestionWindow := congestion.ByteCount(gain * float64(bdp)) + + // BDP estimate will be zero if no bandwidth samples are available yet. + if congestionWindow == 0 { + congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow)) + } + + return maxByteCount(congestionWindow, b.minCongestionWindow()) +} + +func (b *bbrSender) CheckIfFullBandwidthReached() { + if b.lastSampleIsAppLimited { + return + } + + target := Bandwidth(float64(b.bandwidthAtLastRound) * StartupGrowthTarget) + if b.BandwidthEstimate() >= target { + b.bandwidthAtLastRound = b.BandwidthEstimate() + b.roundsWithoutBandwidthGain = 0 + if b.expireAckAggregationInStartup { + // Expire old excess delivery measurements now that bandwidth increased. + b.maxAckHeight.Reset(0, b.roundTripCount) + } + return + } + b.roundsWithoutBandwidthGain++ + if b.roundsWithoutBandwidthGain >= b.numStartupRtts || (b.exitStartupOnLoss && b.InRecovery()) { + b.isAtFullBandwidth = true + } +} + +func (b *bbrSender) MaybeExitStartupOrDrain(now time.Time) { + if b.mode == STARTUP && b.isAtFullBandwidth { + b.OnExitStartup(now) + b.mode = DRAIN + b.pacingGain = b.drainGain + b.congestionWindowGain = b.highCwndGain + } + if b.mode == DRAIN && b.GetBytesInFlight() <= b.GetTargetCongestionWindow(1) { + b.EnterProbeBandwidthMode(now) + } +} + +func (b *bbrSender) EnterProbeBandwidthMode(now time.Time) { + b.mode = PROBE_BW + b.congestionWindowGain = b.congestionWindowGainConst + + // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is + // excluded because in that case increased gain and decreased gain would not + // follow each other. + b.cycleCurrentOffset = rand.Int() % (GainCycleLength - 1) + if b.cycleCurrentOffset >= 1 { + b.cycleCurrentOffset += 1 + } + + b.lastCycleStart = now + b.pacingGain = PacingGain[b.cycleCurrentOffset] +} + +func (b *bbrSender) MaybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) { + if minRttExpired && !b.exitingQuiescence && b.mode != PROBE_RTT { + if b.InSlowStart() { + b.OnExitStartup(now) + } + b.mode = PROBE_RTT + b.pacingGain = 1.0 + // Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight| + // is at the target small value. + b.exitProbeRttAt = time.Time{} + } + + if b.mode == PROBE_RTT { + b.sampler.OnAppLimited() + if b.exitProbeRttAt.IsZero() { + // If the window has reached the appropriate size, schedule exiting + // PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but + // we allow an extra packet since QUIC checks CWND before sending a + // packet. + if b.GetBytesInFlight() < b.ProbeRttCongestionWindow()+b.maxDatagramSize { + b.exitProbeRttAt = now.Add(ProbeRttTime) + b.probeRttRoundPassed = false + } + } else { + if isRoundStart { + b.probeRttRoundPassed = true + } + if !now.Before(b.exitProbeRttAt) && b.probeRttRoundPassed { + b.minRttTimestamp = now + if !b.isAtFullBandwidth { + b.EnterStartupMode(now) + } else { + b.EnterProbeBandwidthMode(now) + } + } + } + } + b.exitingQuiescence = false +} + +func (b *bbrSender) ProbeRttCongestionWindow() congestion.ByteCount { + if b.probeRttBasedOnBdp { + return b.GetTargetCongestionWindow(ModerateProbeRttMultiplier) + } else { + return b.minCongestionWindow() + } +} + +func (b *bbrSender) EnterStartupMode(now time.Time) { + // if b.rttStats != nil { + // TODO: slow start. + // } + b.mode = STARTUP + b.pacingGain = b.highGain + b.congestionWindowGain = b.highCwndGain +} + +func (b *bbrSender) OnExitStartup(now time.Time) { + if b.rttStats == nil { + return + } + // TODO: slow start. +} + +func (b *bbrSender) CalculatePacingRate() { + if b.BandwidthEstimate() == 0 { + return + } + + targetRate := Bandwidth(b.pacingGain * float64(b.BandwidthEstimate())) + if b.isAtFullBandwidth { + b.pacingRate = targetRate + return + } + + // Pace at the rate of initial_window / RTT as soon as RTT measurements are + // available. + if b.pacingRate == 0 && b.rttStats.MinRTT() > 0 { + b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT()) + return + } + // Slow the pacing rate in STARTUP once loss has ever been detected. + hasEverDetectedLoss := b.endRecoveryAt > 0 + if b.slowerStartup && hasEverDetectedLoss && b.hasNoAppLimitedSample { + b.pacingRate = Bandwidth(StartupAfterLossGain * float64(b.BandwidthEstimate())) + return + } + + // Slow the pacing rate in STARTUP by the bytes_lost / CWND. + if b.startupRateReductionMultiplier != 0 && hasEverDetectedLoss && b.hasNoAppLimitedSample { + b.pacingRate = Bandwidth((1.0 - (float64(b.startupBytesLost) * float64(b.startupRateReductionMultiplier) / float64(b.congestionWindow))) * float64(targetRate)) + // Ensure the pacing rate doesn't drop below the startup growth target times + // the bandwidth estimate. + b.pacingRate = maxBandwidth(b.pacingRate, Bandwidth(StartupGrowthTarget*float64(b.BandwidthEstimate()))) + return + } + + // Do not decrease the pacing rate during startup. + b.pacingRate = maxBandwidth(b.pacingRate, targetRate) +} + +func (b *bbrSender) CalculateCongestionWindow(ackedBytes, excessAcked congestion.ByteCount) { + if b.mode == PROBE_RTT { + return + } + + targetWindow := b.GetTargetCongestionWindow(b.congestionWindowGain) + if b.isAtFullBandwidth { + // Add the max recently measured ack aggregation to CWND. + targetWindow += congestion.ByteCount(b.maxAckHeight.GetBest()) + } else if b.enableAckAggregationDuringStartup { + // Add the most recent excess acked. Because CWND never decreases in + // STARTUP, this will automatically create a very localized max filter. + targetWindow += excessAcked + } + + // Instead of immediately setting the target CWND as the new one, BBR grows + // the CWND towards |target_window| by only increasing it |bytes_acked| at a + // time. + addBytesAcked := true || !b.InRecovery() + if b.isAtFullBandwidth { + b.congestionWindow = minByteCount(targetWindow, b.congestionWindow+ackedBytes) + } else if addBytesAcked && (b.congestionWindow < targetWindow || b.sampler.totalBytesAcked < b.initialCongestionWindow) { + // If the connection is not yet out of startup phase, do not decrease the + // window. + b.congestionWindow += ackedBytes + } + + // Enforce the limits on the congestion window. + b.congestionWindow = maxByteCount(b.congestionWindow, b.minCongestionWindow()) + b.congestionWindow = minByteCount(b.congestionWindow, b.maxCongestionWindow()) +} + +func (b *bbrSender) CalculateRecoveryWindow(ackedBytes, lostBytes congestion.ByteCount) { + if b.rateBasedStartup && b.mode == STARTUP { + return + } + + if b.recoveryState == NOT_IN_RECOVERY { + return + } + + // Set up the initial recovery window. + if b.recoveryWindow == 0 { + b.recoveryWindow = maxByteCount(b.GetBytesInFlight()+ackedBytes, b.minCongestionWindow()) + return + } + + // Remove losses from the recovery window, while accounting for a potential + // integer underflow. + if b.recoveryWindow >= lostBytes { + b.recoveryWindow -= lostBytes + } else { + b.recoveryWindow = congestion.ByteCount(b.maxDatagramSize) + } + // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, + // release additional |bytes_acked| to achieve a slow-start-like behavior. + if b.recoveryState == GROWTH { + b.recoveryWindow += ackedBytes + } + // Sanity checks. Ensure that we always allow to send at least an MSS or + // |bytes_acked| in response, whichever is larger. + b.recoveryWindow = maxByteCount(b.recoveryWindow, b.GetBytesInFlight()+ackedBytes) + b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow()) +} + +var _ congestion.CongestionControl = (*bbrSender)(nil) + +func (b *bbrSender) GetMinRtt() time.Duration { + if b.minRtt > 0 { + return b.minRtt + } else { + return InitialRtt + } +} + +func minRtt(a, b time.Duration) time.Duration { + if a < b { + return a + } else { + return b + } +} + +func minBandwidth(a, b Bandwidth) Bandwidth { + if a < b { + return a + } else { + return b + } +} + +func maxBandwidth(a, b Bandwidth) Bandwidth { + if a > b { + return a + } else { + return b + } +} + +func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount { + if a > b { + return a + } else { + return b + } +} + +func minByteCount(a, b congestion.ByteCount) congestion.ByteCount { + if a < b { + return a + } else { + return b + } +} + +var InfiniteRTT = time.Duration(math.MaxInt64) diff --git a/transport/tuic/congestion/clock.go b/transport/tuic/congestion/clock.go new file mode 100644 index 00000000..dc3ccdc5 --- /dev/null +++ b/transport/tuic/congestion/clock.go @@ -0,0 +1,20 @@ +package congestion + +import "time" + +// A Clock returns the current time +type Clock interface { + Now() time.Time +} + +// DefaultClock implements the Clock interface using the Go stdlib clock. +type DefaultClock struct { + TimeFunc func() time.Time +} + +var _ Clock = DefaultClock{} + +// Now gets the current time +func (c DefaultClock) Now() time.Time { + return c.TimeFunc() +} diff --git a/transport/tuic/congestion/cubic.go b/transport/tuic/congestion/cubic.go new file mode 100644 index 00000000..d437c540 --- /dev/null +++ b/transport/tuic/congestion/cubic.go @@ -0,0 +1,213 @@ +package congestion + +import ( + "math" + "time" + + "github.com/sagernet/quic-go/congestion" +) + +// This cubic implementation is based on the one found in Chromiums's QUIC +// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. + +// Constants based on TCP defaults. +// The following constants are in 2^10 fractions of a second instead of ms to +// allow a 10 shift right to divide. + +// 1024*1024^3 (first 1024 is from 0.100^3) +// where 0.100 is 100 ms which is the scaling round trip time. +const ( + cubeScale = 40 + cubeCongestionWindowScale = 410 + cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize + // TODO: when re-enabling cubic, make sure to use the actual packet size here + maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4) +) + +const defaultNumConnections = 1 + +// Default Cubic backoff factor +const beta float32 = 0.7 + +// Additional backoff factor when loss occurs in the concave part of the Cubic +// curve. This additional backoff factor is expected to give up bandwidth to +// new concurrent flows and speed up convergence. +const betaLastMax float32 = 0.85 + +// Cubic implements the cubic algorithm from TCP +type Cubic struct { + clock Clock + + // Number of connections to simulate. + numConnections int + + // Time when this cycle started, after last loss event. + epoch time.Time + + // Max congestion window used just before last loss event. + // Note: to improve fairness to other streams an additional back off is + // applied to this value if the new value is below our latest value. + lastMaxCongestionWindow congestion.ByteCount + + // Number of acked bytes since the cycle started (epoch). + ackedBytesCount congestion.ByteCount + + // TCP Reno equivalent congestion window in packets. + estimatedTCPcongestionWindow congestion.ByteCount + + // Origin point of cubic function. + originPointCongestionWindow congestion.ByteCount + + // Time to origin point of cubic function in 2^10 fractions of a second. + timeToOriginPoint uint32 + + // Last congestion window in packets computed by cubic function. + lastTargetCongestionWindow congestion.ByteCount +} + +// NewCubic returns a new Cubic instance +func NewCubic(clock Clock) *Cubic { + c := &Cubic{ + clock: clock, + numConnections: defaultNumConnections, + } + c.Reset() + return c +} + +// Reset is called after a timeout to reset the cubic state +func (c *Cubic) Reset() { + c.epoch = time.Time{} + c.lastMaxCongestionWindow = 0 + c.ackedBytesCount = 0 + c.estimatedTCPcongestionWindow = 0 + c.originPointCongestionWindow = 0 + c.timeToOriginPoint = 0 + c.lastTargetCongestionWindow = 0 +} + +func (c *Cubic) alpha() float32 { + // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that + // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. + // We derive the equivalent alpha for an N-connection emulation as: + b := c.beta() + return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) +} + +func (c *Cubic) beta() float32 { + // kNConnectionBeta is the backoff factor after loss for our N-connection + // emulation, which emulates the effective backoff of an ensemble of N + // TCP-Reno connections on a single loss event. The effective multiplier is + // computed as: + return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) +} + +func (c *Cubic) betaLastMax() float32 { + // betaLastMax is the additional backoff factor after loss for our + // N-connection emulation, which emulates the additional backoff of + // an ensemble of N TCP-Reno connections on a single loss event. The + // effective multiplier is computed as: + return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) +} + +// OnApplicationLimited is called on ack arrival when sender is unable to use +// the available congestion window. Resets Cubic state during quiescence. +func (c *Cubic) OnApplicationLimited() { + // When sender is not using the available congestion window, the window does + // not grow. But to be RTT-independent, Cubic assumes that the sender has been + // using the entire window during the time since the beginning of the current + // "epoch" (the end of the last loss recovery period). Since + // application-limited periods break this assumption, we reset the epoch when + // in such a period. This reset effectively freezes congestion window growth + // through application-limited periods and allows Cubic growth to continue + // when the entire window is being used. + c.epoch = time.Time{} +} + +// CongestionWindowAfterPacketLoss computes a new congestion window to use after +// a loss event. Returns the new congestion window in packets. The new +// congestion window is a multiplicative decrease of our current window. +func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount { + if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { + // We never reached the old max, so assume we are competing with another + // flow. Use our extra back off factor to allow the other flow to go up. + c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) + } else { + c.lastMaxCongestionWindow = currentCongestionWindow + } + c.epoch = time.Time{} // Reset time. + return congestion.ByteCount(float32(currentCongestionWindow) * c.beta()) +} + +// CongestionWindowAfterAck computes a new congestion window to use after a received ACK. +// Returns the new congestion window in packets. The new congestion window +// follows a cubic function that depends on the time passed since last +// packet loss. +func (c *Cubic) CongestionWindowAfterAck( + ackedBytes congestion.ByteCount, + currentCongestionWindow congestion.ByteCount, + delayMin time.Duration, + eventTime time.Time, +) congestion.ByteCount { + c.ackedBytesCount += ackedBytes + + if c.epoch.IsZero() { + // First ACK after a loss event. + c.epoch = eventTime // Start of epoch. + c.ackedBytesCount = ackedBytes // Reset count. + // Reset estimated_tcp_congestion_window_ to be in sync with cubic. + c.estimatedTCPcongestionWindow = currentCongestionWindow + if c.lastMaxCongestionWindow <= currentCongestionWindow { + c.timeToOriginPoint = 0 + c.originPointCongestionWindow = currentCongestionWindow + } else { + c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) + c.originPointCongestionWindow = c.lastMaxCongestionWindow + } + } + + // Change the time unit from microseconds to 2^10 fractions per second. Take + // the round trip time in account. This is done to allow us to use shift as a + // divide operator. + elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) + + // Right-shifts of negative, signed numbers have implementation-dependent + // behavior, so force the offset to be positive, as is done in the kernel. + offset := int64(c.timeToOriginPoint) - elapsedTime + if offset < 0 { + offset = -offset + } + + deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale + var targetCongestionWindow congestion.ByteCount + if elapsedTime > int64(c.timeToOriginPoint) { + targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow + } else { + targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow + } + // Limit the CWND increase to half the acked bytes. + targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + + // Increase the window by approximately Alpha * 1 MSS of bytes every + // time we ack an estimated tcp window of bytes. For small + // congestion windows (less than 25), the formula below will + // increase slightly slower than linearly per estimated tcp window + // of bytes. + c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) + c.ackedBytesCount = 0 + + // We have a new cubic congestion window. + c.lastTargetCongestionWindow = targetCongestionWindow + + // Compute target congestion_window based on cubic target and estimated TCP + // congestion_window, use highest (fastest). + if targetCongestionWindow < c.estimatedTCPcongestionWindow { + targetCongestionWindow = c.estimatedTCPcongestionWindow + } + return targetCongestionWindow +} + +// SetNumConnections sets the number of emulated connections +func (c *Cubic) SetNumConnections(n int) { + c.numConnections = n +} diff --git a/transport/tuic/congestion/cubic_sender.go b/transport/tuic/congestion/cubic_sender.go new file mode 100644 index 00000000..fc97d17a --- /dev/null +++ b/transport/tuic/congestion/cubic_sender.go @@ -0,0 +1,318 @@ +package congestion + +import ( + "fmt" + "time" + + "github.com/sagernet/quic-go/congestion" + "github.com/sagernet/quic-go/logging" +) + +const ( + maxBurstPackets = 3 + renoBeta = 0.7 // Reno backoff factor. + minCongestionWindowPackets = 2 + initialCongestionWindow = 32 +) + +const ( + InvalidPacketNumber congestion.PacketNumber = -1 + MaxCongestionWindowPackets = 20000 + MaxByteCount = congestion.ByteCount(1<<62 - 1) +) + +type cubicSender struct { + hybridSlowStart HybridSlowStart + rttStats congestion.RTTStatsProvider + cubic *Cubic + pacer *pacer + clock Clock + + reno bool + + // Track the largest packet that has been sent. + largestSentPacketNumber congestion.PacketNumber + + // Track the largest packet that has been acked. + largestAckedPacketNumber congestion.PacketNumber + + // Track the largest packet number outstanding when a CWND cutback occurs. + largestSentAtLastCutback congestion.PacketNumber + + // Whether the last loss event caused us to exit slowstart. + // Used for stats collection of slowstartPacketsLost + lastCutbackExitedSlowstart bool + + // Congestion window in bytes. + congestionWindow congestion.ByteCount + + // Slow start congestion window in bytes, aka ssthresh. + slowStartThreshold congestion.ByteCount + + // ACK counter for the Reno implementation. + numAckedPackets uint64 + + initialCongestionWindow congestion.ByteCount + initialMaxCongestionWindow congestion.ByteCount + + maxDatagramSize congestion.ByteCount + + lastState logging.CongestionState + tracer logging.ConnectionTracer +} + +var _ congestion.CongestionControl = &cubicSender{} + +// NewCubicSender makes a new cubic sender +func NewCubicSender( + clock Clock, + initialMaxDatagramSize congestion.ByteCount, + reno bool, + tracer logging.ConnectionTracer, +) *cubicSender { + return newCubicSender( + clock, + reno, + initialMaxDatagramSize, + initialCongestionWindow*initialMaxDatagramSize, + MaxCongestionWindowPackets*initialMaxDatagramSize, + tracer, + ) +} + +func newCubicSender( + clock Clock, + reno bool, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow congestion.ByteCount, + tracer logging.ConnectionTracer, +) *cubicSender { + c := &cubicSender{ + largestSentPacketNumber: InvalidPacketNumber, + largestAckedPacketNumber: InvalidPacketNumber, + largestSentAtLastCutback: InvalidPacketNumber, + initialCongestionWindow: initialCongestionWindow, + initialMaxCongestionWindow: initialMaxCongestionWindow, + congestionWindow: initialCongestionWindow, + slowStartThreshold: MaxByteCount, + cubic: NewCubic(clock), + clock: clock, + reno: reno, + tracer: tracer, + maxDatagramSize: initialMaxDatagramSize, + } + c.pacer = newPacer(c.BandwidthEstimate) + if c.tracer != nil { + c.lastState = logging.CongestionStateSlowStart + c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) + } + return c +} + +func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { + c.rttStats = provider +} + +// TimeUntilSend returns when the next packet should be sent. +func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time { + return c.pacer.TimeUntilSend() +} + +func (c *cubicSender) HasPacingBudget(now time.Time) bool { + return c.pacer.Budget(now) >= c.maxDatagramSize +} + +func (c *cubicSender) maxCongestionWindow() congestion.ByteCount { + return c.maxDatagramSize * MaxCongestionWindowPackets +} + +func (c *cubicSender) minCongestionWindow() congestion.ByteCount { + return c.maxDatagramSize * minCongestionWindowPackets +} + +func (c *cubicSender) OnPacketSent( + sentTime time.Time, + _ congestion.ByteCount, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + isRetransmittable bool, +) { + c.pacer.SentPacket(sentTime, bytes) + if !isRetransmittable { + return + } + c.largestSentPacketNumber = packetNumber + c.hybridSlowStart.OnPacketSent(packetNumber) +} + +func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < c.GetCongestionWindow() +} + +func (c *cubicSender) InRecovery() bool { + return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback +} + +func (c *cubicSender) InSlowStart() bool { + return c.GetCongestionWindow() < c.slowStartThreshold +} + +func (c *cubicSender) GetCongestionWindow() congestion.ByteCount { + return c.congestionWindow +} + +func (c *cubicSender) MaybeExitSlowStart() { + if c.InSlowStart() && + c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { + // exit slow start + c.slowStartThreshold = c.congestionWindow + c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) + } +} + +func (c *cubicSender) OnPacketAcked( + ackedPacketNumber congestion.PacketNumber, + ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, + eventTime time.Time, +) { + c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber) + if c.InRecovery() { + return + } + c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) + if c.InSlowStart() { + c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) + } +} + +func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets + // already sent should be treated as a single loss event, since it's expected. + if packetNumber <= c.largestSentAtLastCutback { + return + } + c.lastCutbackExitedSlowstart = c.InSlowStart() + c.maybeTraceStateChange(logging.CongestionStateRecovery) + + if c.reno { + c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta) + } else { + c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) + } + if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { + c.congestionWindow = minCwnd + } + c.slowStartThreshold = c.congestionWindow + c.largestSentAtLastCutback = c.largestSentPacketNumber + // reset packet count from congestion avoidance mode. We start + // counting again when we're out of recovery. + c.numAckedPackets = 0 +} + +// Called when we receive an ack. Normal TCP tracks how many packets one ack +// represents, but quic has a separate ack for each packet. +func (c *cubicSender) maybeIncreaseCwnd( + _ congestion.PacketNumber, + ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, + eventTime time.Time, +) { + // Do not increase the congestion window unless the sender is close to using + // the current window. + if !c.isCwndLimited(priorInFlight) { + c.cubic.OnApplicationLimited() + c.maybeTraceStateChange(logging.CongestionStateApplicationLimited) + return + } + if c.congestionWindow >= c.maxCongestionWindow() { + return + } + if c.InSlowStart() { + // TCP slow start, exponential growth, increase by one for each ACK. + c.congestionWindow += c.maxDatagramSize + c.maybeTraceStateChange(logging.CongestionStateSlowStart) + return + } + // Congestion avoidance + c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) + if c.reno { + // Classic Reno congestion avoidance. + c.numAckedPackets++ + if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { + c.congestionWindow += c.maxDatagramSize + c.numAckedPackets = 0 + } + } else { + c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) + } +} + +func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool { + congestionWindow := c.GetCongestionWindow() + if bytesInFlight >= congestionWindow { + return true + } + availableBytes := congestionWindow - bytesInFlight + slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 + return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize +} + +// BandwidthEstimate returns the current bandwidth estimate +func (c *cubicSender) BandwidthEstimate() Bandwidth { + if c.rttStats == nil { + return infBandwidth + } + srtt := c.rttStats.SmoothedRTT() + if srtt == 0 { + // If we haven't measured an rtt, the bandwidth estimate is unknown. + return infBandwidth + } + return BandwidthFromDelta(c.GetCongestionWindow(), srtt) +} + +// OnRetransmissionTimeout is called on an retransmission timeout +func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { + c.largestSentAtLastCutback = InvalidPacketNumber + if !packetsRetransmitted { + return + } + c.hybridSlowStart.Restart() + c.cubic.Reset() + c.slowStartThreshold = c.congestionWindow / 2 + c.congestionWindow = c.minCongestionWindow() +} + +// OnConnectionMigration is called when the connection is migrated (?) +func (c *cubicSender) OnConnectionMigration() { + c.hybridSlowStart.Restart() + c.largestSentPacketNumber = InvalidPacketNumber + c.largestAckedPacketNumber = InvalidPacketNumber + c.largestSentAtLastCutback = InvalidPacketNumber + c.lastCutbackExitedSlowstart = false + c.cubic.Reset() + c.numAckedPackets = 0 + c.congestionWindow = c.initialCongestionWindow + c.slowStartThreshold = c.initialMaxCongestionWindow +} + +func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { + if c.tracer == nil || new == c.lastState { + return + } + c.tracer.UpdatedCongestionState(new) + c.lastState = new +} + +func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) { + if s < c.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) + } + cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() + c.maxDatagramSize = s + if cwndIsMinCwnd { + c.congestionWindow = c.minCongestionWindow() + } + c.pacer.SetMaxDatagramSize(s) +} diff --git a/transport/tuic/congestion/hybrid_slow_start.go b/transport/tuic/congestion/hybrid_slow_start.go new file mode 100644 index 00000000..eba8f7df --- /dev/null +++ b/transport/tuic/congestion/hybrid_slow_start.go @@ -0,0 +1,112 @@ +package congestion + +import ( + "time" + + "github.com/sagernet/quic-go/congestion" +) + +// Note(pwestin): the magic clamping numbers come from the original code in +// tcp_cubic.c. +const hybridStartLowWindow = congestion.ByteCount(16) + +// Number of delay samples for detecting the increase of delay. +const hybridStartMinSamples = uint32(8) + +// Exit slow start if the min rtt has increased by more than 1/8th. +const hybridStartDelayFactorExp = 3 // 2^3 = 8 +// The original paper specifies 2 and 8ms, but those have changed over time. +const ( + hybridStartDelayMinThresholdUs = int64(4000) + hybridStartDelayMaxThresholdUs = int64(16000) +) + +// HybridSlowStart implements the TCP hybrid slow start algorithm +type HybridSlowStart struct { + endPacketNumber congestion.PacketNumber + lastSentPacketNumber congestion.PacketNumber + started bool + currentMinRTT time.Duration + rttSampleCount uint32 + hystartFound bool +} + +// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. +func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) { + s.endPacketNumber = lastSent + s.currentMinRTT = 0 + s.rttSampleCount = 0 + s.started = true +} + +// IsEndOfRound returns true if this ack is the last packet number of our current slow start round. +func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool { + return s.endPacketNumber < ack +} + +// ShouldExitSlowStart should be called on every new ack frame, since a new +// RTT measurement can be made then. +// rtt: the RTT for this ack packet. +// minRTT: is the lowest delay (RTT) we have seen during the session. +// congestionWindow: the congestion window in packets. +func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool { + if !s.started { + // Time to start the hybrid slow start. + s.StartReceiveRound(s.lastSentPacketNumber) + } + if s.hystartFound { + return true + } + // Second detection parameter - delay increase detection. + // Compare the minimum delay (s.currentMinRTT) of the current + // burst of packets relative to the minimum delay during the session. + // Note: we only look at the first few(8) packets in each burst, since we + // only want to compare the lowest RTT of the burst relative to previous + // bursts. + s.rttSampleCount++ + if s.rttSampleCount <= hybridStartMinSamples { + if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { + s.currentMinRTT = latestRTT + } + } + // We only need to check this once per round. + if s.rttSampleCount == hybridStartMinSamples { + // Divide minRTT by 8 to get a rtt increase threshold for exiting. + minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) + // Ensure the rtt threshold is never less than 2ms or more than 16ms. + minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) + minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond + + if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { + s.hystartFound = true + } + } + // Exit from slow start if the cwnd is greater than 16 and + // increasing delay is found. + return congestionWindow >= hybridStartLowWindow && s.hystartFound +} + +// OnPacketSent is called when a packet was sent +func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) { + s.lastSentPacketNumber = packetNumber +} + +// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end +// the round when the final packet of the burst is received and start it on +// the next incoming ack. +func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) { + if s.IsEndOfRound(ackedPacketNumber) { + s.started = false + } +} + +// Started returns true if started +func (s *HybridSlowStart) Started() bool { + return s.started +} + +// Restart the slow start phase +func (s *HybridSlowStart) Restart() { + s.started = false + s.hystartFound = false +} diff --git a/transport/tuic/congestion/minmax.go b/transport/tuic/congestion/minmax.go new file mode 100644 index 00000000..ed75072e --- /dev/null +++ b/transport/tuic/congestion/minmax.go @@ -0,0 +1,72 @@ +package congestion + +import ( + "math" + "time" + + "golang.org/x/exp/constraints" +) + +// InfDuration is a duration of infinite length +const InfDuration = time.Duration(math.MaxInt64) + +func Max[T constraints.Ordered](a, b T) T { + if a < b { + return b + } + return a +} + +func Min[T constraints.Ordered](a, b T) T { + if a < b { + return a + } + return b +} + +// MinNonZeroDuration return the minimum duration that's not zero. +func MinNonZeroDuration(a, b time.Duration) time.Duration { + if a == 0 { + return b + } + if b == 0 { + return a + } + return Min(a, b) +} + +// AbsDuration returns the absolute value of a time duration +func AbsDuration(d time.Duration) time.Duration { + if d >= 0 { + return d + } + return -d +} + +// MinTime returns the earlier time +func MinTime(a, b time.Time) time.Time { + if a.After(b) { + return b + } + return a +} + +// MinNonZeroTime returns the earlist time that is not time.Time{} +// If both a and b are time.Time{}, it returns time.Time{} +func MinNonZeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b + } + if b.IsZero() { + return a + } + return MinTime(a, b) +} + +// MaxTime returns the later time +func MaxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} diff --git a/transport/tuic/congestion/pacer.go b/transport/tuic/congestion/pacer.go new file mode 100644 index 00000000..5d0f13f6 --- /dev/null +++ b/transport/tuic/congestion/pacer.go @@ -0,0 +1,81 @@ +package congestion + +import ( + "math" + "time" + + "github.com/sagernet/quic-go/congestion" +) + +const ( + initialMaxDatagramSize = congestion.ByteCount(1252) + MinPacingDelay = time.Millisecond + TimerGranularity = time.Millisecond + maxBurstSizePackets = 10 +) + +// The pacer implements a token bucket pacing algorithm. +type pacer struct { + budgetAtLastSent congestion.ByteCount + maxDatagramSize congestion.ByteCount + lastSentTime time.Time + getAdjustedBandwidth func() uint64 // in bytes/s +} + +func newPacer(getBandwidth func() Bandwidth) *pacer { + p := &pacer{ + maxDatagramSize: initialMaxDatagramSize, + getAdjustedBandwidth: func() uint64 { + // Bandwidth is in bits/s. We need the value in bytes/s. + bw := uint64(getBandwidth() / BytesPerSecond) + // Use a slightly higher value than the actual measured bandwidth. + // RTT variations then won't result in under-utilization of the congestion window. + // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, + // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. + return bw * 5 / 4 + }, + } + p.budgetAtLastSent = p.maxBurstSize() + return p +} + +func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *pacer) Budget(now time.Time) congestion.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + return Min(p.maxBurstSize(), budget) +} + +func (p *pacer) maxBurstSize() congestion.ByteCount { + return Max( + congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, + maxBurstSizePackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value of time.Time if a packet can be sent immediately. +func (p *pacer) TimeUntilSend() time.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return time.Time{} + } + return p.lastSentTime.Add(Max( + MinPacingDelay, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, + )) +} + +func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) { + p.maxDatagramSize = s +} diff --git a/transport/tuic/congestion/windowed_filter.go b/transport/tuic/congestion/windowed_filter.go new file mode 100644 index 00000000..4da595b9 --- /dev/null +++ b/transport/tuic/congestion/windowed_filter.go @@ -0,0 +1,132 @@ +package congestion + +// WindowedFilter Use the following to construct a windowed filter object of type T. +// For example, a min filter using QuicTime as the time type: +// +// WindowedFilter, QuicTime, QuicTime::Delta> ObjectName; +// +// A max filter using 64-bit integers as the time type: +// +// WindowedFilter, uint64_t, int64_t> ObjectName; +// +// Specifically, this template takes four arguments: +// 1. T -- type of the measurement that is being filtered. +// 2. Compare -- MinFilter or MaxFilter, depending on the type of filter +// desired. +// 3. TimeT -- the type used to represent timestamps. +// 4. TimeDeltaT -- the type used to represent continuous time intervals between +// two timestamps. Has to be the type of (a - b) if both |a| and |b| are +// of type TimeT. +type WindowedFilter struct { + // Time length of window. + windowLength int64 + estimates []Sample + comparator func(int64, int64) bool +} + +type Sample struct { + sample int64 + time int64 +} + +// Compares two values and returns true if the first is greater than or equal +// to the second. +func MaxFilter(a, b int64) bool { + return a >= b +} + +// Compares two values and returns true if the first is less than or equal +// to the second. +func MinFilter(a, b int64) bool { + return a <= b +} + +func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter { + return &WindowedFilter{ + windowLength: windowLength, + estimates: make([]Sample, 3), + comparator: comparator, + } +} + +// Changes the window length. Does not update any current samples. +func (f *WindowedFilter) SetWindowLength(windowLength int64) { + f.windowLength = windowLength +} + +func (f *WindowedFilter) GetBest() int64 { + return f.estimates[0].sample +} + +func (f *WindowedFilter) GetSecondBest() int64 { + return f.estimates[1].sample +} + +func (f *WindowedFilter) GetThirdBest() int64 { + return f.estimates[2].sample +} + +func (f *WindowedFilter) Update(sample int64, time int64) { + if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength { + f.Reset(sample, time) + return + } + + if f.comparator(sample, f.estimates[1].sample) { + f.estimates[1].sample = sample + f.estimates[1].time = time + f.estimates[2].sample = sample + f.estimates[2].time = time + } else if f.comparator(sample, f.estimates[2].sample) { + f.estimates[2].sample = sample + f.estimates[2].time = time + } + + // Expire and update estimates as necessary. + if time-f.estimates[0].time > f.windowLength { + // The best estimate hasn't been updated for an entire window, so promote + // second and third best estimates. + f.estimates[0].sample = f.estimates[1].sample + f.estimates[0].time = f.estimates[1].time + f.estimates[1].sample = f.estimates[2].sample + f.estimates[1].time = f.estimates[2].time + f.estimates[2].sample = sample + f.estimates[2].time = time + // Need to iterate one more time. Check if the new best estimate is + // outside the window as well, since it may also have been recorded a + // long time ago. Don't need to iterate once more since we cover that + // case at the beginning of the method. + if time-f.estimates[0].time > f.windowLength { + f.estimates[0].sample = f.estimates[1].sample + f.estimates[0].time = f.estimates[1].time + f.estimates[1].sample = f.estimates[2].sample + f.estimates[1].time = f.estimates[2].time + } + return + } + if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 { + // A quarter of the window has passed without a better sample, so the + // second-best estimate is taken from the second quarter of the window. + f.estimates[1].sample = sample + f.estimates[1].time = time + f.estimates[2].sample = sample + f.estimates[2].time = time + return + } + + if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 { + // We've passed a half of the window without a better estimate, so take + // a third-best estimate from the second half of the window. + f.estimates[2].sample = sample + f.estimates[2].time = time + } +} + +func (f *WindowedFilter) Reset(newSample int64, newTime int64) { + f.estimates[0].sample = newSample + f.estimates[0].time = newTime + f.estimates[1].sample = newSample + f.estimates[1].time = newTime + f.estimates[2].sample = newSample + f.estimates[2].time = newTime +} diff --git a/transport/tuic/packet.go b/transport/tuic/packet.go new file mode 100644 index 00000000..a3a0c35a --- /dev/null +++ b/transport/tuic/packet.go @@ -0,0 +1,497 @@ +package tuic + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "math" + "net" + "os" + "sync" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/cache" + M "github.com/sagernet/sing/common/metadata" +) + +var udpMessagePool = sync.Pool{ + New: func() interface{} { + return new(udpMessage) + }, +} + +func releaseMessages(messages []*udpMessage) { + for _, message := range messages { + if message != nil { + *message = udpMessage{} + udpMessagePool.Put(message) + } + } +} + +type udpMessage struct { + sessionID uint16 + packetID uint16 + fragmentTotal uint8 + fragmentID uint8 + destination M.Socksaddr + dataLength uint16 + data *buf.Buffer +} + +func (m *udpMessage) release() { + *m = udpMessage{} + udpMessagePool.Put(m) +} + +func (m *udpMessage) releaseMessage() { + m.data.Release() + m.release() +} + +func (m *udpMessage) pack() *buf.Buffer { + buffer := buf.NewSize(m.headerSize() + m.data.Len()) + common.Must( + buffer.WriteByte(Version), + buffer.WriteByte(CommandPacket), + binary.Write(buffer, binary.BigEndian, m.sessionID), + binary.Write(buffer, binary.BigEndian, m.packetID), + binary.Write(buffer, binary.BigEndian, m.fragmentTotal), + binary.Write(buffer, binary.BigEndian, m.fragmentID), + binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())), + addressSerializer.WriteAddrPort(buffer, m.destination), + common.Error(buffer.Write(m.data.Bytes())), + ) + return buffer +} + +func (m *udpMessage) headerSize() int { + return 2 + 10 + addressSerializer.AddrPortLen(m.destination) +} + +func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { + if message.data.Len() <= maxPacketSize { + return []*udpMessage{message} + } + var fragments []*udpMessage + originPacket := message.data.Bytes() + udpMTU := maxPacketSize - message.headerSize() + for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { + fragment := udpMessagePool.Get().(*udpMessage) + *fragment = *message + if remaining > udpMTU { + fragment.data = buf.As(originPacket[:udpMTU]) + originPacket = originPacket[udpMTU:] + } else { + fragment.data = buf.As(originPacket) + originPacket = nil + } + fragments = append(fragments, fragment) + } + fragmentTotal := uint16(len(fragments)) + for index, fragment := range fragments { + fragment.fragmentID = uint8(index) + fragment.fragmentTotal = uint8(fragmentTotal) + if index > 0 { + fragment.destination = M.Socksaddr{} + } + } + return fragments +} + +type udpPacketConn struct { + ctx context.Context + cancel common.ContextCancelCauseFunc + sessionID uint16 + quicConn quic.Connection + data chan *udpMessage + udpStream bool + udpMTU int + packetId atomic.Uint32 + closeOnce sync.Once + isServer bool + defragger *udpDefragger + onDestroy func() +} + +func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn { + ctx, cancel := common.ContextWithCancelCause(ctx) + return &udpPacketConn{ + ctx: ctx, + cancel: cancel, + quicConn: quicConn, + data: make(chan *udpMessage, 64), + udpStream: udpStream, + isServer: isServer, + defragger: newUDPDefragger(), + onDestroy: onDestroy, + } +} + +func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case p := <-c.data: + buffer = p.data + destination = p.destination + p.release() + return + case <-c.ctx.Done(): + return nil, M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case p := <-c.data: + _, err = buffer.ReadOnceFrom(p.data) + destination = p.destination + p.releaseMessage() + return + case <-c.ctx.Done(): + return M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case p := <-c.data: + _, err = newBuffer().ReadOnceFrom(p.data) + destination = p.destination + p.releaseMessage() + return + case <-c.ctx.Done(): + return M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + select { + case pkt := <-c.data: + n = copy(p, pkt.data.Bytes()) + if pkt.destination.IsFqdn() { + addr = pkt.destination + } else { + addr = pkt.destination.UDPAddr() + } + pkt.releaseMessage() + return n, addr, nil + case <-c.ctx.Done(): + return 0, nil, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + defer buffer.Release() + select { + case <-c.ctx.Done(): + return net.ErrClosed + default: + } + if buffer.Len() > 0xffff { + return quic.ErrMessageTooLarge(0xffff) + } + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } + message := udpMessagePool.Get().(*udpMessage) + *message = udpMessage{ + sessionID: c.sessionID, + packetID: uint16(packetId), + fragmentTotal: 1, + destination: destination, + data: buffer, + } + defer message.releaseMessage() + var err error + if !c.udpStream && c.udpMTU > 0 && buffer.Len() > c.udpMTU { + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + } else { + err = c.writePacket(message) + } + if err == nil { + return nil + } + var tooLargeErr quic.ErrMessageTooLarge + if !errors.As(err, &tooLargeErr) { + return err + } + c.udpMTU = int(tooLargeErr) + return c.writePackets(fragUDPMessage(message, c.udpMTU)) +} + +func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + select { + case <-c.ctx.Done(): + return 0, net.ErrClosed + default: + } + if len(p) > 0xffff { + return 0, quic.ErrMessageTooLarge(0xffff) + } + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } + message := udpMessagePool.Get().(*udpMessage) + *message = udpMessage{ + sessionID: c.sessionID, + packetID: uint16(packetId), + fragmentTotal: 1, + destination: M.SocksaddrFromNet(addr), + data: buf.As(p), + } + if c.udpMTU > 0 && len(p) > c.udpMTU { + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + if err == nil { + return len(p), nil + } + } else { + err = c.writePacket(message) + } + if err == nil { + return len(p), nil + } + var tooLargeErr quic.ErrMessageTooLarge + if !errors.As(err, &tooLargeErr) { + return + } + c.udpMTU = int(tooLargeErr) + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + if err == nil { + return len(p), nil + } + return +} + +func (c *udpPacketConn) inputPacket(message *udpMessage) { + if message.fragmentTotal <= 1 { + select { + case c.data <- message: + default: + } + } else { + newMessage := c.defragger.feed(message) + if newMessage != nil { + select { + case c.data <- newMessage: + default: + } + } + } +} + +func (c *udpPacketConn) writePackets(messages []*udpMessage) error { + defer releaseMessages(messages) + for _, message := range messages { + err := c.writePacket(message) + if err != nil { + return err + } + } + return nil +} + +func (c *udpPacketConn) writePacket(message *udpMessage) error { + if !c.udpStream { + buffer := message.pack() + err := c.quicConn.SendMessage(buffer.Bytes()) + buffer.Release() + if err != nil { + return err + } + } else { + stream, err := c.quicConn.OpenUniStream() + if err != nil { + return err + } + buffer := message.pack() + _, err = stream.Write(buffer.Bytes()) + buffer.Release() + stream.Close() + if err != nil { + return err + } + } + return nil +} + +func (c *udpPacketConn) Close() error { + c.closeOnce.Do(func() { + c.closeWithError(os.ErrClosed) + c.onDestroy() + }) + return nil +} + +func (c *udpPacketConn) closeWithError(err error) { + c.cancel(err) + if !c.isServer { + buffer := buf.NewSize(4) + defer buffer.Release() + buffer.WriteByte(Version) + buffer.WriteByte(CommandDissociate) + binary.Write(buffer, binary.BigEndian, c.sessionID) + sendStream, openErr := c.quicConn.OpenUniStream() + if openErr != nil { + return + } + defer sendStream.Close() + sendStream.Write(buffer.Bytes()) + } +} + +func (c *udpPacketConn) LocalAddr() net.Addr { + return c.quicConn.LocalAddr() +} + +func (c *udpPacketConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *udpPacketConn) SetReadDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} + +type udpDefragger struct { + packetMap *cache.LruCache[uint16, *packetItem] +} + +func newUDPDefragger() *udpDefragger { + return &udpDefragger{ + packetMap: cache.New( + cache.WithAge[uint16, *packetItem](10), + cache.WithUpdateAgeOnGet[uint16, *packetItem](), + cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) { + releaseMessages(value.messages) + }), + ), + } +} + +type packetItem struct { + access sync.Mutex + messages []*udpMessage + count uint8 +} + +func (d *udpDefragger) feed(m *udpMessage) *udpMessage { + if m.fragmentTotal <= 1 { + return m + } + if m.fragmentID >= m.fragmentTotal { + return nil + } + item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem) + item.access.Lock() + defer item.access.Unlock() + if int(m.fragmentTotal) != len(item.messages) { + releaseMessages(item.messages) + item.messages = make([]*udpMessage, m.fragmentTotal) + item.count = 1 + item.messages[m.fragmentID] = m + return nil + } + if item.messages[m.fragmentID] != nil { + return nil + } + item.messages[m.fragmentID] = m + item.count++ + if int(item.count) != len(item.messages) { + return nil + } + newMessage := udpMessagePool.Get().(*udpMessage) + *newMessage = *item.messages[0] + if m.dataLength > 0 { + newMessage.data = buf.NewSize(int(m.dataLength)) + for _, message := range item.messages { + newMessage.data.Write(message.data.Bytes()) + message.releaseMessage() + } + item.messages = nil + return newMessage + } + return nil +} + +func newPacketItem() *packetItem { + return new(packetItem) +} + +func readUDPMessage(message *udpMessage, reader io.Reader) error { + err := binary.Read(reader, binary.BigEndian, &message.sessionID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.packetID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.fragmentID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.dataLength) + if err != nil { + return err + } + message.destination, err = addressSerializer.ReadAddrPort(reader) + if err != nil { + return err + } + message.data = buf.NewSize(int(message.dataLength)) + _, err = message.data.ReadFullFrom(reader, message.data.FreeLen()) + if err != nil { + return err + } + return nil +} + +func decodeUDPMessage(message *udpMessage, data []byte) error { + reader := bytes.NewReader(data) + err := binary.Read(reader, binary.BigEndian, &message.sessionID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.packetID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.fragmentID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.dataLength) + if err != nil { + return err + } + message.destination, err = addressSerializer.ReadAddrPort(reader) + if err != nil { + return err + } + if reader.Len() != int(message.dataLength) { + return io.ErrUnexpectedEOF + } + message.data = buf.As(data[len(data)-reader.Len():]) + return nil +} diff --git a/transport/tuic/protocol.go b/transport/tuic/protocol.go new file mode 100644 index 00000000..1247516b --- /dev/null +++ b/transport/tuic/protocol.go @@ -0,0 +1,15 @@ +package tuic + +const ( + Version = 5 +) + +const ( + CommandAuthenticate = iota + CommandConnect + CommandPacket + CommandDissociate + CommandHeartbeat +) + +const AuthenticateLen = 2 + 16 + 32 diff --git a/transport/tuic/server.go b/transport/tuic/server.go new file mode 100644 index 00000000..4a40b44f --- /dev/null +++ b/transport/tuic/server.go @@ -0,0 +1,434 @@ +package tuic + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/binary" + "io" + "net" + "runtime" + "strings" + "sync" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-box/common/baderror" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/gofrs/uuid/v5" +) + +type ServerOptions struct { + Context context.Context + Logger logger.Logger + TLSConfig *tls.Config + Users []User + CongestionControl string + AuthTimeout time.Duration + ZeroRTTHandshake bool + Heartbeat time.Duration + Handler ServerHandler +} + +type User struct { + Name string + UUID uuid.UUID + Password string +} + +type ServerHandler interface { + N.TCPConnectionHandler + N.UDPConnectionHandler +} + +type Server struct { + ctx context.Context + logger logger.Logger + tlsConfig *tls.Config + heartbeat time.Duration + quicConfig *quic.Config + userMap map[uuid.UUID]User + congestionControl string + authTimeout time.Duration + handler ServerHandler + + quicListener io.Closer +} + +func NewServer(options ServerOptions) (*Server, error) { + if options.AuthTimeout == 0 { + options.AuthTimeout = 3 * time.Second + } + if options.Heartbeat == 0 { + options.Heartbeat = 10 * time.Second + } + quicConfig := &quic.Config{ + DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), + MaxDatagramFrameSize: 1400, + EnableDatagrams: true, + Allow0RTT: options.ZeroRTTHandshake, + MaxIncomingStreams: 1 << 60, + MaxIncomingUniStreams: 1 << 60, + } + switch options.CongestionControl { + case "": + options.CongestionControl = "cubic" + case "cubic", "new_reno", "bbr": + default: + return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) + } + if len(options.Users) == 0 { + return nil, E.New("missing users") + } + userMap := make(map[uuid.UUID]User) + for _, user := range options.Users { + userMap[user.UUID] = user + } + return &Server{ + ctx: options.Context, + logger: options.Logger, + tlsConfig: options.TLSConfig, + heartbeat: options.Heartbeat, + quicConfig: quicConfig, + userMap: userMap, + congestionControl: options.CongestionControl, + authTimeout: options.AuthTimeout, + handler: options.Handler, + }, nil +} + +func (s *Server) Start(conn net.PacketConn) error { + if !s.quicConfig.Allow0RTT { + listener, err := quic.Listen(conn, s.tlsConfig, s.quicConfig) + if err != nil { + return err + } + s.quicListener = listener + go func() { + for { + connection, hErr := listener.Accept(s.ctx) + if hErr != nil { + if strings.Contains(hErr.Error(), "server closed") { + s.logger.Debug(E.Cause(hErr, "listener closed")) + } else { + s.logger.Error(E.Cause(hErr, "listener closed")) + } + return + } + go s.handleConnection(connection) + } + }() + } else { + listener, err := quic.ListenEarly(conn, s.tlsConfig, s.quicConfig) + if err != nil { + return err + } + s.quicListener = listener + go func() { + for { + connection, hErr := listener.Accept(s.ctx) + if hErr != nil { + if strings.Contains(hErr.Error(), "server closed") { + s.logger.Debug(E.Cause(hErr, "listener closed")) + } else { + s.logger.Error(E.Cause(hErr, "listener closed")) + } + return + } + go s.handleConnection(connection) + } + }() + } + return nil +} + +func (s *Server) Close() error { + return common.Close( + s.quicListener, + ) +} + +func (s *Server) handleConnection(connection quic.Connection) { + setCongestion(s.ctx, connection, s.congestionControl) + session := &serverSession{ + Server: s, + ctx: s.ctx, + quicConn: connection, + source: M.SocksaddrFromNet(connection.RemoteAddr()), + connDone: make(chan struct{}), + authDone: make(chan struct{}), + udpConnMap: make(map[uint16]*udpPacketConn), + } + session.handle() +} + +type serverSession struct { + *Server + ctx context.Context + quicConn quic.Connection + source M.Socksaddr + connAccess sync.Mutex + connDone chan struct{} + connErr error + authDone chan struct{} + authUser *User + udpAccess sync.RWMutex + udpConnMap map[uint16]*udpPacketConn +} + +func (s *serverSession) handle() { + if s.ctx.Done() != nil { + go func() { + select { + case <-s.ctx.Done(): + s.closeWithError(s.ctx.Err()) + case <-s.connDone: + } + }() + } + go s.loopUniStreams() + go s.loopStreams() + go s.loopMessages() + go s.handleAuthTimeout() + go s.loopHeartbeats() +} + +func (s *serverSession) loopUniStreams() { + for { + uniStream, err := s.quicConn.AcceptUniStream(s.ctx) + if err != nil { + return + } + go func() { + err = s.handleUniStream(uniStream) + if err != nil { + s.closeWithError(E.Cause(err, "handle uni stream")) + } + }() + } +} + +func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error { + defer stream.CancelRead(0) + buffer := buf.New() + defer buffer.Release() + _, err := buffer.ReadAtLeastFrom(stream, 2) + if err != nil { + return E.Cause(err, "read request") + } + version := buffer.Byte(0) + if version != Version { + return E.New("unknown version ", buffer.Byte(0)) + } + command := buffer.Byte(1) + switch command { + case CommandAuthenticate: + select { + case <-s.authDone: + return E.New("authentication: multiple authentication requests") + default: + } + if buffer.Len() < AuthenticateLen { + _, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len()) + if err != nil { + return E.Cause(err, "authentication: read request") + } + } + userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16)) + user, loaded := s.userMap[userUUID] + if !loaded { + return E.New("authentication: unknown user ", userUUID) + } + handshakeState := s.quicConn.ConnectionState().TLS + tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32) + if err != nil { + return E.Cause(err, "authentication: export keying material") + } + if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) { + return E.New("authentication: token mismatch") + } + s.authUser = &user + close(s.authDone) + return nil + case CommandPacket: + select { + case <-s.connDone: + return s.connErr + case <-s.authDone: + } + message := udpMessagePool.Get().(*udpMessage) + err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream)) + if err != nil { + message.release() + return err + } + s.handleUDPMessage(message, true) + return nil + case CommandDissociate: + select { + case <-s.connDone: + return s.connErr + case <-s.authDone: + } + if buffer.Len() > 4 { + return E.New("invalid dissociate message") + } + var sessionID uint16 + err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID) + if err != nil { + return err + } + s.udpAccess.RLock() + udpConn, loaded := s.udpConnMap[sessionID] + s.udpAccess.RUnlock() + if loaded { + udpConn.closeWithError(E.New("remote closed")) + s.udpAccess.Lock() + delete(s.udpConnMap, sessionID) + s.udpAccess.Unlock() + } + return nil + default: + return E.New("unknown command ", command) + } +} + +func (s *serverSession) handleAuthTimeout() { + select { + case <-s.connDone: + case <-s.authDone: + case <-time.After(s.authTimeout): + s.closeWithError(E.New("authentication timeout")) + } +} + +func (s *serverSession) loopStreams() { + for { + stream, err := s.quicConn.AcceptStream(s.ctx) + if err != nil { + return + } + go func() { + err = s.handleStream(stream) + if err != nil { + stream.CancelRead(0) + stream.Close() + s.logger.Error(E.Cause(err, "handle stream request")) + } + }() + } +} + +func (s *serverSession) handleStream(stream quic.Stream) error { + buffer := buf.NewSize(2 + M.MaxSocksaddrLength) + defer buffer.Release() + _, err := buffer.ReadAtLeastFrom(stream, 2) + if err != nil { + return E.Cause(err, "read request") + } + version, _ := buffer.ReadByte() + if version != Version { + return E.New("unknown version ", buffer.Byte(0)) + } + command, _ := buffer.ReadByte() + if command != CommandConnect { + return E.New("unsupported stream command ", command) + } + destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream)) + if err != nil { + return E.Cause(err, "read request destination") + } + select { + case <-s.connDone: + return s.connErr + case <-s.authDone: + } + var conn net.Conn = &serverConn{ + Stream: stream, + destination: destination, + } + if buffer.IsEmpty() { + buffer.Release() + } else { + conn = bufio.NewCachedConn(conn, buffer) + } + ctx := s.ctx + if s.authUser.Name != "" { + ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) + } + _ = s.handler.NewConnection(ctx, conn, M.Metadata{ + Source: s.source, + Destination: destination, + }) + return nil +} + +func (s *serverSession) loopHeartbeats() { + ticker := time.NewTicker(s.heartbeat) + defer ticker.Stop() + for { + select { + case <-s.connDone: + return + case <-ticker.C: + err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat}) + if err != nil { + s.closeWithError(E.Cause(err, "send heartbeat")) + } + } + } +} + +func (s *serverSession) closeWithError(err error) { + s.connAccess.Lock() + defer s.connAccess.Unlock() + select { + case <-s.connDone: + return + default: + s.connErr = err + close(s.connDone) + } + if E.IsClosedOrCanceled(err) { + s.logger.Debug(E.Cause(err, "connection failed")) + } else { + s.logger.Error(E.Cause(err, "connection failed")) + } + _ = s.quicConn.CloseWithError(0, "") +} + +type serverConn struct { + quic.Stream + destination M.Socksaddr +} + +func (c *serverConn) Read(p []byte) (n int, err error) { + n, err = c.Stream.Read(p) + return n, baderror.WrapQUIC(err) +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + n, err = c.Stream.Write(p) + return n, baderror.WrapQUIC(err) +} + +func (c *serverConn) LocalAddr() net.Addr { + return c.destination +} + +func (c *serverConn) RemoteAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *serverConn) Close() error { + c.Stream.CancelRead(0) + return c.Stream.Close() +} diff --git a/transport/tuic/server_packet.go b/transport/tuic/server_packet.go new file mode 100644 index 00000000..fba6118a --- /dev/null +++ b/transport/tuic/server_packet.go @@ -0,0 +1,73 @@ +package tuic + +import ( + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" +) + +func (s *serverSession) loopMessages() { + select { + case <-s.connDone: + return + case <-s.authDone: + } + for { + message, err := s.quicConn.ReceiveMessage(s.ctx) + if err != nil { + s.closeWithError(E.Cause(err, "receive message")) + return + } + hErr := s.handleMessage(message) + if hErr != nil { + s.closeWithError(E.Cause(hErr, "handle message")) + return + } + } +} + +func (s *serverSession) handleMessage(data []byte) error { + if len(data) < 2 { + return E.New("invalid message") + } + if data[0] != Version { + return E.New("unknown version ", data[0]) + } + switch data[1] { + case CommandPacket: + message := udpMessagePool.Get().(*udpMessage) + err := decodeUDPMessage(message, data[2:]) + if err != nil { + message.release() + return E.Cause(err, "decode UDP message") + } + s.handleUDPMessage(message, false) + return nil + case CommandHeartbeat: + return nil + default: + return E.New("unknown command ", data[0]) + } +} + +func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) { + s.udpAccess.RLock() + udpConn, loaded := s.udpConnMap[message.sessionID] + s.udpAccess.RUnlock() + if !loaded || common.Done(udpConn.ctx) { + udpConn = newUDPPacketConn(s.ctx, s.quicConn, udpStream, true, func() { + s.udpAccess.Lock() + delete(s.udpConnMap, message.sessionID) + s.udpAccess.Unlock() + }) + udpConn.sessionID = message.sessionID + s.udpAccess.Lock() + s.udpConnMap[message.sessionID] = udpConn + s.udpAccess.Unlock() + go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{ + Source: s.source, + Destination: message.destination, + }) + } + udpConn.inputPacket(message) +}