Add TUIC protocol

This commit is contained in:
世界 2023-07-23 14:42:19 +08:00
parent 0b14dc3228
commit 917420e79a
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
34 changed files with 4389 additions and 0 deletions

View file

@ -21,6 +21,7 @@ const (
TypeShadowTLS = "shadowtls"
TypeShadowsocksR = "shadowsocksr"
TypeVLESS = "vless"
TypeTUIC = "tuic"
)
const (
@ -62,6 +63,8 @@ func ProxyDisplayName(proxyType string) string {
return "ShadowsocksR"
case TypeVLESS:
return "VLESS"
case TypeTUIC:
return "TUIC"
case TypeSelector:
return "Selector"
case TypeURLTest:

View file

@ -44,6 +44,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o
return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions)
case C.TypeVLESS:
return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions)
case C.TypeTUIC:
return NewTUIC(ctx, router, logger, options.Tag, options.TUICOptions)
default:
return nil, E.New("unknown inbound type: ", options.Type)
}

View file

@ -153,6 +153,17 @@ func (a *myInboundAdapter) createMetadata(conn net.Conn, metadata adapter.Inboun
return metadata
}
func (a *myInboundAdapter) createPacketMetadata(conn N.PacketConn, metadata adapter.InboundContext) adapter.InboundContext {
metadata.Inbound = a.tag
metadata.InboundType = a.protocol
metadata.InboundDetour = a.listenOptions.Detour
metadata.InboundOptions = a.listenOptions.InboundOptions
if !metadata.Destination.IsValid() {
metadata.Destination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
}
return metadata
}
func (a *myInboundAdapter) newError(err error) {
a.logger.Error(err)
}

114
inbound/tuic.go Normal file
View file

@ -0,0 +1,114 @@
//go:build with_quic
package inbound
import (
"context"
"net"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/tuic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/gofrs/uuid/v5"
)
var _ adapter.Inbound = (*TUIC)(nil)
type TUIC struct {
myInboundAdapter
server *tuic.Server
tlsConfig tls.ServerConfig
}
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) {
options.UDPFragmentDefault = true
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
rawConfig, err := tlsConfig.Config()
if err != nil {
return nil, err
}
var users []tuic.User
for index, user := range options.Users {
if user.UUID == "" {
return nil, E.New("missing uuid for user ", index)
}
userUUID, err := uuid.FromString(user.UUID)
if err != nil {
return nil, E.Cause(err, "invalid uuid for user ", index)
}
users = append(users, tuic.User{Name: user.Name, UUID: userUUID, Password: user.Password})
}
inbound := &TUIC{
myInboundAdapter: myInboundAdapter{
protocol: C.TypeTUIC,
network: []string{N.NetworkUDP},
ctx: ctx,
router: router,
logger: logger,
tag: tag,
listenOptions: options.ListenOptions,
},
}
server, err := tuic.NewServer(tuic.ServerOptions{
Context: ctx,
Logger: logger,
TLSConfig: rawConfig,
Users: users,
CongestionControl: options.CongestionControl,
AuthTimeout: time.Duration(options.AuthTimeout),
ZeroRTTHandshake: options.ZeroRTTHandshake,
Heartbeat: time.Duration(options.Heartbeat),
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil),
})
if err != nil {
return nil, err
}
inbound.server = server
return inbound, nil
}
func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
ctx = log.ContextWithNewID(ctx)
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
metadata = h.createMetadata(conn, metadata)
metadata.User, _ = auth.UserFromContext[string](ctx)
return h.router.RouteConnection(ctx, conn, metadata)
}
func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
ctx = log.ContextWithNewID(ctx)
metadata = h.createPacketMetadata(conn, metadata)
metadata.User, _ = auth.UserFromContext[string](ctx)
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
return h.router.RoutePacketConnection(ctx, conn, metadata)
}
func (h *TUIC) Start() error {
packetConn, err := h.myInboundAdapter.ListenUDP()
if err != nil {
return err
}
return h.server.Start(packetConn)
}
func (h *TUIC) Close() error {
return common.Close(
&h.myInboundAdapter,
common.PtrOrNil(h.server),
)
}

16
inbound/tuic_stub.go Normal file
View file

@ -0,0 +1,16 @@
//go:build !with_quic
package inbound
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (adapter.Inbound, error) {
return nil, C.ErrQUICNotIncluded
}

View file

@ -23,6 +23,7 @@ type _Inbound struct {
HysteriaOptions HysteriaInboundOptions `json:"-"`
ShadowTLSOptions ShadowTLSInboundOptions `json:"-"`
VLESSOptions VLESSInboundOptions `json:"-"`
TUICOptions TUICInboundOptions `json:"-"`
}
type Inbound _Inbound
@ -58,6 +59,8 @@ func (h Inbound) MarshalJSON() ([]byte, error) {
v = h.ShadowTLSOptions
case C.TypeVLESS:
v = h.VLESSOptions
case C.TypeTUIC:
v = h.TUICOptions
default:
return nil, E.New("unknown inbound type: ", h.Type)
}
@ -99,6 +102,8 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error {
v = &h.ShadowTLSOptions
case C.TypeVLESS:
v = &h.VLESSOptions
case C.TypeTUIC:
v = &h.TUICOptions
default:
return E.New("unknown inbound type: ", h.Type)
}

View file

@ -23,6 +23,7 @@ type _Outbound struct {
ShadowTLSOptions ShadowTLSOutboundOptions `json:"-"`
ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"`
VLESSOptions VLESSOutboundOptions `json:"-"`
TUICOptions TUICOutboundOptions `json:"-"`
SelectorOptions SelectorOutboundOptions `json:"-"`
URLTestOptions URLTestOutboundOptions `json:"-"`
}
@ -60,6 +61,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) {
v = h.ShadowsocksROptions
case C.TypeVLESS:
v = h.VLESSOptions
case C.TypeTUIC:
v = h.TUICOptions
case C.TypeSelector:
v = h.SelectorOptions
case C.TypeURLTest:
@ -105,6 +108,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error {
v = &h.ShadowsocksROptions
case C.TypeVLESS:
v = &h.VLESSOptions
case C.TypeTUIC:
v = &h.TUICOptions
case C.TypeSelector:
v = &h.SelectorOptions
case C.TypeURLTest:

30
option/tuic.go Normal file
View file

@ -0,0 +1,30 @@
package option
type TUICInboundOptions struct {
ListenOptions
Users []TUICUser `json:"users,omitempty"`
CongestionControl string `json:"congestion_control,omitempty"`
AuthTimeout Duration `json:"auth_timeout,omitempty"`
ZeroRTTHandshake bool `json:"zero_rtt_handshake,omitempty"`
Heartbeat Duration `json:"heartbeat,omitempty"`
TLS *InboundTLSOptions `json:"tls,omitempty"`
}
type TUICUser struct {
Name string `json:"name,omitempty"`
UUID string `json:"uuid,omitempty"`
Password string `json:"password,omitempty"`
}
type TUICOutboundOptions struct {
DialerOptions
ServerOptions
UUID string `json:"uuid,omitempty"`
Password string `json:"password,omitempty"`
CongestionControl string `json:"congestion_control,omitempty"`
UDPRelayMode string `json:"udp_relay_mode,omitempty"`
ZeroRTTHandshake bool `json:"zero_rtt_handshake,omitempty"`
Heartbeat Duration `json:"heartbeat,omitempty"`
Network NetworkList `json:"network,omitempty"`
TLS *OutboundTLSOptions `json:"tls,omitempty"`
}

View file

@ -51,6 +51,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, t
return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions)
case C.TypeVLESS:
return NewVLESS(ctx, router, logger, tag, options.VLESSOptions)
case C.TypeTUIC:
return NewTUIC(ctx, router, logger, tag, options.TUICOptions)
case C.TypeSelector:
return NewSelector(router, logger, tag, options.SelectorOptions)
case C.TypeURLTest:

123
outbound/tuic.go Normal file
View file

@ -0,0 +1,123 @@
//go:build with_quic
package outbound
import (
"context"
"net"
"os"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/tuic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/gofrs/uuid/v5"
)
var (
_ adapter.Outbound = (*TUIC)(nil)
_ adapter.InterfaceUpdateListener = (*TUIC)(nil)
)
type TUIC struct {
myOutboundAdapter
client *tuic.Client
}
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (*TUIC, error) {
options.UDPFragmentDefault = true
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
tlsConfig, err := abstractTLSConfig.Config()
if err != nil {
return nil, err
}
userUUID, err := uuid.FromString(options.UUID)
if err != nil {
return nil, E.Cause(err, "invalid uuid")
}
var udpStream bool
switch options.UDPRelayMode {
case "native":
case "quic":
udpStream = true
}
client, err := tuic.NewClient(tuic.ClientOptions{
Context: ctx,
Dialer: dialer.New(router, options.DialerOptions),
ServerAddress: options.ServerOptions.Build(),
TLSConfig: tlsConfig,
UUID: userUUID,
Password: options.Password,
CongestionControl: options.CongestionControl,
UDPStream: udpStream,
ZeroRTTHandshake: options.ZeroRTTHandshake,
Heartbeat: time.Duration(options.Heartbeat),
})
if err != nil {
return nil, err
}
return &TUIC{
myOutboundAdapter: myOutboundAdapter{
protocol: C.TypeTUIC,
network: options.Network.Build(),
router: router,
logger: logger,
tag: tag,
dependencies: withDialerDependency(options.DialerOptions),
},
client: client,
}, nil
}
func (h *TUIC) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
return h.client.DialConn(ctx, destination)
case N.NetworkUDP:
conn, err := h.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return bufio.NewBindPacketConn(conn, destination), nil
default:
return nil, E.New("unsupported network: ", network)
}
}
func (h *TUIC) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return h.client.ListenPacket(ctx)
}
func (h *TUIC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return NewConnection(ctx, h, conn, metadata)
}
func (h *TUIC) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return NewPacketConnection(ctx, h, conn, metadata)
}
func (h *TUIC) InterfaceUpdated() {
_ = h.client.CloseWithError(E.New("network changed"))
}
func (h *TUIC) Close() error {
return h.client.CloseWithError(os.ErrClosed)
}

16
outbound/tuic_stub.go Normal file
View file

@ -0,0 +1,16 @@
//go:build !with_quic
package outbound
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (adapter.Outbound, error) {
return nil, C.ErrQUICNotIncluded
}

View file

@ -38,6 +38,8 @@ const (
ImageShadowsocksR = "teddysun/shadowsocks-r:latest"
ImageXRayCore = "teddysun/xray:latest"
ImageShadowsocksLegacy = "mritd/shadowsocks:latest"
ImageTUICServer = ""
ImageTUICClient = ""
)
var allImages = []string{
@ -53,6 +55,8 @@ var allImages = []string{
ImageShadowsocksR,
ImageXRayCore,
ImageShadowsocksLegacy,
// ImageTUICServer,
// ImageTUICClient,
}
var localIP = netip.MustParseAddr("127.0.0.1")

View file

@ -0,0 +1,14 @@
{
"relay": {
"server": "127.0.0.1:10000",
"uuid": "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
"password": "tuic",
"certificates": [
"/etc/tuic/ca.pem"
]
},
"local": {
"server": "127.0.0.1:10001"
},
"log_level": "debug"
}

View file

@ -0,0 +1,9 @@
{
"server": "[::]:10000",
"users": {
"FE35D05B-8803-45C4-BAE6-723AD2CD5D3D": "tuic"
},
"certificate": "/etc/tuic/cert.pem",
"private_key": "/etc/tuic/key.pem",
"log_level": "debug"
}

178
test/tuic_test.go Normal file
View file

@ -0,0 +1,178 @@
package main
import (
"net/netip"
"testing"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/gofrs/uuid/v5"
)
func TestTUICSelf(t *testing.T) {
t.Run("self", func(t *testing.T) {
testTUICSelf(t, false, false)
})
t.Run("self-udp-stream", func(t *testing.T) {
testTUICSelf(t, true, false)
})
t.Run("self-early", func(t *testing.T) {
testTUICSelf(t, false, true)
})
}
func testTUICSelf(t *testing.T, udpStream bool, zeroRTTHandshake bool) {
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
var udpRelayMode string
if udpStream {
udpRelayMode = "quic"
}
startInstance(t, option.Options{
Inbounds: []option.Inbound{
{
Type: C.TypeMixed,
Tag: "mixed-in",
MixedOptions: option.HTTPMixedInboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
ListenPort: clientPort,
},
},
},
{
Type: C.TypeTUIC,
TUICOptions: option.TUICInboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
ListenPort: serverPort,
},
Users: []option.TUICUser{{
UUID: uuid.Nil.String(),
}},
ZeroRTTHandshake: zeroRTTHandshake,
TLS: &option.InboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
KeyPath: keyPem,
},
},
},
},
Outbounds: []option.Outbound{
{
Type: C.TypeDirect,
},
{
Type: C.TypeTUIC,
Tag: "tuic-out",
TUICOptions: option.TUICOutboundOptions{
ServerOptions: option.ServerOptions{
Server: "127.0.0.1",
ServerPort: serverPort,
},
UUID: uuid.Nil.String(),
UDPRelayMode: udpRelayMode,
ZeroRTTHandshake: zeroRTTHandshake,
TLS: &option.OutboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
},
},
},
},
Route: &option.RouteOptions{
Rules: []option.Rule{
{
DefaultOptions: option.DefaultRule{
Inbound: []string{"mixed-in"},
Outbound: "tuic-out",
},
},
},
},
})
testSuit(t, clientPort, testPort)
}
func TestTUICInbound(t *testing.T) {
caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
startInstance(t, option.Options{
Inbounds: []option.Inbound{
{
Type: C.TypeTUIC,
TUICOptions: option.TUICInboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
ListenPort: serverPort,
},
Users: []option.TUICUser{{
UUID: "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
Password: "tuic",
}},
TLS: &option.InboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
KeyPath: keyPem,
},
},
},
},
})
startDockerContainer(t, DockerOptions{
Image: ImageTUICClient,
Ports: []uint16{serverPort, clientPort},
Bind: map[string]string{
"tuic-client.json": "/etc/tuic/config.json",
caPem: "/etc/tuic/ca.pem",
},
})
}
func TestTUICOutbound(t *testing.T) {
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
startDockerContainer(t, DockerOptions{
Image: ImageTUICServer,
Ports: []uint16{testPort},
Bind: map[string]string{
"tuic-server.json": "/etc/tuic/config.json",
certPem: "/etc/tuic/cert.pem",
keyPem: "/etc/tuic/key.pem",
},
})
startInstance(t, option.Options{
Inbounds: []option.Inbound{
{
Type: C.TypeMixed,
MixedOptions: option.HTTPMixedInboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
ListenPort: clientPort,
},
},
},
},
Outbounds: []option.Outbound{
{
Type: C.TypeTUIC,
TUICOptions: option.TUICOutboundOptions{
ServerOptions: option.ServerOptions{
Server: "127.0.0.1",
ServerPort: serverPort,
},
UUID: "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
Password: "tuic",
TLS: &option.OutboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
},
},
},
},
})
testSuit(t, clientPort, testPort)
}

10
transport/tuic/address.go Normal file
View file

@ -0,0 +1,10 @@
package tuic
import M "github.com/sagernet/sing/common/metadata"
var addressSerializer = M.NewSerializer(
M.AddressFamilyByte(0x00, M.AddressFamilyFqdn),
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
M.AddressFamilyByte(0x02, M.AddressFamilyIPv6),
M.AddressFamilyByte(0xff, M.AddressFamilyEmpty),
)

322
transport/tuic/client.go Normal file
View file

@ -0,0 +1,322 @@
package tuic
import (
"context"
"crypto/tls"
"io"
"net"
"os"
"runtime"
"sync"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/gofrs/uuid/v5"
)
type ClientOptions struct {
Context context.Context
Dialer N.Dialer
ServerAddress M.Socksaddr
TLSConfig *tls.Config
UUID uuid.UUID
Password string
CongestionControl string
UDPStream bool
ZeroRTTHandshake bool
Heartbeat time.Duration
}
type Client struct {
ctx context.Context
dialer N.Dialer
serverAddr M.Socksaddr
tlsConfig *tls.Config
quicConfig *quic.Config
uuid uuid.UUID
password string
congestionControl string
udpStream bool
zeroRTTHandshake bool
heartbeat time.Duration
connAccess sync.RWMutex
conn *clientQUICConnection
}
func NewClient(options ClientOptions) (*Client, error) {
if options.Heartbeat == 0 {
options.Heartbeat = 10 * time.Second
}
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
MaxDatagramFrameSize: 1400,
EnableDatagrams: true,
MaxIncomingUniStreams: 1 << 60,
}
switch options.CongestionControl {
case "":
options.CongestionControl = "cubic"
case "cubic", "new_reno", "bbr":
default:
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
}
return &Client{
ctx: options.Context,
dialer: options.Dialer,
serverAddr: options.ServerAddress,
tlsConfig: options.TLSConfig,
quicConfig: quicConfig,
uuid: options.UUID,
password: options.Password,
congestionControl: options.CongestionControl,
udpStream: options.UDPStream,
zeroRTTHandshake: options.ZeroRTTHandshake,
heartbeat: options.Heartbeat,
}, nil
}
func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
conn := c.conn
if conn != nil && conn.active() {
return conn, nil
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
conn = c.conn
if conn != nil && conn.active() {
return conn, nil
}
conn, err := c.offerNew(ctx)
if err != nil {
return nil, err
}
return conn, nil
}
func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr)
if err != nil {
return nil, err
}
var quicConn quic.Connection
if c.zeroRTTHandshake {
quicConn, err = quic.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
} else {
quicConn, err = quic.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
}
if err != nil {
udpConn.Close()
return nil, E.Cause(err, "open connection")
}
setCongestion(c.ctx, quicConn, c.congestionControl)
conn := &clientQUICConnection{
quicConn: quicConn,
rawConn: udpConn,
connDone: make(chan struct{}),
udpConnMap: make(map[uint16]*udpPacketConn),
}
go func() {
hErr := c.clientHandshake(quicConn)
if hErr != nil {
conn.closeWithError(hErr)
}
}()
if c.udpStream {
go c.loopUniStreams(conn)
}
go c.loopMessages(conn)
go c.loopHeartbeats(conn)
c.conn = conn
return conn, nil
}
func (c *Client) clientHandshake(conn quic.Connection) error {
authStream, err := conn.OpenUniStream()
if err != nil {
return err
}
defer authStream.Close()
handshakeState := conn.ConnectionState().TLS
tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32)
if err != nil {
return err
}
authRequest := buf.NewSize(AuthenticateLen)
authRequest.WriteByte(Version)
authRequest.WriteByte(CommandAuthenticate)
authRequest.Write(c.uuid[:])
authRequest.Write(tuicAuthToken)
return common.Error(authStream.Write(authRequest.Bytes()))
}
func (c *Client) loopHeartbeats(conn *clientQUICConnection) {
ticker := time.NewTicker(c.heartbeat)
defer ticker.Stop()
for {
select {
case <-conn.connDone:
return
case <-ticker.C:
err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
if err != nil {
conn.closeWithError(E.Cause(err, "send heartbeat"))
}
}
}
}
func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
conn, err := c.offer(ctx)
if err != nil {
return nil, err
}
stream, err := conn.quicConn.OpenStream()
if err != nil {
return nil, err
}
return &clientConn{
parent: conn,
stream: stream,
destination: destination,
}, nil
}
func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
conn, err := c.offer(ctx)
if err != nil {
return nil, err
}
var sessionID uint16
clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() {
conn.udpAccess.Lock()
delete(conn.udpConnMap, sessionID)
conn.udpAccess.Unlock()
})
conn.udpAccess.Lock()
sessionID = conn.udpSessionID
conn.udpSessionID++
conn.udpConnMap[sessionID] = clientPacketConn
conn.udpAccess.Unlock()
clientPacketConn.sessionID = sessionID
return clientPacketConn, nil
}
func (c *Client) CloseWithError(err error) error {
conn := c.conn
if conn != nil {
conn.closeWithError(err)
}
return nil
}
type clientQUICConnection struct {
quicConn quic.Connection
rawConn io.Closer
closeOnce sync.Once
connDone chan struct{}
connErr error
udpAccess sync.RWMutex
udpConnMap map[uint16]*udpPacketConn
udpSessionID uint16
}
func (c *clientQUICConnection) active() bool {
select {
case <-c.quicConn.Context().Done():
return false
default:
}
select {
case <-c.connDone:
return false
default:
}
return true
}
func (c *clientQUICConnection) closeWithError(err error) {
c.closeOnce.Do(func() {
c.connErr = err
close(c.connDone)
_ = c.quicConn.CloseWithError(0, "")
_ = c.rawConn.Close()
})
}
type clientConn struct {
parent *clientQUICConnection
stream quic.Stream
destination M.Socksaddr
requestWritten bool
}
func (c *clientConn) Read(b []byte) (n int, err error) {
n, err = c.stream.Read(b)
return n, baderror.WrapQUIC(err)
}
func (c *clientConn) Write(b []byte) (n int, err error) {
if !c.requestWritten {
request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
request.WriteByte(Version)
request.WriteByte(CommandConnect)
addressSerializer.WriteAddrPort(request, c.destination)
request.Write(b)
_, err = c.stream.Write(request.Bytes())
if err != nil {
c.parent.closeWithError(E.Cause(err, "create new connection"))
return 0, baderror.WrapQUIC(err)
}
c.requestWritten = true
return len(b), nil
}
n, err = c.stream.Write(b)
return n, baderror.WrapQUIC(err)
}
func (c *clientConn) Close() error {
stream := c.stream
if stream == nil {
return nil
}
stream.CancelRead(0)
return stream.Close()
}
func (c *clientConn) LocalAddr() net.Addr {
return M.Socksaddr{}
}
func (c *clientConn) RemoteAddr() net.Addr {
return c.destination
}
func (c *clientConn) SetDeadline(t time.Time) error {
if c.stream == nil {
return os.ErrInvalid
}
return c.stream.SetDeadline(t)
}
func (c *clientConn) SetReadDeadline(t time.Time) error {
if c.stream == nil {
return os.ErrInvalid
}
return c.stream.SetReadDeadline(t)
}
func (c *clientConn) SetWriteDeadline(t time.Time) error {
if c.stream == nil {
return os.ErrInvalid
}
return c.stream.SetWriteDeadline(t)
}

View file

@ -0,0 +1,110 @@
package tuic
import (
"io"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
)
func (c *Client) loopMessages(conn *clientQUICConnection) {
for {
message, err := conn.quicConn.ReceiveMessage(c.ctx)
if err != nil {
conn.closeWithError(E.Cause(err, "receive message"))
return
}
go func() {
hErr := c.handleMessage(conn, message)
if hErr != nil {
conn.closeWithError(E.Cause(hErr, "handle message"))
}
}()
}
}
func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
if len(data) < 2 {
return E.New("invalid message")
}
if data[0] != Version {
return E.New("unknown version ", data[0])
}
switch data[1] {
case CommandPacket:
message := udpMessagePool.Get().(*udpMessage)
err := decodeUDPMessage(message, data[2:])
if err != nil {
message.release()
return E.Cause(err, "decode UDP message")
}
conn.handleUDPMessage(message)
return nil
case CommandHeartbeat:
return nil
default:
return E.New("unknown command ", data[0])
}
}
func (c *Client) loopUniStreams(conn *clientQUICConnection) {
for {
stream, err := conn.quicConn.AcceptUniStream(c.ctx)
if err != nil {
conn.closeWithError(E.Cause(err, "handle uni stream"))
return
}
go func() {
hErr := c.handleUniStream(conn, stream)
if hErr != nil {
conn.closeWithError(hErr)
}
}()
}
}
func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error {
defer stream.CancelRead(0)
buffer := buf.NewPacket()
defer buffer.Release()
_, err := buffer.ReadAtLeastFrom(stream, 2)
if err != nil {
return err
}
version, _ := buffer.ReadByte()
if version != Version {
return E.New("unknown version ", version)
}
command, _ := buffer.ReadByte()
if command != CommandPacket {
return E.New("unknown command ", command)
}
reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
message := udpMessagePool.Get().(*udpMessage)
err = readUDPMessage(message, reader)
if err != nil {
message.release()
return err
}
conn.handleUDPMessage(message)
return nil
}
func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
c.udpAccess.RLock()
udpConn, loaded := c.udpConnMap[message.sessionID]
c.udpAccess.RUnlock()
if !loaded {
message.releaseMessage()
return
}
select {
case <-udpConn.ctx.Done():
message.releaseMessage()
return
default:
}
udpConn.inputPacket(message)
}

View file

@ -0,0 +1,46 @@
package tuic
import (
"context"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/transport/tuic/congestion"
"github.com/sagernet/sing/common/ntp"
)
func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) {
timeFunc := ntp.TimeFuncFromContext(ctx)
if timeFunc == nil {
timeFunc = time.Now
}
switch congestionName {
case "cubic":
connection.SetCongestionControl(
congestion.NewCubicSender(
congestion.DefaultClock{TimeFunc: timeFunc},
congestion.GetInitialPacketSize(connection.RemoteAddr()),
false,
nil,
),
)
case "new_reno":
connection.SetCongestionControl(
congestion.NewCubicSender(
congestion.DefaultClock{TimeFunc: timeFunc},
congestion.GetInitialPacketSize(connection.RemoteAddr()),
true,
nil,
),
)
case "bbr":
connection.SetCongestionControl(
congestion.NewBBRSender(
congestion.DefaultClock{},
congestion.GetInitialPacketSize(connection.RemoteAddr()),
congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize,
congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize,
),
)
}
}

View file

@ -0,0 +1,3 @@
# congestion
mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion

View file

@ -0,0 +1,25 @@
package congestion
import (
"math"
"time"
"github.com/sagernet/quic-go/congestion"
)
// Bandwidth of a connection
type Bandwidth uint64
const infBandwidth Bandwidth = math.MaxUint64
const (
// BitsPerSecond is 1 bit per second
BitsPerSecond Bandwidth = 1
// BytesPerSecond is 1 byte per second
BytesPerSecond = 8 * BitsPerSecond
)
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth {
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
}

View file

@ -0,0 +1,374 @@
package congestion
import (
"math"
"time"
"github.com/sagernet/quic-go/congestion"
)
var InfiniteBandwidth = Bandwidth(math.MaxUint64)
// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned
// to the caller when the packet is acked or lost.
type SendTimeState struct {
// Whether other states in this object is valid.
isValid bool
// Whether the sender is app limited at the time the packet was sent.
// App limited bandwidth sample might be artificially low because the sender
// did not have enough data to send in order to saturate the link.
isAppLimited bool
// Total number of sent bytes at the time the packet was sent.
// Includes the packet itself.
totalBytesSent congestion.ByteCount
// Total number of acked bytes at the time the packet was sent.
totalBytesAcked congestion.ByteCount
// Total number of lost bytes at the time the packet was sent.
totalBytesLost congestion.ByteCount
}
// ConnectionStateOnSentPacket represents the information about a sent packet
// and the state of the connection at the moment the packet was sent,
// specifically the information about the most recently acknowledged packet at
// that moment.
type ConnectionStateOnSentPacket struct {
packetNumber congestion.PacketNumber
// Time at which the packet is sent.
sendTime time.Time
// Size of the packet.
size congestion.ByteCount
// The value of |totalBytesSentAtLastAckedPacket| at the time the
// packet was sent.
totalBytesSentAtLastAckedPacket congestion.ByteCount
// The value of |lastAckedPacketSentTime| at the time the packet was
// sent.
lastAckedPacketSentTime time.Time
// The value of |lastAckedPacketAckTime| at the time the packet was
// sent.
lastAckedPacketAckTime time.Time
// Send time states that are returned to the congestion controller when the
// packet is acked or lost.
sendTimeState SendTimeState
}
// BandwidthSample
type BandwidthSample struct {
// The bandwidth at that particular sample. Zero if no valid bandwidth sample
// is available.
bandwidth Bandwidth
// The RTT measurement at this particular sample. Zero if no RTT sample is
// available. Does not correct for delayed ack time.
rtt time.Duration
// States captured when the packet was sent.
stateAtSend SendTimeState
}
func NewBandwidthSample() *BandwidthSample {
return &BandwidthSample{
// FIXME: the default value of original code is zero.
rtt: InfiniteRTT,
}
}
// BandwidthSampler keeps track of sent and acknowledged packets and outputs a
// bandwidth sample for every packet acknowledged. The samples are taken for
// individual packets, and are not filtered; the consumer has to filter the
// bandwidth samples itself. In certain cases, the sampler will locally severely
// underestimate the bandwidth, hence a maximum filter with a size of at least
// one RTT is recommended.
//
// This class bases its samples on the slope of two curves: the number of bytes
// sent over time, and the number of bytes acknowledged as received over time.
// It produces a sample of both slopes for every packet that gets acknowledged,
// based on a slope between two points on each of the corresponding curves. Note
// that due to the packet loss, the number of bytes on each curve might get
// further and further away from each other, meaning that it is not feasible to
// compare byte values coming from different curves with each other.
//
// The obvious points for measuring slope sample are the ones corresponding to
// the packet that was just acknowledged. Let us denote them as S_1 (point at
// which the current packet was sent) and A_1 (point at which the current packet
// was acknowledged). However, taking a slope requires two points on each line,
// so estimating bandwidth requires picking a packet in the past with respect to
// which the slope is measured.
//
// For that purpose, BandwidthSampler always keeps track of the most recently
// acknowledged packet, and records it together with every outgoing packet.
// When a packet gets acknowledged (A_1), it has not only information about when
// it itself was sent (S_1), but also the information about the latest
// acknowledged packet right before it was sent (S_0 and A_0).
//
// Based on that data, send and ack rate are estimated as:
//
// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0))
// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0))
//
// Here, the ack rate is intuitively the rate we want to treat as bandwidth.
// However, in certain cases (e.g. ack compression) the ack rate at a point may
// end up higher than the rate at which the data was originally sent, which is
// not indicative of the real bandwidth. Hence, we use the send rate as an upper
// bound, and the sample value is
//
// rate_sample = min(send_rate, ack_rate)
//
// An important edge case handled by the sampler is tracking the app-limited
// samples. There are multiple meaning of "app-limited" used interchangeably,
// hence it is important to understand and to be able to distinguish between
// them.
//
// Meaning 1: connection state. The connection is said to be app-limited when
// there is no outstanding data to send. This means that certain bandwidth
// samples in the future would not be an accurate indication of the link
// capacity, and it is important to inform consumer about that. Whenever
// connection becomes app-limited, the sampler is notified via OnAppLimited()
// method.
//
// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth
// sampler becomes notified about the connection being app-limited, it enters
// app-limited phase. In that phase, all *sent* packets are marked as
// app-limited. Note that the connection itself does not have to be
// app-limited during the app-limited phase, and in fact it will not be
// (otherwise how would it send packets?). The boolean flag below indicates
// whether the sampler is in that phase.
//
// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is
// sent during the app-limited phase, the resulting sample related to the
// packet will be marked as app-limited.
//
// With the terminology issue out of the way, let us consider the question of
// what kind of situation it addresses.
//
// Consider a scenario where we first send packets 1 to 20 at a regular
// bandwidth, and then immediately run out of data. After a few seconds, we send
// packets 21 to 60, and only receive ack for 21 between sending packets 40 and
// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0
// we use to compute the slope is going to be packet 20, a few seconds apart
// from the current packet, hence the resulting estimate would be extremely low
// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21,
// meaning that the bandwidth sample would exclude the quiescence.
//
// Based on the analysis of that scenario, we implement the following rule: once
// OnAppLimited() is called, all sent packets will produce app-limited samples
// up until an ack for a packet that was sent after OnAppLimited() was called.
// Note that while the scenario above is not the only scenario when the
// connection is app-limited, the approach works in other cases too.
type BandwidthSampler struct {
// The total number of congestion controlled bytes sent during the connection.
totalBytesSent congestion.ByteCount
// The total number of congestion controlled bytes which were acknowledged.
totalBytesAcked congestion.ByteCount
// The total number of congestion controlled bytes which were lost.
totalBytesLost congestion.ByteCount
// The value of |totalBytesSent| at the time the last acknowledged packet
// was sent. Valid only when |lastAckedPacketSentTime| is valid.
totalBytesSentAtLastAckedPacket congestion.ByteCount
// The time at which the last acknowledged packet was sent. Set to
// QuicTime::Zero() if no valid timestamp is available.
lastAckedPacketSentTime time.Time
// The time at which the most recent packet was acknowledged.
lastAckedPacketAckTime time.Time
// The most recently sent packet.
lastSendPacket congestion.PacketNumber
// Indicates whether the bandwidth sampler is currently in an app-limited
// phase.
isAppLimited bool
// The packet that will be acknowledged after this one will cause the sampler
// to exit the app-limited phase.
endOfAppLimitedPhase congestion.PacketNumber
// Record of the connection state at the point where each packet in flight was
// sent, indexed by the packet number.
connectionStats *ConnectionStates
}
func NewBandwidthSampler() *BandwidthSampler {
return &BandwidthSampler{
connectionStats: &ConnectionStates{
stats: make(map[congestion.PacketNumber]*ConnectionStateOnSentPacket),
},
}
}
// OnPacketSent Inputs the sent packet information into the sampler. Assumes that all
// packets are sent in order. The information about the packet will not be
// released from the sampler until it the packet is either acknowledged or
// declared lost.
func (s *BandwidthSampler) OnPacketSent(sentTime time.Time, lastSentPacket congestion.PacketNumber, sentBytes, bytesInFlight congestion.ByteCount, hasRetransmittableData bool) {
s.lastSendPacket = lastSentPacket
if !hasRetransmittableData {
return
}
s.totalBytesSent += sentBytes
// If there are no packets in flight, the time at which the new transmission
// opens can be treated as the A_0 point for the purpose of bandwidth
// sampling. This underestimates bandwidth to some extent, and produces some
// artificially low samples for most packets in flight, but it provides with
// samples at important points where we would not have them otherwise, most
// importantly at the beginning of the connection.
if bytesInFlight == 0 {
s.lastAckedPacketAckTime = sentTime
s.totalBytesSentAtLastAckedPacket = s.totalBytesSent
// In this situation ack compression is not a concern, set send rate to
// effectively infinite.
s.lastAckedPacketSentTime = sentTime
}
s.connectionStats.Insert(lastSentPacket, sentTime, sentBytes, s)
}
// OnPacketAcked Notifies the sampler that the |lastAckedPacket| is acknowledged. Returns a
// bandwidth sample. If no bandwidth sample is available,
// QuicBandwidth::Zero() is returned.
func (s *BandwidthSampler) OnPacketAcked(ackTime time.Time, lastAckedPacket congestion.PacketNumber) *BandwidthSample {
sentPacketState := s.connectionStats.Get(lastAckedPacket)
if sentPacketState == nil {
return NewBandwidthSample()
}
sample := s.onPacketAckedInner(ackTime, lastAckedPacket, sentPacketState)
s.connectionStats.Remove(lastAckedPacket)
return sample
}
// onPacketAckedInner Handles the actual bandwidth calculations, whereas the outer method handles
// retrieving and removing |sentPacket|.
func (s *BandwidthSampler) onPacketAckedInner(ackTime time.Time, lastAckedPacket congestion.PacketNumber, sentPacket *ConnectionStateOnSentPacket) *BandwidthSample {
s.totalBytesAcked += sentPacket.size
s.totalBytesSentAtLastAckedPacket = sentPacket.sendTimeState.totalBytesSent
s.lastAckedPacketSentTime = sentPacket.sendTime
s.lastAckedPacketAckTime = ackTime
// Exit app-limited phase once a packet that was sent while the connection is
// not app-limited is acknowledged.
if s.isAppLimited && lastAckedPacket > s.endOfAppLimitedPhase {
s.isAppLimited = false
}
// There might have been no packets acknowledged at the moment when the
// current packet was sent. In that case, there is no bandwidth sample to
// make.
if sentPacket.lastAckedPacketSentTime.IsZero() {
return NewBandwidthSample()
}
// Infinite rate indicates that the sampler is supposed to discard the
// current send rate sample and use only the ack rate.
sendRate := InfiniteBandwidth
if sentPacket.sendTime.After(sentPacket.lastAckedPacketSentTime) {
sendRate = BandwidthFromDelta(sentPacket.sendTimeState.totalBytesSent-sentPacket.totalBytesSentAtLastAckedPacket, sentPacket.sendTime.Sub(sentPacket.lastAckedPacketSentTime))
}
// During the slope calculation, ensure that ack time of the current packet is
// always larger than the time of the previous packet, otherwise division by
// zero or integer underflow can occur.
if !ackTime.After(sentPacket.lastAckedPacketAckTime) {
// TODO(wub): Compare this code count before and after fixing clock jitter
// issue.
// if sentPacket.lastAckedPacketAckTime.Equal(sentPacket.sendTime) {
// This is the 1st packet after quiescense.
// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2);
// } else {
// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2);
// }
return NewBandwidthSample()
}
ackRate := BandwidthFromDelta(s.totalBytesAcked-sentPacket.sendTimeState.totalBytesAcked,
ackTime.Sub(sentPacket.lastAckedPacketAckTime))
// Note: this sample does not account for delayed acknowledgement time. This
// means that the RTT measurements here can be artificially high, especially
// on low bandwidth connections.
sample := &BandwidthSample{
bandwidth: minBandwidth(sendRate, ackRate),
rtt: ackTime.Sub(sentPacket.sendTime),
}
SentPacketToSendTimeState(sentPacket, &sample.stateAtSend)
return sample
}
// OnPacketLost Informs the sampler that a packet is considered lost and it should no
// longer keep track of it.
func (s *BandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber) SendTimeState {
ok, sentPacket := s.connectionStats.Remove(packetNumber)
sendTimeState := SendTimeState{
isValid: ok,
}
if sentPacket != nil {
s.totalBytesLost += sentPacket.size
SentPacketToSendTimeState(sentPacket, &sendTimeState)
}
return sendTimeState
}
// OnAppLimited Informs the sampler that the connection is currently app-limited, causing
// the sampler to enter the app-limited phase. The phase will expire by
// itself.
func (s *BandwidthSampler) OnAppLimited() {
s.isAppLimited = true
s.endOfAppLimitedPhase = s.lastSendPacket
}
// SentPacketToSendTimeState Copy a subset of the (private) ConnectionStateOnSentPacket to the (public)
// SendTimeState. Always set send_time_state->is_valid to true.
func SentPacketToSendTimeState(sentPacket *ConnectionStateOnSentPacket, sendTimeState *SendTimeState) {
sendTimeState.isAppLimited = sentPacket.sendTimeState.isAppLimited
sendTimeState.totalBytesSent = sentPacket.sendTimeState.totalBytesSent
sendTimeState.totalBytesAcked = sentPacket.sendTimeState.totalBytesAcked
sendTimeState.totalBytesLost = sentPacket.sendTimeState.totalBytesLost
sendTimeState.isValid = true
}
// ConnectionStates Record of the connection state at the point where each packet in flight was
// sent, indexed by the packet number.
// FIXME: using LinkedList replace map to fast remove all the packets lower than the specified packet number.
type ConnectionStates struct {
stats map[congestion.PacketNumber]*ConnectionStateOnSentPacket
}
func (s *ConnectionStates) Insert(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) bool {
if _, ok := s.stats[packetNumber]; ok {
return false
}
s.stats[packetNumber] = NewConnectionStateOnSentPacket(packetNumber, sentTime, bytes, sampler)
return true
}
func (s *ConnectionStates) Get(packetNumber congestion.PacketNumber) *ConnectionStateOnSentPacket {
return s.stats[packetNumber]
}
func (s *ConnectionStates) Remove(packetNumber congestion.PacketNumber) (bool, *ConnectionStateOnSentPacket) {
state, ok := s.stats[packetNumber]
if ok {
delete(s.stats, packetNumber)
}
return ok, state
}
func NewConnectionStateOnSentPacket(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) *ConnectionStateOnSentPacket {
return &ConnectionStateOnSentPacket{
packetNumber: packetNumber,
sendTime: sentTime,
size: bytes,
lastAckedPacketSentTime: sampler.lastAckedPacketSentTime,
lastAckedPacketAckTime: sampler.lastAckedPacketAckTime,
totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket,
sendTimeState: SendTimeState{
isValid: true,
isAppLimited: sampler.isAppLimited,
totalBytesSent: sampler.totalBytesSent,
totalBytesAcked: sampler.totalBytesAcked,
totalBytesLost: sampler.totalBytesLost,
},
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,20 @@
package congestion
import "time"
// A Clock returns the current time
type Clock interface {
Now() time.Time
}
// DefaultClock implements the Clock interface using the Go stdlib clock.
type DefaultClock struct {
TimeFunc func() time.Time
}
var _ Clock = DefaultClock{}
// Now gets the current time
func (c DefaultClock) Now() time.Time {
return c.TimeFunc()
}

View file

@ -0,0 +1,213 @@
package congestion
import (
"math"
"time"
"github.com/sagernet/quic-go/congestion"
)
// This cubic implementation is based on the one found in Chromiums's QUIC
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
// Constants based on TCP defaults.
// The following constants are in 2^10 fractions of a second instead of ms to
// allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling round trip time.
const (
cubeScale = 40
cubeCongestionWindowScale = 410
cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
// TODO: when re-enabling cubic, make sure to use the actual packet size here
maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4)
)
const defaultNumConnections = 1
// Default Cubic backoff factor
const beta float32 = 0.7
// Additional backoff factor when loss occurs in the concave part of the Cubic
// curve. This additional backoff factor is expected to give up bandwidth to
// new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85
// Cubic implements the cubic algorithm from TCP
type Cubic struct {
clock Clock
// Number of connections to simulate.
numConnections int
// Time when this cycle started, after last loss event.
epoch time.Time
// Max congestion window used just before last loss event.
// Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value.
lastMaxCongestionWindow congestion.ByteCount
// Number of acked bytes since the cycle started (epoch).
ackedBytesCount congestion.ByteCount
// TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow congestion.ByteCount
// Origin point of cubic function.
originPointCongestionWindow congestion.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow congestion.ByteCount
}
// NewCubic returns a new Cubic instance
func NewCubic(clock Clock) *Cubic {
c := &Cubic{
clock: clock,
numConnections: defaultNumConnections,
}
c.Reset()
return c
}
// Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() {
c.epoch = time.Time{}
c.lastMaxCongestionWindow = 0
c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0
c.lastTargetCongestionWindow = 0
}
func (c *Cubic) alpha() float32 {
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
// We derive the equivalent alpha for an N-connection emulation as:
b := c.beta()
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
}
func (c *Cubic) beta() float32 {
// kNConnectionBeta is the backoff factor after loss for our N-connection
// emulation, which emulates the effective backoff of an ensemble of N
// TCP-Reno connections on a single loss event. The effective multiplier is
// computed as:
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
}
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() {
// When sender is not using the available congestion window, the window does
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
// using the entire window during the time since the beginning of the current
// "epoch" (the end of the last loss recovery period). Since
// application-limited periods break this assumption, we reset the epoch when
// in such a period. This reset effectively freezes congestion window growth
// through application-limited periods and allows Cubic growth to continue
// when the entire window is being used.
c.epoch = time.Time{}
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount {
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else {
c.lastMaxCongestionWindow = currentCongestionWindow
}
c.epoch = time.Time{} // Reset time.
return congestion.ByteCount(float32(currentCongestionWindow) * c.beta())
}
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last
// packet loss.
func (c *Cubic) CongestionWindowAfterAck(
ackedBytes congestion.ByteCount,
currentCongestionWindow congestion.ByteCount,
delayMin time.Duration,
eventTime time.Time,
) congestion.ByteCount {
c.ackedBytesCount += ackedBytes
if c.epoch.IsZero() {
// First ACK after a loss event.
c.epoch = eventTime // Start of epoch.
c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow {
c.timeToOriginPoint = 0
c.originPointCongestionWindow = currentCongestionWindow
} else {
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow
}
}
// Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a
// divide operator.
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
var targetCongestionWindow congestion.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// Limit the CWND increase to half the acked bytes.
targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// Increase the window by approximately Alpha * 1 MSS of bytes every
// time we ack an estimated tcp window of bytes. For small
// congestion windows (less than 25), the formula below will
// increase slightly slower than linearly per estimated tcp window
// of bytes.
c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
c.ackedBytesCount = 0
// We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow
// Compute target congestion_window based on cubic target and estimated TCP
// congestion_window, use highest (fastest).
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow
}
return targetCongestionWindow
}
// SetNumConnections sets the number of emulated connections
func (c *Cubic) SetNumConnections(n int) {
c.numConnections = n
}

View file

@ -0,0 +1,318 @@
package congestion
import (
"fmt"
"time"
"github.com/sagernet/quic-go/congestion"
"github.com/sagernet/quic-go/logging"
)
const (
maxBurstPackets = 3
renoBeta = 0.7 // Reno backoff factor.
minCongestionWindowPackets = 2
initialCongestionWindow = 32
)
const (
InvalidPacketNumber congestion.PacketNumber = -1
MaxCongestionWindowPackets = 20000
MaxByteCount = congestion.ByteCount(1<<62 - 1)
)
type cubicSender struct {
hybridSlowStart HybridSlowStart
rttStats congestion.RTTStatsProvider
cubic *Cubic
pacer *pacer
clock Clock
reno bool
// Track the largest packet that has been sent.
largestSentPacketNumber congestion.PacketNumber
// Track the largest packet that has been acked.
largestAckedPacketNumber congestion.PacketNumber
// Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback congestion.PacketNumber
// Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool
// Congestion window in bytes.
congestionWindow congestion.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowStartThreshold congestion.ByteCount
// ACK counter for the Reno implementation.
numAckedPackets uint64
initialCongestionWindow congestion.ByteCount
initialMaxCongestionWindow congestion.ByteCount
maxDatagramSize congestion.ByteCount
lastState logging.CongestionState
tracer logging.ConnectionTracer
}
var _ congestion.CongestionControl = &cubicSender{}
// NewCubicSender makes a new cubic sender
func NewCubicSender(
clock Clock,
initialMaxDatagramSize congestion.ByteCount,
reno bool,
tracer logging.ConnectionTracer,
) *cubicSender {
return newCubicSender(
clock,
reno,
initialMaxDatagramSize,
initialCongestionWindow*initialMaxDatagramSize,
MaxCongestionWindowPackets*initialMaxDatagramSize,
tracer,
)
}
func newCubicSender(
clock Clock,
reno bool,
initialMaxDatagramSize,
initialCongestionWindow,
initialMaxCongestionWindow congestion.ByteCount,
tracer logging.ConnectionTracer,
) *cubicSender {
c := &cubicSender{
largestSentPacketNumber: InvalidPacketNumber,
largestAckedPacketNumber: InvalidPacketNumber,
largestSentAtLastCutback: InvalidPacketNumber,
initialCongestionWindow: initialCongestionWindow,
initialMaxCongestionWindow: initialMaxCongestionWindow,
congestionWindow: initialCongestionWindow,
slowStartThreshold: MaxByteCount,
cubic: NewCubic(clock),
clock: clock,
reno: reno,
tracer: tracer,
maxDatagramSize: initialMaxDatagramSize,
}
c.pacer = newPacer(c.BandwidthEstimate)
if c.tracer != nil {
c.lastState = logging.CongestionStateSlowStart
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
}
return c
}
func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
c.rttStats = provider
}
// TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
return c.pacer.TimeUntilSend()
}
func (c *cubicSender) HasPacingBudget(now time.Time) bool {
return c.pacer.Budget(now) >= c.maxDatagramSize
}
func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
return c.maxDatagramSize * MaxCongestionWindowPackets
}
func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
return c.maxDatagramSize * minCongestionWindowPackets
}
func (c *cubicSender) OnPacketSent(
sentTime time.Time,
_ congestion.ByteCount,
packetNumber congestion.PacketNumber,
bytes congestion.ByteCount,
isRetransmittable bool,
) {
c.pacer.SentPacket(sentTime, bytes)
if !isRetransmittable {
return
}
c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber)
}
func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
return bytesInFlight < c.GetCongestionWindow()
}
func (c *cubicSender) InRecovery() bool {
return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
}
func (c *cubicSender) InSlowStart() bool {
return c.GetCongestionWindow() < c.slowStartThreshold
}
func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
return c.congestionWindow
}
func (c *cubicSender) MaybeExitSlowStart() {
if c.InSlowStart() &&
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
// exit slow start
c.slowStartThreshold = c.congestionWindow
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
}
}
func (c *cubicSender) OnPacketAcked(
ackedPacketNumber congestion.PacketNumber,
ackedBytes congestion.ByteCount,
priorInFlight congestion.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
return
}
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
}
}
func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {
return
}
c.lastCutbackExitedSlowstart = c.InSlowStart()
c.maybeTraceStateChange(logging.CongestionStateRecovery)
if c.reno {
c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
} else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
}
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
c.congestionWindow = minCwnd
}
c.slowStartThreshold = c.congestionWindow
c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery.
c.numAckedPackets = 0
}
// Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(
_ congestion.PacketNumber,
ackedBytes congestion.ByteCount,
priorInFlight congestion.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using
// the current window.
if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited()
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
return
}
if c.congestionWindow >= c.maxCongestionWindow() {
return
}
if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow += c.maxDatagramSize
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
return
}
// Congestion avoidance
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
if c.reno {
// Classic Reno congestion avoidance.
c.numAckedPackets++
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
c.congestionWindow += c.maxDatagramSize
c.numAckedPackets = 0
}
} else {
c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
}
}
func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
congestionWindow := c.GetCongestionWindow()
if bytesInFlight >= congestionWindow {
return true
}
availableBytes := congestionWindow - bytesInFlight
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
}
// BandwidthEstimate returns the current bandwidth estimate
func (c *cubicSender) BandwidthEstimate() Bandwidth {
if c.rttStats == nil {
return infBandwidth
}
srtt := c.rttStats.SmoothedRTT()
if srtt == 0 {
// If we haven't measured an rtt, the bandwidth estimate is unknown.
return infBandwidth
}
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
}
// OnRetransmissionTimeout is called on an retransmission timeout
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
c.largestSentAtLastCutback = InvalidPacketNumber
if !packetsRetransmitted {
return
}
c.hybridSlowStart.Restart()
c.cubic.Reset()
c.slowStartThreshold = c.congestionWindow / 2
c.congestionWindow = c.minCongestionWindow()
}
// OnConnectionMigration is called when the connection is migrated (?)
func (c *cubicSender) OnConnectionMigration() {
c.hybridSlowStart.Restart()
c.largestSentPacketNumber = InvalidPacketNumber
c.largestAckedPacketNumber = InvalidPacketNumber
c.largestSentAtLastCutback = InvalidPacketNumber
c.lastCutbackExitedSlowstart = false
c.cubic.Reset()
c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow
c.slowStartThreshold = c.initialMaxCongestionWindow
}
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
if c.tracer == nil || new == c.lastState {
return
}
c.tracer.UpdatedCongestionState(new)
c.lastState = new
}
func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
if s < c.maxDatagramSize {
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
}
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
c.maxDatagramSize = s
if cwndIsMinCwnd {
c.congestionWindow = c.minCongestionWindow()
}
c.pacer.SetMaxDatagramSize(s)
}

View file

@ -0,0 +1,112 @@
package congestion
import (
"time"
"github.com/sagernet/quic-go/congestion"
)
// Note(pwestin): the magic clamping numbers come from the original code in
// tcp_cubic.c.
const hybridStartLowWindow = congestion.ByteCount(16)
// Number of delay samples for detecting the increase of delay.
const hybridStartMinSamples = uint32(8)
// Exit slow start if the min rtt has increased by more than 1/8th.
const hybridStartDelayFactorExp = 3 // 2^3 = 8
// The original paper specifies 2 and 8ms, but those have changed over time.
const (
hybridStartDelayMinThresholdUs = int64(4000)
hybridStartDelayMaxThresholdUs = int64(16000)
)
// HybridSlowStart implements the TCP hybrid slow start algorithm
type HybridSlowStart struct {
endPacketNumber congestion.PacketNumber
lastSentPacketNumber congestion.PacketNumber
started bool
currentMinRTT time.Duration
rttSampleCount uint32
hystartFound bool
}
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) {
s.endPacketNumber = lastSent
s.currentMinRTT = 0
s.rttSampleCount = 0
s.started = true
}
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool {
return s.endPacketNumber < ack
}
// ShouldExitSlowStart should be called on every new ack frame, since a new
// RTT measurement can be made then.
// rtt: the RTT for this ack packet.
// minRTT: is the lowest delay (RTT) we have seen during the session.
// congestionWindow: the congestion window in packets.
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool {
if !s.started {
// Time to start the hybrid slow start.
s.StartReceiveRound(s.lastSentPacketNumber)
}
if s.hystartFound {
return true
}
// Second detection parameter - delay increase detection.
// Compare the minimum delay (s.currentMinRTT) of the current
// burst of packets relative to the minimum delay during the session.
// Note: we only look at the first few(8) packets in each burst, since we
// only want to compare the lowest RTT of the burst relative to previous
// bursts.
s.rttSampleCount++
if s.rttSampleCount <= hybridStartMinSamples {
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
s.currentMinRTT = latestRTT
}
}
// We only need to check this once per round.
if s.rttSampleCount == hybridStartMinSamples {
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
s.hystartFound = true
}
}
// Exit from slow start if the cwnd is greater than 16 and
// increasing delay is found.
return congestionWindow >= hybridStartLowWindow && s.hystartFound
}
// OnPacketSent is called when a packet was sent
func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) {
s.lastSentPacketNumber = packetNumber
}
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
// the round when the final packet of the burst is received and start it on
// the next incoming ack.
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) {
if s.IsEndOfRound(ackedPacketNumber) {
s.started = false
}
}
// Started returns true if started
func (s *HybridSlowStart) Started() bool {
return s.started
}
// Restart the slow start phase
func (s *HybridSlowStart) Restart() {
s.started = false
s.hystartFound = false
}

View file

@ -0,0 +1,72 @@
package congestion
import (
"math"
"time"
"golang.org/x/exp/constraints"
)
// InfDuration is a duration of infinite length
const InfDuration = time.Duration(math.MaxInt64)
func Max[T constraints.Ordered](a, b T) T {
if a < b {
return b
}
return a
}
func Min[T constraints.Ordered](a, b T) T {
if a < b {
return a
}
return b
}
// MinNonZeroDuration return the minimum duration that's not zero.
func MinNonZeroDuration(a, b time.Duration) time.Duration {
if a == 0 {
return b
}
if b == 0 {
return a
}
return Min(a, b)
}
// AbsDuration returns the absolute value of a time duration
func AbsDuration(d time.Duration) time.Duration {
if d >= 0 {
return d
}
return -d
}
// MinTime returns the earlier time
func MinTime(a, b time.Time) time.Time {
if a.After(b) {
return b
}
return a
}
// MinNonZeroTime returns the earlist time that is not time.Time{}
// If both a and b are time.Time{}, it returns time.Time{}
func MinNonZeroTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
if b.IsZero() {
return a
}
return MinTime(a, b)
}
// MaxTime returns the later time
func MaxTime(a, b time.Time) time.Time {
if a.After(b) {
return a
}
return b
}

View file

@ -0,0 +1,81 @@
package congestion
import (
"math"
"time"
"github.com/sagernet/quic-go/congestion"
)
const (
initialMaxDatagramSize = congestion.ByteCount(1252)
MinPacingDelay = time.Millisecond
TimerGranularity = time.Millisecond
maxBurstSizePackets = 10
)
// The pacer implements a token bucket pacing algorithm.
type pacer struct {
budgetAtLastSent congestion.ByteCount
maxDatagramSize congestion.ByteCount
lastSentTime time.Time
getAdjustedBandwidth func() uint64 // in bytes/s
}
func newPacer(getBandwidth func() Bandwidth) *pacer {
p := &pacer{
maxDatagramSize: initialMaxDatagramSize,
getAdjustedBandwidth: func() uint64 {
// Bandwidth is in bits/s. We need the value in bytes/s.
bw := uint64(getBandwidth() / BytesPerSecond)
// Use a slightly higher value than the actual measured bandwidth.
// RTT variations then won't result in under-utilization of the congestion window.
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
return bw * 5 / 4
},
}
p.budgetAtLastSent = p.maxBurstSize()
return p
}
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
budget := p.Budget(sendTime)
if size > budget {
p.budgetAtLastSent = 0
} else {
p.budgetAtLastSent = budget - size
}
p.lastSentTime = sendTime
}
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
if p.lastSentTime.IsZero() {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
return Min(p.maxBurstSize(), budget)
}
func (p *pacer) maxBurstSize() congestion.ByteCount {
return Max(
congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9,
maxBurstSizePackets*p.maxDatagramSize,
)
}
// TimeUntilSend returns when the next packet should be sent.
// It returns the zero value of time.Time if a packet can be sent immediately.
func (p *pacer) TimeUntilSend() time.Time {
if p.budgetAtLastSent >= p.maxDatagramSize {
return time.Time{}
}
return p.lastSentTime.Add(Max(
MinPacingDelay,
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond,
))
}
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
p.maxDatagramSize = s
}

View file

@ -0,0 +1,132 @@
package congestion
// WindowedFilter Use the following to construct a windowed filter object of type T.
// For example, a min filter using QuicTime as the time type:
//
// WindowedFilter<T, MinFilter<T>, QuicTime, QuicTime::Delta> ObjectName;
//
// A max filter using 64-bit integers as the time type:
//
// WindowedFilter<T, MaxFilter<T>, uint64_t, int64_t> ObjectName;
//
// Specifically, this template takes four arguments:
// 1. T -- type of the measurement that is being filtered.
// 2. Compare -- MinFilter<T> or MaxFilter<T>, depending on the type of filter
// desired.
// 3. TimeT -- the type used to represent timestamps.
// 4. TimeDeltaT -- the type used to represent continuous time intervals between
// two timestamps. Has to be the type of (a - b) if both |a| and |b| are
// of type TimeT.
type WindowedFilter struct {
// Time length of window.
windowLength int64
estimates []Sample
comparator func(int64, int64) bool
}
type Sample struct {
sample int64
time int64
}
// Compares two values and returns true if the first is greater than or equal
// to the second.
func MaxFilter(a, b int64) bool {
return a >= b
}
// Compares two values and returns true if the first is less than or equal
// to the second.
func MinFilter(a, b int64) bool {
return a <= b
}
func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter {
return &WindowedFilter{
windowLength: windowLength,
estimates: make([]Sample, 3),
comparator: comparator,
}
}
// Changes the window length. Does not update any current samples.
func (f *WindowedFilter) SetWindowLength(windowLength int64) {
f.windowLength = windowLength
}
func (f *WindowedFilter) GetBest() int64 {
return f.estimates[0].sample
}
func (f *WindowedFilter) GetSecondBest() int64 {
return f.estimates[1].sample
}
func (f *WindowedFilter) GetThirdBest() int64 {
return f.estimates[2].sample
}
func (f *WindowedFilter) Update(sample int64, time int64) {
if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength {
f.Reset(sample, time)
return
}
if f.comparator(sample, f.estimates[1].sample) {
f.estimates[1].sample = sample
f.estimates[1].time = time
f.estimates[2].sample = sample
f.estimates[2].time = time
} else if f.comparator(sample, f.estimates[2].sample) {
f.estimates[2].sample = sample
f.estimates[2].time = time
}
// Expire and update estimates as necessary.
if time-f.estimates[0].time > f.windowLength {
// The best estimate hasn't been updated for an entire window, so promote
// second and third best estimates.
f.estimates[0].sample = f.estimates[1].sample
f.estimates[0].time = f.estimates[1].time
f.estimates[1].sample = f.estimates[2].sample
f.estimates[1].time = f.estimates[2].time
f.estimates[2].sample = sample
f.estimates[2].time = time
// Need to iterate one more time. Check if the new best estimate is
// outside the window as well, since it may also have been recorded a
// long time ago. Don't need to iterate once more since we cover that
// case at the beginning of the method.
if time-f.estimates[0].time > f.windowLength {
f.estimates[0].sample = f.estimates[1].sample
f.estimates[0].time = f.estimates[1].time
f.estimates[1].sample = f.estimates[2].sample
f.estimates[1].time = f.estimates[2].time
}
return
}
if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 {
// A quarter of the window has passed without a better sample, so the
// second-best estimate is taken from the second quarter of the window.
f.estimates[1].sample = sample
f.estimates[1].time = time
f.estimates[2].sample = sample
f.estimates[2].time = time
return
}
if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 {
// We've passed a half of the window without a better estimate, so take
// a third-best estimate from the second half of the window.
f.estimates[2].sample = sample
f.estimates[2].time = time
}
}
func (f *WindowedFilter) Reset(newSample int64, newTime int64) {
f.estimates[0].sample = newSample
f.estimates[0].time = newTime
f.estimates[1].sample = newSample
f.estimates[1].time = newTime
f.estimates[2].sample = newSample
f.estimates[2].time = newTime
}

497
transport/tuic/packet.go Normal file
View file

@ -0,0 +1,497 @@
package tuic
import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
"math"
"net"
"os"
"sync"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
M "github.com/sagernet/sing/common/metadata"
)
var udpMessagePool = sync.Pool{
New: func() interface{} {
return new(udpMessage)
},
}
func releaseMessages(messages []*udpMessage) {
for _, message := range messages {
if message != nil {
*message = udpMessage{}
udpMessagePool.Put(message)
}
}
}
type udpMessage struct {
sessionID uint16
packetID uint16
fragmentTotal uint8
fragmentID uint8
destination M.Socksaddr
dataLength uint16
data *buf.Buffer
}
func (m *udpMessage) release() {
*m = udpMessage{}
udpMessagePool.Put(m)
}
func (m *udpMessage) releaseMessage() {
m.data.Release()
m.release()
}
func (m *udpMessage) pack() *buf.Buffer {
buffer := buf.NewSize(m.headerSize() + m.data.Len())
common.Must(
buffer.WriteByte(Version),
buffer.WriteByte(CommandPacket),
binary.Write(buffer, binary.BigEndian, m.sessionID),
binary.Write(buffer, binary.BigEndian, m.packetID),
binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
binary.Write(buffer, binary.BigEndian, m.fragmentID),
binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())),
addressSerializer.WriteAddrPort(buffer, m.destination),
common.Error(buffer.Write(m.data.Bytes())),
)
return buffer
}
func (m *udpMessage) headerSize() int {
return 2 + 10 + addressSerializer.AddrPortLen(m.destination)
}
func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
if message.data.Len() <= maxPacketSize {
return []*udpMessage{message}
}
var fragments []*udpMessage
originPacket := message.data.Bytes()
udpMTU := maxPacketSize - message.headerSize()
for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
fragment := udpMessagePool.Get().(*udpMessage)
*fragment = *message
if remaining > udpMTU {
fragment.data = buf.As(originPacket[:udpMTU])
originPacket = originPacket[udpMTU:]
} else {
fragment.data = buf.As(originPacket)
originPacket = nil
}
fragments = append(fragments, fragment)
}
fragmentTotal := uint16(len(fragments))
for index, fragment := range fragments {
fragment.fragmentID = uint8(index)
fragment.fragmentTotal = uint8(fragmentTotal)
if index > 0 {
fragment.destination = M.Socksaddr{}
}
}
return fragments
}
type udpPacketConn struct {
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint16
quicConn quic.Connection
data chan *udpMessage
udpStream bool
udpMTU int
packetId atomic.Uint32
closeOnce sync.Once
isServer bool
defragger *udpDefragger
onDestroy func()
}
func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn {
ctx, cancel := common.ContextWithCancelCause(ctx)
return &udpPacketConn{
ctx: ctx,
cancel: cancel,
quicConn: quicConn,
data: make(chan *udpMessage, 64),
udpStream: udpStream,
isServer: isServer,
defragger: newUDPDefragger(),
onDestroy: onDestroy,
}
}
func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
buffer = p.data
destination = p.destination
p.release()
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}
func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
_, err = buffer.ReadOnceFrom(p.data)
destination = p.destination
p.releaseMessage()
return
case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe
}
}
func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
_, err = newBuffer().ReadOnceFrom(p.data)
destination = p.destination
p.releaseMessage()
return
case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe
}
}
func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case pkt := <-c.data:
n = copy(p, pkt.data.Bytes())
if pkt.destination.IsFqdn() {
addr = pkt.destination
} else {
addr = pkt.destination.UDPAddr()
}
pkt.releaseMessage()
return n, addr, nil
case <-c.ctx.Done():
return 0, nil, io.ErrClosedPipe
}
}
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
select {
case <-c.ctx.Done():
return net.ErrClosed
default:
}
if buffer.Len() > 0xffff {
return quic.ErrMessageTooLarge(0xffff)
}
packetId := c.packetId.Add(1)
if packetId > math.MaxUint16 {
c.packetId.Store(0)
packetId = 0
}
message := udpMessagePool.Get().(*udpMessage)
*message = udpMessage{
sessionID: c.sessionID,
packetID: uint16(packetId),
fragmentTotal: 1,
destination: destination,
data: buffer,
}
defer message.releaseMessage()
var err error
if !c.udpStream && c.udpMTU > 0 && buffer.Len() > c.udpMTU {
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
} else {
err = c.writePacket(message)
}
if err == nil {
return nil
}
var tooLargeErr quic.ErrMessageTooLarge
if !errors.As(err, &tooLargeErr) {
return err
}
c.udpMTU = int(tooLargeErr)
return c.writePackets(fragUDPMessage(message, c.udpMTU))
}
func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
select {
case <-c.ctx.Done():
return 0, net.ErrClosed
default:
}
if len(p) > 0xffff {
return 0, quic.ErrMessageTooLarge(0xffff)
}
packetId := c.packetId.Add(1)
if packetId > math.MaxUint16 {
c.packetId.Store(0)
packetId = 0
}
message := udpMessagePool.Get().(*udpMessage)
*message = udpMessage{
sessionID: c.sessionID,
packetID: uint16(packetId),
fragmentTotal: 1,
destination: M.SocksaddrFromNet(addr),
data: buf.As(p),
}
if c.udpMTU > 0 && len(p) > c.udpMTU {
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil {
return len(p), nil
}
} else {
err = c.writePacket(message)
}
if err == nil {
return len(p), nil
}
var tooLargeErr quic.ErrMessageTooLarge
if !errors.As(err, &tooLargeErr) {
return
}
c.udpMTU = int(tooLargeErr)
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil {
return len(p), nil
}
return
}
func (c *udpPacketConn) inputPacket(message *udpMessage) {
if message.fragmentTotal <= 1 {
select {
case c.data <- message:
default:
}
} else {
newMessage := c.defragger.feed(message)
if newMessage != nil {
select {
case c.data <- newMessage:
default:
}
}
}
}
func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
defer releaseMessages(messages)
for _, message := range messages {
err := c.writePacket(message)
if err != nil {
return err
}
}
return nil
}
func (c *udpPacketConn) writePacket(message *udpMessage) error {
if !c.udpStream {
buffer := message.pack()
err := c.quicConn.SendMessage(buffer.Bytes())
buffer.Release()
if err != nil {
return err
}
} else {
stream, err := c.quicConn.OpenUniStream()
if err != nil {
return err
}
buffer := message.pack()
_, err = stream.Write(buffer.Bytes())
buffer.Release()
stream.Close()
if err != nil {
return err
}
}
return nil
}
func (c *udpPacketConn) Close() error {
c.closeOnce.Do(func() {
c.closeWithError(os.ErrClosed)
c.onDestroy()
})
return nil
}
func (c *udpPacketConn) closeWithError(err error) {
c.cancel(err)
if !c.isServer {
buffer := buf.NewSize(4)
defer buffer.Release()
buffer.WriteByte(Version)
buffer.WriteByte(CommandDissociate)
binary.Write(buffer, binary.BigEndian, c.sessionID)
sendStream, openErr := c.quicConn.OpenUniStream()
if openErr != nil {
return
}
defer sendStream.Close()
sendStream.Write(buffer.Bytes())
}
}
func (c *udpPacketConn) LocalAddr() net.Addr {
return c.quicConn.LocalAddr()
}
func (c *udpPacketConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
type udpDefragger struct {
packetMap *cache.LruCache[uint16, *packetItem]
}
func newUDPDefragger() *udpDefragger {
return &udpDefragger{
packetMap: cache.New(
cache.WithAge[uint16, *packetItem](10),
cache.WithUpdateAgeOnGet[uint16, *packetItem](),
cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
releaseMessages(value.messages)
}),
),
}
}
type packetItem struct {
access sync.Mutex
messages []*udpMessage
count uint8
}
func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
if m.fragmentTotal <= 1 {
return m
}
if m.fragmentID >= m.fragmentTotal {
return nil
}
item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
item.access.Lock()
defer item.access.Unlock()
if int(m.fragmentTotal) != len(item.messages) {
releaseMessages(item.messages)
item.messages = make([]*udpMessage, m.fragmentTotal)
item.count = 1
item.messages[m.fragmentID] = m
return nil
}
if item.messages[m.fragmentID] != nil {
return nil
}
item.messages[m.fragmentID] = m
item.count++
if int(item.count) != len(item.messages) {
return nil
}
newMessage := udpMessagePool.Get().(*udpMessage)
*newMessage = *item.messages[0]
if m.dataLength > 0 {
newMessage.data = buf.NewSize(int(m.dataLength))
for _, message := range item.messages {
newMessage.data.Write(message.data.Bytes())
message.releaseMessage()
}
item.messages = nil
return newMessage
}
return nil
}
func newPacketItem() *packetItem {
return new(packetItem)
}
func readUDPMessage(message *udpMessage, reader io.Reader) error {
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.packetID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.dataLength)
if err != nil {
return err
}
message.destination, err = addressSerializer.ReadAddrPort(reader)
if err != nil {
return err
}
message.data = buf.NewSize(int(message.dataLength))
_, err = message.data.ReadFullFrom(reader, message.data.FreeLen())
if err != nil {
return err
}
return nil
}
func decodeUDPMessage(message *udpMessage, data []byte) error {
reader := bytes.NewReader(data)
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.packetID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.dataLength)
if err != nil {
return err
}
message.destination, err = addressSerializer.ReadAddrPort(reader)
if err != nil {
return err
}
if reader.Len() != int(message.dataLength) {
return io.ErrUnexpectedEOF
}
message.data = buf.As(data[len(data)-reader.Len():])
return nil
}

View file

@ -0,0 +1,15 @@
package tuic
const (
Version = 5
)
const (
CommandAuthenticate = iota
CommandConnect
CommandPacket
CommandDissociate
CommandHeartbeat
)
const AuthenticateLen = 2 + 16 + 32

434
transport/tuic/server.go Normal file
View file

@ -0,0 +1,434 @@
package tuic
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"io"
"net"
"runtime"
"strings"
"sync"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/gofrs/uuid/v5"
)
type ServerOptions struct {
Context context.Context
Logger logger.Logger
TLSConfig *tls.Config
Users []User
CongestionControl string
AuthTimeout time.Duration
ZeroRTTHandshake bool
Heartbeat time.Duration
Handler ServerHandler
}
type User struct {
Name string
UUID uuid.UUID
Password string
}
type ServerHandler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
}
type Server struct {
ctx context.Context
logger logger.Logger
tlsConfig *tls.Config
heartbeat time.Duration
quicConfig *quic.Config
userMap map[uuid.UUID]User
congestionControl string
authTimeout time.Duration
handler ServerHandler
quicListener io.Closer
}
func NewServer(options ServerOptions) (*Server, error) {
if options.AuthTimeout == 0 {
options.AuthTimeout = 3 * time.Second
}
if options.Heartbeat == 0 {
options.Heartbeat = 10 * time.Second
}
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
MaxDatagramFrameSize: 1400,
EnableDatagrams: true,
Allow0RTT: options.ZeroRTTHandshake,
MaxIncomingStreams: 1 << 60,
MaxIncomingUniStreams: 1 << 60,
}
switch options.CongestionControl {
case "":
options.CongestionControl = "cubic"
case "cubic", "new_reno", "bbr":
default:
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
}
if len(options.Users) == 0 {
return nil, E.New("missing users")
}
userMap := make(map[uuid.UUID]User)
for _, user := range options.Users {
userMap[user.UUID] = user
}
return &Server{
ctx: options.Context,
logger: options.Logger,
tlsConfig: options.TLSConfig,
heartbeat: options.Heartbeat,
quicConfig: quicConfig,
userMap: userMap,
congestionControl: options.CongestionControl,
authTimeout: options.AuthTimeout,
handler: options.Handler,
}, nil
}
func (s *Server) Start(conn net.PacketConn) error {
if !s.quicConfig.Allow0RTT {
listener, err := quic.Listen(conn, s.tlsConfig, s.quicConfig)
if err != nil {
return err
}
s.quicListener = listener
go func() {
for {
connection, hErr := listener.Accept(s.ctx)
if hErr != nil {
if strings.Contains(hErr.Error(), "server closed") {
s.logger.Debug(E.Cause(hErr, "listener closed"))
} else {
s.logger.Error(E.Cause(hErr, "listener closed"))
}
return
}
go s.handleConnection(connection)
}
}()
} else {
listener, err := quic.ListenEarly(conn, s.tlsConfig, s.quicConfig)
if err != nil {
return err
}
s.quicListener = listener
go func() {
for {
connection, hErr := listener.Accept(s.ctx)
if hErr != nil {
if strings.Contains(hErr.Error(), "server closed") {
s.logger.Debug(E.Cause(hErr, "listener closed"))
} else {
s.logger.Error(E.Cause(hErr, "listener closed"))
}
return
}
go s.handleConnection(connection)
}
}()
}
return nil
}
func (s *Server) Close() error {
return common.Close(
s.quicListener,
)
}
func (s *Server) handleConnection(connection quic.Connection) {
setCongestion(s.ctx, connection, s.congestionControl)
session := &serverSession{
Server: s,
ctx: s.ctx,
quicConn: connection,
source: M.SocksaddrFromNet(connection.RemoteAddr()),
connDone: make(chan struct{}),
authDone: make(chan struct{}),
udpConnMap: make(map[uint16]*udpPacketConn),
}
session.handle()
}
type serverSession struct {
*Server
ctx context.Context
quicConn quic.Connection
source M.Socksaddr
connAccess sync.Mutex
connDone chan struct{}
connErr error
authDone chan struct{}
authUser *User
udpAccess sync.RWMutex
udpConnMap map[uint16]*udpPacketConn
}
func (s *serverSession) handle() {
if s.ctx.Done() != nil {
go func() {
select {
case <-s.ctx.Done():
s.closeWithError(s.ctx.Err())
case <-s.connDone:
}
}()
}
go s.loopUniStreams()
go s.loopStreams()
go s.loopMessages()
go s.handleAuthTimeout()
go s.loopHeartbeats()
}
func (s *serverSession) loopUniStreams() {
for {
uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
if err != nil {
return
}
go func() {
err = s.handleUniStream(uniStream)
if err != nil {
s.closeWithError(E.Cause(err, "handle uni stream"))
}
}()
}
}
func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
defer stream.CancelRead(0)
buffer := buf.New()
defer buffer.Release()
_, err := buffer.ReadAtLeastFrom(stream, 2)
if err != nil {
return E.Cause(err, "read request")
}
version := buffer.Byte(0)
if version != Version {
return E.New("unknown version ", buffer.Byte(0))
}
command := buffer.Byte(1)
switch command {
case CommandAuthenticate:
select {
case <-s.authDone:
return E.New("authentication: multiple authentication requests")
default:
}
if buffer.Len() < AuthenticateLen {
_, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len())
if err != nil {
return E.Cause(err, "authentication: read request")
}
}
userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16))
user, loaded := s.userMap[userUUID]
if !loaded {
return E.New("authentication: unknown user ", userUUID)
}
handshakeState := s.quicConn.ConnectionState().TLS
tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
if err != nil {
return E.Cause(err, "authentication: export keying material")
}
if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
return E.New("authentication: token mismatch")
}
s.authUser = &user
close(s.authDone)
return nil
case CommandPacket:
select {
case <-s.connDone:
return s.connErr
case <-s.authDone:
}
message := udpMessagePool.Get().(*udpMessage)
err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
if err != nil {
message.release()
return err
}
s.handleUDPMessage(message, true)
return nil
case CommandDissociate:
select {
case <-s.connDone:
return s.connErr
case <-s.authDone:
}
if buffer.Len() > 4 {
return E.New("invalid dissociate message")
}
var sessionID uint16
err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID)
if err != nil {
return err
}
s.udpAccess.RLock()
udpConn, loaded := s.udpConnMap[sessionID]
s.udpAccess.RUnlock()
if loaded {
udpConn.closeWithError(E.New("remote closed"))
s.udpAccess.Lock()
delete(s.udpConnMap, sessionID)
s.udpAccess.Unlock()
}
return nil
default:
return E.New("unknown command ", command)
}
}
func (s *serverSession) handleAuthTimeout() {
select {
case <-s.connDone:
case <-s.authDone:
case <-time.After(s.authTimeout):
s.closeWithError(E.New("authentication timeout"))
}
}
func (s *serverSession) loopStreams() {
for {
stream, err := s.quicConn.AcceptStream(s.ctx)
if err != nil {
return
}
go func() {
err = s.handleStream(stream)
if err != nil {
stream.CancelRead(0)
stream.Close()
s.logger.Error(E.Cause(err, "handle stream request"))
}
}()
}
}
func (s *serverSession) handleStream(stream quic.Stream) error {
buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
defer buffer.Release()
_, err := buffer.ReadAtLeastFrom(stream, 2)
if err != nil {
return E.Cause(err, "read request")
}
version, _ := buffer.ReadByte()
if version != Version {
return E.New("unknown version ", buffer.Byte(0))
}
command, _ := buffer.ReadByte()
if command != CommandConnect {
return E.New("unsupported stream command ", command)
}
destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream))
if err != nil {
return E.Cause(err, "read request destination")
}
select {
case <-s.connDone:
return s.connErr
case <-s.authDone:
}
var conn net.Conn = &serverConn{
Stream: stream,
destination: destination,
}
if buffer.IsEmpty() {
buffer.Release()
} else {
conn = bufio.NewCachedConn(conn, buffer)
}
ctx := s.ctx
if s.authUser.Name != "" {
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
}
_ = s.handler.NewConnection(ctx, conn, M.Metadata{
Source: s.source,
Destination: destination,
})
return nil
}
func (s *serverSession) loopHeartbeats() {
ticker := time.NewTicker(s.heartbeat)
defer ticker.Stop()
for {
select {
case <-s.connDone:
return
case <-ticker.C:
err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
if err != nil {
s.closeWithError(E.Cause(err, "send heartbeat"))
}
}
}
}
func (s *serverSession) closeWithError(err error) {
s.connAccess.Lock()
defer s.connAccess.Unlock()
select {
case <-s.connDone:
return
default:
s.connErr = err
close(s.connDone)
}
if E.IsClosedOrCanceled(err) {
s.logger.Debug(E.Cause(err, "connection failed"))
} else {
s.logger.Error(E.Cause(err, "connection failed"))
}
_ = s.quicConn.CloseWithError(0, "")
}
type serverConn struct {
quic.Stream
destination M.Socksaddr
}
func (c *serverConn) Read(p []byte) (n int, err error) {
n, err = c.Stream.Read(p)
return n, baderror.WrapQUIC(err)
}
func (c *serverConn) Write(p []byte) (n int, err error) {
n, err = c.Stream.Write(p)
return n, baderror.WrapQUIC(err)
}
func (c *serverConn) LocalAddr() net.Addr {
return c.destination
}
func (c *serverConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}
func (c *serverConn) Close() error {
c.Stream.CancelRead(0)
return c.Stream.Close()
}

View file

@ -0,0 +1,73 @@
package tuic
import (
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
func (s *serverSession) loopMessages() {
select {
case <-s.connDone:
return
case <-s.authDone:
}
for {
message, err := s.quicConn.ReceiveMessage(s.ctx)
if err != nil {
s.closeWithError(E.Cause(err, "receive message"))
return
}
hErr := s.handleMessage(message)
if hErr != nil {
s.closeWithError(E.Cause(hErr, "handle message"))
return
}
}
}
func (s *serverSession) handleMessage(data []byte) error {
if len(data) < 2 {
return E.New("invalid message")
}
if data[0] != Version {
return E.New("unknown version ", data[0])
}
switch data[1] {
case CommandPacket:
message := udpMessagePool.Get().(*udpMessage)
err := decodeUDPMessage(message, data[2:])
if err != nil {
message.release()
return E.Cause(err, "decode UDP message")
}
s.handleUDPMessage(message, false)
return nil
case CommandHeartbeat:
return nil
default:
return E.New("unknown command ", data[0])
}
}
func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) {
s.udpAccess.RLock()
udpConn, loaded := s.udpConnMap[message.sessionID]
s.udpAccess.RUnlock()
if !loaded || common.Done(udpConn.ctx) {
udpConn = newUDPPacketConn(s.ctx, s.quicConn, udpStream, true, func() {
s.udpAccess.Lock()
delete(s.udpConnMap, message.sessionID)
s.udpAccess.Unlock()
})
udpConn.sessionID = message.sessionID
s.udpAccess.Lock()
s.udpConnMap[message.sessionID] = udpConn
s.udpAccess.Unlock()
go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
Source: s.source,
Destination: message.destination,
})
}
udpConn.inputPacket(message)
}