mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-25 10:01:30 +00:00
Add TUIC protocol
This commit is contained in:
parent
0b14dc3228
commit
917420e79a
|
@ -21,6 +21,7 @@ const (
|
||||||
TypeShadowTLS = "shadowtls"
|
TypeShadowTLS = "shadowtls"
|
||||||
TypeShadowsocksR = "shadowsocksr"
|
TypeShadowsocksR = "shadowsocksr"
|
||||||
TypeVLESS = "vless"
|
TypeVLESS = "vless"
|
||||||
|
TypeTUIC = "tuic"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -62,6 +63,8 @@ func ProxyDisplayName(proxyType string) string {
|
||||||
return "ShadowsocksR"
|
return "ShadowsocksR"
|
||||||
case TypeVLESS:
|
case TypeVLESS:
|
||||||
return "VLESS"
|
return "VLESS"
|
||||||
|
case TypeTUIC:
|
||||||
|
return "TUIC"
|
||||||
case TypeSelector:
|
case TypeSelector:
|
||||||
return "Selector"
|
return "Selector"
|
||||||
case TypeURLTest:
|
case TypeURLTest:
|
||||||
|
|
|
@ -44,6 +44,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o
|
||||||
return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions)
|
return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions)
|
||||||
case C.TypeVLESS:
|
case C.TypeVLESS:
|
||||||
return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions)
|
return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions)
|
||||||
|
case C.TypeTUIC:
|
||||||
|
return NewTUIC(ctx, router, logger, options.Tag, options.TUICOptions)
|
||||||
default:
|
default:
|
||||||
return nil, E.New("unknown inbound type: ", options.Type)
|
return nil, E.New("unknown inbound type: ", options.Type)
|
||||||
}
|
}
|
||||||
|
|
|
@ -153,6 +153,17 @@ func (a *myInboundAdapter) createMetadata(conn net.Conn, metadata adapter.Inboun
|
||||||
return metadata
|
return metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *myInboundAdapter) createPacketMetadata(conn N.PacketConn, metadata adapter.InboundContext) adapter.InboundContext {
|
||||||
|
metadata.Inbound = a.tag
|
||||||
|
metadata.InboundType = a.protocol
|
||||||
|
metadata.InboundDetour = a.listenOptions.Detour
|
||||||
|
metadata.InboundOptions = a.listenOptions.InboundOptions
|
||||||
|
if !metadata.Destination.IsValid() {
|
||||||
|
metadata.Destination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
||||||
func (a *myInboundAdapter) newError(err error) {
|
func (a *myInboundAdapter) newError(err error) {
|
||||||
a.logger.Error(err)
|
a.logger.Error(err)
|
||||||
}
|
}
|
||||||
|
|
114
inbound/tuic.go
Normal file
114
inbound/tuic.go
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
//go:build with_quic
|
||||||
|
|
||||||
|
package inbound
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-box/transport/tuic"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/auth"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"github.com/gofrs/uuid/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ adapter.Inbound = (*TUIC)(nil)
|
||||||
|
|
||||||
|
type TUIC struct {
|
||||||
|
myInboundAdapter
|
||||||
|
server *tuic.Server
|
||||||
|
tlsConfig tls.ServerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) {
|
||||||
|
options.UDPFragmentDefault = true
|
||||||
|
if options.TLS == nil || !options.TLS.Enabled {
|
||||||
|
return nil, C.ErrTLSRequired
|
||||||
|
}
|
||||||
|
tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rawConfig, err := tlsConfig.Config()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var users []tuic.User
|
||||||
|
for index, user := range options.Users {
|
||||||
|
if user.UUID == "" {
|
||||||
|
return nil, E.New("missing uuid for user ", index)
|
||||||
|
}
|
||||||
|
userUUID, err := uuid.FromString(user.UUID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "invalid uuid for user ", index)
|
||||||
|
}
|
||||||
|
users = append(users, tuic.User{Name: user.Name, UUID: userUUID, Password: user.Password})
|
||||||
|
}
|
||||||
|
inbound := &TUIC{
|
||||||
|
myInboundAdapter: myInboundAdapter{
|
||||||
|
protocol: C.TypeTUIC,
|
||||||
|
network: []string{N.NetworkUDP},
|
||||||
|
ctx: ctx,
|
||||||
|
router: router,
|
||||||
|
logger: logger,
|
||||||
|
tag: tag,
|
||||||
|
listenOptions: options.ListenOptions,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server, err := tuic.NewServer(tuic.ServerOptions{
|
||||||
|
Context: ctx,
|
||||||
|
Logger: logger,
|
||||||
|
TLSConfig: rawConfig,
|
||||||
|
Users: users,
|
||||||
|
CongestionControl: options.CongestionControl,
|
||||||
|
AuthTimeout: time.Duration(options.AuthTimeout),
|
||||||
|
ZeroRTTHandshake: options.ZeroRTTHandshake,
|
||||||
|
Heartbeat: time.Duration(options.Heartbeat),
|
||||||
|
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inbound.server = server
|
||||||
|
return inbound, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||||
|
ctx = log.ContextWithNewID(ctx)
|
||||||
|
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||||
|
metadata = h.createMetadata(conn, metadata)
|
||||||
|
metadata.User, _ = auth.UserFromContext[string](ctx)
|
||||||
|
return h.router.RouteConnection(ctx, conn, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||||
|
ctx = log.ContextWithNewID(ctx)
|
||||||
|
metadata = h.createPacketMetadata(conn, metadata)
|
||||||
|
metadata.User, _ = auth.UserFromContext[string](ctx)
|
||||||
|
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
||||||
|
return h.router.RoutePacketConnection(ctx, conn, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) Start() error {
|
||||||
|
packetConn, err := h.myInboundAdapter.ListenUDP()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return h.server.Start(packetConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) Close() error {
|
||||||
|
return common.Close(
|
||||||
|
&h.myInboundAdapter,
|
||||||
|
common.PtrOrNil(h.server),
|
||||||
|
)
|
||||||
|
}
|
16
inbound/tuic_stub.go
Normal file
16
inbound/tuic_stub.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
//go:build !with_quic
|
||||||
|
|
||||||
|
package inbound
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (adapter.Inbound, error) {
|
||||||
|
return nil, C.ErrQUICNotIncluded
|
||||||
|
}
|
|
@ -23,6 +23,7 @@ type _Inbound struct {
|
||||||
HysteriaOptions HysteriaInboundOptions `json:"-"`
|
HysteriaOptions HysteriaInboundOptions `json:"-"`
|
||||||
ShadowTLSOptions ShadowTLSInboundOptions `json:"-"`
|
ShadowTLSOptions ShadowTLSInboundOptions `json:"-"`
|
||||||
VLESSOptions VLESSInboundOptions `json:"-"`
|
VLESSOptions VLESSInboundOptions `json:"-"`
|
||||||
|
TUICOptions TUICInboundOptions `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Inbound _Inbound
|
type Inbound _Inbound
|
||||||
|
@ -58,6 +59,8 @@ func (h Inbound) MarshalJSON() ([]byte, error) {
|
||||||
v = h.ShadowTLSOptions
|
v = h.ShadowTLSOptions
|
||||||
case C.TypeVLESS:
|
case C.TypeVLESS:
|
||||||
v = h.VLESSOptions
|
v = h.VLESSOptions
|
||||||
|
case C.TypeTUIC:
|
||||||
|
v = h.TUICOptions
|
||||||
default:
|
default:
|
||||||
return nil, E.New("unknown inbound type: ", h.Type)
|
return nil, E.New("unknown inbound type: ", h.Type)
|
||||||
}
|
}
|
||||||
|
@ -99,6 +102,8 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error {
|
||||||
v = &h.ShadowTLSOptions
|
v = &h.ShadowTLSOptions
|
||||||
case C.TypeVLESS:
|
case C.TypeVLESS:
|
||||||
v = &h.VLESSOptions
|
v = &h.VLESSOptions
|
||||||
|
case C.TypeTUIC:
|
||||||
|
v = &h.TUICOptions
|
||||||
default:
|
default:
|
||||||
return E.New("unknown inbound type: ", h.Type)
|
return E.New("unknown inbound type: ", h.Type)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ type _Outbound struct {
|
||||||
ShadowTLSOptions ShadowTLSOutboundOptions `json:"-"`
|
ShadowTLSOptions ShadowTLSOutboundOptions `json:"-"`
|
||||||
ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"`
|
ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"`
|
||||||
VLESSOptions VLESSOutboundOptions `json:"-"`
|
VLESSOptions VLESSOutboundOptions `json:"-"`
|
||||||
|
TUICOptions TUICOutboundOptions `json:"-"`
|
||||||
SelectorOptions SelectorOutboundOptions `json:"-"`
|
SelectorOptions SelectorOutboundOptions `json:"-"`
|
||||||
URLTestOptions URLTestOutboundOptions `json:"-"`
|
URLTestOptions URLTestOutboundOptions `json:"-"`
|
||||||
}
|
}
|
||||||
|
@ -60,6 +61,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) {
|
||||||
v = h.ShadowsocksROptions
|
v = h.ShadowsocksROptions
|
||||||
case C.TypeVLESS:
|
case C.TypeVLESS:
|
||||||
v = h.VLESSOptions
|
v = h.VLESSOptions
|
||||||
|
case C.TypeTUIC:
|
||||||
|
v = h.TUICOptions
|
||||||
case C.TypeSelector:
|
case C.TypeSelector:
|
||||||
v = h.SelectorOptions
|
v = h.SelectorOptions
|
||||||
case C.TypeURLTest:
|
case C.TypeURLTest:
|
||||||
|
@ -105,6 +108,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error {
|
||||||
v = &h.ShadowsocksROptions
|
v = &h.ShadowsocksROptions
|
||||||
case C.TypeVLESS:
|
case C.TypeVLESS:
|
||||||
v = &h.VLESSOptions
|
v = &h.VLESSOptions
|
||||||
|
case C.TypeTUIC:
|
||||||
|
v = &h.TUICOptions
|
||||||
case C.TypeSelector:
|
case C.TypeSelector:
|
||||||
v = &h.SelectorOptions
|
v = &h.SelectorOptions
|
||||||
case C.TypeURLTest:
|
case C.TypeURLTest:
|
||||||
|
|
30
option/tuic.go
Normal file
30
option/tuic.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package option
|
||||||
|
|
||||||
|
type TUICInboundOptions struct {
|
||||||
|
ListenOptions
|
||||||
|
Users []TUICUser `json:"users,omitempty"`
|
||||||
|
CongestionControl string `json:"congestion_control,omitempty"`
|
||||||
|
AuthTimeout Duration `json:"auth_timeout,omitempty"`
|
||||||
|
ZeroRTTHandshake bool `json:"zero_rtt_handshake,omitempty"`
|
||||||
|
Heartbeat Duration `json:"heartbeat,omitempty"`
|
||||||
|
TLS *InboundTLSOptions `json:"tls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TUICUser struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
UUID string `json:"uuid,omitempty"`
|
||||||
|
Password string `json:"password,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TUICOutboundOptions struct {
|
||||||
|
DialerOptions
|
||||||
|
ServerOptions
|
||||||
|
UUID string `json:"uuid,omitempty"`
|
||||||
|
Password string `json:"password,omitempty"`
|
||||||
|
CongestionControl string `json:"congestion_control,omitempty"`
|
||||||
|
UDPRelayMode string `json:"udp_relay_mode,omitempty"`
|
||||||
|
ZeroRTTHandshake bool `json:"zero_rtt_handshake,omitempty"`
|
||||||
|
Heartbeat Duration `json:"heartbeat,omitempty"`
|
||||||
|
Network NetworkList `json:"network,omitempty"`
|
||||||
|
TLS *OutboundTLSOptions `json:"tls,omitempty"`
|
||||||
|
}
|
|
@ -51,6 +51,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, t
|
||||||
return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions)
|
return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions)
|
||||||
case C.TypeVLESS:
|
case C.TypeVLESS:
|
||||||
return NewVLESS(ctx, router, logger, tag, options.VLESSOptions)
|
return NewVLESS(ctx, router, logger, tag, options.VLESSOptions)
|
||||||
|
case C.TypeTUIC:
|
||||||
|
return NewTUIC(ctx, router, logger, tag, options.TUICOptions)
|
||||||
case C.TypeSelector:
|
case C.TypeSelector:
|
||||||
return NewSelector(router, logger, tag, options.SelectorOptions)
|
return NewSelector(router, logger, tag, options.SelectorOptions)
|
||||||
case C.TypeURLTest:
|
case C.TypeURLTest:
|
||||||
|
|
123
outbound/tuic.go
Normal file
123
outbound/tuic.go
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
//go:build with_quic
|
||||||
|
|
||||||
|
package outbound
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-box/transport/tuic"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"github.com/gofrs/uuid/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ adapter.Outbound = (*TUIC)(nil)
|
||||||
|
_ adapter.InterfaceUpdateListener = (*TUIC)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
type TUIC struct {
|
||||||
|
myOutboundAdapter
|
||||||
|
client *tuic.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (*TUIC, error) {
|
||||||
|
options.UDPFragmentDefault = true
|
||||||
|
if options.TLS == nil || !options.TLS.Enabled {
|
||||||
|
return nil, C.ErrTLSRequired
|
||||||
|
}
|
||||||
|
abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig, err := abstractTLSConfig.Config()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userUUID, err := uuid.FromString(options.UUID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "invalid uuid")
|
||||||
|
}
|
||||||
|
var udpStream bool
|
||||||
|
switch options.UDPRelayMode {
|
||||||
|
case "native":
|
||||||
|
case "quic":
|
||||||
|
udpStream = true
|
||||||
|
}
|
||||||
|
client, err := tuic.NewClient(tuic.ClientOptions{
|
||||||
|
Context: ctx,
|
||||||
|
Dialer: dialer.New(router, options.DialerOptions),
|
||||||
|
ServerAddress: options.ServerOptions.Build(),
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
UUID: userUUID,
|
||||||
|
Password: options.Password,
|
||||||
|
CongestionControl: options.CongestionControl,
|
||||||
|
UDPStream: udpStream,
|
||||||
|
ZeroRTTHandshake: options.ZeroRTTHandshake,
|
||||||
|
Heartbeat: time.Duration(options.Heartbeat),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &TUIC{
|
||||||
|
myOutboundAdapter: myOutboundAdapter{
|
||||||
|
protocol: C.TypeTUIC,
|
||||||
|
network: options.Network.Build(),
|
||||||
|
router: router,
|
||||||
|
logger: logger,
|
||||||
|
tag: tag,
|
||||||
|
dependencies: withDialerDependency(options.DialerOptions),
|
||||||
|
},
|
||||||
|
client: client,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
switch N.NetworkName(network) {
|
||||||
|
case N.NetworkTCP:
|
||||||
|
h.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||||
|
return h.client.DialConn(ctx, destination)
|
||||||
|
case N.NetworkUDP:
|
||||||
|
conn, err := h.ListenPacket(ctx, destination)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return bufio.NewBindPacketConn(conn, destination), nil
|
||||||
|
default:
|
||||||
|
return nil, E.New("unsupported network: ", network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
|
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||||
|
return h.client.ListenPacket(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||||
|
return NewConnection(ctx, h, conn, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||||
|
return NewPacketConnection(ctx, h, conn, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) InterfaceUpdated() {
|
||||||
|
_ = h.client.CloseWithError(E.New("network changed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TUIC) Close() error {
|
||||||
|
return h.client.CloseWithError(os.ErrClosed)
|
||||||
|
}
|
16
outbound/tuic_stub.go
Normal file
16
outbound/tuic_stub.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
//go:build !with_quic
|
||||||
|
|
||||||
|
package outbound
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (adapter.Outbound, error) {
|
||||||
|
return nil, C.ErrQUICNotIncluded
|
||||||
|
}
|
|
@ -38,6 +38,8 @@ const (
|
||||||
ImageShadowsocksR = "teddysun/shadowsocks-r:latest"
|
ImageShadowsocksR = "teddysun/shadowsocks-r:latest"
|
||||||
ImageXRayCore = "teddysun/xray:latest"
|
ImageXRayCore = "teddysun/xray:latest"
|
||||||
ImageShadowsocksLegacy = "mritd/shadowsocks:latest"
|
ImageShadowsocksLegacy = "mritd/shadowsocks:latest"
|
||||||
|
ImageTUICServer = ""
|
||||||
|
ImageTUICClient = ""
|
||||||
)
|
)
|
||||||
|
|
||||||
var allImages = []string{
|
var allImages = []string{
|
||||||
|
@ -53,6 +55,8 @@ var allImages = []string{
|
||||||
ImageShadowsocksR,
|
ImageShadowsocksR,
|
||||||
ImageXRayCore,
|
ImageXRayCore,
|
||||||
ImageShadowsocksLegacy,
|
ImageShadowsocksLegacy,
|
||||||
|
// ImageTUICServer,
|
||||||
|
// ImageTUICClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
var localIP = netip.MustParseAddr("127.0.0.1")
|
var localIP = netip.MustParseAddr("127.0.0.1")
|
||||||
|
|
14
test/config/tuic-client.json
Normal file
14
test/config/tuic-client.json
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
{
|
||||||
|
"relay": {
|
||||||
|
"server": "127.0.0.1:10000",
|
||||||
|
"uuid": "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
|
||||||
|
"password": "tuic",
|
||||||
|
"certificates": [
|
||||||
|
"/etc/tuic/ca.pem"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"local": {
|
||||||
|
"server": "127.0.0.1:10001"
|
||||||
|
},
|
||||||
|
"log_level": "debug"
|
||||||
|
}
|
9
test/config/tuic-server.json
Normal file
9
test/config/tuic-server.json
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
{
|
||||||
|
"server": "[::]:10000",
|
||||||
|
"users": {
|
||||||
|
"FE35D05B-8803-45C4-BAE6-723AD2CD5D3D": "tuic"
|
||||||
|
},
|
||||||
|
"certificate": "/etc/tuic/cert.pem",
|
||||||
|
"private_key": "/etc/tuic/key.pem",
|
||||||
|
"log_level": "debug"
|
||||||
|
}
|
178
test/tuic_test.go
Normal file
178
test/tuic_test.go
Normal file
|
@ -0,0 +1,178 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
|
||||||
|
"github.com/gofrs/uuid/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTUICSelf(t *testing.T) {
|
||||||
|
t.Run("self", func(t *testing.T) {
|
||||||
|
testTUICSelf(t, false, false)
|
||||||
|
})
|
||||||
|
t.Run("self-udp-stream", func(t *testing.T) {
|
||||||
|
testTUICSelf(t, true, false)
|
||||||
|
})
|
||||||
|
t.Run("self-early", func(t *testing.T) {
|
||||||
|
testTUICSelf(t, false, true)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTUICSelf(t *testing.T, udpStream bool, zeroRTTHandshake bool) {
|
||||||
|
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
|
||||||
|
var udpRelayMode string
|
||||||
|
if udpStream {
|
||||||
|
udpRelayMode = "quic"
|
||||||
|
}
|
||||||
|
startInstance(t, option.Options{
|
||||||
|
Inbounds: []option.Inbound{
|
||||||
|
{
|
||||||
|
Type: C.TypeMixed,
|
||||||
|
Tag: "mixed-in",
|
||||||
|
MixedOptions: option.HTTPMixedInboundOptions{
|
||||||
|
ListenOptions: option.ListenOptions{
|
||||||
|
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
|
||||||
|
ListenPort: clientPort,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: C.TypeTUIC,
|
||||||
|
TUICOptions: option.TUICInboundOptions{
|
||||||
|
ListenOptions: option.ListenOptions{
|
||||||
|
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
|
||||||
|
ListenPort: serverPort,
|
||||||
|
},
|
||||||
|
Users: []option.TUICUser{{
|
||||||
|
UUID: uuid.Nil.String(),
|
||||||
|
}},
|
||||||
|
ZeroRTTHandshake: zeroRTTHandshake,
|
||||||
|
TLS: &option.InboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "example.org",
|
||||||
|
CertificatePath: certPem,
|
||||||
|
KeyPath: keyPem,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Outbounds: []option.Outbound{
|
||||||
|
{
|
||||||
|
Type: C.TypeDirect,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: C.TypeTUIC,
|
||||||
|
Tag: "tuic-out",
|
||||||
|
TUICOptions: option.TUICOutboundOptions{
|
||||||
|
ServerOptions: option.ServerOptions{
|
||||||
|
Server: "127.0.0.1",
|
||||||
|
ServerPort: serverPort,
|
||||||
|
},
|
||||||
|
UUID: uuid.Nil.String(),
|
||||||
|
UDPRelayMode: udpRelayMode,
|
||||||
|
ZeroRTTHandshake: zeroRTTHandshake,
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "example.org",
|
||||||
|
CertificatePath: certPem,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Route: &option.RouteOptions{
|
||||||
|
Rules: []option.Rule{
|
||||||
|
{
|
||||||
|
DefaultOptions: option.DefaultRule{
|
||||||
|
Inbound: []string{"mixed-in"},
|
||||||
|
Outbound: "tuic-out",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
testSuit(t, clientPort, testPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTUICInbound(t *testing.T) {
|
||||||
|
caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
|
||||||
|
startInstance(t, option.Options{
|
||||||
|
Inbounds: []option.Inbound{
|
||||||
|
{
|
||||||
|
Type: C.TypeTUIC,
|
||||||
|
TUICOptions: option.TUICInboundOptions{
|
||||||
|
ListenOptions: option.ListenOptions{
|
||||||
|
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
|
||||||
|
ListenPort: serverPort,
|
||||||
|
},
|
||||||
|
Users: []option.TUICUser{{
|
||||||
|
UUID: "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
|
||||||
|
Password: "tuic",
|
||||||
|
}},
|
||||||
|
TLS: &option.InboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "example.org",
|
||||||
|
CertificatePath: certPem,
|
||||||
|
KeyPath: keyPem,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
startDockerContainer(t, DockerOptions{
|
||||||
|
Image: ImageTUICClient,
|
||||||
|
Ports: []uint16{serverPort, clientPort},
|
||||||
|
Bind: map[string]string{
|
||||||
|
"tuic-client.json": "/etc/tuic/config.json",
|
||||||
|
caPem: "/etc/tuic/ca.pem",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTUICOutbound(t *testing.T) {
|
||||||
|
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
|
||||||
|
startDockerContainer(t, DockerOptions{
|
||||||
|
Image: ImageTUICServer,
|
||||||
|
Ports: []uint16{testPort},
|
||||||
|
Bind: map[string]string{
|
||||||
|
"tuic-server.json": "/etc/tuic/config.json",
|
||||||
|
certPem: "/etc/tuic/cert.pem",
|
||||||
|
keyPem: "/etc/tuic/key.pem",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
startInstance(t, option.Options{
|
||||||
|
Inbounds: []option.Inbound{
|
||||||
|
{
|
||||||
|
Type: C.TypeMixed,
|
||||||
|
MixedOptions: option.HTTPMixedInboundOptions{
|
||||||
|
ListenOptions: option.ListenOptions{
|
||||||
|
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
|
||||||
|
ListenPort: clientPort,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Outbounds: []option.Outbound{
|
||||||
|
{
|
||||||
|
Type: C.TypeTUIC,
|
||||||
|
TUICOptions: option.TUICOutboundOptions{
|
||||||
|
ServerOptions: option.ServerOptions{
|
||||||
|
Server: "127.0.0.1",
|
||||||
|
ServerPort: serverPort,
|
||||||
|
},
|
||||||
|
UUID: "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
|
||||||
|
Password: "tuic",
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "example.org",
|
||||||
|
CertificatePath: certPem,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
testSuit(t, clientPort, testPort)
|
||||||
|
}
|
10
transport/tuic/address.go
Normal file
10
transport/tuic/address.go
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
package tuic
|
||||||
|
|
||||||
|
import M "github.com/sagernet/sing/common/metadata"
|
||||||
|
|
||||||
|
var addressSerializer = M.NewSerializer(
|
||||||
|
M.AddressFamilyByte(0x00, M.AddressFamilyFqdn),
|
||||||
|
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
|
||||||
|
M.AddressFamilyByte(0x02, M.AddressFamilyIPv6),
|
||||||
|
M.AddressFamilyByte(0xff, M.AddressFamilyEmpty),
|
||||||
|
)
|
322
transport/tuic/client.go
Normal file
322
transport/tuic/client.go
Normal file
|
@ -0,0 +1,322 @@
|
||||||
|
package tuic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/sing-box/common/baderror"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"github.com/gofrs/uuid/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClientOptions struct {
|
||||||
|
Context context.Context
|
||||||
|
Dialer N.Dialer
|
||||||
|
ServerAddress M.Socksaddr
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
UUID uuid.UUID
|
||||||
|
Password string
|
||||||
|
CongestionControl string
|
||||||
|
UDPStream bool
|
||||||
|
ZeroRTTHandshake bool
|
||||||
|
Heartbeat time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
ctx context.Context
|
||||||
|
dialer N.Dialer
|
||||||
|
serverAddr M.Socksaddr
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
quicConfig *quic.Config
|
||||||
|
uuid uuid.UUID
|
||||||
|
password string
|
||||||
|
congestionControl string
|
||||||
|
udpStream bool
|
||||||
|
zeroRTTHandshake bool
|
||||||
|
heartbeat time.Duration
|
||||||
|
|
||||||
|
connAccess sync.RWMutex
|
||||||
|
conn *clientQUICConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(options ClientOptions) (*Client, error) {
|
||||||
|
if options.Heartbeat == 0 {
|
||||||
|
options.Heartbeat = 10 * time.Second
|
||||||
|
}
|
||||||
|
quicConfig := &quic.Config{
|
||||||
|
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
||||||
|
MaxDatagramFrameSize: 1400,
|
||||||
|
EnableDatagrams: true,
|
||||||
|
MaxIncomingUniStreams: 1 << 60,
|
||||||
|
}
|
||||||
|
switch options.CongestionControl {
|
||||||
|
case "":
|
||||||
|
options.CongestionControl = "cubic"
|
||||||
|
case "cubic", "new_reno", "bbr":
|
||||||
|
default:
|
||||||
|
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
|
||||||
|
}
|
||||||
|
return &Client{
|
||||||
|
ctx: options.Context,
|
||||||
|
dialer: options.Dialer,
|
||||||
|
serverAddr: options.ServerAddress,
|
||||||
|
tlsConfig: options.TLSConfig,
|
||||||
|
quicConfig: quicConfig,
|
||||||
|
uuid: options.UUID,
|
||||||
|
password: options.Password,
|
||||||
|
congestionControl: options.CongestionControl,
|
||||||
|
udpStream: options.UDPStream,
|
||||||
|
zeroRTTHandshake: options.ZeroRTTHandshake,
|
||||||
|
heartbeat: options.Heartbeat,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
|
||||||
|
conn := c.conn
|
||||||
|
if conn != nil && conn.active() {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
c.connAccess.Lock()
|
||||||
|
defer c.connAccess.Unlock()
|
||||||
|
conn = c.conn
|
||||||
|
if conn != nil && conn.active() {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
conn, err := c.offerNew(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
|
||||||
|
udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var quicConn quic.Connection
|
||||||
|
if c.zeroRTTHandshake {
|
||||||
|
quicConn, err = quic.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
|
||||||
|
} else {
|
||||||
|
quicConn, err = quic.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
udpConn.Close()
|
||||||
|
return nil, E.Cause(err, "open connection")
|
||||||
|
}
|
||||||
|
setCongestion(c.ctx, quicConn, c.congestionControl)
|
||||||
|
conn := &clientQUICConnection{
|
||||||
|
quicConn: quicConn,
|
||||||
|
rawConn: udpConn,
|
||||||
|
connDone: make(chan struct{}),
|
||||||
|
udpConnMap: make(map[uint16]*udpPacketConn),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
hErr := c.clientHandshake(quicConn)
|
||||||
|
if hErr != nil {
|
||||||
|
conn.closeWithError(hErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if c.udpStream {
|
||||||
|
go c.loopUniStreams(conn)
|
||||||
|
}
|
||||||
|
go c.loopMessages(conn)
|
||||||
|
go c.loopHeartbeats(conn)
|
||||||
|
c.conn = conn
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) clientHandshake(conn quic.Connection) error {
|
||||||
|
authStream, err := conn.OpenUniStream()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer authStream.Close()
|
||||||
|
handshakeState := conn.ConnectionState().TLS
|
||||||
|
tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
authRequest := buf.NewSize(AuthenticateLen)
|
||||||
|
authRequest.WriteByte(Version)
|
||||||
|
authRequest.WriteByte(CommandAuthenticate)
|
||||||
|
authRequest.Write(c.uuid[:])
|
||||||
|
authRequest.Write(tuicAuthToken)
|
||||||
|
return common.Error(authStream.Write(authRequest.Bytes()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) loopHeartbeats(conn *clientQUICConnection) {
|
||||||
|
ticker := time.NewTicker(c.heartbeat)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-conn.connDone:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
|
||||||
|
if err != nil {
|
||||||
|
conn.closeWithError(E.Cause(err, "send heartbeat"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
conn, err := c.offer(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stream, err := conn.quicConn.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &clientConn{
|
||||||
|
parent: conn,
|
||||||
|
stream: stream,
|
||||||
|
destination: destination,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
|
||||||
|
conn, err := c.offer(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var sessionID uint16
|
||||||
|
clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() {
|
||||||
|
conn.udpAccess.Lock()
|
||||||
|
delete(conn.udpConnMap, sessionID)
|
||||||
|
conn.udpAccess.Unlock()
|
||||||
|
})
|
||||||
|
conn.udpAccess.Lock()
|
||||||
|
sessionID = conn.udpSessionID
|
||||||
|
conn.udpSessionID++
|
||||||
|
conn.udpConnMap[sessionID] = clientPacketConn
|
||||||
|
conn.udpAccess.Unlock()
|
||||||
|
clientPacketConn.sessionID = sessionID
|
||||||
|
return clientPacketConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CloseWithError(err error) error {
|
||||||
|
conn := c.conn
|
||||||
|
if conn != nil {
|
||||||
|
conn.closeWithError(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientQUICConnection struct {
|
||||||
|
quicConn quic.Connection
|
||||||
|
rawConn io.Closer
|
||||||
|
closeOnce sync.Once
|
||||||
|
connDone chan struct{}
|
||||||
|
connErr error
|
||||||
|
udpAccess sync.RWMutex
|
||||||
|
udpConnMap map[uint16]*udpPacketConn
|
||||||
|
udpSessionID uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientQUICConnection) active() bool {
|
||||||
|
select {
|
||||||
|
case <-c.quicConn.Context().Done():
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-c.connDone:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientQUICConnection) closeWithError(err error) {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.connErr = err
|
||||||
|
close(c.connDone)
|
||||||
|
_ = c.quicConn.CloseWithError(0, "")
|
||||||
|
_ = c.rawConn.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientConn struct {
|
||||||
|
parent *clientQUICConnection
|
||||||
|
stream quic.Stream
|
||||||
|
destination M.Socksaddr
|
||||||
|
requestWritten bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) Read(b []byte) (n int, err error) {
|
||||||
|
n, err = c.stream.Read(b)
|
||||||
|
return n, baderror.WrapQUIC(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) Write(b []byte) (n int, err error) {
|
||||||
|
if !c.requestWritten {
|
||||||
|
request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
|
||||||
|
request.WriteByte(Version)
|
||||||
|
request.WriteByte(CommandConnect)
|
||||||
|
addressSerializer.WriteAddrPort(request, c.destination)
|
||||||
|
request.Write(b)
|
||||||
|
_, err = c.stream.Write(request.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
c.parent.closeWithError(E.Cause(err, "create new connection"))
|
||||||
|
return 0, baderror.WrapQUIC(err)
|
||||||
|
}
|
||||||
|
c.requestWritten = true
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
n, err = c.stream.Write(b)
|
||||||
|
return n, baderror.WrapQUIC(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) Close() error {
|
||||||
|
stream := c.stream
|
||||||
|
if stream == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
stream.CancelRead(0)
|
||||||
|
return stream.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) LocalAddr() net.Addr {
|
||||||
|
return M.Socksaddr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) RemoteAddr() net.Addr {
|
||||||
|
return c.destination
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) SetDeadline(t time.Time) error {
|
||||||
|
if c.stream == nil {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
return c.stream.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) SetReadDeadline(t time.Time) error {
|
||||||
|
if c.stream == nil {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
return c.stream.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
if c.stream == nil {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
return c.stream.SetWriteDeadline(t)
|
||||||
|
}
|
110
transport/tuic/client_packet.go
Normal file
110
transport/tuic/client_packet.go
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
package tuic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *Client) loopMessages(conn *clientQUICConnection) {
|
||||||
|
for {
|
||||||
|
message, err := conn.quicConn.ReceiveMessage(c.ctx)
|
||||||
|
if err != nil {
|
||||||
|
conn.closeWithError(E.Cause(err, "receive message"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
hErr := c.handleMessage(conn, message)
|
||||||
|
if hErr != nil {
|
||||||
|
conn.closeWithError(E.Cause(hErr, "handle message"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return E.New("invalid message")
|
||||||
|
}
|
||||||
|
if data[0] != Version {
|
||||||
|
return E.New("unknown version ", data[0])
|
||||||
|
}
|
||||||
|
switch data[1] {
|
||||||
|
case CommandPacket:
|
||||||
|
message := udpMessagePool.Get().(*udpMessage)
|
||||||
|
err := decodeUDPMessage(message, data[2:])
|
||||||
|
if err != nil {
|
||||||
|
message.release()
|
||||||
|
return E.Cause(err, "decode UDP message")
|
||||||
|
}
|
||||||
|
conn.handleUDPMessage(message)
|
||||||
|
return nil
|
||||||
|
case CommandHeartbeat:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return E.New("unknown command ", data[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) loopUniStreams(conn *clientQUICConnection) {
|
||||||
|
for {
|
||||||
|
stream, err := conn.quicConn.AcceptUniStream(c.ctx)
|
||||||
|
if err != nil {
|
||||||
|
conn.closeWithError(E.Cause(err, "handle uni stream"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
hErr := c.handleUniStream(conn, stream)
|
||||||
|
if hErr != nil {
|
||||||
|
conn.closeWithError(hErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error {
|
||||||
|
defer stream.CancelRead(0)
|
||||||
|
buffer := buf.NewPacket()
|
||||||
|
defer buffer.Release()
|
||||||
|
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
version, _ := buffer.ReadByte()
|
||||||
|
if version != Version {
|
||||||
|
return E.New("unknown version ", version)
|
||||||
|
}
|
||||||
|
command, _ := buffer.ReadByte()
|
||||||
|
if command != CommandPacket {
|
||||||
|
return E.New("unknown command ", command)
|
||||||
|
}
|
||||||
|
reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
|
||||||
|
message := udpMessagePool.Get().(*udpMessage)
|
||||||
|
err = readUDPMessage(message, reader)
|
||||||
|
if err != nil {
|
||||||
|
message.release()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn.handleUDPMessage(message)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
|
||||||
|
c.udpAccess.RLock()
|
||||||
|
udpConn, loaded := c.udpConnMap[message.sessionID]
|
||||||
|
c.udpAccess.RUnlock()
|
||||||
|
if !loaded {
|
||||||
|
message.releaseMessage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-udpConn.ctx.Done():
|
||||||
|
message.releaseMessage()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
udpConn.inputPacket(message)
|
||||||
|
}
|
46
transport/tuic/congestion.go
Normal file
46
transport/tuic/congestion.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package tuic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/sing-box/transport/tuic/congestion"
|
||||||
|
"github.com/sagernet/sing/common/ntp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) {
|
||||||
|
timeFunc := ntp.TimeFuncFromContext(ctx)
|
||||||
|
if timeFunc == nil {
|
||||||
|
timeFunc = time.Now
|
||||||
|
}
|
||||||
|
switch congestionName {
|
||||||
|
case "cubic":
|
||||||
|
connection.SetCongestionControl(
|
||||||
|
congestion.NewCubicSender(
|
||||||
|
congestion.DefaultClock{TimeFunc: timeFunc},
|
||||||
|
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
case "new_reno":
|
||||||
|
connection.SetCongestionControl(
|
||||||
|
congestion.NewCubicSender(
|
||||||
|
congestion.DefaultClock{TimeFunc: timeFunc},
|
||||||
|
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
||||||
|
true,
|
||||||
|
nil,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
case "bbr":
|
||||||
|
connection.SetCongestionControl(
|
||||||
|
congestion.NewBBRSender(
|
||||||
|
congestion.DefaultClock{},
|
||||||
|
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
||||||
|
congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize,
|
||||||
|
congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
3
transport/tuic/congestion/README.md
Normal file
3
transport/tuic/congestion/README.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# congestion
|
||||||
|
|
||||||
|
mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion
|
25
transport/tuic/congestion/bandwidth.go
Normal file
25
transport/tuic/congestion/bandwidth.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Bandwidth of a connection
|
||||||
|
type Bandwidth uint64
|
||||||
|
|
||||||
|
const infBandwidth Bandwidth = math.MaxUint64
|
||||||
|
|
||||||
|
const (
|
||||||
|
// BitsPerSecond is 1 bit per second
|
||||||
|
BitsPerSecond Bandwidth = 1
|
||||||
|
// BytesPerSecond is 1 byte per second
|
||||||
|
BytesPerSecond = 8 * BitsPerSecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
|
||||||
|
func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth {
|
||||||
|
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
|
||||||
|
}
|
374
transport/tuic/congestion/bandwidth_sampler.go
Normal file
374
transport/tuic/congestion/bandwidth_sampler.go
Normal file
|
@ -0,0 +1,374 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
|
)
|
||||||
|
|
||||||
|
var InfiniteBandwidth = Bandwidth(math.MaxUint64)
|
||||||
|
|
||||||
|
// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned
|
||||||
|
// to the caller when the packet is acked or lost.
|
||||||
|
type SendTimeState struct {
|
||||||
|
// Whether other states in this object is valid.
|
||||||
|
isValid bool
|
||||||
|
// Whether the sender is app limited at the time the packet was sent.
|
||||||
|
// App limited bandwidth sample might be artificially low because the sender
|
||||||
|
// did not have enough data to send in order to saturate the link.
|
||||||
|
isAppLimited bool
|
||||||
|
// Total number of sent bytes at the time the packet was sent.
|
||||||
|
// Includes the packet itself.
|
||||||
|
totalBytesSent congestion.ByteCount
|
||||||
|
// Total number of acked bytes at the time the packet was sent.
|
||||||
|
totalBytesAcked congestion.ByteCount
|
||||||
|
// Total number of lost bytes at the time the packet was sent.
|
||||||
|
totalBytesLost congestion.ByteCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionStateOnSentPacket represents the information about a sent packet
|
||||||
|
// and the state of the connection at the moment the packet was sent,
|
||||||
|
// specifically the information about the most recently acknowledged packet at
|
||||||
|
// that moment.
|
||||||
|
type ConnectionStateOnSentPacket struct {
|
||||||
|
packetNumber congestion.PacketNumber
|
||||||
|
// Time at which the packet is sent.
|
||||||
|
sendTime time.Time
|
||||||
|
// Size of the packet.
|
||||||
|
size congestion.ByteCount
|
||||||
|
// The value of |totalBytesSentAtLastAckedPacket| at the time the
|
||||||
|
// packet was sent.
|
||||||
|
totalBytesSentAtLastAckedPacket congestion.ByteCount
|
||||||
|
// The value of |lastAckedPacketSentTime| at the time the packet was
|
||||||
|
// sent.
|
||||||
|
lastAckedPacketSentTime time.Time
|
||||||
|
// The value of |lastAckedPacketAckTime| at the time the packet was
|
||||||
|
// sent.
|
||||||
|
lastAckedPacketAckTime time.Time
|
||||||
|
// Send time states that are returned to the congestion controller when the
|
||||||
|
// packet is acked or lost.
|
||||||
|
sendTimeState SendTimeState
|
||||||
|
}
|
||||||
|
|
||||||
|
// BandwidthSample
|
||||||
|
type BandwidthSample struct {
|
||||||
|
// The bandwidth at that particular sample. Zero if no valid bandwidth sample
|
||||||
|
// is available.
|
||||||
|
bandwidth Bandwidth
|
||||||
|
// The RTT measurement at this particular sample. Zero if no RTT sample is
|
||||||
|
// available. Does not correct for delayed ack time.
|
||||||
|
rtt time.Duration
|
||||||
|
// States captured when the packet was sent.
|
||||||
|
stateAtSend SendTimeState
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBandwidthSample() *BandwidthSample {
|
||||||
|
return &BandwidthSample{
|
||||||
|
// FIXME: the default value of original code is zero.
|
||||||
|
rtt: InfiniteRTT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BandwidthSampler keeps track of sent and acknowledged packets and outputs a
|
||||||
|
// bandwidth sample for every packet acknowledged. The samples are taken for
|
||||||
|
// individual packets, and are not filtered; the consumer has to filter the
|
||||||
|
// bandwidth samples itself. In certain cases, the sampler will locally severely
|
||||||
|
// underestimate the bandwidth, hence a maximum filter with a size of at least
|
||||||
|
// one RTT is recommended.
|
||||||
|
//
|
||||||
|
// This class bases its samples on the slope of two curves: the number of bytes
|
||||||
|
// sent over time, and the number of bytes acknowledged as received over time.
|
||||||
|
// It produces a sample of both slopes for every packet that gets acknowledged,
|
||||||
|
// based on a slope between two points on each of the corresponding curves. Note
|
||||||
|
// that due to the packet loss, the number of bytes on each curve might get
|
||||||
|
// further and further away from each other, meaning that it is not feasible to
|
||||||
|
// compare byte values coming from different curves with each other.
|
||||||
|
//
|
||||||
|
// The obvious points for measuring slope sample are the ones corresponding to
|
||||||
|
// the packet that was just acknowledged. Let us denote them as S_1 (point at
|
||||||
|
// which the current packet was sent) and A_1 (point at which the current packet
|
||||||
|
// was acknowledged). However, taking a slope requires two points on each line,
|
||||||
|
// so estimating bandwidth requires picking a packet in the past with respect to
|
||||||
|
// which the slope is measured.
|
||||||
|
//
|
||||||
|
// For that purpose, BandwidthSampler always keeps track of the most recently
|
||||||
|
// acknowledged packet, and records it together with every outgoing packet.
|
||||||
|
// When a packet gets acknowledged (A_1), it has not only information about when
|
||||||
|
// it itself was sent (S_1), but also the information about the latest
|
||||||
|
// acknowledged packet right before it was sent (S_0 and A_0).
|
||||||
|
//
|
||||||
|
// Based on that data, send and ack rate are estimated as:
|
||||||
|
//
|
||||||
|
// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0))
|
||||||
|
// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0))
|
||||||
|
//
|
||||||
|
// Here, the ack rate is intuitively the rate we want to treat as bandwidth.
|
||||||
|
// However, in certain cases (e.g. ack compression) the ack rate at a point may
|
||||||
|
// end up higher than the rate at which the data was originally sent, which is
|
||||||
|
// not indicative of the real bandwidth. Hence, we use the send rate as an upper
|
||||||
|
// bound, and the sample value is
|
||||||
|
//
|
||||||
|
// rate_sample = min(send_rate, ack_rate)
|
||||||
|
//
|
||||||
|
// An important edge case handled by the sampler is tracking the app-limited
|
||||||
|
// samples. There are multiple meaning of "app-limited" used interchangeably,
|
||||||
|
// hence it is important to understand and to be able to distinguish between
|
||||||
|
// them.
|
||||||
|
//
|
||||||
|
// Meaning 1: connection state. The connection is said to be app-limited when
|
||||||
|
// there is no outstanding data to send. This means that certain bandwidth
|
||||||
|
// samples in the future would not be an accurate indication of the link
|
||||||
|
// capacity, and it is important to inform consumer about that. Whenever
|
||||||
|
// connection becomes app-limited, the sampler is notified via OnAppLimited()
|
||||||
|
// method.
|
||||||
|
//
|
||||||
|
// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth
|
||||||
|
// sampler becomes notified about the connection being app-limited, it enters
|
||||||
|
// app-limited phase. In that phase, all *sent* packets are marked as
|
||||||
|
// app-limited. Note that the connection itself does not have to be
|
||||||
|
// app-limited during the app-limited phase, and in fact it will not be
|
||||||
|
// (otherwise how would it send packets?). The boolean flag below indicates
|
||||||
|
// whether the sampler is in that phase.
|
||||||
|
//
|
||||||
|
// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is
|
||||||
|
// sent during the app-limited phase, the resulting sample related to the
|
||||||
|
// packet will be marked as app-limited.
|
||||||
|
//
|
||||||
|
// With the terminology issue out of the way, let us consider the question of
|
||||||
|
// what kind of situation it addresses.
|
||||||
|
//
|
||||||
|
// Consider a scenario where we first send packets 1 to 20 at a regular
|
||||||
|
// bandwidth, and then immediately run out of data. After a few seconds, we send
|
||||||
|
// packets 21 to 60, and only receive ack for 21 between sending packets 40 and
|
||||||
|
// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0
|
||||||
|
// we use to compute the slope is going to be packet 20, a few seconds apart
|
||||||
|
// from the current packet, hence the resulting estimate would be extremely low
|
||||||
|
// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21,
|
||||||
|
// meaning that the bandwidth sample would exclude the quiescence.
|
||||||
|
//
|
||||||
|
// Based on the analysis of that scenario, we implement the following rule: once
|
||||||
|
// OnAppLimited() is called, all sent packets will produce app-limited samples
|
||||||
|
// up until an ack for a packet that was sent after OnAppLimited() was called.
|
||||||
|
// Note that while the scenario above is not the only scenario when the
|
||||||
|
// connection is app-limited, the approach works in other cases too.
|
||||||
|
type BandwidthSampler struct {
|
||||||
|
// The total number of congestion controlled bytes sent during the connection.
|
||||||
|
totalBytesSent congestion.ByteCount
|
||||||
|
// The total number of congestion controlled bytes which were acknowledged.
|
||||||
|
totalBytesAcked congestion.ByteCount
|
||||||
|
// The total number of congestion controlled bytes which were lost.
|
||||||
|
totalBytesLost congestion.ByteCount
|
||||||
|
// The value of |totalBytesSent| at the time the last acknowledged packet
|
||||||
|
// was sent. Valid only when |lastAckedPacketSentTime| is valid.
|
||||||
|
totalBytesSentAtLastAckedPacket congestion.ByteCount
|
||||||
|
// The time at which the last acknowledged packet was sent. Set to
|
||||||
|
// QuicTime::Zero() if no valid timestamp is available.
|
||||||
|
lastAckedPacketSentTime time.Time
|
||||||
|
// The time at which the most recent packet was acknowledged.
|
||||||
|
lastAckedPacketAckTime time.Time
|
||||||
|
// The most recently sent packet.
|
||||||
|
lastSendPacket congestion.PacketNumber
|
||||||
|
// Indicates whether the bandwidth sampler is currently in an app-limited
|
||||||
|
// phase.
|
||||||
|
isAppLimited bool
|
||||||
|
// The packet that will be acknowledged after this one will cause the sampler
|
||||||
|
// to exit the app-limited phase.
|
||||||
|
endOfAppLimitedPhase congestion.PacketNumber
|
||||||
|
// Record of the connection state at the point where each packet in flight was
|
||||||
|
// sent, indexed by the packet number.
|
||||||
|
connectionStats *ConnectionStates
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBandwidthSampler() *BandwidthSampler {
|
||||||
|
return &BandwidthSampler{
|
||||||
|
connectionStats: &ConnectionStates{
|
||||||
|
stats: make(map[congestion.PacketNumber]*ConnectionStateOnSentPacket),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPacketSent Inputs the sent packet information into the sampler. Assumes that all
|
||||||
|
// packets are sent in order. The information about the packet will not be
|
||||||
|
// released from the sampler until it the packet is either acknowledged or
|
||||||
|
// declared lost.
|
||||||
|
func (s *BandwidthSampler) OnPacketSent(sentTime time.Time, lastSentPacket congestion.PacketNumber, sentBytes, bytesInFlight congestion.ByteCount, hasRetransmittableData bool) {
|
||||||
|
s.lastSendPacket = lastSentPacket
|
||||||
|
|
||||||
|
if !hasRetransmittableData {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.totalBytesSent += sentBytes
|
||||||
|
|
||||||
|
// If there are no packets in flight, the time at which the new transmission
|
||||||
|
// opens can be treated as the A_0 point for the purpose of bandwidth
|
||||||
|
// sampling. This underestimates bandwidth to some extent, and produces some
|
||||||
|
// artificially low samples for most packets in flight, but it provides with
|
||||||
|
// samples at important points where we would not have them otherwise, most
|
||||||
|
// importantly at the beginning of the connection.
|
||||||
|
if bytesInFlight == 0 {
|
||||||
|
s.lastAckedPacketAckTime = sentTime
|
||||||
|
s.totalBytesSentAtLastAckedPacket = s.totalBytesSent
|
||||||
|
|
||||||
|
// In this situation ack compression is not a concern, set send rate to
|
||||||
|
// effectively infinite.
|
||||||
|
s.lastAckedPacketSentTime = sentTime
|
||||||
|
}
|
||||||
|
|
||||||
|
s.connectionStats.Insert(lastSentPacket, sentTime, sentBytes, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPacketAcked Notifies the sampler that the |lastAckedPacket| is acknowledged. Returns a
|
||||||
|
// bandwidth sample. If no bandwidth sample is available,
|
||||||
|
// QuicBandwidth::Zero() is returned.
|
||||||
|
func (s *BandwidthSampler) OnPacketAcked(ackTime time.Time, lastAckedPacket congestion.PacketNumber) *BandwidthSample {
|
||||||
|
sentPacketState := s.connectionStats.Get(lastAckedPacket)
|
||||||
|
if sentPacketState == nil {
|
||||||
|
return NewBandwidthSample()
|
||||||
|
}
|
||||||
|
|
||||||
|
sample := s.onPacketAckedInner(ackTime, lastAckedPacket, sentPacketState)
|
||||||
|
s.connectionStats.Remove(lastAckedPacket)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
}
|
||||||
|
|
||||||
|
// onPacketAckedInner Handles the actual bandwidth calculations, whereas the outer method handles
|
||||||
|
// retrieving and removing |sentPacket|.
|
||||||
|
func (s *BandwidthSampler) onPacketAckedInner(ackTime time.Time, lastAckedPacket congestion.PacketNumber, sentPacket *ConnectionStateOnSentPacket) *BandwidthSample {
|
||||||
|
s.totalBytesAcked += sentPacket.size
|
||||||
|
|
||||||
|
s.totalBytesSentAtLastAckedPacket = sentPacket.sendTimeState.totalBytesSent
|
||||||
|
s.lastAckedPacketSentTime = sentPacket.sendTime
|
||||||
|
s.lastAckedPacketAckTime = ackTime
|
||||||
|
|
||||||
|
// Exit app-limited phase once a packet that was sent while the connection is
|
||||||
|
// not app-limited is acknowledged.
|
||||||
|
if s.isAppLimited && lastAckedPacket > s.endOfAppLimitedPhase {
|
||||||
|
s.isAppLimited = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// There might have been no packets acknowledged at the moment when the
|
||||||
|
// current packet was sent. In that case, there is no bandwidth sample to
|
||||||
|
// make.
|
||||||
|
if sentPacket.lastAckedPacketSentTime.IsZero() {
|
||||||
|
return NewBandwidthSample()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infinite rate indicates that the sampler is supposed to discard the
|
||||||
|
// current send rate sample and use only the ack rate.
|
||||||
|
sendRate := InfiniteBandwidth
|
||||||
|
if sentPacket.sendTime.After(sentPacket.lastAckedPacketSentTime) {
|
||||||
|
sendRate = BandwidthFromDelta(sentPacket.sendTimeState.totalBytesSent-sentPacket.totalBytesSentAtLastAckedPacket, sentPacket.sendTime.Sub(sentPacket.lastAckedPacketSentTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// During the slope calculation, ensure that ack time of the current packet is
|
||||||
|
// always larger than the time of the previous packet, otherwise division by
|
||||||
|
// zero or integer underflow can occur.
|
||||||
|
if !ackTime.After(sentPacket.lastAckedPacketAckTime) {
|
||||||
|
// TODO(wub): Compare this code count before and after fixing clock jitter
|
||||||
|
// issue.
|
||||||
|
// if sentPacket.lastAckedPacketAckTime.Equal(sentPacket.sendTime) {
|
||||||
|
// This is the 1st packet after quiescense.
|
||||||
|
// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2);
|
||||||
|
// } else {
|
||||||
|
// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2);
|
||||||
|
// }
|
||||||
|
|
||||||
|
return NewBandwidthSample()
|
||||||
|
}
|
||||||
|
|
||||||
|
ackRate := BandwidthFromDelta(s.totalBytesAcked-sentPacket.sendTimeState.totalBytesAcked,
|
||||||
|
ackTime.Sub(sentPacket.lastAckedPacketAckTime))
|
||||||
|
|
||||||
|
// Note: this sample does not account for delayed acknowledgement time. This
|
||||||
|
// means that the RTT measurements here can be artificially high, especially
|
||||||
|
// on low bandwidth connections.
|
||||||
|
sample := &BandwidthSample{
|
||||||
|
bandwidth: minBandwidth(sendRate, ackRate),
|
||||||
|
rtt: ackTime.Sub(sentPacket.sendTime),
|
||||||
|
}
|
||||||
|
|
||||||
|
SentPacketToSendTimeState(sentPacket, &sample.stateAtSend)
|
||||||
|
return sample
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPacketLost Informs the sampler that a packet is considered lost and it should no
|
||||||
|
// longer keep track of it.
|
||||||
|
func (s *BandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber) SendTimeState {
|
||||||
|
ok, sentPacket := s.connectionStats.Remove(packetNumber)
|
||||||
|
sendTimeState := SendTimeState{
|
||||||
|
isValid: ok,
|
||||||
|
}
|
||||||
|
if sentPacket != nil {
|
||||||
|
s.totalBytesLost += sentPacket.size
|
||||||
|
SentPacketToSendTimeState(sentPacket, &sendTimeState)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sendTimeState
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnAppLimited Informs the sampler that the connection is currently app-limited, causing
|
||||||
|
// the sampler to enter the app-limited phase. The phase will expire by
|
||||||
|
// itself.
|
||||||
|
func (s *BandwidthSampler) OnAppLimited() {
|
||||||
|
s.isAppLimited = true
|
||||||
|
s.endOfAppLimitedPhase = s.lastSendPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
// SentPacketToSendTimeState Copy a subset of the (private) ConnectionStateOnSentPacket to the (public)
|
||||||
|
// SendTimeState. Always set send_time_state->is_valid to true.
|
||||||
|
func SentPacketToSendTimeState(sentPacket *ConnectionStateOnSentPacket, sendTimeState *SendTimeState) {
|
||||||
|
sendTimeState.isAppLimited = sentPacket.sendTimeState.isAppLimited
|
||||||
|
sendTimeState.totalBytesSent = sentPacket.sendTimeState.totalBytesSent
|
||||||
|
sendTimeState.totalBytesAcked = sentPacket.sendTimeState.totalBytesAcked
|
||||||
|
sendTimeState.totalBytesLost = sentPacket.sendTimeState.totalBytesLost
|
||||||
|
sendTimeState.isValid = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionStates Record of the connection state at the point where each packet in flight was
|
||||||
|
// sent, indexed by the packet number.
|
||||||
|
// FIXME: using LinkedList replace map to fast remove all the packets lower than the specified packet number.
|
||||||
|
type ConnectionStates struct {
|
||||||
|
stats map[congestion.PacketNumber]*ConnectionStateOnSentPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionStates) Insert(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) bool {
|
||||||
|
if _, ok := s.stats[packetNumber]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stats[packetNumber] = NewConnectionStateOnSentPacket(packetNumber, sentTime, bytes, sampler)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionStates) Get(packetNumber congestion.PacketNumber) *ConnectionStateOnSentPacket {
|
||||||
|
return s.stats[packetNumber]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionStates) Remove(packetNumber congestion.PacketNumber) (bool, *ConnectionStateOnSentPacket) {
|
||||||
|
state, ok := s.stats[packetNumber]
|
||||||
|
if ok {
|
||||||
|
delete(s.stats, packetNumber)
|
||||||
|
}
|
||||||
|
return ok, state
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionStateOnSentPacket(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) *ConnectionStateOnSentPacket {
|
||||||
|
return &ConnectionStateOnSentPacket{
|
||||||
|
packetNumber: packetNumber,
|
||||||
|
sendTime: sentTime,
|
||||||
|
size: bytes,
|
||||||
|
lastAckedPacketSentTime: sampler.lastAckedPacketSentTime,
|
||||||
|
lastAckedPacketAckTime: sampler.lastAckedPacketAckTime,
|
||||||
|
totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket,
|
||||||
|
sendTimeState: SendTimeState{
|
||||||
|
isValid: true,
|
||||||
|
isAppLimited: sampler.isAppLimited,
|
||||||
|
totalBytesSent: sampler.totalBytesSent,
|
||||||
|
totalBytesAcked: sampler.totalBytesAcked,
|
||||||
|
totalBytesLost: sampler.totalBytesLost,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
1000
transport/tuic/congestion/bbr_sender.go
Normal file
1000
transport/tuic/congestion/bbr_sender.go
Normal file
File diff suppressed because it is too large
Load diff
20
transport/tuic/congestion/clock.go
Normal file
20
transport/tuic/congestion/clock.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// A Clock returns the current time
|
||||||
|
type Clock interface {
|
||||||
|
Now() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultClock implements the Clock interface using the Go stdlib clock.
|
||||||
|
type DefaultClock struct {
|
||||||
|
TimeFunc func() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Clock = DefaultClock{}
|
||||||
|
|
||||||
|
// Now gets the current time
|
||||||
|
func (c DefaultClock) Now() time.Time {
|
||||||
|
return c.TimeFunc()
|
||||||
|
}
|
213
transport/tuic/congestion/cubic.go
Normal file
213
transport/tuic/congestion/cubic.go
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This cubic implementation is based on the one found in Chromiums's QUIC
|
||||||
|
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
|
||||||
|
|
||||||
|
// Constants based on TCP defaults.
|
||||||
|
// The following constants are in 2^10 fractions of a second instead of ms to
|
||||||
|
// allow a 10 shift right to divide.
|
||||||
|
|
||||||
|
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||||
|
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||||
|
const (
|
||||||
|
cubeScale = 40
|
||||||
|
cubeCongestionWindowScale = 410
|
||||||
|
cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
|
||||||
|
// TODO: when re-enabling cubic, make sure to use the actual packet size here
|
||||||
|
maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4)
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultNumConnections = 1
|
||||||
|
|
||||||
|
// Default Cubic backoff factor
|
||||||
|
const beta float32 = 0.7
|
||||||
|
|
||||||
|
// Additional backoff factor when loss occurs in the concave part of the Cubic
|
||||||
|
// curve. This additional backoff factor is expected to give up bandwidth to
|
||||||
|
// new concurrent flows and speed up convergence.
|
||||||
|
const betaLastMax float32 = 0.85
|
||||||
|
|
||||||
|
// Cubic implements the cubic algorithm from TCP
|
||||||
|
type Cubic struct {
|
||||||
|
clock Clock
|
||||||
|
|
||||||
|
// Number of connections to simulate.
|
||||||
|
numConnections int
|
||||||
|
|
||||||
|
// Time when this cycle started, after last loss event.
|
||||||
|
epoch time.Time
|
||||||
|
|
||||||
|
// Max congestion window used just before last loss event.
|
||||||
|
// Note: to improve fairness to other streams an additional back off is
|
||||||
|
// applied to this value if the new value is below our latest value.
|
||||||
|
lastMaxCongestionWindow congestion.ByteCount
|
||||||
|
|
||||||
|
// Number of acked bytes since the cycle started (epoch).
|
||||||
|
ackedBytesCount congestion.ByteCount
|
||||||
|
|
||||||
|
// TCP Reno equivalent congestion window in packets.
|
||||||
|
estimatedTCPcongestionWindow congestion.ByteCount
|
||||||
|
|
||||||
|
// Origin point of cubic function.
|
||||||
|
originPointCongestionWindow congestion.ByteCount
|
||||||
|
|
||||||
|
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||||
|
timeToOriginPoint uint32
|
||||||
|
|
||||||
|
// Last congestion window in packets computed by cubic function.
|
||||||
|
lastTargetCongestionWindow congestion.ByteCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCubic returns a new Cubic instance
|
||||||
|
func NewCubic(clock Clock) *Cubic {
|
||||||
|
c := &Cubic{
|
||||||
|
clock: clock,
|
||||||
|
numConnections: defaultNumConnections,
|
||||||
|
}
|
||||||
|
c.Reset()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset is called after a timeout to reset the cubic state
|
||||||
|
func (c *Cubic) Reset() {
|
||||||
|
c.epoch = time.Time{}
|
||||||
|
c.lastMaxCongestionWindow = 0
|
||||||
|
c.ackedBytesCount = 0
|
||||||
|
c.estimatedTCPcongestionWindow = 0
|
||||||
|
c.originPointCongestionWindow = 0
|
||||||
|
c.timeToOriginPoint = 0
|
||||||
|
c.lastTargetCongestionWindow = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cubic) alpha() float32 {
|
||||||
|
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
|
||||||
|
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
|
||||||
|
// We derive the equivalent alpha for an N-connection emulation as:
|
||||||
|
b := c.beta()
|
||||||
|
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cubic) beta() float32 {
|
||||||
|
// kNConnectionBeta is the backoff factor after loss for our N-connection
|
||||||
|
// emulation, which emulates the effective backoff of an ensemble of N
|
||||||
|
// TCP-Reno connections on a single loss event. The effective multiplier is
|
||||||
|
// computed as:
|
||||||
|
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cubic) betaLastMax() float32 {
|
||||||
|
// betaLastMax is the additional backoff factor after loss for our
|
||||||
|
// N-connection emulation, which emulates the additional backoff of
|
||||||
|
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||||
|
// effective multiplier is computed as:
|
||||||
|
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||||
|
// the available congestion window. Resets Cubic state during quiescence.
|
||||||
|
func (c *Cubic) OnApplicationLimited() {
|
||||||
|
// When sender is not using the available congestion window, the window does
|
||||||
|
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||||
|
// using the entire window during the time since the beginning of the current
|
||||||
|
// "epoch" (the end of the last loss recovery period). Since
|
||||||
|
// application-limited periods break this assumption, we reset the epoch when
|
||||||
|
// in such a period. This reset effectively freezes congestion window growth
|
||||||
|
// through application-limited periods and allows Cubic growth to continue
|
||||||
|
// when the entire window is being used.
|
||||||
|
c.epoch = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||||
|
// a loss event. Returns the new congestion window in packets. The new
|
||||||
|
// congestion window is a multiplicative decrease of our current window.
|
||||||
|
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount {
|
||||||
|
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
|
||||||
|
// We never reached the old max, so assume we are competing with another
|
||||||
|
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||||
|
c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||||
|
} else {
|
||||||
|
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||||
|
}
|
||||||
|
c.epoch = time.Time{} // Reset time.
|
||||||
|
return congestion.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||||
|
}
|
||||||
|
|
||||||
|
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||||
|
// Returns the new congestion window in packets. The new congestion window
|
||||||
|
// follows a cubic function that depends on the time passed since last
|
||||||
|
// packet loss.
|
||||||
|
func (c *Cubic) CongestionWindowAfterAck(
|
||||||
|
ackedBytes congestion.ByteCount,
|
||||||
|
currentCongestionWindow congestion.ByteCount,
|
||||||
|
delayMin time.Duration,
|
||||||
|
eventTime time.Time,
|
||||||
|
) congestion.ByteCount {
|
||||||
|
c.ackedBytesCount += ackedBytes
|
||||||
|
|
||||||
|
if c.epoch.IsZero() {
|
||||||
|
// First ACK after a loss event.
|
||||||
|
c.epoch = eventTime // Start of epoch.
|
||||||
|
c.ackedBytesCount = ackedBytes // Reset count.
|
||||||
|
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||||
|
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||||
|
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||||
|
c.timeToOriginPoint = 0
|
||||||
|
c.originPointCongestionWindow = currentCongestionWindow
|
||||||
|
} else {
|
||||||
|
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||||
|
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||||
|
// the round trip time in account. This is done to allow us to use shift as a
|
||||||
|
// divide operator.
|
||||||
|
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||||
|
|
||||||
|
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||||
|
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||||
|
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||||
|
if offset < 0 {
|
||||||
|
offset = -offset
|
||||||
|
}
|
||||||
|
|
||||||
|
deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
|
||||||
|
var targetCongestionWindow congestion.ByteCount
|
||||||
|
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||||
|
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||||
|
} else {
|
||||||
|
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||||
|
}
|
||||||
|
// Limit the CWND increase to half the acked bytes.
|
||||||
|
targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||||
|
|
||||||
|
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||||
|
// time we ack an estimated tcp window of bytes. For small
|
||||||
|
// congestion windows (less than 25), the formula below will
|
||||||
|
// increase slightly slower than linearly per estimated tcp window
|
||||||
|
// of bytes.
|
||||||
|
c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
|
||||||
|
c.ackedBytesCount = 0
|
||||||
|
|
||||||
|
// We have a new cubic congestion window.
|
||||||
|
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||||
|
|
||||||
|
// Compute target congestion_window based on cubic target and estimated TCP
|
||||||
|
// congestion_window, use highest (fastest).
|
||||||
|
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||||
|
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||||
|
}
|
||||||
|
return targetCongestionWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNumConnections sets the number of emulated connections
|
||||||
|
func (c *Cubic) SetNumConnections(n int) {
|
||||||
|
c.numConnections = n
|
||||||
|
}
|
318
transport/tuic/congestion/cubic_sender.go
Normal file
318
transport/tuic/congestion/cubic_sender.go
Normal file
|
@ -0,0 +1,318 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
|
"github.com/sagernet/quic-go/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxBurstPackets = 3
|
||||||
|
renoBeta = 0.7 // Reno backoff factor.
|
||||||
|
minCongestionWindowPackets = 2
|
||||||
|
initialCongestionWindow = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
InvalidPacketNumber congestion.PacketNumber = -1
|
||||||
|
MaxCongestionWindowPackets = 20000
|
||||||
|
MaxByteCount = congestion.ByteCount(1<<62 - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
type cubicSender struct {
|
||||||
|
hybridSlowStart HybridSlowStart
|
||||||
|
rttStats congestion.RTTStatsProvider
|
||||||
|
cubic *Cubic
|
||||||
|
pacer *pacer
|
||||||
|
clock Clock
|
||||||
|
|
||||||
|
reno bool
|
||||||
|
|
||||||
|
// Track the largest packet that has been sent.
|
||||||
|
largestSentPacketNumber congestion.PacketNumber
|
||||||
|
|
||||||
|
// Track the largest packet that has been acked.
|
||||||
|
largestAckedPacketNumber congestion.PacketNumber
|
||||||
|
|
||||||
|
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||||
|
largestSentAtLastCutback congestion.PacketNumber
|
||||||
|
|
||||||
|
// Whether the last loss event caused us to exit slowstart.
|
||||||
|
// Used for stats collection of slowstartPacketsLost
|
||||||
|
lastCutbackExitedSlowstart bool
|
||||||
|
|
||||||
|
// Congestion window in bytes.
|
||||||
|
congestionWindow congestion.ByteCount
|
||||||
|
|
||||||
|
// Slow start congestion window in bytes, aka ssthresh.
|
||||||
|
slowStartThreshold congestion.ByteCount
|
||||||
|
|
||||||
|
// ACK counter for the Reno implementation.
|
||||||
|
numAckedPackets uint64
|
||||||
|
|
||||||
|
initialCongestionWindow congestion.ByteCount
|
||||||
|
initialMaxCongestionWindow congestion.ByteCount
|
||||||
|
|
||||||
|
maxDatagramSize congestion.ByteCount
|
||||||
|
|
||||||
|
lastState logging.CongestionState
|
||||||
|
tracer logging.ConnectionTracer
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ congestion.CongestionControl = &cubicSender{}
|
||||||
|
|
||||||
|
// NewCubicSender makes a new cubic sender
|
||||||
|
func NewCubicSender(
|
||||||
|
clock Clock,
|
||||||
|
initialMaxDatagramSize congestion.ByteCount,
|
||||||
|
reno bool,
|
||||||
|
tracer logging.ConnectionTracer,
|
||||||
|
) *cubicSender {
|
||||||
|
return newCubicSender(
|
||||||
|
clock,
|
||||||
|
reno,
|
||||||
|
initialMaxDatagramSize,
|
||||||
|
initialCongestionWindow*initialMaxDatagramSize,
|
||||||
|
MaxCongestionWindowPackets*initialMaxDatagramSize,
|
||||||
|
tracer,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCubicSender(
|
||||||
|
clock Clock,
|
||||||
|
reno bool,
|
||||||
|
initialMaxDatagramSize,
|
||||||
|
initialCongestionWindow,
|
||||||
|
initialMaxCongestionWindow congestion.ByteCount,
|
||||||
|
tracer logging.ConnectionTracer,
|
||||||
|
) *cubicSender {
|
||||||
|
c := &cubicSender{
|
||||||
|
largestSentPacketNumber: InvalidPacketNumber,
|
||||||
|
largestAckedPacketNumber: InvalidPacketNumber,
|
||||||
|
largestSentAtLastCutback: InvalidPacketNumber,
|
||||||
|
initialCongestionWindow: initialCongestionWindow,
|
||||||
|
initialMaxCongestionWindow: initialMaxCongestionWindow,
|
||||||
|
congestionWindow: initialCongestionWindow,
|
||||||
|
slowStartThreshold: MaxByteCount,
|
||||||
|
cubic: NewCubic(clock),
|
||||||
|
clock: clock,
|
||||||
|
reno: reno,
|
||||||
|
tracer: tracer,
|
||||||
|
maxDatagramSize: initialMaxDatagramSize,
|
||||||
|
}
|
||||||
|
c.pacer = newPacer(c.BandwidthEstimate)
|
||||||
|
if c.tracer != nil {
|
||||||
|
c.lastState = logging.CongestionStateSlowStart
|
||||||
|
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
|
||||||
|
c.rttStats = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeUntilSend returns when the next packet should be sent.
|
||||||
|
func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
|
||||||
|
return c.pacer.TimeUntilSend()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) HasPacingBudget(now time.Time) bool {
|
||||||
|
return c.pacer.Budget(now) >= c.maxDatagramSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
|
||||||
|
return c.maxDatagramSize * MaxCongestionWindowPackets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
|
||||||
|
return c.maxDatagramSize * minCongestionWindowPackets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) OnPacketSent(
|
||||||
|
sentTime time.Time,
|
||||||
|
_ congestion.ByteCount,
|
||||||
|
packetNumber congestion.PacketNumber,
|
||||||
|
bytes congestion.ByteCount,
|
||||||
|
isRetransmittable bool,
|
||||||
|
) {
|
||||||
|
c.pacer.SentPacket(sentTime, bytes)
|
||||||
|
if !isRetransmittable {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.largestSentPacketNumber = packetNumber
|
||||||
|
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
||||||
|
return bytesInFlight < c.GetCongestionWindow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) InRecovery() bool {
|
||||||
|
return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) InSlowStart() bool {
|
||||||
|
return c.GetCongestionWindow() < c.slowStartThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
|
||||||
|
return c.congestionWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) MaybeExitSlowStart() {
|
||||||
|
if c.InSlowStart() &&
|
||||||
|
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
|
||||||
|
// exit slow start
|
||||||
|
c.slowStartThreshold = c.congestionWindow
|
||||||
|
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) OnPacketAcked(
|
||||||
|
ackedPacketNumber congestion.PacketNumber,
|
||||||
|
ackedBytes congestion.ByteCount,
|
||||||
|
priorInFlight congestion.ByteCount,
|
||||||
|
eventTime time.Time,
|
||||||
|
) {
|
||||||
|
c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||||
|
if c.InRecovery() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||||
|
if c.InSlowStart() {
|
||||||
|
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
|
||||||
|
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||||
|
// already sent should be treated as a single loss event, since it's expected.
|
||||||
|
if packetNumber <= c.largestSentAtLastCutback {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.lastCutbackExitedSlowstart = c.InSlowStart()
|
||||||
|
c.maybeTraceStateChange(logging.CongestionStateRecovery)
|
||||||
|
|
||||||
|
if c.reno {
|
||||||
|
c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
|
||||||
|
} else {
|
||||||
|
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||||
|
}
|
||||||
|
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
|
||||||
|
c.congestionWindow = minCwnd
|
||||||
|
}
|
||||||
|
c.slowStartThreshold = c.congestionWindow
|
||||||
|
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||||
|
// reset packet count from congestion avoidance mode. We start
|
||||||
|
// counting again when we're out of recovery.
|
||||||
|
c.numAckedPackets = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||||
|
// represents, but quic has a separate ack for each packet.
|
||||||
|
func (c *cubicSender) maybeIncreaseCwnd(
|
||||||
|
_ congestion.PacketNumber,
|
||||||
|
ackedBytes congestion.ByteCount,
|
||||||
|
priorInFlight congestion.ByteCount,
|
||||||
|
eventTime time.Time,
|
||||||
|
) {
|
||||||
|
// Do not increase the congestion window unless the sender is close to using
|
||||||
|
// the current window.
|
||||||
|
if !c.isCwndLimited(priorInFlight) {
|
||||||
|
c.cubic.OnApplicationLimited()
|
||||||
|
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.congestionWindow >= c.maxCongestionWindow() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.InSlowStart() {
|
||||||
|
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||||
|
c.congestionWindow += c.maxDatagramSize
|
||||||
|
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Congestion avoidance
|
||||||
|
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
||||||
|
if c.reno {
|
||||||
|
// Classic Reno congestion avoidance.
|
||||||
|
c.numAckedPackets++
|
||||||
|
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
|
||||||
|
c.congestionWindow += c.maxDatagramSize
|
||||||
|
c.numAckedPackets = 0
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
|
||||||
|
congestionWindow := c.GetCongestionWindow()
|
||||||
|
if bytesInFlight >= congestionWindow {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
availableBytes := congestionWindow - bytesInFlight
|
||||||
|
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
|
||||||
|
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// BandwidthEstimate returns the current bandwidth estimate
|
||||||
|
func (c *cubicSender) BandwidthEstimate() Bandwidth {
|
||||||
|
if c.rttStats == nil {
|
||||||
|
return infBandwidth
|
||||||
|
}
|
||||||
|
srtt := c.rttStats.SmoothedRTT()
|
||||||
|
if srtt == 0 {
|
||||||
|
// If we haven't measured an rtt, the bandwidth estimate is unknown.
|
||||||
|
return infBandwidth
|
||||||
|
}
|
||||||
|
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnRetransmissionTimeout is called on an retransmission timeout
|
||||||
|
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||||
|
c.largestSentAtLastCutback = InvalidPacketNumber
|
||||||
|
if !packetsRetransmitted {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.hybridSlowStart.Restart()
|
||||||
|
c.cubic.Reset()
|
||||||
|
c.slowStartThreshold = c.congestionWindow / 2
|
||||||
|
c.congestionWindow = c.minCongestionWindow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConnectionMigration is called when the connection is migrated (?)
|
||||||
|
func (c *cubicSender) OnConnectionMigration() {
|
||||||
|
c.hybridSlowStart.Restart()
|
||||||
|
c.largestSentPacketNumber = InvalidPacketNumber
|
||||||
|
c.largestAckedPacketNumber = InvalidPacketNumber
|
||||||
|
c.largestSentAtLastCutback = InvalidPacketNumber
|
||||||
|
c.lastCutbackExitedSlowstart = false
|
||||||
|
c.cubic.Reset()
|
||||||
|
c.numAckedPackets = 0
|
||||||
|
c.congestionWindow = c.initialCongestionWindow
|
||||||
|
c.slowStartThreshold = c.initialMaxCongestionWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
|
||||||
|
if c.tracer == nil || new == c.lastState {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.tracer.UpdatedCongestionState(new)
|
||||||
|
c.lastState = new
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||||
|
if s < c.maxDatagramSize {
|
||||||
|
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
|
||||||
|
}
|
||||||
|
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
|
||||||
|
c.maxDatagramSize = s
|
||||||
|
if cwndIsMinCwnd {
|
||||||
|
c.congestionWindow = c.minCongestionWindow()
|
||||||
|
}
|
||||||
|
c.pacer.SetMaxDatagramSize(s)
|
||||||
|
}
|
112
transport/tuic/congestion/hybrid_slow_start.go
Normal file
112
transport/tuic/congestion/hybrid_slow_start.go
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Note(pwestin): the magic clamping numbers come from the original code in
|
||||||
|
// tcp_cubic.c.
|
||||||
|
const hybridStartLowWindow = congestion.ByteCount(16)
|
||||||
|
|
||||||
|
// Number of delay samples for detecting the increase of delay.
|
||||||
|
const hybridStartMinSamples = uint32(8)
|
||||||
|
|
||||||
|
// Exit slow start if the min rtt has increased by more than 1/8th.
|
||||||
|
const hybridStartDelayFactorExp = 3 // 2^3 = 8
|
||||||
|
// The original paper specifies 2 and 8ms, but those have changed over time.
|
||||||
|
const (
|
||||||
|
hybridStartDelayMinThresholdUs = int64(4000)
|
||||||
|
hybridStartDelayMaxThresholdUs = int64(16000)
|
||||||
|
)
|
||||||
|
|
||||||
|
// HybridSlowStart implements the TCP hybrid slow start algorithm
|
||||||
|
type HybridSlowStart struct {
|
||||||
|
endPacketNumber congestion.PacketNumber
|
||||||
|
lastSentPacketNumber congestion.PacketNumber
|
||||||
|
started bool
|
||||||
|
currentMinRTT time.Duration
|
||||||
|
rttSampleCount uint32
|
||||||
|
hystartFound bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
|
||||||
|
func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) {
|
||||||
|
s.endPacketNumber = lastSent
|
||||||
|
s.currentMinRTT = 0
|
||||||
|
s.rttSampleCount = 0
|
||||||
|
s.started = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
|
||||||
|
func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool {
|
||||||
|
return s.endPacketNumber < ack
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldExitSlowStart should be called on every new ack frame, since a new
|
||||||
|
// RTT measurement can be made then.
|
||||||
|
// rtt: the RTT for this ack packet.
|
||||||
|
// minRTT: is the lowest delay (RTT) we have seen during the session.
|
||||||
|
// congestionWindow: the congestion window in packets.
|
||||||
|
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool {
|
||||||
|
if !s.started {
|
||||||
|
// Time to start the hybrid slow start.
|
||||||
|
s.StartReceiveRound(s.lastSentPacketNumber)
|
||||||
|
}
|
||||||
|
if s.hystartFound {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Second detection parameter - delay increase detection.
|
||||||
|
// Compare the minimum delay (s.currentMinRTT) of the current
|
||||||
|
// burst of packets relative to the minimum delay during the session.
|
||||||
|
// Note: we only look at the first few(8) packets in each burst, since we
|
||||||
|
// only want to compare the lowest RTT of the burst relative to previous
|
||||||
|
// bursts.
|
||||||
|
s.rttSampleCount++
|
||||||
|
if s.rttSampleCount <= hybridStartMinSamples {
|
||||||
|
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
|
||||||
|
s.currentMinRTT = latestRTT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We only need to check this once per round.
|
||||||
|
if s.rttSampleCount == hybridStartMinSamples {
|
||||||
|
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
||||||
|
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
||||||
|
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
||||||
|
minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
||||||
|
minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
||||||
|
|
||||||
|
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
||||||
|
s.hystartFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Exit from slow start if the cwnd is greater than 16 and
|
||||||
|
// increasing delay is found.
|
||||||
|
return congestionWindow >= hybridStartLowWindow && s.hystartFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPacketSent is called when a packet was sent
|
||||||
|
func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) {
|
||||||
|
s.lastSentPacketNumber = packetNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
|
||||||
|
// the round when the final packet of the burst is received and start it on
|
||||||
|
// the next incoming ack.
|
||||||
|
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) {
|
||||||
|
if s.IsEndOfRound(ackedPacketNumber) {
|
||||||
|
s.started = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Started returns true if started
|
||||||
|
func (s *HybridSlowStart) Started() bool {
|
||||||
|
return s.started
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restart the slow start phase
|
||||||
|
func (s *HybridSlowStart) Restart() {
|
||||||
|
s.started = false
|
||||||
|
s.hystartFound = false
|
||||||
|
}
|
72
transport/tuic/congestion/minmax.go
Normal file
72
transport/tuic/congestion/minmax.go
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InfDuration is a duration of infinite length
|
||||||
|
const InfDuration = time.Duration(math.MaxInt64)
|
||||||
|
|
||||||
|
func Max[T constraints.Ordered](a, b T) T {
|
||||||
|
if a < b {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func Min[T constraints.Ordered](a, b T) T {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinNonZeroDuration return the minimum duration that's not zero.
|
||||||
|
func MinNonZeroDuration(a, b time.Duration) time.Duration {
|
||||||
|
if a == 0 {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
if b == 0 {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return Min(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AbsDuration returns the absolute value of a time duration
|
||||||
|
func AbsDuration(d time.Duration) time.Duration {
|
||||||
|
if d >= 0 {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
return -d
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinTime returns the earlier time
|
||||||
|
func MinTime(a, b time.Time) time.Time {
|
||||||
|
if a.After(b) {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinNonZeroTime returns the earlist time that is not time.Time{}
|
||||||
|
// If both a and b are time.Time{}, it returns time.Time{}
|
||||||
|
func MinNonZeroTime(a, b time.Time) time.Time {
|
||||||
|
if a.IsZero() {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
if b.IsZero() {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return MinTime(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxTime returns the later time
|
||||||
|
func MaxTime(a, b time.Time) time.Time {
|
||||||
|
if a.After(b) {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
81
transport/tuic/congestion/pacer.go
Normal file
81
transport/tuic/congestion/pacer.go
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
initialMaxDatagramSize = congestion.ByteCount(1252)
|
||||||
|
MinPacingDelay = time.Millisecond
|
||||||
|
TimerGranularity = time.Millisecond
|
||||||
|
maxBurstSizePackets = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// The pacer implements a token bucket pacing algorithm.
|
||||||
|
type pacer struct {
|
||||||
|
budgetAtLastSent congestion.ByteCount
|
||||||
|
maxDatagramSize congestion.ByteCount
|
||||||
|
lastSentTime time.Time
|
||||||
|
getAdjustedBandwidth func() uint64 // in bytes/s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPacer(getBandwidth func() Bandwidth) *pacer {
|
||||||
|
p := &pacer{
|
||||||
|
maxDatagramSize: initialMaxDatagramSize,
|
||||||
|
getAdjustedBandwidth: func() uint64 {
|
||||||
|
// Bandwidth is in bits/s. We need the value in bytes/s.
|
||||||
|
bw := uint64(getBandwidth() / BytesPerSecond)
|
||||||
|
// Use a slightly higher value than the actual measured bandwidth.
|
||||||
|
// RTT variations then won't result in under-utilization of the congestion window.
|
||||||
|
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
|
||||||
|
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
|
||||||
|
return bw * 5 / 4
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.budgetAtLastSent = p.maxBurstSize()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
||||||
|
budget := p.Budget(sendTime)
|
||||||
|
if size > budget {
|
||||||
|
p.budgetAtLastSent = 0
|
||||||
|
} else {
|
||||||
|
p.budgetAtLastSent = budget - size
|
||||||
|
}
|
||||||
|
p.lastSentTime = sendTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
||||||
|
if p.lastSentTime.IsZero() {
|
||||||
|
return p.maxBurstSize()
|
||||||
|
}
|
||||||
|
budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||||
|
return Min(p.maxBurstSize(), budget)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pacer) maxBurstSize() congestion.ByteCount {
|
||||||
|
return Max(
|
||||||
|
congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9,
|
||||||
|
maxBurstSizePackets*p.maxDatagramSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeUntilSend returns when the next packet should be sent.
|
||||||
|
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||||
|
func (p *pacer) TimeUntilSend() time.Time {
|
||||||
|
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return p.lastSentTime.Add(Max(
|
||||||
|
MinPacingDelay,
|
||||||
|
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||||
|
p.maxDatagramSize = s
|
||||||
|
}
|
132
transport/tuic/congestion/windowed_filter.go
Normal file
132
transport/tuic/congestion/windowed_filter.go
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
package congestion
|
||||||
|
|
||||||
|
// WindowedFilter Use the following to construct a windowed filter object of type T.
|
||||||
|
// For example, a min filter using QuicTime as the time type:
|
||||||
|
//
|
||||||
|
// WindowedFilter<T, MinFilter<T>, QuicTime, QuicTime::Delta> ObjectName;
|
||||||
|
//
|
||||||
|
// A max filter using 64-bit integers as the time type:
|
||||||
|
//
|
||||||
|
// WindowedFilter<T, MaxFilter<T>, uint64_t, int64_t> ObjectName;
|
||||||
|
//
|
||||||
|
// Specifically, this template takes four arguments:
|
||||||
|
// 1. T -- type of the measurement that is being filtered.
|
||||||
|
// 2. Compare -- MinFilter<T> or MaxFilter<T>, depending on the type of filter
|
||||||
|
// desired.
|
||||||
|
// 3. TimeT -- the type used to represent timestamps.
|
||||||
|
// 4. TimeDeltaT -- the type used to represent continuous time intervals between
|
||||||
|
// two timestamps. Has to be the type of (a - b) if both |a| and |b| are
|
||||||
|
// of type TimeT.
|
||||||
|
type WindowedFilter struct {
|
||||||
|
// Time length of window.
|
||||||
|
windowLength int64
|
||||||
|
estimates []Sample
|
||||||
|
comparator func(int64, int64) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Sample struct {
|
||||||
|
sample int64
|
||||||
|
time int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares two values and returns true if the first is greater than or equal
|
||||||
|
// to the second.
|
||||||
|
func MaxFilter(a, b int64) bool {
|
||||||
|
return a >= b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares two values and returns true if the first is less than or equal
|
||||||
|
// to the second.
|
||||||
|
func MinFilter(a, b int64) bool {
|
||||||
|
return a <= b
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter {
|
||||||
|
return &WindowedFilter{
|
||||||
|
windowLength: windowLength,
|
||||||
|
estimates: make([]Sample, 3),
|
||||||
|
comparator: comparator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Changes the window length. Does not update any current samples.
|
||||||
|
func (f *WindowedFilter) SetWindowLength(windowLength int64) {
|
||||||
|
f.windowLength = windowLength
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *WindowedFilter) GetBest() int64 {
|
||||||
|
return f.estimates[0].sample
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *WindowedFilter) GetSecondBest() int64 {
|
||||||
|
return f.estimates[1].sample
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *WindowedFilter) GetThirdBest() int64 {
|
||||||
|
return f.estimates[2].sample
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *WindowedFilter) Update(sample int64, time int64) {
|
||||||
|
if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength {
|
||||||
|
f.Reset(sample, time)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.comparator(sample, f.estimates[1].sample) {
|
||||||
|
f.estimates[1].sample = sample
|
||||||
|
f.estimates[1].time = time
|
||||||
|
f.estimates[2].sample = sample
|
||||||
|
f.estimates[2].time = time
|
||||||
|
} else if f.comparator(sample, f.estimates[2].sample) {
|
||||||
|
f.estimates[2].sample = sample
|
||||||
|
f.estimates[2].time = time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expire and update estimates as necessary.
|
||||||
|
if time-f.estimates[0].time > f.windowLength {
|
||||||
|
// The best estimate hasn't been updated for an entire window, so promote
|
||||||
|
// second and third best estimates.
|
||||||
|
f.estimates[0].sample = f.estimates[1].sample
|
||||||
|
f.estimates[0].time = f.estimates[1].time
|
||||||
|
f.estimates[1].sample = f.estimates[2].sample
|
||||||
|
f.estimates[1].time = f.estimates[2].time
|
||||||
|
f.estimates[2].sample = sample
|
||||||
|
f.estimates[2].time = time
|
||||||
|
// Need to iterate one more time. Check if the new best estimate is
|
||||||
|
// outside the window as well, since it may also have been recorded a
|
||||||
|
// long time ago. Don't need to iterate once more since we cover that
|
||||||
|
// case at the beginning of the method.
|
||||||
|
if time-f.estimates[0].time > f.windowLength {
|
||||||
|
f.estimates[0].sample = f.estimates[1].sample
|
||||||
|
f.estimates[0].time = f.estimates[1].time
|
||||||
|
f.estimates[1].sample = f.estimates[2].sample
|
||||||
|
f.estimates[1].time = f.estimates[2].time
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 {
|
||||||
|
// A quarter of the window has passed without a better sample, so the
|
||||||
|
// second-best estimate is taken from the second quarter of the window.
|
||||||
|
f.estimates[1].sample = sample
|
||||||
|
f.estimates[1].time = time
|
||||||
|
f.estimates[2].sample = sample
|
||||||
|
f.estimates[2].time = time
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 {
|
||||||
|
// We've passed a half of the window without a better estimate, so take
|
||||||
|
// a third-best estimate from the second half of the window.
|
||||||
|
f.estimates[2].sample = sample
|
||||||
|
f.estimates[2].time = time
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *WindowedFilter) Reset(newSample int64, newTime int64) {
|
||||||
|
f.estimates[0].sample = newSample
|
||||||
|
f.estimates[0].time = newTime
|
||||||
|
f.estimates[1].sample = newSample
|
||||||
|
f.estimates[1].time = newTime
|
||||||
|
f.estimates[2].sample = newSample
|
||||||
|
f.estimates[2].time = newTime
|
||||||
|
}
|
497
transport/tuic/packet.go
Normal file
497
transport/tuic/packet.go
Normal file
|
@ -0,0 +1,497 @@
|
||||||
|
package tuic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/atomic"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
"github.com/sagernet/sing/common/cache"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
var udpMessagePool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new(udpMessage)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseMessages(messages []*udpMessage) {
|
||||||
|
for _, message := range messages {
|
||||||
|
if message != nil {
|
||||||
|
*message = udpMessage{}
|
||||||
|
udpMessagePool.Put(message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpMessage struct {
|
||||||
|
sessionID uint16
|
||||||
|
packetID uint16
|
||||||
|
fragmentTotal uint8
|
||||||
|
fragmentID uint8
|
||||||
|
destination M.Socksaddr
|
||||||
|
dataLength uint16
|
||||||
|
data *buf.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpMessage) release() {
|
||||||
|
*m = udpMessage{}
|
||||||
|
udpMessagePool.Put(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpMessage) releaseMessage() {
|
||||||
|
m.data.Release()
|
||||||
|
m.release()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpMessage) pack() *buf.Buffer {
|
||||||
|
buffer := buf.NewSize(m.headerSize() + m.data.Len())
|
||||||
|
common.Must(
|
||||||
|
buffer.WriteByte(Version),
|
||||||
|
buffer.WriteByte(CommandPacket),
|
||||||
|
binary.Write(buffer, binary.BigEndian, m.sessionID),
|
||||||
|
binary.Write(buffer, binary.BigEndian, m.packetID),
|
||||||
|
binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
|
||||||
|
binary.Write(buffer, binary.BigEndian, m.fragmentID),
|
||||||
|
binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())),
|
||||||
|
addressSerializer.WriteAddrPort(buffer, m.destination),
|
||||||
|
common.Error(buffer.Write(m.data.Bytes())),
|
||||||
|
)
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpMessage) headerSize() int {
|
||||||
|
return 2 + 10 + addressSerializer.AddrPortLen(m.destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
|
||||||
|
if message.data.Len() <= maxPacketSize {
|
||||||
|
return []*udpMessage{message}
|
||||||
|
}
|
||||||
|
var fragments []*udpMessage
|
||||||
|
originPacket := message.data.Bytes()
|
||||||
|
udpMTU := maxPacketSize - message.headerSize()
|
||||||
|
for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
|
||||||
|
fragment := udpMessagePool.Get().(*udpMessage)
|
||||||
|
*fragment = *message
|
||||||
|
if remaining > udpMTU {
|
||||||
|
fragment.data = buf.As(originPacket[:udpMTU])
|
||||||
|
originPacket = originPacket[udpMTU:]
|
||||||
|
} else {
|
||||||
|
fragment.data = buf.As(originPacket)
|
||||||
|
originPacket = nil
|
||||||
|
}
|
||||||
|
fragments = append(fragments, fragment)
|
||||||
|
}
|
||||||
|
fragmentTotal := uint16(len(fragments))
|
||||||
|
for index, fragment := range fragments {
|
||||||
|
fragment.fragmentID = uint8(index)
|
||||||
|
fragment.fragmentTotal = uint8(fragmentTotal)
|
||||||
|
if index > 0 {
|
||||||
|
fragment.destination = M.Socksaddr{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fragments
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpPacketConn struct {
|
||||||
|
ctx context.Context
|
||||||
|
cancel common.ContextCancelCauseFunc
|
||||||
|
sessionID uint16
|
||||||
|
quicConn quic.Connection
|
||||||
|
data chan *udpMessage
|
||||||
|
udpStream bool
|
||||||
|
udpMTU int
|
||||||
|
packetId atomic.Uint32
|
||||||
|
closeOnce sync.Once
|
||||||
|
isServer bool
|
||||||
|
defragger *udpDefragger
|
||||||
|
onDestroy func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn {
|
||||||
|
ctx, cancel := common.ContextWithCancelCause(ctx)
|
||||||
|
return &udpPacketConn{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
quicConn: quicConn,
|
||||||
|
data: make(chan *udpMessage, 64),
|
||||||
|
udpStream: udpStream,
|
||||||
|
isServer: isServer,
|
||||||
|
defragger: newUDPDefragger(),
|
||||||
|
onDestroy: onDestroy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||||
|
select {
|
||||||
|
case p := <-c.data:
|
||||||
|
buffer = p.data
|
||||||
|
destination = p.destination
|
||||||
|
p.release()
|
||||||
|
return
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||||
|
select {
|
||||||
|
case p := <-c.data:
|
||||||
|
_, err = buffer.ReadOnceFrom(p.data)
|
||||||
|
destination = p.destination
|
||||||
|
p.releaseMessage()
|
||||||
|
return
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return M.Socksaddr{}, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||||
|
select {
|
||||||
|
case p := <-c.data:
|
||||||
|
_, err = newBuffer().ReadOnceFrom(p.data)
|
||||||
|
destination = p.destination
|
||||||
|
p.releaseMessage()
|
||||||
|
return
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return M.Socksaddr{}, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
|
select {
|
||||||
|
case pkt := <-c.data:
|
||||||
|
n = copy(p, pkt.data.Bytes())
|
||||||
|
if pkt.destination.IsFqdn() {
|
||||||
|
addr = pkt.destination
|
||||||
|
} else {
|
||||||
|
addr = pkt.destination.UDPAddr()
|
||||||
|
}
|
||||||
|
pkt.releaseMessage()
|
||||||
|
return n, addr, nil
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return 0, nil, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
|
defer buffer.Release()
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return net.ErrClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if buffer.Len() > 0xffff {
|
||||||
|
return quic.ErrMessageTooLarge(0xffff)
|
||||||
|
}
|
||||||
|
packetId := c.packetId.Add(1)
|
||||||
|
if packetId > math.MaxUint16 {
|
||||||
|
c.packetId.Store(0)
|
||||||
|
packetId = 0
|
||||||
|
}
|
||||||
|
message := udpMessagePool.Get().(*udpMessage)
|
||||||
|
*message = udpMessage{
|
||||||
|
sessionID: c.sessionID,
|
||||||
|
packetID: uint16(packetId),
|
||||||
|
fragmentTotal: 1,
|
||||||
|
destination: destination,
|
||||||
|
data: buffer,
|
||||||
|
}
|
||||||
|
defer message.releaseMessage()
|
||||||
|
var err error
|
||||||
|
if !c.udpStream && c.udpMTU > 0 && buffer.Len() > c.udpMTU {
|
||||||
|
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
||||||
|
} else {
|
||||||
|
err = c.writePacket(message)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var tooLargeErr quic.ErrMessageTooLarge
|
||||||
|
if !errors.As(err, &tooLargeErr) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.udpMTU = int(tooLargeErr)
|
||||||
|
return c.writePackets(fragUDPMessage(message, c.udpMTU))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if len(p) > 0xffff {
|
||||||
|
return 0, quic.ErrMessageTooLarge(0xffff)
|
||||||
|
}
|
||||||
|
packetId := c.packetId.Add(1)
|
||||||
|
if packetId > math.MaxUint16 {
|
||||||
|
c.packetId.Store(0)
|
||||||
|
packetId = 0
|
||||||
|
}
|
||||||
|
message := udpMessagePool.Get().(*udpMessage)
|
||||||
|
*message = udpMessage{
|
||||||
|
sessionID: c.sessionID,
|
||||||
|
packetID: uint16(packetId),
|
||||||
|
fragmentTotal: 1,
|
||||||
|
destination: M.SocksaddrFromNet(addr),
|
||||||
|
data: buf.As(p),
|
||||||
|
}
|
||||||
|
if c.udpMTU > 0 && len(p) > c.udpMTU {
|
||||||
|
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
||||||
|
if err == nil {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = c.writePacket(message)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
var tooLargeErr quic.ErrMessageTooLarge
|
||||||
|
if !errors.As(err, &tooLargeErr) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.udpMTU = int(tooLargeErr)
|
||||||
|
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
||||||
|
if err == nil {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) inputPacket(message *udpMessage) {
|
||||||
|
if message.fragmentTotal <= 1 {
|
||||||
|
select {
|
||||||
|
case c.data <- message:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
newMessage := c.defragger.feed(message)
|
||||||
|
if newMessage != nil {
|
||||||
|
select {
|
||||||
|
case c.data <- newMessage:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
|
||||||
|
defer releaseMessages(messages)
|
||||||
|
for _, message := range messages {
|
||||||
|
err := c.writePacket(message)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) writePacket(message *udpMessage) error {
|
||||||
|
if !c.udpStream {
|
||||||
|
buffer := message.pack()
|
||||||
|
err := c.quicConn.SendMessage(buffer.Bytes())
|
||||||
|
buffer.Release()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stream, err := c.quicConn.OpenUniStream()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
buffer := message.pack()
|
||||||
|
_, err = stream.Write(buffer.Bytes())
|
||||||
|
buffer.Release()
|
||||||
|
stream.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) Close() error {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.closeWithError(os.ErrClosed)
|
||||||
|
c.onDestroy()
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) closeWithError(err error) {
|
||||||
|
c.cancel(err)
|
||||||
|
if !c.isServer {
|
||||||
|
buffer := buf.NewSize(4)
|
||||||
|
defer buffer.Release()
|
||||||
|
buffer.WriteByte(Version)
|
||||||
|
buffer.WriteByte(CommandDissociate)
|
||||||
|
binary.Write(buffer, binary.BigEndian, c.sessionID)
|
||||||
|
sendStream, openErr := c.quicConn.OpenUniStream()
|
||||||
|
if openErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer sendStream.Close()
|
||||||
|
sendStream.Write(buffer.Bytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) LocalAddr() net.Addr {
|
||||||
|
return c.quicConn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) SetDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpDefragger struct {
|
||||||
|
packetMap *cache.LruCache[uint16, *packetItem]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPDefragger() *udpDefragger {
|
||||||
|
return &udpDefragger{
|
||||||
|
packetMap: cache.New(
|
||||||
|
cache.WithAge[uint16, *packetItem](10),
|
||||||
|
cache.WithUpdateAgeOnGet[uint16, *packetItem](),
|
||||||
|
cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
|
||||||
|
releaseMessages(value.messages)
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type packetItem struct {
|
||||||
|
access sync.Mutex
|
||||||
|
messages []*udpMessage
|
||||||
|
count uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
|
||||||
|
if m.fragmentTotal <= 1 {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
if m.fragmentID >= m.fragmentTotal {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
|
||||||
|
item.access.Lock()
|
||||||
|
defer item.access.Unlock()
|
||||||
|
if int(m.fragmentTotal) != len(item.messages) {
|
||||||
|
releaseMessages(item.messages)
|
||||||
|
item.messages = make([]*udpMessage, m.fragmentTotal)
|
||||||
|
item.count = 1
|
||||||
|
item.messages[m.fragmentID] = m
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if item.messages[m.fragmentID] != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
item.messages[m.fragmentID] = m
|
||||||
|
item.count++
|
||||||
|
if int(item.count) != len(item.messages) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
newMessage := udpMessagePool.Get().(*udpMessage)
|
||||||
|
*newMessage = *item.messages[0]
|
||||||
|
if m.dataLength > 0 {
|
||||||
|
newMessage.data = buf.NewSize(int(m.dataLength))
|
||||||
|
for _, message := range item.messages {
|
||||||
|
newMessage.data.Write(message.data.Bytes())
|
||||||
|
message.releaseMessage()
|
||||||
|
}
|
||||||
|
item.messages = nil
|
||||||
|
return newMessage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPacketItem() *packetItem {
|
||||||
|
return new(packetItem)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readUDPMessage(message *udpMessage, reader io.Reader) error {
|
||||||
|
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.packetID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.dataLength)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
message.destination, err = addressSerializer.ReadAddrPort(reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
message.data = buf.NewSize(int(message.dataLength))
|
||||||
|
_, err = message.data.ReadFullFrom(reader, message.data.FreeLen())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeUDPMessage(message *udpMessage, data []byte) error {
|
||||||
|
reader := bytes.NewReader(data)
|
||||||
|
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.packetID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.dataLength)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
message.destination, err = addressSerializer.ReadAddrPort(reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if reader.Len() != int(message.dataLength) {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
message.data = buf.As(data[len(data)-reader.Len():])
|
||||||
|
return nil
|
||||||
|
}
|
15
transport/tuic/protocol.go
Normal file
15
transport/tuic/protocol.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package tuic
|
||||||
|
|
||||||
|
const (
|
||||||
|
Version = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CommandAuthenticate = iota
|
||||||
|
CommandConnect
|
||||||
|
CommandPacket
|
||||||
|
CommandDissociate
|
||||||
|
CommandHeartbeat
|
||||||
|
)
|
||||||
|
|
||||||
|
const AuthenticateLen = 2 + 16 + 32
|
434
transport/tuic/server.go
Normal file
434
transport/tuic/server.go
Normal file
|
@ -0,0 +1,434 @@
|
||||||
|
package tuic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/sing-box/common/baderror"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/auth"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"github.com/gofrs/uuid/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerOptions struct {
|
||||||
|
Context context.Context
|
||||||
|
Logger logger.Logger
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
Users []User
|
||||||
|
CongestionControl string
|
||||||
|
AuthTimeout time.Duration
|
||||||
|
ZeroRTTHandshake bool
|
||||||
|
Heartbeat time.Duration
|
||||||
|
Handler ServerHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Name string
|
||||||
|
UUID uuid.UUID
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerHandler interface {
|
||||||
|
N.TCPConnectionHandler
|
||||||
|
N.UDPConnectionHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger logger.Logger
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
heartbeat time.Duration
|
||||||
|
quicConfig *quic.Config
|
||||||
|
userMap map[uuid.UUID]User
|
||||||
|
congestionControl string
|
||||||
|
authTimeout time.Duration
|
||||||
|
handler ServerHandler
|
||||||
|
|
||||||
|
quicListener io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(options ServerOptions) (*Server, error) {
|
||||||
|
if options.AuthTimeout == 0 {
|
||||||
|
options.AuthTimeout = 3 * time.Second
|
||||||
|
}
|
||||||
|
if options.Heartbeat == 0 {
|
||||||
|
options.Heartbeat = 10 * time.Second
|
||||||
|
}
|
||||||
|
quicConfig := &quic.Config{
|
||||||
|
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
||||||
|
MaxDatagramFrameSize: 1400,
|
||||||
|
EnableDatagrams: true,
|
||||||
|
Allow0RTT: options.ZeroRTTHandshake,
|
||||||
|
MaxIncomingStreams: 1 << 60,
|
||||||
|
MaxIncomingUniStreams: 1 << 60,
|
||||||
|
}
|
||||||
|
switch options.CongestionControl {
|
||||||
|
case "":
|
||||||
|
options.CongestionControl = "cubic"
|
||||||
|
case "cubic", "new_reno", "bbr":
|
||||||
|
default:
|
||||||
|
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
|
||||||
|
}
|
||||||
|
if len(options.Users) == 0 {
|
||||||
|
return nil, E.New("missing users")
|
||||||
|
}
|
||||||
|
userMap := make(map[uuid.UUID]User)
|
||||||
|
for _, user := range options.Users {
|
||||||
|
userMap[user.UUID] = user
|
||||||
|
}
|
||||||
|
return &Server{
|
||||||
|
ctx: options.Context,
|
||||||
|
logger: options.Logger,
|
||||||
|
tlsConfig: options.TLSConfig,
|
||||||
|
heartbeat: options.Heartbeat,
|
||||||
|
quicConfig: quicConfig,
|
||||||
|
userMap: userMap,
|
||||||
|
congestionControl: options.CongestionControl,
|
||||||
|
authTimeout: options.AuthTimeout,
|
||||||
|
handler: options.Handler,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Start(conn net.PacketConn) error {
|
||||||
|
if !s.quicConfig.Allow0RTT {
|
||||||
|
listener, err := quic.Listen(conn, s.tlsConfig, s.quicConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.quicListener = listener
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
connection, hErr := listener.Accept(s.ctx)
|
||||||
|
if hErr != nil {
|
||||||
|
if strings.Contains(hErr.Error(), "server closed") {
|
||||||
|
s.logger.Debug(E.Cause(hErr, "listener closed"))
|
||||||
|
} else {
|
||||||
|
s.logger.Error(E.Cause(hErr, "listener closed"))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go s.handleConnection(connection)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
listener, err := quic.ListenEarly(conn, s.tlsConfig, s.quicConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.quicListener = listener
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
connection, hErr := listener.Accept(s.ctx)
|
||||||
|
if hErr != nil {
|
||||||
|
if strings.Contains(hErr.Error(), "server closed") {
|
||||||
|
s.logger.Debug(E.Cause(hErr, "listener closed"))
|
||||||
|
} else {
|
||||||
|
s.logger.Error(E.Cause(hErr, "listener closed"))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go s.handleConnection(connection)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Close() error {
|
||||||
|
return common.Close(
|
||||||
|
s.quicListener,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleConnection(connection quic.Connection) {
|
||||||
|
setCongestion(s.ctx, connection, s.congestionControl)
|
||||||
|
session := &serverSession{
|
||||||
|
Server: s,
|
||||||
|
ctx: s.ctx,
|
||||||
|
quicConn: connection,
|
||||||
|
source: M.SocksaddrFromNet(connection.RemoteAddr()),
|
||||||
|
connDone: make(chan struct{}),
|
||||||
|
authDone: make(chan struct{}),
|
||||||
|
udpConnMap: make(map[uint16]*udpPacketConn),
|
||||||
|
}
|
||||||
|
session.handle()
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverSession struct {
|
||||||
|
*Server
|
||||||
|
ctx context.Context
|
||||||
|
quicConn quic.Connection
|
||||||
|
source M.Socksaddr
|
||||||
|
connAccess sync.Mutex
|
||||||
|
connDone chan struct{}
|
||||||
|
connErr error
|
||||||
|
authDone chan struct{}
|
||||||
|
authUser *User
|
||||||
|
udpAccess sync.RWMutex
|
||||||
|
udpConnMap map[uint16]*udpPacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) handle() {
|
||||||
|
if s.ctx.Done() != nil {
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
s.closeWithError(s.ctx.Err())
|
||||||
|
case <-s.connDone:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
go s.loopUniStreams()
|
||||||
|
go s.loopStreams()
|
||||||
|
go s.loopMessages()
|
||||||
|
go s.handleAuthTimeout()
|
||||||
|
go s.loopHeartbeats()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) loopUniStreams() {
|
||||||
|
for {
|
||||||
|
uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err = s.handleUniStream(uniStream)
|
||||||
|
if err != nil {
|
||||||
|
s.closeWithError(E.Cause(err, "handle uni stream"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
|
||||||
|
defer stream.CancelRead(0)
|
||||||
|
buffer := buf.New()
|
||||||
|
defer buffer.Release()
|
||||||
|
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read request")
|
||||||
|
}
|
||||||
|
version := buffer.Byte(0)
|
||||||
|
if version != Version {
|
||||||
|
return E.New("unknown version ", buffer.Byte(0))
|
||||||
|
}
|
||||||
|
command := buffer.Byte(1)
|
||||||
|
switch command {
|
||||||
|
case CommandAuthenticate:
|
||||||
|
select {
|
||||||
|
case <-s.authDone:
|
||||||
|
return E.New("authentication: multiple authentication requests")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if buffer.Len() < AuthenticateLen {
|
||||||
|
_, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len())
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "authentication: read request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16))
|
||||||
|
user, loaded := s.userMap[userUUID]
|
||||||
|
if !loaded {
|
||||||
|
return E.New("authentication: unknown user ", userUUID)
|
||||||
|
}
|
||||||
|
handshakeState := s.quicConn.ConnectionState().TLS
|
||||||
|
tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "authentication: export keying material")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
|
||||||
|
return E.New("authentication: token mismatch")
|
||||||
|
}
|
||||||
|
s.authUser = &user
|
||||||
|
close(s.authDone)
|
||||||
|
return nil
|
||||||
|
case CommandPacket:
|
||||||
|
select {
|
||||||
|
case <-s.connDone:
|
||||||
|
return s.connErr
|
||||||
|
case <-s.authDone:
|
||||||
|
}
|
||||||
|
message := udpMessagePool.Get().(*udpMessage)
|
||||||
|
err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
|
||||||
|
if err != nil {
|
||||||
|
message.release()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.handleUDPMessage(message, true)
|
||||||
|
return nil
|
||||||
|
case CommandDissociate:
|
||||||
|
select {
|
||||||
|
case <-s.connDone:
|
||||||
|
return s.connErr
|
||||||
|
case <-s.authDone:
|
||||||
|
}
|
||||||
|
if buffer.Len() > 4 {
|
||||||
|
return E.New("invalid dissociate message")
|
||||||
|
}
|
||||||
|
var sessionID uint16
|
||||||
|
err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.udpAccess.RLock()
|
||||||
|
udpConn, loaded := s.udpConnMap[sessionID]
|
||||||
|
s.udpAccess.RUnlock()
|
||||||
|
if loaded {
|
||||||
|
udpConn.closeWithError(E.New("remote closed"))
|
||||||
|
s.udpAccess.Lock()
|
||||||
|
delete(s.udpConnMap, sessionID)
|
||||||
|
s.udpAccess.Unlock()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return E.New("unknown command ", command)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) handleAuthTimeout() {
|
||||||
|
select {
|
||||||
|
case <-s.connDone:
|
||||||
|
case <-s.authDone:
|
||||||
|
case <-time.After(s.authTimeout):
|
||||||
|
s.closeWithError(E.New("authentication timeout"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) loopStreams() {
|
||||||
|
for {
|
||||||
|
stream, err := s.quicConn.AcceptStream(s.ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err = s.handleStream(stream)
|
||||||
|
if err != nil {
|
||||||
|
stream.CancelRead(0)
|
||||||
|
stream.Close()
|
||||||
|
s.logger.Error(E.Cause(err, "handle stream request"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) handleStream(stream quic.Stream) error {
|
||||||
|
buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
|
||||||
|
defer buffer.Release()
|
||||||
|
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read request")
|
||||||
|
}
|
||||||
|
version, _ := buffer.ReadByte()
|
||||||
|
if version != Version {
|
||||||
|
return E.New("unknown version ", buffer.Byte(0))
|
||||||
|
}
|
||||||
|
command, _ := buffer.ReadByte()
|
||||||
|
if command != CommandConnect {
|
||||||
|
return E.New("unsupported stream command ", command)
|
||||||
|
}
|
||||||
|
destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read request destination")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-s.connDone:
|
||||||
|
return s.connErr
|
||||||
|
case <-s.authDone:
|
||||||
|
}
|
||||||
|
var conn net.Conn = &serverConn{
|
||||||
|
Stream: stream,
|
||||||
|
destination: destination,
|
||||||
|
}
|
||||||
|
if buffer.IsEmpty() {
|
||||||
|
buffer.Release()
|
||||||
|
} else {
|
||||||
|
conn = bufio.NewCachedConn(conn, buffer)
|
||||||
|
}
|
||||||
|
ctx := s.ctx
|
||||||
|
if s.authUser.Name != "" {
|
||||||
|
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
|
||||||
|
}
|
||||||
|
_ = s.handler.NewConnection(ctx, conn, M.Metadata{
|
||||||
|
Source: s.source,
|
||||||
|
Destination: destination,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) loopHeartbeats() {
|
||||||
|
ticker := time.NewTicker(s.heartbeat)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.connDone:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
|
||||||
|
if err != nil {
|
||||||
|
s.closeWithError(E.Cause(err, "send heartbeat"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) closeWithError(err error) {
|
||||||
|
s.connAccess.Lock()
|
||||||
|
defer s.connAccess.Unlock()
|
||||||
|
select {
|
||||||
|
case <-s.connDone:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
s.connErr = err
|
||||||
|
close(s.connDone)
|
||||||
|
}
|
||||||
|
if E.IsClosedOrCanceled(err) {
|
||||||
|
s.logger.Debug(E.Cause(err, "connection failed"))
|
||||||
|
} else {
|
||||||
|
s.logger.Error(E.Cause(err, "connection failed"))
|
||||||
|
}
|
||||||
|
_ = s.quicConn.CloseWithError(0, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverConn struct {
|
||||||
|
quic.Stream
|
||||||
|
destination M.Socksaddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = c.Stream.Read(p)
|
||||||
|
return n, baderror.WrapQUIC(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||||
|
n, err = c.Stream.Write(p)
|
||||||
|
return n, baderror.WrapQUIC(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) LocalAddr() net.Addr {
|
||||||
|
return c.destination
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) RemoteAddr() net.Addr {
|
||||||
|
return M.Socksaddr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) Close() error {
|
||||||
|
c.Stream.CancelRead(0)
|
||||||
|
return c.Stream.Close()
|
||||||
|
}
|
73
transport/tuic/server_packet.go
Normal file
73
transport/tuic/server_packet.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package tuic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *serverSession) loopMessages() {
|
||||||
|
select {
|
||||||
|
case <-s.connDone:
|
||||||
|
return
|
||||||
|
case <-s.authDone:
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
message, err := s.quicConn.ReceiveMessage(s.ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.closeWithError(E.Cause(err, "receive message"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hErr := s.handleMessage(message)
|
||||||
|
if hErr != nil {
|
||||||
|
s.closeWithError(E.Cause(hErr, "handle message"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) handleMessage(data []byte) error {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return E.New("invalid message")
|
||||||
|
}
|
||||||
|
if data[0] != Version {
|
||||||
|
return E.New("unknown version ", data[0])
|
||||||
|
}
|
||||||
|
switch data[1] {
|
||||||
|
case CommandPacket:
|
||||||
|
message := udpMessagePool.Get().(*udpMessage)
|
||||||
|
err := decodeUDPMessage(message, data[2:])
|
||||||
|
if err != nil {
|
||||||
|
message.release()
|
||||||
|
return E.Cause(err, "decode UDP message")
|
||||||
|
}
|
||||||
|
s.handleUDPMessage(message, false)
|
||||||
|
return nil
|
||||||
|
case CommandHeartbeat:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return E.New("unknown command ", data[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) {
|
||||||
|
s.udpAccess.RLock()
|
||||||
|
udpConn, loaded := s.udpConnMap[message.sessionID]
|
||||||
|
s.udpAccess.RUnlock()
|
||||||
|
if !loaded || common.Done(udpConn.ctx) {
|
||||||
|
udpConn = newUDPPacketConn(s.ctx, s.quicConn, udpStream, true, func() {
|
||||||
|
s.udpAccess.Lock()
|
||||||
|
delete(s.udpConnMap, message.sessionID)
|
||||||
|
s.udpAccess.Unlock()
|
||||||
|
})
|
||||||
|
udpConn.sessionID = message.sessionID
|
||||||
|
s.udpAccess.Lock()
|
||||||
|
s.udpConnMap[message.sessionID] = udpConn
|
||||||
|
s.udpAccess.Unlock()
|
||||||
|
go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
|
||||||
|
Source: s.source,
|
||||||
|
Destination: message.destination,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
udpConn.inputPacket(message)
|
||||||
|
}
|
Loading…
Reference in a new issue