package sniff

import (
	"bytes"
	"context"
	"crypto"
	"crypto/aes"
	"crypto/tls"
	"encoding/binary"
	"io"
	"os"

	"github.com/sagernet/sing-box/adapter"
	"github.com/sagernet/sing-box/common/ja3"
	"github.com/sagernet/sing-box/common/sniff/internal/qtls"
	C "github.com/sagernet/sing-box/constant"
	"github.com/sagernet/sing/common/buf"
	E "github.com/sagernet/sing/common/exceptions"

	"golang.org/x/crypto/hkdf"
)

var ErrClientHelloFragmented = E.New("need more packet for chromium QUIC connection")

func QUICClientHello(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error {
	reader := bytes.NewReader(packet)
	typeByte, err := reader.ReadByte()
	if err != nil {
		return err
	}
	if typeByte&0x40 == 0 {
		return E.New("bad type byte")
	}
	var versionNumber uint32
	err = binary.Read(reader, binary.BigEndian, &versionNumber)
	if err != nil {
		return err
	}
	if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 {
		return E.New("bad version")
	}
	packetType := (typeByte & 0x30) >> 4
	if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 {
		return E.New("bad packet type")
	}

	destConnIDLen, err := reader.ReadByte()
	if err != nil {
		return err
	}

	if destConnIDLen == 0 || destConnIDLen > 20 {
		return E.New("bad destination connection id length")
	}

	destConnID := make([]byte, destConnIDLen)
	_, err = io.ReadFull(reader, destConnID)
	if err != nil {
		return err
	}

	srcConnIDLen, err := reader.ReadByte()
	if err != nil {
		return err
	}

	_, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen))
	if err != nil {
		return err
	}

	tokenLen, err := qtls.ReadUvarint(reader)
	if err != nil {
		return err
	}

	_, err = io.CopyN(io.Discard, reader, int64(tokenLen))
	if err != nil {
		return err
	}

	packetLen, err := qtls.ReadUvarint(reader)
	if err != nil {
		return err
	}

	hdrLen := int(reader.Size()) - reader.Len()
	if hdrLen+int(packetLen) > len(packet) {
		return os.ErrInvalid
	}

	_, err = io.CopyN(io.Discard, reader, 4)
	if err != nil {
		return err
	}

	pnBytes := make([]byte, aes.BlockSize)
	_, err = io.ReadFull(reader, pnBytes)
	if err != nil {
		return err
	}

	var salt []byte
	switch versionNumber {
	case qtls.Version1:
		salt = qtls.SaltV1
	case qtls.Version2:
		salt = qtls.SaltV2
	default:
		salt = qtls.SaltOld
	}
	var hkdfHeaderProtectionLabel string
	switch versionNumber {
	case qtls.Version2:
		hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2
	default:
		hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1
	}
	initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
	secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
	hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16)
	block, err := aes.NewCipher(hpKey)
	if err != nil {
		return err
	}
	mask := make([]byte, aes.BlockSize)
	block.Encrypt(mask, pnBytes)
	newPacket := make([]byte, len(packet))
	copy(newPacket, packet)
	newPacket[0] ^= mask[0] & 0xf
	for i := range newPacket[hdrLen : hdrLen+4] {
		newPacket[hdrLen+i] ^= mask[i+1]
	}
	packetNumberLength := newPacket[0]&0x3 + 1
	if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen {
		return os.ErrInvalid
	}
	var packetNumber uint32
	switch packetNumberLength {
	case 1:
		packetNumber = uint32(newPacket[hdrLen])
	case 2:
		packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:]))
	case 3:
		packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16
	case 4:
		packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:])
	default:
		return E.New("bad packet number length")
	}
	extHdrLen := hdrLen + int(packetNumberLength)
	copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:])
	data := newPacket[extHdrLen : int(packetLen)+hdrLen]

	var keyLabel string
	var ivLabel string
	switch versionNumber {
	case qtls.Version2:
		keyLabel = qtls.HKDFLabelKeyV2
		ivLabel = qtls.HKDFLabelIVV2
	default:
		keyLabel = qtls.HKDFLabelKeyV1
		ivLabel = qtls.HKDFLabelIVV1
	}

	key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
	iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
	cipher := qtls.AEADAESGCMTLS13(key, iv)
	nonce := make([]byte, int32(cipher.NonceSize()))
	binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
	decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen])
	if err != nil {
		return err
	}
	var frameType byte
	var fragments []qCryptoFragment
	decryptedReader := bytes.NewReader(decrypted)
	const (
		frameTypePadding         = 0x00
		frameTypePing            = 0x01
		frameTypeAck             = 0x02
		frameTypeAck2            = 0x03
		frameTypeCrypto          = 0x06
		frameTypeConnectionClose = 0x1c
	)
	var frameTypeList []uint8
	for {
		frameType, err = decryptedReader.ReadByte()
		if err == io.EOF {
			break
		}
		frameTypeList = append(frameTypeList, frameType)
		switch frameType {
		case frameTypePadding:
			continue
		case frameTypePing:
			continue
		case frameTypeAck, frameTypeAck2:
			_, err = qtls.ReadUvarint(decryptedReader) // Largest Acknowledged
			if err != nil {
				return err
			}
			_, err = qtls.ReadUvarint(decryptedReader) // ACK Delay
			if err != nil {
				return err
			}
			ackRangeCount, err := qtls.ReadUvarint(decryptedReader) // ACK Range Count
			if err != nil {
				return err
			}
			_, err = qtls.ReadUvarint(decryptedReader) // First ACK Range
			if err != nil {
				return err
			}
			for i := 0; i < int(ackRangeCount); i++ {
				_, err = qtls.ReadUvarint(decryptedReader) // Gap
				if err != nil {
					return err
				}
				_, err = qtls.ReadUvarint(decryptedReader) // ACK Range Length
				if err != nil {
					return err
				}
			}
			if frameType == 0x03 {
				_, err = qtls.ReadUvarint(decryptedReader) // ECT0 Count
				if err != nil {
					return err
				}
				_, err = qtls.ReadUvarint(decryptedReader) // ECT1 Count
				if err != nil {
					return err
				}
				_, err = qtls.ReadUvarint(decryptedReader) // ECN-CE Count
				if err != nil {
					return err
				}
			}
		case frameTypeCrypto:
			var offset uint64
			offset, err = qtls.ReadUvarint(decryptedReader)
			if err != nil {
				return err
			}
			var length uint64
			length, err = qtls.ReadUvarint(decryptedReader)
			if err != nil {
				return err
			}
			index := len(decrypted) - decryptedReader.Len()
			fragments = append(fragments, qCryptoFragment{offset, length, decrypted[index : index+int(length)]})
			_, err = decryptedReader.Seek(int64(length), io.SeekCurrent)
			if err != nil {
				return err
			}
		case frameTypeConnectionClose:
			_, err = qtls.ReadUvarint(decryptedReader) // Error Code
			if err != nil {
				return err
			}
			_, err = qtls.ReadUvarint(decryptedReader) // Frame Type
			if err != nil {
				return err
			}
			var length uint64
			length, err = qtls.ReadUvarint(decryptedReader) // Reason Phrase Length
			if err != nil {
				return err
			}
			_, err = decryptedReader.Seek(int64(length), io.SeekCurrent) // Reason Phrase
			if err != nil {
				return err
			}
		default:
			return os.ErrInvalid
		}
	}
	if metadata.SniffContext != nil {
		fragments = append(fragments, metadata.SniffContext.([]qCryptoFragment)...)
		metadata.SniffContext = nil
	}
	var frameLen uint64
	for _, fragment := range fragments {
		frameLen += fragment.length
	}
	buffer := buf.NewSize(5 + int(frameLen))
	defer buffer.Release()
	buffer.WriteByte(0x16)
	binary.Write(buffer, binary.BigEndian, uint16(0x0303))
	binary.Write(buffer, binary.BigEndian, uint16(frameLen))
	var index uint64
	var length int
find:
	for {
		for _, fragment := range fragments {
			if fragment.offset == index {
				buffer.Write(fragment.payload)
				index = fragment.offset + fragment.length
				length++
				continue find
			}
		}
		break
	}
	metadata.Protocol = C.ProtocolQUIC
	fingerprint, err := ja3.Compute(buffer.Bytes())
	if err != nil {
		metadata.Protocol = C.ProtocolQUIC
		metadata.Client = C.ClientChromium
		metadata.SniffContext = fragments
		return ErrClientHelloFragmented
	}
	metadata.Domain = fingerprint.ServerName
	for metadata.Client == "" {
		if len(frameTypeList) == 1 {
			metadata.Client = C.ClientFirefox
			break
		}
		if frameTypeList[0] == frameTypeCrypto && isZero(frameTypeList[1:]) {
			if len(fingerprint.Versions) == 2 && fingerprint.Versions[0]&ja3.GreaseBitmask == 0x0A0A &&
				len(fingerprint.EllipticCurves) == 5 && fingerprint.EllipticCurves[0]&ja3.GreaseBitmask == 0x0A0A {
				metadata.Client = C.ClientSafari
				break
			}
			if len(fingerprint.CipherSuites) == 1 && fingerprint.CipherSuites[0] == tls.TLS_AES_256_GCM_SHA384 &&
				len(fingerprint.EllipticCurves) == 1 && fingerprint.EllipticCurves[0] == uint16(tls.X25519) &&
				len(fingerprint.SignatureAlgorithms) == 1 && fingerprint.SignatureAlgorithms[0] == uint16(tls.ECDSAWithP256AndSHA256) {
				metadata.Client = C.ClientSafari
				break
			}
		}

		if frameTypeList[len(frameTypeList)-1] == frameTypeCrypto && isZero(frameTypeList[:len(frameTypeList)-1]) {
			metadata.Client = C.ClientQUICGo
			break
		}

		if count(frameTypeList, frameTypeCrypto) > 1 || count(frameTypeList, frameTypePing) > 0 {
			if maybeUQUIC(fingerprint) {
				metadata.Client = C.ClientQUICGo
			} else {
				metadata.Client = C.ClientChromium
			}
			break
		}

		metadata.Client = C.ClientUnknown
		//nolint:staticcheck
		break
	}
	return nil
}

func isZero(slices []uint8) bool {
	for _, slice := range slices {
		if slice != 0 {
			return false
		}
	}
	return true
}

func count(slices []uint8, value uint8) int {
	var times int
	for _, slice := range slices {
		if slice == value {
			times++
		}
	}
	return times
}

type qCryptoFragment struct {
	offset  uint64
	length  uint64
	payload []byte
}