//go:build with_quic

package inbound

import (
	"bytes"
	"context"
	"sync"

	"github.com/sagernet/quic-go"
	"github.com/sagernet/quic-go/congestion"
	"github.com/sagernet/sing-box/adapter"
	C "github.com/sagernet/sing-box/constant"
	"github.com/sagernet/sing-box/log"
	"github.com/sagernet/sing-box/option"
	"github.com/sagernet/sing-box/transport/hysteria"
	"github.com/sagernet/sing-dns"
	"github.com/sagernet/sing/common"
	E "github.com/sagernet/sing/common/exceptions"
	M "github.com/sagernet/sing/common/metadata"
	N "github.com/sagernet/sing/common/network"
)

var _ adapter.Inbound = (*Hysteria)(nil)

type Hysteria struct {
	myInboundAdapter
	quicConfig   *quic.Config
	tlsConfig    *TLSConfig
	authKey      []byte
	xplusKey     []byte
	sendBPS      uint64
	recvBPS      uint64
	listener     quic.Listener
	udpAccess    sync.RWMutex
	udpSessionId uint32
	udpSessions  map[uint32]chan *hysteria.UDPMessage
	udpDefragger hysteria.Defragger
}

func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
	quicConfig := &quic.Config{
		InitialStreamReceiveWindow:     options.ReceiveWindowConn,
		MaxStreamReceiveWindow:         options.ReceiveWindowConn,
		InitialConnectionReceiveWindow: options.ReceiveWindowClient,
		MaxConnectionReceiveWindow:     options.ReceiveWindowClient,
		MaxIncomingStreams:             int64(options.MaxConnClient),
		KeepAlivePeriod:                hysteria.KeepAlivePeriod,
		DisablePathMTUDiscovery:        options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows),
		EnableDatagrams:                true,
	}
	if options.ReceiveWindowConn == 0 {
		quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
		quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
	}
	if options.ReceiveWindowClient == 0 {
		quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
		quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
	}
	if quicConfig.MaxIncomingStreams == 0 {
		quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
	}
	var auth []byte
	if len(options.Auth) > 0 {
		auth = options.Auth
	} else {
		auth = []byte(options.AuthString)
	}
	var xplus []byte
	if options.Obfs != "" {
		xplus = []byte(options.Obfs)
	}
	var up, down uint64
	if len(options.Up) > 0 {
		up = hysteria.StringToBps(options.Up)
		if up == 0 {
			return nil, E.New("invalid up speed format: ", options.Up)
		}
	} else {
		up = uint64(options.UpMbps) * hysteria.MbpsToBps
	}
	if len(options.Down) > 0 {
		down = hysteria.StringToBps(options.Down)
		if down == 0 {
			return nil, E.New("invalid down speed format: ", options.Down)
		}
	} else {
		down = uint64(options.DownMbps) * hysteria.MbpsToBps
	}
	if up < hysteria.MinSpeedBPS {
		return nil, E.New("invalid up speed")
	}
	if down < hysteria.MinSpeedBPS {
		return nil, E.New("invalid down speed")
	}
	inbound := &Hysteria{
		myInboundAdapter: myInboundAdapter{
			protocol:      C.TypeHysteria,
			network:       []string{N.NetworkUDP},
			ctx:           ctx,
			router:        router,
			logger:        logger,
			tag:           tag,
			listenOptions: options.ListenOptions,
		},
		quicConfig:  quicConfig,
		authKey:     auth,
		xplusKey:    xplus,
		sendBPS:     up,
		recvBPS:     down,
		udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
	}
	if options.TLS == nil || !options.TLS.Enabled {
		return nil, C.ErrTLSRequired
	}
	if len(options.TLS.ALPN) == 0 {
		options.TLS.ALPN = []string{hysteria.DefaultALPN}
	}
	tlsConfig, err := NewTLSConfig(ctx, logger, common.PtrValueOrDefault(options.TLS))
	if err != nil {
		return nil, err
	}
	inbound.tlsConfig = tlsConfig
	return inbound, nil
}

func (h *Hysteria) Start() error {
	packetConn, err := h.myInboundAdapter.ListenUDP()
	if err != nil {
		return err
	}
	if len(h.xplusKey) > 0 {
		packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
		packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
	}
	err = h.tlsConfig.Start()
	if err != nil {
		return err
	}
	listener, err := quic.Listen(packetConn, h.tlsConfig.Config(), h.quicConfig)
	if err != nil {
		return err
	}
	h.listener = listener
	h.logger.Info("udp server started at ", listener.Addr())
	go h.acceptLoop()
	return nil
}

func (h *Hysteria) acceptLoop() {
	for {
		ctx := log.ContextWithNewID(h.ctx)
		conn, err := h.listener.Accept(ctx)
		if err != nil {
			return
		}
		h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
		go func() {
			hErr := h.accept(ctx, conn)
			if hErr != nil {
				conn.CloseWithError(0, "")
				NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr()))
			}
		}()
	}
}

func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
	controlStream, err := conn.AcceptStream(ctx)
	if err != nil {
		return err
	}
	clientHello, err := hysteria.ReadClientHello(controlStream)
	if err != nil {
		return err
	}
	if !bytes.Equal(clientHello.Auth, h.authKey) {
		err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
			Message: "wrong password",
		})
		return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err)
	}
	if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 {
		return E.New("invalid rate from client")
	}
	serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS
	if h.sendBPS > 0 && serverSendBPS > h.sendBPS {
		serverSendBPS = h.sendBPS
	}
	if h.recvBPS > 0 && serverRecvBPS > h.recvBPS {
		serverRecvBPS = h.recvBPS
	}
	err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
		OK:      true,
		SendBPS: serverSendBPS,
		RecvBPS: serverRecvBPS,
	})
	if err != nil {
		return err
	}
	conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
	go h.udpRecvLoop(conn)
	for {
		var stream quic.Stream
		stream, err = conn.AcceptStream(ctx)
		if err != nil {
			return err
		}
		go func() {
			hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
			if hErr != nil {
				stream.Close()
				NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
			}
		}()
	}
}

func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
	for {
		packet, err := conn.ReceiveMessage()
		if err != nil {
			return
		}
		message, err := hysteria.ParseUDPMessage(packet)
		if err != nil {
			h.logger.Error("parse udp message: ", err)
			continue
		}
		dfMsg := h.udpDefragger.Feed(message)
		if dfMsg == nil {
			continue
		}
		h.udpAccess.RLock()
		ch, ok := h.udpSessions[dfMsg.SessionID]
		if ok {
			select {
			case ch <- dfMsg:
				// OK
			default:
				// Silently drop the message when the channel is full
			}
		}
		h.udpAccess.RUnlock()
	}
}

func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
	request, err := hysteria.ReadClientRequest(stream)
	if err != nil {
		return err
	}
	err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
		OK: true,
	})
	if err != nil {
		return err
	}
	var metadata adapter.InboundContext
	metadata.Inbound = h.tag
	metadata.InboundType = C.TypeHysteria
	metadata.SniffEnabled = h.listenOptions.SniffEnabled
	metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination
	metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy)
	metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr())
	metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr())
	metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
	if !request.UDP {
		h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
		return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata)
	} else {
		h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
		var id uint32
		h.udpAccess.Lock()
		id = h.udpSessionId
		nCh := make(chan *hysteria.UDPMessage, 1024)
		h.udpSessions[id] = nCh
		h.udpSessionId += 1
		h.udpAccess.Unlock()
		packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
			h.udpAccess.Lock()
			if ch, ok := h.udpSessions[id]; ok {
				close(ch)
				delete(h.udpSessions, id)
			}
			h.udpAccess.Unlock()
			return nil
		}))
		go packetConn.Hold()
		return h.router.RoutePacketConnection(ctx, packetConn, metadata)
	}
}

func (h *Hysteria) Close() error {
	h.udpAccess.Lock()
	for _, session := range h.udpSessions {
		close(session)
	}
	h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
	h.udpAccess.Unlock()
	return common.Close(
		&h.myInboundAdapter,
		h.listener,
		common.PtrOrNil(h.tlsConfig),
	)
}