From 31c294d998081318b752bc8b884751cf22a2238d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Oct 2023 12:00:00 +0800 Subject: [PATCH] Update BBR and Hysteria congestion control & Migrate legacy Hysteria protocol to library --- go.mod | 6 +- go.sum | 8 +- inbound/hysteria.go | 368 ++++++---------------- inbound/hysteria2.go | 3 +- outbound/hysteria.go | 322 ++++---------------- outbound/hysteria2.go | 2 +- test/go.mod | 10 +- test/go.sum | 20 +- test/hysteria_test.go | 4 +- transport/hysteria/frag.go | 65 ---- transport/hysteria/protocol.go | 539 --------------------------------- transport/hysteria/speed.go | 36 --- transport/hysteria/wrap.go | 68 ----- transport/hysteria/xplus.go | 118 -------- transport/v2rayquic/client.go | 3 +- transport/v2rayquic/server.go | 3 +- transport/v2rayquic/stream.go | 41 +++ 17 files changed, 214 insertions(+), 1402 deletions(-) delete mode 100644 transport/hysteria/frag.go delete mode 100644 transport/hysteria/protocol.go delete mode 100644 transport/hysteria/speed.go delete mode 100644 transport/hysteria/wrap.go delete mode 100644 transport/hysteria/xplus.go create mode 100644 transport/v2rayquic/stream.go diff --git a/go.mod b/go.mod index f0c48932..3152bfb5 100644 --- a/go.mod +++ b/go.mod @@ -24,12 +24,12 @@ require ( github.com/sagernet/cloudflare-tls v0.0.0-20230829051644-4a68352d0c4a github.com/sagernet/gomobile v0.0.0-20230915142329-c6740b6d2950 github.com/sagernet/gvisor v0.0.0-20230930141345-5fef6f2e17ab - github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee + github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 github.com/sagernet/sing v0.2.16-0.20231021090846-8002db54c028 github.com/sagernet/sing-dns v0.1.10 github.com/sagernet/sing-mux v0.1.3 - github.com/sagernet/sing-quic v0.1.2 + github.com/sagernet/sing-quic v0.1.3-0.20231026034240-fa3d997246b6 github.com/sagernet/sing-shadowsocks v0.2.5 github.com/sagernet/sing-shadowsocks2 v0.1.4 github.com/sagernet/sing-shadowtls v0.1.4 @@ -45,7 +45,6 @@ require ( go.uber.org/zap v1.26.0 go4.org/netipx v0.0.0-20230824141953-6213f710f925 golang.org/x/crypto v0.14.0 - golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/net v0.17.0 golang.org/x/sys v0.13.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 @@ -85,6 +84,7 @@ require ( github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/mod v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/time v0.3.0 // indirect diff --git a/go.sum b/go.sum index dfc88fd0..8fa62347 100644 --- a/go.sum +++ b/go.sum @@ -104,8 +104,8 @@ github.com/sagernet/gvisor v0.0.0-20230930141345-5fef6f2e17ab h1:u+xQoi/Yc6bNUvT github.com/sagernet/gvisor v0.0.0-20230930141345-5fef6f2e17ab/go.mod h1:3akUhSHSVtLuJaYcW5JPepUraBOW06Ibz2HKwaK5rOk= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= -github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee h1:ykuhl9jCS638N+jw1vC9AvT9bbQn6xRNScP2FWPV9dM= -github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee/go.mod h1:0CfhWwZAeXGYM9+Nkkw1zcQtFHQC8KWjbpeDv7pu8iw= +github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 h1:dAe4OIJAtE0nHOzTHhAReQteh3+sa63rvXbuIpbeOTY= +github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= @@ -116,8 +116,8 @@ github.com/sagernet/sing-dns v0.1.10 h1:iIU7nRBlUYj+fF2TaktGIvRiTFFrHwSMedLQsvlT github.com/sagernet/sing-dns v0.1.10/go.mod h1:vtUimtf7Nq9EdvD5WTpfCr69KL1M7bcgOVKiYBiAY/c= github.com/sagernet/sing-mux v0.1.3 h1:fAf7PZa2A55mCeh0KKM02f1k2Y4vEmxuZZ/51ahkkLA= github.com/sagernet/sing-mux v0.1.3/go.mod h1:wGeIeiiFLx4HUM5LAg65wrNZ/X1muOimqK0PEhNbPi0= -github.com/sagernet/sing-quic v0.1.2 h1:+u9CRf0KHi5HgXmJ3eB0CtqpWXtF0lx2QlWq+ZFZ+XY= -github.com/sagernet/sing-quic v0.1.2/go.mod h1:H1TX0/y9UUM43wyaLQ+qjg2+o901ibYtwWX2rWG+a3o= +github.com/sagernet/sing-quic v0.1.3-0.20231026034240-fa3d997246b6 h1:w+TUbIZKZFSdf/AUa/y33kY9xaLeNGz/tBNcNhqpqfg= +github.com/sagernet/sing-quic v0.1.3-0.20231026034240-fa3d997246b6/go.mod h1:1M7xP4802K9Kz6BQ7LlA7UeCapWvWlH1Htmk2bAqkWc= github.com/sagernet/sing-shadowsocks v0.2.5 h1:qxIttos4xu6ii7MTVJYA8EFQR7Q3KG6xMqmLJIFtBaY= github.com/sagernet/sing-shadowsocks v0.2.5/go.mod h1:MGWGkcU2xW2G2mfArT9/QqpVLOGU+dBaahZCtPHdt7A= github.com/sagernet/sing-shadowsocks2 v0.1.4 h1:vht2M8t3m5DTgXR2j24KbYOygG5aOp+MUhpQnAux728= diff --git a/inbound/hysteria.go b/inbound/hysteria.go index b96327c1..29707f65 100644 --- a/inbound/hysteria.go +++ b/inbound/hysteria.go @@ -4,104 +4,38 @@ package inbound import ( "context" - "sync" + "net" - "github.com/sagernet/quic-go" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/humanize" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/hysteria" - "github.com/sagernet/sing-quic" - hyCC "github.com/sagernet/sing-quic/hysteria2/congestion" + "github.com/sagernet/sing-quic/hysteria" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - - "golang.org/x/exp/slices" ) var _ adapter.Inbound = (*Hysteria)(nil) type Hysteria struct { myInboundAdapter - quicConfig *quic.Config tlsConfig tls.ServerConfig - authKey []string - authUser []string - xplusKey []byte - sendBPS uint64 - recvBPS uint64 - listener qtls.Listener - udpAccess sync.RWMutex - udpSessionId uint32 - udpSessions map[uint32]chan *hysteria.UDPMessage - udpDefragger hysteria.Defragger + service *hysteria.Service[int] + userNameList []string } func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) { options.UDPFragmentDefault = true - quicConfig := &quic.Config{ - InitialStreamReceiveWindow: options.ReceiveWindowConn, - MaxStreamReceiveWindow: options.ReceiveWindowConn, - InitialConnectionReceiveWindow: options.ReceiveWindowClient, - MaxConnectionReceiveWindow: options.ReceiveWindowClient, - MaxIncomingStreams: int64(options.MaxConnClient), - KeepAlivePeriod: hysteria.KeepAlivePeriod, - DisablePathMTUDiscovery: options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows), - EnableDatagrams: true, + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired } - if options.ReceiveWindowConn == 0 { - quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow - quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow - } - if options.ReceiveWindowClient == 0 { - quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow - quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow - } - if quicConfig.MaxIncomingStreams == 0 { - quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams - } - authKey := common.Map(options.Users, func(it option.HysteriaUser) string { - if len(it.Auth) > 0 { - return string(it.Auth) - } else { - return it.AuthString - } - }) - authUser := common.Map(options.Users, func(it option.HysteriaUser) string { - return it.Name - }) - var xplus []byte - if options.Obfs != "" { - xplus = []byte(options.Obfs) - } - var up, down uint64 - if len(options.Up) > 0 { - up = hysteria.StringToBps(options.Up) - if up == 0 { - return nil, E.New("invalid up speed format: ", options.Up) - } - } else { - up = uint64(options.UpMbps) * hysteria.MbpsToBps - } - if len(options.Down) > 0 { - down = hysteria.StringToBps(options.Down) - if down == 0 { - return nil, E.New("invalid down speed format: ", options.Down) - } - } else { - down = uint64(options.DownMbps) * hysteria.MbpsToBps - } - if up < hysteria.MinSpeedBPS { - return nil, E.New("invalid up speed") - } - if down < hysteria.MinSpeedBPS { - return nil, E.New("invalid down speed") + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err } inbound := &Hysteria{ myInboundAdapter: myInboundAdapter{ @@ -113,224 +47,108 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL tag: tag, listenOptions: options.ListenOptions, }, - quicConfig: quicConfig, - authKey: authKey, - authUser: authUser, - xplusKey: xplus, - sendBPS: up, - recvBPS: down, - udpSessions: make(map[uint32]chan *hysteria.UDPMessage), + tlsConfig: tlsConfig, } - if options.TLS == nil || !options.TLS.Enabled { - return nil, C.ErrTLSRequired + var sendBps, receiveBps uint64 + if len(options.Up) > 0 { + sendBps, err = humanize.ParseBytes(options.Up) + if err != nil { + return nil, E.Cause(err, "invalid up speed format: ", options.Up) + } + } else { + sendBps = uint64(options.UpMbps) * hysteria.MbpsToBps } - if len(options.TLS.ALPN) == 0 { - options.TLS.ALPN = []string{hysteria.DefaultALPN} + if len(options.Down) > 0 { + receiveBps, err = humanize.ParseBytes(options.Down) + if receiveBps == 0 { + return nil, E.New("invalid down speed format: ", options.Down) + } + } else { + receiveBps = uint64(options.DownMbps) * hysteria.MbpsToBps } - tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) + service, err := hysteria.NewService[int](hysteria.ServiceOptions{ + Context: ctx, + Logger: logger, + SendBPS: sendBps, + ReceiveBPS: receiveBps, + XPlusPassword: options.Obfs, + TLSConfig: tlsConfig, + Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil), + + // Legacy options + + ConnReceiveWindow: options.ReceiveWindowConn, + StreamReceiveWindow: options.ReceiveWindowClient, + MaxIncomingStreams: int64(options.MaxConnClient), + DisableMTUDiscovery: options.DisableMTUDiscovery, + }) if err != nil { return nil, err } - inbound.tlsConfig = tlsConfig + userList := make([]int, 0, len(options.Users)) + userNameList := make([]string, 0, len(options.Users)) + userPasswordList := make([]string, 0, len(options.Users)) + for index, user := range options.Users { + userList = append(userList, index) + userNameList = append(userNameList, user.Name) + var password string + if user.AuthString != "" { + password = user.AuthString + } else { + password = string(user.Auth) + } + userPasswordList = append(userPasswordList, password) + } + service.UpdateUsers(userList, userPasswordList) + inbound.service = service + inbound.userNameList = userNameList return inbound, nil } +func (h *Hysteria) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + ctx = log.ContextWithNewID(ctx) + metadata = h.createMetadata(conn, metadata) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + } + return h.router.RouteConnection(ctx, conn, metadata) +} + +func (h *Hysteria) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = log.ContextWithNewID(ctx) + metadata = h.createPacketMetadata(conn, metadata) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) + } + return h.router.RoutePacketConnection(ctx, conn, metadata) +} + func (h *Hysteria) Start() error { + if h.tlsConfig != nil { + err := h.tlsConfig.Start() + if err != nil { + return err + } + } packetConn, err := h.myInboundAdapter.ListenUDP() if err != nil { return err } - if len(h.xplusKey) > 0 { - packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey) - packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn} - } - err = h.tlsConfig.Start() - if err != nil { - return err - } - listener, err := qtls.Listen(packetConn, h.tlsConfig, h.quicConfig) - if err != nil { - return err - } - h.listener = listener - h.logger.Info("udp server started at ", listener.Addr()) - go h.acceptLoop() - return nil -} - -func (h *Hysteria) acceptLoop() { - for { - ctx := log.ContextWithNewID(h.ctx) - conn, err := h.listener.Accept(ctx) - if err != nil { - return - } - go func() { - hErr := h.accept(ctx, conn) - if hErr != nil { - conn.CloseWithError(0, "") - NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr())) - } - }() - } -} - -func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error { - controlStream, err := conn.AcceptStream(ctx) - if err != nil { - return err - } - clientHello, err := hysteria.ReadClientHello(controlStream) - if err != nil { - return err - } - if len(h.authKey) > 0 { - userIndex := slices.Index(h.authKey, string(clientHello.Auth)) - if userIndex == -1 { - err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{ - Message: "wrong password", - }) - return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err) - } - user := h.authUser[userIndex] - if user == "" { - user = F.ToString(userIndex) - } else { - ctx = auth.ContextWithUser(ctx, user) - } - h.logger.InfoContext(ctx, "[", user, "] inbound connection from ", conn.RemoteAddr()) - } else { - h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr()) - } - h.logger.DebugContext(ctx, "peer send speed: ", clientHello.SendBPS/1024/1024, " MBps, peer recv speed: ", clientHello.RecvBPS/1024/1024, " MBps") - if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 { - return E.New("invalid rate from client") - } - serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS - if h.sendBPS > 0 && serverSendBPS > h.sendBPS { - serverSendBPS = h.sendBPS - } - if h.recvBPS > 0 && serverRecvBPS > h.recvBPS { - serverRecvBPS = h.recvBPS - } - err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{ - OK: true, - SendBPS: serverSendBPS, - RecvBPS: serverRecvBPS, - }) - if err != nil { - return err - } - conn.SetCongestionControl(hyCC.NewBrutalSender(serverSendBPS)) - go h.udpRecvLoop(conn) - for { - var stream quic.Stream - stream, err = conn.AcceptStream(ctx) - if err != nil { - return err - } - go func() { - hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream) - if hErr != nil { - stream.Close() - NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr())) - } - }() - } -} - -func (h *Hysteria) udpRecvLoop(conn quic.Connection) { - for { - packet, err := conn.ReceiveMessage(h.ctx) - if err != nil { - return - } - message, err := hysteria.ParseUDPMessage(packet) - if err != nil { - h.logger.Error("parse udp message: ", err) - continue - } - dfMsg := h.udpDefragger.Feed(message) - if dfMsg == nil { - continue - } - h.udpAccess.RLock() - ch, ok := h.udpSessions[dfMsg.SessionID] - if ok { - select { - case ch <- dfMsg: - // OK - default: - // Silently drop the message when the channel is full - } - } - h.udpAccess.RUnlock() - } -} - -func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error { - request, err := hysteria.ReadClientRequest(stream) - if err != nil { - return err - } - var metadata adapter.InboundContext - metadata.Inbound = h.tag - metadata.InboundType = C.TypeHysteria - metadata.InboundOptions = h.listenOptions.InboundOptions - metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap() - metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap() - metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port).Unwrap() - metadata.User, _ = auth.UserFromContext[string](ctx) - - if !request.UDP { - err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{ - OK: true, - }) - if err != nil { - return err - } - h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) - return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination, false), metadata) - } else { - h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) - var id uint32 - h.udpAccess.Lock() - id = h.udpSessionId - nCh := make(chan *hysteria.UDPMessage, 1024) - h.udpSessions[id] = nCh - h.udpSessionId += 1 - h.udpAccess.Unlock() - err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{ - OK: true, - UDPSessionID: id, - }) - if err != nil { - return err - } - packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error { - h.udpAccess.Lock() - if ch, ok := h.udpSessions[id]; ok { - close(ch) - delete(h.udpSessions, id) - } - h.udpAccess.Unlock() - return nil - })) - go packetConn.Hold() - return h.router.RoutePacketConnection(ctx, packetConn, metadata) - } + return h.service.Start(packetConn) } func (h *Hysteria) Close() error { - h.udpAccess.Lock() - for _, session := range h.udpSessions { - close(session) - } - h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) - h.udpAccess.Unlock() return common.Close( &h.myInboundAdapter, - h.listener, h.tlsConfig, + common.PtrOrNil(h.service), ) } diff --git a/inbound/hysteria2.go b/inbound/hysteria2.go index fd650ae9..07b45af2 100644 --- a/inbound/hysteria2.go +++ b/inbound/hysteria2.go @@ -14,7 +14,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/hysteria" + "github.com/sagernet/sing-quic/hysteria" "github.com/sagernet/sing-quic/hysteria2" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" @@ -32,6 +32,7 @@ type Hysteria2 struct { } func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (*Hysteria2, error) { + options.UDPFragmentDefault = true if options.TLS == nil || !options.TLS.Enabled { return nil, C.ErrTLSRequired } diff --git a/outbound/hysteria.go b/outbound/hysteria.go index ffdf61bb..8c130e33 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -5,18 +5,16 @@ package outbound import ( "context" "net" - "sync" + "os" - "github.com/sagernet/quic-go" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/humanize" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/hysteria" - "github.com/sagernet/sing-quic" - hyCC "github.com/sagernet/sing-quic/hysteria2/congestion" + "github.com/sagernet/sing-quic/hysteria" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" @@ -25,27 +23,13 @@ import ( ) var ( - _ adapter.Outbound = (*Hysteria)(nil) - _ adapter.InterfaceUpdateListener = (*Hysteria)(nil) + _ adapter.Outbound = (*TUIC)(nil) + _ adapter.InterfaceUpdateListener = (*TUIC)(nil) ) type Hysteria struct { myOutboundAdapter - ctx context.Context - dialer N.Dialer - serverAddr M.Socksaddr - tlsConfig tls.Config - quicConfig *quic.Config - authKey []byte - xplusKey []byte - sendBPS uint64 - recvBPS uint64 - connAccess sync.Mutex - conn quic.Connection - rawConn net.Conn - udpAccess sync.RWMutex - udpSessions map[uint32]chan *hysteria.UDPMessage - udpDefragger hysteria.Defragger + client *hysteria.Client } func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) { @@ -57,252 +41,77 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } - if len(tlsConfig.NextProtos()) == 0 { - tlsConfig.SetNextProtos([]string{hysteria.DefaultALPN}) + outboundDialer, err := dialer.New(router, options.DialerOptions) + if err != nil { + return nil, err } - quicConfig := &quic.Config{ - InitialStreamReceiveWindow: options.ReceiveWindowConn, - MaxStreamReceiveWindow: options.ReceiveWindowConn, - InitialConnectionReceiveWindow: options.ReceiveWindow, - MaxConnectionReceiveWindow: options.ReceiveWindow, - KeepAlivePeriod: hysteria.KeepAlivePeriod, - DisablePathMTUDiscovery: options.DisableMTUDiscovery, - EnableDatagrams: true, - } - if options.ReceiveWindowConn == 0 { - quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow - quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow - } - if options.ReceiveWindow == 0 { - quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow - quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow - } - if quicConfig.MaxIncomingStreams == 0 { - quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams - } - var auth []byte - if len(options.Auth) > 0 { - auth = options.Auth + networkList := options.Network.Build() + var password string + if options.AuthString != "" { + password = options.AuthString } else { - auth = []byte(options.AuthString) + password = string(options.Auth) } - var xplus []byte - if options.Obfs != "" { - xplus = []byte(options.Obfs) - } - var up, down uint64 + var sendBps, receiveBps uint64 if len(options.Up) > 0 { - up = hysteria.StringToBps(options.Up) - if up == 0 { - return nil, E.New("invalid up speed format: ", options.Up) + sendBps, err = humanize.ParseBytes(options.Up) + if err != nil { + return nil, E.Cause(err, "invalid up speed format: ", options.Up) } } else { - up = uint64(options.UpMbps) * hysteria.MbpsToBps + sendBps = uint64(options.UpMbps) * hysteria.MbpsToBps } if len(options.Down) > 0 { - down = hysteria.StringToBps(options.Down) - if down == 0 { + receiveBps, err = humanize.ParseBytes(options.Down) + if receiveBps == 0 { return nil, E.New("invalid down speed format: ", options.Down) } } else { - down = uint64(options.DownMbps) * hysteria.MbpsToBps + receiveBps = uint64(options.DownMbps) * hysteria.MbpsToBps } - if up < hysteria.MinSpeedBPS { - return nil, E.New("invalid up speed") - } - if down < hysteria.MinSpeedBPS { - return nil, E.New("invalid down speed") - } - outboundDialer, err := dialer.New(router, options.DialerOptions) + client, err := hysteria.NewClient(hysteria.ClientOptions{ + Context: ctx, + Dialer: outboundDialer, + Logger: logger, + ServerAddress: options.ServerOptions.Build(), + SendBPS: sendBps, + ReceiveBPS: receiveBps, + XPlusPassword: options.Obfs, + Password: password, + TLSConfig: tlsConfig, + UDPDisabled: !common.Contains(networkList, N.NetworkUDP), + + ConnReceiveWindow: options.ReceiveWindowConn, + StreamReceiveWindow: options.ReceiveWindow, + DisableMTUDiscovery: options.DisableMTUDiscovery, + }) if err != nil { return nil, err } return &Hysteria{ myOutboundAdapter: myOutboundAdapter{ protocol: C.TypeHysteria, - network: options.Network.Build(), + network: networkList, router: router, logger: logger, tag: tag, dependencies: withDialerDependency(options.DialerOptions), }, - ctx: ctx, - dialer: outboundDialer, - serverAddr: options.ServerOptions.Build(), - tlsConfig: tlsConfig, - quicConfig: quicConfig, - authKey: auth, - xplusKey: xplus, - sendBPS: up, - recvBPS: down, + client: client, }, nil } -func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) { - conn := h.conn - if conn != nil && !common.Done(conn.Context()) { - return conn, nil - } - h.connAccess.Lock() - defer h.connAccess.Unlock() - h.udpAccess.Lock() - defer h.udpAccess.Unlock() - conn = h.conn - if conn != nil && !common.Done(conn.Context()) { - return conn, nil - } - common.Close(h.rawConn) - conn, err := h.offerNew(ctx) - if err != nil { - return nil, err - } - if common.Contains(h.network, N.NetworkUDP) { - for _, session := range h.udpSessions { - close(session) - } - h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) - h.udpDefragger = hysteria.Defragger{} - go h.udpRecvLoop(conn) - } - return conn, nil -} - -func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { - udpConn, err := h.dialer.DialContext(h.ctx, "udp", h.serverAddr) - if err != nil { - return nil, err - } - var packetConn net.PacketConn - packetConn = bufio.NewUnbindPacketConn(udpConn) - if h.xplusKey != nil { - packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey) - } - packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn} - quicConn, err := qtls.Dial(h.ctx, packetConn, udpConn.RemoteAddr(), h.tlsConfig, h.quicConfig) - if err != nil { - packetConn.Close() - return nil, err - } - controlStream, err := quicConn.OpenStreamSync(ctx) - if err != nil { - packetConn.Close() - return nil, err - } - err = hysteria.WriteClientHello(controlStream, hysteria.ClientHello{ - SendBPS: h.sendBPS, - RecvBPS: h.recvBPS, - Auth: h.authKey, - }) - if err != nil { - packetConn.Close() - return nil, err - } - serverHello, err := hysteria.ReadServerHello(controlStream) - if err != nil { - packetConn.Close() - return nil, err - } - if !serverHello.OK { - packetConn.Close() - return nil, E.New("remote error: ", serverHello.Message) - } - quicConn.SetCongestionControl(hyCC.NewBrutalSender(serverHello.RecvBPS)) - h.conn = quicConn - h.rawConn = udpConn - return quicConn, nil -} - -func (h *Hysteria) udpRecvLoop(conn quic.Connection) { - for { - packet, err := conn.ReceiveMessage(h.ctx) - if err != nil { - return - } - message, err := hysteria.ParseUDPMessage(packet) - if err != nil { - h.logger.Error("parse udp message: ", err) - continue - } - dfMsg := h.udpDefragger.Feed(message) - if dfMsg == nil { - continue - } - h.udpAccess.RLock() - ch, ok := h.udpSessions[dfMsg.SessionID] - if ok { - select { - case ch <- dfMsg: - // OK - default: - // Silently drop the message when the channel is full - } - } - h.udpAccess.RUnlock() - } -} - -func (h *Hysteria) InterfaceUpdated() { - h.Close() - return -} - -func (h *Hysteria) Close() error { - h.connAccess.Lock() - defer h.connAccess.Unlock() - h.udpAccess.Lock() - defer h.udpAccess.Unlock() - if h.conn != nil { - h.conn.CloseWithError(0, "") - h.rawConn.Close() - } - for _, session := range h.udpSessions { - close(session) - } - h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) - return nil -} - -func (h *Hysteria) open(ctx context.Context, reconnect bool) (quic.Connection, quic.Stream, error) { - conn, err := h.offer(ctx) - if err != nil { - if nErr, ok := err.(net.Error); ok && !nErr.Temporary() && reconnect { - return h.open(ctx, false) - } - return nil, nil, err - } - stream, err := conn.OpenStream() - if err != nil { - if nErr, ok := err.(net.Error); ok && !nErr.Temporary() && reconnect { - return h.open(ctx, false) - } - return nil, nil, err - } - return conn, &hysteria.StreamWrapper{Stream: stream}, nil -} - func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch N.NetworkName(network) { case N.NetworkTCP: h.logger.InfoContext(ctx, "outbound connection to ", destination) - _, stream, err := h.open(ctx, true) - if err != nil { - return nil, err - } - err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{ - Host: destination.AddrString(), - Port: destination.Port, - }) - if err != nil { - stream.Close() - return nil, err - } - return hysteria.NewConn(stream, destination, true), nil + return h.client.DialConn(ctx, destination) case N.NetworkUDP: conn, err := h.ListenPacket(ctx, destination) if err != nil { return nil, err } - return conn.(*hysteria.PacketConn), nil + return bufio.NewBindPacketConn(conn, destination), nil default: return nil, E.New("unsupported network: ", network) } @@ -310,44 +119,7 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { h.logger.InfoContext(ctx, "outbound packet connection to ", destination) - conn, stream, err := h.open(ctx, true) - if err != nil { - return nil, err - } - err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{ - UDP: true, - Host: destination.AddrString(), - Port: destination.Port, - }) - if err != nil { - stream.Close() - return nil, err - } - var response *hysteria.ServerResponse - response, err = hysteria.ReadServerResponse(stream) - if err != nil { - stream.Close() - return nil, err - } - if !response.OK { - stream.Close() - return nil, E.New("remote error: ", response.Message) - } - h.udpAccess.Lock() - nCh := make(chan *hysteria.UDPMessage, 1024) - h.udpSessions[response.UDPSessionID] = nCh - h.udpAccess.Unlock() - packetConn := hysteria.NewPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error { - h.udpAccess.Lock() - if ch, ok := h.udpSessions[response.UDPSessionID]; ok { - close(ch) - delete(h.udpSessions, response.UDPSessionID) - } - h.udpAccess.Unlock() - return nil - })) - go packetConn.Hold() - return packetConn, nil + return h.client.ListenPacket(ctx, destination) } func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { @@ -357,3 +129,11 @@ func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata ad func (h *Hysteria) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { return NewPacketConnection(ctx, h, conn, metadata) } + +func (h *Hysteria) InterfaceUpdated() error { + return h.client.CloseWithError(E.New("network changed")) +} + +func (h *Hysteria) Close() error { + return h.client.CloseWithError(os.ErrClosed) +} diff --git a/outbound/hysteria2.go b/outbound/hysteria2.go index 120865a9..2998f948 100644 --- a/outbound/hysteria2.go +++ b/outbound/hysteria2.go @@ -13,7 +13,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/hysteria" + "github.com/sagernet/sing-quic/hysteria" "github.com/sagernet/sing-quic/hysteria2" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" diff --git a/test/go.mod b/test/go.mod index fed2124a..50759f68 100644 --- a/test/go.mod +++ b/test/go.mod @@ -10,10 +10,10 @@ require ( github.com/docker/docker v24.0.6+incompatible github.com/docker/go-connections v0.4.0 github.com/gofrs/uuid/v5 v5.0.0 - github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee - github.com/sagernet/sing v0.2.15 + github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 + github.com/sagernet/sing v0.2.16-0.20231021090846-8002db54c028 github.com/sagernet/sing-dns v0.1.10 - github.com/sagernet/sing-quic v0.1.2 + github.com/sagernet/sing-quic v0.1.3-0.20231026034240-fa3d997246b6 github.com/sagernet/sing-shadowsocks v0.2.5 github.com/sagernet/sing-shadowsocks2 v0.1.4 github.com/spyzhov/ajson v0.9.0 @@ -70,12 +70,12 @@ require ( github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a // indirect github.com/sagernet/cloudflare-tls v0.0.0-20230829051644-4a68352d0c4a // indirect github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 // indirect - github.com/sagernet/gvisor v0.0.0-20230627031050-1ab0276e0dd2 // indirect + github.com/sagernet/gvisor v0.0.0-20230930141345-5fef6f2e17ab // indirect github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 // indirect github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 // indirect github.com/sagernet/sing-mux v0.1.3 // indirect github.com/sagernet/sing-shadowtls v0.1.4 // indirect - github.com/sagernet/sing-tun v0.1.16 // indirect + github.com/sagernet/sing-tun v0.1.17-0.20231026060825-efd9884154a6 // indirect github.com/sagernet/sing-vmess v0.1.8 // indirect github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 // indirect github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 // indirect diff --git a/test/go.sum b/test/go.sum index d13690ad..597f5594 100644 --- a/test/go.sum +++ b/test/go.sum @@ -117,32 +117,32 @@ github.com/sagernet/cloudflare-tls v0.0.0-20230829051644-4a68352d0c4a h1:wZHruBx github.com/sagernet/cloudflare-tls v0.0.0-20230829051644-4a68352d0c4a/go.mod h1:dNV1ZP9y3qx5ltULeKaQZTZWTLHflgW5DES+Ses7cMI= github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA= github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h1:QUQ4RRHD6hGGHdFMEtR8T2P6GS6R3D/CXKdaYHKKXms= -github.com/sagernet/gvisor v0.0.0-20230627031050-1ab0276e0dd2 h1:dnkKrzapqtAwjTSWt6hdPrARORfoYvuUczynvRLrueo= -github.com/sagernet/gvisor v0.0.0-20230627031050-1ab0276e0dd2/go.mod h1:1JUiV7nGuf++YFm9eWZ8q2lrwHmhcUGzptMl/vL1+LA= +github.com/sagernet/gvisor v0.0.0-20230930141345-5fef6f2e17ab h1:u+xQoi/Yc6bNUvTfrDD6HhGRybn2lzrhf5vmS+wb4Ho= +github.com/sagernet/gvisor v0.0.0-20230930141345-5fef6f2e17ab/go.mod h1:3akUhSHSVtLuJaYcW5JPepUraBOW06Ibz2HKwaK5rOk= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= -github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee h1:ykuhl9jCS638N+jw1vC9AvT9bbQn6xRNScP2FWPV9dM= -github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee/go.mod h1:0CfhWwZAeXGYM9+Nkkw1zcQtFHQC8KWjbpeDv7pu8iw= +github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 h1:dAe4OIJAtE0nHOzTHhAReQteh3+sa63rvXbuIpbeOTY= +github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.15 h1:PFwyiMzkyJkq+YGOVznJUsRVOT6EoVxRGIsllLuvHXA= -github.com/sagernet/sing v0.2.15/go.mod h1:AhNEHu0GXrpqkuzvTwvC8+j2cQUU/dh+zLEmq4C99pg= +github.com/sagernet/sing v0.2.16-0.20231021090846-8002db54c028 h1:6GbQt7SC9y5Imrq5jDMbXDSaNiMhJ8KBjhjtQRuqQvE= +github.com/sagernet/sing v0.2.16-0.20231021090846-8002db54c028/go.mod h1:AhNEHu0GXrpqkuzvTwvC8+j2cQUU/dh+zLEmq4C99pg= github.com/sagernet/sing-dns v0.1.10 h1:iIU7nRBlUYj+fF2TaktGIvRiTFFrHwSMedLQsvlTZCI= github.com/sagernet/sing-dns v0.1.10/go.mod h1:vtUimtf7Nq9EdvD5WTpfCr69KL1M7bcgOVKiYBiAY/c= github.com/sagernet/sing-mux v0.1.3 h1:fAf7PZa2A55mCeh0KKM02f1k2Y4vEmxuZZ/51ahkkLA= github.com/sagernet/sing-mux v0.1.3/go.mod h1:wGeIeiiFLx4HUM5LAg65wrNZ/X1muOimqK0PEhNbPi0= -github.com/sagernet/sing-quic v0.1.2 h1:+u9CRf0KHi5HgXmJ3eB0CtqpWXtF0lx2QlWq+ZFZ+XY= -github.com/sagernet/sing-quic v0.1.2/go.mod h1:H1TX0/y9UUM43wyaLQ+qjg2+o901ibYtwWX2rWG+a3o= +github.com/sagernet/sing-quic v0.1.3-0.20231026034240-fa3d997246b6 h1:w+TUbIZKZFSdf/AUa/y33kY9xaLeNGz/tBNcNhqpqfg= +github.com/sagernet/sing-quic v0.1.3-0.20231026034240-fa3d997246b6/go.mod h1:1M7xP4802K9Kz6BQ7LlA7UeCapWvWlH1Htmk2bAqkWc= github.com/sagernet/sing-shadowsocks v0.2.5 h1:qxIttos4xu6ii7MTVJYA8EFQR7Q3KG6xMqmLJIFtBaY= github.com/sagernet/sing-shadowsocks v0.2.5/go.mod h1:MGWGkcU2xW2G2mfArT9/QqpVLOGU+dBaahZCtPHdt7A= github.com/sagernet/sing-shadowsocks2 v0.1.4 h1:vht2M8t3m5DTgXR2j24KbYOygG5aOp+MUhpQnAux728= github.com/sagernet/sing-shadowsocks2 v0.1.4/go.mod h1:Mgdee99NxxNd5Zld3ixIs18yVs4x2dI2VTDDE1N14Wc= github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnVpEx6Tw3k= github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4= -github.com/sagernet/sing-tun v0.1.16 h1:RHXYIVg6uacvdfbYMiPEz9VX5uu6mNrvP7u9yAH3oNc= -github.com/sagernet/sing-tun v0.1.16/go.mod h1:S3q8GCjeyRniK+KLmo4XqKY0bS3x2UdKkKbqxT/Agl8= +github.com/sagernet/sing-tun v0.1.17-0.20231026060825-efd9884154a6 h1:4yEXBqQoUgXj7qPSLD6lr+z9/KfsvixO9JUA2i5xnM8= +github.com/sagernet/sing-tun v0.1.17-0.20231026060825-efd9884154a6/go.mod h1:w2+S+uWE94E/pQWSDdDdMIjwAEb645kuGPunr6ZllUg= github.com/sagernet/sing-vmess v0.1.8 h1:XVWad1RpTy9b5tPxdm5MCU8cGfrTGdR8qCq6HV2aCNc= github.com/sagernet/sing-vmess v0.1.8/go.mod h1:vhx32UNzTDUkNwOyIjcZQohre1CaytquC5mPplId8uA= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= diff --git a/test/hysteria_test.go b/test/hysteria_test.go index df0ae3a2..664e5a7c 100644 --- a/test/hysteria_test.go +++ b/test/hysteria_test.go @@ -79,7 +79,7 @@ func TestHysteriaSelf(t *testing.T) { }, }, }) - testSuitSimple1(t, clientPort, testPort) + testSuit(t, clientPort, testPort) } func TestHysteriaInbound(t *testing.T) { @@ -118,7 +118,7 @@ func TestHysteriaInbound(t *testing.T) { caPem: "/etc/hysteria/ca.pem", }, }) - testSuitSimple1(t, clientPort, testPort) + testSuit(t, clientPort, testPort) } func TestHysteriaOutbound(t *testing.T) { diff --git a/transport/hysteria/frag.go b/transport/hysteria/frag.go deleted file mode 100644 index 721341f1..00000000 --- a/transport/hysteria/frag.go +++ /dev/null @@ -1,65 +0,0 @@ -package hysteria - -func FragUDPMessage(m UDPMessage, maxSize int) []UDPMessage { - if m.Size() <= maxSize { - return []UDPMessage{m} - } - fullPayload := m.Data - maxPayloadSize := maxSize - m.HeaderSize() - off := 0 - fragID := uint8(0) - fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up - var frags []UDPMessage - for off < len(fullPayload) { - payloadSize := len(fullPayload) - off - if payloadSize > maxPayloadSize { - payloadSize = maxPayloadSize - } - frag := m - frag.FragID = fragID - frag.FragCount = fragCount - frag.Data = fullPayload[off : off+payloadSize] - frags = append(frags, frag) - off += payloadSize - fragID++ - } - return frags -} - -type Defragger struct { - msgID uint16 - frags []*UDPMessage - count uint8 -} - -func (d *Defragger) Feed(m UDPMessage) *UDPMessage { - if m.FragCount <= 1 { - return &m - } - if m.FragID >= m.FragCount { - // wtf is this? - return nil - } - if m.MsgID != d.msgID { - // new message, clear previous state - d.msgID = m.MsgID - d.frags = make([]*UDPMessage, m.FragCount) - d.count = 1 - d.frags[m.FragID] = &m - } else if d.frags[m.FragID] == nil { - d.frags[m.FragID] = &m - d.count++ - if int(d.count) == len(d.frags) { - // all fragments received, assemble - var data []byte - for _, frag := range d.frags { - data = append(data, frag.Data...) - } - m.Data = data - m.FragID = 0 - m.FragCount = 1 - return &m - } - } - return nil -} diff --git a/transport/hysteria/protocol.go b/transport/hysteria/protocol.go deleted file mode 100644 index a338988f..00000000 --- a/transport/hysteria/protocol.go +++ /dev/null @@ -1,539 +0,0 @@ -package hysteria - -import ( - "bytes" - "encoding/binary" - "io" - "math/rand" - "net" - "os" - "time" - - "github.com/sagernet/quic-go" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" -) - -const ( - MbpsToBps = 125000 - MinSpeedBPS = 16384 - DefaultStreamReceiveWindow = 15728640 // 15 MB/s - DefaultConnectionReceiveWindow = 67108864 // 64 MB/s - DefaultMaxIncomingStreams = 1024 - DefaultALPN = "hysteria" - KeepAlivePeriod = 10 * time.Second -) - -const Version = 3 - -type ClientHello struct { - SendBPS uint64 - RecvBPS uint64 - Auth []byte -} - -func WriteClientHello(stream io.Writer, hello ClientHello) error { - var requestLen int - requestLen += 1 // version - requestLen += 8 // sendBPS - requestLen += 8 // recvBPS - requestLen += 2 // auth len - requestLen += len(hello.Auth) - request := buf.NewSize(requestLen) - defer request.Release() - common.Must( - request.WriteByte(Version), - binary.Write(request, binary.BigEndian, hello.SendBPS), - binary.Write(request, binary.BigEndian, hello.RecvBPS), - binary.Write(request, binary.BigEndian, uint16(len(hello.Auth))), - common.Error(request.Write(hello.Auth)), - ) - return common.Error(stream.Write(request.Bytes())) -} - -func ReadClientHello(reader io.Reader) (*ClientHello, error) { - var version uint8 - err := binary.Read(reader, binary.BigEndian, &version) - if err != nil { - return nil, err - } - if version != Version { - return nil, E.New("unsupported client version: ", version) - } - var clientHello ClientHello - err = binary.Read(reader, binary.BigEndian, &clientHello.SendBPS) - if err != nil { - return nil, err - } - err = binary.Read(reader, binary.BigEndian, &clientHello.RecvBPS) - if err != nil { - return nil, err - } - var authLen uint16 - err = binary.Read(reader, binary.BigEndian, &authLen) - if err != nil { - return nil, err - } - clientHello.Auth = make([]byte, authLen) - _, err = io.ReadFull(reader, clientHello.Auth) - if err != nil { - return nil, err - } - return &clientHello, nil -} - -type ServerHello struct { - OK bool - SendBPS uint64 - RecvBPS uint64 - Message string -} - -func ReadServerHello(stream io.Reader) (*ServerHello, error) { - var responseLen int - responseLen += 1 // ok - responseLen += 8 // sendBPS - responseLen += 8 // recvBPS - responseLen += 2 // message len - response := buf.NewSize(responseLen) - defer response.Release() - _, err := response.ReadFullFrom(stream, responseLen) - if err != nil { - return nil, err - } - var serverHello ServerHello - serverHello.OK = response.Byte(0) == 1 - serverHello.SendBPS = binary.BigEndian.Uint64(response.Range(1, 9)) - serverHello.RecvBPS = binary.BigEndian.Uint64(response.Range(9, 17)) - messageLen := binary.BigEndian.Uint16(response.Range(17, 19)) - if messageLen == 0 { - return &serverHello, nil - } - message := make([]byte, messageLen) - _, err = io.ReadFull(stream, message) - if err != nil { - return nil, err - } - serverHello.Message = string(message) - return &serverHello, nil -} - -func WriteServerHello(stream io.Writer, hello ServerHello) error { - var responseLen int - responseLen += 1 // ok - responseLen += 8 // sendBPS - responseLen += 8 // recvBPS - responseLen += 2 // message len - responseLen += len(hello.Message) - response := buf.NewSize(responseLen) - defer response.Release() - if hello.OK { - common.Must(response.WriteByte(1)) - } else { - common.Must(response.WriteByte(0)) - } - common.Must( - binary.Write(response, binary.BigEndian, hello.SendBPS), - binary.Write(response, binary.BigEndian, hello.RecvBPS), - binary.Write(response, binary.BigEndian, uint16(len(hello.Message))), - common.Error(response.WriteString(hello.Message)), - ) - return common.Error(stream.Write(response.Bytes())) -} - -type ClientRequest struct { - UDP bool - Host string - Port uint16 -} - -func ReadClientRequest(stream io.Reader) (*ClientRequest, error) { - var clientRequest ClientRequest - err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP) - if err != nil { - return nil, err - } - var hostLen uint16 - err = binary.Read(stream, binary.BigEndian, &hostLen) - if err != nil { - return nil, err - } - host := make([]byte, hostLen) - _, err = io.ReadFull(stream, host) - if err != nil { - return nil, err - } - clientRequest.Host = string(host) - err = binary.Read(stream, binary.BigEndian, &clientRequest.Port) - if err != nil { - return nil, err - } - return &clientRequest, nil -} - -func WriteClientRequest(stream io.Writer, request ClientRequest) error { - var requestLen int - requestLen += 1 // udp - requestLen += 2 // host len - requestLen += len(request.Host) - requestLen += 2 // port - buffer := buf.NewSize(requestLen) - defer buffer.Release() - if request.UDP { - common.Must(buffer.WriteByte(1)) - } else { - common.Must(buffer.WriteByte(0)) - } - common.Must( - binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))), - common.Error(buffer.WriteString(request.Host)), - binary.Write(buffer, binary.BigEndian, request.Port), - ) - return common.Error(stream.Write(buffer.Bytes())) -} - -type ServerResponse struct { - OK bool - UDPSessionID uint32 - Message string -} - -func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { - var responseLen int - responseLen += 1 // ok - responseLen += 4 // udp session id - responseLen += 2 // message len - response := buf.NewSize(responseLen) - defer response.Release() - _, err := response.ReadFullFrom(stream, responseLen) - if err != nil { - return nil, err - } - var serverResponse ServerResponse - serverResponse.OK = response.Byte(0) == 1 - serverResponse.UDPSessionID = binary.BigEndian.Uint32(response.Range(1, 5)) - messageLen := binary.BigEndian.Uint16(response.Range(5, 7)) - if messageLen == 0 { - return &serverResponse, nil - } - message := make([]byte, messageLen) - _, err = io.ReadFull(stream, message) - if err != nil { - return nil, err - } - serverResponse.Message = string(message) - return &serverResponse, nil -} - -func WriteServerResponse(stream io.Writer, response ServerResponse) error { - var responseLen int - responseLen += 1 // ok - responseLen += 4 // udp session id - responseLen += 2 // message len - responseLen += len(response.Message) - buffer := buf.NewSize(responseLen) - defer buffer.Release() - if response.OK { - common.Must(buffer.WriteByte(1)) - } else { - common.Must(buffer.WriteByte(0)) - } - common.Must( - binary.Write(buffer, binary.BigEndian, response.UDPSessionID), - binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))), - common.Error(buffer.WriteString(response.Message)), - ) - return common.Error(stream.Write(buffer.Bytes())) -} - -type UDPMessage struct { - SessionID uint32 - Host string - Port uint16 - MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented - FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented - FragCount uint8 // must be 1 when not fragmented - Data []byte -} - -func (m UDPMessage) HeaderSize() int { - return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2 -} - -func (m UDPMessage) Size() int { - return m.HeaderSize() + len(m.Data) -} - -func ParseUDPMessage(packet []byte) (message UDPMessage, err error) { - reader := bytes.NewReader(packet) - err = binary.Read(reader, binary.BigEndian, &message.SessionID) - if err != nil { - return - } - var hostLen uint16 - err = binary.Read(reader, binary.BigEndian, &hostLen) - if err != nil { - return - } - _, err = reader.Seek(int64(hostLen), io.SeekCurrent) - if err != nil { - return - } - if 6+int(hostLen) > len(packet) { - err = E.New("invalid host length") - return - } - message.Host = string(packet[6 : 6+hostLen]) - err = binary.Read(reader, binary.BigEndian, &message.Port) - if err != nil { - return - } - err = binary.Read(reader, binary.BigEndian, &message.MsgID) - if err != nil { - return - } - err = binary.Read(reader, binary.BigEndian, &message.FragID) - if err != nil { - return - } - err = binary.Read(reader, binary.BigEndian, &message.FragCount) - if err != nil { - return - } - var dataLen uint16 - err = binary.Read(reader, binary.BigEndian, &dataLen) - if err != nil { - return - } - if reader.Len() != int(dataLen) { - err = E.New("invalid data length") - } - dataOffset := int(reader.Size()) - reader.Len() - message.Data = packet[dataOffset:] - return -} - -func WriteUDPMessage(conn quic.Connection, message UDPMessage) error { - var messageLen int - messageLen += 4 // session id - messageLen += 2 // host len - messageLen += len(message.Host) - messageLen += 2 // port - messageLen += 2 // msg id - messageLen += 1 // frag id - messageLen += 1 // frag count - messageLen += 2 // data len - messageLen += len(message.Data) - buffer := buf.NewSize(messageLen) - defer buffer.Release() - err := writeUDPMessage(conn, message, buffer) - if errSize, ok := err.(quic.ErrMessageTooLarge); ok { - // need to frag - message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 - fragMsgs := FragUDPMessage(message, int(errSize)) - for _, fragMsg := range fragMsgs { - buffer.FullReset() - err = writeUDPMessage(conn, fragMsg, buffer) - if err != nil { - return err - } - } - return nil - } - return err -} - -func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error { - common.Must( - binary.Write(buffer, binary.BigEndian, message.SessionID), - binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))), - common.Error(buffer.WriteString(message.Host)), - binary.Write(buffer, binary.BigEndian, message.Port), - binary.Write(buffer, binary.BigEndian, message.MsgID), - binary.Write(buffer, binary.BigEndian, message.FragID), - binary.Write(buffer, binary.BigEndian, message.FragCount), - binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))), - common.Error(buffer.Write(message.Data)), - ) - return conn.SendMessage(buffer.Bytes()) -} - -var _ net.Conn = (*Conn)(nil) - -type Conn struct { - quic.Stream - destination M.Socksaddr - needReadResponse bool -} - -func NewConn(stream quic.Stream, destination M.Socksaddr, isClient bool) *Conn { - return &Conn{ - Stream: stream, - destination: destination, - needReadResponse: isClient, - } -} - -func (c *Conn) Read(p []byte) (n int, err error) { - if c.needReadResponse { - var response *ServerResponse - response, err = ReadServerResponse(c.Stream) - if err != nil { - c.Close() - return - } - if !response.OK { - c.Close() - return 0, E.New("remote error: ", response.Message) - } - c.needReadResponse = false - } - return c.Stream.Read(p) -} - -func (c *Conn) LocalAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *Conn) RemoteAddr() net.Addr { - return c.destination.TCPAddr() -} - -func (c *Conn) ReaderReplaceable() bool { - return !c.needReadResponse -} - -func (c *Conn) WriterReplaceable() bool { - return true -} - -func (c *Conn) Upstream() any { - return c.Stream -} - -type PacketConn struct { - session quic.Connection - stream quic.Stream - sessionId uint32 - destination M.Socksaddr - msgCh <-chan *UDPMessage - closer io.Closer -} - -func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn { - return &PacketConn{ - session: session, - stream: stream, - sessionId: sessionId, - destination: destination, - msgCh: msgCh, - closer: closer, - } -} - -func (c *PacketConn) Hold() { - // Hold the stream until it's closed - buf := make([]byte, 1024) - for { - _, err := c.stream.Read(buf) - if err != nil { - break - } - } - _ = c.Close() -} - -func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - msg := <-c.msgCh - if msg == nil { - err = net.ErrClosed - return - } - err = common.Error(buffer.Write(msg.Data)) - destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port).Unwrap() - return -} - -func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - msg := <-c.msgCh - if msg == nil { - err = net.ErrClosed - return - } - buffer = buf.As(msg.Data) - destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port).Unwrap() - return -} - -func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - return WriteUDPMessage(c.session, UDPMessage{ - SessionID: c.sessionId, - Host: destination.AddrString(), - Port: destination.Port, - FragCount: 1, - Data: buffer.Bytes(), - }) -} - -func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - msg := <-c.msgCh - if msg == nil { - err = net.ErrClosed - return - } - n = copy(p, msg.Data) - destination := M.ParseSocksaddrHostPort(msg.Host, msg.Port) - if destination.IsFqdn() { - addr = destination - } else { - addr = destination.UDPAddr() - } - return -} - -func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr)) - if err == nil { - n = len(p) - } - return -} - -func (c *PacketConn) LocalAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *PacketConn) RemoteAddr() net.Addr { - return c.destination.UDPAddr() -} - -func (c *PacketConn) SetDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *PacketConn) SetReadDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *PacketConn) SetWriteDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *PacketConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *PacketConn) Read(b []byte) (n int, err error) { - n, _, err = c.ReadFrom(b) - return -} - -func (c *PacketConn) Write(b []byte) (n int, err error) { - return c.WriteTo(b, c.destination) -} - -func (c *PacketConn) Close() error { - return common.Close(c.stream, c.closer) -} diff --git a/transport/hysteria/speed.go b/transport/hysteria/speed.go deleted file mode 100644 index 161e0d58..00000000 --- a/transport/hysteria/speed.go +++ /dev/null @@ -1,36 +0,0 @@ -package hysteria - -import ( - "regexp" - "strconv" -) - -func StringToBps(s string) uint64 { - if s == "" { - return 0 - } - m := regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`).FindStringSubmatch(s) - if m == nil { - return 0 - } - var n uint64 - switch m[2] { - case "K": - n = 1 << 10 - case "M": - n = 1 << 20 - case "G": - n = 1 << 30 - case "T": - n = 1 << 40 - default: - n = 1 - } - v, _ := strconv.ParseUint(m[1], 10, 64) - n = v * n - if m[3] == "b" { - // Bits, need to convert to bytes - n = n >> 3 - } - return n -} diff --git a/transport/hysteria/wrap.go b/transport/hysteria/wrap.go deleted file mode 100644 index e89ac95e..00000000 --- a/transport/hysteria/wrap.go +++ /dev/null @@ -1,68 +0,0 @@ -package hysteria - -import ( - "net" - "os" - "syscall" - - "github.com/sagernet/quic-go" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/baderror" -) - -type PacketConnWrapper struct { - net.PacketConn -} - -func (c *PacketConnWrapper) SetReadBuffer(bytes int) error { - return common.MustCast[*net.UDPConn](c.PacketConn).SetReadBuffer(bytes) -} - -func (c *PacketConnWrapper) SetWriteBuffer(bytes int) error { - return common.MustCast[*net.UDPConn](c.PacketConn).SetWriteBuffer(bytes) -} - -func (c *PacketConnWrapper) SyscallConn() (syscall.RawConn, error) { - return common.MustCast[*net.UDPConn](c.PacketConn).SyscallConn() -} - -func (c *PacketConnWrapper) File() (f *os.File, err error) { - return common.MustCast[*net.UDPConn](c.PacketConn).File() -} - -func (c *PacketConnWrapper) Upstream() any { - return c.PacketConn -} - -type StreamWrapper struct { - Conn quic.Connection - quic.Stream -} - -func (s *StreamWrapper) Read(p []byte) (n int, err error) { - n, err = s.Stream.Read(p) - return n, baderror.WrapQUIC(err) -} - -func (s *StreamWrapper) Write(p []byte) (n int, err error) { - n, err = s.Stream.Write(p) - return n, baderror.WrapQUIC(err) -} - -func (s *StreamWrapper) LocalAddr() net.Addr { - return s.Conn.LocalAddr() -} - -func (s *StreamWrapper) RemoteAddr() net.Addr { - return s.Conn.RemoteAddr() -} - -func (s *StreamWrapper) Upstream() any { - return s.Stream -} - -func (s *StreamWrapper) Close() error { - s.CancelRead(0) - s.Stream.Close() - return nil -} diff --git a/transport/hysteria/xplus.go b/transport/hysteria/xplus.go deleted file mode 100644 index 14e0eaa8..00000000 --- a/transport/hysteria/xplus.go +++ /dev/null @@ -1,118 +0,0 @@ -package hysteria - -import ( - "crypto/sha256" - "math/rand" - "net" - "sync" - "time" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -const xplusSaltLen = 16 - -func NewXPlusPacketConn(conn net.PacketConn, key []byte) net.PacketConn { - vectorisedWriter, isVectorised := bufio.CreateVectorisedPacketWriter(conn) - if isVectorised { - return &VectorisedXPlusConn{ - XPlusPacketConn: XPlusPacketConn{ - PacketConn: conn, - key: key, - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - }, - writer: vectorisedWriter, - } - } else { - return &XPlusPacketConn{ - PacketConn: conn, - key: key, - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - } - } -} - -type XPlusPacketConn struct { - net.PacketConn - key []byte - randAccess sync.Mutex - rand *rand.Rand -} - -func (c *XPlusPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, addr, err = c.PacketConn.ReadFrom(p) - if err != nil { - return - } else if n < xplusSaltLen { - n = 0 - return - } - key := sha256.Sum256(append(c.key, p[:xplusSaltLen]...)) - for i := range p[xplusSaltLen:] { - p[i] = p[xplusSaltLen+i] ^ key[i%sha256.Size] - } - n -= xplusSaltLen - return -} - -func (c *XPlusPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - // can't use unsafe buffer on WriteTo - buffer := buf.NewSize(len(p) + xplusSaltLen) - defer buffer.Release() - salt := buffer.Extend(xplusSaltLen) - c.randAccess.Lock() - _, _ = c.rand.Read(salt) - c.randAccess.Unlock() - key := sha256.Sum256(append(c.key, salt...)) - for i := range p { - common.Must(buffer.WriteByte(p[i] ^ key[i%sha256.Size])) - } - return c.PacketConn.WriteTo(buffer.Bytes(), addr) -} - -func (c *XPlusPacketConn) Upstream() any { - return c.PacketConn -} - -type VectorisedXPlusConn struct { - XPlusPacketConn - writer N.VectorisedPacketWriter -} - -func (c *VectorisedXPlusConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - header := buf.NewSize(xplusSaltLen) - defer header.Release() - salt := header.Extend(xplusSaltLen) - c.randAccess.Lock() - _, _ = c.rand.Read(salt) - c.randAccess.Unlock() - key := sha256.Sum256(append(c.key, salt...)) - for i := range p { - p[i] ^= key[i%sha256.Size] - } - return bufio.WriteVectorisedPacket(c.writer, [][]byte{header.Bytes(), p}, M.SocksaddrFromNet(addr)) -} - -func (c *VectorisedXPlusConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { - header := buf.NewSize(xplusSaltLen) - defer header.Release() - salt := header.Extend(xplusSaltLen) - c.randAccess.Lock() - _, _ = c.rand.Read(salt) - c.randAccess.Unlock() - key := sha256.Sum256(append(c.key, salt...)) - var index int - for _, buffer := range buffers { - data := buffer.Bytes() - for i := range data { - data[i] ^= key[index%sha256.Size] - index++ - } - } - buffers = append([]*buf.Buffer{header}, buffers...) - return c.writer.WriteVectorisedPacket(buffers, destination) -} diff --git a/transport/v2rayquic/client.go b/transport/v2rayquic/client.go index c3345780..f5184615 100644 --- a/transport/v2rayquic/client.go +++ b/transport/v2rayquic/client.go @@ -12,7 +12,6 @@ import ( "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/hysteria" "github.com/sagernet/sing-quic" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" @@ -93,7 +92,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { if err != nil { return nil, err } - return &hysteria.StreamWrapper{Conn: conn, Stream: stream}, nil + return &StreamWrapper{Conn: conn, Stream: stream}, nil } func (c *Client) Close() error { diff --git a/transport/v2rayquic/server.go b/transport/v2rayquic/server.go index 71960e58..0ef8d2a1 100644 --- a/transport/v2rayquic/server.go +++ b/transport/v2rayquic/server.go @@ -12,7 +12,6 @@ import ( "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/hysteria" "github.com/sagernet/sing-quic" "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" @@ -86,7 +85,7 @@ func (s *Server) streamAcceptLoop(conn quic.Connection) error { if err != nil { return err } - go s.handler.NewConnection(conn.Context(), &hysteria.StreamWrapper{Conn: conn, Stream: stream}, M.Metadata{}) + go s.handler.NewConnection(conn.Context(), &StreamWrapper{Conn: conn, Stream: stream}, M.Metadata{}) } } diff --git a/transport/v2rayquic/stream.go b/transport/v2rayquic/stream.go new file mode 100644 index 00000000..d9c3beba --- /dev/null +++ b/transport/v2rayquic/stream.go @@ -0,0 +1,41 @@ +package v2rayquic + +import ( + "net" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing/common/baderror" +) + +type StreamWrapper struct { + Conn quic.Connection + quic.Stream +} + +func (s *StreamWrapper) Read(p []byte) (n int, err error) { + n, err = s.Stream.Read(p) + return n, baderror.WrapQUIC(err) +} + +func (s *StreamWrapper) Write(p []byte) (n int, err error) { + n, err = s.Stream.Write(p) + return n, baderror.WrapQUIC(err) +} + +func (s *StreamWrapper) LocalAddr() net.Addr { + return s.Conn.LocalAddr() +} + +func (s *StreamWrapper) RemoteAddr() net.Addr { + return s.Conn.RemoteAddr() +} + +func (s *StreamWrapper) Upstream() any { + return s.Stream +} + +func (s *StreamWrapper) Close() error { + s.CancelRead(0) + s.Stream.Close() + return nil +}