from __future__ import annotations

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from mnemonic import Mnemonic
from secrets import randbelow
import re

from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.repositories.tokens.exceptions import (
    TokenNotFound,
    InvalidMnemonic,
    RecoveryKeyNotFound,
    NewDeviceKeyNotFound,
)
from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey


class AbstractTokensRepository(ABC):
    def get_token_by_token_string(self, token_string: str) -> Token:
        """Get the token by token"""
        tokens = self.get_tokens()
        for token in tokens:
            if token.token == token_string:
                return token

        raise TokenNotFound("Token not found!")

    def get_token_by_name(self, token_name: str) -> Token:
        """Get the token by name"""
        tokens = self.get_tokens()
        for token in tokens:
            if token.device_name == token_name:
                return token

        raise TokenNotFound("Token not found!")

    @abstractmethod
    def get_tokens(self) -> list[Token]:
        """Get the tokens"""

    def create_token(self, device_name: str) -> Token:
        """Create new token"""
        unique_name = self._make_unique_device_name(device_name)
        new_token = Token.generate(unique_name)

        self._store_token(new_token)

        return new_token

    @abstractmethod
    def delete_token(self, input_token: Token) -> None:
        """Delete the token"""

    def refresh_token(self, input_token: Token) -> Token:
        """Change the token field of the existing token"""
        new_token = Token.generate(device_name=input_token.device_name)
        new_token.created_at = input_token.created_at

        if input_token in self.get_tokens():
            self.delete_token(input_token)
            self._store_token(new_token)
            return new_token

        raise TokenNotFound("Token not found!")

    def is_token_valid(self, token_string: str) -> bool:
        """Check if the token is valid"""
        return token_string in [token.token for token in self.get_tokens()]

    def is_token_name_exists(self, token_name: str) -> bool:
        """Check if the token name exists"""
        return token_name in [token.device_name for token in self.get_tokens()]

    def is_token_name_pair_valid(self, token_name: str, token_string: str) -> bool:
        """Check if the token name and token are valid"""
        try:
            token = self.get_token_by_name(token_name)
            if token is None:
                return False
        except TokenNotFound:
            return False
        return token.token == token_string

    @abstractmethod
    def get_recovery_key(self) -> Optional[RecoveryKey]:
        """Get the recovery key"""

    def create_recovery_key(
        self,
        expiration: Optional[datetime],
        uses_left: Optional[int],
    ) -> RecoveryKey:
        """Create the recovery key"""
        recovery_key = RecoveryKey.generate(expiration, uses_left)
        self._store_recovery_key(recovery_key)
        return recovery_key

    def use_mnemonic_recovery_key(
        self, mnemonic_phrase: str, device_name: str
    ) -> Token:
        """Use the mnemonic recovery key and create a new token with the given name"""
        if not self.is_recovery_key_valid():
            raise RecoveryKeyNotFound("Recovery key not found")

        recovery_key = self.get_recovery_key()

        if recovery_key is None:
            raise RecoveryKeyNotFound("Recovery key not found")

        recovery_hex_key = recovery_key.key
        if not self._assert_mnemonic(recovery_hex_key, mnemonic_phrase):
            raise RecoveryKeyNotFound("Recovery key not found")

        new_token = self.create_token(device_name=device_name)

        self._decrement_recovery_token()

        return new_token

    def is_recovery_key_valid(self) -> bool:
        """Check if the recovery key is valid"""
        recovery_key = self.get_recovery_key()
        if recovery_key is None:
            return False
        return recovery_key.is_valid()

    @abstractmethod
    def _store_recovery_key(self, recovery_key: RecoveryKey) -> None:
        """Store recovery key directly"""

    @abstractmethod
    def _delete_recovery_key(self) -> None:
        """Delete the recovery key"""

    def get_new_device_key(self) -> NewDeviceKey:
        """Creates and returns the new device key"""
        new_device_key = NewDeviceKey.generate()
        self._store_new_device_key(new_device_key)

        return new_device_key

    def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
        """Store new device key directly"""

    @abstractmethod
    def delete_new_device_key(self) -> None:
        """Delete the new device key"""

    def use_mnemonic_new_device_key(
        self, mnemonic_phrase: str, device_name: str
    ) -> Token:
        """Use the mnemonic new device key"""
        new_device_key = self._get_stored_new_device_key()
        if not new_device_key:
            raise NewDeviceKeyNotFound

        if not new_device_key.is_valid():
            raise NewDeviceKeyNotFound

        if not self._assert_mnemonic(new_device_key.key, mnemonic_phrase):
            raise NewDeviceKeyNotFound("Phrase is not token!")

        new_token = self.create_token(device_name=device_name)
        self.delete_new_device_key()

        return new_token

    def reset(self):
        for token in self.get_tokens():
            self.delete_token(token)
        self.delete_new_device_key()
        self._delete_recovery_key()

    def clone(self, source: AbstractTokensRepository) -> None:
        """Clone the state of another repository to this one"""
        self.reset()
        for token in source.get_tokens():
            self._store_token(token)

        recovery_key = source.get_recovery_key()
        if recovery_key is not None:
            self._store_recovery_key(recovery_key)

        new_device_key = source._get_stored_new_device_key()
        if new_device_key is not None:
            self._store_new_device_key(new_device_key)

    @abstractmethod
    def _store_token(self, new_token: Token):
        """Store a token directly"""

    @abstractmethod
    def _decrement_recovery_token(self):
        """Decrement recovery key use count by one"""

    @abstractmethod
    def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
        """Retrieves new device key that is already stored."""

    def _make_unique_device_name(self, name: str) -> str:
        """Token name must be an alphanumeric string and not empty.
        Replace invalid characters with '_'
        If name exists, add a random number to the end of the name until it is unique.
        """
        if not re.match("^[a-zA-Z0-9]*$", name):
            name = re.sub("[^a-zA-Z0-9]", "_", name)
        if name == "":
            name = "Unknown device"
        while self.is_token_name_exists(name):
            name += str(randbelow(10))
        return name

    # TODO: find a proper place for it
    def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str):
        """Return true if hex string matches the phrase, false otherwise
        Raise an InvalidMnemonic error if not mnemonic"""
        recovery_token = bytes.fromhex(hex_key)
        if not Mnemonic(language="english").check(mnemonic_phrase):
            raise InvalidMnemonic("Phrase is not mnemonic!")

        phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
        return phrase_bytes == recovery_token