Add hysteria2 protocol

This commit is contained in:
世界 2023-08-31 20:07:32 +08:00
parent ee136f2f89
commit 3a4eeec8f2
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
23 changed files with 2389 additions and 0 deletions

View File

@ -22,6 +22,7 @@ const (
TypeShadowsocksR = "shadowsocksr"
TypeVLESS = "vless"
TypeTUIC = "tuic"
TypeHysteria2 = "hysteria2"
)
const (
@ -65,6 +66,8 @@ func ProxyDisplayName(proxyType string) string {
return "VLESS"
case TypeTUIC:
return "TUIC"
case TypeHysteria2:
return "Hysteria2"
case TypeSelector:
return "Selector"
case TypeURLTest:

View File

@ -46,6 +46,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o
return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions)
case C.TypeTUIC:
return NewTUIC(ctx, router, logger, options.Tag, options.TUICOptions)
case C.TypeHysteria2:
return NewHysteria2(ctx, router, logger, options.Tag, options.Hysteria2Options)
default:
return nil, E.New("unknown inbound type: ", options.Type)
}

144
inbound/hysteria2.go Normal file
View File

@ -0,0 +1,144 @@
//go:build with_quic
package inbound
import (
"context"
"net"
"net/http"
"net/http/httputil"
"net/url"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/hysteria2"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)
var _ adapter.Inbound = (*Hysteria2)(nil)
type Hysteria2 struct {
myInboundAdapter
tlsConfig tls.ServerConfig
server *hysteria2.Server
}
func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (*Hysteria2, error) {
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
var salamanderPassword string
if options.Obfs != nil {
if options.Obfs.Password == "" {
return nil, E.New("missing obfs password")
}
switch options.Obfs.Type {
case hysteria2.ObfsTypeSalamander:
salamanderPassword = options.Obfs.Password
default:
return nil, E.New("unknown obfs type: ", options.Obfs.Type)
}
}
var masqueradeHandler http.Handler
if options.Masquerade != "" {
masqueradeURL, err := url.Parse(options.Masquerade)
if err != nil {
return nil, E.Cause(err, "parse masquerade URL")
}
switch masqueradeURL.Scheme {
case "file":
masqueradeHandler = http.FileServer(http.Dir(masqueradeURL.Path))
case "http", "https":
masqueradeHandler = &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(masqueradeURL)
r.Out.Host = r.In.Host
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
w.WriteHeader(http.StatusBadGateway)
},
}
default:
return nil, E.New("unknown masquerade URL scheme: ", masqueradeURL.Scheme)
}
}
inbound := &Hysteria2{
myInboundAdapter: myInboundAdapter{
protocol: C.TypeHysteria2,
network: []string{N.NetworkUDP},
ctx: ctx,
router: router,
logger: logger,
tag: tag,
listenOptions: options.ListenOptions,
},
tlsConfig: tlsConfig,
}
server, err := hysteria2.NewServer(hysteria2.ServerOptions{
Context: ctx,
Logger: logger,
SendBPS: uint64(options.UpMbps * 1024 * 1024),
ReceiveBPS: uint64(options.DownMbps * 1024 * 1024),
SalamanderPassword: salamanderPassword,
TLSConfig: tlsConfig,
Users: common.Map(options.Users, func(it option.Hysteria2User) hysteria2.User {
return hysteria2.User(it)
}),
IgnoreClientBandwidth: options.IgnoreClientBandwidth,
Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil),
MasqueradeHandler: masqueradeHandler,
})
if err != nil {
return nil, err
}
inbound.server = server
return inbound, nil
}
func (h *Hysteria2) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
ctx = log.ContextWithNewID(ctx)
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
metadata = h.createMetadata(conn, metadata)
metadata.User, _ = auth.UserFromContext[string](ctx)
return h.router.RouteConnection(ctx, conn, metadata)
}
func (h *Hysteria2) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
ctx = log.ContextWithNewID(ctx)
metadata = h.createPacketMetadata(conn, metadata)
metadata.User, _ = auth.UserFromContext[string](ctx)
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
return h.router.RoutePacketConnection(ctx, conn, metadata)
}
func (h *Hysteria2) Start() error {
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {
return err
}
}
packetConn, err := h.myInboundAdapter.ListenUDP()
if err != nil {
return err
}
return h.server.Start(packetConn)
}
func (h *Hysteria2) Close() error {
return common.Close(
&h.myInboundAdapter,
h.tlsConfig,
common.PtrOrNil(h.server),
)
}

View File

@ -14,3 +14,7 @@ import (
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (adapter.Inbound, error) {
return nil, C.ErrQUICNotIncluded
}
func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (adapter.Inbound, error) {
return nil, C.ErrQUICNotIncluded
}

33
option/hysteria2.go Normal file
View File

@ -0,0 +1,33 @@
package option
type Hysteria2InboundOptions struct {
ListenOptions
UpMbps int `json:"up_mbps,omitempty"`
DownMbps int `json:"down_mbps,omitempty"`
Obfs *Hysteria2Obfs `json:"obfs,omitempty"`
Users []Hysteria2User `json:"users,omitempty"`
IgnoreClientBandwidth bool `json:"ignore_client_bandwidth,omitempty"`
TLS *InboundTLSOptions `json:"tls,omitempty"`
Masquerade string `json:"masquerade,omitempty"`
}
type Hysteria2Obfs struct {
Type string `json:"type,omitempty"`
Password string `json:"password,omitempty"`
}
type Hysteria2User struct {
Name string `json:"name,omitempty"`
Password string `json:"password,omitempty"`
}
type Hysteria2OutboundOptions struct {
DialerOptions
ServerOptions
UpMbps int `json:"up_mbps,omitempty"`
DownMbps int `json:"down_mbps,omitempty"`
Obfs *Hysteria2Obfs `json:"obfs,omitempty"`
Password string `json:"password,omitempty"`
Network NetworkList `json:"network,omitempty"`
TLS *OutboundTLSOptions `json:"tls,omitempty"`
}

View File

@ -24,6 +24,7 @@ type _Inbound struct {
ShadowTLSOptions ShadowTLSInboundOptions `json:"-"`
VLESSOptions VLESSInboundOptions `json:"-"`
TUICOptions TUICInboundOptions `json:"-"`
Hysteria2Options Hysteria2InboundOptions `json:"-"`
}
type Inbound _Inbound
@ -61,6 +62,8 @@ func (h Inbound) MarshalJSON() ([]byte, error) {
v = h.VLESSOptions
case C.TypeTUIC:
v = h.TUICOptions
case C.TypeHysteria2:
v = h.Hysteria2Options
default:
return nil, E.New("unknown inbound type: ", h.Type)
}
@ -104,6 +107,8 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error {
v = &h.VLESSOptions
case C.TypeTUIC:
v = &h.TUICOptions
case C.TypeHysteria2:
v = &h.Hysteria2Options
default:
return E.New("unknown inbound type: ", h.Type)
}

View File

@ -24,6 +24,7 @@ type _Outbound struct {
ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"`
VLESSOptions VLESSOutboundOptions `json:"-"`
TUICOptions TUICOutboundOptions `json:"-"`
Hysteria2Options Hysteria2OutboundOptions `json:"-"`
SelectorOptions SelectorOutboundOptions `json:"-"`
URLTestOptions URLTestOutboundOptions `json:"-"`
}
@ -63,6 +64,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) {
v = h.VLESSOptions
case C.TypeTUIC:
v = h.TUICOptions
case C.TypeHysteria2:
v = h.Hysteria2Options
case C.TypeSelector:
v = h.SelectorOptions
case C.TypeURLTest:
@ -110,6 +113,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error {
v = &h.VLESSOptions
case C.TypeTUIC:
v = &h.TUICOptions
case C.TypeHysteria2:
v = &h.Hysteria2Options
case C.TypeSelector:
v = &h.SelectorOptions
case C.TypeURLTest:

View File

@ -53,6 +53,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, t
return NewVLESS(ctx, router, logger, tag, options.VLESSOptions)
case C.TypeTUIC:
return NewTUIC(ctx, router, logger, tag, options.TUICOptions)
case C.TypeHysteria2:
return NewHysteria2(ctx, router, logger, tag, options.Hysteria2Options)
case C.TypeSelector:
return NewSelector(router, logger, tag, options.SelectorOptions)
case C.TypeURLTest:

122
outbound/hysteria2.go Normal file
View File

@ -0,0 +1,122 @@
//go:build with_quic
package outbound
import (
"context"
"net"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/hysteria2"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var (
_ adapter.Outbound = (*TUIC)(nil)
_ adapter.InterfaceUpdateListener = (*TUIC)(nil)
)
type Hysteria2 struct {
myOutboundAdapter
client *hysteria2.Client
}
func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2OutboundOptions) (*Hysteria2, error) {
options.UDPFragmentDefault = true
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
var salamanderPassword string
if options.Obfs != nil {
if options.Obfs.Password == "" {
return nil, E.New("missing obfs password")
}
switch options.Obfs.Type {
case hysteria2.ObfsTypeSalamander:
salamanderPassword = options.Obfs.Password
default:
return nil, E.New("unknown obfs type: ", options.Obfs.Type)
}
}
outboundDialer, err := dialer.New(router, options.DialerOptions)
if err != nil {
return nil, err
}
networkList := options.Network.Build()
client, err := hysteria2.NewClient(hysteria2.ClientOptions{
Context: ctx,
Dialer: outboundDialer,
ServerAddress: options.ServerOptions.Build(),
SendBPS: uint64(options.UpMbps * 1024 * 1024),
ReceiveBPS: uint64(options.DownMbps * 1024 * 1024),
SalamanderPassword: salamanderPassword,
Password: options.Password,
TLSConfig: tlsConfig,
UDPDisabled: !common.Contains(networkList, N.NetworkUDP),
})
if err != nil {
return nil, err
}
return &Hysteria2{
myOutboundAdapter: myOutboundAdapter{
protocol: C.TypeHysteria2,
network: networkList,
router: router,
logger: logger,
tag: tag,
dependencies: withDialerDependency(options.DialerOptions),
},
client: client,
}, nil
}
func (h *Hysteria2) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
return h.client.DialConn(ctx, destination)
case N.NetworkUDP:
conn, err := h.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return bufio.NewBindPacketConn(conn, destination), nil
default:
return nil, E.New("unsupported network: ", network)
}
}
func (h *Hysteria2) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return h.client.ListenPacket(ctx)
}
func (h *Hysteria2) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return NewConnection(ctx, h, conn, metadata)
}
func (h *Hysteria2) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return NewPacketConnection(ctx, h, conn, metadata)
}
func (h *Hysteria2) InterfaceUpdated() error {
return h.client.CloseWithError(E.New("network changed"))
}
func (h *Hysteria2) Close() error {
return h.client.CloseWithError(os.ErrClosed)
}

View File

@ -14,3 +14,7 @@ import (
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (adapter.Outbound, error) {
return nil, C.ErrQUICNotIncluded
}
func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2OutboundOptions) (adapter.Outbound, error) {
return nil, C.ErrQUICNotIncluded
}

View File

@ -168,3 +168,81 @@ func TestECHQUIC(t *testing.T) {
})
testSuitLargeUDP(t, clientPort, testPort)
}
func TestECHHysteria2(t *testing.T) {
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
echConfig, echKey := common.Must2(tls.ECHKeygenDefault("not.example.org", false))
startInstance(t, option.Options{
Inbounds: []option.Inbound{
{
Type: C.TypeMixed,
Tag: "mixed-in",
MixedOptions: option.HTTPMixedInboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
ListenPort: clientPort,
},
},
},
{
Type: C.TypeHysteria2,
Hysteria2Options: option.Hysteria2InboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
ListenPort: serverPort,
},
Users: []option.Hysteria2User{{
Password: "password",
}},
TLS: &option.InboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
KeyPath: keyPem,
ECH: &option.InboundECHOptions{
Enabled: true,
Key: []string{echKey},
},
},
},
},
},
Outbounds: []option.Outbound{
{
Type: C.TypeDirect,
},
{
Type: C.TypeHysteria2,
Tag: "hy2-out",
Hysteria2Options: option.Hysteria2OutboundOptions{
ServerOptions: option.ServerOptions{
Server: "127.0.0.1",
ServerPort: serverPort,
},
Password: "password",
TLS: &option.OutboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
ECH: &option.OutboundECHOptions{
Enabled: true,
Config: []string{echConfig},
},
},
},
},
},
Route: &option.RouteOptions{
Rules: []option.Rule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultRule{
Inbound: []string{"mixed-in"},
Outbound: "hy2-out",
},
},
},
},
})
testSuit(t, clientPort, testPort)
}

97
test/hysteria2_test.go Normal file
View File

@ -0,0 +1,97 @@
package main
import (
"net/netip"
"testing"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/hysteria2"
)
func TestHysteria2Self(t *testing.T) {
t.Run("self", func(t *testing.T) {
testHysteria2Self(t, "")
})
t.Run("self-salamander", func(t *testing.T) {
testHysteria2Self(t, "password")
})
}
func testHysteria2Self(t *testing.T, salamanderPassword string) {
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
var obfs *option.Hysteria2Obfs
if salamanderPassword != "" {
obfs = &option.Hysteria2Obfs{
Type: hysteria2.ObfsTypeSalamander,
Password: salamanderPassword,
}
}
startInstance(t, option.Options{
Inbounds: []option.Inbound{
{
Type: C.TypeMixed,
Tag: "mixed-in",
MixedOptions: option.HTTPMixedInboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
ListenPort: clientPort,
},
},
},
{
Type: C.TypeHysteria2,
Hysteria2Options: option.Hysteria2InboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.NewListenAddress(netip.IPv4Unspecified()),
ListenPort: serverPort,
},
Obfs: obfs,
Users: []option.Hysteria2User{{
Password: "password",
}},
TLS: &option.InboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
KeyPath: keyPem,
},
},
},
},
Outbounds: []option.Outbound{
{
Type: C.TypeDirect,
},
{
Type: C.TypeHysteria2,
Tag: "hy2-out",
Hysteria2Options: option.Hysteria2OutboundOptions{
ServerOptions: option.ServerOptions{
Server: "127.0.0.1",
ServerPort: serverPort,
},
Obfs: obfs,
Password: "password",
TLS: &option.OutboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
},
},
},
},
Route: &option.RouteOptions{
Rules: []option.Rule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultRule{
Inbound: []string{"mixed-in"},
Outbound: "hy2-out",
},
},
},
},
})
testSuit(t, clientPort, testPort)
}

View File

@ -0,0 +1,306 @@
package hysteria2
import (
"context"
"io"
"net"
"net/http"
"net/url"
"os"
"runtime"
"sync"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/common/qtls"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/transport/hysteria2/congestion"
"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const (
defaultStreamReceiveWindow = 8388608 // 8MB
defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB
defaultMaxIdleTimeout = 30 * time.Second
defaultKeepAlivePeriod = 10 * time.Second
)
type ClientOptions struct {
Context context.Context
Dialer N.Dialer
ServerAddress M.Socksaddr
SendBPS uint64
ReceiveBPS uint64
SalamanderPassword string
Password string
TLSConfig tls.Config
UDPDisabled bool
}
type Client struct {
ctx context.Context
dialer N.Dialer
serverAddr M.Socksaddr
sendBPS uint64
receiveBPS uint64
salamanderPassword string
password string
tlsConfig tls.Config
quicConfig *quic.Config
udpDisabled bool
connAccess sync.RWMutex
conn *clientQUICConnection
}
func NewClient(options ClientOptions) (*Client, error) {
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
MaxDatagramFrameSize: 1400,
EnableDatagrams: true,
InitialStreamReceiveWindow: defaultStreamReceiveWindow,
MaxStreamReceiveWindow: defaultStreamReceiveWindow,
InitialConnectionReceiveWindow: defaultConnReceiveWindow,
MaxConnectionReceiveWindow: defaultConnReceiveWindow,
MaxIdleTimeout: defaultMaxIdleTimeout,
KeepAlivePeriod: defaultKeepAlivePeriod,
}
return &Client{
ctx: options.Context,
dialer: options.Dialer,
serverAddr: options.ServerAddress,
sendBPS: options.SendBPS,
receiveBPS: options.ReceiveBPS,
salamanderPassword: options.SalamanderPassword,
password: options.Password,
tlsConfig: options.TLSConfig,
quicConfig: quicConfig,
udpDisabled: options.UDPDisabled,
}, nil
}
func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
conn := c.conn
if conn != nil && conn.active() {
return conn, nil
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
conn = c.conn
if conn != nil && conn.active() {
return conn, nil
}
conn, err := c.offerNew(ctx)
if err != nil {
return nil, err
}
return conn, nil
}
func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr)
if err != nil {
return nil, err
}
var packetConn net.PacketConn
packetConn = bufio.NewUnbindPacketConn(udpConn)
if c.salamanderPassword != "" {
packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword))
}
var quicConn quic.EarlyConnection
http3Transport, err := qtls.CreateTransport(packetConn, &quicConn, c.serverAddr, c.tlsConfig, c.quicConfig, true)
if err != nil {
udpConn.Close()
return nil, err
}
request := &http.Request{
Method: http.MethodPost,
URL: &url.URL{
Scheme: "https",
Host: protocol.URLHost,
Path: protocol.URLPath,
},
Header: make(http.Header),
}
protocol.AuthRequestToHeader(request.Header, protocol.AuthRequest{Auth: c.password, Rx: c.receiveBPS})
response, err := http3Transport.RoundTrip(request)
if err != nil {
if quicConn != nil {
quicConn.CloseWithError(0, "")
}
udpConn.Close()
return nil, err
}
if response.StatusCode != protocol.StatusAuthOK {
if quicConn != nil {
quicConn.CloseWithError(0, "")
}
udpConn.Close()
return nil, E.New("authentication failed, status code: ", response.StatusCode)
}
response.Body.Close()
authResponse := protocol.AuthResponseFromHeader(response.Header)
actualTx := authResponse.Rx
if actualTx == 0 || actualTx > c.sendBPS {
actualTx = c.sendBPS
}
if !authResponse.RxAuto && actualTx > 0 {
quicConn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
} else {
quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
tuicCongestion.DefaultClock{},
tuicCongestion.GetInitialPacketSize(quicConn.RemoteAddr()),
tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
))
}
conn := &clientQUICConnection{
quicConn: quicConn,
rawConn: udpConn,
connDone: make(chan struct{}),
udpDisabled: c.udpDisabled || !authResponse.UDPEnabled,
udpConnMap: make(map[uint32]*udpPacketConn),
}
if !c.udpDisabled {
go c.loopMessages(conn)
}
c.conn = conn
return conn, nil
}
func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
conn, err := c.offer(ctx)
if err != nil {
return nil, err
}
stream, err := conn.quicConn.OpenStream()
if err != nil {
return nil, err
}
return &clientConn{
Stream: stream,
destination: destination,
}, nil
}
func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
if c.udpDisabled {
return nil, os.ErrInvalid
}
conn, err := c.offer(ctx)
if err != nil {
return nil, err
}
if conn.udpDisabled {
return nil, E.New("UDP disabled by server")
}
var sessionID uint32
clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, func() {
conn.udpAccess.Lock()
delete(conn.udpConnMap, sessionID)
conn.udpAccess.Unlock()
})
conn.udpAccess.Lock()
sessionID = conn.udpSessionID
conn.udpSessionID++
conn.udpConnMap[sessionID] = clientPacketConn
conn.udpAccess.Unlock()
clientPacketConn.sessionID = sessionID
return clientPacketConn, nil
}
func (c *Client) CloseWithError(err error) error {
conn := c.conn
if conn != nil {
conn.closeWithError(err)
}
return nil
}
type clientQUICConnection struct {
quicConn quic.Connection
rawConn io.Closer
closeOnce sync.Once
connDone chan struct{}
connErr error
udpDisabled bool
udpAccess sync.RWMutex
udpConnMap map[uint32]*udpPacketConn
udpSessionID uint32
}
func (c *clientQUICConnection) active() bool {
select {
case <-c.quicConn.Context().Done():
return false
default:
}
select {
case <-c.connDone:
return false
default:
}
return true
}
func (c *clientQUICConnection) closeWithError(err error) {
c.closeOnce.Do(func() {
c.connErr = err
close(c.connDone)
c.quicConn.CloseWithError(0, "")
})
}
type clientConn struct {
quic.Stream
destination M.Socksaddr
requestWritten bool
responseRead bool
}
func (c *clientConn) NeedHandshake() bool {
return !c.requestWritten
}
func (c *clientConn) Read(p []byte) (n int, err error) {
if c.responseRead {
return c.Stream.Read(p)
}
status, errorMessage, err := protocol.ReadTCPResponse(c.Stream)
if err != nil {
return
}
if !status {
err = E.New("remote error: ", errorMessage)
return
}
c.responseRead = true
return c.Stream.Read(p)
}
func (c *clientConn) Write(p []byte) (n int, err error) {
if !c.requestWritten {
buffer := protocol.WriteTCPRequest(c.destination.String(), p)
defer buffer.Release()
_, err = c.Stream.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(p), nil
}
return c.Stream.Write(p)
}
func (c *clientConn) LocalAddr() net.Addr {
return M.Socksaddr{}
}
func (c *clientConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}

View File

@ -0,0 +1,47 @@
package hysteria2
import E "github.com/sagernet/sing/common/exceptions"
func (c *Client) loopMessages(conn *clientQUICConnection) {
for {
message, err := conn.quicConn.ReceiveMessage(c.ctx)
if err != nil {
conn.closeWithError(E.Cause(err, "receive message"))
return
}
go func() {
hErr := c.handleMessage(conn, message)
if hErr != nil {
conn.closeWithError(E.Cause(hErr, "handle message"))
}
}()
}
}
func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
message := udpMessagePool.Get().(*udpMessage)
err := decodeUDPMessage(message, data)
if err != nil {
message.release()
return E.Cause(err, "decode UDP message")
}
conn.handleUDPMessage(message)
return nil
}
func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
c.udpAccess.RLock()
udpConn, loaded := c.udpConnMap[message.sessionID]
c.udpAccess.RUnlock()
if !loaded {
message.releaseMessage()
return
}
select {
case <-udpConn.ctx.Done():
message.releaseMessage()
return
default:
}
udpConn.inputPacket(message)
}

View File

@ -0,0 +1,151 @@
package congestion
import (
"time"
"github.com/sagernet/quic-go/congestion"
)
const (
initMaxDatagramSize = 1252
pktInfoSlotCount = 4
minSampleCount = 50
minAckRate = 0.8
)
var _ congestion.CongestionControl = &BrutalSender{}
type BrutalSender struct {
rttStats congestion.RTTStatsProvider
bps congestion.ByteCount
maxDatagramSize congestion.ByteCount
pacer *pacer
pktInfoSlots [pktInfoSlotCount]pktInfo
ackRate float64
}
type pktInfo struct {
Timestamp int64
AckCount uint64
LossCount uint64
}
func NewBrutalSender(bps uint64) *BrutalSender {
bs := &BrutalSender{
bps: congestion.ByteCount(bps),
maxDatagramSize: initMaxDatagramSize,
ackRate: 1,
}
bs.pacer = newPacer(func() congestion.ByteCount {
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
})
return bs
}
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
b.rttStats = rttStats
}
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
return b.pacer.TimeUntilSend()
}
func (b *BrutalSender) HasPacingBudget(now time.Time) bool {
return b.pacer.Budget(now) >= b.maxDatagramSize
}
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
return bytesInFlight < b.GetCongestionWindow()
}
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
rtt := b.rttStats.SmoothedRTT()
if rtt <= 0 {
return 10240
}
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
}
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
) {
b.pacer.SentPacket(sentTime, bytes)
}
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
priorInFlight congestion.ByteCount, eventTime time.Time,
) {
currentTimestamp := eventTime.Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].AckCount++
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 1
b.pktInfoSlots[slot].LossCount = 0
}
b.updateAckRate(currentTimestamp)
}
func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
priorInFlight congestion.ByteCount,
) {
currentTimestamp := time.Now().Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].LossCount++
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 0
b.pktInfoSlots[slot].LossCount = 1
}
b.updateAckRate(currentTimestamp)
}
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
b.maxDatagramSize = size
b.pacer.SetMaxDatagramSize(size)
}
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
minTimestamp := currentTimestamp - pktInfoSlotCount
var ackCount, lossCount uint64
for _, info := range b.pktInfoSlots {
if info.Timestamp < minTimestamp {
continue
}
ackCount += info.AckCount
lossCount += info.LossCount
}
if ackCount+lossCount < minSampleCount {
b.ackRate = 1
}
rate := float64(ackCount) / float64(ackCount+lossCount)
if rate < minAckRate {
b.ackRate = minAckRate
}
b.ackRate = rate
}
func (b *BrutalSender) InSlowStart() bool {
return false
}
func (b *BrutalSender) InRecovery() bool {
return false
}
func (b *BrutalSender) MaybeExitSlowStart() {}
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
func maxDuration(a, b time.Duration) time.Duration {
if a > b {
return a
}
return b
}

View File

@ -0,0 +1,86 @@
package congestion
import (
"math"
"time"
"github.com/sagernet/quic-go/congestion"
)
const (
maxBurstPackets = 10
minPacingDelay = time.Millisecond
)
// The pacer implements a token bucket pacing algorithm.
type pacer struct {
budgetAtLastSent congestion.ByteCount
maxDatagramSize congestion.ByteCount
lastSentTime time.Time
getBandwidth func() congestion.ByteCount // in bytes/s
}
func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
p := &pacer{
budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
maxDatagramSize: initMaxDatagramSize,
getBandwidth: getBandwidth,
}
return p
}
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
budget := p.Budget(sendTime)
if size > budget {
p.budgetAtLastSent = 0
} else {
p.budgetAtLastSent = budget - size
}
p.lastSentTime = sendTime
}
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
if p.lastSentTime.IsZero() {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
return minByteCount(p.maxBurstSize(), budget)
}
func (p *pacer) maxBurstSize() congestion.ByteCount {
return maxByteCount(
congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
maxBurstPackets*p.maxDatagramSize,
)
}
// TimeUntilSend returns when the next packet should be sent.
// It returns the zero value of time.Time if a packet can be sent immediately.
func (p *pacer) TimeUntilSend() time.Time {
if p.budgetAtLastSent >= p.maxDatagramSize {
return time.Time{}
}
return p.lastSentTime.Add(maxDuration(
minPacingDelay,
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
float64(p.getBandwidth())))*time.Nanosecond,
))
}
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
p.maxDatagramSize = s
}
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
if a < b {
return b
}
return a
}
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
if a < b {
return a
}
return b
}

View File

@ -0,0 +1,68 @@
package protocol
import (
"net/http"
"strconv"
)
const (
URLHost = "hysteria"
URLPath = "/auth"
RequestHeaderAuth = "Hysteria-Auth"
ResponseHeaderUDPEnabled = "Hysteria-UDP"
CommonHeaderCCRX = "Hysteria-CC-RX"
CommonHeaderPadding = "Hysteria-Padding"
StatusAuthOK = 233
)
// AuthRequest is what client sends to server for authentication.
type AuthRequest struct {
Auth string
Rx uint64 // 0 = unknown, client asks server to use bandwidth detection
}
// AuthResponse is what server sends to client when authentication is passed.
type AuthResponse struct {
UDPEnabled bool
Rx uint64 // 0 = unlimited
RxAuto bool // true = server asks client to use bandwidth detection
}
func AuthRequestFromHeader(h http.Header) AuthRequest {
rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
return AuthRequest{
Auth: h.Get(RequestHeaderAuth),
Rx: rx,
}
}
func AuthRequestToHeader(h http.Header, req AuthRequest) {
h.Set(RequestHeaderAuth, req.Auth)
h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10))
h.Set(CommonHeaderPadding, authRequestPadding.String())
}
func AuthResponseFromHeader(h http.Header) AuthResponse {
resp := AuthResponse{}
resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
rxStr := h.Get(CommonHeaderCCRX)
if rxStr == "auto" {
// Special case for server requesting client to use bandwidth detection
resp.RxAuto = true
} else {
resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64)
}
return resp
}
func AuthResponseToHeader(h http.Header, resp AuthResponse) {
h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled))
if resp.RxAuto {
h.Set(CommonHeaderCCRX, "auto")
} else {
h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10))
}
h.Set(CommonHeaderPadding, authResponsePadding.String())
}

View File

@ -0,0 +1,31 @@
package protocol
import (
"math/rand"
)
const (
paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
// padding specifies a half-open range [Min, Max).
type padding struct {
Min int
Max int
}
func (p padding) String() string {
n := p.Min + rand.Intn(p.Max-p.Min)
bs := make([]byte, n)
for i := range bs {
bs[i] = paddingChars[rand.Intn(len(paddingChars))]
}
return string(bs)
}
var (
authRequestPadding = padding{Min: 256, Max: 2048}
authResponsePadding = padding{Min: 256, Max: 2048}
tcpRequestPadding = padding{Min: 64, Max: 512}
tcpResponsePadding = padding{Min: 128, Max: 1024}
)

View File

@ -0,0 +1,266 @@
package protocol
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/sagernet/quic-go/quicvarint"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
)
const (
FrameTypeTCPRequest = 0x401
// Max length values are for preventing DoS attacks
MaxAddressLength = 2048
MaxMessageLength = 2048
MaxPaddingLength = 4096
MaxUDPSize = 4096
maxVarInt1 = 63
maxVarInt2 = 16383
maxVarInt4 = 1073741823
maxVarInt8 = 4611686018427387903
)
// TCPRequest format:
// 0x401 (QUIC varint)
// Address length (QUIC varint)
// Address (bytes)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadTCPRequest(r io.Reader) (string, error) {
bReader := quicvarint.NewReader(r)
addrLen, err := quicvarint.Read(bReader)
if err != nil {
return "", err
}
if addrLen == 0 || addrLen > MaxAddressLength {
return "", E.New("invalid address length")
}
addrBuf := make([]byte, addrLen)
_, err = io.ReadFull(r, addrBuf)
if err != nil {
return "", err
}
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return "", err
}
if paddingLen > MaxPaddingLength {
return "", E.New("invalid padding length")
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return "", err
}
}
return string(addrBuf), nil
}
func WriteTCPRequest(addr string, payload []byte) *buf.Buffer {
padding := tcpRequestPadding.String()
paddingLen := len(padding)
addrLen := len(addr)
sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
int(quicvarint.Len(uint64(addrLen))) + addrLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buffer := buf.NewSize(sz + len(payload))
bufferContent := buffer.Extend(sz)
i := varintPut(bufferContent, FrameTypeTCPRequest)
i += varintPut(bufferContent[i:], uint64(addrLen))
i += copy(bufferContent[i:], addr)
i += varintPut(bufferContent[i:], uint64(paddingLen))
copy(bufferContent[i:], padding)
buffer.Write(payload)
return buffer
}
// TCPResponse format:
// Status (byte, 0=ok, 1=error)
// Message length (QUIC varint)
// Message (bytes)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadTCPResponse(r io.Reader) (bool, string, error) {
var status [1]byte
if _, err := io.ReadFull(r, status[:]); err != nil {
return false, "", err
}
bReader := quicvarint.NewReader(r)
msg, err := ReadVString(bReader)
if err != nil {
return false, "", err
}
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return false, "", err
}
if paddingLen > MaxPaddingLength {
return false, "", E.New("invalid padding length")
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return false, "", err
}
}
return status[0] == 0, msg, nil
}
func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer {
padding := tcpResponsePadding.String()
paddingLen := len(padding)
msgLen := len(msg)
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buffer := buf.NewSize(sz + len(payload))
if ok {
buffer.WriteByte(0)
} else {
buffer.WriteByte(1)
}
WriteVString(buffer, msg)
WriteUVariant(buffer, uint64(paddingLen))
buffer.Extend(paddingLen)
buffer.Write(payload)
return buffer
}
// UDPMessage format:
// Session ID (uint32 BE)
// Packet ID (uint16 BE)
// Fragment ID (uint8)
// Fragment count (uint8)
// Address length (QUIC varint)
// Address (bytes)
// Data...
type UDPMessage struct {
SessionID uint32 // 4
PacketID uint16 // 2
FragID uint8 // 1
FragCount uint8 // 1
Addr string // varint + bytes
Data []byte
}
func (m *UDPMessage) HeaderSize() int {
lAddr := len(m.Addr)
return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
}
func (m *UDPMessage) Size() int {
return m.HeaderSize() + len(m.Data)
}
func (m *UDPMessage) Serialize(buf []byte) int {
// Make sure the buffer is big enough
if len(buf) < m.Size() {
return -1
}
binary.BigEndian.PutUint32(buf, m.SessionID)
binary.BigEndian.PutUint16(buf[4:], m.PacketID)
buf[6] = m.FragID
buf[7] = m.FragCount
i := varintPut(buf[8:], uint64(len(m.Addr)))
i += copy(buf[8+i:], m.Addr)
i += copy(buf[8+i:], m.Data)
return 8 + i
}
func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
m := &UDPMessage{}
buf := bytes.NewBuffer(msg)
if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
return nil, err
}
if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
return nil, err
}
if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
return nil, err
}
if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
return nil, err
}
lAddr, err := quicvarint.Read(buf)
if err != nil {
return nil, err
}
if lAddr == 0 || lAddr > MaxMessageLength {
return nil, E.New("invalid address length")
}
bs := buf.Bytes()
m.Addr = string(bs[:lAddr])
m.Data = bs[lAddr:]
return m, nil
}
func ReadVString(reader io.Reader) (string, error) {
length, err := quicvarint.Read(quicvarint.NewReader(reader))
if err != nil {
return "", err
}
value, err := rw.ReadBytes(reader, int(length))
if err != nil {
return "", err
}
return string(value), nil
}
func WriteVString(writer io.Writer, value string) error {
err := WriteUVariant(writer, uint64(len(value)))
if err != nil {
return err
}
return rw.WriteString(writer, value)
}
func WriteUVariant(writer io.Writer, value uint64) error {
var b [8]byte
return common.Error(writer.Write(b[:varintPut(b[:], value)]))
}
// varintPut is like quicvarint.Append, but instead of appending to a slice,
// it writes to a fixed-size buffer. Returns the number of bytes written.
func varintPut(b []byte, i uint64) int {
if i <= maxVarInt1 {
b[0] = uint8(i)
return 1
}
if i <= maxVarInt2 {
b[0] = uint8(i>>8) | 0x40
b[1] = uint8(i)
return 2
}
if i <= maxVarInt4 {
b[0] = uint8(i>>24) | 0x80
b[1] = uint8(i >> 16)
b[2] = uint8(i >> 8)
b[3] = uint8(i)
return 4
}
if i <= maxVarInt8 {
b[0] = uint8(i>>56) | 0xc0
b[1] = uint8(i >> 48)
b[2] = uint8(i >> 40)
b[3] = uint8(i >> 32)
b[4] = uint8(i >> 24)
b[5] = uint8(i >> 16)
b[6] = uint8(i >> 8)
b[7] = uint8(i)
return 8
}
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
}

View File

@ -0,0 +1,438 @@
package hysteria2
import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
"math"
"net"
"os"
"sync"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/quicvarint"
"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
M "github.com/sagernet/sing/common/metadata"
)
var udpMessagePool = sync.Pool{
New: func() interface{} {
return new(udpMessage)
},
}
func releaseMessages(messages []*udpMessage) {
for _, message := range messages {
if message != nil {
*message = udpMessage{}
udpMessagePool.Put(message)
}
}
}
type udpMessage struct {
sessionID uint32
packetID uint16
fragmentID uint8
fragmentTotal uint8
destination string
data *buf.Buffer
}
func (m *udpMessage) release() {
*m = udpMessage{}
udpMessagePool.Put(m)
}
func (m *udpMessage) releaseMessage() {
m.data.Release()
m.release()
}
func (m *udpMessage) pack() *buf.Buffer {
buffer := buf.NewSize(m.headerSize() + m.data.Len())
common.Must(
binary.Write(buffer, binary.BigEndian, m.sessionID),
binary.Write(buffer, binary.BigEndian, m.packetID),
binary.Write(buffer, binary.BigEndian, m.fragmentID),
binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
protocol.WriteVString(buffer, m.destination),
common.Error(buffer.Write(m.data.Bytes())),
)
return buffer
}
func (m *udpMessage) headerSize() int {
return 8 + int(quicvarint.Len(uint64(len(m.destination)))) + len(m.destination)
}
func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
if message.data.Len() <= maxPacketSize {
return []*udpMessage{message}
}
var fragments []*udpMessage
originPacket := message.data.Bytes()
udpMTU := maxPacketSize - message.headerSize()
for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
fragment := udpMessagePool.Get().(*udpMessage)
*fragment = *message
if remaining > udpMTU {
fragment.data = buf.As(originPacket[:udpMTU])
originPacket = originPacket[udpMTU:]
} else {
fragment.data = buf.As(originPacket)
originPacket = nil
}
fragments = append(fragments, fragment)
}
fragmentTotal := uint16(len(fragments))
for index, fragment := range fragments {
fragment.fragmentID = uint8(index)
fragment.fragmentTotal = uint8(fragmentTotal)
/*if index > 0 {
fragment.destination = ""
// not work in hysteria
}*/
}
return fragments
}
type udpPacketConn struct {
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint32
quicConn quic.Connection
data chan *udpMessage
udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32
closeOnce sync.Once
defragger *udpDefragger
onDestroy func()
}
func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn {
ctx, cancel := common.ContextWithCancelCause(ctx)
return &udpPacketConn{
ctx: ctx,
cancel: cancel,
quicConn: quicConn,
data: make(chan *udpMessage, 64),
defragger: newUDPDefragger(),
onDestroy: onDestroy,
}
}
func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
buffer = p.data
destination = M.ParseSocksaddr(p.destination)
p.release()
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}
func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
_, err = buffer.ReadOnceFrom(p.data)
destination = M.ParseSocksaddr(p.destination)
p.releaseMessage()
return
case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe
}
}
func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
_, err = newBuffer().ReadOnceFrom(p.data)
destination = M.ParseSocksaddr(p.destination)
p.releaseMessage()
return
case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe
}
}
func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case pkt := <-c.data:
n = copy(p, pkt.data.Bytes())
destination := M.ParseSocksaddr(pkt.destination)
if destination.IsFqdn() {
addr = destination
} else {
addr = destination.UDPAddr()
}
pkt.releaseMessage()
return n, addr, nil
case <-c.ctx.Done():
return 0, nil, io.ErrClosedPipe
}
}
func (c *udpPacketConn) needFragment() bool {
nowTime := time.Now()
if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
c.udpMTUTime = nowTime
return true
}
return false
}
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
select {
case <-c.ctx.Done():
return net.ErrClosed
default:
}
if buffer.Len() > 0xffff {
return quic.ErrMessageTooLarge(0xffff)
}
packetId := c.packetId.Add(1)
if packetId > math.MaxUint16 {
c.packetId.Store(0)
packetId = 0
}
message := udpMessagePool.Get().(*udpMessage)
*message = udpMessage{
sessionID: c.sessionID,
packetID: uint16(packetId),
fragmentTotal: 1,
destination: destination.String(),
data: buffer,
}
defer message.releaseMessage()
var err error
if c.needFragment() && buffer.Len() > c.udpMTU {
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
} else {
err = c.writePacket(message)
}
if err == nil {
return nil
}
var tooLargeErr quic.ErrMessageTooLarge
if !errors.As(err, &tooLargeErr) {
return err
}
c.udpMTU = int(tooLargeErr)
c.udpMTUTime = time.Now()
return c.writePackets(fragUDPMessage(message, c.udpMTU))
}
func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
select {
case <-c.ctx.Done():
return 0, net.ErrClosed
default:
}
if len(p) > 0xffff {
return 0, quic.ErrMessageTooLarge(0xffff)
}
packetId := c.packetId.Add(1)
if packetId > math.MaxUint16 {
c.packetId.Store(0)
packetId = 0
}
message := udpMessagePool.Get().(*udpMessage)
*message = udpMessage{
sessionID: c.sessionID,
packetID: uint16(packetId),
fragmentTotal: 1,
destination: addr.String(),
data: buf.As(p),
}
if c.needFragment() && len(p) > c.udpMTU {
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil {
return len(p), nil
}
} else {
err = c.writePacket(message)
}
if err == nil {
return len(p), nil
}
var tooLargeErr quic.ErrMessageTooLarge
if !errors.As(err, &tooLargeErr) {
return
}
c.udpMTU = int(tooLargeErr)
c.udpMTUTime = time.Now()
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil {
return len(p), nil
}
return
}
func (c *udpPacketConn) inputPacket(message *udpMessage) {
if message.fragmentTotal <= 1 {
select {
case c.data <- message:
default:
}
} else {
newMessage := c.defragger.feed(message)
if newMessage != nil {
select {
case c.data <- newMessage:
default:
}
}
}
}
func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
defer releaseMessages(messages)
for _, message := range messages {
err := c.writePacket(message)
if err != nil {
return err
}
}
return nil
}
func (c *udpPacketConn) writePacket(message *udpMessage) error {
buffer := message.pack()
defer buffer.Release()
return c.quicConn.SendMessage(buffer.Bytes())
}
func (c *udpPacketConn) Close() error {
c.closeOnce.Do(func() {
c.closeWithError(os.ErrClosed)
c.onDestroy()
})
return nil
}
func (c *udpPacketConn) closeWithError(err error) {
c.cancel(err)
}
func (c *udpPacketConn) LocalAddr() net.Addr {
return c.quicConn.LocalAddr()
}
func (c *udpPacketConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
type udpDefragger struct {
packetMap *cache.LruCache[uint16, *packetItem]
}
func newUDPDefragger() *udpDefragger {
return &udpDefragger{
packetMap: cache.New(
cache.WithAge[uint16, *packetItem](10),
cache.WithUpdateAgeOnGet[uint16, *packetItem](),
cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
releaseMessages(value.messages)
}),
),
}
}
type packetItem struct {
access sync.Mutex
messages []*udpMessage
count uint8
}
func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
if m.fragmentTotal <= 1 {
return m
}
if m.fragmentID >= m.fragmentTotal {
return nil
}
item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
item.access.Lock()
defer item.access.Unlock()
if int(m.fragmentTotal) != len(item.messages) {
releaseMessages(item.messages)
item.messages = make([]*udpMessage, m.fragmentTotal)
item.count = 1
item.messages[m.fragmentID] = m
return nil
}
if item.messages[m.fragmentID] != nil {
return nil
}
item.messages[m.fragmentID] = m
item.count++
if int(item.count) != len(item.messages) {
return nil
}
newMessage := udpMessagePool.Get().(*udpMessage)
*newMessage = *item.messages[0]
var finalLength int
for _, message := range item.messages {
finalLength += message.data.Len()
}
if finalLength > 0 {
newMessage.data = buf.NewSize(finalLength)
for _, message := range item.messages {
newMessage.data.Write(message.data.Bytes())
message.releaseMessage()
}
item.messages = nil
return newMessage
}
return nil
}
func newPacketItem() *packetItem {
return new(packetItem)
}
func decodeUDPMessage(message *udpMessage, data []byte) error {
reader := bytes.NewReader(data)
err := binary.Read(reader, binary.BigEndian, &message.sessionID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.packetID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
if err != nil {
return err
}
message.destination, err = protocol.ReadVString(reader)
if err != nil {
return err
}
message.data = buf.As(data[len(data)-reader.Len():])
return nil
}

View File

@ -0,0 +1,106 @@
package hysteria2
import (
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/crypto/blake2b"
)
const salamanderSaltLen = 8
const ObfsTypeSalamander = "salamander"
type Salamander struct {
net.PacketConn
password []byte
}
func NewSalamanderConn(conn net.PacketConn, password []byte) net.PacketConn {
writer, isVectorised := bufio.CreateVectorisedPacketWriter(conn)
if isVectorised {
return &VectorisedSalamander{
Salamander: Salamander{
PacketConn: conn,
password: password,
},
writer: writer,
}
} else {
return &Salamander{
PacketConn: conn,
password: password,
}
}
}
func (s *Salamander) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = s.PacketConn.ReadFrom(p)
if err != nil {
return
}
if n <= salamanderSaltLen {
return 0, nil, E.New("salamander: packet too short")
}
key := blake2b.Sum256(append(s.password, p[:salamanderSaltLen]...))
for index, c := range p[salamanderSaltLen:n] {
p[index] = c ^ key[index%blake2b.Size256]
}
return n - salamanderSaltLen, addr, nil
}
func (s *Salamander) WriteTo(p []byte, addr net.Addr) (n int, err error) {
buffer := buf.NewSize(len(p) + salamanderSaltLen)
defer buffer.Release()
buffer.WriteRandom(salamanderSaltLen)
key := blake2b.Sum256(append(s.password, buffer.Bytes()...))
for index, c := range p {
common.Must(buffer.WriteByte(c ^ key[index%blake2b.Size256]))
}
_, err = s.PacketConn.WriteTo(buffer.Bytes(), addr)
if err != nil {
return
}
return len(p), nil
}
type VectorisedSalamander struct {
Salamander
writer N.VectorisedPacketWriter
}
func (s *VectorisedSalamander) WriteTo(p []byte, addr net.Addr) (n int, err error) {
buffer := buf.NewSize(salamanderSaltLen)
buffer.WriteRandom(salamanderSaltLen)
key := blake2b.Sum256(append(s.password, buffer.Bytes()...))
for i := range p {
p[i] ^= key[i%blake2b.Size256]
}
err = s.writer.WriteVectorisedPacket([]*buf.Buffer{buffer, buf.As(p)}, M.SocksaddrFromNet(addr))
if err != nil {
return
}
return len(p), nil
}
func (s *VectorisedSalamander) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
header := buf.NewSize(salamanderSaltLen)
defer header.Release()
header.WriteRandom(salamanderSaltLen)
key := blake2b.Sum256(append(s.password, header.Bytes()...))
var bufferIndex int
for _, buffer := range buffers {
content := buffer.Bytes()
for index, c := range content {
content[bufferIndex+index] = c ^ key[bufferIndex+index%blake2b.Size256]
}
bufferIndex += len(content)
}
return s.writer.WriteVectorisedPacket(append([]*buf.Buffer{header}, buffers...), destination)
}

View File

@ -0,0 +1,336 @@
package hysteria2
import (
"context"
"io"
"net"
"net/http"
"os"
"runtime"
"strings"
"sync"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-box/common/baderror"
"github.com/sagernet/sing-box/common/qtls"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/transport/hysteria2/congestion"
"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type ServerOptions struct {
Context context.Context
Logger logger.Logger
SendBPS uint64
ReceiveBPS uint64
IgnoreClientBandwidth bool
SalamanderPassword string
TLSConfig tls.ServerConfig
Users []User
UDPDisabled bool
Handler ServerHandler
MasqueradeHandler http.Handler
}
type User struct {
Name string
Password string
}
type ServerHandler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
}
type Server struct {
ctx context.Context
logger logger.Logger
sendBPS uint64
receiveBPS uint64
ignoreClientBandwidth bool
salamanderPassword string
tlsConfig tls.ServerConfig
quicConfig *quic.Config
userMap map[string]User
udpDisabled bool
handler ServerHandler
masqueradeHandler http.Handler
quicListener io.Closer
}
func NewServer(options ServerOptions) (*Server, error) {
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
MaxDatagramFrameSize: 1400,
EnableDatagrams: !options.UDPDisabled,
MaxIncomingStreams: 1 << 60,
InitialStreamReceiveWindow: defaultStreamReceiveWindow,
MaxStreamReceiveWindow: defaultStreamReceiveWindow,
InitialConnectionReceiveWindow: defaultConnReceiveWindow,
MaxConnectionReceiveWindow: defaultConnReceiveWindow,
MaxIdleTimeout: defaultMaxIdleTimeout,
KeepAlivePeriod: defaultKeepAlivePeriod,
}
if len(options.Users) == 0 {
return nil, E.New("missing users")
}
userMap := make(map[string]User)
for _, user := range options.Users {
userMap[user.Password] = user
}
if options.MasqueradeHandler == nil {
options.MasqueradeHandler = http.NotFoundHandler()
}
return &Server{
ctx: options.Context,
logger: options.Logger,
sendBPS: options.SendBPS,
receiveBPS: options.ReceiveBPS,
ignoreClientBandwidth: options.IgnoreClientBandwidth,
salamanderPassword: options.SalamanderPassword,
tlsConfig: options.TLSConfig,
quicConfig: quicConfig,
userMap: userMap,
udpDisabled: options.UDPDisabled,
handler: options.Handler,
masqueradeHandler: options.MasqueradeHandler,
}, nil
}
func (s *Server) Start(conn net.PacketConn) error {
if s.salamanderPassword != "" {
conn = NewSalamanderConn(conn, []byte(s.salamanderPassword))
}
err := qtls.ConfigureHTTP3(s.tlsConfig)
if err != nil {
return err
}
listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
if err != nil {
return err
}
s.quicListener = listener
go s.loopConnections(listener)
return nil
}
func (s *Server) Close() error {
return common.Close(
s.quicListener,
)
}
func (s *Server) loopConnections(listener qtls.QUICListener) {
for {
connection, err := listener.Accept(s.ctx)
if err != nil {
if strings.Contains(err.Error(), "server closed") {
s.logger.Debug(E.Cause(err, "listener closed"))
} else {
s.logger.Error(E.Cause(err, "listener closed"))
}
return
}
go s.handleConnection(connection)
}
}
func (s *Server) handleConnection(connection quic.Connection) {
session := &serverSession{
Server: s,
ctx: s.ctx,
quicConn: connection,
source: M.SocksaddrFromNet(connection.RemoteAddr()),
connDone: make(chan struct{}),
udpConnMap: make(map[uint32]*udpPacketConn),
}
httpServer := http3.Server{
Handler: session,
StreamHijacker: session.handleStream0,
}
_ = httpServer.ServeQUICConn(connection)
_ = connection.CloseWithError(0, "")
}
type serverSession struct {
*Server
ctx context.Context
quicConn quic.Connection
source M.Socksaddr
connAccess sync.Mutex
connDone chan struct{}
connErr error
authenticated bool
authUser *User
udpAccess sync.RWMutex
udpConnMap map[uint32]*udpPacketConn
}
func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
if s.authenticated {
protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
UDPEnabled: !s.udpDisabled,
Rx: s.receiveBPS,
RxAuto: s.ignoreClientBandwidth,
})
w.WriteHeader(protocol.StatusAuthOK)
return
}
request := protocol.AuthRequestFromHeader(r.Header)
user, loaded := s.userMap[request.Auth]
if !loaded {
s.masqueradeHandler.ServeHTTP(w, r)
return
}
s.authUser = &user
s.authenticated = true
if !s.ignoreClientBandwidth && request.Rx > 0 {
var sendBps uint64
if s.sendBPS > 0 && s.sendBPS < request.Rx {
sendBps = s.sendBPS
} else {
sendBps = request.Rx
}
s.quicConn.SetCongestionControl(congestion.NewBrutalSender(sendBps))
} else {
s.quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
tuicCongestion.DefaultClock{},
tuicCongestion.GetInitialPacketSize(s.quicConn.RemoteAddr()),
tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
))
}
protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
UDPEnabled: !s.udpDisabled,
Rx: s.receiveBPS,
RxAuto: s.ignoreClientBandwidth,
})
w.WriteHeader(protocol.StatusAuthOK)
if s.ctx.Done() != nil {
go func() {
select {
case <-s.ctx.Done():
s.closeWithError(s.ctx.Err())
case <-s.connDone:
}
}()
}
if !s.udpDisabled {
go s.loopMessages()
}
} else {
s.masqueradeHandler.ServeHTTP(w, r)
}
}
func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
if !s.authenticated || err != nil {
return false, nil
}
if frameType != protocol.FrameTypeTCPRequest {
return false, nil
}
go func() {
hErr := s.handleStream(stream)
if hErr != nil {
stream.CancelRead(0)
stream.Close()
s.logger.Error(E.Cause(hErr, "handle stream request"))
}
}()
return true, nil
}
func (s *serverSession) handleStream(stream quic.Stream) error {
destinationString, err := protocol.ReadTCPRequest(stream)
if err != nil {
return E.New("read TCP request")
}
var conn net.Conn = &serverConn{
Stream: stream,
}
ctx := s.ctx
if s.authUser.Name != "" {
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
}
_ = s.handler.NewConnection(ctx, conn, M.Metadata{
Source: s.source,
Destination: M.ParseSocksaddr(destinationString),
})
return nil
}
func (s *serverSession) closeWithError(err error) {
s.connAccess.Lock()
defer s.connAccess.Unlock()
select {
case <-s.connDone:
return
default:
s.connErr = err
close(s.connDone)
}
if E.IsClosedOrCanceled(err) {
s.logger.Debug(E.Cause(err, "connection failed"))
} else {
s.logger.Error(E.Cause(err, "connection failed"))
}
_ = s.quicConn.CloseWithError(0, "")
}
type serverConn struct {
quic.Stream
responseWritten bool
}
func (c *serverConn) HandshakeFailure(err error) error {
if c.responseWritten {
return os.ErrClosed
}
c.responseWritten = true
buffer := protocol.WriteTCPResponse(false, err.Error(), nil)
defer buffer.Release()
return common.Error(c.Stream.Write(buffer.Bytes()))
}
func (c *serverConn) Read(p []byte) (n int, err error) {
n, err = c.Stream.Read(p)
return n, baderror.WrapQUIC(err)
}
func (c *serverConn) Write(p []byte) (n int, err error) {
if !c.responseWritten {
c.responseWritten = true
buffer := protocol.WriteTCPResponse(true, "", p)
defer buffer.Release()
_, err = c.Stream.Write(buffer.Bytes())
if err != nil {
return 0, baderror.WrapQUIC(err)
}
return len(p), nil
}
n, err = c.Stream.Write(p)
return n, baderror.WrapQUIC(err)
}
func (c *serverConn) LocalAddr() net.Addr {
return M.Socksaddr{}
}
func (c *serverConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}
func (c *serverConn) Close() error {
c.Stream.CancelRead(0)
return c.Stream.Close()
}

View File

@ -0,0 +1,55 @@
package hysteria2
import (
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
func (s *serverSession) loopMessages() {
for {
message, err := s.quicConn.ReceiveMessage(s.ctx)
if err != nil {
s.closeWithError(E.Cause(err, "receive message"))
return
}
hErr := s.handleMessage(message)
if hErr != nil {
s.closeWithError(E.Cause(hErr, "handle message"))
return
}
}
}
func (s *serverSession) handleMessage(data []byte) error {
message := udpMessagePool.Get().(*udpMessage)
err := decodeUDPMessage(message, data)
if err != nil {
message.release()
return E.Cause(err, "decode UDP message")
}
s.handleUDPMessage(message)
return nil
}
func (s *serverSession) handleUDPMessage(message *udpMessage) {
s.udpAccess.RLock()
udpConn, loaded := s.udpConnMap[message.sessionID]
s.udpAccess.RUnlock()
if !loaded || common.Done(udpConn.ctx) {
udpConn = newUDPPacketConn(s.ctx, s.quicConn, func() {
s.udpAccess.Lock()
delete(s.udpConnMap, message.sessionID)
s.udpAccess.Unlock()
})
udpConn.sessionID = message.sessionID
s.udpAccess.Lock()
s.udpConnMap[message.sessionID] = udpConn
s.udpAccess.Unlock()
go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
Source: s.source,
Destination: M.ParseSocksaddr(message.destination),
})
}
udpConn.inputPacket(message)
}