mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-22 08:31:30 +00:00
Migrate QUIC wrapper and protocol implementations to library
This commit is contained in:
parent
1d6d3edec5
commit
bd7adcbb7e
|
@ -1,120 +0,0 @@
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/quic-go/http3"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
aTLS "github.com/sagernet/sing/common/tls"
|
|
||||||
)
|
|
||||||
|
|
||||||
type QUICConfig interface {
|
|
||||||
Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error)
|
|
||||||
DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.EarlyConnection, error)
|
|
||||||
CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, quicConfig *quic.Config, enableDatagrams bool) http.RoundTripper
|
|
||||||
}
|
|
||||||
|
|
||||||
type QUICServerConfig interface {
|
|
||||||
Listen(conn net.PacketConn, config *quic.Config) (QUICListener, error)
|
|
||||||
ListenEarly(conn net.PacketConn, config *quic.Config) (QUICEarlyListener, error)
|
|
||||||
ConfigureHTTP3()
|
|
||||||
}
|
|
||||||
|
|
||||||
type QUICListener interface {
|
|
||||||
Accept(ctx context.Context) (quic.Connection, error)
|
|
||||||
Close() error
|
|
||||||
Addr() net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
type QUICEarlyListener interface {
|
|
||||||
Accept(ctx context.Context) (quic.EarlyConnection, error)
|
|
||||||
Close() error
|
|
||||||
Addr() net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config aTLS.Config, quicConfig *quic.Config) (quic.Connection, error) {
|
|
||||||
if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig {
|
|
||||||
return quicTLSConfig.Dial(ctx, conn, addr, quicConfig)
|
|
||||||
}
|
|
||||||
tlsConfig, err := config.Config()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return quic.Dial(ctx, conn, addr, tlsConfig, quicConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config aTLS.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
|
|
||||||
if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig {
|
|
||||||
return quicTLSConfig.DialEarly(ctx, conn, addr, quicConfig)
|
|
||||||
}
|
|
||||||
tlsConfig, err := config.Config()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return quic.DialEarly(ctx, conn, addr, tlsConfig, quicConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, config aTLS.Config, quicConfig *quic.Config, enableDatagrams bool) (http.RoundTripper, error) {
|
|
||||||
if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig {
|
|
||||||
return quicTLSConfig.CreateTransport(conn, quicConnPtr, serverAddr, quicConfig, enableDatagrams), nil
|
|
||||||
}
|
|
||||||
tlsConfig, err := config.Config()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &http3.RoundTripper{
|
|
||||||
TLSClientConfig: tlsConfig,
|
|
||||||
QuicConfig: quicConfig,
|
|
||||||
EnableDatagrams: enableDatagrams,
|
|
||||||
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
|
||||||
quicConn, err := quic.DialEarly(ctx, conn, serverAddr.UDPAddr(), tlsCfg, cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
*quicConnPtr = quicConn
|
|
||||||
return quicConn, nil
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Listen(conn net.PacketConn, config aTLS.ServerConfig, quicConfig *quic.Config) (QUICListener, error) {
|
|
||||||
if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig {
|
|
||||||
return quicTLSConfig.Listen(conn, quicConfig)
|
|
||||||
}
|
|
||||||
tlsConfig, err := config.Config()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return quic.Listen(conn, tlsConfig, quicConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListenEarly(conn net.PacketConn, config aTLS.ServerConfig, quicConfig *quic.Config) (QUICEarlyListener, error) {
|
|
||||||
if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig {
|
|
||||||
return quicTLSConfig.ListenEarly(conn, quicConfig)
|
|
||||||
}
|
|
||||||
tlsConfig, err := config.Config()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return quic.ListenEarly(conn, tlsConfig, quicConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ConfigureHTTP3(config aTLS.ServerConfig) error {
|
|
||||||
if len(config.NextProtos()) == 0 {
|
|
||||||
config.SetNextProtos([]string{http3.NextProtoH3})
|
|
||||||
}
|
|
||||||
if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig {
|
|
||||||
quicTLSConfig.ConfigureHTTP3()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
tlsConfig, err := config.Config()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
http3.ConfigureTLSConfig(tlsConfig)
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -10,13 +10,13 @@ import (
|
||||||
"github.com/sagernet/cloudflare-tls"
|
"github.com/sagernet/cloudflare-tls"
|
||||||
"github.com/sagernet/quic-go/ech"
|
"github.com/sagernet/quic-go/ech"
|
||||||
"github.com/sagernet/quic-go/http3_ech"
|
"github.com/sagernet/quic-go/http3_ech"
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
"github.com/sagernet/sing-quic"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
_ qtls.QUICConfig = (*echClientConfig)(nil)
|
_ qtls.Config = (*echClientConfig)(nil)
|
||||||
_ qtls.QUICServerConfig = (*echServerConfig)(nil)
|
_ qtls.ServerConfig = (*echServerConfig)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *echClientConfig) Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error) {
|
func (c *echClientConfig) Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error) {
|
||||||
|
@ -43,11 +43,11 @@ func (c *echClientConfig) CreateTransport(conn net.PacketConn, quicConnPtr *quic
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *echServerConfig) Listen(conn net.PacketConn, config *quic.Config) (qtls.QUICListener, error) {
|
func (c *echServerConfig) Listen(conn net.PacketConn, config *quic.Config) (qtls.Listener, error) {
|
||||||
return quic.Listen(conn, c.config, config)
|
return quic.Listen(conn, c.config, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *echServerConfig) ListenEarly(conn net.PacketConn, config *quic.Config) (qtls.QUICEarlyListener, error) {
|
func (c *echServerConfig) ListenEarly(conn net.PacketConn, config *quic.Config) (qtls.EarlyListener, error) {
|
||||||
return quic.ListenEarly(conn, c.config, config)
|
return quic.ListenEarly(conn, c.config, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -29,6 +29,7 @@ require (
|
||||||
github.com/sagernet/sing v0.2.10-0.20230912050851-1453c7c8c20d
|
github.com/sagernet/sing v0.2.10-0.20230912050851-1453c7c8c20d
|
||||||
github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b
|
github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b
|
||||||
github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400
|
github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400
|
||||||
|
github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703
|
||||||
github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0
|
github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0
|
||||||
github.com/sagernet/sing-shadowsocks2 v0.1.4-0.20230907005906-5d2917b29248
|
github.com/sagernet/sing-shadowsocks2 v0.1.4-0.20230907005906-5d2917b29248
|
||||||
github.com/sagernet/sing-shadowtls v0.1.4
|
github.com/sagernet/sing-shadowtls v0.1.4
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -118,6 +118,8 @@ github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b h1:m/UWg2voyb9
|
||||||
github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b/go.mod h1:Kg98PBJEg/08jsNFtmZWmPomhskn9Ausn50ecNm4M+8=
|
github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b/go.mod h1:Kg98PBJEg/08jsNFtmZWmPomhskn9Ausn50ecNm4M+8=
|
||||||
github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400 h1:LtpYd5c5AJtUSxjyH4KjUS8HT+2XgilyozjbCq/x3EM=
|
github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400 h1:LtpYd5c5AJtUSxjyH4KjUS8HT+2XgilyozjbCq/x3EM=
|
||||||
github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400/go.mod h1:TKxqIvfQQgd36jp2tzsPavGjYTVZilV+atip1cssjIY=
|
github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400/go.mod h1:TKxqIvfQQgd36jp2tzsPavGjYTVZilV+atip1cssjIY=
|
||||||
|
github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703 h1:BbJZ5RkY3jQk5P9G5Ra0VhmDNKdT0aIP1FszEDyQL+o=
|
||||||
|
github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703/go.mod h1:Mh5Senu4XDuX+RxSPQEoUB0j6kVmGais2h62Cnfj6Xk=
|
||||||
github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0 h1:9wHYWxH+fcs01PM2+DylA8LNNY3ElnZykQo9rysng8U=
|
github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0 h1:9wHYWxH+fcs01PM2+DylA8LNNY3ElnZykQo9rysng8U=
|
||||||
github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0/go.mod h1:80fNKP0wnqlu85GZXV1H1vDPC/2t+dQbFggOw4XuFUM=
|
github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0/go.mod h1:80fNKP0wnqlu85GZXV1H1vDPC/2t+dQbFggOw4XuFUM=
|
||||||
github.com/sagernet/sing-shadowsocks2 v0.1.4-0.20230907005906-5d2917b29248 h1:JTFfy/LDmVFEK4KZJEujmC1iO8+aoF4unYhhZZRzRq4=
|
github.com/sagernet/sing-shadowsocks2 v0.1.4-0.20230907005906-5d2917b29248 h1:JTFfy/LDmVFEK4KZJEujmC1iO8+aoF4unYhhZZRzRq4=
|
||||||
|
|
|
@ -9,12 +9,12 @@ import (
|
||||||
"github.com/sagernet/quic-go"
|
"github.com/sagernet/quic-go"
|
||||||
"github.com/sagernet/quic-go/congestion"
|
"github.com/sagernet/quic-go/congestion"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/hysteria"
|
"github.com/sagernet/sing-box/transport/hysteria"
|
||||||
|
"github.com/sagernet/sing-quic"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/auth"
|
"github.com/sagernet/sing/common/auth"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
@ -36,7 +36,7 @@ type Hysteria struct {
|
||||||
xplusKey []byte
|
xplusKey []byte
|
||||||
sendBPS uint64
|
sendBPS uint64
|
||||||
recvBPS uint64
|
recvBPS uint64
|
||||||
listener qtls.QUICListener
|
listener qtls.Listener
|
||||||
udpAccess sync.RWMutex
|
udpAccess sync.RWMutex
|
||||||
udpSessionId uint32
|
udpSessionId uint32
|
||||||
udpSessions map[uint32]chan *hysteria.UDPMessage
|
udpSessions map[uint32]chan *hysteria.UDPMessage
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/hysteria2"
|
"github.com/sagernet/sing-quic/hysteria2"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/auth"
|
"github.com/sagernet/sing/common/auth"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
@ -26,7 +26,8 @@ var _ adapter.Inbound = (*Hysteria2)(nil)
|
||||||
type Hysteria2 struct {
|
type Hysteria2 struct {
|
||||||
myInboundAdapter
|
myInboundAdapter
|
||||||
tlsConfig tls.ServerConfig
|
tlsConfig tls.ServerConfig
|
||||||
server *hysteria2.Server
|
service *hysteria2.Service[int]
|
||||||
|
userNameList []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (*Hysteria2, error) {
|
func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (*Hysteria2, error) {
|
||||||
|
@ -84,16 +85,13 @@ func NewHysteria2(ctx context.Context, router adapter.Router, logger log.Context
|
||||||
},
|
},
|
||||||
tlsConfig: tlsConfig,
|
tlsConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
server, err := hysteria2.NewServer(hysteria2.ServerOptions{
|
service, err := hysteria2.NewService[int](hysteria2.ServiceOptions{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
SendBPS: uint64(options.UpMbps * 1024 * 1024),
|
SendBPS: uint64(options.UpMbps * 1024 * 1024),
|
||||||
ReceiveBPS: uint64(options.DownMbps * 1024 * 1024),
|
ReceiveBPS: uint64(options.DownMbps * 1024 * 1024),
|
||||||
SalamanderPassword: salamanderPassword,
|
SalamanderPassword: salamanderPassword,
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
Users: common.Map(options.Users, func(it option.Hysteria2User) hysteria2.User {
|
|
||||||
return hysteria2.User(it)
|
|
||||||
}),
|
|
||||||
IgnoreClientBandwidth: options.IgnoreClientBandwidth,
|
IgnoreClientBandwidth: options.IgnoreClientBandwidth,
|
||||||
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil),
|
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil),
|
||||||
MasqueradeHandler: masqueradeHandler,
|
MasqueradeHandler: masqueradeHandler,
|
||||||
|
@ -101,7 +99,17 @@ func NewHysteria2(ctx context.Context, router adapter.Router, logger log.Context
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
inbound.server = server
|
userList := make([]int, 0, len(options.Users))
|
||||||
|
userNameList := make([]string, 0, len(options.Users))
|
||||||
|
userPasswordList := make([]string, 0, len(options.Users))
|
||||||
|
for index, user := range options.Users {
|
||||||
|
userList = append(userList, index)
|
||||||
|
userNameList = append(userNameList, user.Name)
|
||||||
|
userPasswordList = append(userPasswordList, user.Password)
|
||||||
|
}
|
||||||
|
service.UpdateUsers(userList, userPasswordList)
|
||||||
|
inbound.service = service
|
||||||
|
inbound.userNameList = userNameList
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,14 +117,20 @@ func (h *Hysteria2) newConnection(ctx context.Context, conn net.Conn, metadata a
|
||||||
ctx = log.ContextWithNewID(ctx)
|
ctx = log.ContextWithNewID(ctx)
|
||||||
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||||
metadata = h.createMetadata(conn, metadata)
|
metadata = h.createMetadata(conn, metadata)
|
||||||
metadata.User, _ = auth.UserFromContext[string](ctx)
|
userID, _ := auth.UserFromContext[int](ctx)
|
||||||
|
if userName := h.userNameList[userID]; userName != "" {
|
||||||
|
metadata.User = userName
|
||||||
|
}
|
||||||
return h.router.RouteConnection(ctx, conn, metadata)
|
return h.router.RouteConnection(ctx, conn, metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria2) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
func (h *Hysteria2) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||||
ctx = log.ContextWithNewID(ctx)
|
ctx = log.ContextWithNewID(ctx)
|
||||||
metadata = h.createPacketMetadata(conn, metadata)
|
metadata = h.createPacketMetadata(conn, metadata)
|
||||||
metadata.User, _ = auth.UserFromContext[string](ctx)
|
userID, _ := auth.UserFromContext[int](ctx)
|
||||||
|
if userName := h.userNameList[userID]; userName != "" {
|
||||||
|
metadata.User = userName
|
||||||
|
}
|
||||||
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
||||||
return h.router.RoutePacketConnection(ctx, conn, metadata)
|
return h.router.RoutePacketConnection(ctx, conn, metadata)
|
||||||
}
|
}
|
||||||
|
@ -132,13 +146,13 @@ func (h *Hysteria2) Start() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return h.server.Start(packetConn)
|
return h.service.Start(packetConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria2) Close() error {
|
func (h *Hysteria2) Close() error {
|
||||||
return common.Close(
|
return common.Close(
|
||||||
&h.myInboundAdapter,
|
&h.myInboundAdapter,
|
||||||
h.tlsConfig,
|
h.tlsConfig,
|
||||||
common.PtrOrNil(h.server),
|
common.PtrOrNil(h.service),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ package inbound
|
||||||
import (
|
import (
|
||||||
"github.com/sagernet/quic-go"
|
"github.com/sagernet/quic-go"
|
||||||
"github.com/sagernet/quic-go/http3"
|
"github.com/sagernet/quic-go/http3"
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
"github.com/sagernet/sing-quic"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/tuic"
|
"github.com/sagernet/sing-quic/tuic"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/auth"
|
"github.com/sagernet/sing/common/auth"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
@ -25,8 +25,9 @@ var _ adapter.Inbound = (*TUIC)(nil)
|
||||||
|
|
||||||
type TUIC struct {
|
type TUIC struct {
|
||||||
myInboundAdapter
|
myInboundAdapter
|
||||||
server *tuic.Server
|
|
||||||
tlsConfig tls.ServerConfig
|
tlsConfig tls.ServerConfig
|
||||||
|
server *tuic.Service[int]
|
||||||
|
userNameList []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) {
|
func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) {
|
||||||
|
@ -38,17 +39,6 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var users []tuic.User
|
|
||||||
for index, user := range options.Users {
|
|
||||||
if user.UUID == "" {
|
|
||||||
return nil, E.New("missing uuid for user ", index)
|
|
||||||
}
|
|
||||||
userUUID, err := uuid.FromString(user.UUID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "invalid uuid for user ", index)
|
|
||||||
}
|
|
||||||
users = append(users, tuic.User{Name: user.Name, UUID: userUUID, Password: user.Password})
|
|
||||||
}
|
|
||||||
inbound := &TUIC{
|
inbound := &TUIC{
|
||||||
myInboundAdapter: myInboundAdapter{
|
myInboundAdapter: myInboundAdapter{
|
||||||
protocol: C.TypeTUIC,
|
protocol: C.TypeTUIC,
|
||||||
|
@ -60,11 +50,10 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge
|
||||||
listenOptions: options.ListenOptions,
|
listenOptions: options.ListenOptions,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
server, err := tuic.NewServer(tuic.ServerOptions{
|
service, err := tuic.NewService[int](tuic.ServiceOptions{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
Users: users,
|
|
||||||
CongestionControl: options.CongestionControl,
|
CongestionControl: options.CongestionControl,
|
||||||
AuthTimeout: time.Duration(options.AuthTimeout),
|
AuthTimeout: time.Duration(options.AuthTimeout),
|
||||||
ZeroRTTHandshake: options.ZeroRTTHandshake,
|
ZeroRTTHandshake: options.ZeroRTTHandshake,
|
||||||
|
@ -74,7 +63,26 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
inbound.server = server
|
var userList []int
|
||||||
|
var userNameList []string
|
||||||
|
var userUUIDList [][16]byte
|
||||||
|
var userPasswordList []string
|
||||||
|
for index, user := range options.Users {
|
||||||
|
if user.UUID == "" {
|
||||||
|
return nil, E.New("missing uuid for user ", index)
|
||||||
|
}
|
||||||
|
userUUID, err := uuid.FromString(user.UUID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "invalid uuid for user ", index)
|
||||||
|
}
|
||||||
|
userList = append(userList, index)
|
||||||
|
userNameList = append(userNameList, user.Name)
|
||||||
|
userUUIDList = append(userUUIDList, userUUID)
|
||||||
|
userPasswordList = append(userPasswordList, user.Password)
|
||||||
|
}
|
||||||
|
service.UpdateUsers(userList, userUUIDList, userPasswordList)
|
||||||
|
inbound.server = service
|
||||||
|
inbound.userNameList = userNameList
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,14 +90,20 @@ func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapte
|
||||||
ctx = log.ContextWithNewID(ctx)
|
ctx = log.ContextWithNewID(ctx)
|
||||||
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||||
metadata = h.createMetadata(conn, metadata)
|
metadata = h.createMetadata(conn, metadata)
|
||||||
metadata.User, _ = auth.UserFromContext[string](ctx)
|
userID, _ := auth.UserFromContext[int](ctx)
|
||||||
|
if userName := h.userNameList[userID]; userName != "" {
|
||||||
|
metadata.User = userName
|
||||||
|
}
|
||||||
return h.router.RouteConnection(ctx, conn, metadata)
|
return h.router.RouteConnection(ctx, conn, metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||||
ctx = log.ContextWithNewID(ctx)
|
ctx = log.ContextWithNewID(ctx)
|
||||||
metadata = h.createPacketMetadata(conn, metadata)
|
metadata = h.createPacketMetadata(conn, metadata)
|
||||||
metadata.User, _ = auth.UserFromContext[string](ctx)
|
userID, _ := auth.UserFromContext[int](ctx)
|
||||||
|
if userName := h.userNameList[userID]; userName != "" {
|
||||||
|
metadata.User = userName
|
||||||
|
}
|
||||||
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
||||||
return h.router.RoutePacketConnection(ctx, conn, metadata)
|
return h.router.RoutePacketConnection(ctx, conn, metadata)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,12 @@ import (
|
||||||
"github.com/sagernet/quic-go/congestion"
|
"github.com/sagernet/quic-go/congestion"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/hysteria"
|
"github.com/sagernet/sing-box/transport/hysteria"
|
||||||
|
"github.com/sagernet/sing-quic"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/hysteria2"
|
"github.com/sagernet/sing-quic/hysteria2"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/tuic"
|
"github.com/sagernet/sing-quic/tuic"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
|
@ -1,314 +0,0 @@
|
||||||
package hysteria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
|
||||||
"github.com/sagernet/sing-box/transport/hysteria2/congestion"
|
|
||||||
"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
|
|
||||||
tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion"
|
|
||||||
"github.com/sagernet/sing/common/baderror"
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultStreamReceiveWindow = 8388608 // 8MB
|
|
||||||
defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB
|
|
||||||
defaultMaxIdleTimeout = 30 * time.Second
|
|
||||||
defaultKeepAlivePeriod = 10 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
type ClientOptions struct {
|
|
||||||
Context context.Context
|
|
||||||
Dialer N.Dialer
|
|
||||||
ServerAddress M.Socksaddr
|
|
||||||
SendBPS uint64
|
|
||||||
ReceiveBPS uint64
|
|
||||||
SalamanderPassword string
|
|
||||||
Password string
|
|
||||||
TLSConfig tls.Config
|
|
||||||
UDPDisabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
ctx context.Context
|
|
||||||
dialer N.Dialer
|
|
||||||
serverAddr M.Socksaddr
|
|
||||||
sendBPS uint64
|
|
||||||
receiveBPS uint64
|
|
||||||
salamanderPassword string
|
|
||||||
password string
|
|
||||||
tlsConfig tls.Config
|
|
||||||
quicConfig *quic.Config
|
|
||||||
udpDisabled bool
|
|
||||||
|
|
||||||
connAccess sync.RWMutex
|
|
||||||
conn *clientQUICConnection
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClient(options ClientOptions) (*Client, error) {
|
|
||||||
quicConfig := &quic.Config{
|
|
||||||
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
|
||||||
EnableDatagrams: true,
|
|
||||||
InitialStreamReceiveWindow: defaultStreamReceiveWindow,
|
|
||||||
MaxStreamReceiveWindow: defaultStreamReceiveWindow,
|
|
||||||
InitialConnectionReceiveWindow: defaultConnReceiveWindow,
|
|
||||||
MaxConnectionReceiveWindow: defaultConnReceiveWindow,
|
|
||||||
MaxIdleTimeout: defaultMaxIdleTimeout,
|
|
||||||
KeepAlivePeriod: defaultKeepAlivePeriod,
|
|
||||||
}
|
|
||||||
return &Client{
|
|
||||||
ctx: options.Context,
|
|
||||||
dialer: options.Dialer,
|
|
||||||
serverAddr: options.ServerAddress,
|
|
||||||
sendBPS: options.SendBPS,
|
|
||||||
receiveBPS: options.ReceiveBPS,
|
|
||||||
salamanderPassword: options.SalamanderPassword,
|
|
||||||
password: options.Password,
|
|
||||||
tlsConfig: options.TLSConfig,
|
|
||||||
quicConfig: quicConfig,
|
|
||||||
udpDisabled: options.UDPDisabled,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
|
|
||||||
conn := c.conn
|
|
||||||
if conn != nil && conn.active() {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
c.connAccess.Lock()
|
|
||||||
defer c.connAccess.Unlock()
|
|
||||||
conn = c.conn
|
|
||||||
if conn != nil && conn.active() {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
conn, err := c.offerNew(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
|
|
||||||
udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var packetConn net.PacketConn
|
|
||||||
packetConn = bufio.NewUnbindPacketConn(udpConn)
|
|
||||||
if c.salamanderPassword != "" {
|
|
||||||
packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword))
|
|
||||||
}
|
|
||||||
var quicConn quic.EarlyConnection
|
|
||||||
http3Transport, err := qtls.CreateTransport(packetConn, &quicConn, c.serverAddr, c.tlsConfig, c.quicConfig, true)
|
|
||||||
if err != nil {
|
|
||||||
udpConn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
request := &http.Request{
|
|
||||||
Method: http.MethodPost,
|
|
||||||
URL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: protocol.URLHost,
|
|
||||||
Path: protocol.URLPath,
|
|
||||||
},
|
|
||||||
Header: make(http.Header),
|
|
||||||
}
|
|
||||||
protocol.AuthRequestToHeader(request.Header, protocol.AuthRequest{Auth: c.password, Rx: c.receiveBPS})
|
|
||||||
response, err := http3Transport.RoundTrip(request.WithContext(ctx))
|
|
||||||
if err != nil {
|
|
||||||
if quicConn != nil {
|
|
||||||
quicConn.CloseWithError(0, "")
|
|
||||||
}
|
|
||||||
udpConn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if response.StatusCode != protocol.StatusAuthOK {
|
|
||||||
if quicConn != nil {
|
|
||||||
quicConn.CloseWithError(0, "")
|
|
||||||
}
|
|
||||||
udpConn.Close()
|
|
||||||
return nil, E.New("authentication failed, status code: ", response.StatusCode)
|
|
||||||
}
|
|
||||||
response.Body.Close()
|
|
||||||
authResponse := protocol.AuthResponseFromHeader(response.Header)
|
|
||||||
actualTx := authResponse.Rx
|
|
||||||
if actualTx == 0 || actualTx > c.sendBPS {
|
|
||||||
actualTx = c.sendBPS
|
|
||||||
}
|
|
||||||
if !authResponse.RxAuto && actualTx > 0 {
|
|
||||||
quicConn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
|
|
||||||
} else {
|
|
||||||
quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
|
|
||||||
tuicCongestion.DefaultClock{},
|
|
||||||
tuicCongestion.GetInitialPacketSize(quicConn.RemoteAddr()),
|
|
||||||
tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
|
|
||||||
tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
conn := &clientQUICConnection{
|
|
||||||
quicConn: quicConn,
|
|
||||||
rawConn: udpConn,
|
|
||||||
connDone: make(chan struct{}),
|
|
||||||
udpDisabled: c.udpDisabled || !authResponse.UDPEnabled,
|
|
||||||
udpConnMap: make(map[uint32]*udpPacketConn),
|
|
||||||
}
|
|
||||||
if !c.udpDisabled {
|
|
||||||
go c.loopMessages(conn)
|
|
||||||
}
|
|
||||||
c.conn = conn
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
|
||||||
conn, err := c.offer(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stream, err := conn.quicConn.OpenStream()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &clientConn{
|
|
||||||
Stream: stream,
|
|
||||||
destination: destination,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
|
|
||||||
if c.udpDisabled {
|
|
||||||
return nil, os.ErrInvalid
|
|
||||||
}
|
|
||||||
conn, err := c.offer(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if conn.udpDisabled {
|
|
||||||
return nil, E.New("UDP disabled by server")
|
|
||||||
}
|
|
||||||
var sessionID uint32
|
|
||||||
clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, func() {
|
|
||||||
conn.udpAccess.Lock()
|
|
||||||
delete(conn.udpConnMap, sessionID)
|
|
||||||
conn.udpAccess.Unlock()
|
|
||||||
})
|
|
||||||
conn.udpAccess.Lock()
|
|
||||||
sessionID = conn.udpSessionID
|
|
||||||
conn.udpSessionID++
|
|
||||||
conn.udpConnMap[sessionID] = clientPacketConn
|
|
||||||
conn.udpAccess.Unlock()
|
|
||||||
clientPacketConn.sessionID = sessionID
|
|
||||||
return clientPacketConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) CloseWithError(err error) error {
|
|
||||||
conn := c.conn
|
|
||||||
if conn != nil {
|
|
||||||
conn.closeWithError(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientQUICConnection struct {
|
|
||||||
quicConn quic.Connection
|
|
||||||
rawConn io.Closer
|
|
||||||
closeOnce sync.Once
|
|
||||||
connDone chan struct{}
|
|
||||||
connErr error
|
|
||||||
udpDisabled bool
|
|
||||||
udpAccess sync.RWMutex
|
|
||||||
udpConnMap map[uint32]*udpPacketConn
|
|
||||||
udpSessionID uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientQUICConnection) active() bool {
|
|
||||||
select {
|
|
||||||
case <-c.quicConn.Context().Done():
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-c.connDone:
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientQUICConnection) closeWithError(err error) {
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
c.connErr = err
|
|
||||||
close(c.connDone)
|
|
||||||
c.quicConn.CloseWithError(0, "")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientConn struct {
|
|
||||||
quic.Stream
|
|
||||||
destination M.Socksaddr
|
|
||||||
requestWritten bool
|
|
||||||
responseRead bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) NeedHandshake() bool {
|
|
||||||
return !c.requestWritten
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) Read(p []byte) (n int, err error) {
|
|
||||||
if c.responseRead {
|
|
||||||
n, err = c.Stream.Read(p)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
status, errorMessage, err := protocol.ReadTCPResponse(c.Stream)
|
|
||||||
if err != nil {
|
|
||||||
return 0, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
if !status {
|
|
||||||
err = E.New("remote error: ", errorMessage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.responseRead = true
|
|
||||||
n, err = c.Stream.Read(p)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
|
||||||
if !c.requestWritten {
|
|
||||||
buffer := protocol.WriteTCPRequest(c.destination.String(), p)
|
|
||||||
defer buffer.Release()
|
|
||||||
_, err = c.Stream.Write(buffer.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.requestWritten = true
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
n, err = c.Stream.Write(p)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) LocalAddr() net.Addr {
|
|
||||||
return M.Socksaddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) RemoteAddr() net.Addr {
|
|
||||||
return M.Socksaddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) Close() error {
|
|
||||||
c.Stream.CancelRead(0)
|
|
||||||
return c.Stream.Close()
|
|
||||||
}
|
|
|
@ -1,47 +0,0 @@
|
||||||
package hysteria2
|
|
||||||
|
|
||||||
import E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
|
|
||||||
func (c *Client) loopMessages(conn *clientQUICConnection) {
|
|
||||||
for {
|
|
||||||
message, err := conn.quicConn.ReceiveMessage(c.ctx)
|
|
||||||
if err != nil {
|
|
||||||
conn.closeWithError(E.Cause(err, "receive message"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
hErr := c.handleMessage(conn, message)
|
|
||||||
if hErr != nil {
|
|
||||||
conn.closeWithError(E.Cause(hErr, "handle message"))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
|
|
||||||
message := allocMessage()
|
|
||||||
err := decodeUDPMessage(message, data)
|
|
||||||
if err != nil {
|
|
||||||
message.release()
|
|
||||||
return E.Cause(err, "decode UDP message")
|
|
||||||
}
|
|
||||||
conn.handleUDPMessage(message)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
|
|
||||||
c.udpAccess.RLock()
|
|
||||||
udpConn, loaded := c.udpConnMap[message.sessionID]
|
|
||||||
c.udpAccess.RUnlock()
|
|
||||||
if !loaded {
|
|
||||||
message.releaseMessage()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-udpConn.ctx.Done():
|
|
||||||
message.releaseMessage()
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
udpConn.inputPacket(message)
|
|
||||||
}
|
|
|
@ -1,151 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/congestion"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
initMaxDatagramSize = 1252
|
|
||||||
|
|
||||||
pktInfoSlotCount = 4
|
|
||||||
minSampleCount = 50
|
|
||||||
minAckRate = 0.8
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ congestion.CongestionControl = &BrutalSender{}
|
|
||||||
|
|
||||||
type BrutalSender struct {
|
|
||||||
rttStats congestion.RTTStatsProvider
|
|
||||||
bps congestion.ByteCount
|
|
||||||
maxDatagramSize congestion.ByteCount
|
|
||||||
pacer *pacer
|
|
||||||
|
|
||||||
pktInfoSlots [pktInfoSlotCount]pktInfo
|
|
||||||
ackRate float64
|
|
||||||
}
|
|
||||||
|
|
||||||
type pktInfo struct {
|
|
||||||
Timestamp int64
|
|
||||||
AckCount uint64
|
|
||||||
LossCount uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBrutalSender(bps uint64) *BrutalSender {
|
|
||||||
bs := &BrutalSender{
|
|
||||||
bps: congestion.ByteCount(bps),
|
|
||||||
maxDatagramSize: initMaxDatagramSize,
|
|
||||||
ackRate: 1,
|
|
||||||
}
|
|
||||||
bs.pacer = newPacer(func() congestion.ByteCount {
|
|
||||||
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
|
||||||
})
|
|
||||||
return bs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
|
|
||||||
b.rttStats = rttStats
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
|
||||||
return b.pacer.TimeUntilSend()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) HasPacingBudget(now time.Time) bool {
|
|
||||||
return b.pacer.Budget(now) >= b.maxDatagramSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
|
||||||
return bytesInFlight < b.GetCongestionWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
|
||||||
rtt := b.rttStats.SmoothedRTT()
|
|
||||||
if rtt <= 0 {
|
|
||||||
return 10240
|
|
||||||
}
|
|
||||||
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
|
||||||
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
|
|
||||||
) {
|
|
||||||
b.pacer.SentPacket(sentTime, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
|
||||||
priorInFlight congestion.ByteCount, eventTime time.Time,
|
|
||||||
) {
|
|
||||||
currentTimestamp := eventTime.Unix()
|
|
||||||
slot := currentTimestamp % pktInfoSlotCount
|
|
||||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
|
||||||
b.pktInfoSlots[slot].AckCount++
|
|
||||||
} else {
|
|
||||||
// uninitialized slot or too old, reset
|
|
||||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
|
||||||
b.pktInfoSlots[slot].AckCount = 1
|
|
||||||
b.pktInfoSlots[slot].LossCount = 0
|
|
||||||
}
|
|
||||||
b.updateAckRate(currentTimestamp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
|
||||||
priorInFlight congestion.ByteCount,
|
|
||||||
) {
|
|
||||||
currentTimestamp := time.Now().Unix()
|
|
||||||
slot := currentTimestamp % pktInfoSlotCount
|
|
||||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
|
||||||
b.pktInfoSlots[slot].LossCount++
|
|
||||||
} else {
|
|
||||||
// uninitialized slot or too old, reset
|
|
||||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
|
||||||
b.pktInfoSlots[slot].AckCount = 0
|
|
||||||
b.pktInfoSlots[slot].LossCount = 1
|
|
||||||
}
|
|
||||||
b.updateAckRate(currentTimestamp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
|
||||||
b.maxDatagramSize = size
|
|
||||||
b.pacer.SetMaxDatagramSize(size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
|
||||||
minTimestamp := currentTimestamp - pktInfoSlotCount
|
|
||||||
var ackCount, lossCount uint64
|
|
||||||
for _, info := range b.pktInfoSlots {
|
|
||||||
if info.Timestamp < minTimestamp {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ackCount += info.AckCount
|
|
||||||
lossCount += info.LossCount
|
|
||||||
}
|
|
||||||
if ackCount+lossCount < minSampleCount {
|
|
||||||
b.ackRate = 1
|
|
||||||
}
|
|
||||||
rate := float64(ackCount) / float64(ackCount+lossCount)
|
|
||||||
if rate < minAckRate {
|
|
||||||
b.ackRate = minAckRate
|
|
||||||
}
|
|
||||||
b.ackRate = rate
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) InSlowStart() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) InRecovery() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BrutalSender) MaybeExitSlowStart() {}
|
|
||||||
|
|
||||||
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
|
||||||
|
|
||||||
func maxDuration(a, b time.Duration) time.Duration {
|
|
||||||
if a > b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
|
@ -1,86 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/congestion"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
maxBurstPackets = 10
|
|
||||||
minPacingDelay = time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
// The pacer implements a token bucket pacing algorithm.
|
|
||||||
type pacer struct {
|
|
||||||
budgetAtLastSent congestion.ByteCount
|
|
||||||
maxDatagramSize congestion.ByteCount
|
|
||||||
lastSentTime time.Time
|
|
||||||
getBandwidth func() congestion.ByteCount // in bytes/s
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
|
|
||||||
p := &pacer{
|
|
||||||
budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
|
|
||||||
maxDatagramSize: initMaxDatagramSize,
|
|
||||||
getBandwidth: getBandwidth,
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
|
||||||
budget := p.Budget(sendTime)
|
|
||||||
if size > budget {
|
|
||||||
p.budgetAtLastSent = 0
|
|
||||||
} else {
|
|
||||||
p.budgetAtLastSent = budget - size
|
|
||||||
}
|
|
||||||
p.lastSentTime = sendTime
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
|
||||||
if p.lastSentTime.IsZero() {
|
|
||||||
return p.maxBurstSize()
|
|
||||||
}
|
|
||||||
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
|
||||||
return minByteCount(p.maxBurstSize(), budget)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pacer) maxBurstSize() congestion.ByteCount {
|
|
||||||
return maxByteCount(
|
|
||||||
congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
|
|
||||||
maxBurstPackets*p.maxDatagramSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TimeUntilSend returns when the next packet should be sent.
|
|
||||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
|
||||||
func (p *pacer) TimeUntilSend() time.Time {
|
|
||||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
|
||||||
return time.Time{}
|
|
||||||
}
|
|
||||||
return p.lastSentTime.Add(maxDuration(
|
|
||||||
minPacingDelay,
|
|
||||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
|
|
||||||
float64(p.getBandwidth())))*time.Nanosecond,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
|
||||||
p.maxDatagramSize = s
|
|
||||||
}
|
|
||||||
|
|
||||||
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
|
||||||
if a < b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
package protocol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
URLHost = "hysteria"
|
|
||||||
URLPath = "/auth"
|
|
||||||
|
|
||||||
RequestHeaderAuth = "Hysteria-Auth"
|
|
||||||
ResponseHeaderUDPEnabled = "Hysteria-UDP"
|
|
||||||
CommonHeaderCCRX = "Hysteria-CC-RX"
|
|
||||||
CommonHeaderPadding = "Hysteria-Padding"
|
|
||||||
|
|
||||||
StatusAuthOK = 233
|
|
||||||
)
|
|
||||||
|
|
||||||
// AuthRequest is what client sends to server for authentication.
|
|
||||||
type AuthRequest struct {
|
|
||||||
Auth string
|
|
||||||
Rx uint64 // 0 = unknown, client asks server to use bandwidth detection
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthResponse is what server sends to client when authentication is passed.
|
|
||||||
type AuthResponse struct {
|
|
||||||
UDPEnabled bool
|
|
||||||
Rx uint64 // 0 = unlimited
|
|
||||||
RxAuto bool // true = server asks client to use bandwidth detection
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthRequestFromHeader(h http.Header) AuthRequest {
|
|
||||||
rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
|
|
||||||
return AuthRequest{
|
|
||||||
Auth: h.Get(RequestHeaderAuth),
|
|
||||||
Rx: rx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthRequestToHeader(h http.Header, req AuthRequest) {
|
|
||||||
h.Set(RequestHeaderAuth, req.Auth)
|
|
||||||
h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10))
|
|
||||||
h.Set(CommonHeaderPadding, authRequestPadding.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthResponseFromHeader(h http.Header) AuthResponse {
|
|
||||||
resp := AuthResponse{}
|
|
||||||
resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
|
|
||||||
rxStr := h.Get(CommonHeaderCCRX)
|
|
||||||
if rxStr == "auto" {
|
|
||||||
// Special case for server requesting client to use bandwidth detection
|
|
||||||
resp.RxAuto = true
|
|
||||||
} else {
|
|
||||||
resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64)
|
|
||||||
}
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthResponseToHeader(h http.Header, resp AuthResponse) {
|
|
||||||
h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled))
|
|
||||||
if resp.RxAuto {
|
|
||||||
h.Set(CommonHeaderCCRX, "auto")
|
|
||||||
} else {
|
|
||||||
h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10))
|
|
||||||
}
|
|
||||||
h.Set(CommonHeaderPadding, authResponsePadding.String())
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
package protocol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
||||||
)
|
|
||||||
|
|
||||||
// padding specifies a half-open range [Min, Max).
|
|
||||||
type padding struct {
|
|
||||||
Min int
|
|
||||||
Max int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p padding) String() string {
|
|
||||||
n := p.Min + rand.Intn(p.Max-p.Min)
|
|
||||||
bs := make([]byte, n)
|
|
||||||
for i := range bs {
|
|
||||||
bs[i] = paddingChars[rand.Intn(len(paddingChars))]
|
|
||||||
}
|
|
||||||
return string(bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
authRequestPadding = padding{Min: 256, Max: 2048}
|
|
||||||
authResponsePadding = padding{Min: 256, Max: 2048}
|
|
||||||
tcpRequestPadding = padding{Min: 64, Max: 512}
|
|
||||||
tcpResponsePadding = padding{Min: 128, Max: 1024}
|
|
||||||
)
|
|
|
@ -1,266 +0,0 @@
|
||||||
package protocol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/quicvarint"
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
"github.com/sagernet/sing/common/rw"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
FrameTypeTCPRequest = 0x401
|
|
||||||
|
|
||||||
// Max length values are for preventing DoS attacks
|
|
||||||
|
|
||||||
MaxAddressLength = 2048
|
|
||||||
MaxMessageLength = 2048
|
|
||||||
MaxPaddingLength = 4096
|
|
||||||
|
|
||||||
MaxUDPSize = 4096
|
|
||||||
|
|
||||||
maxVarInt1 = 63
|
|
||||||
maxVarInt2 = 16383
|
|
||||||
maxVarInt4 = 1073741823
|
|
||||||
maxVarInt8 = 4611686018427387903
|
|
||||||
)
|
|
||||||
|
|
||||||
// TCPRequest format:
|
|
||||||
// 0x401 (QUIC varint)
|
|
||||||
// Address length (QUIC varint)
|
|
||||||
// Address (bytes)
|
|
||||||
// Padding length (QUIC varint)
|
|
||||||
// Padding (bytes)
|
|
||||||
|
|
||||||
func ReadTCPRequest(r io.Reader) (string, error) {
|
|
||||||
bReader := quicvarint.NewReader(r)
|
|
||||||
addrLen, err := quicvarint.Read(bReader)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if addrLen == 0 || addrLen > MaxAddressLength {
|
|
||||||
return "", E.New("invalid address length")
|
|
||||||
}
|
|
||||||
addrBuf := make([]byte, addrLen)
|
|
||||||
_, err = io.ReadFull(r, addrBuf)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
paddingLen, err := quicvarint.Read(bReader)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if paddingLen > MaxPaddingLength {
|
|
||||||
return "", E.New("invalid padding length")
|
|
||||||
}
|
|
||||||
if paddingLen > 0 {
|
|
||||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return string(addrBuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteTCPRequest(addr string, payload []byte) *buf.Buffer {
|
|
||||||
padding := tcpRequestPadding.String()
|
|
||||||
paddingLen := len(padding)
|
|
||||||
addrLen := len(addr)
|
|
||||||
sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
|
|
||||||
int(quicvarint.Len(uint64(addrLen))) + addrLen +
|
|
||||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
|
||||||
buffer := buf.NewSize(sz + len(payload))
|
|
||||||
bufferContent := buffer.Extend(sz)
|
|
||||||
i := varintPut(bufferContent, FrameTypeTCPRequest)
|
|
||||||
i += varintPut(bufferContent[i:], uint64(addrLen))
|
|
||||||
i += copy(bufferContent[i:], addr)
|
|
||||||
i += varintPut(bufferContent[i:], uint64(paddingLen))
|
|
||||||
copy(bufferContent[i:], padding)
|
|
||||||
buffer.Write(payload)
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCPResponse format:
|
|
||||||
// Status (byte, 0=ok, 1=error)
|
|
||||||
// Message length (QUIC varint)
|
|
||||||
// Message (bytes)
|
|
||||||
// Padding length (QUIC varint)
|
|
||||||
// Padding (bytes)
|
|
||||||
|
|
||||||
func ReadTCPResponse(r io.Reader) (bool, string, error) {
|
|
||||||
var status [1]byte
|
|
||||||
if _, err := io.ReadFull(r, status[:]); err != nil {
|
|
||||||
return false, "", err
|
|
||||||
}
|
|
||||||
bReader := quicvarint.NewReader(r)
|
|
||||||
msg, err := ReadVString(bReader)
|
|
||||||
if err != nil {
|
|
||||||
return false, "", err
|
|
||||||
}
|
|
||||||
paddingLen, err := quicvarint.Read(bReader)
|
|
||||||
if err != nil {
|
|
||||||
return false, "", err
|
|
||||||
}
|
|
||||||
if paddingLen > MaxPaddingLength {
|
|
||||||
return false, "", E.New("invalid padding length")
|
|
||||||
}
|
|
||||||
if paddingLen > 0 {
|
|
||||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
|
||||||
if err != nil {
|
|
||||||
return false, "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return status[0] == 0, msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer {
|
|
||||||
padding := tcpResponsePadding.String()
|
|
||||||
paddingLen := len(padding)
|
|
||||||
msgLen := len(msg)
|
|
||||||
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
|
|
||||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
|
||||||
buffer := buf.NewSize(sz + len(payload))
|
|
||||||
if ok {
|
|
||||||
buffer.WriteByte(0)
|
|
||||||
} else {
|
|
||||||
buffer.WriteByte(1)
|
|
||||||
}
|
|
||||||
WriteVString(buffer, msg)
|
|
||||||
WriteUVariant(buffer, uint64(paddingLen))
|
|
||||||
buffer.Extend(paddingLen)
|
|
||||||
buffer.Write(payload)
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
// UDPMessage format:
|
|
||||||
// Session ID (uint32 BE)
|
|
||||||
// Packet ID (uint16 BE)
|
|
||||||
// Fragment ID (uint8)
|
|
||||||
// Fragment count (uint8)
|
|
||||||
// Address length (QUIC varint)
|
|
||||||
// Address (bytes)
|
|
||||||
// Data...
|
|
||||||
|
|
||||||
type UDPMessage struct {
|
|
||||||
SessionID uint32 // 4
|
|
||||||
PacketID uint16 // 2
|
|
||||||
FragID uint8 // 1
|
|
||||||
FragCount uint8 // 1
|
|
||||||
Addr string // varint + bytes
|
|
||||||
Data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *UDPMessage) HeaderSize() int {
|
|
||||||
lAddr := len(m.Addr)
|
|
||||||
return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *UDPMessage) Size() int {
|
|
||||||
return m.HeaderSize() + len(m.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *UDPMessage) Serialize(buf []byte) int {
|
|
||||||
// Make sure the buffer is big enough
|
|
||||||
if len(buf) < m.Size() {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint32(buf, m.SessionID)
|
|
||||||
binary.BigEndian.PutUint16(buf[4:], m.PacketID)
|
|
||||||
buf[6] = m.FragID
|
|
||||||
buf[7] = m.FragCount
|
|
||||||
i := varintPut(buf[8:], uint64(len(m.Addr)))
|
|
||||||
i += copy(buf[8+i:], m.Addr)
|
|
||||||
i += copy(buf[8+i:], m.Data)
|
|
||||||
return 8 + i
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
|
|
||||||
m := &UDPMessage{}
|
|
||||||
buf := bytes.NewBuffer(msg)
|
|
||||||
if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
lAddr, err := quicvarint.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if lAddr == 0 || lAddr > MaxMessageLength {
|
|
||||||
return nil, E.New("invalid address length")
|
|
||||||
}
|
|
||||||
bs := buf.Bytes()
|
|
||||||
m.Addr = string(bs[:lAddr])
|
|
||||||
m.Data = bs[lAddr:]
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadVString(reader io.Reader) (string, error) {
|
|
||||||
length, err := quicvarint.Read(quicvarint.NewReader(reader))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
value, err := rw.ReadBytes(reader, int(length))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return string(value), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteVString(writer io.Writer, value string) error {
|
|
||||||
err := WriteUVariant(writer, uint64(len(value)))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return rw.WriteString(writer, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteUVariant(writer io.Writer, value uint64) error {
|
|
||||||
var b [8]byte
|
|
||||||
return common.Error(writer.Write(b[:varintPut(b[:], value)]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// varintPut is like quicvarint.Append, but instead of appending to a slice,
|
|
||||||
// it writes to a fixed-size buffer. Returns the number of bytes written.
|
|
||||||
func varintPut(b []byte, i uint64) int {
|
|
||||||
if i <= maxVarInt1 {
|
|
||||||
b[0] = uint8(i)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
if i <= maxVarInt2 {
|
|
||||||
b[0] = uint8(i>>8) | 0x40
|
|
||||||
b[1] = uint8(i)
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
if i <= maxVarInt4 {
|
|
||||||
b[0] = uint8(i>>24) | 0x80
|
|
||||||
b[1] = uint8(i >> 16)
|
|
||||||
b[2] = uint8(i >> 8)
|
|
||||||
b[3] = uint8(i)
|
|
||||||
return 4
|
|
||||||
}
|
|
||||||
if i <= maxVarInt8 {
|
|
||||||
b[0] = uint8(i>>56) | 0xc0
|
|
||||||
b[1] = uint8(i >> 48)
|
|
||||||
b[2] = uint8(i >> 40)
|
|
||||||
b[3] = uint8(i >> 32)
|
|
||||||
b[4] = uint8(i >> 24)
|
|
||||||
b[5] = uint8(i >> 16)
|
|
||||||
b[6] = uint8(i >> 8)
|
|
||||||
b[7] = uint8(i)
|
|
||||||
return 8
|
|
||||||
}
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
|
|
||||||
}
|
|
|
@ -1,450 +0,0 @@
|
||||||
package hysteria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/quic-go/quicvarint"
|
|
||||||
"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/atomic"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/cache"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
var udpMessagePool = sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
return new(udpMessage)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func allocMessage() *udpMessage {
|
|
||||||
message := udpMessagePool.Get().(*udpMessage)
|
|
||||||
message.referenced = true
|
|
||||||
return message
|
|
||||||
}
|
|
||||||
|
|
||||||
func releaseMessages(messages []*udpMessage) {
|
|
||||||
for _, message := range messages {
|
|
||||||
if message != nil {
|
|
||||||
message.release()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpMessage struct {
|
|
||||||
sessionID uint32
|
|
||||||
packetID uint16
|
|
||||||
fragmentID uint8
|
|
||||||
fragmentTotal uint8
|
|
||||||
destination string
|
|
||||||
data *buf.Buffer
|
|
||||||
referenced bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpMessage) release() {
|
|
||||||
if !m.referenced {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*m = udpMessage{}
|
|
||||||
udpMessagePool.Put(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpMessage) releaseMessage() {
|
|
||||||
m.data.Release()
|
|
||||||
m.release()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpMessage) pack() *buf.Buffer {
|
|
||||||
buffer := buf.NewSize(m.headerSize() + m.data.Len())
|
|
||||||
common.Must(
|
|
||||||
binary.Write(buffer, binary.BigEndian, m.sessionID),
|
|
||||||
binary.Write(buffer, binary.BigEndian, m.packetID),
|
|
||||||
binary.Write(buffer, binary.BigEndian, m.fragmentID),
|
|
||||||
binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
|
|
||||||
protocol.WriteVString(buffer, m.destination),
|
|
||||||
common.Error(buffer.Write(m.data.Bytes())),
|
|
||||||
)
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpMessage) headerSize() int {
|
|
||||||
return 8 + int(quicvarint.Len(uint64(len(m.destination)))) + len(m.destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
|
|
||||||
if message.data.Len() <= maxPacketSize {
|
|
||||||
return []*udpMessage{message}
|
|
||||||
}
|
|
||||||
var fragments []*udpMessage
|
|
||||||
originPacket := message.data.Bytes()
|
|
||||||
udpMTU := maxPacketSize - message.headerSize()
|
|
||||||
for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
|
|
||||||
fragment := allocMessage()
|
|
||||||
*fragment = *message
|
|
||||||
if remaining > udpMTU {
|
|
||||||
fragment.data = buf.As(originPacket[:udpMTU])
|
|
||||||
originPacket = originPacket[udpMTU:]
|
|
||||||
} else {
|
|
||||||
fragment.data = buf.As(originPacket)
|
|
||||||
originPacket = nil
|
|
||||||
}
|
|
||||||
fragments = append(fragments, fragment)
|
|
||||||
}
|
|
||||||
fragmentTotal := uint16(len(fragments))
|
|
||||||
for index, fragment := range fragments {
|
|
||||||
fragment.fragmentID = uint8(index)
|
|
||||||
fragment.fragmentTotal = uint8(fragmentTotal)
|
|
||||||
/*if index > 0 {
|
|
||||||
fragment.destination = ""
|
|
||||||
// not work in hysteria
|
|
||||||
}*/
|
|
||||||
}
|
|
||||||
return fragments
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpPacketConn struct {
|
|
||||||
ctx context.Context
|
|
||||||
cancel common.ContextCancelCauseFunc
|
|
||||||
sessionID uint32
|
|
||||||
quicConn quic.Connection
|
|
||||||
data chan *udpMessage
|
|
||||||
udpMTU int
|
|
||||||
udpMTUTime time.Time
|
|
||||||
packetId atomic.Uint32
|
|
||||||
closeOnce sync.Once
|
|
||||||
defragger *udpDefragger
|
|
||||||
onDestroy func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn {
|
|
||||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
|
||||||
return &udpPacketConn{
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
quicConn: quicConn,
|
|
||||||
data: make(chan *udpMessage, 64),
|
|
||||||
defragger: newUDPDefragger(),
|
|
||||||
onDestroy: onDestroy,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case p := <-c.data:
|
|
||||||
buffer = p.data
|
|
||||||
destination = M.ParseSocksaddr(p.destination)
|
|
||||||
p.release()
|
|
||||||
return
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case p := <-c.data:
|
|
||||||
_, err = buffer.ReadOnceFrom(p.data)
|
|
||||||
destination = M.ParseSocksaddr(p.destination)
|
|
||||||
p.releaseMessage()
|
|
||||||
return
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case p := <-c.data:
|
|
||||||
_, err = newBuffer().ReadOnceFrom(p.data)
|
|
||||||
destination = M.ParseSocksaddr(p.destination)
|
|
||||||
p.releaseMessage()
|
|
||||||
return
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
select {
|
|
||||||
case pkt := <-c.data:
|
|
||||||
n = copy(p, pkt.data.Bytes())
|
|
||||||
destination := M.ParseSocksaddr(pkt.destination)
|
|
||||||
if destination.IsFqdn() {
|
|
||||||
addr = destination
|
|
||||||
} else {
|
|
||||||
addr = destination.UDPAddr()
|
|
||||||
}
|
|
||||||
pkt.releaseMessage()
|
|
||||||
return n, addr, nil
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return 0, nil, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) needFragment() bool {
|
|
||||||
nowTime := time.Now()
|
|
||||||
if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
|
|
||||||
c.udpMTUTime = nowTime
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
defer buffer.Release()
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return net.ErrClosed
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if buffer.Len() > 0xffff {
|
|
||||||
return quic.ErrMessageTooLarge(0xffff)
|
|
||||||
}
|
|
||||||
packetId := c.packetId.Add(1)
|
|
||||||
if packetId > math.MaxUint16 {
|
|
||||||
c.packetId.Store(0)
|
|
||||||
packetId = 0
|
|
||||||
}
|
|
||||||
message := allocMessage()
|
|
||||||
*message = udpMessage{
|
|
||||||
sessionID: c.sessionID,
|
|
||||||
packetID: uint16(packetId),
|
|
||||||
fragmentTotal: 1,
|
|
||||||
destination: destination.String(),
|
|
||||||
data: buffer,
|
|
||||||
}
|
|
||||||
defer message.releaseMessage()
|
|
||||||
var err error
|
|
||||||
if c.needFragment() && buffer.Len() > c.udpMTU {
|
|
||||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
|
||||||
} else {
|
|
||||||
err = c.writePacket(message)
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var tooLargeErr quic.ErrMessageTooLarge
|
|
||||||
if !errors.As(err, &tooLargeErr) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.udpMTU = int(tooLargeErr)
|
|
||||||
c.udpMTUTime = time.Now()
|
|
||||||
return c.writePackets(fragUDPMessage(message, c.udpMTU))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return 0, net.ErrClosed
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if len(p) > 0xffff {
|
|
||||||
return 0, quic.ErrMessageTooLarge(0xffff)
|
|
||||||
}
|
|
||||||
packetId := c.packetId.Add(1)
|
|
||||||
if packetId > math.MaxUint16 {
|
|
||||||
c.packetId.Store(0)
|
|
||||||
packetId = 0
|
|
||||||
}
|
|
||||||
message := allocMessage()
|
|
||||||
*message = udpMessage{
|
|
||||||
sessionID: c.sessionID,
|
|
||||||
packetID: uint16(packetId),
|
|
||||||
fragmentTotal: 1,
|
|
||||||
destination: addr.String(),
|
|
||||||
data: buf.As(p),
|
|
||||||
}
|
|
||||||
if c.needFragment() && len(p) > c.udpMTU {
|
|
||||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
|
||||||
if err == nil {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = c.writePacket(message)
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
var tooLargeErr quic.ErrMessageTooLarge
|
|
||||||
if !errors.As(err, &tooLargeErr) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.udpMTU = int(tooLargeErr)
|
|
||||||
c.udpMTUTime = time.Now()
|
|
||||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
|
||||||
if err == nil {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) inputPacket(message *udpMessage) {
|
|
||||||
if message.fragmentTotal <= 1 {
|
|
||||||
select {
|
|
||||||
case c.data <- message:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
newMessage := c.defragger.feed(message)
|
|
||||||
if newMessage != nil {
|
|
||||||
select {
|
|
||||||
case c.data <- newMessage:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
|
|
||||||
defer releaseMessages(messages)
|
|
||||||
for _, message := range messages {
|
|
||||||
err := c.writePacket(message)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) writePacket(message *udpMessage) error {
|
|
||||||
buffer := message.pack()
|
|
||||||
defer buffer.Release()
|
|
||||||
return c.quicConn.SendMessage(buffer.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) Close() error {
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
c.closeWithError(os.ErrClosed)
|
|
||||||
c.onDestroy()
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) closeWithError(err error) {
|
|
||||||
c.cancel(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) LocalAddr() net.Addr {
|
|
||||||
return c.quicConn.LocalAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) SetDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpDefragger struct {
|
|
||||||
packetMap *cache.LruCache[uint16, *packetItem]
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUDPDefragger() *udpDefragger {
|
|
||||||
return &udpDefragger{
|
|
||||||
packetMap: cache.New(
|
|
||||||
cache.WithAge[uint16, *packetItem](10),
|
|
||||||
cache.WithUpdateAgeOnGet[uint16, *packetItem](),
|
|
||||||
cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
|
|
||||||
releaseMessages(value.messages)
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type packetItem struct {
|
|
||||||
access sync.Mutex
|
|
||||||
messages []*udpMessage
|
|
||||||
count uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
|
|
||||||
if m.fragmentTotal <= 1 {
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
if m.fragmentID >= m.fragmentTotal {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
|
|
||||||
item.access.Lock()
|
|
||||||
defer item.access.Unlock()
|
|
||||||
if int(m.fragmentTotal) != len(item.messages) {
|
|
||||||
releaseMessages(item.messages)
|
|
||||||
item.messages = make([]*udpMessage, m.fragmentTotal)
|
|
||||||
item.count = 1
|
|
||||||
item.messages[m.fragmentID] = m
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if item.messages[m.fragmentID] != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
item.messages[m.fragmentID] = m
|
|
||||||
item.count++
|
|
||||||
if int(item.count) != len(item.messages) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
newMessage := allocMessage()
|
|
||||||
newMessage.sessionID = m.sessionID
|
|
||||||
newMessage.packetID = m.packetID
|
|
||||||
newMessage.destination = item.messages[0].destination
|
|
||||||
var finalLength int
|
|
||||||
for _, message := range item.messages {
|
|
||||||
finalLength += message.data.Len()
|
|
||||||
}
|
|
||||||
if finalLength > 0 {
|
|
||||||
newMessage.data = buf.NewSize(finalLength)
|
|
||||||
for _, message := range item.messages {
|
|
||||||
newMessage.data.Write(message.data.Bytes())
|
|
||||||
message.releaseMessage()
|
|
||||||
}
|
|
||||||
item.messages = nil
|
|
||||||
return newMessage
|
|
||||||
}
|
|
||||||
item.messages = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPacketItem() *packetItem {
|
|
||||||
return new(packetItem)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeUDPMessage(message *udpMessage, data []byte) error {
|
|
||||||
reader := bytes.NewReader(data)
|
|
||||||
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.packetID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
message.destination, err = protocol.ReadVString(reader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
message.data = buf.As(data[len(data)-reader.Len():])
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,106 +0,0 @@
|
||||||
package hysteria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2b"
|
|
||||||
)
|
|
||||||
|
|
||||||
const salamanderSaltLen = 8
|
|
||||||
|
|
||||||
const ObfsTypeSalamander = "salamander"
|
|
||||||
|
|
||||||
type Salamander struct {
|
|
||||||
net.PacketConn
|
|
||||||
password []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSalamanderConn(conn net.PacketConn, password []byte) net.PacketConn {
|
|
||||||
writer, isVectorised := bufio.CreateVectorisedPacketWriter(conn)
|
|
||||||
if isVectorised {
|
|
||||||
return &VectorisedSalamander{
|
|
||||||
Salamander: Salamander{
|
|
||||||
PacketConn: conn,
|
|
||||||
password: password,
|
|
||||||
},
|
|
||||||
writer: writer,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return &Salamander{
|
|
||||||
PacketConn: conn,
|
|
||||||
password: password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Salamander) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
n, addr, err = s.PacketConn.ReadFrom(p)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if n <= salamanderSaltLen {
|
|
||||||
return 0, nil, E.New("salamander: packet too short")
|
|
||||||
}
|
|
||||||
key := blake2b.Sum256(append(s.password, p[:salamanderSaltLen]...))
|
|
||||||
for index, c := range p[salamanderSaltLen:n] {
|
|
||||||
p[index] = c ^ key[index%blake2b.Size256]
|
|
||||||
}
|
|
||||||
return n - salamanderSaltLen, addr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Salamander) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
||||||
buffer := buf.NewSize(len(p) + salamanderSaltLen)
|
|
||||||
defer buffer.Release()
|
|
||||||
buffer.WriteRandom(salamanderSaltLen)
|
|
||||||
key := blake2b.Sum256(append(s.password, buffer.Bytes()...))
|
|
||||||
for index, c := range p {
|
|
||||||
common.Must(buffer.WriteByte(c ^ key[index%blake2b.Size256]))
|
|
||||||
}
|
|
||||||
_, err = s.PacketConn.WriteTo(buffer.Bytes(), addr)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type VectorisedSalamander struct {
|
|
||||||
Salamander
|
|
||||||
writer N.VectorisedPacketWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VectorisedSalamander) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
||||||
buffer := buf.NewSize(salamanderSaltLen)
|
|
||||||
buffer.WriteRandom(salamanderSaltLen)
|
|
||||||
key := blake2b.Sum256(append(s.password, buffer.Bytes()...))
|
|
||||||
for i := range p {
|
|
||||||
p[i] ^= key[i%blake2b.Size256]
|
|
||||||
}
|
|
||||||
err = s.writer.WriteVectorisedPacket([]*buf.Buffer{buffer, buf.As(p)}, M.SocksaddrFromNet(addr))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VectorisedSalamander) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
header := buf.NewSize(salamanderSaltLen)
|
|
||||||
defer header.Release()
|
|
||||||
header.WriteRandom(salamanderSaltLen)
|
|
||||||
key := blake2b.Sum256(append(s.password, header.Bytes()...))
|
|
||||||
var bufferIndex int
|
|
||||||
for _, buffer := range buffers {
|
|
||||||
content := buffer.Bytes()
|
|
||||||
for index, c := range content {
|
|
||||||
content[bufferIndex+index] = c ^ key[bufferIndex+index%blake2b.Size256]
|
|
||||||
}
|
|
||||||
bufferIndex += len(content)
|
|
||||||
}
|
|
||||||
return s.writer.WriteVectorisedPacket(append([]*buf.Buffer{header}, buffers...), destination)
|
|
||||||
}
|
|
|
@ -1,344 +0,0 @@
|
||||||
package hysteria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/quic-go/http3"
|
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
|
||||||
"github.com/sagernet/sing-box/transport/hysteria2/congestion"
|
|
||||||
"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
|
|
||||||
tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion"
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/auth"
|
|
||||||
"github.com/sagernet/sing/common/baderror"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
"github.com/sagernet/sing/common/logger"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServerOptions struct {
|
|
||||||
Context context.Context
|
|
||||||
Logger logger.Logger
|
|
||||||
SendBPS uint64
|
|
||||||
ReceiveBPS uint64
|
|
||||||
IgnoreClientBandwidth bool
|
|
||||||
SalamanderPassword string
|
|
||||||
TLSConfig tls.ServerConfig
|
|
||||||
Users []User
|
|
||||||
UDPDisabled bool
|
|
||||||
Handler ServerHandler
|
|
||||||
MasqueradeHandler http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
Name string
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerHandler interface {
|
|
||||||
N.TCPConnectionHandler
|
|
||||||
N.UDPConnectionHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
ctx context.Context
|
|
||||||
logger logger.Logger
|
|
||||||
sendBPS uint64
|
|
||||||
receiveBPS uint64
|
|
||||||
ignoreClientBandwidth bool
|
|
||||||
salamanderPassword string
|
|
||||||
tlsConfig tls.ServerConfig
|
|
||||||
quicConfig *quic.Config
|
|
||||||
userMap map[string]User
|
|
||||||
udpDisabled bool
|
|
||||||
handler ServerHandler
|
|
||||||
masqueradeHandler http.Handler
|
|
||||||
quicListener io.Closer
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServer(options ServerOptions) (*Server, error) {
|
|
||||||
quicConfig := &quic.Config{
|
|
||||||
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
|
||||||
EnableDatagrams: !options.UDPDisabled,
|
|
||||||
MaxIncomingStreams: 1 << 60,
|
|
||||||
InitialStreamReceiveWindow: defaultStreamReceiveWindow,
|
|
||||||
MaxStreamReceiveWindow: defaultStreamReceiveWindow,
|
|
||||||
InitialConnectionReceiveWindow: defaultConnReceiveWindow,
|
|
||||||
MaxConnectionReceiveWindow: defaultConnReceiveWindow,
|
|
||||||
MaxIdleTimeout: defaultMaxIdleTimeout,
|
|
||||||
KeepAlivePeriod: defaultKeepAlivePeriod,
|
|
||||||
}
|
|
||||||
if len(options.Users) == 0 {
|
|
||||||
return nil, E.New("missing users")
|
|
||||||
}
|
|
||||||
userMap := make(map[string]User)
|
|
||||||
for _, user := range options.Users {
|
|
||||||
userMap[user.Password] = user
|
|
||||||
}
|
|
||||||
if options.MasqueradeHandler == nil {
|
|
||||||
options.MasqueradeHandler = http.NotFoundHandler()
|
|
||||||
}
|
|
||||||
return &Server{
|
|
||||||
ctx: options.Context,
|
|
||||||
logger: options.Logger,
|
|
||||||
sendBPS: options.SendBPS,
|
|
||||||
receiveBPS: options.ReceiveBPS,
|
|
||||||
ignoreClientBandwidth: options.IgnoreClientBandwidth,
|
|
||||||
salamanderPassword: options.SalamanderPassword,
|
|
||||||
tlsConfig: options.TLSConfig,
|
|
||||||
quicConfig: quicConfig,
|
|
||||||
userMap: userMap,
|
|
||||||
udpDisabled: options.UDPDisabled,
|
|
||||||
handler: options.Handler,
|
|
||||||
masqueradeHandler: options.MasqueradeHandler,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Start(conn net.PacketConn) error {
|
|
||||||
if s.salamanderPassword != "" {
|
|
||||||
conn = NewSalamanderConn(conn, []byte(s.salamanderPassword))
|
|
||||||
}
|
|
||||||
err := qtls.ConfigureHTTP3(s.tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.quicListener = listener
|
|
||||||
go s.loopConnections(listener)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Close() error {
|
|
||||||
return common.Close(
|
|
||||||
s.quicListener,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) loopConnections(listener qtls.QUICListener) {
|
|
||||||
for {
|
|
||||||
connection, err := listener.Accept(s.ctx)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "server closed") {
|
|
||||||
s.logger.Debug(E.Cause(err, "listener closed"))
|
|
||||||
} else {
|
|
||||||
s.logger.Error(E.Cause(err, "listener closed"))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go s.handleConnection(connection)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleConnection(connection quic.Connection) {
|
|
||||||
session := &serverSession{
|
|
||||||
Server: s,
|
|
||||||
ctx: s.ctx,
|
|
||||||
quicConn: connection,
|
|
||||||
source: M.SocksaddrFromNet(connection.RemoteAddr()),
|
|
||||||
connDone: make(chan struct{}),
|
|
||||||
udpConnMap: make(map[uint32]*udpPacketConn),
|
|
||||||
}
|
|
||||||
httpServer := http3.Server{
|
|
||||||
Handler: session,
|
|
||||||
StreamHijacker: session.handleStream0,
|
|
||||||
}
|
|
||||||
_ = httpServer.ServeQUICConn(connection)
|
|
||||||
_ = connection.CloseWithError(0, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
type serverSession struct {
|
|
||||||
*Server
|
|
||||||
ctx context.Context
|
|
||||||
quicConn quic.Connection
|
|
||||||
source M.Socksaddr
|
|
||||||
connAccess sync.Mutex
|
|
||||||
connDone chan struct{}
|
|
||||||
connErr error
|
|
||||||
authenticated bool
|
|
||||||
authUser *User
|
|
||||||
udpAccess sync.RWMutex
|
|
||||||
udpConnMap map[uint32]*udpPacketConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
|
|
||||||
if s.authenticated {
|
|
||||||
protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
|
|
||||||
UDPEnabled: !s.udpDisabled,
|
|
||||||
Rx: s.receiveBPS,
|
|
||||||
RxAuto: s.ignoreClientBandwidth,
|
|
||||||
})
|
|
||||||
w.WriteHeader(protocol.StatusAuthOK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
request := protocol.AuthRequestFromHeader(r.Header)
|
|
||||||
user, loaded := s.userMap[request.Auth]
|
|
||||||
if !loaded {
|
|
||||||
s.masqueradeHandler.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.authUser = &user
|
|
||||||
s.authenticated = true
|
|
||||||
if !s.ignoreClientBandwidth && request.Rx > 0 {
|
|
||||||
var sendBps uint64
|
|
||||||
if s.sendBPS > 0 && s.sendBPS < request.Rx {
|
|
||||||
sendBps = s.sendBPS
|
|
||||||
} else {
|
|
||||||
sendBps = request.Rx
|
|
||||||
}
|
|
||||||
s.quicConn.SetCongestionControl(congestion.NewBrutalSender(sendBps))
|
|
||||||
} else {
|
|
||||||
s.quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
|
|
||||||
tuicCongestion.DefaultClock{},
|
|
||||||
tuicCongestion.GetInitialPacketSize(s.quicConn.RemoteAddr()),
|
|
||||||
tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
|
|
||||||
tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
|
|
||||||
UDPEnabled: !s.udpDisabled,
|
|
||||||
Rx: s.receiveBPS,
|
|
||||||
RxAuto: s.ignoreClientBandwidth,
|
|
||||||
})
|
|
||||||
w.WriteHeader(protocol.StatusAuthOK)
|
|
||||||
if s.ctx.Done() != nil {
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
s.closeWithError(s.ctx.Err())
|
|
||||||
case <-s.connDone:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
if !s.udpDisabled {
|
|
||||||
go s.loopMessages()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
s.masqueradeHandler.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
|
|
||||||
if !s.authenticated || err != nil {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
if frameType != protocol.FrameTypeTCPRequest {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
hErr := s.handleStream(stream)
|
|
||||||
stream.CancelRead(0)
|
|
||||||
stream.Close()
|
|
||||||
if hErr != nil {
|
|
||||||
stream.CancelRead(0)
|
|
||||||
stream.Close()
|
|
||||||
s.logger.Error(E.Cause(hErr, "handle stream request"))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleStream(stream quic.Stream) error {
|
|
||||||
destinationString, err := protocol.ReadTCPRequest(stream)
|
|
||||||
if err != nil {
|
|
||||||
return E.New("read TCP request")
|
|
||||||
}
|
|
||||||
ctx := s.ctx
|
|
||||||
if s.authUser.Name != "" {
|
|
||||||
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
|
|
||||||
}
|
|
||||||
_ = s.handler.NewConnection(ctx, &serverConn{Stream: stream}, M.Metadata{
|
|
||||||
Source: s.source,
|
|
||||||
Destination: M.ParseSocksaddr(destinationString),
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) closeWithError(err error) {
|
|
||||||
s.connAccess.Lock()
|
|
||||||
defer s.connAccess.Unlock()
|
|
||||||
select {
|
|
||||||
case <-s.connDone:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
s.connErr = err
|
|
||||||
close(s.connDone)
|
|
||||||
}
|
|
||||||
if E.IsClosedOrCanceled(err) {
|
|
||||||
s.logger.Debug(E.Cause(err, "connection failed"))
|
|
||||||
} else {
|
|
||||||
s.logger.Error(E.Cause(err, "connection failed"))
|
|
||||||
}
|
|
||||||
_ = s.quicConn.CloseWithError(0, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
type serverConn struct {
|
|
||||||
quic.Stream
|
|
||||||
responseWritten bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) HandshakeFailure(err error) error {
|
|
||||||
if c.responseWritten {
|
|
||||||
return os.ErrClosed
|
|
||||||
}
|
|
||||||
c.responseWritten = true
|
|
||||||
buffer := protocol.WriteTCPResponse(false, err.Error(), nil)
|
|
||||||
defer buffer.Release()
|
|
||||||
return common.Error(c.Stream.Write(buffer.Bytes()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) HandshakeSuccess() error {
|
|
||||||
if c.responseWritten {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
c.responseWritten = true
|
|
||||||
buffer := protocol.WriteTCPResponse(true, "", nil)
|
|
||||||
defer buffer.Release()
|
|
||||||
return common.Error(c.Stream.Write(buffer.Bytes()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) Read(p []byte) (n int, err error) {
|
|
||||||
n, err = c.Stream.Read(p)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
c.responseWritten = true
|
|
||||||
buffer := protocol.WriteTCPResponse(true, "", p)
|
|
||||||
defer buffer.Release()
|
|
||||||
_, err = c.Stream.Write(buffer.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return 0, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
n, err = c.Stream.Write(p)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) LocalAddr() net.Addr {
|
|
||||||
return M.Socksaddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) RemoteAddr() net.Addr {
|
|
||||||
return M.Socksaddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) Close() error {
|
|
||||||
c.Stream.CancelRead(0)
|
|
||||||
return c.Stream.Close()
|
|
||||||
}
|
|
|
@ -1,55 +0,0 @@
|
||||||
package hysteria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s *serverSession) loopMessages() {
|
|
||||||
for {
|
|
||||||
message, err := s.quicConn.ReceiveMessage(s.ctx)
|
|
||||||
if err != nil {
|
|
||||||
s.closeWithError(E.Cause(err, "receive message"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hErr := s.handleMessage(message)
|
|
||||||
if hErr != nil {
|
|
||||||
s.closeWithError(E.Cause(hErr, "handle message"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleMessage(data []byte) error {
|
|
||||||
message := allocMessage()
|
|
||||||
err := decodeUDPMessage(message, data)
|
|
||||||
if err != nil {
|
|
||||||
message.release()
|
|
||||||
return E.Cause(err, "decode UDP message")
|
|
||||||
}
|
|
||||||
s.handleUDPMessage(message)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleUDPMessage(message *udpMessage) {
|
|
||||||
s.udpAccess.RLock()
|
|
||||||
udpConn, loaded := s.udpConnMap[message.sessionID]
|
|
||||||
s.udpAccess.RUnlock()
|
|
||||||
if !loaded || common.Done(udpConn.ctx) {
|
|
||||||
udpConn = newUDPPacketConn(s.ctx, s.quicConn, func() {
|
|
||||||
s.udpAccess.Lock()
|
|
||||||
delete(s.udpConnMap, message.sessionID)
|
|
||||||
s.udpAccess.Unlock()
|
|
||||||
})
|
|
||||||
udpConn.sessionID = message.sessionID
|
|
||||||
s.udpAccess.Lock()
|
|
||||||
s.udpConnMap[message.sessionID] = udpConn
|
|
||||||
s.udpAccess.Unlock()
|
|
||||||
go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
|
|
||||||
Source: s.source,
|
|
||||||
Destination: M.ParseSocksaddr(message.destination),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
udpConn.inputPacket(message)
|
|
||||||
}
|
|
|
@ -1,10 +0,0 @@
|
||||||
package tuic
|
|
||||||
|
|
||||||
import M "github.com/sagernet/sing/common/metadata"
|
|
||||||
|
|
||||||
var addressSerializer = M.NewSerializer(
|
|
||||||
M.AddressFamilyByte(0x00, M.AddressFamilyFqdn),
|
|
||||||
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
|
|
||||||
M.AddressFamilyByte(0x02, M.AddressFamilyIPv6),
|
|
||||||
M.AddressFamilyByte(0xff, M.AddressFamilyEmpty),
|
|
||||||
)
|
|
|
@ -1,307 +0,0 @@
|
||||||
//go:build with_quic
|
|
||||||
|
|
||||||
package tuic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/baderror"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
|
|
||||||
"github.com/gofrs/uuid/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ClientOptions struct {
|
|
||||||
Context context.Context
|
|
||||||
Dialer N.Dialer
|
|
||||||
ServerAddress M.Socksaddr
|
|
||||||
TLSConfig tls.Config
|
|
||||||
UUID uuid.UUID
|
|
||||||
Password string
|
|
||||||
CongestionControl string
|
|
||||||
UDPStream bool
|
|
||||||
ZeroRTTHandshake bool
|
|
||||||
Heartbeat time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
ctx context.Context
|
|
||||||
dialer N.Dialer
|
|
||||||
serverAddr M.Socksaddr
|
|
||||||
tlsConfig tls.Config
|
|
||||||
quicConfig *quic.Config
|
|
||||||
uuid uuid.UUID
|
|
||||||
password string
|
|
||||||
congestionControl string
|
|
||||||
udpStream bool
|
|
||||||
zeroRTTHandshake bool
|
|
||||||
heartbeat time.Duration
|
|
||||||
|
|
||||||
connAccess sync.RWMutex
|
|
||||||
conn *clientQUICConnection
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClient(options ClientOptions) (*Client, error) {
|
|
||||||
if options.Heartbeat == 0 {
|
|
||||||
options.Heartbeat = 10 * time.Second
|
|
||||||
}
|
|
||||||
quicConfig := &quic.Config{
|
|
||||||
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
|
||||||
MaxDatagramFrameSize: 1400,
|
|
||||||
EnableDatagrams: true,
|
|
||||||
MaxIncomingUniStreams: 1 << 60,
|
|
||||||
}
|
|
||||||
switch options.CongestionControl {
|
|
||||||
case "":
|
|
||||||
options.CongestionControl = "cubic"
|
|
||||||
case "cubic", "new_reno", "bbr":
|
|
||||||
default:
|
|
||||||
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
|
|
||||||
}
|
|
||||||
return &Client{
|
|
||||||
ctx: options.Context,
|
|
||||||
dialer: options.Dialer,
|
|
||||||
serverAddr: options.ServerAddress,
|
|
||||||
tlsConfig: options.TLSConfig,
|
|
||||||
quicConfig: quicConfig,
|
|
||||||
uuid: options.UUID,
|
|
||||||
password: options.Password,
|
|
||||||
congestionControl: options.CongestionControl,
|
|
||||||
udpStream: options.UDPStream,
|
|
||||||
zeroRTTHandshake: options.ZeroRTTHandshake,
|
|
||||||
heartbeat: options.Heartbeat,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
|
|
||||||
conn := c.conn
|
|
||||||
if conn != nil && conn.active() {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
c.connAccess.Lock()
|
|
||||||
defer c.connAccess.Unlock()
|
|
||||||
conn = c.conn
|
|
||||||
if conn != nil && conn.active() {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
conn, err := c.offerNew(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
|
|
||||||
udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var quicConn quic.Connection
|
|
||||||
if c.zeroRTTHandshake {
|
|
||||||
quicConn, err = qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
|
|
||||||
} else {
|
|
||||||
quicConn, err = qtls.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
udpConn.Close()
|
|
||||||
return nil, E.Cause(err, "open connection")
|
|
||||||
}
|
|
||||||
setCongestion(c.ctx, quicConn, c.congestionControl)
|
|
||||||
conn := &clientQUICConnection{
|
|
||||||
quicConn: quicConn,
|
|
||||||
rawConn: udpConn,
|
|
||||||
connDone: make(chan struct{}),
|
|
||||||
udpConnMap: make(map[uint16]*udpPacketConn),
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
hErr := c.clientHandshake(quicConn)
|
|
||||||
if hErr != nil {
|
|
||||||
conn.closeWithError(hErr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if c.udpStream {
|
|
||||||
go c.loopUniStreams(conn)
|
|
||||||
}
|
|
||||||
go c.loopMessages(conn)
|
|
||||||
go c.loopHeartbeats(conn)
|
|
||||||
c.conn = conn
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) clientHandshake(conn quic.Connection) error {
|
|
||||||
authStream, err := conn.OpenUniStream()
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "open handshake stream")
|
|
||||||
}
|
|
||||||
defer authStream.Close()
|
|
||||||
handshakeState := conn.ConnectionState()
|
|
||||||
tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "export keying material")
|
|
||||||
}
|
|
||||||
authRequest := buf.NewSize(AuthenticateLen)
|
|
||||||
authRequest.WriteByte(Version)
|
|
||||||
authRequest.WriteByte(CommandAuthenticate)
|
|
||||||
authRequest.Write(c.uuid[:])
|
|
||||||
authRequest.Write(tuicAuthToken)
|
|
||||||
return common.Error(authStream.Write(authRequest.Bytes()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) loopHeartbeats(conn *clientQUICConnection) {
|
|
||||||
ticker := time.NewTicker(c.heartbeat)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-conn.connDone:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
|
|
||||||
if err != nil {
|
|
||||||
conn.closeWithError(E.Cause(err, "send heartbeat"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
|
||||||
conn, err := c.offer(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stream, err := conn.quicConn.OpenStream()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &clientConn{
|
|
||||||
Stream: stream,
|
|
||||||
parent: conn,
|
|
||||||
destination: destination,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
|
|
||||||
conn, err := c.offer(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var sessionID uint16
|
|
||||||
clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() {
|
|
||||||
conn.udpAccess.Lock()
|
|
||||||
delete(conn.udpConnMap, sessionID)
|
|
||||||
conn.udpAccess.Unlock()
|
|
||||||
})
|
|
||||||
conn.udpAccess.Lock()
|
|
||||||
sessionID = conn.udpSessionID
|
|
||||||
conn.udpSessionID++
|
|
||||||
conn.udpConnMap[sessionID] = clientPacketConn
|
|
||||||
conn.udpAccess.Unlock()
|
|
||||||
clientPacketConn.sessionID = sessionID
|
|
||||||
return clientPacketConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) CloseWithError(err error) error {
|
|
||||||
conn := c.conn
|
|
||||||
if conn != nil {
|
|
||||||
conn.closeWithError(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientQUICConnection struct {
|
|
||||||
quicConn quic.Connection
|
|
||||||
rawConn io.Closer
|
|
||||||
closeOnce sync.Once
|
|
||||||
connDone chan struct{}
|
|
||||||
connErr error
|
|
||||||
udpAccess sync.RWMutex
|
|
||||||
udpConnMap map[uint16]*udpPacketConn
|
|
||||||
udpSessionID uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientQUICConnection) active() bool {
|
|
||||||
select {
|
|
||||||
case <-c.quicConn.Context().Done():
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-c.connDone:
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientQUICConnection) closeWithError(err error) {
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
c.connErr = err
|
|
||||||
close(c.connDone)
|
|
||||||
_ = c.quicConn.CloseWithError(0, "")
|
|
||||||
_ = c.rawConn.Close()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientConn struct {
|
|
||||||
quic.Stream
|
|
||||||
parent *clientQUICConnection
|
|
||||||
destination M.Socksaddr
|
|
||||||
requestWritten bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) NeedHandshake() bool {
|
|
||||||
return !c.requestWritten
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) Read(b []byte) (n int, err error) {
|
|
||||||
n, err = c.Stream.Read(b)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) Write(b []byte) (n int, err error) {
|
|
||||||
if !c.requestWritten {
|
|
||||||
request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
|
|
||||||
defer request.Release()
|
|
||||||
request.WriteByte(Version)
|
|
||||||
request.WriteByte(CommandConnect)
|
|
||||||
err = addressSerializer.WriteAddrPort(request, c.destination)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
request.Write(b)
|
|
||||||
_, err = c.Stream.Write(request.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
c.parent.closeWithError(E.Cause(err, "create new connection"))
|
|
||||||
return 0, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
c.requestWritten = true
|
|
||||||
return len(b), nil
|
|
||||||
}
|
|
||||||
n, err = c.Stream.Write(b)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) Close() error {
|
|
||||||
c.Stream.CancelRead(0)
|
|
||||||
return c.Stream.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) LocalAddr() net.Addr {
|
|
||||||
return M.Socksaddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientConn) RemoteAddr() net.Addr {
|
|
||||||
return c.destination
|
|
||||||
}
|
|
|
@ -1,112 +0,0 @@
|
||||||
//go:build with_quic
|
|
||||||
|
|
||||||
package tuic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (c *Client) loopMessages(conn *clientQUICConnection) {
|
|
||||||
for {
|
|
||||||
message, err := conn.quicConn.ReceiveMessage(c.ctx)
|
|
||||||
if err != nil {
|
|
||||||
conn.closeWithError(E.Cause(err, "receive message"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
hErr := c.handleMessage(conn, message)
|
|
||||||
if hErr != nil {
|
|
||||||
conn.closeWithError(E.Cause(hErr, "handle message"))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
|
|
||||||
if len(data) < 2 {
|
|
||||||
return E.New("invalid message")
|
|
||||||
}
|
|
||||||
if data[0] != Version {
|
|
||||||
return E.New("unknown version ", data[0])
|
|
||||||
}
|
|
||||||
switch data[1] {
|
|
||||||
case CommandPacket:
|
|
||||||
message := allocMessage()
|
|
||||||
err := decodeUDPMessage(message, data[2:])
|
|
||||||
if err != nil {
|
|
||||||
message.release()
|
|
||||||
return E.Cause(err, "decode UDP message")
|
|
||||||
}
|
|
||||||
conn.handleUDPMessage(message)
|
|
||||||
return nil
|
|
||||||
case CommandHeartbeat:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return E.New("unknown command ", data[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) loopUniStreams(conn *clientQUICConnection) {
|
|
||||||
for {
|
|
||||||
stream, err := conn.quicConn.AcceptUniStream(c.ctx)
|
|
||||||
if err != nil {
|
|
||||||
conn.closeWithError(E.Cause(err, "handle uni stream"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
hErr := c.handleUniStream(conn, stream)
|
|
||||||
if hErr != nil {
|
|
||||||
conn.closeWithError(hErr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error {
|
|
||||||
defer stream.CancelRead(0)
|
|
||||||
buffer := buf.NewPacket()
|
|
||||||
defer buffer.Release()
|
|
||||||
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
version, _ := buffer.ReadByte()
|
|
||||||
if version != Version {
|
|
||||||
return E.New("unknown version ", version)
|
|
||||||
}
|
|
||||||
command, _ := buffer.ReadByte()
|
|
||||||
if command != CommandPacket {
|
|
||||||
return E.New("unknown command ", command)
|
|
||||||
}
|
|
||||||
reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
|
|
||||||
message := allocMessage()
|
|
||||||
err = readUDPMessage(message, reader)
|
|
||||||
if err != nil {
|
|
||||||
message.release()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn.handleUDPMessage(message)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
|
|
||||||
c.udpAccess.RLock()
|
|
||||||
udpConn, loaded := c.udpConnMap[message.sessionID]
|
|
||||||
c.udpAccess.RUnlock()
|
|
||||||
if !loaded {
|
|
||||||
message.releaseMessage()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-udpConn.ctx.Done():
|
|
||||||
message.releaseMessage()
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
udpConn.inputPacket(message)
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
package tuic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/sing-box/transport/tuic/congestion"
|
|
||||||
"github.com/sagernet/sing/common/ntp"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) {
|
|
||||||
timeFunc := ntp.TimeFuncFromContext(ctx)
|
|
||||||
if timeFunc == nil {
|
|
||||||
timeFunc = time.Now
|
|
||||||
}
|
|
||||||
switch congestionName {
|
|
||||||
case "cubic":
|
|
||||||
connection.SetCongestionControl(
|
|
||||||
congestion.NewCubicSender(
|
|
||||||
congestion.DefaultClock{TimeFunc: timeFunc},
|
|
||||||
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
case "new_reno":
|
|
||||||
connection.SetCongestionControl(
|
|
||||||
congestion.NewCubicSender(
|
|
||||||
congestion.DefaultClock{TimeFunc: timeFunc},
|
|
||||||
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
|
||||||
true,
|
|
||||||
nil,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
case "bbr":
|
|
||||||
connection.SetCongestionControl(
|
|
||||||
congestion.NewBBRSender(
|
|
||||||
congestion.DefaultClock{},
|
|
||||||
congestion.GetInitialPacketSize(connection.RemoteAddr()),
|
|
||||||
congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize,
|
|
||||||
congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,3 +0,0 @@
|
||||||
# congestion
|
|
||||||
|
|
||||||
mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion
|
|
|
@ -1,25 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/congestion"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Bandwidth of a connection
|
|
||||||
type Bandwidth uint64
|
|
||||||
|
|
||||||
const infBandwidth Bandwidth = math.MaxUint64
|
|
||||||
|
|
||||||
const (
|
|
||||||
// BitsPerSecond is 1 bit per second
|
|
||||||
BitsPerSecond Bandwidth = 1
|
|
||||||
// BytesPerSecond is 1 byte per second
|
|
||||||
BytesPerSecond = 8 * BitsPerSecond
|
|
||||||
)
|
|
||||||
|
|
||||||
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
|
|
||||||
func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth {
|
|
||||||
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
|
|
||||||
}
|
|
|
@ -1,374 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/congestion"
|
|
||||||
)
|
|
||||||
|
|
||||||
var InfiniteBandwidth = Bandwidth(math.MaxUint64)
|
|
||||||
|
|
||||||
// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned
|
|
||||||
// to the caller when the packet is acked or lost.
|
|
||||||
type SendTimeState struct {
|
|
||||||
// Whether other states in this object is valid.
|
|
||||||
isValid bool
|
|
||||||
// Whether the sender is app limited at the time the packet was sent.
|
|
||||||
// App limited bandwidth sample might be artificially low because the sender
|
|
||||||
// did not have enough data to send in order to saturate the link.
|
|
||||||
isAppLimited bool
|
|
||||||
// Total number of sent bytes at the time the packet was sent.
|
|
||||||
// Includes the packet itself.
|
|
||||||
totalBytesSent congestion.ByteCount
|
|
||||||
// Total number of acked bytes at the time the packet was sent.
|
|
||||||
totalBytesAcked congestion.ByteCount
|
|
||||||
// Total number of lost bytes at the time the packet was sent.
|
|
||||||
totalBytesLost congestion.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionStateOnSentPacket represents the information about a sent packet
|
|
||||||
// and the state of the connection at the moment the packet was sent,
|
|
||||||
// specifically the information about the most recently acknowledged packet at
|
|
||||||
// that moment.
|
|
||||||
type ConnectionStateOnSentPacket struct {
|
|
||||||
packetNumber congestion.PacketNumber
|
|
||||||
// Time at which the packet is sent.
|
|
||||||
sendTime time.Time
|
|
||||||
// Size of the packet.
|
|
||||||
size congestion.ByteCount
|
|
||||||
// The value of |totalBytesSentAtLastAckedPacket| at the time the
|
|
||||||
// packet was sent.
|
|
||||||
totalBytesSentAtLastAckedPacket congestion.ByteCount
|
|
||||||
// The value of |lastAckedPacketSentTime| at the time the packet was
|
|
||||||
// sent.
|
|
||||||
lastAckedPacketSentTime time.Time
|
|
||||||
// The value of |lastAckedPacketAckTime| at the time the packet was
|
|
||||||
// sent.
|
|
||||||
lastAckedPacketAckTime time.Time
|
|
||||||
// Send time states that are returned to the congestion controller when the
|
|
||||||
// packet is acked or lost.
|
|
||||||
sendTimeState SendTimeState
|
|
||||||
}
|
|
||||||
|
|
||||||
// BandwidthSample
|
|
||||||
type BandwidthSample struct {
|
|
||||||
// The bandwidth at that particular sample. Zero if no valid bandwidth sample
|
|
||||||
// is available.
|
|
||||||
bandwidth Bandwidth
|
|
||||||
// The RTT measurement at this particular sample. Zero if no RTT sample is
|
|
||||||
// available. Does not correct for delayed ack time.
|
|
||||||
rtt time.Duration
|
|
||||||
// States captured when the packet was sent.
|
|
||||||
stateAtSend SendTimeState
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBandwidthSample() *BandwidthSample {
|
|
||||||
return &BandwidthSample{
|
|
||||||
// FIXME: the default value of original code is zero.
|
|
||||||
rtt: InfiniteRTT,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BandwidthSampler keeps track of sent and acknowledged packets and outputs a
|
|
||||||
// bandwidth sample for every packet acknowledged. The samples are taken for
|
|
||||||
// individual packets, and are not filtered; the consumer has to filter the
|
|
||||||
// bandwidth samples itself. In certain cases, the sampler will locally severely
|
|
||||||
// underestimate the bandwidth, hence a maximum filter with a size of at least
|
|
||||||
// one RTT is recommended.
|
|
||||||
//
|
|
||||||
// This class bases its samples on the slope of two curves: the number of bytes
|
|
||||||
// sent over time, and the number of bytes acknowledged as received over time.
|
|
||||||
// It produces a sample of both slopes for every packet that gets acknowledged,
|
|
||||||
// based on a slope between two points on each of the corresponding curves. Note
|
|
||||||
// that due to the packet loss, the number of bytes on each curve might get
|
|
||||||
// further and further away from each other, meaning that it is not feasible to
|
|
||||||
// compare byte values coming from different curves with each other.
|
|
||||||
//
|
|
||||||
// The obvious points for measuring slope sample are the ones corresponding to
|
|
||||||
// the packet that was just acknowledged. Let us denote them as S_1 (point at
|
|
||||||
// which the current packet was sent) and A_1 (point at which the current packet
|
|
||||||
// was acknowledged). However, taking a slope requires two points on each line,
|
|
||||||
// so estimating bandwidth requires picking a packet in the past with respect to
|
|
||||||
// which the slope is measured.
|
|
||||||
//
|
|
||||||
// For that purpose, BandwidthSampler always keeps track of the most recently
|
|
||||||
// acknowledged packet, and records it together with every outgoing packet.
|
|
||||||
// When a packet gets acknowledged (A_1), it has not only information about when
|
|
||||||
// it itself was sent (S_1), but also the information about the latest
|
|
||||||
// acknowledged packet right before it was sent (S_0 and A_0).
|
|
||||||
//
|
|
||||||
// Based on that data, send and ack rate are estimated as:
|
|
||||||
//
|
|
||||||
// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0))
|
|
||||||
// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0))
|
|
||||||
//
|
|
||||||
// Here, the ack rate is intuitively the rate we want to treat as bandwidth.
|
|
||||||
// However, in certain cases (e.g. ack compression) the ack rate at a point may
|
|
||||||
// end up higher than the rate at which the data was originally sent, which is
|
|
||||||
// not indicative of the real bandwidth. Hence, we use the send rate as an upper
|
|
||||||
// bound, and the sample value is
|
|
||||||
//
|
|
||||||
// rate_sample = min(send_rate, ack_rate)
|
|
||||||
//
|
|
||||||
// An important edge case handled by the sampler is tracking the app-limited
|
|
||||||
// samples. There are multiple meaning of "app-limited" used interchangeably,
|
|
||||||
// hence it is important to understand and to be able to distinguish between
|
|
||||||
// them.
|
|
||||||
//
|
|
||||||
// Meaning 1: connection state. The connection is said to be app-limited when
|
|
||||||
// there is no outstanding data to send. This means that certain bandwidth
|
|
||||||
// samples in the future would not be an accurate indication of the link
|
|
||||||
// capacity, and it is important to inform consumer about that. Whenever
|
|
||||||
// connection becomes app-limited, the sampler is notified via OnAppLimited()
|
|
||||||
// method.
|
|
||||||
//
|
|
||||||
// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth
|
|
||||||
// sampler becomes notified about the connection being app-limited, it enters
|
|
||||||
// app-limited phase. In that phase, all *sent* packets are marked as
|
|
||||||
// app-limited. Note that the connection itself does not have to be
|
|
||||||
// app-limited during the app-limited phase, and in fact it will not be
|
|
||||||
// (otherwise how would it send packets?). The boolean flag below indicates
|
|
||||||
// whether the sampler is in that phase.
|
|
||||||
//
|
|
||||||
// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is
|
|
||||||
// sent during the app-limited phase, the resulting sample related to the
|
|
||||||
// packet will be marked as app-limited.
|
|
||||||
//
|
|
||||||
// With the terminology issue out of the way, let us consider the question of
|
|
||||||
// what kind of situation it addresses.
|
|
||||||
//
|
|
||||||
// Consider a scenario where we first send packets 1 to 20 at a regular
|
|
||||||
// bandwidth, and then immediately run out of data. After a few seconds, we send
|
|
||||||
// packets 21 to 60, and only receive ack for 21 between sending packets 40 and
|
|
||||||
// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0
|
|
||||||
// we use to compute the slope is going to be packet 20, a few seconds apart
|
|
||||||
// from the current packet, hence the resulting estimate would be extremely low
|
|
||||||
// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21,
|
|
||||||
// meaning that the bandwidth sample would exclude the quiescence.
|
|
||||||
//
|
|
||||||
// Based on the analysis of that scenario, we implement the following rule: once
|
|
||||||
// OnAppLimited() is called, all sent packets will produce app-limited samples
|
|
||||||
// up until an ack for a packet that was sent after OnAppLimited() was called.
|
|
||||||
// Note that while the scenario above is not the only scenario when the
|
|
||||||
// connection is app-limited, the approach works in other cases too.
|
|
||||||
type BandwidthSampler struct {
|
|
||||||
// The total number of congestion controlled bytes sent during the connection.
|
|
||||||
totalBytesSent congestion.ByteCount
|
|
||||||
// The total number of congestion controlled bytes which were acknowledged.
|
|
||||||
totalBytesAcked congestion.ByteCount
|
|
||||||
// The total number of congestion controlled bytes which were lost.
|
|
||||||
totalBytesLost congestion.ByteCount
|
|
||||||
// The value of |totalBytesSent| at the time the last acknowledged packet
|
|
||||||
// was sent. Valid only when |lastAckedPacketSentTime| is valid.
|
|
||||||
totalBytesSentAtLastAckedPacket congestion.ByteCount
|
|
||||||
// The time at which the last acknowledged packet was sent. Set to
|
|
||||||
// QuicTime::Zero() if no valid timestamp is available.
|
|
||||||
lastAckedPacketSentTime time.Time
|
|
||||||
// The time at which the most recent packet was acknowledged.
|
|
||||||
lastAckedPacketAckTime time.Time
|
|
||||||
// The most recently sent packet.
|
|
||||||
lastSendPacket congestion.PacketNumber
|
|
||||||
// Indicates whether the bandwidth sampler is currently in an app-limited
|
|
||||||
// phase.
|
|
||||||
isAppLimited bool
|
|
||||||
// The packet that will be acknowledged after this one will cause the sampler
|
|
||||||
// to exit the app-limited phase.
|
|
||||||
endOfAppLimitedPhase congestion.PacketNumber
|
|
||||||
// Record of the connection state at the point where each packet in flight was
|
|
||||||
// sent, indexed by the packet number.
|
|
||||||
connectionStats *ConnectionStates
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBandwidthSampler() *BandwidthSampler {
|
|
||||||
return &BandwidthSampler{
|
|
||||||
connectionStats: &ConnectionStates{
|
|
||||||
stats: make(map[congestion.PacketNumber]*ConnectionStateOnSentPacket),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnPacketSent Inputs the sent packet information into the sampler. Assumes that all
|
|
||||||
// packets are sent in order. The information about the packet will not be
|
|
||||||
// released from the sampler until it the packet is either acknowledged or
|
|
||||||
// declared lost.
|
|
||||||
func (s *BandwidthSampler) OnPacketSent(sentTime time.Time, lastSentPacket congestion.PacketNumber, sentBytes, bytesInFlight congestion.ByteCount, hasRetransmittableData bool) {
|
|
||||||
s.lastSendPacket = lastSentPacket
|
|
||||||
|
|
||||||
if !hasRetransmittableData {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.totalBytesSent += sentBytes
|
|
||||||
|
|
||||||
// If there are no packets in flight, the time at which the new transmission
|
|
||||||
// opens can be treated as the A_0 point for the purpose of bandwidth
|
|
||||||
// sampling. This underestimates bandwidth to some extent, and produces some
|
|
||||||
// artificially low samples for most packets in flight, but it provides with
|
|
||||||
// samples at important points where we would not have them otherwise, most
|
|
||||||
// importantly at the beginning of the connection.
|
|
||||||
if bytesInFlight == 0 {
|
|
||||||
s.lastAckedPacketAckTime = sentTime
|
|
||||||
s.totalBytesSentAtLastAckedPacket = s.totalBytesSent
|
|
||||||
|
|
||||||
// In this situation ack compression is not a concern, set send rate to
|
|
||||||
// effectively infinite.
|
|
||||||
s.lastAckedPacketSentTime = sentTime
|
|
||||||
}
|
|
||||||
|
|
||||||
s.connectionStats.Insert(lastSentPacket, sentTime, sentBytes, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnPacketAcked Notifies the sampler that the |lastAckedPacket| is acknowledged. Returns a
|
|
||||||
// bandwidth sample. If no bandwidth sample is available,
|
|
||||||
// QuicBandwidth::Zero() is returned.
|
|
||||||
func (s *BandwidthSampler) OnPacketAcked(ackTime time.Time, lastAckedPacket congestion.PacketNumber) *BandwidthSample {
|
|
||||||
sentPacketState := s.connectionStats.Get(lastAckedPacket)
|
|
||||||
if sentPacketState == nil {
|
|
||||||
return NewBandwidthSample()
|
|
||||||
}
|
|
||||||
|
|
||||||
sample := s.onPacketAckedInner(ackTime, lastAckedPacket, sentPacketState)
|
|
||||||
s.connectionStats.Remove(lastAckedPacket)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
}
|
|
||||||
|
|
||||||
// onPacketAckedInner Handles the actual bandwidth calculations, whereas the outer method handles
|
|
||||||
// retrieving and removing |sentPacket|.
|
|
||||||
func (s *BandwidthSampler) onPacketAckedInner(ackTime time.Time, lastAckedPacket congestion.PacketNumber, sentPacket *ConnectionStateOnSentPacket) *BandwidthSample {
|
|
||||||
s.totalBytesAcked += sentPacket.size
|
|
||||||
|
|
||||||
s.totalBytesSentAtLastAckedPacket = sentPacket.sendTimeState.totalBytesSent
|
|
||||||
s.lastAckedPacketSentTime = sentPacket.sendTime
|
|
||||||
s.lastAckedPacketAckTime = ackTime
|
|
||||||
|
|
||||||
// Exit app-limited phase once a packet that was sent while the connection is
|
|
||||||
// not app-limited is acknowledged.
|
|
||||||
if s.isAppLimited && lastAckedPacket > s.endOfAppLimitedPhase {
|
|
||||||
s.isAppLimited = false
|
|
||||||
}
|
|
||||||
|
|
||||||
// There might have been no packets acknowledged at the moment when the
|
|
||||||
// current packet was sent. In that case, there is no bandwidth sample to
|
|
||||||
// make.
|
|
||||||
if sentPacket.lastAckedPacketSentTime.IsZero() {
|
|
||||||
return NewBandwidthSample()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Infinite rate indicates that the sampler is supposed to discard the
|
|
||||||
// current send rate sample and use only the ack rate.
|
|
||||||
sendRate := InfiniteBandwidth
|
|
||||||
if sentPacket.sendTime.After(sentPacket.lastAckedPacketSentTime) {
|
|
||||||
sendRate = BandwidthFromDelta(sentPacket.sendTimeState.totalBytesSent-sentPacket.totalBytesSentAtLastAckedPacket, sentPacket.sendTime.Sub(sentPacket.lastAckedPacketSentTime))
|
|
||||||
}
|
|
||||||
|
|
||||||
// During the slope calculation, ensure that ack time of the current packet is
|
|
||||||
// always larger than the time of the previous packet, otherwise division by
|
|
||||||
// zero or integer underflow can occur.
|
|
||||||
if !ackTime.After(sentPacket.lastAckedPacketAckTime) {
|
|
||||||
// TODO(wub): Compare this code count before and after fixing clock jitter
|
|
||||||
// issue.
|
|
||||||
// if sentPacket.lastAckedPacketAckTime.Equal(sentPacket.sendTime) {
|
|
||||||
// This is the 1st packet after quiescense.
|
|
||||||
// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2);
|
|
||||||
// } else {
|
|
||||||
// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2);
|
|
||||||
// }
|
|
||||||
|
|
||||||
return NewBandwidthSample()
|
|
||||||
}
|
|
||||||
|
|
||||||
ackRate := BandwidthFromDelta(s.totalBytesAcked-sentPacket.sendTimeState.totalBytesAcked,
|
|
||||||
ackTime.Sub(sentPacket.lastAckedPacketAckTime))
|
|
||||||
|
|
||||||
// Note: this sample does not account for delayed acknowledgement time. This
|
|
||||||
// means that the RTT measurements here can be artificially high, especially
|
|
||||||
// on low bandwidth connections.
|
|
||||||
sample := &BandwidthSample{
|
|
||||||
bandwidth: minBandwidth(sendRate, ackRate),
|
|
||||||
rtt: ackTime.Sub(sentPacket.sendTime),
|
|
||||||
}
|
|
||||||
|
|
||||||
SentPacketToSendTimeState(sentPacket, &sample.stateAtSend)
|
|
||||||
return sample
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnPacketLost Informs the sampler that a packet is considered lost and it should no
|
|
||||||
// longer keep track of it.
|
|
||||||
func (s *BandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber) SendTimeState {
|
|
||||||
ok, sentPacket := s.connectionStats.Remove(packetNumber)
|
|
||||||
sendTimeState := SendTimeState{
|
|
||||||
isValid: ok,
|
|
||||||
}
|
|
||||||
if sentPacket != nil {
|
|
||||||
s.totalBytesLost += sentPacket.size
|
|
||||||
SentPacketToSendTimeState(sentPacket, &sendTimeState)
|
|
||||||
}
|
|
||||||
|
|
||||||
return sendTimeState
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnAppLimited Informs the sampler that the connection is currently app-limited, causing
|
|
||||||
// the sampler to enter the app-limited phase. The phase will expire by
|
|
||||||
// itself.
|
|
||||||
func (s *BandwidthSampler) OnAppLimited() {
|
|
||||||
s.isAppLimited = true
|
|
||||||
s.endOfAppLimitedPhase = s.lastSendPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
// SentPacketToSendTimeState Copy a subset of the (private) ConnectionStateOnSentPacket to the (public)
|
|
||||||
// SendTimeState. Always set send_time_state->is_valid to true.
|
|
||||||
func SentPacketToSendTimeState(sentPacket *ConnectionStateOnSentPacket, sendTimeState *SendTimeState) {
|
|
||||||
sendTimeState.isAppLimited = sentPacket.sendTimeState.isAppLimited
|
|
||||||
sendTimeState.totalBytesSent = sentPacket.sendTimeState.totalBytesSent
|
|
||||||
sendTimeState.totalBytesAcked = sentPacket.sendTimeState.totalBytesAcked
|
|
||||||
sendTimeState.totalBytesLost = sentPacket.sendTimeState.totalBytesLost
|
|
||||||
sendTimeState.isValid = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionStates Record of the connection state at the point where each packet in flight was
|
|
||||||
// sent, indexed by the packet number.
|
|
||||||
// FIXME: using LinkedList replace map to fast remove all the packets lower than the specified packet number.
|
|
||||||
type ConnectionStates struct {
|
|
||||||
stats map[congestion.PacketNumber]*ConnectionStateOnSentPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ConnectionStates) Insert(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) bool {
|
|
||||||
if _, ok := s.stats[packetNumber]; ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
s.stats[packetNumber] = NewConnectionStateOnSentPacket(packetNumber, sentTime, bytes, sampler)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ConnectionStates) Get(packetNumber congestion.PacketNumber) *ConnectionStateOnSentPacket {
|
|
||||||
return s.stats[packetNumber]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ConnectionStates) Remove(packetNumber congestion.PacketNumber) (bool, *ConnectionStateOnSentPacket) {
|
|
||||||
state, ok := s.stats[packetNumber]
|
|
||||||
if ok {
|
|
||||||
delete(s.stats, packetNumber)
|
|
||||||
}
|
|
||||||
return ok, state
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConnectionStateOnSentPacket(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) *ConnectionStateOnSentPacket {
|
|
||||||
return &ConnectionStateOnSentPacket{
|
|
||||||
packetNumber: packetNumber,
|
|
||||||
sendTime: sentTime,
|
|
||||||
size: bytes,
|
|
||||||
lastAckedPacketSentTime: sampler.lastAckedPacketSentTime,
|
|
||||||
lastAckedPacketAckTime: sampler.lastAckedPacketAckTime,
|
|
||||||
totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket,
|
|
||||||
sendTimeState: SendTimeState{
|
|
||||||
isValid: true,
|
|
||||||
isAppLimited: sampler.isAppLimited,
|
|
||||||
totalBytesSent: sampler.totalBytesSent,
|
|
||||||
totalBytesAcked: sampler.totalBytesAcked,
|
|
||||||
totalBytesLost: sampler.totalBytesLost,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,20 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// A Clock returns the current time
|
|
||||||
type Clock interface {
|
|
||||||
Now() time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultClock implements the Clock interface using the Go stdlib clock.
|
|
||||||
type DefaultClock struct {
|
|
||||||
TimeFunc func() time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Clock = DefaultClock{}
|
|
||||||
|
|
||||||
// Now gets the current time
|
|
||||||
func (c DefaultClock) Now() time.Time {
|
|
||||||
return c.TimeFunc()
|
|
||||||
}
|
|
|
@ -1,213 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/congestion"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
|
||||||
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
|
|
||||||
|
|
||||||
// Constants based on TCP defaults.
|
|
||||||
// The following constants are in 2^10 fractions of a second instead of ms to
|
|
||||||
// allow a 10 shift right to divide.
|
|
||||||
|
|
||||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
|
||||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
|
||||||
const (
|
|
||||||
cubeScale = 40
|
|
||||||
cubeCongestionWindowScale = 410
|
|
||||||
cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
|
|
||||||
// TODO: when re-enabling cubic, make sure to use the actual packet size here
|
|
||||||
maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4)
|
|
||||||
)
|
|
||||||
|
|
||||||
const defaultNumConnections = 1
|
|
||||||
|
|
||||||
// Default Cubic backoff factor
|
|
||||||
const beta float32 = 0.7
|
|
||||||
|
|
||||||
// Additional backoff factor when loss occurs in the concave part of the Cubic
|
|
||||||
// curve. This additional backoff factor is expected to give up bandwidth to
|
|
||||||
// new concurrent flows and speed up convergence.
|
|
||||||
const betaLastMax float32 = 0.85
|
|
||||||
|
|
||||||
// Cubic implements the cubic algorithm from TCP
|
|
||||||
type Cubic struct {
|
|
||||||
clock Clock
|
|
||||||
|
|
||||||
// Number of connections to simulate.
|
|
||||||
numConnections int
|
|
||||||
|
|
||||||
// Time when this cycle started, after last loss event.
|
|
||||||
epoch time.Time
|
|
||||||
|
|
||||||
// Max congestion window used just before last loss event.
|
|
||||||
// Note: to improve fairness to other streams an additional back off is
|
|
||||||
// applied to this value if the new value is below our latest value.
|
|
||||||
lastMaxCongestionWindow congestion.ByteCount
|
|
||||||
|
|
||||||
// Number of acked bytes since the cycle started (epoch).
|
|
||||||
ackedBytesCount congestion.ByteCount
|
|
||||||
|
|
||||||
// TCP Reno equivalent congestion window in packets.
|
|
||||||
estimatedTCPcongestionWindow congestion.ByteCount
|
|
||||||
|
|
||||||
// Origin point of cubic function.
|
|
||||||
originPointCongestionWindow congestion.ByteCount
|
|
||||||
|
|
||||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
|
||||||
timeToOriginPoint uint32
|
|
||||||
|
|
||||||
// Last congestion window in packets computed by cubic function.
|
|
||||||
lastTargetCongestionWindow congestion.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCubic returns a new Cubic instance
|
|
||||||
func NewCubic(clock Clock) *Cubic {
|
|
||||||
c := &Cubic{
|
|
||||||
clock: clock,
|
|
||||||
numConnections: defaultNumConnections,
|
|
||||||
}
|
|
||||||
c.Reset()
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset is called after a timeout to reset the cubic state
|
|
||||||
func (c *Cubic) Reset() {
|
|
||||||
c.epoch = time.Time{}
|
|
||||||
c.lastMaxCongestionWindow = 0
|
|
||||||
c.ackedBytesCount = 0
|
|
||||||
c.estimatedTCPcongestionWindow = 0
|
|
||||||
c.originPointCongestionWindow = 0
|
|
||||||
c.timeToOriginPoint = 0
|
|
||||||
c.lastTargetCongestionWindow = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cubic) alpha() float32 {
|
|
||||||
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
|
|
||||||
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
|
|
||||||
// We derive the equivalent alpha for an N-connection emulation as:
|
|
||||||
b := c.beta()
|
|
||||||
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cubic) beta() float32 {
|
|
||||||
// kNConnectionBeta is the backoff factor after loss for our N-connection
|
|
||||||
// emulation, which emulates the effective backoff of an ensemble of N
|
|
||||||
// TCP-Reno connections on a single loss event. The effective multiplier is
|
|
||||||
// computed as:
|
|
||||||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cubic) betaLastMax() float32 {
|
|
||||||
// betaLastMax is the additional backoff factor after loss for our
|
|
||||||
// N-connection emulation, which emulates the additional backoff of
|
|
||||||
// an ensemble of N TCP-Reno connections on a single loss event. The
|
|
||||||
// effective multiplier is computed as:
|
|
||||||
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
|
||||||
// the available congestion window. Resets Cubic state during quiescence.
|
|
||||||
func (c *Cubic) OnApplicationLimited() {
|
|
||||||
// When sender is not using the available congestion window, the window does
|
|
||||||
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
|
||||||
// using the entire window during the time since the beginning of the current
|
|
||||||
// "epoch" (the end of the last loss recovery period). Since
|
|
||||||
// application-limited periods break this assumption, we reset the epoch when
|
|
||||||
// in such a period. This reset effectively freezes congestion window growth
|
|
||||||
// through application-limited periods and allows Cubic growth to continue
|
|
||||||
// when the entire window is being used.
|
|
||||||
c.epoch = time.Time{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
|
||||||
// a loss event. Returns the new congestion window in packets. The new
|
|
||||||
// congestion window is a multiplicative decrease of our current window.
|
|
||||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount {
|
|
||||||
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
|
|
||||||
// We never reached the old max, so assume we are competing with another
|
|
||||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
|
||||||
c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
|
||||||
} else {
|
|
||||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
|
||||||
}
|
|
||||||
c.epoch = time.Time{} // Reset time.
|
|
||||||
return congestion.ByteCount(float32(currentCongestionWindow) * c.beta())
|
|
||||||
}
|
|
||||||
|
|
||||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
|
||||||
// Returns the new congestion window in packets. The new congestion window
|
|
||||||
// follows a cubic function that depends on the time passed since last
|
|
||||||
// packet loss.
|
|
||||||
func (c *Cubic) CongestionWindowAfterAck(
|
|
||||||
ackedBytes congestion.ByteCount,
|
|
||||||
currentCongestionWindow congestion.ByteCount,
|
|
||||||
delayMin time.Duration,
|
|
||||||
eventTime time.Time,
|
|
||||||
) congestion.ByteCount {
|
|
||||||
c.ackedBytesCount += ackedBytes
|
|
||||||
|
|
||||||
if c.epoch.IsZero() {
|
|
||||||
// First ACK after a loss event.
|
|
||||||
c.epoch = eventTime // Start of epoch.
|
|
||||||
c.ackedBytesCount = ackedBytes // Reset count.
|
|
||||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
|
||||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
|
||||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
|
||||||
c.timeToOriginPoint = 0
|
|
||||||
c.originPointCongestionWindow = currentCongestionWindow
|
|
||||||
} else {
|
|
||||||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
|
||||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
|
||||||
// the round trip time in account. This is done to allow us to use shift as a
|
|
||||||
// divide operator.
|
|
||||||
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
|
||||||
|
|
||||||
// Right-shifts of negative, signed numbers have implementation-dependent
|
|
||||||
// behavior, so force the offset to be positive, as is done in the kernel.
|
|
||||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
|
||||||
if offset < 0 {
|
|
||||||
offset = -offset
|
|
||||||
}
|
|
||||||
|
|
||||||
deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
|
|
||||||
var targetCongestionWindow congestion.ByteCount
|
|
||||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
|
||||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
|
||||||
} else {
|
|
||||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
|
||||||
}
|
|
||||||
// Limit the CWND increase to half the acked bytes.
|
|
||||||
targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
|
||||||
|
|
||||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
|
||||||
// time we ack an estimated tcp window of bytes. For small
|
|
||||||
// congestion windows (less than 25), the formula below will
|
|
||||||
// increase slightly slower than linearly per estimated tcp window
|
|
||||||
// of bytes.
|
|
||||||
c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
|
|
||||||
c.ackedBytesCount = 0
|
|
||||||
|
|
||||||
// We have a new cubic congestion window.
|
|
||||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
|
||||||
|
|
||||||
// Compute target congestion_window based on cubic target and estimated TCP
|
|
||||||
// congestion_window, use highest (fastest).
|
|
||||||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
|
||||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
|
||||||
}
|
|
||||||
return targetCongestionWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNumConnections sets the number of emulated connections
|
|
||||||
func (c *Cubic) SetNumConnections(n int) {
|
|
||||||
c.numConnections = n
|
|
||||||
}
|
|
|
@ -1,318 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/congestion"
|
|
||||||
"github.com/sagernet/quic-go/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
maxBurstPackets = 3
|
|
||||||
renoBeta = 0.7 // Reno backoff factor.
|
|
||||||
minCongestionWindowPackets = 2
|
|
||||||
initialCongestionWindow = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
InvalidPacketNumber congestion.PacketNumber = -1
|
|
||||||
MaxCongestionWindowPackets = 20000
|
|
||||||
MaxByteCount = congestion.ByteCount(1<<62 - 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
type cubicSender struct {
|
|
||||||
hybridSlowStart HybridSlowStart
|
|
||||||
rttStats congestion.RTTStatsProvider
|
|
||||||
cubic *Cubic
|
|
||||||
pacer *pacer
|
|
||||||
clock Clock
|
|
||||||
|
|
||||||
reno bool
|
|
||||||
|
|
||||||
// Track the largest packet that has been sent.
|
|
||||||
largestSentPacketNumber congestion.PacketNumber
|
|
||||||
|
|
||||||
// Track the largest packet that has been acked.
|
|
||||||
largestAckedPacketNumber congestion.PacketNumber
|
|
||||||
|
|
||||||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
|
||||||
largestSentAtLastCutback congestion.PacketNumber
|
|
||||||
|
|
||||||
// Whether the last loss event caused us to exit slowstart.
|
|
||||||
// Used for stats collection of slowstartPacketsLost
|
|
||||||
lastCutbackExitedSlowstart bool
|
|
||||||
|
|
||||||
// Congestion window in bytes.
|
|
||||||
congestionWindow congestion.ByteCount
|
|
||||||
|
|
||||||
// Slow start congestion window in bytes, aka ssthresh.
|
|
||||||
slowStartThreshold congestion.ByteCount
|
|
||||||
|
|
||||||
// ACK counter for the Reno implementation.
|
|
||||||
numAckedPackets uint64
|
|
||||||
|
|
||||||
initialCongestionWindow congestion.ByteCount
|
|
||||||
initialMaxCongestionWindow congestion.ByteCount
|
|
||||||
|
|
||||||
maxDatagramSize congestion.ByteCount
|
|
||||||
|
|
||||||
lastState logging.CongestionState
|
|
||||||
tracer logging.ConnectionTracer
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ congestion.CongestionControl = &cubicSender{}
|
|
||||||
|
|
||||||
// NewCubicSender makes a new cubic sender
|
|
||||||
func NewCubicSender(
|
|
||||||
clock Clock,
|
|
||||||
initialMaxDatagramSize congestion.ByteCount,
|
|
||||||
reno bool,
|
|
||||||
tracer logging.ConnectionTracer,
|
|
||||||
) *cubicSender {
|
|
||||||
return newCubicSender(
|
|
||||||
clock,
|
|
||||||
reno,
|
|
||||||
initialMaxDatagramSize,
|
|
||||||
initialCongestionWindow*initialMaxDatagramSize,
|
|
||||||
MaxCongestionWindowPackets*initialMaxDatagramSize,
|
|
||||||
tracer,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCubicSender(
|
|
||||||
clock Clock,
|
|
||||||
reno bool,
|
|
||||||
initialMaxDatagramSize,
|
|
||||||
initialCongestionWindow,
|
|
||||||
initialMaxCongestionWindow congestion.ByteCount,
|
|
||||||
tracer logging.ConnectionTracer,
|
|
||||||
) *cubicSender {
|
|
||||||
c := &cubicSender{
|
|
||||||
largestSentPacketNumber: InvalidPacketNumber,
|
|
||||||
largestAckedPacketNumber: InvalidPacketNumber,
|
|
||||||
largestSentAtLastCutback: InvalidPacketNumber,
|
|
||||||
initialCongestionWindow: initialCongestionWindow,
|
|
||||||
initialMaxCongestionWindow: initialMaxCongestionWindow,
|
|
||||||
congestionWindow: initialCongestionWindow,
|
|
||||||
slowStartThreshold: MaxByteCount,
|
|
||||||
cubic: NewCubic(clock),
|
|
||||||
clock: clock,
|
|
||||||
reno: reno,
|
|
||||||
tracer: tracer,
|
|
||||||
maxDatagramSize: initialMaxDatagramSize,
|
|
||||||
}
|
|
||||||
c.pacer = newPacer(c.BandwidthEstimate)
|
|
||||||
if c.tracer != nil {
|
|
||||||
c.lastState = logging.CongestionStateSlowStart
|
|
||||||
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
|
|
||||||
c.rttStats = provider
|
|
||||||
}
|
|
||||||
|
|
||||||
// TimeUntilSend returns when the next packet should be sent.
|
|
||||||
func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
|
|
||||||
return c.pacer.TimeUntilSend()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) HasPacingBudget(now time.Time) bool {
|
|
||||||
return c.pacer.Budget(now) >= c.maxDatagramSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
|
|
||||||
return c.maxDatagramSize * MaxCongestionWindowPackets
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
|
|
||||||
return c.maxDatagramSize * minCongestionWindowPackets
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) OnPacketSent(
|
|
||||||
sentTime time.Time,
|
|
||||||
_ congestion.ByteCount,
|
|
||||||
packetNumber congestion.PacketNumber,
|
|
||||||
bytes congestion.ByteCount,
|
|
||||||
isRetransmittable bool,
|
|
||||||
) {
|
|
||||||
c.pacer.SentPacket(sentTime, bytes)
|
|
||||||
if !isRetransmittable {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.largestSentPacketNumber = packetNumber
|
|
||||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
|
||||||
return bytesInFlight < c.GetCongestionWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) InRecovery() bool {
|
|
||||||
return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) InSlowStart() bool {
|
|
||||||
return c.GetCongestionWindow() < c.slowStartThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
|
|
||||||
return c.congestionWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) MaybeExitSlowStart() {
|
|
||||||
if c.InSlowStart() &&
|
|
||||||
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
|
|
||||||
// exit slow start
|
|
||||||
c.slowStartThreshold = c.congestionWindow
|
|
||||||
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) OnPacketAcked(
|
|
||||||
ackedPacketNumber congestion.PacketNumber,
|
|
||||||
ackedBytes congestion.ByteCount,
|
|
||||||
priorInFlight congestion.ByteCount,
|
|
||||||
eventTime time.Time,
|
|
||||||
) {
|
|
||||||
c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
|
|
||||||
if c.InRecovery() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
|
||||||
if c.InSlowStart() {
|
|
||||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
|
|
||||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
|
||||||
// already sent should be treated as a single loss event, since it's expected.
|
|
||||||
if packetNumber <= c.largestSentAtLastCutback {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.lastCutbackExitedSlowstart = c.InSlowStart()
|
|
||||||
c.maybeTraceStateChange(logging.CongestionStateRecovery)
|
|
||||||
|
|
||||||
if c.reno {
|
|
||||||
c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
|
|
||||||
} else {
|
|
||||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
|
||||||
}
|
|
||||||
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
|
|
||||||
c.congestionWindow = minCwnd
|
|
||||||
}
|
|
||||||
c.slowStartThreshold = c.congestionWindow
|
|
||||||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
|
||||||
// reset packet count from congestion avoidance mode. We start
|
|
||||||
// counting again when we're out of recovery.
|
|
||||||
c.numAckedPackets = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
|
||||||
// represents, but quic has a separate ack for each packet.
|
|
||||||
func (c *cubicSender) maybeIncreaseCwnd(
|
|
||||||
_ congestion.PacketNumber,
|
|
||||||
ackedBytes congestion.ByteCount,
|
|
||||||
priorInFlight congestion.ByteCount,
|
|
||||||
eventTime time.Time,
|
|
||||||
) {
|
|
||||||
// Do not increase the congestion window unless the sender is close to using
|
|
||||||
// the current window.
|
|
||||||
if !c.isCwndLimited(priorInFlight) {
|
|
||||||
c.cubic.OnApplicationLimited()
|
|
||||||
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if c.congestionWindow >= c.maxCongestionWindow() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if c.InSlowStart() {
|
|
||||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
|
||||||
c.congestionWindow += c.maxDatagramSize
|
|
||||||
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Congestion avoidance
|
|
||||||
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
|
||||||
if c.reno {
|
|
||||||
// Classic Reno congestion avoidance.
|
|
||||||
c.numAckedPackets++
|
|
||||||
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
|
|
||||||
c.congestionWindow += c.maxDatagramSize
|
|
||||||
c.numAckedPackets = 0
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
|
|
||||||
congestionWindow := c.GetCongestionWindow()
|
|
||||||
if bytesInFlight >= congestionWindow {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
availableBytes := congestionWindow - bytesInFlight
|
|
||||||
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
|
|
||||||
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// BandwidthEstimate returns the current bandwidth estimate
|
|
||||||
func (c *cubicSender) BandwidthEstimate() Bandwidth {
|
|
||||||
if c.rttStats == nil {
|
|
||||||
return infBandwidth
|
|
||||||
}
|
|
||||||
srtt := c.rttStats.SmoothedRTT()
|
|
||||||
if srtt == 0 {
|
|
||||||
// If we haven't measured an rtt, the bandwidth estimate is unknown.
|
|
||||||
return infBandwidth
|
|
||||||
}
|
|
||||||
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnRetransmissionTimeout is called on an retransmission timeout
|
|
||||||
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
|
||||||
c.largestSentAtLastCutback = InvalidPacketNumber
|
|
||||||
if !packetsRetransmitted {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.hybridSlowStart.Restart()
|
|
||||||
c.cubic.Reset()
|
|
||||||
c.slowStartThreshold = c.congestionWindow / 2
|
|
||||||
c.congestionWindow = c.minCongestionWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnConnectionMigration is called when the connection is migrated (?)
|
|
||||||
func (c *cubicSender) OnConnectionMigration() {
|
|
||||||
c.hybridSlowStart.Restart()
|
|
||||||
c.largestSentPacketNumber = InvalidPacketNumber
|
|
||||||
c.largestAckedPacketNumber = InvalidPacketNumber
|
|
||||||
c.largestSentAtLastCutback = InvalidPacketNumber
|
|
||||||
c.lastCutbackExitedSlowstart = false
|
|
||||||
c.cubic.Reset()
|
|
||||||
c.numAckedPackets = 0
|
|
||||||
c.congestionWindow = c.initialCongestionWindow
|
|
||||||
c.slowStartThreshold = c.initialMaxCongestionWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
|
|
||||||
if c.tracer == nil || new == c.lastState {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.tracer.UpdatedCongestionState(new)
|
|
||||||
c.lastState = new
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
|
|
||||||
if s < c.maxDatagramSize {
|
|
||||||
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
|
|
||||||
}
|
|
||||||
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
|
|
||||||
c.maxDatagramSize = s
|
|
||||||
if cwndIsMinCwnd {
|
|
||||||
c.congestionWindow = c.minCongestionWindow()
|
|
||||||
}
|
|
||||||
c.pacer.SetMaxDatagramSize(s)
|
|
||||||
}
|
|
|
@ -1,112 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/congestion"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
|
||||||
// tcp_cubic.c.
|
|
||||||
const hybridStartLowWindow = congestion.ByteCount(16)
|
|
||||||
|
|
||||||
// Number of delay samples for detecting the increase of delay.
|
|
||||||
const hybridStartMinSamples = uint32(8)
|
|
||||||
|
|
||||||
// Exit slow start if the min rtt has increased by more than 1/8th.
|
|
||||||
const hybridStartDelayFactorExp = 3 // 2^3 = 8
|
|
||||||
// The original paper specifies 2 and 8ms, but those have changed over time.
|
|
||||||
const (
|
|
||||||
hybridStartDelayMinThresholdUs = int64(4000)
|
|
||||||
hybridStartDelayMaxThresholdUs = int64(16000)
|
|
||||||
)
|
|
||||||
|
|
||||||
// HybridSlowStart implements the TCP hybrid slow start algorithm
|
|
||||||
type HybridSlowStart struct {
|
|
||||||
endPacketNumber congestion.PacketNumber
|
|
||||||
lastSentPacketNumber congestion.PacketNumber
|
|
||||||
started bool
|
|
||||||
currentMinRTT time.Duration
|
|
||||||
rttSampleCount uint32
|
|
||||||
hystartFound bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
|
|
||||||
func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) {
|
|
||||||
s.endPacketNumber = lastSent
|
|
||||||
s.currentMinRTT = 0
|
|
||||||
s.rttSampleCount = 0
|
|
||||||
s.started = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
|
|
||||||
func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool {
|
|
||||||
return s.endPacketNumber < ack
|
|
||||||
}
|
|
||||||
|
|
||||||
// ShouldExitSlowStart should be called on every new ack frame, since a new
|
|
||||||
// RTT measurement can be made then.
|
|
||||||
// rtt: the RTT for this ack packet.
|
|
||||||
// minRTT: is the lowest delay (RTT) we have seen during the session.
|
|
||||||
// congestionWindow: the congestion window in packets.
|
|
||||||
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool {
|
|
||||||
if !s.started {
|
|
||||||
// Time to start the hybrid slow start.
|
|
||||||
s.StartReceiveRound(s.lastSentPacketNumber)
|
|
||||||
}
|
|
||||||
if s.hystartFound {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Second detection parameter - delay increase detection.
|
|
||||||
// Compare the minimum delay (s.currentMinRTT) of the current
|
|
||||||
// burst of packets relative to the minimum delay during the session.
|
|
||||||
// Note: we only look at the first few(8) packets in each burst, since we
|
|
||||||
// only want to compare the lowest RTT of the burst relative to previous
|
|
||||||
// bursts.
|
|
||||||
s.rttSampleCount++
|
|
||||||
if s.rttSampleCount <= hybridStartMinSamples {
|
|
||||||
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
|
|
||||||
s.currentMinRTT = latestRTT
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We only need to check this once per round.
|
|
||||||
if s.rttSampleCount == hybridStartMinSamples {
|
|
||||||
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
|
||||||
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
|
||||||
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
|
||||||
minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
|
||||||
minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
|
||||||
|
|
||||||
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
|
||||||
s.hystartFound = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Exit from slow start if the cwnd is greater than 16 and
|
|
||||||
// increasing delay is found.
|
|
||||||
return congestionWindow >= hybridStartLowWindow && s.hystartFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnPacketSent is called when a packet was sent
|
|
||||||
func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) {
|
|
||||||
s.lastSentPacketNumber = packetNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
|
|
||||||
// the round when the final packet of the burst is received and start it on
|
|
||||||
// the next incoming ack.
|
|
||||||
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) {
|
|
||||||
if s.IsEndOfRound(ackedPacketNumber) {
|
|
||||||
s.started = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Started returns true if started
|
|
||||||
func (s *HybridSlowStart) Started() bool {
|
|
||||||
return s.started
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restart the slow start phase
|
|
||||||
func (s *HybridSlowStart) Restart() {
|
|
||||||
s.started = false
|
|
||||||
s.hystartFound = false
|
|
||||||
}
|
|
|
@ -1,72 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/exp/constraints"
|
|
||||||
)
|
|
||||||
|
|
||||||
// InfDuration is a duration of infinite length
|
|
||||||
const InfDuration = time.Duration(math.MaxInt64)
|
|
||||||
|
|
||||||
func Max[T constraints.Ordered](a, b T) T {
|
|
||||||
if a < b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func Min[T constraints.Ordered](a, b T) T {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinNonZeroDuration return the minimum duration that's not zero.
|
|
||||||
func MinNonZeroDuration(a, b time.Duration) time.Duration {
|
|
||||||
if a == 0 {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
if b == 0 {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return Min(a, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AbsDuration returns the absolute value of a time duration
|
|
||||||
func AbsDuration(d time.Duration) time.Duration {
|
|
||||||
if d >= 0 {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
return -d
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinTime returns the earlier time
|
|
||||||
func MinTime(a, b time.Time) time.Time {
|
|
||||||
if a.After(b) {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinNonZeroTime returns the earlist time that is not time.Time{}
|
|
||||||
// If both a and b are time.Time{}, it returns time.Time{}
|
|
||||||
func MinNonZeroTime(a, b time.Time) time.Time {
|
|
||||||
if a.IsZero() {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
if b.IsZero() {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return MinTime(a, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaxTime returns the later time
|
|
||||||
func MaxTime(a, b time.Time) time.Time {
|
|
||||||
if a.After(b) {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
|
@ -1,81 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go/congestion"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
initialMaxDatagramSize = congestion.ByteCount(1252)
|
|
||||||
MinPacingDelay = time.Millisecond
|
|
||||||
TimerGranularity = time.Millisecond
|
|
||||||
maxBurstSizePackets = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
// The pacer implements a token bucket pacing algorithm.
|
|
||||||
type pacer struct {
|
|
||||||
budgetAtLastSent congestion.ByteCount
|
|
||||||
maxDatagramSize congestion.ByteCount
|
|
||||||
lastSentTime time.Time
|
|
||||||
getAdjustedBandwidth func() uint64 // in bytes/s
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPacer(getBandwidth func() Bandwidth) *pacer {
|
|
||||||
p := &pacer{
|
|
||||||
maxDatagramSize: initialMaxDatagramSize,
|
|
||||||
getAdjustedBandwidth: func() uint64 {
|
|
||||||
// Bandwidth is in bits/s. We need the value in bytes/s.
|
|
||||||
bw := uint64(getBandwidth() / BytesPerSecond)
|
|
||||||
// Use a slightly higher value than the actual measured bandwidth.
|
|
||||||
// RTT variations then won't result in under-utilization of the congestion window.
|
|
||||||
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
|
|
||||||
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
|
|
||||||
return bw * 5 / 4
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.budgetAtLastSent = p.maxBurstSize()
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
|
||||||
budget := p.Budget(sendTime)
|
|
||||||
if size > budget {
|
|
||||||
p.budgetAtLastSent = 0
|
|
||||||
} else {
|
|
||||||
p.budgetAtLastSent = budget - size
|
|
||||||
}
|
|
||||||
p.lastSentTime = sendTime
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
|
||||||
if p.lastSentTime.IsZero() {
|
|
||||||
return p.maxBurstSize()
|
|
||||||
}
|
|
||||||
budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
|
||||||
return Min(p.maxBurstSize(), budget)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pacer) maxBurstSize() congestion.ByteCount {
|
|
||||||
return Max(
|
|
||||||
congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9,
|
|
||||||
maxBurstSizePackets*p.maxDatagramSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TimeUntilSend returns when the next packet should be sent.
|
|
||||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
|
||||||
func (p *pacer) TimeUntilSend() time.Time {
|
|
||||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
|
||||||
return time.Time{}
|
|
||||||
}
|
|
||||||
return p.lastSentTime.Add(Max(
|
|
||||||
MinPacingDelay,
|
|
||||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
|
||||||
p.maxDatagramSize = s
|
|
||||||
}
|
|
|
@ -1,132 +0,0 @@
|
||||||
package congestion
|
|
||||||
|
|
||||||
// WindowedFilter Use the following to construct a windowed filter object of type T.
|
|
||||||
// For example, a min filter using QuicTime as the time type:
|
|
||||||
//
|
|
||||||
// WindowedFilter<T, MinFilter<T>, QuicTime, QuicTime::Delta> ObjectName;
|
|
||||||
//
|
|
||||||
// A max filter using 64-bit integers as the time type:
|
|
||||||
//
|
|
||||||
// WindowedFilter<T, MaxFilter<T>, uint64_t, int64_t> ObjectName;
|
|
||||||
//
|
|
||||||
// Specifically, this template takes four arguments:
|
|
||||||
// 1. T -- type of the measurement that is being filtered.
|
|
||||||
// 2. Compare -- MinFilter<T> or MaxFilter<T>, depending on the type of filter
|
|
||||||
// desired.
|
|
||||||
// 3. TimeT -- the type used to represent timestamps.
|
|
||||||
// 4. TimeDeltaT -- the type used to represent continuous time intervals between
|
|
||||||
// two timestamps. Has to be the type of (a - b) if both |a| and |b| are
|
|
||||||
// of type TimeT.
|
|
||||||
type WindowedFilter struct {
|
|
||||||
// Time length of window.
|
|
||||||
windowLength int64
|
|
||||||
estimates []Sample
|
|
||||||
comparator func(int64, int64) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Sample struct {
|
|
||||||
sample int64
|
|
||||||
time int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compares two values and returns true if the first is greater than or equal
|
|
||||||
// to the second.
|
|
||||||
func MaxFilter(a, b int64) bool {
|
|
||||||
return a >= b
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compares two values and returns true if the first is less than or equal
|
|
||||||
// to the second.
|
|
||||||
func MinFilter(a, b int64) bool {
|
|
||||||
return a <= b
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter {
|
|
||||||
return &WindowedFilter{
|
|
||||||
windowLength: windowLength,
|
|
||||||
estimates: make([]Sample, 3),
|
|
||||||
comparator: comparator,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Changes the window length. Does not update any current samples.
|
|
||||||
func (f *WindowedFilter) SetWindowLength(windowLength int64) {
|
|
||||||
f.windowLength = windowLength
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *WindowedFilter) GetBest() int64 {
|
|
||||||
return f.estimates[0].sample
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *WindowedFilter) GetSecondBest() int64 {
|
|
||||||
return f.estimates[1].sample
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *WindowedFilter) GetThirdBest() int64 {
|
|
||||||
return f.estimates[2].sample
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *WindowedFilter) Update(sample int64, time int64) {
|
|
||||||
if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength {
|
|
||||||
f.Reset(sample, time)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.comparator(sample, f.estimates[1].sample) {
|
|
||||||
f.estimates[1].sample = sample
|
|
||||||
f.estimates[1].time = time
|
|
||||||
f.estimates[2].sample = sample
|
|
||||||
f.estimates[2].time = time
|
|
||||||
} else if f.comparator(sample, f.estimates[2].sample) {
|
|
||||||
f.estimates[2].sample = sample
|
|
||||||
f.estimates[2].time = time
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expire and update estimates as necessary.
|
|
||||||
if time-f.estimates[0].time > f.windowLength {
|
|
||||||
// The best estimate hasn't been updated for an entire window, so promote
|
|
||||||
// second and third best estimates.
|
|
||||||
f.estimates[0].sample = f.estimates[1].sample
|
|
||||||
f.estimates[0].time = f.estimates[1].time
|
|
||||||
f.estimates[1].sample = f.estimates[2].sample
|
|
||||||
f.estimates[1].time = f.estimates[2].time
|
|
||||||
f.estimates[2].sample = sample
|
|
||||||
f.estimates[2].time = time
|
|
||||||
// Need to iterate one more time. Check if the new best estimate is
|
|
||||||
// outside the window as well, since it may also have been recorded a
|
|
||||||
// long time ago. Don't need to iterate once more since we cover that
|
|
||||||
// case at the beginning of the method.
|
|
||||||
if time-f.estimates[0].time > f.windowLength {
|
|
||||||
f.estimates[0].sample = f.estimates[1].sample
|
|
||||||
f.estimates[0].time = f.estimates[1].time
|
|
||||||
f.estimates[1].sample = f.estimates[2].sample
|
|
||||||
f.estimates[1].time = f.estimates[2].time
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 {
|
|
||||||
// A quarter of the window has passed without a better sample, so the
|
|
||||||
// second-best estimate is taken from the second quarter of the window.
|
|
||||||
f.estimates[1].sample = sample
|
|
||||||
f.estimates[1].time = time
|
|
||||||
f.estimates[2].sample = sample
|
|
||||||
f.estimates[2].time = time
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 {
|
|
||||||
// We've passed a half of the window without a better estimate, so take
|
|
||||||
// a third-best estimate from the second half of the window.
|
|
||||||
f.estimates[2].sample = sample
|
|
||||||
f.estimates[2].time = time
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *WindowedFilter) Reset(newSample int64, newTime int64) {
|
|
||||||
f.estimates[0].sample = newSample
|
|
||||||
f.estimates[0].time = newTime
|
|
||||||
f.estimates[1].sample = newSample
|
|
||||||
f.estimates[1].time = newTime
|
|
||||||
f.estimates[2].sample = newSample
|
|
||||||
f.estimates[2].time = newTime
|
|
||||||
}
|
|
|
@ -1,532 +0,0 @@
|
||||||
package tuic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/atomic"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/cache"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
var udpMessagePool = sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
return new(udpMessage)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func allocMessage() *udpMessage {
|
|
||||||
message := udpMessagePool.Get().(*udpMessage)
|
|
||||||
message.referenced = true
|
|
||||||
return message
|
|
||||||
}
|
|
||||||
|
|
||||||
func releaseMessages(messages []*udpMessage) {
|
|
||||||
for _, message := range messages {
|
|
||||||
if message != nil {
|
|
||||||
message.release()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpMessage struct {
|
|
||||||
sessionID uint16
|
|
||||||
packetID uint16
|
|
||||||
fragmentTotal uint8
|
|
||||||
fragmentID uint8
|
|
||||||
destination M.Socksaddr
|
|
||||||
data *buf.Buffer
|
|
||||||
referenced bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpMessage) release() {
|
|
||||||
if !m.referenced {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*m = udpMessage{}
|
|
||||||
udpMessagePool.Put(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpMessage) releaseMessage() {
|
|
||||||
m.data.Release()
|
|
||||||
m.release()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpMessage) pack() *buf.Buffer {
|
|
||||||
buffer := buf.NewSize(m.headerSize() + m.data.Len())
|
|
||||||
common.Must(
|
|
||||||
buffer.WriteByte(Version),
|
|
||||||
buffer.WriteByte(CommandPacket),
|
|
||||||
binary.Write(buffer, binary.BigEndian, m.sessionID),
|
|
||||||
binary.Write(buffer, binary.BigEndian, m.packetID),
|
|
||||||
binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
|
|
||||||
binary.Write(buffer, binary.BigEndian, m.fragmentID),
|
|
||||||
binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())),
|
|
||||||
addressSerializer.WriteAddrPort(buffer, m.destination),
|
|
||||||
common.Error(buffer.Write(m.data.Bytes())),
|
|
||||||
)
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpMessage) headerSize() int {
|
|
||||||
return 10 + addressSerializer.AddrPortLen(m.destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
|
|
||||||
if message.data.Len() <= maxPacketSize {
|
|
||||||
return []*udpMessage{message}
|
|
||||||
}
|
|
||||||
var fragments []*udpMessage
|
|
||||||
originPacket := message.data.Bytes()
|
|
||||||
udpMTU := maxPacketSize - message.headerSize()
|
|
||||||
for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
|
|
||||||
fragment := allocMessage()
|
|
||||||
*fragment = *message
|
|
||||||
if remaining > udpMTU {
|
|
||||||
fragment.data = buf.As(originPacket[:udpMTU])
|
|
||||||
originPacket = originPacket[udpMTU:]
|
|
||||||
} else {
|
|
||||||
fragment.data = buf.As(originPacket)
|
|
||||||
originPacket = nil
|
|
||||||
}
|
|
||||||
fragments = append(fragments, fragment)
|
|
||||||
}
|
|
||||||
fragmentTotal := uint16(len(fragments))
|
|
||||||
for index, fragment := range fragments {
|
|
||||||
fragment.fragmentID = uint8(index)
|
|
||||||
fragment.fragmentTotal = uint8(fragmentTotal)
|
|
||||||
if index > 0 {
|
|
||||||
fragment.destination = M.Socksaddr{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fragments
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpPacketConn struct {
|
|
||||||
ctx context.Context
|
|
||||||
cancel common.ContextCancelCauseFunc
|
|
||||||
sessionID uint16
|
|
||||||
quicConn quic.Connection
|
|
||||||
data chan *udpMessage
|
|
||||||
udpStream bool
|
|
||||||
udpMTU int
|
|
||||||
udpMTUTime time.Time
|
|
||||||
packetId atomic.Uint32
|
|
||||||
closeOnce sync.Once
|
|
||||||
isServer bool
|
|
||||||
defragger *udpDefragger
|
|
||||||
onDestroy func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn {
|
|
||||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
|
||||||
return &udpPacketConn{
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
quicConn: quicConn,
|
|
||||||
data: make(chan *udpMessage, 64),
|
|
||||||
udpStream: udpStream,
|
|
||||||
isServer: isServer,
|
|
||||||
defragger: newUDPDefragger(),
|
|
||||||
onDestroy: onDestroy,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case p := <-c.data:
|
|
||||||
buffer = p.data
|
|
||||||
destination = p.destination
|
|
||||||
p.release()
|
|
||||||
return
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case p := <-c.data:
|
|
||||||
_, err = buffer.ReadOnceFrom(p.data)
|
|
||||||
destination = p.destination
|
|
||||||
p.releaseMessage()
|
|
||||||
return
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case p := <-c.data:
|
|
||||||
_, err = newBuffer().ReadOnceFrom(p.data)
|
|
||||||
destination = p.destination
|
|
||||||
p.releaseMessage()
|
|
||||||
return
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
select {
|
|
||||||
case pkt := <-c.data:
|
|
||||||
n = copy(p, pkt.data.Bytes())
|
|
||||||
if pkt.destination.IsFqdn() {
|
|
||||||
addr = pkt.destination
|
|
||||||
} else {
|
|
||||||
addr = pkt.destination.UDPAddr()
|
|
||||||
}
|
|
||||||
pkt.releaseMessage()
|
|
||||||
return n, addr, nil
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return 0, nil, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) needFragment() bool {
|
|
||||||
nowTime := time.Now()
|
|
||||||
if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
|
|
||||||
c.udpMTUTime = nowTime
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
defer buffer.Release()
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return net.ErrClosed
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if buffer.Len() > 0xffff {
|
|
||||||
return quic.ErrMessageTooLarge(0xffff)
|
|
||||||
}
|
|
||||||
if !destination.IsValid() {
|
|
||||||
return E.New("invalid destination address")
|
|
||||||
}
|
|
||||||
packetId := c.packetId.Add(1)
|
|
||||||
if packetId > math.MaxUint16 {
|
|
||||||
c.packetId.Store(0)
|
|
||||||
packetId = 0
|
|
||||||
}
|
|
||||||
message := allocMessage()
|
|
||||||
*message = udpMessage{
|
|
||||||
sessionID: c.sessionID,
|
|
||||||
packetID: uint16(packetId),
|
|
||||||
fragmentTotal: 1,
|
|
||||||
destination: destination,
|
|
||||||
data: buffer,
|
|
||||||
}
|
|
||||||
defer message.releaseMessage()
|
|
||||||
var err error
|
|
||||||
if !c.udpStream && c.needFragment() && buffer.Len() > c.udpMTU {
|
|
||||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
|
||||||
} else {
|
|
||||||
err = c.writePacket(message)
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var tooLargeErr quic.ErrMessageTooLarge
|
|
||||||
if !errors.As(err, &tooLargeErr) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.udpMTU = int(tooLargeErr)
|
|
||||||
c.udpMTUTime = time.Now()
|
|
||||||
return c.writePackets(fragUDPMessage(message, c.udpMTU))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return 0, net.ErrClosed
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if len(p) > 0xffff {
|
|
||||||
return 0, quic.ErrMessageTooLarge(0xffff)
|
|
||||||
}
|
|
||||||
destination := M.SocksaddrFromNet(addr)
|
|
||||||
if !destination.IsValid() {
|
|
||||||
return 0, E.New("invalid destination address")
|
|
||||||
}
|
|
||||||
packetId := c.packetId.Add(1)
|
|
||||||
if packetId > math.MaxUint16 {
|
|
||||||
c.packetId.Store(0)
|
|
||||||
packetId = 0
|
|
||||||
}
|
|
||||||
message := allocMessage()
|
|
||||||
*message = udpMessage{
|
|
||||||
sessionID: c.sessionID,
|
|
||||||
packetID: uint16(packetId),
|
|
||||||
fragmentTotal: 1,
|
|
||||||
destination: destination,
|
|
||||||
data: buf.As(p),
|
|
||||||
}
|
|
||||||
if !c.udpStream && c.needFragment() && len(p) > c.udpMTU {
|
|
||||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
|
||||||
if err == nil {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = c.writePacket(message)
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
var tooLargeErr quic.ErrMessageTooLarge
|
|
||||||
if !errors.As(err, &tooLargeErr) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.udpMTU = int(tooLargeErr)
|
|
||||||
c.udpMTUTime = time.Now()
|
|
||||||
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
|
|
||||||
if err == nil {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) inputPacket(message *udpMessage) {
|
|
||||||
if message.fragmentTotal <= 1 {
|
|
||||||
select {
|
|
||||||
case c.data <- message:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
newMessage := c.defragger.feed(message)
|
|
||||||
if newMessage != nil {
|
|
||||||
select {
|
|
||||||
case c.data <- newMessage:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
|
|
||||||
defer releaseMessages(messages)
|
|
||||||
for _, message := range messages {
|
|
||||||
err := c.writePacket(message)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) writePacket(message *udpMessage) error {
|
|
||||||
if !c.udpStream {
|
|
||||||
buffer := message.pack()
|
|
||||||
err := c.quicConn.SendMessage(buffer.Bytes())
|
|
||||||
buffer.Release()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
stream, err := c.quicConn.OpenUniStream()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
buffer := message.pack()
|
|
||||||
_, err = stream.Write(buffer.Bytes())
|
|
||||||
buffer.Release()
|
|
||||||
stream.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) Close() error {
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
c.closeWithError(os.ErrClosed)
|
|
||||||
c.onDestroy()
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) closeWithError(err error) {
|
|
||||||
c.cancel(err)
|
|
||||||
if !c.isServer {
|
|
||||||
buffer := buf.NewSize(4)
|
|
||||||
defer buffer.Release()
|
|
||||||
buffer.WriteByte(Version)
|
|
||||||
buffer.WriteByte(CommandDissociate)
|
|
||||||
binary.Write(buffer, binary.BigEndian, c.sessionID)
|
|
||||||
sendStream, openErr := c.quicConn.OpenUniStream()
|
|
||||||
if openErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer sendStream.Close()
|
|
||||||
sendStream.Write(buffer.Bytes())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) LocalAddr() net.Addr {
|
|
||||||
return c.quicConn.LocalAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) SetDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpDefragger struct {
|
|
||||||
packetMap *cache.LruCache[uint16, *packetItem]
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUDPDefragger() *udpDefragger {
|
|
||||||
return &udpDefragger{
|
|
||||||
packetMap: cache.New(
|
|
||||||
cache.WithAge[uint16, *packetItem](10),
|
|
||||||
cache.WithUpdateAgeOnGet[uint16, *packetItem](),
|
|
||||||
cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
|
|
||||||
releaseMessages(value.messages)
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type packetItem struct {
|
|
||||||
access sync.Mutex
|
|
||||||
messages []*udpMessage
|
|
||||||
count uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
|
|
||||||
if m.fragmentTotal <= 1 {
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
if m.fragmentID >= m.fragmentTotal {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
|
|
||||||
item.access.Lock()
|
|
||||||
defer item.access.Unlock()
|
|
||||||
if int(m.fragmentTotal) != len(item.messages) {
|
|
||||||
releaseMessages(item.messages)
|
|
||||||
item.messages = make([]*udpMessage, m.fragmentTotal)
|
|
||||||
item.count = 1
|
|
||||||
item.messages[m.fragmentID] = m
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if item.messages[m.fragmentID] != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
item.messages[m.fragmentID] = m
|
|
||||||
item.count++
|
|
||||||
if int(item.count) != len(item.messages) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
newMessage := allocMessage()
|
|
||||||
*newMessage = *item.messages[0]
|
|
||||||
var dataLength uint16
|
|
||||||
for _, message := range item.messages {
|
|
||||||
dataLength += uint16(message.data.Len())
|
|
||||||
}
|
|
||||||
if dataLength > 0 {
|
|
||||||
newMessage.data = buf.NewSize(int(dataLength))
|
|
||||||
for _, message := range item.messages {
|
|
||||||
common.Must1(newMessage.data.Write(message.data.Bytes()))
|
|
||||||
message.releaseMessage()
|
|
||||||
}
|
|
||||||
item.messages = nil
|
|
||||||
return newMessage
|
|
||||||
}
|
|
||||||
item.messages = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPacketItem() *packetItem {
|
|
||||||
return new(packetItem)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readUDPMessage(message *udpMessage, reader io.Reader) error {
|
|
||||||
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.packetID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var dataLength uint16
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &dataLength)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
message.destination, err = addressSerializer.ReadAddrPort(reader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
message.data = buf.NewSize(int(dataLength))
|
|
||||||
_, err = message.data.ReadFullFrom(reader, message.data.FreeLen())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeUDPMessage(message *udpMessage, data []byte) error {
|
|
||||||
reader := bytes.NewReader(data)
|
|
||||||
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.packetID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var dataLength uint16
|
|
||||||
err = binary.Read(reader, binary.BigEndian, &dataLength)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
message.destination, err = addressSerializer.ReadAddrPort(reader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if reader.Len() != int(dataLength) {
|
|
||||||
return io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
message.data = buf.As(data[len(data)-reader.Len():])
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,15 +0,0 @@
|
||||||
package tuic
|
|
||||||
|
|
||||||
const (
|
|
||||||
Version = 5
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
CommandAuthenticate = iota
|
|
||||||
CommandConnect
|
|
||||||
CommandPacket
|
|
||||||
CommandDissociate
|
|
||||||
CommandHeartbeat
|
|
||||||
)
|
|
||||||
|
|
||||||
const AuthenticateLen = 2 + 16 + 32
|
|
|
@ -1,437 +0,0 @@
|
||||||
//go:build with_quic
|
|
||||||
|
|
||||||
package tuic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/auth"
|
|
||||||
"github.com/sagernet/sing/common/baderror"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
"github.com/sagernet/sing/common/logger"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
|
|
||||||
"github.com/gofrs/uuid/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServerOptions struct {
|
|
||||||
Context context.Context
|
|
||||||
Logger logger.Logger
|
|
||||||
TLSConfig tls.ServerConfig
|
|
||||||
Users []User
|
|
||||||
CongestionControl string
|
|
||||||
AuthTimeout time.Duration
|
|
||||||
ZeroRTTHandshake bool
|
|
||||||
Heartbeat time.Duration
|
|
||||||
Handler ServerHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
Name string
|
|
||||||
UUID uuid.UUID
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerHandler interface {
|
|
||||||
N.TCPConnectionHandler
|
|
||||||
N.UDPConnectionHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
ctx context.Context
|
|
||||||
logger logger.Logger
|
|
||||||
tlsConfig tls.ServerConfig
|
|
||||||
heartbeat time.Duration
|
|
||||||
quicConfig *quic.Config
|
|
||||||
userMap map[uuid.UUID]User
|
|
||||||
congestionControl string
|
|
||||||
authTimeout time.Duration
|
|
||||||
handler ServerHandler
|
|
||||||
|
|
||||||
quicListener io.Closer
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServer(options ServerOptions) (*Server, error) {
|
|
||||||
if options.AuthTimeout == 0 {
|
|
||||||
options.AuthTimeout = 3 * time.Second
|
|
||||||
}
|
|
||||||
if options.Heartbeat == 0 {
|
|
||||||
options.Heartbeat = 10 * time.Second
|
|
||||||
}
|
|
||||||
quicConfig := &quic.Config{
|
|
||||||
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
|
||||||
MaxDatagramFrameSize: 1400,
|
|
||||||
EnableDatagrams: true,
|
|
||||||
Allow0RTT: options.ZeroRTTHandshake,
|
|
||||||
MaxIncomingStreams: 1 << 60,
|
|
||||||
MaxIncomingUniStreams: 1 << 60,
|
|
||||||
}
|
|
||||||
switch options.CongestionControl {
|
|
||||||
case "":
|
|
||||||
options.CongestionControl = "cubic"
|
|
||||||
case "cubic", "new_reno", "bbr":
|
|
||||||
default:
|
|
||||||
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
|
|
||||||
}
|
|
||||||
if len(options.Users) == 0 {
|
|
||||||
return nil, E.New("missing users")
|
|
||||||
}
|
|
||||||
userMap := make(map[uuid.UUID]User)
|
|
||||||
for _, user := range options.Users {
|
|
||||||
userMap[user.UUID] = user
|
|
||||||
}
|
|
||||||
return &Server{
|
|
||||||
ctx: options.Context,
|
|
||||||
logger: options.Logger,
|
|
||||||
tlsConfig: options.TLSConfig,
|
|
||||||
heartbeat: options.Heartbeat,
|
|
||||||
quicConfig: quicConfig,
|
|
||||||
userMap: userMap,
|
|
||||||
congestionControl: options.CongestionControl,
|
|
||||||
authTimeout: options.AuthTimeout,
|
|
||||||
handler: options.Handler,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Start(conn net.PacketConn) error {
|
|
||||||
if !s.quicConfig.Allow0RTT {
|
|
||||||
listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.quicListener = listener
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
connection, hErr := listener.Accept(s.ctx)
|
|
||||||
if hErr != nil {
|
|
||||||
if strings.Contains(hErr.Error(), "server closed") {
|
|
||||||
s.logger.Debug(E.Cause(hErr, "listener closed"))
|
|
||||||
} else {
|
|
||||||
s.logger.Error(E.Cause(hErr, "listener closed"))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go s.handleConnection(connection)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
} else {
|
|
||||||
listener, err := qtls.ListenEarly(conn, s.tlsConfig, s.quicConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.quicListener = listener
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
connection, hErr := listener.Accept(s.ctx)
|
|
||||||
if hErr != nil {
|
|
||||||
if strings.Contains(hErr.Error(), "server closed") {
|
|
||||||
s.logger.Debug(E.Cause(hErr, "listener closed"))
|
|
||||||
} else {
|
|
||||||
s.logger.Error(E.Cause(hErr, "listener closed"))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go s.handleConnection(connection)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Close() error {
|
|
||||||
return common.Close(
|
|
||||||
s.quicListener,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleConnection(connection quic.Connection) {
|
|
||||||
setCongestion(s.ctx, connection, s.congestionControl)
|
|
||||||
session := &serverSession{
|
|
||||||
Server: s,
|
|
||||||
ctx: s.ctx,
|
|
||||||
quicConn: connection,
|
|
||||||
source: M.SocksaddrFromNet(connection.RemoteAddr()),
|
|
||||||
connDone: make(chan struct{}),
|
|
||||||
authDone: make(chan struct{}),
|
|
||||||
udpConnMap: make(map[uint16]*udpPacketConn),
|
|
||||||
}
|
|
||||||
session.handle()
|
|
||||||
}
|
|
||||||
|
|
||||||
type serverSession struct {
|
|
||||||
*Server
|
|
||||||
ctx context.Context
|
|
||||||
quicConn quic.Connection
|
|
||||||
source M.Socksaddr
|
|
||||||
connAccess sync.Mutex
|
|
||||||
connDone chan struct{}
|
|
||||||
connErr error
|
|
||||||
authDone chan struct{}
|
|
||||||
authUser *User
|
|
||||||
udpAccess sync.RWMutex
|
|
||||||
udpConnMap map[uint16]*udpPacketConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handle() {
|
|
||||||
if s.ctx.Done() != nil {
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
s.closeWithError(s.ctx.Err())
|
|
||||||
case <-s.connDone:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
go s.loopUniStreams()
|
|
||||||
go s.loopStreams()
|
|
||||||
go s.loopMessages()
|
|
||||||
go s.handleAuthTimeout()
|
|
||||||
go s.loopHeartbeats()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) loopUniStreams() {
|
|
||||||
for {
|
|
||||||
uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
err = s.handleUniStream(uniStream)
|
|
||||||
if err != nil {
|
|
||||||
s.closeWithError(E.Cause(err, "handle uni stream"))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
|
|
||||||
defer stream.CancelRead(0)
|
|
||||||
buffer := buf.New()
|
|
||||||
defer buffer.Release()
|
|
||||||
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "read request")
|
|
||||||
}
|
|
||||||
version := buffer.Byte(0)
|
|
||||||
if version != Version {
|
|
||||||
return E.New("unknown version ", buffer.Byte(0))
|
|
||||||
}
|
|
||||||
command := buffer.Byte(1)
|
|
||||||
switch command {
|
|
||||||
case CommandAuthenticate:
|
|
||||||
select {
|
|
||||||
case <-s.authDone:
|
|
||||||
return E.New("authentication: multiple authentication requests")
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if buffer.Len() < AuthenticateLen {
|
|
||||||
_, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len())
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "authentication: read request")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16))
|
|
||||||
user, loaded := s.userMap[userUUID]
|
|
||||||
if !loaded {
|
|
||||||
return E.New("authentication: unknown user ", userUUID)
|
|
||||||
}
|
|
||||||
handshakeState := s.quicConn.ConnectionState()
|
|
||||||
tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "authentication: export keying material")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
|
|
||||||
return E.New("authentication: token mismatch")
|
|
||||||
}
|
|
||||||
s.authUser = &user
|
|
||||||
close(s.authDone)
|
|
||||||
return nil
|
|
||||||
case CommandPacket:
|
|
||||||
select {
|
|
||||||
case <-s.connDone:
|
|
||||||
return s.connErr
|
|
||||||
case <-s.authDone:
|
|
||||||
}
|
|
||||||
message := allocMessage()
|
|
||||||
err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
|
|
||||||
if err != nil {
|
|
||||||
message.release()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.handleUDPMessage(message, true)
|
|
||||||
return nil
|
|
||||||
case CommandDissociate:
|
|
||||||
select {
|
|
||||||
case <-s.connDone:
|
|
||||||
return s.connErr
|
|
||||||
case <-s.authDone:
|
|
||||||
}
|
|
||||||
if buffer.Len() > 4 {
|
|
||||||
return E.New("invalid dissociate message")
|
|
||||||
}
|
|
||||||
var sessionID uint16
|
|
||||||
err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.udpAccess.RLock()
|
|
||||||
udpConn, loaded := s.udpConnMap[sessionID]
|
|
||||||
s.udpAccess.RUnlock()
|
|
||||||
if loaded {
|
|
||||||
udpConn.closeWithError(E.New("remote closed"))
|
|
||||||
s.udpAccess.Lock()
|
|
||||||
delete(s.udpConnMap, sessionID)
|
|
||||||
s.udpAccess.Unlock()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return E.New("unknown command ", command)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleAuthTimeout() {
|
|
||||||
select {
|
|
||||||
case <-s.connDone:
|
|
||||||
case <-s.authDone:
|
|
||||||
case <-time.After(s.authTimeout):
|
|
||||||
s.closeWithError(E.New("authentication timeout"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) loopStreams() {
|
|
||||||
for {
|
|
||||||
stream, err := s.quicConn.AcceptStream(s.ctx)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
err = s.handleStream(stream)
|
|
||||||
if err != nil {
|
|
||||||
stream.CancelRead(0)
|
|
||||||
stream.Close()
|
|
||||||
s.logger.Error(E.Cause(err, "handle stream request"))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleStream(stream quic.Stream) error {
|
|
||||||
buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
|
|
||||||
defer buffer.Release()
|
|
||||||
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "read request")
|
|
||||||
}
|
|
||||||
version, _ := buffer.ReadByte()
|
|
||||||
if version != Version {
|
|
||||||
return E.New("unknown version ", buffer.Byte(0))
|
|
||||||
}
|
|
||||||
command, _ := buffer.ReadByte()
|
|
||||||
if command != CommandConnect {
|
|
||||||
return E.New("unsupported stream command ", command)
|
|
||||||
}
|
|
||||||
destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream))
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "read request destination")
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-s.connDone:
|
|
||||||
return s.connErr
|
|
||||||
case <-s.authDone:
|
|
||||||
}
|
|
||||||
var conn net.Conn = &serverConn{
|
|
||||||
Stream: stream,
|
|
||||||
destination: destination,
|
|
||||||
}
|
|
||||||
if buffer.IsEmpty() {
|
|
||||||
buffer.Release()
|
|
||||||
} else {
|
|
||||||
conn = bufio.NewCachedConn(conn, buffer)
|
|
||||||
}
|
|
||||||
ctx := s.ctx
|
|
||||||
if s.authUser.Name != "" {
|
|
||||||
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
|
|
||||||
}
|
|
||||||
_ = s.handler.NewConnection(ctx, conn, M.Metadata{
|
|
||||||
Source: s.source,
|
|
||||||
Destination: destination,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) loopHeartbeats() {
|
|
||||||
ticker := time.NewTicker(s.heartbeat)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-s.connDone:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
|
|
||||||
if err != nil {
|
|
||||||
s.closeWithError(E.Cause(err, "send heartbeat"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) closeWithError(err error) {
|
|
||||||
s.connAccess.Lock()
|
|
||||||
defer s.connAccess.Unlock()
|
|
||||||
select {
|
|
||||||
case <-s.connDone:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
s.connErr = err
|
|
||||||
close(s.connDone)
|
|
||||||
}
|
|
||||||
if E.IsClosedOrCanceled(err) {
|
|
||||||
s.logger.Debug(E.Cause(err, "connection failed"))
|
|
||||||
} else {
|
|
||||||
s.logger.Error(E.Cause(err, "connection failed"))
|
|
||||||
}
|
|
||||||
_ = s.quicConn.CloseWithError(0, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
type serverConn struct {
|
|
||||||
quic.Stream
|
|
||||||
destination M.Socksaddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) Read(p []byte) (n int, err error) {
|
|
||||||
n, err = c.Stream.Read(p)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
|
||||||
n, err = c.Stream.Write(p)
|
|
||||||
return n, baderror.WrapQUIC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) LocalAddr() net.Addr {
|
|
||||||
return c.destination
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) RemoteAddr() net.Addr {
|
|
||||||
return M.Socksaddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *serverConn) Close() error {
|
|
||||||
c.Stream.CancelRead(0)
|
|
||||||
return c.Stream.Close()
|
|
||||||
}
|
|
|
@ -1,75 +0,0 @@
|
||||||
//go:build with_quic
|
|
||||||
|
|
||||||
package tuic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s *serverSession) loopMessages() {
|
|
||||||
select {
|
|
||||||
case <-s.connDone:
|
|
||||||
return
|
|
||||||
case <-s.authDone:
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
message, err := s.quicConn.ReceiveMessage(s.ctx)
|
|
||||||
if err != nil {
|
|
||||||
s.closeWithError(E.Cause(err, "receive message"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hErr := s.handleMessage(message)
|
|
||||||
if hErr != nil {
|
|
||||||
s.closeWithError(E.Cause(hErr, "handle message"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleMessage(data []byte) error {
|
|
||||||
if len(data) < 2 {
|
|
||||||
return E.New("invalid message")
|
|
||||||
}
|
|
||||||
if data[0] != Version {
|
|
||||||
return E.New("unknown version ", data[0])
|
|
||||||
}
|
|
||||||
switch data[1] {
|
|
||||||
case CommandPacket:
|
|
||||||
message := allocMessage()
|
|
||||||
err := decodeUDPMessage(message, data[2:])
|
|
||||||
if err != nil {
|
|
||||||
message.release()
|
|
||||||
return E.Cause(err, "decode UDP message")
|
|
||||||
}
|
|
||||||
s.handleUDPMessage(message, false)
|
|
||||||
return nil
|
|
||||||
case CommandHeartbeat:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return E.New("unknown command ", data[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) {
|
|
||||||
s.udpAccess.RLock()
|
|
||||||
udpConn, loaded := s.udpConnMap[message.sessionID]
|
|
||||||
s.udpAccess.RUnlock()
|
|
||||||
if !loaded || common.Done(udpConn.ctx) {
|
|
||||||
udpConn = newUDPPacketConn(s.ctx, s.quicConn, udpStream, true, func() {
|
|
||||||
s.udpAccess.Lock()
|
|
||||||
delete(s.udpConnMap, message.sessionID)
|
|
||||||
s.udpAccess.Unlock()
|
|
||||||
})
|
|
||||||
udpConn.sessionID = message.sessionID
|
|
||||||
s.udpAccess.Lock()
|
|
||||||
s.udpConnMap[message.sessionID] = udpConn
|
|
||||||
s.udpAccess.Unlock()
|
|
||||||
go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
|
|
||||||
Source: s.source,
|
|
||||||
Destination: message.destination,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
udpConn.inputPacket(message)
|
|
||||||
}
|
|
|
@ -9,11 +9,11 @@ import (
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
"github.com/sagernet/quic-go"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/hysteria"
|
"github.com/sagernet/sing-box/transport/hysteria"
|
||||||
|
"github.com/sagernet/sing-quic"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
|
|
@ -9,11 +9,11 @@ import (
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
"github.com/sagernet/quic-go"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/qtls"
|
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/hysteria"
|
"github.com/sagernet/sing-box/transport/hysteria"
|
||||||
|
"github.com/sagernet/sing-quic"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
@ -27,7 +27,7 @@ type Server struct {
|
||||||
quicConfig *quic.Config
|
quicConfig *quic.Config
|
||||||
handler adapter.V2RayServerTransportHandler
|
handler adapter.V2RayServerTransportHandler
|
||||||
udpListener net.PacketConn
|
udpListener net.PacketConn
|
||||||
quicListener qtls.QUICListener
|
quicListener qtls.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) {
|
func NewServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) {
|
||||||
|
|
Loading…
Reference in a new issue