feature(auth): tz_aware recovery

This commit is contained in:
Houkime 2023-11-10 17:10:01 +00:00
parent 8badb9aaaf
commit dd6f37a17d
5 changed files with 36 additions and 14 deletions

View file

@ -7,7 +7,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from mnemonic import Mnemonic from mnemonic import Mnemonic
from selfprivacy_api.utils.timeutils import ensure_tz_aware from selfprivacy_api.utils.timeutils import ensure_tz_aware, ensure_tz_aware_strict
from selfprivacy_api.repositories.tokens.redis_tokens_repository import ( from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
RedisTokensRepository, RedisTokensRepository,
) )
@ -95,16 +95,22 @@ class RecoveryTokenStatus(BaseModel):
def get_api_recovery_token_status() -> RecoveryTokenStatus: def get_api_recovery_token_status() -> RecoveryTokenStatus:
"""Get the recovery token status""" """Get the recovery token status, timezone-aware"""
token = TOKEN_REPO.get_recovery_key() token = TOKEN_REPO.get_recovery_key()
if token is None: if token is None:
return RecoveryTokenStatus(exists=False, valid=False) return RecoveryTokenStatus(exists=False, valid=False)
is_valid = TOKEN_REPO.is_recovery_key_valid() is_valid = TOKEN_REPO.is_recovery_key_valid()
# New tokens are tz-aware, but older ones might not be
expiry_date = token.expires_at
if expiry_date is not None:
expiry_date = ensure_tz_aware_strict(expiry_date)
return RecoveryTokenStatus( return RecoveryTokenStatus(
exists=True, exists=True,
valid=is_valid, valid=is_valid,
date=_naive(token.created_at), date=ensure_tz_aware_strict(token.created_at),
expiration=_naive(token.expires_at), expiration=expiry_date,
uses_left=token.uses_left, uses_left=token.uses_left,
) )

View file

@ -38,7 +38,7 @@ class ApiRecoveryKeyStatus:
def get_recovery_key_status() -> ApiRecoveryKeyStatus: def get_recovery_key_status() -> ApiRecoveryKeyStatus:
"""Get recovery key status""" """Get recovery key status, times are timezone-aware"""
status = get_api_recovery_token_status() status = get_api_recovery_token_status()
if status is None or not status.exists: if status is None or not status.exists:
return ApiRecoveryKeyStatus( return ApiRecoveryKeyStatus(

View file

@ -67,8 +67,7 @@ def mnemonic_to_hex(mnemonic):
return Mnemonic(language="english").to_entropy(mnemonic).hex() return Mnemonic(language="english").to_entropy(mnemonic).hex()
def assert_recovery_recent(time_generated): def assert_recovery_recent(time_generated: str):
assert ( assert datetime.fromisoformat(time_generated) - timedelta(seconds=5) < datetime.now(
datetime.strptime(time_generated, "%Y-%m-%dT%H:%M:%S.%f") - timedelta(seconds=5) timezone.utc
< datetime.now()
) )

View file

@ -2,6 +2,10 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
import pytest
from datetime import datetime, timezone
from tests.common import ( from tests.common import (
generate_api_query, generate_api_query,
assert_recovery_recent, assert_recovery_recent,
@ -11,6 +15,7 @@ from tests.common import (
# Graphql API's output should be timezone-naive # Graphql API's output should be timezone-naive
from tests.common import five_minutes_into_future_naive_utc as five_minutes_into_future from tests.common import five_minutes_into_future_naive_utc as five_minutes_into_future
from tests.common import five_minutes_into_future as five_minutes_into_future_tz
from tests.common import five_minutes_into_past_naive_utc as five_minutes_into_past from tests.common import five_minutes_into_past_naive_utc as five_minutes_into_past
from tests.test_graphql.api_common import ( from tests.test_graphql.api_common import (
@ -158,17 +163,24 @@ def test_graphql_generate_recovery_key(client, authorized_client, tokens_file):
graphql_use_recovery_key(client, key, "new_test_token2") graphql_use_recovery_key(client, key, "new_test_token2")
@pytest.mark.parametrize(
"expiration_date", [five_minutes_into_future(), five_minutes_into_future_tz()]
)
def test_graphql_generate_recovery_key_with_expiration_date( def test_graphql_generate_recovery_key_with_expiration_date(
client, authorized_client, tokens_file client, authorized_client, tokens_file, expiration_date: datetime
): ):
expiration_date = five_minutes_into_future()
key = graphql_make_new_recovery_key(authorized_client, expires_at=expiration_date) key = graphql_make_new_recovery_key(authorized_client, expires_at=expiration_date)
status = graphql_recovery_status(authorized_client) status = graphql_recovery_status(authorized_client)
assert status["exists"] is True assert status["exists"] is True
assert status["valid"] is True assert status["valid"] is True
assert_recovery_recent(status["creationDate"]) assert_recovery_recent(status["creationDate"])
assert status["expirationDate"] == expiration_date.isoformat()
# timezone-aware comparison. Should pass regardless of server's tz
assert datetime.fromisoformat(
status["expirationDate"]
) == expiration_date.astimezone(timezone.utc)
assert status["usesLeft"] is None assert status["usesLeft"] is None
graphql_use_recovery_key(client, key, "new_test_token") graphql_use_recovery_key(client, key, "new_test_token")
@ -194,7 +206,11 @@ def test_graphql_use_recovery_key_after_expiration(
assert status["exists"] is True assert status["exists"] is True
assert status["valid"] is False assert status["valid"] is False
assert_recovery_recent(status["creationDate"]) assert_recovery_recent(status["creationDate"])
assert status["expirationDate"] == expiration_date.isoformat()
# timezone-aware comparison. Should pass regardless of server's tz
assert datetime.fromisoformat(
status["expirationDate"]
) == expiration_date.astimezone(timezone.utc)
assert status["usesLeft"] is None assert status["usesLeft"] is None

View file

@ -2,6 +2,7 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
import datetime import datetime
from datetime import timezone
import pytest import pytest
from tests.conftest import TOKENS_FILE_CONTENTS from tests.conftest import TOKENS_FILE_CONTENTS
@ -337,7 +338,7 @@ def test_generate_recovery_token_with_expiration_date(
"exists": True, "exists": True,
"valid": True, "valid": True,
"date": time_generated, "date": time_generated,
"expiration": expiration_date.isoformat(), "expiration": expiration_date.astimezone(timezone.utc).isoformat(),
"uses_left": None, "uses_left": None,
} }