sing-box/transport/tuic/server.go
2023-09-12 13:26:20 +08:00

438 lines
10 KiB
Go

//go:build with_quic
package tuic
import (
"bytes"
"context"
"encoding/binary"
"io"
"net"
"runtime"
"strings"
"sync"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/common/qtls"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/gofrs/uuid/v5"
)
type ServerOptions struct {
Context context.Context
Logger logger.Logger
TLSConfig tls.ServerConfig
Users []User
CongestionControl string
AuthTimeout time.Duration
ZeroRTTHandshake bool
Heartbeat time.Duration
Handler ServerHandler
}
type User struct {
Name string
UUID uuid.UUID
Password string
}
type ServerHandler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
}
type Server struct {
ctx context.Context
logger logger.Logger
tlsConfig tls.ServerConfig
heartbeat time.Duration
quicConfig *quic.Config
userMap map[uuid.UUID]User
congestionControl string
authTimeout time.Duration
handler ServerHandler
quicListener io.Closer
}
func NewServer(options ServerOptions) (*Server, error) {
if options.AuthTimeout == 0 {
options.AuthTimeout = 3 * time.Second
}
if options.Heartbeat == 0 {
options.Heartbeat = 10 * time.Second
}
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
MaxDatagramFrameSize: 1400,
EnableDatagrams: true,
Allow0RTT: options.ZeroRTTHandshake,
MaxIncomingStreams: 1 << 60,
MaxIncomingUniStreams: 1 << 60,
}
switch options.CongestionControl {
case "":
options.CongestionControl = "cubic"
case "cubic", "new_reno", "bbr":
default:
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
}
if len(options.Users) == 0 {
return nil, E.New("missing users")
}
userMap := make(map[uuid.UUID]User)
for _, user := range options.Users {
userMap[user.UUID] = user
}
return &Server{
ctx: options.Context,
logger: options.Logger,
tlsConfig: options.TLSConfig,
heartbeat: options.Heartbeat,
quicConfig: quicConfig,
userMap: userMap,
congestionControl: options.CongestionControl,
authTimeout: options.AuthTimeout,
handler: options.Handler,
}, nil
}
func (s *Server) Start(conn net.PacketConn) error {
if !s.quicConfig.Allow0RTT {
listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
if err != nil {
return err
}
s.quicListener = listener
go func() {
for {
connection, hErr := listener.Accept(s.ctx)
if hErr != nil {
if strings.Contains(hErr.Error(), "server closed") {
s.logger.Debug(E.Cause(hErr, "listener closed"))
} else {
s.logger.Error(E.Cause(hErr, "listener closed"))
}
return
}
go s.handleConnection(connection)
}
}()
} else {
listener, err := qtls.ListenEarly(conn, s.tlsConfig, s.quicConfig)
if err != nil {
return err
}
s.quicListener = listener
go func() {
for {
connection, hErr := listener.Accept(s.ctx)
if hErr != nil {
if strings.Contains(hErr.Error(), "server closed") {
s.logger.Debug(E.Cause(hErr, "listener closed"))
} else {
s.logger.Error(E.Cause(hErr, "listener closed"))
}
return
}
go s.handleConnection(connection)
}
}()
}
return nil
}
func (s *Server) Close() error {
return common.Close(
s.quicListener,
)
}
func (s *Server) handleConnection(connection quic.Connection) {
setCongestion(s.ctx, connection, s.congestionControl)
session := &serverSession{
Server: s,
ctx: s.ctx,
quicConn: connection,
source: M.SocksaddrFromNet(connection.RemoteAddr()),
connDone: make(chan struct{}),
authDone: make(chan struct{}),
udpConnMap: make(map[uint16]*udpPacketConn),
}
session.handle()
}
type serverSession struct {
*Server
ctx context.Context
quicConn quic.Connection
source M.Socksaddr
connAccess sync.Mutex
connDone chan struct{}
connErr error
authDone chan struct{}
authUser *User
udpAccess sync.RWMutex
udpConnMap map[uint16]*udpPacketConn
}
func (s *serverSession) handle() {
if s.ctx.Done() != nil {
go func() {
select {
case <-s.ctx.Done():
s.closeWithError(s.ctx.Err())
case <-s.connDone:
}
}()
}
go s.loopUniStreams()
go s.loopStreams()
go s.loopMessages()
go s.handleAuthTimeout()
go s.loopHeartbeats()
}
func (s *serverSession) loopUniStreams() {
for {
uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
if err != nil {
return
}
go func() {
err = s.handleUniStream(uniStream)
if err != nil {
s.closeWithError(E.Cause(err, "handle uni stream"))
}
}()
}
}
func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
defer stream.CancelRead(0)
buffer := buf.New()
defer buffer.Release()
_, err := buffer.ReadAtLeastFrom(stream, 2)
if err != nil {
return E.Cause(err, "read request")
}
version := buffer.Byte(0)
if version != Version {
return E.New("unknown version ", buffer.Byte(0))
}
command := buffer.Byte(1)
switch command {
case CommandAuthenticate:
select {
case <-s.authDone:
return E.New("authentication: multiple authentication requests")
default:
}
if buffer.Len() < AuthenticateLen {
_, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len())
if err != nil {
return E.Cause(err, "authentication: read request")
}
}
userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16))
user, loaded := s.userMap[userUUID]
if !loaded {
return E.New("authentication: unknown user ", userUUID)
}
handshakeState := s.quicConn.ConnectionState()
tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
if err != nil {
return E.Cause(err, "authentication: export keying material")
}
if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
return E.New("authentication: token mismatch")
}
s.authUser = &user
close(s.authDone)
return nil
case CommandPacket:
select {
case <-s.connDone:
return s.connErr
case <-s.authDone:
}
message := allocMessage()
err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
if err != nil {
message.release()
return err
}
s.handleUDPMessage(message, true)
return nil
case CommandDissociate:
select {
case <-s.connDone:
return s.connErr
case <-s.authDone:
}
if buffer.Len() > 4 {
return E.New("invalid dissociate message")
}
var sessionID uint16
err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID)
if err != nil {
return err
}
s.udpAccess.RLock()
udpConn, loaded := s.udpConnMap[sessionID]
s.udpAccess.RUnlock()
if loaded {
udpConn.closeWithError(E.New("remote closed"))
s.udpAccess.Lock()
delete(s.udpConnMap, sessionID)
s.udpAccess.Unlock()
}
return nil
default:
return E.New("unknown command ", command)
}
}
func (s *serverSession) handleAuthTimeout() {
select {
case <-s.connDone:
case <-s.authDone:
case <-time.After(s.authTimeout):
s.closeWithError(E.New("authentication timeout"))
}
}
func (s *serverSession) loopStreams() {
for {
stream, err := s.quicConn.AcceptStream(s.ctx)
if err != nil {
return
}
go func() {
err = s.handleStream(stream)
if err != nil {
stream.CancelRead(0)
stream.Close()
s.logger.Error(E.Cause(err, "handle stream request"))
}
}()
}
}
func (s *serverSession) handleStream(stream quic.Stream) error {
buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
defer buffer.Release()
_, err := buffer.ReadAtLeastFrom(stream, 2)
if err != nil {
return E.Cause(err, "read request")
}
version, _ := buffer.ReadByte()
if version != Version {
return E.New("unknown version ", buffer.Byte(0))
}
command, _ := buffer.ReadByte()
if command != CommandConnect {
return E.New("unsupported stream command ", command)
}
destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream))
if err != nil {
return E.Cause(err, "read request destination")
}
select {
case <-s.connDone:
return s.connErr
case <-s.authDone:
}
var conn net.Conn = &serverConn{
Stream: stream,
destination: destination,
}
if buffer.IsEmpty() {
buffer.Release()
} else {
conn = bufio.NewCachedConn(conn, buffer)
}
ctx := s.ctx
if s.authUser.Name != "" {
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
}
_ = s.handler.NewConnection(ctx, conn, M.Metadata{
Source: s.source,
Destination: destination,
})
return nil
}
func (s *serverSession) loopHeartbeats() {
ticker := time.NewTicker(s.heartbeat)
defer ticker.Stop()
for {
select {
case <-s.connDone:
return
case <-ticker.C:
err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
if err != nil {
s.closeWithError(E.Cause(err, "send heartbeat"))
}
}
}
}
func (s *serverSession) closeWithError(err error) {
s.connAccess.Lock()
defer s.connAccess.Unlock()
select {
case <-s.connDone:
return
default:
s.connErr = err
close(s.connDone)
}
if E.IsClosedOrCanceled(err) {
s.logger.Debug(E.Cause(err, "connection failed"))
} else {
s.logger.Error(E.Cause(err, "connection failed"))
}
_ = s.quicConn.CloseWithError(0, "")
}
type serverConn struct {
quic.Stream
destination M.Socksaddr
}
func (c *serverConn) Read(p []byte) (n int, err error) {
n, err = c.Stream.Read(p)
return n, baderror.WrapQUIC(err)
}
func (c *serverConn) Write(p []byte) (n int, err error) {
n, err = c.Stream.Write(p)
return n, baderror.WrapQUIC(err)
}
func (c *serverConn) LocalAddr() net.Addr {
return c.destination
}
func (c *serverConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}
func (c *serverConn) Close() error {
c.Stream.CancelRead(0)
return c.Stream.Close()
}