"""
Token repository using Redis as backend.
"""

from typing import Any, Optional
from datetime import datetime, timezone
from hashlib import md5

from selfprivacy_api.utils.redis_pool import RedisPool

from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey

from selfprivacy_api.repositories.tokens.exceptions import TokenNotFound
from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
    AbstractTokensRepository,
)

TOKENS_PREFIX = "token_repo:tokens:"
NEW_DEVICE_KEY_REDIS_KEY = "token_repo:new_device_key"
RECOVERY_KEY_REDIS_KEY = "token_repo:recovery_key"


class RedisTokensRepository(AbstractTokensRepository):
    """
    Token repository using Redis as a backend
    """

    def __init__(self):
        self.connection = RedisPool().get_connection()

    @staticmethod
    def token_key_for_device(device_name: str):
        md5_hash = md5(usedforsecurity=False)
        md5_hash.update(bytes(device_name, "utf-8"))
        digest = md5_hash.hexdigest()
        return TOKENS_PREFIX + digest

    def get_tokens(self) -> list[Token]:
        """Get the tokens"""
        redis = self.connection
        token_keys: list[str] = redis.keys(TOKENS_PREFIX + "*")  # type: ignore
        tokens = []
        for key in token_keys:
            token = self._token_from_hash(key)
            if token is not None:
                tokens.append(token)
        return tokens

    def _discover_token_key(self, input_token: Token) -> Optional[str]:
        """brute-force searching for tokens, for robust deletion"""
        redis = self.connection
        token_keys: list[str] = redis.keys(TOKENS_PREFIX + "*")  # type: ignore
        for key in token_keys:
            token = self._token_from_hash(key)
            if token == input_token:
                return key
        return None

    def delete_token(self, input_token: Token) -> None:
        """Delete the token"""
        redis = self.connection
        key = self._discover_token_key(input_token)
        if key is None:
            raise TokenNotFound
        redis.delete(key)

    def get_recovery_key(self) -> Optional[RecoveryKey]:
        """Get the recovery key"""
        redis = self.connection
        if redis.exists(RECOVERY_KEY_REDIS_KEY):
            return self._recovery_key_from_hash(RECOVERY_KEY_REDIS_KEY)
        return None

    def _store_recovery_key(self, recovery_key: RecoveryKey) -> None:
        self._store_model_as_hash(RECOVERY_KEY_REDIS_KEY, recovery_key)

    def _delete_recovery_key(self) -> None:
        """Delete the recovery key"""
        redis = self.connection
        redis.delete(RECOVERY_KEY_REDIS_KEY)

    def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
        """Store new device key directly"""
        self._store_model_as_hash(NEW_DEVICE_KEY_REDIS_KEY, new_device_key)

    def delete_new_device_key(self) -> None:
        """Delete the new device key"""
        redis = self.connection
        redis.delete(NEW_DEVICE_KEY_REDIS_KEY)

    @staticmethod
    def _token_redis_key(token: Token) -> str:
        return RedisTokensRepository.token_key_for_device(token.device_name)

    def _store_token(self, new_token: Token):
        """Store a token directly"""
        key = RedisTokensRepository._token_redis_key(new_token)
        self._store_model_as_hash(key, new_token)

    def _decrement_recovery_token(self):
        """Decrement recovery key use count by one"""
        if self.is_recovery_key_valid():
            recovery_key = self.get_recovery_key()
            if recovery_key is None:
                return
            uses_left = recovery_key.uses_left
            if uses_left is not None:
                redis = self.connection
                redis.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1)

    def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
        """Retrieves new device key that is already stored."""
        return self._new_device_key_from_hash(NEW_DEVICE_KEY_REDIS_KEY)

    @staticmethod
    def _is_date_key(key: str) -> bool:
        return key in [
            "created_at",
            "expires_at",
        ]

    @staticmethod
    def _prepare_model_dict(model_dict: dict[str, Any]) -> None:
        date_keys = [
            key for key in model_dict.keys() if RedisTokensRepository._is_date_key(key)
        ]
        for date in date_keys:
            if model_dict[date] != "None":
                model_dict[date] = datetime.fromisoformat(model_dict[date])
        for key in model_dict.keys():
            if model_dict[key] == "None":
                model_dict[key] = None

    def _model_dict_from_hash(self, redis_key: str) -> Optional[dict[str, Any]]:
        redis = self.connection
        if redis.exists(redis_key):
            token_dict: dict[str, Any] = redis.hgetall(redis_key)  # type: ignore
            RedisTokensRepository._prepare_model_dict(token_dict)
            return token_dict
        return None

    def _hash_as_model(self, redis_key: str, model_class):
        token_dict = self._model_dict_from_hash(redis_key)
        if token_dict is not None:
            return model_class(**token_dict)
        return None

    def _token_from_hash(self, redis_key: str) -> Optional[Token]:
        token = self._hash_as_model(redis_key, Token)
        if token is not None:
            token.created_at = token.created_at.replace(tzinfo=None)
            return token
        return None

    def _recovery_key_from_hash(self, redis_key: str) -> Optional[RecoveryKey]:
        return self._hash_as_model(redis_key, RecoveryKey)

    def _new_device_key_from_hash(self, redis_key: str) -> Optional[NewDeviceKey]:
        return self._hash_as_model(redis_key, NewDeviceKey)

    def _store_model_as_hash(self, redis_key, model):
        redis = self.connection
        for key, value in model.dict().items():
            if isinstance(value, datetime):
                if value.tzinfo is None:
                    value = value.replace(tzinfo=timezone.utc)
                value = value.isoformat()
            redis.hset(redis_key, key, str(value))