package tuic import ( "bytes" "context" "crypto/tls" "encoding/binary" "io" "net" "runtime" "strings" "sync" "time" "github.com/sagernet/quic-go" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/gofrs/uuid/v5" ) type ServerOptions struct { Context context.Context Logger logger.Logger TLSConfig *tls.Config Users []User CongestionControl string AuthTimeout time.Duration ZeroRTTHandshake bool Heartbeat time.Duration Handler ServerHandler } type User struct { Name string UUID uuid.UUID Password string } type ServerHandler interface { N.TCPConnectionHandler N.UDPConnectionHandler } type Server struct { ctx context.Context logger logger.Logger tlsConfig *tls.Config heartbeat time.Duration quicConfig *quic.Config userMap map[uuid.UUID]User congestionControl string authTimeout time.Duration handler ServerHandler quicListener io.Closer } func NewServer(options ServerOptions) (*Server, error) { if options.AuthTimeout == 0 { options.AuthTimeout = 3 * time.Second } if options.Heartbeat == 0 { options.Heartbeat = 10 * time.Second } quicConfig := &quic.Config{ DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), MaxDatagramFrameSize: 1400, EnableDatagrams: true, Allow0RTT: options.ZeroRTTHandshake, MaxIncomingStreams: 1 << 60, MaxIncomingUniStreams: 1 << 60, } switch options.CongestionControl { case "": options.CongestionControl = "cubic" case "cubic", "new_reno", "bbr": default: return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) } if len(options.Users) == 0 { return nil, E.New("missing users") } userMap := make(map[uuid.UUID]User) for _, user := range options.Users { userMap[user.UUID] = user } return &Server{ ctx: options.Context, logger: options.Logger, tlsConfig: options.TLSConfig, heartbeat: options.Heartbeat, quicConfig: quicConfig, userMap: userMap, congestionControl: options.CongestionControl, authTimeout: options.AuthTimeout, handler: options.Handler, }, nil } func (s *Server) Start(conn net.PacketConn) error { if !s.quicConfig.Allow0RTT { listener, err := quic.Listen(conn, s.tlsConfig, s.quicConfig) if err != nil { return err } s.quicListener = listener go func() { for { connection, hErr := listener.Accept(s.ctx) if hErr != nil { if strings.Contains(hErr.Error(), "server closed") { s.logger.Debug(E.Cause(hErr, "listener closed")) } else { s.logger.Error(E.Cause(hErr, "listener closed")) } return } go s.handleConnection(connection) } }() } else { listener, err := quic.ListenEarly(conn, s.tlsConfig, s.quicConfig) if err != nil { return err } s.quicListener = listener go func() { for { connection, hErr := listener.Accept(s.ctx) if hErr != nil { if strings.Contains(hErr.Error(), "server closed") { s.logger.Debug(E.Cause(hErr, "listener closed")) } else { s.logger.Error(E.Cause(hErr, "listener closed")) } return } go s.handleConnection(connection) } }() } return nil } func (s *Server) Close() error { return common.Close( s.quicListener, ) } func (s *Server) handleConnection(connection quic.Connection) { setCongestion(s.ctx, connection, s.congestionControl) session := &serverSession{ Server: s, ctx: s.ctx, quicConn: connection, source: M.SocksaddrFromNet(connection.RemoteAddr()), connDone: make(chan struct{}), authDone: make(chan struct{}), udpConnMap: make(map[uint16]*udpPacketConn), } session.handle() } type serverSession struct { *Server ctx context.Context quicConn quic.Connection source M.Socksaddr connAccess sync.Mutex connDone chan struct{} connErr error authDone chan struct{} authUser *User udpAccess sync.RWMutex udpConnMap map[uint16]*udpPacketConn } func (s *serverSession) handle() { if s.ctx.Done() != nil { go func() { select { case <-s.ctx.Done(): s.closeWithError(s.ctx.Err()) case <-s.connDone: } }() } go s.loopUniStreams() go s.loopStreams() go s.loopMessages() go s.handleAuthTimeout() go s.loopHeartbeats() } func (s *serverSession) loopUniStreams() { for { uniStream, err := s.quicConn.AcceptUniStream(s.ctx) if err != nil { return } go func() { err = s.handleUniStream(uniStream) if err != nil { s.closeWithError(E.Cause(err, "handle uni stream")) } }() } } func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error { defer stream.CancelRead(0) buffer := buf.New() defer buffer.Release() _, err := buffer.ReadAtLeastFrom(stream, 2) if err != nil { return E.Cause(err, "read request") } version := buffer.Byte(0) if version != Version { return E.New("unknown version ", buffer.Byte(0)) } command := buffer.Byte(1) switch command { case CommandAuthenticate: select { case <-s.authDone: return E.New("authentication: multiple authentication requests") default: } if buffer.Len() < AuthenticateLen { _, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len()) if err != nil { return E.Cause(err, "authentication: read request") } } userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16)) user, loaded := s.userMap[userUUID] if !loaded { return E.New("authentication: unknown user ", userUUID) } handshakeState := s.quicConn.ConnectionState().TLS tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32) if err != nil { return E.Cause(err, "authentication: export keying material") } if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) { return E.New("authentication: token mismatch") } s.authUser = &user close(s.authDone) return nil case CommandPacket: select { case <-s.connDone: return s.connErr case <-s.authDone: } message := allocMessage() err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream)) if err != nil { message.release() return err } s.handleUDPMessage(message, true) return nil case CommandDissociate: select { case <-s.connDone: return s.connErr case <-s.authDone: } if buffer.Len() > 4 { return E.New("invalid dissociate message") } var sessionID uint16 err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID) if err != nil { return err } s.udpAccess.RLock() udpConn, loaded := s.udpConnMap[sessionID] s.udpAccess.RUnlock() if loaded { udpConn.closeWithError(E.New("remote closed")) s.udpAccess.Lock() delete(s.udpConnMap, sessionID) s.udpAccess.Unlock() } return nil default: return E.New("unknown command ", command) } } func (s *serverSession) handleAuthTimeout() { select { case <-s.connDone: case <-s.authDone: case <-time.After(s.authTimeout): s.closeWithError(E.New("authentication timeout")) } } func (s *serverSession) loopStreams() { for { stream, err := s.quicConn.AcceptStream(s.ctx) if err != nil { return } go func() { err = s.handleStream(stream) if err != nil { stream.CancelRead(0) stream.Close() s.logger.Error(E.Cause(err, "handle stream request")) } }() } } func (s *serverSession) handleStream(stream quic.Stream) error { buffer := buf.NewSize(2 + M.MaxSocksaddrLength) defer buffer.Release() _, err := buffer.ReadAtLeastFrom(stream, 2) if err != nil { return E.Cause(err, "read request") } version, _ := buffer.ReadByte() if version != Version { return E.New("unknown version ", buffer.Byte(0)) } command, _ := buffer.ReadByte() if command != CommandConnect { return E.New("unsupported stream command ", command) } destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream)) if err != nil { return E.Cause(err, "read request destination") } select { case <-s.connDone: return s.connErr case <-s.authDone: } var conn net.Conn = &serverConn{ Stream: stream, destination: destination, } if buffer.IsEmpty() { buffer.Release() } else { conn = bufio.NewCachedConn(conn, buffer) } ctx := s.ctx if s.authUser.Name != "" { ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) } _ = s.handler.NewConnection(ctx, conn, M.Metadata{ Source: s.source, Destination: destination, }) return nil } func (s *serverSession) loopHeartbeats() { ticker := time.NewTicker(s.heartbeat) defer ticker.Stop() for { select { case <-s.connDone: return case <-ticker.C: err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat}) if err != nil { s.closeWithError(E.Cause(err, "send heartbeat")) } } } } func (s *serverSession) closeWithError(err error) { s.connAccess.Lock() defer s.connAccess.Unlock() select { case <-s.connDone: return default: s.connErr = err close(s.connDone) } if E.IsClosedOrCanceled(err) { s.logger.Debug(E.Cause(err, "connection failed")) } else { s.logger.Error(E.Cause(err, "connection failed")) } _ = s.quicConn.CloseWithError(0, "") } type serverConn struct { quic.Stream destination M.Socksaddr } func (c *serverConn) Read(p []byte) (n int, err error) { n, err = c.Stream.Read(p) return n, baderror.WrapQUIC(err) } func (c *serverConn) Write(p []byte) (n int, err error) { n, err = c.Stream.Write(p) return n, baderror.WrapQUIC(err) } func (c *serverConn) LocalAddr() net.Addr { return c.destination } func (c *serverConn) RemoteAddr() net.Addr { return M.Socksaddr{} } func (c *serverConn) Close() error { c.Stream.CancelRead(0) return c.Stream.Close() }