mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-22 00:21:30 +00:00
Add TUIC protocol
This commit is contained in:
parent
0b14dc3228
commit
917420e79a
|
@ -21,6 +21,7 @@ const (
|
|||
TypeShadowTLS = "shadowtls"
|
||||
TypeShadowsocksR = "shadowsocksr"
|
||||
TypeVLESS = "vless"
|
||||
TypeTUIC = "tuic"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -62,6 +63,8 @@ func ProxyDisplayName(proxyType string) string {
|
|||
return "ShadowsocksR"
|
||||
case TypeVLESS:
|
||||
return "VLESS"
|
||||
case TypeTUIC:
|
||||
return "TUIC"
|
||||
case TypeSelector:
|
||||
return "Selector"
|
||||
case TypeURLTest:
|
||||
|
|
|
@ -44,6 +44,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o
|
|||
return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions)
|
||||
case C.TypeVLESS:
|
||||
return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions)
|
||||
case C.TypeTUIC:
|
||||
return NewTUIC(ctx, router, logger, options.Tag, options.TUICOptions)
|
||||
default:
|
||||
return nil, E.New("unknown inbound type: ", options.Type)
|
||||
}
|
||||
|
|
|
@ -153,6 +153,17 @@ func (a *myInboundAdapter) createMetadata(conn net.Conn, metadata adapter.Inboun
|
|||
return metadata
|
||||
}
|
||||
|
||||
func (a *myInboundAdapter) createPacketMetadata(conn N.PacketConn, metadata adapter.InboundContext) adapter.InboundContext {
|
||||
metadata.Inbound = a.tag
|
||||
metadata.InboundType = a.protocol
|
||||
metadata.InboundDetour = a.listenOptions.Detour
|
||||
metadata.InboundOptions = a.listenOptions.InboundOptions
|
||||
if !metadata.Destination.IsValid() {
|
||||
metadata.Destination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
func (a *myInboundAdapter) newError(err error) {
|
||||
a.logger.Error(err)
|
||||
}
|
||||
|
|
114
inbound/tuic.go
Normal file
114
inbound/tuic.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
//go:build with_quic
|
||||
|
||||
package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/tuic"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
var _ adapter.Inbound = (*TUIC)(nil)
|
||||
|
||||
type TUIC struct {
|
||||
myInboundAdapter
|
||||
server *tuic.Server
|
||||
tlsConfig tls.ServerConfig
|
||||
}
|
||||
|
||||
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) {
|
||||
options.UDPFragmentDefault = true
|
||||
if options.TLS == nil || !options.TLS.Enabled {
|
||||
return nil, C.ErrTLSRequired
|
||||
}
|
||||
tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawConfig, err := tlsConfig.Config()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var users []tuic.User
|
||||
for index, user := range options.Users {
|
||||
if user.UUID == "" {
|
||||
return nil, E.New("missing uuid for user ", index)
|
||||
}
|
||||
userUUID, err := uuid.FromString(user.UUID)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "invalid uuid for user ", index)
|
||||
}
|
||||
users = append(users, tuic.User{Name: user.Name, UUID: userUUID, Password: user.Password})
|
||||
}
|
||||
inbound := &TUIC{
|
||||
myInboundAdapter: myInboundAdapter{
|
||||
protocol: C.TypeTUIC,
|
||||
network: []string{N.NetworkUDP},
|
||||
ctx: ctx,
|
||||
router: router,
|
||||
logger: logger,
|
||||
tag: tag,
|
||||
listenOptions: options.ListenOptions,
|
||||
},
|
||||
}
|
||||
server, err := tuic.NewServer(tuic.ServerOptions{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
TLSConfig: rawConfig,
|
||||
Users: users,
|
||||
CongestionControl: options.CongestionControl,
|
||||
AuthTimeout: time.Duration(options.AuthTimeout),
|
||||
ZeroRTTHandshake: options.ZeroRTTHandshake,
|
||||
Heartbeat: time.Duration(options.Heartbeat),
|
||||
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbound.server = server
|
||||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
ctx = log.ContextWithNewID(ctx)
|
||||
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||
metadata = h.createMetadata(conn, metadata)
|
||||
metadata.User, _ = auth.UserFromContext[string](ctx)
|
||||
return h.router.RouteConnection(ctx, conn, metadata)
|
||||
}
|
||||
|
||||
func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
ctx = log.ContextWithNewID(ctx)
|
||||
metadata = h.createPacketMetadata(conn, metadata)
|
||||
metadata.User, _ = auth.UserFromContext[string](ctx)
|
||||
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
||||
return h.router.RoutePacketConnection(ctx, conn, metadata)
|
||||
}
|
||||
|
||||
func (h *TUIC) Start() error {
|
||||
packetConn, err := h.myInboundAdapter.ListenUDP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.server.Start(packetConn)
|
||||
}
|
||||
|
||||
func (h *TUIC) Close() error {
|
||||
return common.Close(
|
||||
&h.myInboundAdapter,
|
||||
common.PtrOrNil(h.server),
|
||||
)
|
||||
}
|
16
inbound/tuic_stub.go
Normal file
16
inbound/tuic_stub.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
//go:build !with_quic
|
||||
|
||||
package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (adapter.Inbound, error) {
|
||||
return nil, C.ErrQUICNotIncluded
|
||||
}
|
|
@ -23,6 +23,7 @@ type _Inbound struct {
|
|||
HysteriaOptions HysteriaInboundOptions `json:"-"`
|
||||
ShadowTLSOptions ShadowTLSInboundOptions `json:"-"`
|
||||
VLESSOptions VLESSInboundOptions `json:"-"`
|
||||
TUICOptions TUICInboundOptions `json:"-"`
|
||||
}
|
||||
|
||||
type Inbound _Inbound
|
||||
|
@ -58,6 +59,8 @@ func (h Inbound) MarshalJSON() ([]byte, error) {
|
|||
v = h.ShadowTLSOptions
|
||||
case C.TypeVLESS:
|
||||
v = h.VLESSOptions
|
||||
case C.TypeTUIC:
|
||||
v = h.TUICOptions
|
||||
default:
|
||||
return nil, E.New("unknown inbound type: ", h.Type)
|
||||
}
|
||||
|
@ -99,6 +102,8 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error {
|
|||
v = &h.ShadowTLSOptions
|
||||
case C.TypeVLESS:
|
||||
v = &h.VLESSOptions
|
||||
case C.TypeTUIC:
|
||||
v = &h.TUICOptions
|
||||
default:
|
||||
return E.New("unknown inbound type: ", h.Type)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ type _Outbound struct {
|
|||
ShadowTLSOptions ShadowTLSOutboundOptions `json:"-"`
|
||||
ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"`
|
||||
VLESSOptions VLESSOutboundOptions `json:"-"`
|
||||
TUICOptions TUICOutboundOptions `json:"-"`
|
||||
SelectorOptions SelectorOutboundOptions `json:"-"`
|
||||
URLTestOptions URLTestOutboundOptions `json:"-"`
|
||||
}
|
||||
|
@ -60,6 +61,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) {
|
|||
v = h.ShadowsocksROptions
|
||||
case C.TypeVLESS:
|
||||
v = h.VLESSOptions
|
||||
case C.TypeTUIC:
|
||||
v = h.TUICOptions
|
||||
case C.TypeSelector:
|
||||
v = h.SelectorOptions
|
||||
case C.TypeURLTest:
|
||||
|
@ -105,6 +108,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error {
|
|||
v = &h.ShadowsocksROptions
|
||||
case C.TypeVLESS:
|
||||
v = &h.VLESSOptions
|
||||
case C.TypeTUIC:
|
||||
v = &h.TUICOptions
|
||||
case C.TypeSelector:
|
||||
v = &h.SelectorOptions
|
||||
case C.TypeURLTest:
|
||||
|
|
30
option/tuic.go
Normal file
30
option/tuic.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package option
|
||||
|
||||
type TUICInboundOptions struct {
|
||||
ListenOptions
|
||||
Users []TUICUser `json:"users,omitempty"`
|
||||
CongestionControl string `json:"congestion_control,omitempty"`
|
||||
AuthTimeout Duration `json:"auth_timeout,omitempty"`
|
||||
ZeroRTTHandshake bool `json:"zero_rtt_handshake,omitempty"`
|
||||
Heartbeat Duration `json:"heartbeat,omitempty"`
|
||||
TLS *InboundTLSOptions `json:"tls,omitempty"`
|
||||
}
|
||||
|
||||
type TUICUser struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
UUID string `json:"uuid,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
type TUICOutboundOptions struct {
|
||||
DialerOptions
|
||||
ServerOptions
|
||||
UUID string `json:"uuid,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
CongestionControl string `json:"congestion_control,omitempty"`
|
||||
UDPRelayMode string `json:"udp_relay_mode,omitempty"`
|
||||
ZeroRTTHandshake bool `json:"zero_rtt_handshake,omitempty"`
|
||||
Heartbeat Duration `json:"heartbeat,omitempty"`
|
||||
Network NetworkList `json:"network,omitempty"`
|
||||
TLS *OutboundTLSOptions `json:"tls,omitempty"`
|
||||
}
|
|
@ -51,6 +51,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, t
|
|||
return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions)
|
||||
case C.TypeVLESS:
|
||||
return NewVLESS(ctx, router, logger, tag, options.VLESSOptions)
|
||||
case C.TypeTUIC:
|
||||
return NewTUIC(ctx, router, logger, tag, options.TUICOptions)
|
||||
case C.TypeSelector:
|
||||
return NewSelector(router, logger, tag, options.SelectorOptions)
|
||||
case C.TypeURLTest:
|
||||
|
|
123
outbound/tuic.go
Normal file
123
outbound/tuic.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
//go:build with_quic
|
||||
|
||||
package outbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/tuic"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
_ adapter.Outbound = (*TUIC)(nil)
|
||||
_ adapter.InterfaceUpdateListener = (*TUIC)(nil)
|
||||
)
|
||||
|
||||
type TUIC struct {
|
||||
myOutboundAdapter
|
||||
client *tuic.Client
|
||||
}
|
||||
|
||||
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (*TUIC, error) {
|
||||
options.UDPFragmentDefault = true
|
||||
if options.TLS == nil || !options.TLS.Enabled {
|
||||
return nil, C.ErrTLSRequired
|
||||
}
|
||||
abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig, err := abstractTLSConfig.Config()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userUUID, err := uuid.FromString(options.UUID)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "invalid uuid")
|
||||
}
|
||||
var udpStream bool
|
||||
switch options.UDPRelayMode {
|
||||
case "native":
|
||||
case "quic":
|
||||
udpStream = true
|
||||
}
|
||||
client, err := tuic.NewClient(tuic.ClientOptions{
|
||||
Context: ctx,
|
||||
Dialer: dialer.New(router, options.DialerOptions),
|
||||
ServerAddress: options.ServerOptions.Build(),
|
||||
TLSConfig: tlsConfig,
|
||||
UUID: userUUID,
|
||||
Password: options.Password,
|
||||
CongestionControl: options.CongestionControl,
|
||||
UDPStream: udpStream,
|
||||
ZeroRTTHandshake: options.ZeroRTTHandshake,
|
||||
Heartbeat: time.Duration(options.Heartbeat),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TUIC{
|
||||
myOutboundAdapter: myOutboundAdapter{
|
||||
protocol: C.TypeTUIC,
|
||||
network: options.Network.Build(),
|
||||
router: router,
|
||||
logger: logger,
|
||||
tag: tag,
|
||||
dependencies: withDialerDependency(options.DialerOptions),
|
||||
},
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *TUIC) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
h.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
return h.client.DialConn(ctx, destination)
|
||||
case N.NetworkUDP:
|
||||
conn, err := h.ListenPacket(ctx, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bufio.NewBindPacketConn(conn, destination), nil
|
||||
default:
|
||||
return nil, E.New("unsupported network: ", network)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TUIC) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
return h.client.ListenPacket(ctx)
|
||||
}
|
||||
|
||||
func (h *TUIC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
return NewConnection(ctx, h, conn, metadata)
|
||||
}
|
||||
|
||||
func (h *TUIC) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
return NewPacketConnection(ctx, h, conn, metadata)
|
||||
}
|
||||
|
||||
func (h *TUIC) InterfaceUpdated() {
|
||||
_ = h.client.CloseWithError(E.New("network changed"))
|
||||
}
|
||||
|
||||
func (h *TUIC) Close() error {
|
||||
return h.client.CloseWithError(os.ErrClosed)
|
||||
}
|
16
outbound/tuic_stub.go
Normal file
16
outbound/tuic_stub.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
//go:build !with_quic
|
||||
|
||||
package outbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (adapter.Outbound, error) {
|
||||
return nil, C.ErrQUICNotIncluded
|
||||
}
|
|
@ -38,6 +38,8 @@ const (
|
|||
ImageShadowsocksR = "teddysun/shadowsocks-r:latest"
|
||||
ImageXRayCore = "teddysun/xray:latest"
|
||||
ImageShadowsocksLegacy = "mritd/shadowsocks:latest"
|
||||
ImageTUICServer = ""
|
||||
ImageTUICClient = ""
|
||||
)
|
||||
|
||||
var allImages = []string{
|
||||
|
@ -53,6 +55,8 @@ var allImages = []string{
|
|||
ImageShadowsocksR,
|
||||
ImageXRayCore,
|
||||
ImageShadowsocksLegacy,
|
||||
// ImageTUICServer,
|
||||
// ImageTUICClient,
|
||||
}
|
||||
|
||||
var localIP = netip.MustParseAddr("127.0.0.1")
|
||||
|
|
14
test/config/tuic-client.json
Normal file
14
test/config/tuic-client.json
Normal file
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"relay": {
|
||||
"server": "127.0.0.1:10000",
|
||||
"uuid": "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
|
||||
"password": "tuic",
|
||||
"certificates": [
|
||||
"/etc/tuic/ca.pem"
|
||||
]
|
||||
},
|
||||
"local": {
|
||||
"server": "127.0.0.1:10001"
|
||||
},
|
||||
"log_level": "debug"
|
||||
}
|
9
test/config/tuic-server.json
Normal file
9
test/config/tuic-server.json
Normal file
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"server": "[::]:10000",
|
||||
"users": {
|
||||
"FE35D05B-8803-45C4-BAE6-723AD2CD5D3D": "tuic"
|
||||
},
|
||||
"certificate": "/etc/tuic/cert.pem",
|
||||
"private_key": "/etc/tuic/key.pem",
|
||||
"log_level": "debug"
|
||||
}
|
178
test/tuic_test.go
Normal file
178
test/tuic_test.go
Normal file
|
@ -0,0 +1,178 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
func TestTUICSelf(t *testing.T) {
|
||||
t.Run("self", func(t *testing.T) {
|
||||
testTUICSelf(t, false, false)
|
||||
})
|
||||
t.Run("self-udp-stream", func(t *testing.T) {
|
||||
testTUICSelf(t, true, false)
|
||||
})
|
||||
t.Run("self-early", func(t *testing.T) {
|
||||
testTUICSelf(t, false, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testTUICSelf(t *testing.T, udpStream bool, zeroRTTHandshake bool) {
|
||||
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
|
||||
var udpRelayMode string
|
||||
if udpStream {
|
||||
udpRelayMode = "quic"
|
||||
}
|
||||
startInstance(t, option.Options{
|
||||
Inbounds: []option.Inbound{
|
||||
{
|
||||
Type: C.TypeMixed,
|
||||
Tag: "mixed-in",
|
||||
MixedOptions: option.HTTPMixedInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
|
||||
ListenPort: clientPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.TypeTUIC,
|
||||
TUICOptions: option.TUICInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
|
||||
ListenPort: serverPort,
|
||||
},
|
||||
Users: []option.TUICUser{{
|
||||
UUID: uuid.Nil.String(),
|
||||
}},
|
||||
ZeroRTTHandshake: zeroRTTHandshake,
|
||||
TLS: &option.InboundTLSOptions{
|
||||
Enabled: true,
|
||||
ServerName: "example.org",
|
||||
CertificatePath: certPem,
|
||||
KeyPath: keyPem,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Outbounds: []option.Outbound{
|
||||
{
|
||||
Type: C.TypeDirect,
|
||||
},
|
||||
{
|
||||
Type: C.TypeTUIC,
|
||||
Tag: "tuic-out",
|
||||
TUICOptions: option.TUICOutboundOptions{
|
||||
ServerOptions: option.ServerOptions{
|
||||
Server: "127.0.0.1",
|
||||
ServerPort: serverPort,
|
||||
},
|
||||
UUID: uuid.Nil.String(),
|
||||
UDPRelayMode: udpRelayMode,
|
||||
ZeroRTTHandshake: zeroRTTHandshake,
|
||||
TLS: &option.OutboundTLSOptions{
|
||||
Enabled: true,
|
||||
ServerName: "example.org",
|
||||
CertificatePath: certPem,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Route: &option.RouteOptions{
|
||||
Rules: []option.Rule{
|
||||
{
|
||||
DefaultOptions: option.DefaultRule{
|
||||
Inbound: []string{"mixed-in"},
|
||||
Outbound: "tuic-out",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
testSuit(t, clientPort, testPort)
|
||||
}
|
||||
|
||||
func TestTUICInbound(t *testing.T) {
|
||||
caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
|
||||
startInstance(t, option.Options{
|
||||
Inbounds: []option.Inbound{
|
||||
{
|
||||
Type: C.TypeTUIC,
|
||||
TUICOptions: option.TUICInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
|
||||
ListenPort: serverPort,
|
||||
},
|
||||
Users: []option.TUICUser{{
|
||||
UUID: "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
|
||||
Password: "tuic",
|
||||
}},
|
||||
TLS: &option.InboundTLSOptions{
|
||||
Enabled: true,
|
||||
ServerName: "example.org",
|
||||
CertificatePath: certPem,
|
||||
KeyPath: keyPem,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
startDockerContainer(t, DockerOptions{
|
||||
Image: ImageTUICClient,
|
||||
Ports: []uint16{serverPort, clientPort},
|
||||
Bind: map[string]string{
|
||||
"tuic-client.json": "/etc/tuic/config.json",
|
||||
caPem: "/etc/tuic/ca.pem",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestTUICOutbound(t *testing.T) {
|
||||
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
|
||||
startDockerContainer(t, DockerOptions{
|
||||
Image: ImageTUICServer,
|
||||
Ports: []uint16{testPort},
|
||||
Bind: map[string]string{
|
||||
"tuic-server.json": "/etc/tuic/config.json",
|
||||
certPem: "/etc/tuic/cert.pem",
|
||||
keyPem: "/etc/tuic/key.pem",
|
||||
},
|
||||
})
|
||||
startInstance(t, option.Options{
|
||||
Inbounds: []option.Inbound{
|
||||
{
|
||||
Type: C.TypeMixed,
|
||||
MixedOptions: option.HTTPMixedInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
|
||||
ListenPort: clientPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Outbounds: []option.Outbound{
|
||||
{
|
||||
Type: C.TypeTUIC,
|
||||
TUICOptions: option.TUICOutboundOptions{
|
||||
ServerOptions: option.ServerOptions{
|
||||
Server: "127.0.0.1",
|
||||
ServerPort: serverPort,
|
||||
},
|
||||
UUID: "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
|
||||
Password: "tuic",
|
||||
TLS: &option.OutboundTLSOptions{
|
||||
Enabled: true,
|
||||
ServerName: "example.org",
|
||||
CertificatePath: certPem,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
testSuit(t, clientPort, testPort)
|
||||
}
|
10
transport/tuic/address.go
Normal file
10
transport/tuic/address.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package tuic
|
||||
|
||||
import M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
var addressSerializer = M.NewSerializer(
|
||||
M.AddressFamilyByte(0x00, M.AddressFamilyFqdn),
|
||||
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
|
||||
M.AddressFamilyByte(0x02, M.AddressFamilyIPv6),
|
||||
M.AddressFamilyByte(0xff, M.AddressFamilyEmpty),
|
||||
)
|
322
transport/tuic/client.go
Normal file
322
transport/tuic/client.go
Normal file
|
@ -0,0 +1,322 @@
|
|||
package tuic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing-box/common/baderror"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
type ClientOptions struct {
|
||||
Context context.Context
|
||||
Dialer N.Dialer
|
||||
ServerAddress M.Socksaddr
|
||||
TLSConfig *tls.Config
|
||||
UUID uuid.UUID
|
||||
Password string
|
||||
CongestionControl string
|
||||
UDPStream bool
|
||||
ZeroRTTHandshake bool
|
||||
Heartbeat time.Duration
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
ctx context.Context
|
||||
dialer N.Dialer
|
||||
serverAddr M.Socksaddr
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
uuid uuid.UUID
|
||||
password string
|
||||
congestionControl string
|
||||
udpStream bool
|
||||
zeroRTTHandshake bool
|
||||
heartbeat time.Duration
|
||||
|
||||
connAccess sync.RWMutex
|
||||
conn *clientQUICConnection
|
||||
}
|
||||
|
||||
func NewClient(options ClientOptions) (*Client, error) {
|
||||
if options.Heartbeat == 0 {
|
||||
options.Heartbeat = 10 * time.Second
|
||||
}
|
||||
quicConfig := &quic.Config{
|
||||
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
||||
MaxDatagramFrameSize: 1400,
|
||||
EnableDatagrams: true,
|
||||
MaxIncomingUniStreams: 1 << 60,
|
||||
}
|
||||
switch options.CongestionControl {
|
||||
case "":
|
||||
options.CongestionControl = "cubic"
|
||||
case "cubic", "new_reno", "bbr":
|
||||
default:
|
||||
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
|
||||
}
|
||||
return &Client{
|
||||
ctx: options.Context,
|
||||
dialer: options.Dialer,
|
||||
serverAddr: options.ServerAddress,
|
||||
tlsConfig: options.TLSConfig,
|
||||
quicConfig: quicConfig,
|
||||
uuid: options.UUID,
|
||||
password: options.Password,
|
||||
congestionControl: options.CongestionControl,
|
||||
udpStream: options.UDPStream,
|
||||
zeroRTTHandshake: options.ZeroRTTHandshake,
|
||||
heartbeat: options.Heartbeat,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
|
||||
conn := c.conn
|
||||
if conn != nil && conn.active() {
|
||||
return conn, nil
|
||||
}
|
||||
c.connAccess.Lock()
|
||||
defer c.connAccess.Unlock()
|
||||
conn = c.conn
|
||||
if conn != nil && conn.active() {
|
||||
return conn, nil
|
||||
}
|
||||
conn, err := c.offerNew(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
|
||||
udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var quicConn quic.Connection
|
||||
if c.zeroRTTHandshake {
|
||||
quicConn, err = quic.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
|
||||
} else {
|
||||
quicConn, err = quic.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
|
||||
}
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
return nil, E.Cause(err, "open connection")
|
||||
}
|
||||
setCongestion(c.ctx, quicConn, c.congestionControl)
|
||||
conn := &clientQUICConnection{
|
||||
quicConn: quicConn,
|
||||
rawConn: udpConn,
|
||||
connDone: make(chan struct{}),
|
||||
udpConnMap: make(map[uint16]*udpPacketConn),
|
||||
}
|
||||
go func() {
|
||||
hErr := c.clientHandshake(quicConn)
|
||||
if hErr != nil {
|
||||
conn.closeWithError(hErr)
|
||||
}
|
||||
}()
|
||||
if c.udpStream {
|
||||
go c.loopUniStreams(conn)
|
||||
}
|
||||
go c.loopMessages(conn)
|
||||
go c.loopHeartbeats(conn)
|
||||
c.conn = conn
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) clientHandshake(conn quic.Connection) error {
|
||||
authStream, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer authStream.Close()
|
||||
handshakeState := conn.ConnectionState().TLS
|
||||
tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
authRequest := buf.NewSize(AuthenticateLen)
|
||||
authRequest.WriteByte(Version)
|
||||
authRequest.WriteByte(CommandAuthenticate)
|
||||
authRequest.Write(c.uuid[:])
|
||||
authRequest.Write(tuicAuthToken)
|
||||
return common.Error(authStream.Write(authRequest.Bytes()))
|
||||
}
|
||||
|
||||
func (c *Client) loopHeartbeats(conn *clientQUICConnection) {
|
||||
ticker := time.NewTicker(c.heartbeat)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-conn.connDone:
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
|
||||
if err != nil {
|
||||
conn.closeWithError(E.Cause(err, "send heartbeat"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
||||
conn, err := c.offer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := conn.quicConn.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &clientConn{
|
||||
parent: conn,
|
||||
stream: stream,
|
||||
destination: destination,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
|
||||
conn, err := c.offer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var sessionID uint16
|
||||
clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() {
|
||||
conn.udpAccess.Lock()
|
||||
delete(conn.udpConnMap, sessionID)
|
||||
conn.udpAccess.Unlock()
|
||||
})
|
||||
conn.udpAccess.Lock()
|
||||
sessionID = conn.udpSessionID
|
||||
conn.udpSessionID++
|
||||
conn.udpConnMap[sessionID] = clientPacketConn
|
||||
conn.udpAccess.Unlock()
|
||||
clientPacketConn.sessionID = sessionID
|
||||
return clientPacketConn, nil
|
||||
}
|
||||
|
||||
func (c *Client) CloseWithError(err error) error {
|
||||
conn := c.conn
|
||||
if conn != nil {
|
||||
conn.closeWithError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type clientQUICConnection struct {
|
||||
quicConn quic.Connection
|
||||
rawConn io.Closer
|
||||
closeOnce sync.Once
|
||||
connDone chan struct{}
|
||||
connErr error
|
||||
udpAccess sync.RWMutex
|
||||
udpConnMap map[uint16]*udpPacketConn
|
||||
udpSessionID uint16
|
||||
}
|
||||
|
||||
func (c *clientQUICConnection) active() bool {
|
||||
select {
|
||||
case <-c.quicConn.Context().Done():
|
||||
return false
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-c.connDone:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *clientQUICConnection) closeWithError(err error) {
|
||||
c.closeOnce.Do(func() {
|
||||
c.connErr = err
|
||||
close(c.connDone)
|
||||
_ = c.quicConn.CloseWithError(0, "")
|
||||
_ = c.rawConn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
parent *clientQUICConnection
|
||||
stream quic.Stream
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
}
|
||||
|
||||
func (c *clientConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.stream.Read(b)
|
||||
return n, baderror.WrapQUIC(err)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(b []byte) (n int, err error) {
|
||||
if !c.requestWritten {
|
||||
request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
|
||||
request.WriteByte(Version)
|
||||
request.WriteByte(CommandConnect)
|
||||
addressSerializer.WriteAddrPort(request, c.destination)
|
||||
request.Write(b)
|
||||
_, err = c.stream.Write(request.Bytes())
|
||||
if err != nil {
|
||||
c.parent.closeWithError(E.Cause(err, "create new connection"))
|
||||
return 0, baderror.WrapQUIC(err)
|
||||
}
|
||||
c.requestWritten = true
|
||||
return len(b), nil
|
||||
}
|
||||
n, err = c.stream.Write(b)
|
||||
return n, baderror.WrapQUIC(err)
|
||||
}
|
||||
|
||||
func (c *clientConn) Close() error {
|
||||
stream := c.stream
|
||||
if stream == nil {
|
||||
return nil
|
||||
}
|
||||
stream.CancelRead(0)
|
||||
return stream.Close()
|
||||
}
|
||||
|
||||
func (c *clientConn) LocalAddr() net.Addr {
|
||||
return M.Socksaddr{}
|
||||
}
|
||||
|
||||
func (c *clientConn) RemoteAddr() net.Addr {
|
||||
return c.destination
|
||||
}
|
||||
|
||||
func (c *clientConn) SetDeadline(t time.Time) error {
|
||||
if c.stream == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
return c.stream.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *clientConn) SetReadDeadline(t time.Time) error {
|
||||
if c.stream == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
return c.stream.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *clientConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.stream == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
return c.stream.SetWriteDeadline(t)
|
||||
}
|
110
transport/tuic/client_packet.go
Normal file
110
transport/tuic/client_packet.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package tuic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func (c *Client) loopMessages(conn *clientQUICConnection) {
|
||||
for {
|
||||
message, err := conn.quicConn.ReceiveMessage(c.ctx)
|
||||
if err != nil {
|
||||
conn.closeWithError(E.Cause(err, "receive message"))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
hErr := c.handleMessage(conn, message)
|
||||
if hErr != nil {
|
||||
conn.closeWithError(E.Cause(hErr, "handle message"))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
|
||||
if len(data) < 2 {
|
||||
return E.New("invalid message")
|
||||
}
|
||||
if data[0] != Version {
|
||||
return E.New("unknown version ", data[0])
|
||||
}
|
||||
switch data[1] {
|
||||
case CommandPacket:
|
||||
message := udpMessagePool.Get().(*udpMessage)
|
||||
err := decodeUDPMessage(message, data[2:])
|
||||
if err != nil {
|
||||
message.release()
|
||||
return E.Cause(err, "decode UDP message")
|
||||
}
|
||||
conn.handleUDPMessage(message)
|
||||
return nil
|
||||
case CommandHeartbeat:
|
||||
return nil
|
||||
default:
|
||||
return E.New("unknown command ", data[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) loopUniStreams(conn *clientQUICConnection) {
|
||||
for {
|
||||
stream, err := conn.quicConn.AcceptUniStream(c.ctx)
|
||||
if err != nil {
|
||||
conn.closeWithError(E.Cause(err, "handle uni stream"))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
hErr := c.handleUniStream(conn, stream)
|
||||
if hErr != nil {
|
||||
conn.closeWithError(hErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error {
|
||||
defer stream.CancelRead(0)
|
||||
buffer := buf.NewPacket()
|
||||
defer buffer.Release()
|
||||
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
version, _ := buffer.ReadByte()
|
||||
if version != Version {
|
||||
return E.New("unknown version ", version)
|
||||
}
|
||||
command, _ := buffer.ReadByte()
|
||||
if command != CommandPacket {
|
||||
return E.New("unknown command ", command)
|
||||
}
|
||||
reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
|
||||
message := udpMessagePool.Get().(*udpMessage)
|
||||
err = readUDPMessage(message, reader)
|
||||
if err != nil {
|
||||
message.release()
|
||||
return err
|
||||
}
|
||||
conn.handleUDPMessage(message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
|
||||
c.udpAccess.RLock()
|
||||
udpConn, loaded := c.udpConnMap[message.sessionID]
|
||||
c.udpAccess.RUnlock()
|
||||
if !loaded {
|
||||
message.releaseMessage()
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-udpConn.ctx.Done():
|
||||
message.releaseMessage()
|
||||
return
|
||||
default:
|
||||
}
|
||||
udpConn.inputPacket(message)
|
||||
}
|
46
transport/tuic/congestion.go
Normal file
46
transport/tuic/congestion.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package tuic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing-box/transport/tuic/congestion"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
)
|
||||
|
||||
func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) {
|
||||
timeFunc := ntp.TimeFuncFromContext(ctx)
|
||||
if timeFunc == nil {
|
||||
timeFunc = time.Now
|
||||
}
|
||||
switch congestionName {
|
||||
case "cubic":
|
||||
connection.SetCongestionControl(
|
||||
congestion.NewCubicSender(
|
||||
congestion.DefaultClock{TimeFunc: timeFunc},
|
||||
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
||||
false,
|
||||
nil,
|
||||
),
|
||||
)
|
||||
case "new_reno":
|
||||
connection.SetCongestionControl(
|
||||
congestion.NewCubicSender(
|
||||
congestion.DefaultClock{TimeFunc: timeFunc},
|
||||
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
||||
true,
|
||||
nil,
|
||||
),
|
||||
)
|
||||
case "bbr":
|
||||
connection.SetCongestionControl(
|
||||
congestion.NewBBRSender(
|
||||
congestion.DefaultClock{},
|
||||
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
||||
congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize,
|
||||
congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
3
transport/tuic/congestion/README.md
Normal file
3
transport/tuic/congestion/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# congestion
|
||||
|
||||
mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion
|
25
transport/tuic/congestion/bandwidth.go
Normal file
25
transport/tuic/congestion/bandwidth.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
// Bandwidth of a connection
|
||||
type Bandwidth uint64
|
||||
|
||||
const infBandwidth Bandwidth = math.MaxUint64
|
||||
|
||||
const (
|
||||
// BitsPerSecond is 1 bit per second
|
||||
BitsPerSecond Bandwidth = 1
|
||||
// BytesPerSecond is 1 byte per second
|
||||
BytesPerSecond = 8 * BitsPerSecond
|
||||
)
|
||||
|
||||
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
|
||||
func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth {
|
||||
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
|
||||
}
|
374
transport/tuic/congestion/bandwidth_sampler.go
Normal file
374
transport/tuic/congestion/bandwidth_sampler.go
Normal file
|
@ -0,0 +1,374 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
var InfiniteBandwidth = Bandwidth(math.MaxUint64)
|
||||
|
||||
// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned
|
||||
// to the caller when the packet is acked or lost.
|
||||
type SendTimeState struct {
|
||||
// Whether other states in this object is valid.
|
||||
isValid bool
|
||||
// Whether the sender is app limited at the time the packet was sent.
|
||||
// App limited bandwidth sample might be artificially low because the sender
|
||||
// did not have enough data to send in order to saturate the link.
|
||||
isAppLimited bool
|
||||
// Total number of sent bytes at the time the packet was sent.
|
||||
// Includes the packet itself.
|
||||
totalBytesSent congestion.ByteCount
|
||||
// Total number of acked bytes at the time the packet was sent.
|
||||
totalBytesAcked congestion.ByteCount
|
||||
// Total number of lost bytes at the time the packet was sent.
|
||||
totalBytesLost congestion.ByteCount
|
||||
}
|
||||
|
||||
// ConnectionStateOnSentPacket represents the information about a sent packet
|
||||
// and the state of the connection at the moment the packet was sent,
|
||||
// specifically the information about the most recently acknowledged packet at
|
||||
// that moment.
|
||||
type ConnectionStateOnSentPacket struct {
|
||||
packetNumber congestion.PacketNumber
|
||||
// Time at which the packet is sent.
|
||||
sendTime time.Time
|
||||
// Size of the packet.
|
||||
size congestion.ByteCount
|
||||
// The value of |totalBytesSentAtLastAckedPacket| at the time the
|
||||
// packet was sent.
|
||||
totalBytesSentAtLastAckedPacket congestion.ByteCount
|
||||
// The value of |lastAckedPacketSentTime| at the time the packet was
|
||||
// sent.
|
||||
lastAckedPacketSentTime time.Time
|
||||
// The value of |lastAckedPacketAckTime| at the time the packet was
|
||||
// sent.
|
||||
lastAckedPacketAckTime time.Time
|
||||
// Send time states that are returned to the congestion controller when the
|
||||
// packet is acked or lost.
|
||||
sendTimeState SendTimeState
|
||||
}
|
||||
|
||||
// BandwidthSample
|
||||
type BandwidthSample struct {
|
||||
// The bandwidth at that particular sample. Zero if no valid bandwidth sample
|
||||
// is available.
|
||||
bandwidth Bandwidth
|
||||
// The RTT measurement at this particular sample. Zero if no RTT sample is
|
||||
// available. Does not correct for delayed ack time.
|
||||
rtt time.Duration
|
||||
// States captured when the packet was sent.
|
||||
stateAtSend SendTimeState
|
||||
}
|
||||
|
||||
func NewBandwidthSample() *BandwidthSample {
|
||||
return &BandwidthSample{
|
||||
// FIXME: the default value of original code is zero.
|
||||
rtt: InfiniteRTT,
|
||||
}
|
||||
}
|
||||
|
||||
// BandwidthSampler keeps track of sent and acknowledged packets and outputs a
|
||||
// bandwidth sample for every packet acknowledged. The samples are taken for
|
||||
// individual packets, and are not filtered; the consumer has to filter the
|
||||
// bandwidth samples itself. In certain cases, the sampler will locally severely
|
||||
// underestimate the bandwidth, hence a maximum filter with a size of at least
|
||||
// one RTT is recommended.
|
||||
//
|
||||
// This class bases its samples on the slope of two curves: the number of bytes
|
||||
// sent over time, and the number of bytes acknowledged as received over time.
|
||||
// It produces a sample of both slopes for every packet that gets acknowledged,
|
||||
// based on a slope between two points on each of the corresponding curves. Note
|
||||
// that due to the packet loss, the number of bytes on each curve might get
|
||||
// further and further away from each other, meaning that it is not feasible to
|
||||
// compare byte values coming from different curves with each other.
|
||||
//
|
||||
// The obvious points for measuring slope sample are the ones corresponding to
|
||||
// the packet that was just acknowledged. Let us denote them as S_1 (point at
|
||||
// which the current packet was sent) and A_1 (point at which the current packet
|
||||
// was acknowledged). However, taking a slope requires two points on each line,
|
||||
// so estimating bandwidth requires picking a packet in the past with respect to
|
||||
// which the slope is measured.
|
||||
//
|
||||
// For that purpose, BandwidthSampler always keeps track of the most recently
|
||||
// acknowledged packet, and records it together with every outgoing packet.
|
||||
// When a packet gets acknowledged (A_1), it has not only information about when
|
||||
// it itself was sent (S_1), but also the information about the latest
|
||||
// acknowledged packet right before it was sent (S_0 and A_0).
|
||||
//
|
||||
// Based on that data, send and ack rate are estimated as:
|
||||
//
|
||||
// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0))
|
||||
// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0))
|
||||
//
|
||||
// Here, the ack rate is intuitively the rate we want to treat as bandwidth.
|
||||
// However, in certain cases (e.g. ack compression) the ack rate at a point may
|
||||
// end up higher than the rate at which the data was originally sent, which is
|
||||
// not indicative of the real bandwidth. Hence, we use the send rate as an upper
|
||||
// bound, and the sample value is
|
||||
//
|
||||
// rate_sample = min(send_rate, ack_rate)
|
||||
//
|
||||
// An important edge case handled by the sampler is tracking the app-limited
|
||||
// samples. There are multiple meaning of "app-limited" used interchangeably,
|
||||
// hence it is important to understand and to be able to distinguish between
|
||||
// them.
|
||||
//
|
||||
// Meaning 1: connection state. The connection is said to be app-limited when
|
||||
// there is no outstanding data to send. This means that certain bandwidth
|
||||
// samples in the future would not be an accurate indication of the link
|
||||
// capacity, and it is important to inform consumer about that. Whenever
|
||||
// connection becomes app-limited, the sampler is notified via OnAppLimited()
|
||||
// method.
|
||||
//
|
||||
// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth
|
||||
// sampler becomes notified about the connection being app-limited, it enters
|
||||
// app-limited phase. In that phase, all *sent* packets are marked as
|
||||
// app-limited. Note that the connection itself does not have to be
|
||||
// app-limited during the app-limited phase, and in fact it will not be
|
||||
// (otherwise how would it send packets?). The boolean flag below indicates
|
||||
// whether the sampler is in that phase.
|
||||
//
|
||||
// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is
|
||||
// sent during the app-limited phase, the resulting sample related to the
|
||||
// packet will be marked as app-limited.
|
||||
//
|
||||
// With the terminology issue out of the way, let us consider the question of
|
||||
// what kind of situation it addresses.
|
||||
//
|
||||
// Consider a scenario where we first send packets 1 to 20 at a regular
|
||||
// bandwidth, and then immediately run out of data. After a few seconds, we send
|
||||
// packets 21 to 60, and only receive ack for 21 between sending packets 40 and
|
||||
// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0
|
||||
// we use to compute the slope is going to be packet 20, a few seconds apart
|
||||
// from the current packet, hence the resulting estimate would be extremely low
|
||||
// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21,
|
||||
// meaning that the bandwidth sample would exclude the quiescence.
|
||||
//
|
||||
// Based on the analysis of that scenario, we implement the following rule: once
|
||||
// OnAppLimited() is called, all sent packets will produce app-limited samples
|
||||
// up until an ack for a packet that was sent after OnAppLimited() was called.
|
||||
// Note that while the scenario above is not the only scenario when the
|
||||
// connection is app-limited, the approach works in other cases too.
|
||||
type BandwidthSampler struct {
|
||||
// The total number of congestion controlled bytes sent during the connection.
|
||||
totalBytesSent congestion.ByteCount
|
||||
// The total number of congestion controlled bytes which were acknowledged.
|
||||
totalBytesAcked congestion.ByteCount
|
||||
// The total number of congestion controlled bytes which were lost.
|
||||
totalBytesLost congestion.ByteCount
|
||||
// The value of |totalBytesSent| at the time the last acknowledged packet
|
||||
// was sent. Valid only when |lastAckedPacketSentTime| is valid.
|
||||
totalBytesSentAtLastAckedPacket congestion.ByteCount
|
||||
// The time at which the last acknowledged packet was sent. Set to
|
||||
// QuicTime::Zero() if no valid timestamp is available.
|
||||
lastAckedPacketSentTime time.Time
|
||||
// The time at which the most recent packet was acknowledged.
|
||||
lastAckedPacketAckTime time.Time
|
||||
// The most recently sent packet.
|
||||
lastSendPacket congestion.PacketNumber
|
||||
// Indicates whether the bandwidth sampler is currently in an app-limited
|
||||
// phase.
|
||||
isAppLimited bool
|
||||
// The packet that will be acknowledged after this one will cause the sampler
|
||||
// to exit the app-limited phase.
|
||||
endOfAppLimitedPhase congestion.PacketNumber
|
||||
// Record of the connection state at the point where each packet in flight was
|
||||
// sent, indexed by the packet number.
|
||||
connectionStats *ConnectionStates
|
||||
}
|
||||
|
||||
func NewBandwidthSampler() *BandwidthSampler {
|
||||
return &BandwidthSampler{
|
||||
connectionStats: &ConnectionStates{
|
||||
stats: make(map[congestion.PacketNumber]*ConnectionStateOnSentPacket),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketSent Inputs the sent packet information into the sampler. Assumes that all
|
||||
// packets are sent in order. The information about the packet will not be
|
||||
// released from the sampler until it the packet is either acknowledged or
|
||||
// declared lost.
|
||||
func (s *BandwidthSampler) OnPacketSent(sentTime time.Time, lastSentPacket congestion.PacketNumber, sentBytes, bytesInFlight congestion.ByteCount, hasRetransmittableData bool) {
|
||||
s.lastSendPacket = lastSentPacket
|
||||
|
||||
if !hasRetransmittableData {
|
||||
return
|
||||
}
|
||||
|
||||
s.totalBytesSent += sentBytes
|
||||
|
||||
// If there are no packets in flight, the time at which the new transmission
|
||||
// opens can be treated as the A_0 point for the purpose of bandwidth
|
||||
// sampling. This underestimates bandwidth to some extent, and produces some
|
||||
// artificially low samples for most packets in flight, but it provides with
|
||||
// samples at important points where we would not have them otherwise, most
|
||||
// importantly at the beginning of the connection.
|
||||
if bytesInFlight == 0 {
|
||||
s.lastAckedPacketAckTime = sentTime
|
||||
s.totalBytesSentAtLastAckedPacket = s.totalBytesSent
|
||||
|
||||
// In this situation ack compression is not a concern, set send rate to
|
||||
// effectively infinite.
|
||||
s.lastAckedPacketSentTime = sentTime
|
||||
}
|
||||
|
||||
s.connectionStats.Insert(lastSentPacket, sentTime, sentBytes, s)
|
||||
}
|
||||
|
||||
// OnPacketAcked Notifies the sampler that the |lastAckedPacket| is acknowledged. Returns a
|
||||
// bandwidth sample. If no bandwidth sample is available,
|
||||
// QuicBandwidth::Zero() is returned.
|
||||
func (s *BandwidthSampler) OnPacketAcked(ackTime time.Time, lastAckedPacket congestion.PacketNumber) *BandwidthSample {
|
||||
sentPacketState := s.connectionStats.Get(lastAckedPacket)
|
||||
if sentPacketState == nil {
|
||||
return NewBandwidthSample()
|
||||
}
|
||||
|
||||
sample := s.onPacketAckedInner(ackTime, lastAckedPacket, sentPacketState)
|
||||
s.connectionStats.Remove(lastAckedPacket)
|
||||
|
||||
return sample
|
||||
}
|
||||
|
||||
// onPacketAckedInner Handles the actual bandwidth calculations, whereas the outer method handles
|
||||
// retrieving and removing |sentPacket|.
|
||||
func (s *BandwidthSampler) onPacketAckedInner(ackTime time.Time, lastAckedPacket congestion.PacketNumber, sentPacket *ConnectionStateOnSentPacket) *BandwidthSample {
|
||||
s.totalBytesAcked += sentPacket.size
|
||||
|
||||
s.totalBytesSentAtLastAckedPacket = sentPacket.sendTimeState.totalBytesSent
|
||||
s.lastAckedPacketSentTime = sentPacket.sendTime
|
||||
s.lastAckedPacketAckTime = ackTime
|
||||
|
||||
// Exit app-limited phase once a packet that was sent while the connection is
|
||||
// not app-limited is acknowledged.
|
||||
if s.isAppLimited && lastAckedPacket > s.endOfAppLimitedPhase {
|
||||
s.isAppLimited = false
|
||||
}
|
||||
|
||||
// There might have been no packets acknowledged at the moment when the
|
||||
// current packet was sent. In that case, there is no bandwidth sample to
|
||||
// make.
|
||||
if sentPacket.lastAckedPacketSentTime.IsZero() {
|
||||
return NewBandwidthSample()
|
||||
}
|
||||
|
||||
// Infinite rate indicates that the sampler is supposed to discard the
|
||||
// current send rate sample and use only the ack rate.
|
||||
sendRate := InfiniteBandwidth
|
||||
if sentPacket.sendTime.After(sentPacket.lastAckedPacketSentTime) {
|
||||
sendRate = BandwidthFromDelta(sentPacket.sendTimeState.totalBytesSent-sentPacket.totalBytesSentAtLastAckedPacket, sentPacket.sendTime.Sub(sentPacket.lastAckedPacketSentTime))
|
||||
}
|
||||
|
||||
// During the slope calculation, ensure that ack time of the current packet is
|
||||
// always larger than the time of the previous packet, otherwise division by
|
||||
// zero or integer underflow can occur.
|
||||
if !ackTime.After(sentPacket.lastAckedPacketAckTime) {
|
||||
// TODO(wub): Compare this code count before and after fixing clock jitter
|
||||
// issue.
|
||||
// if sentPacket.lastAckedPacketAckTime.Equal(sentPacket.sendTime) {
|
||||
// This is the 1st packet after quiescense.
|
||||
// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2);
|
||||
// } else {
|
||||
// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2);
|
||||
// }
|
||||
|
||||
return NewBandwidthSample()
|
||||
}
|
||||
|
||||
ackRate := BandwidthFromDelta(s.totalBytesAcked-sentPacket.sendTimeState.totalBytesAcked,
|
||||
ackTime.Sub(sentPacket.lastAckedPacketAckTime))
|
||||
|
||||
// Note: this sample does not account for delayed acknowledgement time. This
|
||||
// means that the RTT measurements here can be artificially high, especially
|
||||
// on low bandwidth connections.
|
||||
sample := &BandwidthSample{
|
||||
bandwidth: minBandwidth(sendRate, ackRate),
|
||||
rtt: ackTime.Sub(sentPacket.sendTime),
|
||||
}
|
||||
|
||||
SentPacketToSendTimeState(sentPacket, &sample.stateAtSend)
|
||||
return sample
|
||||
}
|
||||
|
||||
// OnPacketLost Informs the sampler that a packet is considered lost and it should no
|
||||
// longer keep track of it.
|
||||
func (s *BandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber) SendTimeState {
|
||||
ok, sentPacket := s.connectionStats.Remove(packetNumber)
|
||||
sendTimeState := SendTimeState{
|
||||
isValid: ok,
|
||||
}
|
||||
if sentPacket != nil {
|
||||
s.totalBytesLost += sentPacket.size
|
||||
SentPacketToSendTimeState(sentPacket, &sendTimeState)
|
||||
}
|
||||
|
||||
return sendTimeState
|
||||
}
|
||||
|
||||
// OnAppLimited Informs the sampler that the connection is currently app-limited, causing
|
||||
// the sampler to enter the app-limited phase. The phase will expire by
|
||||
// itself.
|
||||
func (s *BandwidthSampler) OnAppLimited() {
|
||||
s.isAppLimited = true
|
||||
s.endOfAppLimitedPhase = s.lastSendPacket
|
||||
}
|
||||
|
||||
// SentPacketToSendTimeState Copy a subset of the (private) ConnectionStateOnSentPacket to the (public)
|
||||
// SendTimeState. Always set send_time_state->is_valid to true.
|
||||
func SentPacketToSendTimeState(sentPacket *ConnectionStateOnSentPacket, sendTimeState *SendTimeState) {
|
||||
sendTimeState.isAppLimited = sentPacket.sendTimeState.isAppLimited
|
||||
sendTimeState.totalBytesSent = sentPacket.sendTimeState.totalBytesSent
|
||||
sendTimeState.totalBytesAcked = sentPacket.sendTimeState.totalBytesAcked
|
||||
sendTimeState.totalBytesLost = sentPacket.sendTimeState.totalBytesLost
|
||||
sendTimeState.isValid = true
|
||||
}
|
||||
|
||||
// ConnectionStates Record of the connection state at the point where each packet in flight was
|
||||
// sent, indexed by the packet number.
|
||||
// FIXME: using LinkedList replace map to fast remove all the packets lower than the specified packet number.
|
||||
type ConnectionStates struct {
|
||||
stats map[congestion.PacketNumber]*ConnectionStateOnSentPacket
|
||||
}
|
||||
|
||||
func (s *ConnectionStates) Insert(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) bool {
|
||||
if _, ok := s.stats[packetNumber]; ok {
|
||||
return false
|
||||
}
|
||||
|
||||
s.stats[packetNumber] = NewConnectionStateOnSentPacket(packetNumber, sentTime, bytes, sampler)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *ConnectionStates) Get(packetNumber congestion.PacketNumber) *ConnectionStateOnSentPacket {
|
||||
return s.stats[packetNumber]
|
||||
}
|
||||
|
||||
func (s *ConnectionStates) Remove(packetNumber congestion.PacketNumber) (bool, *ConnectionStateOnSentPacket) {
|
||||
state, ok := s.stats[packetNumber]
|
||||
if ok {
|
||||
delete(s.stats, packetNumber)
|
||||
}
|
||||
return ok, state
|
||||
}
|
||||
|
||||
func NewConnectionStateOnSentPacket(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) *ConnectionStateOnSentPacket {
|
||||
return &ConnectionStateOnSentPacket{
|
||||
packetNumber: packetNumber,
|
||||
sendTime: sentTime,
|
||||
size: bytes,
|
||||
lastAckedPacketSentTime: sampler.lastAckedPacketSentTime,
|
||||
lastAckedPacketAckTime: sampler.lastAckedPacketAckTime,
|
||||
totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket,
|
||||
sendTimeState: SendTimeState{
|
||||
isValid: true,
|
||||
isAppLimited: sampler.isAppLimited,
|
||||
totalBytesSent: sampler.totalBytesSent,
|
||||
totalBytesAcked: sampler.totalBytesAcked,
|
||||
totalBytesLost: sampler.totalBytesLost,
|
||||
},
|
||||
}
|
||||
}
|
1000
transport/tuic/congestion/bbr_sender.go
Normal file
1000
transport/tuic/congestion/bbr_sender.go
Normal file
File diff suppressed because it is too large
Load diff
20
transport/tuic/congestion/clock.go
Normal file
20
transport/tuic/congestion/clock.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package congestion
|
||||
|
||||
import "time"
|
||||
|
||||
// A Clock returns the current time
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// DefaultClock implements the Clock interface using the Go stdlib clock.
|
||||
type DefaultClock struct {
|
||||
TimeFunc func() time.Time
|
||||
}
|
||||
|
||||
var _ Clock = DefaultClock{}
|
||||
|
||||
// Now gets the current time
|
||||
func (c DefaultClock) Now() time.Time {
|
||||
return c.TimeFunc()
|
||||
}
|
213
transport/tuic/congestion/cubic.go
Normal file
213
transport/tuic/congestion/cubic.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
||||
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
|
||||
|
||||
// Constants based on TCP defaults.
|
||||
// The following constants are in 2^10 fractions of a second instead of ms to
|
||||
// allow a 10 shift right to divide.
|
||||
|
||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||
const (
|
||||
cubeScale = 40
|
||||
cubeCongestionWindowScale = 410
|
||||
cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
|
||||
// TODO: when re-enabling cubic, make sure to use the actual packet size here
|
||||
maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4)
|
||||
)
|
||||
|
||||
const defaultNumConnections = 1
|
||||
|
||||
// Default Cubic backoff factor
|
||||
const beta float32 = 0.7
|
||||
|
||||
// Additional backoff factor when loss occurs in the concave part of the Cubic
|
||||
// curve. This additional backoff factor is expected to give up bandwidth to
|
||||
// new concurrent flows and speed up convergence.
|
||||
const betaLastMax float32 = 0.85
|
||||
|
||||
// Cubic implements the cubic algorithm from TCP
|
||||
type Cubic struct {
|
||||
clock Clock
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// Time when this cycle started, after last loss event.
|
||||
epoch time.Time
|
||||
|
||||
// Max congestion window used just before last loss event.
|
||||
// Note: to improve fairness to other streams an additional back off is
|
||||
// applied to this value if the new value is below our latest value.
|
||||
lastMaxCongestionWindow congestion.ByteCount
|
||||
|
||||
// Number of acked bytes since the cycle started (epoch).
|
||||
ackedBytesCount congestion.ByteCount
|
||||
|
||||
// TCP Reno equivalent congestion window in packets.
|
||||
estimatedTCPcongestionWindow congestion.ByteCount
|
||||
|
||||
// Origin point of cubic function.
|
||||
originPointCongestionWindow congestion.ByteCount
|
||||
|
||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||
timeToOriginPoint uint32
|
||||
|
||||
// Last congestion window in packets computed by cubic function.
|
||||
lastTargetCongestionWindow congestion.ByteCount
|
||||
}
|
||||
|
||||
// NewCubic returns a new Cubic instance
|
||||
func NewCubic(clock Clock) *Cubic {
|
||||
c := &Cubic{
|
||||
clock: clock,
|
||||
numConnections: defaultNumConnections,
|
||||
}
|
||||
c.Reset()
|
||||
return c
|
||||
}
|
||||
|
||||
// Reset is called after a timeout to reset the cubic state
|
||||
func (c *Cubic) Reset() {
|
||||
c.epoch = time.Time{}
|
||||
c.lastMaxCongestionWindow = 0
|
||||
c.ackedBytesCount = 0
|
||||
c.estimatedTCPcongestionWindow = 0
|
||||
c.originPointCongestionWindow = 0
|
||||
c.timeToOriginPoint = 0
|
||||
c.lastTargetCongestionWindow = 0
|
||||
}
|
||||
|
||||
func (c *Cubic) alpha() float32 {
|
||||
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
|
||||
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
|
||||
// We derive the equivalent alpha for an N-connection emulation as:
|
||||
b := c.beta()
|
||||
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
|
||||
}
|
||||
|
||||
func (c *Cubic) beta() float32 {
|
||||
// kNConnectionBeta is the backoff factor after loss for our N-connection
|
||||
// emulation, which emulates the effective backoff of an ensemble of N
|
||||
// TCP-Reno connections on a single loss event. The effective multiplier is
|
||||
// computed as:
|
||||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
func (c *Cubic) betaLastMax() float32 {
|
||||
// betaLastMax is the additional backoff factor after loss for our
|
||||
// N-connection emulation, which emulates the additional backoff of
|
||||
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||
// effective multiplier is computed as:
|
||||
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||
// the available congestion window. Resets Cubic state during quiescence.
|
||||
func (c *Cubic) OnApplicationLimited() {
|
||||
// When sender is not using the available congestion window, the window does
|
||||
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||
// using the entire window during the time since the beginning of the current
|
||||
// "epoch" (the end of the last loss recovery period). Since
|
||||
// application-limited periods break this assumption, we reset the epoch when
|
||||
// in such a period. This reset effectively freezes congestion window growth
|
||||
// through application-limited periods and allows Cubic growth to continue
|
||||
// when the entire window is being used.
|
||||
c.epoch = time.Time{}
|
||||
}
|
||||
|
||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||
// a loss event. Returns the new congestion window in packets. The new
|
||||
// congestion window is a multiplicative decrease of our current window.
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount {
|
||||
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
|
||||
// We never reached the old max, so assume we are competing with another
|
||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||
c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||
} else {
|
||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||
}
|
||||
c.epoch = time.Time{} // Reset time.
|
||||
return congestion.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||
}
|
||||
|
||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||
// Returns the new congestion window in packets. The new congestion window
|
||||
// follows a cubic function that depends on the time passed since last
|
||||
// packet loss.
|
||||
func (c *Cubic) CongestionWindowAfterAck(
|
||||
ackedBytes congestion.ByteCount,
|
||||
currentCongestionWindow congestion.ByteCount,
|
||||
delayMin time.Duration,
|
||||
eventTime time.Time,
|
||||
) congestion.ByteCount {
|
||||
c.ackedBytesCount += ackedBytes
|
||||
|
||||
if c.epoch.IsZero() {
|
||||
// First ACK after a loss event.
|
||||
c.epoch = eventTime // Start of epoch.
|
||||
c.ackedBytesCount = ackedBytes // Reset count.
|
||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||
c.timeToOriginPoint = 0
|
||||
c.originPointCongestionWindow = currentCongestionWindow
|
||||
} else {
|
||||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||
}
|
||||
}
|
||||
|
||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||
// the round trip time in account. This is done to allow us to use shift as a
|
||||
// divide operator.
|
||||
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||
|
||||
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
|
||||
var targetCongestionWindow congestion.ByteCount
|
||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||
} else {
|
||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||
}
|
||||
// Limit the CWND increase to half the acked bytes.
|
||||
targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||
|
||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||
// time we ack an estimated tcp window of bytes. For small
|
||||
// congestion windows (less than 25), the formula below will
|
||||
// increase slightly slower than linearly per estimated tcp window
|
||||
// of bytes.
|
||||
c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
|
||||
c.ackedBytesCount = 0
|
||||
|
||||
// We have a new cubic congestion window.
|
||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||
|
||||
// Compute target congestion_window based on cubic target and estimated TCP
|
||||
// congestion_window, use highest (fastest).
|
||||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||
}
|
||||
return targetCongestionWindow
|
||||
}
|
||||
|
||||
// SetNumConnections sets the number of emulated connections
|
||||
func (c *Cubic) SetNumConnections(n int) {
|
||||
c.numConnections = n
|
||||
}
|
318
transport/tuic/congestion/cubic_sender.go
Normal file
318
transport/tuic/congestion/cubic_sender.go
Normal file
|
@ -0,0 +1,318 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
"github.com/sagernet/quic-go/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBurstPackets = 3
|
||||
renoBeta = 0.7 // Reno backoff factor.
|
||||
minCongestionWindowPackets = 2
|
||||
initialCongestionWindow = 32
|
||||
)
|
||||
|
||||
const (
|
||||
InvalidPacketNumber congestion.PacketNumber = -1
|
||||
MaxCongestionWindowPackets = 20000
|
||||
MaxByteCount = congestion.ByteCount(1<<62 - 1)
|
||||
)
|
||||
|
||||
type cubicSender struct {
|
||||
hybridSlowStart HybridSlowStart
|
||||
rttStats congestion.RTTStatsProvider
|
||||
cubic *Cubic
|
||||
pacer *pacer
|
||||
clock Clock
|
||||
|
||||
reno bool
|
||||
|
||||
// Track the largest packet that has been sent.
|
||||
largestSentPacketNumber congestion.PacketNumber
|
||||
|
||||
// Track the largest packet that has been acked.
|
||||
largestAckedPacketNumber congestion.PacketNumber
|
||||
|
||||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||
largestSentAtLastCutback congestion.PacketNumber
|
||||
|
||||
// Whether the last loss event caused us to exit slowstart.
|
||||
// Used for stats collection of slowstartPacketsLost
|
||||
lastCutbackExitedSlowstart bool
|
||||
|
||||
// Congestion window in bytes.
|
||||
congestionWindow congestion.ByteCount
|
||||
|
||||
// Slow start congestion window in bytes, aka ssthresh.
|
||||
slowStartThreshold congestion.ByteCount
|
||||
|
||||
// ACK counter for the Reno implementation.
|
||||
numAckedPackets uint64
|
||||
|
||||
initialCongestionWindow congestion.ByteCount
|
||||
initialMaxCongestionWindow congestion.ByteCount
|
||||
|
||||
maxDatagramSize congestion.ByteCount
|
||||
|
||||
lastState logging.CongestionState
|
||||
tracer logging.ConnectionTracer
|
||||
}
|
||||
|
||||
var _ congestion.CongestionControl = &cubicSender{}
|
||||
|
||||
// NewCubicSender makes a new cubic sender
|
||||
func NewCubicSender(
|
||||
clock Clock,
|
||||
initialMaxDatagramSize congestion.ByteCount,
|
||||
reno bool,
|
||||
tracer logging.ConnectionTracer,
|
||||
) *cubicSender {
|
||||
return newCubicSender(
|
||||
clock,
|
||||
reno,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow*initialMaxDatagramSize,
|
||||
MaxCongestionWindowPackets*initialMaxDatagramSize,
|
||||
tracer,
|
||||
)
|
||||
}
|
||||
|
||||
func newCubicSender(
|
||||
clock Clock,
|
||||
reno bool,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow,
|
||||
initialMaxCongestionWindow congestion.ByteCount,
|
||||
tracer logging.ConnectionTracer,
|
||||
) *cubicSender {
|
||||
c := &cubicSender{
|
||||
largestSentPacketNumber: InvalidPacketNumber,
|
||||
largestAckedPacketNumber: InvalidPacketNumber,
|
||||
largestSentAtLastCutback: InvalidPacketNumber,
|
||||
initialCongestionWindow: initialCongestionWindow,
|
||||
initialMaxCongestionWindow: initialMaxCongestionWindow,
|
||||
congestionWindow: initialCongestionWindow,
|
||||
slowStartThreshold: MaxByteCount,
|
||||
cubic: NewCubic(clock),
|
||||
clock: clock,
|
||||
reno: reno,
|
||||
tracer: tracer,
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
}
|
||||
c.pacer = newPacer(c.BandwidthEstimate)
|
||||
if c.tracer != nil {
|
||||
c.lastState = logging.CongestionStateSlowStart
|
||||
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
|
||||
c.rttStats = provider
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
|
||||
return c.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
func (c *cubicSender) HasPacingBudget(now time.Time) bool {
|
||||
return c.pacer.Budget(now) >= c.maxDatagramSize
|
||||
}
|
||||
|
||||
func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
|
||||
return c.maxDatagramSize * MaxCongestionWindowPackets
|
||||
}
|
||||
|
||||
func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
|
||||
return c.maxDatagramSize * minCongestionWindowPackets
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketSent(
|
||||
sentTime time.Time,
|
||||
_ congestion.ByteCount,
|
||||
packetNumber congestion.PacketNumber,
|
||||
bytes congestion.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
c.pacer.SentPacket(sentTime, bytes)
|
||||
if !isRetransmittable {
|
||||
return
|
||||
}
|
||||
c.largestSentPacketNumber = packetNumber
|
||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||
}
|
||||
|
||||
func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
||||
return bytesInFlight < c.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (c *cubicSender) InRecovery() bool {
|
||||
return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
|
||||
}
|
||||
|
||||
func (c *cubicSender) InSlowStart() bool {
|
||||
return c.GetCongestionWindow() < c.slowStartThreshold
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
|
||||
return c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) MaybeExitSlowStart() {
|
||||
if c.InSlowStart() &&
|
||||
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
|
||||
// exit slow start
|
||||
c.slowStartThreshold = c.congestionWindow
|
||||
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketAcked(
|
||||
ackedPacketNumber congestion.PacketNumber,
|
||||
ackedBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||
if c.InRecovery() {
|
||||
return
|
||||
}
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||
if c.InSlowStart() {
|
||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
|
||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||
// already sent should be treated as a single loss event, since it's expected.
|
||||
if packetNumber <= c.largestSentAtLastCutback {
|
||||
return
|
||||
}
|
||||
c.lastCutbackExitedSlowstart = c.InSlowStart()
|
||||
c.maybeTraceStateChange(logging.CongestionStateRecovery)
|
||||
|
||||
if c.reno {
|
||||
c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
|
||||
} else {
|
||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||
}
|
||||
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
|
||||
c.congestionWindow = minCwnd
|
||||
}
|
||||
c.slowStartThreshold = c.congestionWindow
|
||||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||
// reset packet count from congestion avoidance mode. We start
|
||||
// counting again when we're out of recovery.
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
|
||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||
// represents, but quic has a separate ack for each packet.
|
||||
func (c *cubicSender) maybeIncreaseCwnd(
|
||||
_ congestion.PacketNumber,
|
||||
ackedBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
// Do not increase the congestion window unless the sender is close to using
|
||||
// the current window.
|
||||
if !c.isCwndLimited(priorInFlight) {
|
||||
c.cubic.OnApplicationLimited()
|
||||
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
|
||||
return
|
||||
}
|
||||
if c.congestionWindow >= c.maxCongestionWindow() {
|
||||
return
|
||||
}
|
||||
if c.InSlowStart() {
|
||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||
c.congestionWindow += c.maxDatagramSize
|
||||
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
|
||||
return
|
||||
}
|
||||
// Congestion avoidance
|
||||
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
||||
if c.reno {
|
||||
// Classic Reno congestion avoidance.
|
||||
c.numAckedPackets++
|
||||
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
|
||||
c.congestionWindow += c.maxDatagramSize
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
} else {
|
||||
c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
|
||||
congestionWindow := c.GetCongestionWindow()
|
||||
if bytesInFlight >= congestionWindow {
|
||||
return true
|
||||
}
|
||||
availableBytes := congestionWindow - bytesInFlight
|
||||
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
|
||||
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
|
||||
}
|
||||
|
||||
// BandwidthEstimate returns the current bandwidth estimate
|
||||
func (c *cubicSender) BandwidthEstimate() Bandwidth {
|
||||
if c.rttStats == nil {
|
||||
return infBandwidth
|
||||
}
|
||||
srtt := c.rttStats.SmoothedRTT()
|
||||
if srtt == 0 {
|
||||
// If we haven't measured an rtt, the bandwidth estimate is unknown.
|
||||
return infBandwidth
|
||||
}
|
||||
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
|
||||
}
|
||||
|
||||
// OnRetransmissionTimeout is called on an retransmission timeout
|
||||
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
c.largestSentAtLastCutback = InvalidPacketNumber
|
||||
if !packetsRetransmitted {
|
||||
return
|
||||
}
|
||||
c.hybridSlowStart.Restart()
|
||||
c.cubic.Reset()
|
||||
c.slowStartThreshold = c.congestionWindow / 2
|
||||
c.congestionWindow = c.minCongestionWindow()
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when the connection is migrated (?)
|
||||
func (c *cubicSender) OnConnectionMigration() {
|
||||
c.hybridSlowStart.Restart()
|
||||
c.largestSentPacketNumber = InvalidPacketNumber
|
||||
c.largestAckedPacketNumber = InvalidPacketNumber
|
||||
c.largestSentAtLastCutback = InvalidPacketNumber
|
||||
c.lastCutbackExitedSlowstart = false
|
||||
c.cubic.Reset()
|
||||
c.numAckedPackets = 0
|
||||
c.congestionWindow = c.initialCongestionWindow
|
||||
c.slowStartThreshold = c.initialMaxCongestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
|
||||
if c.tracer == nil || new == c.lastState {
|
||||
return
|
||||
}
|
||||
c.tracer.UpdatedCongestionState(new)
|
||||
c.lastState = new
|
||||
}
|
||||
|
||||
func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||
if s < c.maxDatagramSize {
|
||||
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
|
||||
}
|
||||
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
|
||||
c.maxDatagramSize = s
|
||||
if cwndIsMinCwnd {
|
||||
c.congestionWindow = c.minCongestionWindow()
|
||||
}
|
||||
c.pacer.SetMaxDatagramSize(s)
|
||||
}
|
112
transport/tuic/congestion/hybrid_slow_start.go
Normal file
112
transport/tuic/congestion/hybrid_slow_start.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
||||
// tcp_cubic.c.
|
||||
const hybridStartLowWindow = congestion.ByteCount(16)
|
||||
|
||||
// Number of delay samples for detecting the increase of delay.
|
||||
const hybridStartMinSamples = uint32(8)
|
||||
|
||||
// Exit slow start if the min rtt has increased by more than 1/8th.
|
||||
const hybridStartDelayFactorExp = 3 // 2^3 = 8
|
||||
// The original paper specifies 2 and 8ms, but those have changed over time.
|
||||
const (
|
||||
hybridStartDelayMinThresholdUs = int64(4000)
|
||||
hybridStartDelayMaxThresholdUs = int64(16000)
|
||||
)
|
||||
|
||||
// HybridSlowStart implements the TCP hybrid slow start algorithm
|
||||
type HybridSlowStart struct {
|
||||
endPacketNumber congestion.PacketNumber
|
||||
lastSentPacketNumber congestion.PacketNumber
|
||||
started bool
|
||||
currentMinRTT time.Duration
|
||||
rttSampleCount uint32
|
||||
hystartFound bool
|
||||
}
|
||||
|
||||
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
|
||||
func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) {
|
||||
s.endPacketNumber = lastSent
|
||||
s.currentMinRTT = 0
|
||||
s.rttSampleCount = 0
|
||||
s.started = true
|
||||
}
|
||||
|
||||
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
|
||||
func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool {
|
||||
return s.endPacketNumber < ack
|
||||
}
|
||||
|
||||
// ShouldExitSlowStart should be called on every new ack frame, since a new
|
||||
// RTT measurement can be made then.
|
||||
// rtt: the RTT for this ack packet.
|
||||
// minRTT: is the lowest delay (RTT) we have seen during the session.
|
||||
// congestionWindow: the congestion window in packets.
|
||||
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool {
|
||||
if !s.started {
|
||||
// Time to start the hybrid slow start.
|
||||
s.StartReceiveRound(s.lastSentPacketNumber)
|
||||
}
|
||||
if s.hystartFound {
|
||||
return true
|
||||
}
|
||||
// Second detection parameter - delay increase detection.
|
||||
// Compare the minimum delay (s.currentMinRTT) of the current
|
||||
// burst of packets relative to the minimum delay during the session.
|
||||
// Note: we only look at the first few(8) packets in each burst, since we
|
||||
// only want to compare the lowest RTT of the burst relative to previous
|
||||
// bursts.
|
||||
s.rttSampleCount++
|
||||
if s.rttSampleCount <= hybridStartMinSamples {
|
||||
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
|
||||
s.currentMinRTT = latestRTT
|
||||
}
|
||||
}
|
||||
// We only need to check this once per round.
|
||||
if s.rttSampleCount == hybridStartMinSamples {
|
||||
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
||||
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
||||
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
||||
minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
||||
minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
||||
|
||||
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
||||
s.hystartFound = true
|
||||
}
|
||||
}
|
||||
// Exit from slow start if the cwnd is greater than 16 and
|
||||
// increasing delay is found.
|
||||
return congestionWindow >= hybridStartLowWindow && s.hystartFound
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet was sent
|
||||
func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) {
|
||||
s.lastSentPacketNumber = packetNumber
|
||||
}
|
||||
|
||||
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
|
||||
// the round when the final packet of the burst is received and start it on
|
||||
// the next incoming ack.
|
||||
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) {
|
||||
if s.IsEndOfRound(ackedPacketNumber) {
|
||||
s.started = false
|
||||
}
|
||||
}
|
||||
|
||||
// Started returns true if started
|
||||
func (s *HybridSlowStart) Started() bool {
|
||||
return s.started
|
||||
}
|
||||
|
||||
// Restart the slow start phase
|
||||
func (s *HybridSlowStart) Restart() {
|
||||
s.started = false
|
||||
s.hystartFound = false
|
||||
}
|
72
transport/tuic/congestion/minmax.go
Normal file
72
transport/tuic/congestion/minmax.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// InfDuration is a duration of infinite length
|
||||
const InfDuration = time.Duration(math.MaxInt64)
|
||||
|
||||
func Max[T constraints.Ordered](a, b T) T {
|
||||
if a < b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func Min[T constraints.Ordered](a, b T) T {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// MinNonZeroDuration return the minimum duration that's not zero.
|
||||
func MinNonZeroDuration(a, b time.Duration) time.Duration {
|
||||
if a == 0 {
|
||||
return b
|
||||
}
|
||||
if b == 0 {
|
||||
return a
|
||||
}
|
||||
return Min(a, b)
|
||||
}
|
||||
|
||||
// AbsDuration returns the absolute value of a time duration
|
||||
func AbsDuration(d time.Duration) time.Duration {
|
||||
if d >= 0 {
|
||||
return d
|
||||
}
|
||||
return -d
|
||||
}
|
||||
|
||||
// MinTime returns the earlier time
|
||||
func MinTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// MinNonZeroTime returns the earlist time that is not time.Time{}
|
||||
// If both a and b are time.Time{}, it returns time.Time{}
|
||||
func MinNonZeroTime(a, b time.Time) time.Time {
|
||||
if a.IsZero() {
|
||||
return b
|
||||
}
|
||||
if b.IsZero() {
|
||||
return a
|
||||
}
|
||||
return MinTime(a, b)
|
||||
}
|
||||
|
||||
// MaxTime returns the later time
|
||||
func MaxTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
81
transport/tuic/congestion/pacer.go
Normal file
81
transport/tuic/congestion/pacer.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
initialMaxDatagramSize = congestion.ByteCount(1252)
|
||||
MinPacingDelay = time.Millisecond
|
||||
TimerGranularity = time.Millisecond
|
||||
maxBurstSizePackets = 10
|
||||
)
|
||||
|
||||
// The pacer implements a token bucket pacing algorithm.
|
||||
type pacer struct {
|
||||
budgetAtLastSent congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
lastSentTime time.Time
|
||||
getAdjustedBandwidth func() uint64 // in bytes/s
|
||||
}
|
||||
|
||||
func newPacer(getBandwidth func() Bandwidth) *pacer {
|
||||
p := &pacer{
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
getAdjustedBandwidth: func() uint64 {
|
||||
// Bandwidth is in bits/s. We need the value in bytes/s.
|
||||
bw := uint64(getBandwidth() / BytesPerSecond)
|
||||
// Use a slightly higher value than the actual measured bandwidth.
|
||||
// RTT variations then won't result in under-utilization of the congestion window.
|
||||
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
|
||||
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
|
||||
return bw * 5 / 4
|
||||
},
|
||||
}
|
||||
p.budgetAtLastSent = p.maxBurstSize()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
||||
budget := p.Budget(sendTime)
|
||||
if size > budget {
|
||||
p.budgetAtLastSent = 0
|
||||
} else {
|
||||
p.budgetAtLastSent = budget - size
|
||||
}
|
||||
p.lastSentTime = sendTime
|
||||
}
|
||||
|
||||
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
||||
if p.lastSentTime.IsZero() {
|
||||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
return Min(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
func (p *pacer) maxBurstSize() congestion.ByteCount {
|
||||
return Max(
|
||||
congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9,
|
||||
maxBurstSizePackets*p.maxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||
func (p *pacer) TimeUntilSend() time.Time {
|
||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||
return time.Time{}
|
||||
}
|
||||
return p.lastSentTime.Add(Max(
|
||||
MinPacingDelay,
|
||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond,
|
||||
))
|
||||
}
|
||||
|
||||
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||
p.maxDatagramSize = s
|
||||
}
|
132
transport/tuic/congestion/windowed_filter.go
Normal file
132
transport/tuic/congestion/windowed_filter.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
package congestion
|
||||
|
||||
// WindowedFilter Use the following to construct a windowed filter object of type T.
|
||||
// For example, a min filter using QuicTime as the time type:
|
||||
//
|
||||
// WindowedFilter<T, MinFilter<T>, QuicTime, QuicTime::Delta> ObjectName;
|
||||
//
|
||||
// A max filter using 64-bit integers as the time type:
|
||||
//
|
||||
// WindowedFilter<T, MaxFilter<T>, uint64_t, int64_t> ObjectName;
|
||||
//
|
||||
// Specifically, this template takes four arguments:
|
||||
// 1. T -- type of the measurement that is being filtered.
|
||||
// 2. Compare -- MinFilter<T> or MaxFilter<T>, depending on the type of filter
|
||||
// desired.
|
||||
// 3. TimeT -- the type used to represent timestamps.
|
||||
// 4. TimeDeltaT -- the type used to represent continuous time intervals between
|
||||
// two timestamps. Has to be the type of (a - b) if both |a| and |b| are
|
||||
// of type TimeT.
|
||||
type WindowedFilter struct {
|
||||
// Time length of window.
|
||||
windowLength int64
|
||||
estimates []Sample
|
||||
comparator func(int64, int64) bool
|
||||
}
|
||||
|
||||
type Sample struct {
|
||||
sample int64
|
||||
time int64
|
||||
}
|
||||
|
||||
// Compares two values and returns true if the first is greater than or equal
|
||||
// to the second.
|
||||
func MaxFilter(a, b int64) bool {
|
||||
return a >= b
|
||||
}
|
||||
|
||||
// Compares two values and returns true if the first is less than or equal
|
||||
// to the second.
|
||||
func MinFilter(a, b int64) bool {
|
||||
return a <= b
|
||||
}
|
||||
|
||||
func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter {
|
||||
return &WindowedFilter{
|
||||
windowLength: windowLength,
|
||||
estimates: make([]Sample, 3),
|
||||
comparator: comparator,
|
||||
}
|
||||
}
|
||||
|
||||
// Changes the window length. Does not update any current samples.
|
||||
func (f *WindowedFilter) SetWindowLength(windowLength int64) {
|
||||
f.windowLength = windowLength
|
||||
}
|
||||
|
||||
func (f *WindowedFilter) GetBest() int64 {
|
||||
return f.estimates[0].sample
|
||||
}
|
||||
|
||||
func (f *WindowedFilter) GetSecondBest() int64 {
|
||||
return f.estimates[1].sample
|
||||
}
|
||||
|
||||
func (f *WindowedFilter) GetThirdBest() int64 {
|
||||
return f.estimates[2].sample
|
||||
}
|
||||
|
||||
func (f *WindowedFilter) Update(sample int64, time int64) {
|
||||
if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength {
|
||||
f.Reset(sample, time)
|
||||
return
|
||||
}
|
||||
|
||||
if f.comparator(sample, f.estimates[1].sample) {
|
||||
f.estimates[1].sample = sample
|
||||
f.estimates[1].time = time
|
||||
f.estimates[2].sample = sample
|
||||
f.estimates[2].time = time
|
||||
} else if f.comparator(sample, f.estimates[2].sample) {
|
||||
f.estimates[2].sample = sample
|
||||
f.estimates[2].time = time
|
||||
}
|
||||
|
||||
// Expire and update estimates as necessary.
|
||||
if time-f.estimates[0].time > f.windowLength {
|
||||
// The best estimate hasn't been updated for an entire window, so promote
|
||||
// second and third best estimates.
|
||||
f.estimates[0].sample = f.estimates[1].sample
|
||||
f.estimates[0].time = f.estimates[1].time
|
||||
f.estimates[1].sample = f.estimates[2].sample
|
||||
f.estimates[1].time = f.estimates[2].time
|
||||
f.estimates[2].sample = sample
|
||||
f.estimates[2].time = time
|
||||
// Need to iterate one more time. Check if the new best estimate is
|
||||
// outside the window as well, since it may also have been recorded a
|
||||
// long time ago. Don't need to iterate once more since we cover that
|
||||
// case at the beginning of the method.
|
||||
if time-f.estimates[0].time > f.windowLength {
|
||||
f.estimates[0].sample = f.estimates[1].sample
|
||||
f.estimates[0].time = f.estimates[1].time
|
||||
f.estimates[1].sample = f.estimates[2].sample
|
||||
f.estimates[1].time = f.estimates[2].time
|
||||
}
|
||||
return
|
||||
}
|
||||
if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 {
|
||||
// A quarter of the window has passed without a better sample, so the
|
||||
// second-best estimate is taken from the second quarter of the window.
|
||||
f.estimates[1].sample = sample
|
||||
f.estimates[1].time = time
|
||||
f.estimates[2].sample = sample
|
||||
f.estimates[2].time = time
|
||||
return
|
||||
}
|
||||
|
||||
if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 {
|
||||
// We've passed a half of the window without a better estimate, so take
|
||||
// a third-best estimate from the second half of the window.
|
||||
f.estimates[2].sample = sample
|
||||
f.estimates[2].time = time
|
||||
}
|
||||
}
|
||||
|
||||
func (f *WindowedFilter) Reset(newSample int64, newTime int64) {
|
||||
f.estimates[0].sample = newSample
|
||||
f.estimates[0].time = newTime
|
||||
f.estimates[1].sample = newSample
|
||||
f.estimates[1].time = newTime
|
||||
f.estimates[2].sample = newSample
|
||||
f.estimates[2].time = newTime
|
||||
}
|
497
transport/tuic/packet.go
Normal file
497
transport/tuic/packet.go
Normal file
|
@ -0,0 +1,497 @@
|
|||
package tuic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/cache"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
var udpMessagePool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(udpMessage)
|
||||
},
|
||||
}
|
||||
|
||||
func releaseMessages(messages []*udpMessage) {
|
||||
for _, message := range messages {
|
||||
if message != nil {
|
||||
*message = udpMessage{}
|
||||
udpMessagePool.Put(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type udpMessage struct {
|
||||
sessionID uint16
|
||||
packetID uint16
|
||||
fragmentTotal uint8
|
||||
fragmentID uint8
|
||||
destination M.Socksaddr
|
||||
dataLength uint16
|
||||
data *buf.Buffer
|
||||
}
|
||||
|
||||
func (m *udpMessage) release() {
|
||||
*m = udpMessage{}
|
||||
udpMessagePool.Put(m)
|
||||
}
|
||||
|
||||
func (m *udpMessage) releaseMessage() {
|
||||
m.data.Release()
|
||||
m.release()
|
||||
}
|
||||
|
||||
func (m *udpMessage) pack() *buf.Buffer {
|
||||
buffer := buf.NewSize(m.headerSize() + m.data.Len())
|
||||
common.Must(
|
||||
buffer.WriteByte(Version),
|
||||
buffer.WriteByte(CommandPacket),
|
||||
binary.Write(buffer, binary.BigEndian, m.sessionID),
|
||||
binary.Write(buffer, binary.BigEndian, m.packetID),
|
||||
binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
|
||||
binary.Write(buffer, binary.BigEndian, m.fragmentID),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())),
|
||||
addressSerializer.WriteAddrPort(buffer, m.destination),
|
||||
common.Error(buffer.Write(m.data.Bytes())),
|
||||
)
|
||||
return buffer
|
||||
}
|
||||
|
||||
func (m *udpMessage) headerSize() int {
|
||||
return 2 + 10 + addressSerializer.AddrPortLen(m.destination)
|
||||
}
|
||||
|
||||
func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
|
||||
if message.data.Len() <= maxPacketSize {
|
||||
return []*udpMessage{message}
|
||||
}
|
||||
var fragments []*udpMessage
|
||||
originPacket := message.data.Bytes()
|
||||
udpMTU := maxPacketSize - message.headerSize()
|
||||
for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
|
||||
fragment := udpMessagePool.Get().(*udpMessage)
|
||||
*fragment = *message
|
||||
if remaining > udpMTU {
|
||||
fragment.data = buf.As(originPacket[:udpMTU])
|
||||
originPacket = originPacket[udpMTU:]
|
||||
} else {
|
||||
fragment.data = buf.As(originPacket)
|
||||
originPacket = nil
|
||||
}
|
||||
fragments = append(fragments, fragment)
|
||||
}
|
||||
fragmentTotal := uint16(len(fragments))
|
||||
for index, fragment := range fragments {
|
||||
fragment.fragmentID = uint8(index)
|
||||
fragment.fragmentTotal = uint8(fragmentTotal)
|
||||
if index > 0 {
|
||||
fragment.destination = M.Socksaddr{}
|
||||
}
|
||||
}
|
||||
return fragments
|
||||
}
|
||||
|
||||
type udpPacketConn struct {
|
||||
ctx context.Context
|
||||
cancel common.ContextCancelCauseFunc
|
||||
sessionID uint16
|
||||
quicConn quic.Connection
|
||||
data chan *udpMessage
|
||||
udpStream bool
|
||||
udpMTU int
|
||||
packetId atomic.Uint32
|
||||
closeOnce sync.Once
|
||||
isServer bool
|
||||
defragger *udpDefragger
|
||||
onDestroy func()
|
||||
}
|
||||
|
||||
func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn {
|
||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
||||
return &udpPacketConn{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
quicConn: quicConn,
|
||||
data: make(chan *udpMessage, 64),
|
||||
udpStream: udpStream,
|
||||
isServer: isServer,
|
||||
defragger: newUDPDefragger(),
|
||||
onDestroy: onDestroy,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case p := <-c.data:
|
||||
buffer = p.data
|
||||
destination = p.destination
|
||||
p.release()
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case p := <-c.data:
|
||||
_, err = buffer.ReadOnceFrom(p.data)
|
||||
destination = p.destination
|
||||
p.releaseMessage()
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return M.Socksaddr{}, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case p := <-c.data:
|
||||
_, err = newBuffer().ReadOnceFrom(p.data)
|
||||
destination = p.destination
|
||||
p.releaseMessage()
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return M.Socksaddr{}, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case pkt := <-c.data:
|
||||
n = copy(p, pkt.data.Bytes())
|
||||
if pkt.destination.IsFqdn() {
|
||||
addr = pkt.destination
|
||||
} else {
|
||||
addr = pkt.destination.UDPAddr()
|
||||
}
|
||||
pkt.releaseMessage()
|
||||
return n, addr, nil
|
||||
case <-c.ctx.Done():
|
||||
return 0, nil, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
defer buffer.Release()
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return net.ErrClosed
|
||||
default:
|
||||
}
|
||||
if buffer.Len() > 0xffff {
|
||||
return quic.ErrMessageTooLarge(0xffff)
|
||||
}
|
||||
packetId := c.packetId.Add(1)
|
||||
if packetId > math.MaxUint16 {
|
||||
c.packetId.Store(0)
|
||||
packetId = 0
|
||||
}
|
||||
message := udpMessagePool.Get().(*udpMessage)
|
||||
*message = udpMessage{
|
||||
sessionID: c.sessionID,
|
||||
packetID: uint16(packetId),
|
||||
fragmentTotal: 1,
|
||||
destination: destination,
|
||||
data: buffer,
|
||||
}
|
||||
defer message.releaseMessage()
|
||||
var err error
|
||||
if !c.udpStream && c.udpMTU > 0 && buffer.Len() > c.udpMTU {
|
||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
||||
} else {
|
||||
err = c.writePacket(message)
|
||||
}
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var tooLargeErr quic.ErrMessageTooLarge
|
||||
if !errors.As(err, &tooLargeErr) {
|
||||
return err
|
||||
}
|
||||
c.udpMTU = int(tooLargeErr)
|
||||
return c.writePackets(fragUDPMessage(message, c.udpMTU))
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return 0, net.ErrClosed
|
||||
default:
|
||||
}
|
||||
if len(p) > 0xffff {
|
||||
return 0, quic.ErrMessageTooLarge(0xffff)
|
||||
}
|
||||
packetId := c.packetId.Add(1)
|
||||
if packetId > math.MaxUint16 {
|
||||
c.packetId.Store(0)
|
||||
packetId = 0
|
||||
}
|
||||
message := udpMessagePool.Get().(*udpMessage)
|
||||
*message = udpMessage{
|
||||
sessionID: c.sessionID,
|
||||
packetID: uint16(packetId),
|
||||
fragmentTotal: 1,
|
||||
destination: M.SocksaddrFromNet(addr),
|
||||
data: buf.As(p),
|
||||
}
|
||||
if c.udpMTU > 0 && len(p) > c.udpMTU {
|
||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
||||
if err == nil {
|
||||
return len(p), nil
|
||||
}
|
||||
} else {
|
||||
err = c.writePacket(message)
|
||||
}
|
||||
if err == nil {
|
||||
return len(p), nil
|
||||
}
|
||||
var tooLargeErr quic.ErrMessageTooLarge
|
||||
if !errors.As(err, &tooLargeErr) {
|
||||
return
|
||||
}
|
||||
c.udpMTU = int(tooLargeErr)
|
||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
||||
if err == nil {
|
||||
return len(p), nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) inputPacket(message *udpMessage) {
|
||||
if message.fragmentTotal <= 1 {
|
||||
select {
|
||||
case c.data <- message:
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
newMessage := c.defragger.feed(message)
|
||||
if newMessage != nil {
|
||||
select {
|
||||
case c.data <- newMessage:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
|
||||
defer releaseMessages(messages)
|
||||
for _, message := range messages {
|
||||
err := c.writePacket(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) writePacket(message *udpMessage) error {
|
||||
if !c.udpStream {
|
||||
buffer := message.pack()
|
||||
err := c.quicConn.SendMessage(buffer.Bytes())
|
||||
buffer.Release()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
stream, err := c.quicConn.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer := message.pack()
|
||||
_, err = stream.Write(buffer.Bytes())
|
||||
buffer.Release()
|
||||
stream.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
c.closeWithError(os.ErrClosed)
|
||||
c.onDestroy()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) closeWithError(err error) {
|
||||
c.cancel(err)
|
||||
if !c.isServer {
|
||||
buffer := buf.NewSize(4)
|
||||
defer buffer.Release()
|
||||
buffer.WriteByte(Version)
|
||||
buffer.WriteByte(CommandDissociate)
|
||||
binary.Write(buffer, binary.BigEndian, c.sessionID)
|
||||
sendStream, openErr := c.quicConn.OpenUniStream()
|
||||
if openErr != nil {
|
||||
return
|
||||
}
|
||||
defer sendStream.Close()
|
||||
sendStream.Write(buffer.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) LocalAddr() net.Addr {
|
||||
return c.quicConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) SetDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
type udpDefragger struct {
|
||||
packetMap *cache.LruCache[uint16, *packetItem]
|
||||
}
|
||||
|
||||
func newUDPDefragger() *udpDefragger {
|
||||
return &udpDefragger{
|
||||
packetMap: cache.New(
|
||||
cache.WithAge[uint16, *packetItem](10),
|
||||
cache.WithUpdateAgeOnGet[uint16, *packetItem](),
|
||||
cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
|
||||
releaseMessages(value.messages)
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
type packetItem struct {
|
||||
access sync.Mutex
|
||||
messages []*udpMessage
|
||||
count uint8
|
||||
}
|
||||
|
||||
func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
|
||||
if m.fragmentTotal <= 1 {
|
||||
return m
|
||||
}
|
||||
if m.fragmentID >= m.fragmentTotal {
|
||||
return nil
|
||||
}
|
||||
item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
|
||||
item.access.Lock()
|
||||
defer item.access.Unlock()
|
||||
if int(m.fragmentTotal) != len(item.messages) {
|
||||
releaseMessages(item.messages)
|
||||
item.messages = make([]*udpMessage, m.fragmentTotal)
|
||||
item.count = 1
|
||||
item.messages[m.fragmentID] = m
|
||||
return nil
|
||||
}
|
||||
if item.messages[m.fragmentID] != nil {
|
||||
return nil
|
||||
}
|
||||
item.messages[m.fragmentID] = m
|
||||
item.count++
|
||||
if int(item.count) != len(item.messages) {
|
||||
return nil
|
||||
}
|
||||
newMessage := udpMessagePool.Get().(*udpMessage)
|
||||
*newMessage = *item.messages[0]
|
||||
if m.dataLength > 0 {
|
||||
newMessage.data = buf.NewSize(int(m.dataLength))
|
||||
for _, message := range item.messages {
|
||||
newMessage.data.Write(message.data.Bytes())
|
||||
message.releaseMessage()
|
||||
}
|
||||
item.messages = nil
|
||||
return newMessage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newPacketItem() *packetItem {
|
||||
return new(packetItem)
|
||||
}
|
||||
|
||||
func readUDPMessage(message *udpMessage, reader io.Reader) error {
|
||||
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.packetID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.dataLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
message.destination, err = addressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
message.data = buf.NewSize(int(message.dataLength))
|
||||
_, err = message.data.ReadFullFrom(reader, message.data.FreeLen())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeUDPMessage(message *udpMessage, data []byte) error {
|
||||
reader := bytes.NewReader(data)
|
||||
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.packetID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.dataLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
message.destination, err = addressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reader.Len() != int(message.dataLength) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
message.data = buf.As(data[len(data)-reader.Len():])
|
||||
return nil
|
||||
}
|
15
transport/tuic/protocol.go
Normal file
15
transport/tuic/protocol.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package tuic
|
||||
|
||||
const (
|
||||
Version = 5
|
||||
)
|
||||
|
||||
const (
|
||||
CommandAuthenticate = iota
|
||||
CommandConnect
|
||||
CommandPacket
|
||||
CommandDissociate
|
||||
CommandHeartbeat
|
||||
)
|
||||
|
||||
const AuthenticateLen = 2 + 16 + 32
|
434
transport/tuic/server.go
Normal file
434
transport/tuic/server.go
Normal file
|
@ -0,0 +1,434 @@
|
|||
package tuic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing-box/common/baderror"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
type ServerOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.Logger
|
||||
TLSConfig *tls.Config
|
||||
Users []User
|
||||
CongestionControl string
|
||||
AuthTimeout time.Duration
|
||||
ZeroRTTHandshake bool
|
||||
Heartbeat time.Duration
|
||||
Handler ServerHandler
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
UUID uuid.UUID
|
||||
Password string
|
||||
}
|
||||
|
||||
type ServerHandler interface {
|
||||
N.TCPConnectionHandler
|
||||
N.UDPConnectionHandler
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
logger logger.Logger
|
||||
tlsConfig *tls.Config
|
||||
heartbeat time.Duration
|
||||
quicConfig *quic.Config
|
||||
userMap map[uuid.UUID]User
|
||||
congestionControl string
|
||||
authTimeout time.Duration
|
||||
handler ServerHandler
|
||||
|
||||
quicListener io.Closer
|
||||
}
|
||||
|
||||
func NewServer(options ServerOptions) (*Server, error) {
|
||||
if options.AuthTimeout == 0 {
|
||||
options.AuthTimeout = 3 * time.Second
|
||||
}
|
||||
if options.Heartbeat == 0 {
|
||||
options.Heartbeat = 10 * time.Second
|
||||
}
|
||||
quicConfig := &quic.Config{
|
||||
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
||||
MaxDatagramFrameSize: 1400,
|
||||
EnableDatagrams: true,
|
||||
Allow0RTT: options.ZeroRTTHandshake,
|
||||
MaxIncomingStreams: 1 << 60,
|
||||
MaxIncomingUniStreams: 1 << 60,
|
||||
}
|
||||
switch options.CongestionControl {
|
||||
case "":
|
||||
options.CongestionControl = "cubic"
|
||||
case "cubic", "new_reno", "bbr":
|
||||
default:
|
||||
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
|
||||
}
|
||||
if len(options.Users) == 0 {
|
||||
return nil, E.New("missing users")
|
||||
}
|
||||
userMap := make(map[uuid.UUID]User)
|
||||
for _, user := range options.Users {
|
||||
userMap[user.UUID] = user
|
||||
}
|
||||
return &Server{
|
||||
ctx: options.Context,
|
||||
logger: options.Logger,
|
||||
tlsConfig: options.TLSConfig,
|
||||
heartbeat: options.Heartbeat,
|
||||
quicConfig: quicConfig,
|
||||
userMap: userMap,
|
||||
congestionControl: options.CongestionControl,
|
||||
authTimeout: options.AuthTimeout,
|
||||
handler: options.Handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Start(conn net.PacketConn) error {
|
||||
if !s.quicConfig.Allow0RTT {
|
||||
listener, err := quic.Listen(conn, s.tlsConfig, s.quicConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.quicListener = listener
|
||||
go func() {
|
||||
for {
|
||||
connection, hErr := listener.Accept(s.ctx)
|
||||
if hErr != nil {
|
||||
if strings.Contains(hErr.Error(), "server closed") {
|
||||
s.logger.Debug(E.Cause(hErr, "listener closed"))
|
||||
} else {
|
||||
s.logger.Error(E.Cause(hErr, "listener closed"))
|
||||
}
|
||||
return
|
||||
}
|
||||
go s.handleConnection(connection)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
listener, err := quic.ListenEarly(conn, s.tlsConfig, s.quicConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.quicListener = listener
|
||||
go func() {
|
||||
for {
|
||||
connection, hErr := listener.Accept(s.ctx)
|
||||
if hErr != nil {
|
||||
if strings.Contains(hErr.Error(), "server closed") {
|
||||
s.logger.Debug(E.Cause(hErr, "listener closed"))
|
||||
} else {
|
||||
s.logger.Error(E.Cause(hErr, "listener closed"))
|
||||
}
|
||||
return
|
||||
}
|
||||
go s.handleConnection(connection)
|
||||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
return common.Close(
|
||||
s.quicListener,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(connection quic.Connection) {
|
||||
setCongestion(s.ctx, connection, s.congestionControl)
|
||||
session := &serverSession{
|
||||
Server: s,
|
||||
ctx: s.ctx,
|
||||
quicConn: connection,
|
||||
source: M.SocksaddrFromNet(connection.RemoteAddr()),
|
||||
connDone: make(chan struct{}),
|
||||
authDone: make(chan struct{}),
|
||||
udpConnMap: make(map[uint16]*udpPacketConn),
|
||||
}
|
||||
session.handle()
|
||||
}
|
||||
|
||||
type serverSession struct {
|
||||
*Server
|
||||
ctx context.Context
|
||||
quicConn quic.Connection
|
||||
source M.Socksaddr
|
||||
connAccess sync.Mutex
|
||||
connDone chan struct{}
|
||||
connErr error
|
||||
authDone chan struct{}
|
||||
authUser *User
|
||||
udpAccess sync.RWMutex
|
||||
udpConnMap map[uint16]*udpPacketConn
|
||||
}
|
||||
|
||||
func (s *serverSession) handle() {
|
||||
if s.ctx.Done() != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.closeWithError(s.ctx.Err())
|
||||
case <-s.connDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
go s.loopUniStreams()
|
||||
go s.loopStreams()
|
||||
go s.loopMessages()
|
||||
go s.handleAuthTimeout()
|
||||
go s.loopHeartbeats()
|
||||
}
|
||||
|
||||
func (s *serverSession) loopUniStreams() {
|
||||
for {
|
||||
uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
err = s.handleUniStream(uniStream)
|
||||
if err != nil {
|
||||
s.closeWithError(E.Cause(err, "handle uni stream"))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
|
||||
defer stream.CancelRead(0)
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read request")
|
||||
}
|
||||
version := buffer.Byte(0)
|
||||
if version != Version {
|
||||
return E.New("unknown version ", buffer.Byte(0))
|
||||
}
|
||||
command := buffer.Byte(1)
|
||||
switch command {
|
||||
case CommandAuthenticate:
|
||||
select {
|
||||
case <-s.authDone:
|
||||
return E.New("authentication: multiple authentication requests")
|
||||
default:
|
||||
}
|
||||
if buffer.Len() < AuthenticateLen {
|
||||
_, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len())
|
||||
if err != nil {
|
||||
return E.Cause(err, "authentication: read request")
|
||||
}
|
||||
}
|
||||
userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16))
|
||||
user, loaded := s.userMap[userUUID]
|
||||
if !loaded {
|
||||
return E.New("authentication: unknown user ", userUUID)
|
||||
}
|
||||
handshakeState := s.quicConn.ConnectionState().TLS
|
||||
tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
|
||||
if err != nil {
|
||||
return E.Cause(err, "authentication: export keying material")
|
||||
}
|
||||
if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
|
||||
return E.New("authentication: token mismatch")
|
||||
}
|
||||
s.authUser = &user
|
||||
close(s.authDone)
|
||||
return nil
|
||||
case CommandPacket:
|
||||
select {
|
||||
case <-s.connDone:
|
||||
return s.connErr
|
||||
case <-s.authDone:
|
||||
}
|
||||
message := udpMessagePool.Get().(*udpMessage)
|
||||
err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
|
||||
if err != nil {
|
||||
message.release()
|
||||
return err
|
||||
}
|
||||
s.handleUDPMessage(message, true)
|
||||
return nil
|
||||
case CommandDissociate:
|
||||
select {
|
||||
case <-s.connDone:
|
||||
return s.connErr
|
||||
case <-s.authDone:
|
||||
}
|
||||
if buffer.Len() > 4 {
|
||||
return E.New("invalid dissociate message")
|
||||
}
|
||||
var sessionID uint16
|
||||
err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.udpAccess.RLock()
|
||||
udpConn, loaded := s.udpConnMap[sessionID]
|
||||
s.udpAccess.RUnlock()
|
||||
if loaded {
|
||||
udpConn.closeWithError(E.New("remote closed"))
|
||||
s.udpAccess.Lock()
|
||||
delete(s.udpConnMap, sessionID)
|
||||
s.udpAccess.Unlock()
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return E.New("unknown command ", command)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleAuthTimeout() {
|
||||
select {
|
||||
case <-s.connDone:
|
||||
case <-s.authDone:
|
||||
case <-time.After(s.authTimeout):
|
||||
s.closeWithError(E.New("authentication timeout"))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) loopStreams() {
|
||||
for {
|
||||
stream, err := s.quicConn.AcceptStream(s.ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
err = s.handleStream(stream)
|
||||
if err != nil {
|
||||
stream.CancelRead(0)
|
||||
stream.Close()
|
||||
s.logger.Error(E.Cause(err, "handle stream request"))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleStream(stream quic.Stream) error {
|
||||
buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
|
||||
defer buffer.Release()
|
||||
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read request")
|
||||
}
|
||||
version, _ := buffer.ReadByte()
|
||||
if version != Version {
|
||||
return E.New("unknown version ", buffer.Byte(0))
|
||||
}
|
||||
command, _ := buffer.ReadByte()
|
||||
if command != CommandConnect {
|
||||
return E.New("unsupported stream command ", command)
|
||||
}
|
||||
destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream))
|
||||
if err != nil {
|
||||
return E.Cause(err, "read request destination")
|
||||
}
|
||||
select {
|
||||
case <-s.connDone:
|
||||
return s.connErr
|
||||
case <-s.authDone:
|
||||
}
|
||||
var conn net.Conn = &serverConn{
|
||||
Stream: stream,
|
||||
destination: destination,
|
||||
}
|
||||
if buffer.IsEmpty() {
|
||||
buffer.Release()
|
||||
} else {
|
||||
conn = bufio.NewCachedConn(conn, buffer)
|
||||
}
|
||||
ctx := s.ctx
|
||||
if s.authUser.Name != "" {
|
||||
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
|
||||
}
|
||||
_ = s.handler.NewConnection(ctx, conn, M.Metadata{
|
||||
Source: s.source,
|
||||
Destination: destination,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serverSession) loopHeartbeats() {
|
||||
ticker := time.NewTicker(s.heartbeat)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.connDone:
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
|
||||
if err != nil {
|
||||
s.closeWithError(E.Cause(err, "send heartbeat"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) closeWithError(err error) {
|
||||
s.connAccess.Lock()
|
||||
defer s.connAccess.Unlock()
|
||||
select {
|
||||
case <-s.connDone:
|
||||
return
|
||||
default:
|
||||
s.connErr = err
|
||||
close(s.connDone)
|
||||
}
|
||||
if E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug(E.Cause(err, "connection failed"))
|
||||
} else {
|
||||
s.logger.Error(E.Cause(err, "connection failed"))
|
||||
}
|
||||
_ = s.quicConn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
quic.Stream
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(p []byte) (n int, err error) {
|
||||
n, err = c.Stream.Read(p)
|
||||
return n, baderror.WrapQUIC(err)
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||
n, err = c.Stream.Write(p)
|
||||
return n, baderror.WrapQUIC(err)
|
||||
}
|
||||
|
||||
func (c *serverConn) LocalAddr() net.Addr {
|
||||
return c.destination
|
||||
}
|
||||
|
||||
func (c *serverConn) RemoteAddr() net.Addr {
|
||||
return M.Socksaddr{}
|
||||
}
|
||||
|
||||
func (c *serverConn) Close() error {
|
||||
c.Stream.CancelRead(0)
|
||||
return c.Stream.Close()
|
||||
}
|
73
transport/tuic/server_packet.go
Normal file
73
transport/tuic/server_packet.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package tuic
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func (s *serverSession) loopMessages() {
|
||||
select {
|
||||
case <-s.connDone:
|
||||
return
|
||||
case <-s.authDone:
|
||||
}
|
||||
for {
|
||||
message, err := s.quicConn.ReceiveMessage(s.ctx)
|
||||
if err != nil {
|
||||
s.closeWithError(E.Cause(err, "receive message"))
|
||||
return
|
||||
}
|
||||
hErr := s.handleMessage(message)
|
||||
if hErr != nil {
|
||||
s.closeWithError(E.Cause(hErr, "handle message"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleMessage(data []byte) error {
|
||||
if len(data) < 2 {
|
||||
return E.New("invalid message")
|
||||
}
|
||||
if data[0] != Version {
|
||||
return E.New("unknown version ", data[0])
|
||||
}
|
||||
switch data[1] {
|
||||
case CommandPacket:
|
||||
message := udpMessagePool.Get().(*udpMessage)
|
||||
err := decodeUDPMessage(message, data[2:])
|
||||
if err != nil {
|
||||
message.release()
|
||||
return E.Cause(err, "decode UDP message")
|
||||
}
|
||||
s.handleUDPMessage(message, false)
|
||||
return nil
|
||||
case CommandHeartbeat:
|
||||
return nil
|
||||
default:
|
||||
return E.New("unknown command ", data[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) {
|
||||
s.udpAccess.RLock()
|
||||
udpConn, loaded := s.udpConnMap[message.sessionID]
|
||||
s.udpAccess.RUnlock()
|
||||
if !loaded || common.Done(udpConn.ctx) {
|
||||
udpConn = newUDPPacketConn(s.ctx, s.quicConn, udpStream, true, func() {
|
||||
s.udpAccess.Lock()
|
||||
delete(s.udpConnMap, message.sessionID)
|
||||
s.udpAccess.Unlock()
|
||||
})
|
||||
udpConn.sessionID = message.sessionID
|
||||
s.udpAccess.Lock()
|
||||
s.udpConnMap[message.sessionID] = udpConn
|
||||
s.udpAccess.Unlock()
|
||||
go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
|
||||
Source: s.source,
|
||||
Destination: message.destination,
|
||||
})
|
||||
}
|
||||
udpConn.inputPacket(message)
|
||||
}
|
Loading…
Reference in a new issue