refactor(tokens-repo): switch token backend to redis

And use timezone-aware comparisons for expiry checks
This commit is contained in:
Houkime 2023-01-11 17:02:01 +00:00 committed by Inex Code
parent 9cc6e304c0
commit 158c1f13a6
11 changed files with 136 additions and 70 deletions

View file

@ -1,11 +1,11 @@
"""App tokens actions"""
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional
from pydantic import BaseModel
from mnemonic import Mnemonic
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
RedisTokensRepository,
)
from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound,
@ -14,7 +14,7 @@ from selfprivacy_api.repositories.tokens.exceptions import (
NewDeviceKeyNotFound,
)
TOKEN_REPO = JsonTokensRepository()
TOKEN_REPO = RedisTokensRepository()
class TokenInfoWithIsCaller(BaseModel):
@ -82,6 +82,14 @@ class RecoveryTokenStatus(BaseModel):
uses_left: Optional[int] = None
def naive(date_time: datetime) -> datetime:
if date_time is None:
return None
if date_time.tzinfo is not None:
date_time.astimezone(timezone.utc)
return date_time.replace(tzinfo=None)
def get_api_recovery_token_status() -> RecoveryTokenStatus:
"""Get the recovery token status"""
token = TOKEN_REPO.get_recovery_key()
@ -91,8 +99,8 @@ def get_api_recovery_token_status() -> RecoveryTokenStatus:
return RecoveryTokenStatus(
exists=True,
valid=is_valid,
date=token.created_at,
expiration=token.expires_at,
date=naive(token.created_at),
expiration=naive(token.expires_at),
uses_left=token.uses_left,
)

View file

@ -1,11 +1,13 @@
"""
New device key used to obtain access token.
"""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
import secrets
from pydantic import BaseModel
from mnemonic import Mnemonic
from selfprivacy_api.models.tokens.time import is_past
class NewDeviceKey(BaseModel):
"""
@ -22,7 +24,7 @@ class NewDeviceKey(BaseModel):
"""
Check if the recovery key is valid.
"""
if self.expires_at < datetime.now():
if is_past(self.expires_at):
return False
return True
@ -37,10 +39,10 @@ class NewDeviceKey(BaseModel):
"""
Factory to generate a random token.
"""
creation_date = datetime.now()
creation_date = datetime.now(timezone.utc)
key = secrets.token_bytes(16).hex()
return NewDeviceKey(
key=key,
created_at=creation_date,
expires_at=datetime.now() + timedelta(minutes=10),
expires_at=creation_date + timedelta(minutes=10),
)

View file

@ -3,12 +3,14 @@ Recovery key used to obtain access token.
Recovery key has a token string, date of creation, optional date of expiration and optional count of uses left.
"""
from datetime import datetime
from datetime import datetime, timezone
import secrets
from typing import Optional
from pydantic import BaseModel
from mnemonic import Mnemonic
from selfprivacy_api.models.tokens.time import is_past, ensure_timezone
class RecoveryKey(BaseModel):
"""
@ -26,7 +28,7 @@ class RecoveryKey(BaseModel):
"""
Check if the recovery key is valid.
"""
if self.expires_at is not None and self.expires_at < datetime.now():
if self.expires_at is not None and is_past(self.expires_at):
return False
if self.uses_left is not None and self.uses_left <= 0:
return False
@ -46,7 +48,9 @@ class RecoveryKey(BaseModel):
"""
Factory to generate a random token.
"""
creation_date = datetime.now()
creation_date = datetime.now(timezone.utc)
if expiration is not None:
expiration = ensure_timezone(expiration)
key = secrets.token_bytes(24).hex()
return RecoveryKey(
key=key,

View file

@ -0,0 +1,13 @@
from datetime import datetime, timezone
def is_past(dt: datetime) -> bool:
# we cannot compare a naive now()
# to dt which might be tz-aware or unaware
dt = ensure_timezone(dt)
return dt < datetime.now(timezone.utc)
def ensure_timezone(dt:datetime) -> datetime:
if dt.tzinfo is None or dt.tzinfo.utcoffset(None) is None:
dt = dt.replace(tzinfo= timezone.utc)
return dt

View file

@ -2,7 +2,7 @@
temporary legacy
"""
from typing import Optional
from datetime import datetime
from datetime import datetime, timezone
from selfprivacy_api.utils import UserDataFiles, WriteUserData, ReadUserData
from selfprivacy_api.models.tokens.token import Token
@ -15,6 +15,7 @@ from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
AbstractTokensRepository,
)
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
@ -56,6 +57,20 @@ class JsonTokensRepository(AbstractTokensRepository):
raise TokenNotFound("Token not found!")
def __key_date_from_str(self, date_string: str) -> datetime:
if date_string is None or date_string == "":
return None
# we assume that we store dates in json as naive utc
utc_no_tz = datetime.fromisoformat(date_string)
utc_with_tz = utc_no_tz.replace(tzinfo=timezone.utc)
return utc_with_tz
def __date_from_tokens_file(
self, tokens_file: object, tokenfield: str, datefield: str
):
date_string = tokens_file[tokenfield].get(datefield)
return self.__key_date_from_str(date_string)
def get_recovery_key(self) -> Optional[RecoveryKey]:
"""Get the recovery key"""
with ReadUserData(UserDataFiles.TOKENS) as tokens_file:
@ -68,8 +83,12 @@ class JsonTokensRepository(AbstractTokensRepository):
recovery_key = RecoveryKey(
key=tokens_file["recovery_token"].get("token"),
created_at=tokens_file["recovery_token"].get("date"),
expires_at=tokens_file["recovery_token"].get("expiration"),
created_at=self.__date_from_tokens_file(
tokens_file, "recovery_token", "date"
),
expires_at=self.__date_from_tokens_file(
tokens_file, "recovery_token", "expiration"
),
uses_left=tokens_file["recovery_token"].get("uses_left"),
)

View file

@ -2,7 +2,7 @@
Token repository using Redis as backend.
"""
from typing import Optional
from datetime import datetime
from datetime import datetime, timezone
from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
AbstractTokensRepository,
@ -38,6 +38,8 @@ class RedisTokensRepository(AbstractTokensRepository):
for key in token_keys:
token = self._token_from_hash(key)
if token is not None:
# token creation dates are temporarily not tz-aware
token.created_at = token.created_at.replace(tzinfo=None)
tokens.append(token)
return tokens
@ -150,5 +152,7 @@ class RedisTokensRepository(AbstractTokensRepository):
redis = self.connection
for key, value in model.dict().items():
if isinstance(value, datetime):
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
value = value.isoformat()
redis.hset(redis_key, key, str(value))

View file

@ -1,16 +1,21 @@
import json
import datetime
from datetime import datetime, timezone, timedelta
from mnemonic import Mnemonic
# for expiration tests. If headache, consider freezegun
RECOVERY_KEY_VALIDATION_DATETIME = "selfprivacy_api.models.tokens.recovery_key.datetime"
DEVICE_KEY_VALIDATION_DATETIME = "selfprivacy_api.models.tokens.new_device_key.datetime"
RECOVERY_KEY_VALIDATION_DATETIME = "selfprivacy_api.models.tokens.time.datetime"
DEVICE_KEY_VALIDATION_DATETIME = RECOVERY_KEY_VALIDATION_DATETIME
FIVE_MINUTES_INTO_FUTURE_NAIVE = datetime.now() + timedelta(minutes=5)
FIVE_MINUTES_INTO_FUTURE = datetime.now(timezone.utc) + timedelta(minutes=5)
FIVE_MINUTES_INTO_PAST_NAIVE = datetime.now() - timedelta(minutes=5)
FIVE_MINUTES_INTO_PAST = datetime.now(timezone.utc) - timedelta(minutes=5)
class NearFuture(datetime.datetime):
class NearFuture(datetime):
@classmethod
def now(cls):
return datetime.datetime.now() + datetime.timedelta(minutes=13)
def now(cls, tz=None):
return datetime.now(tz) + timedelta(minutes=13)
def read_json(file_path):
@ -41,7 +46,6 @@ def mnemonic_to_hex(mnemonic):
def assert_recovery_recent(time_generated):
assert (
datetime.datetime.strptime(time_generated, "%Y-%m-%dT%H:%M:%S.%f")
- datetime.timedelta(seconds=5)
< datetime.datetime.now()
datetime.strptime(time_generated, "%Y-%m-%dT%H:%M:%S.%f") - timedelta(seconds=5)
< datetime.now()
)

View file

@ -11,6 +11,9 @@ from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
)
from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
RedisTokensRepository,
)
from tests.common import read_json
@ -63,21 +66,26 @@ def empty_json_repo(empty_tokens):
@pytest.fixture
def tokens_file(empty_json_repo, tmpdir):
def empty_redis_repo():
repo = RedisTokensRepository()
repo.reset()
assert repo.get_tokens() == []
return repo
@pytest.fixture
def tokens_file(empty_redis_repo, tmpdir):
"""A state with tokens"""
repo = empty_redis_repo
for token in TOKENS_FILE_CONTENTS["tokens"]:
empty_json_repo._store_token(
repo._store_token(
Token(
token=token["token"],
device_name=token["name"],
created_at=token["date"],
)
)
# temporary return for compatibility with older tests
tokenfile = tmpdir / "empty_tokens.json"
assert path.exists(tokenfile)
return tokenfile
return repo
@pytest.fixture

View file

@ -1,7 +1,6 @@
# pylint: disable=redefined-outer-name
# pylint: disable=unused-argument
# pylint: disable=missing-function-docstring
import datetime
from tests.common import (
generate_api_query,
@ -9,6 +8,11 @@ from tests.common import (
NearFuture,
RECOVERY_KEY_VALIDATION_DATETIME,
)
# Graphql API's output should be timezone-naive
from tests.common import FIVE_MINUTES_INTO_FUTURE_NAIVE as FIVE_MINUTES_INTO_FUTURE
from tests.common import FIVE_MINUTES_INTO_PAST_NAIVE as FIVE_MINUTES_INTO_PAST
from tests.test_graphql.common import (
assert_empty,
assert_data,
@ -153,7 +157,7 @@ def test_graphql_generate_recovery_key(client, authorized_client, tokens_file):
def test_graphql_generate_recovery_key_with_expiration_date(
client, authorized_client, tokens_file
):
expiration_date = datetime.datetime.now() + datetime.timedelta(minutes=5)
expiration_date = FIVE_MINUTES_INTO_FUTURE
key = graphql_make_new_recovery_key(authorized_client, expires_at=expiration_date)
status = graphql_recovery_status(authorized_client)
@ -171,7 +175,7 @@ def test_graphql_generate_recovery_key_with_expiration_date(
def test_graphql_use_recovery_key_after_expiration(
client, authorized_client, tokens_file, mocker
):
expiration_date = datetime.datetime.now() + datetime.timedelta(minutes=5)
expiration_date = FIVE_MINUTES_INTO_FUTURE
key = graphql_make_new_recovery_key(authorized_client, expires_at=expiration_date)
# Timewarp to after it expires
@ -193,7 +197,7 @@ def test_graphql_use_recovery_key_after_expiration(
def test_graphql_generate_recovery_key_with_expiration_in_the_past(
authorized_client, tokens_file
):
expiration_date = datetime.datetime.now() - datetime.timedelta(minutes=5)
expiration_date = FIVE_MINUTES_INTO_PAST
response = request_make_new_recovery_key(
authorized_client, expires_at=expiration_date
)

View file

@ -2,7 +2,7 @@
# pylint: disable=unused-argument
# pylint: disable=missing-function-docstring
from datetime import datetime, timedelta
from datetime import datetime, timezone
from mnemonic import Mnemonic
import pytest
@ -16,9 +16,8 @@ from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound,
NewDeviceKeyNotFound,
)
from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
RedisTokensRepository,
)
from tests.common import FIVE_MINUTES_INTO_PAST
ORIGINAL_DEVICE_NAMES = [
@ -28,6 +27,10 @@ ORIGINAL_DEVICE_NAMES = [
"forth_token",
]
TEST_DATE = datetime(2022, 7, 15, 17, 41, 31, 675698, timezone.utc)
# tokens are not tz-aware
TOKEN_TEST_DATE = datetime(2022, 7, 15, 17, 41, 31, 675698)
def mnemonic_from_hex(hexkey):
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(hexkey))
@ -40,8 +43,8 @@ def mock_new_device_key_generate(mocker):
autospec=True,
return_value=NewDeviceKey(
key="43478d05b35e4781598acd76e33832bb",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
expires_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TEST_DATE,
expires_at=TEST_DATE,
),
)
return mock
@ -55,8 +58,8 @@ def mock_new_device_key_generate_for_mnemonic(mocker):
autospec=True,
return_value=NewDeviceKey(
key="2237238de23dc71ab558e317bdb8ff8e",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
expires_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TEST_DATE,
expires_at=TEST_DATE,
),
)
return mock
@ -83,7 +86,7 @@ def mock_recovery_key_generate_invalid(mocker):
autospec=True,
return_value=RecoveryKey(
key="889bf49c1d3199d71a2e704718772bd53a422020334db051",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TEST_DATE,
expires_at=None,
uses_left=0,
),
@ -99,7 +102,7 @@ def mock_token_generate(mocker):
return_value=Token(
token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM",
device_name="IamNewDevice",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TOKEN_TEST_DATE,
),
)
return mock
@ -112,7 +115,7 @@ def mock_recovery_key_generate(mocker):
autospec=True,
return_value=RecoveryKey(
key="889bf49c1d3199d71a2e704718772bd53a422020334db051",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TEST_DATE,
expires_at=None,
uses_left=1,
),
@ -120,14 +123,6 @@ def mock_recovery_key_generate(mocker):
return mock
@pytest.fixture
def empty_redis_repo():
repo = RedisTokensRepository()
repo.reset()
assert repo.get_tokens() == []
return repo
@pytest.fixture(params=["json", "redis"])
def empty_repo(request, empty_json_repo, empty_redis_repo):
if request.param == "json":
@ -224,13 +219,13 @@ def test_create_token(empty_repo, mock_token_generate):
assert repo.create_token(device_name="IamNewDevice") == Token(
token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM",
device_name="IamNewDevice",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TOKEN_TEST_DATE,
)
assert repo.get_tokens() == [
Token(
token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM",
device_name="IamNewDevice",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TOKEN_TEST_DATE,
)
]
@ -266,7 +261,7 @@ def test_delete_not_found_token(some_tokens_repo):
input_token = Token(
token="imbadtoken",
device_name="primary_token",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TEST_DATE,
)
with pytest.raises(TokenNotFound):
assert repo.delete_token(input_token) is None
@ -295,7 +290,7 @@ def test_refresh_not_found_token(some_tokens_repo, mock_token_generate):
input_token = Token(
token="idontknowwhoiam",
device_name="tellmewhoiam?",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TEST_DATE,
)
with pytest.raises(TokenNotFound):
@ -319,7 +314,7 @@ def test_create_get_recovery_key(some_tokens_repo, mock_recovery_key_generate):
assert repo.create_recovery_key(uses_left=1, expiration=None) is not None
assert repo.get_recovery_key() == RecoveryKey(
key="889bf49c1d3199d71a2e704718772bd53a422020334db051",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TEST_DATE,
expires_at=None,
uses_left=1,
)
@ -358,10 +353,13 @@ def test_use_mnemonic_expired_recovery_key(
some_tokens_repo,
):
repo = some_tokens_repo
expiration = datetime.now() - timedelta(minutes=5)
expiration = FIVE_MINUTES_INTO_PAST
assert repo.create_recovery_key(uses_left=2, expiration=expiration) is not None
recovery_key = repo.get_recovery_key()
assert recovery_key.expires_at == expiration
# TODO: do not ignore timezone once json backend is deleted
assert recovery_key.expires_at.replace(tzinfo=None) == expiration.replace(
tzinfo=None
)
assert not repo.is_recovery_key_valid()
with pytest.raises(RecoveryKeyNotFound):
@ -458,8 +456,8 @@ def test_get_new_device_key(some_tokens_repo, mock_new_device_key_generate):
assert repo.get_new_device_key() == NewDeviceKey(
key="43478d05b35e4781598acd76e33832bb",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
expires_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
created_at=TEST_DATE,
expires_at=TEST_DATE,
)
@ -535,7 +533,7 @@ def test_use_mnemonic_expired_new_device_key(
some_tokens_repo,
):
repo = some_tokens_repo
expiration = datetime.now() - timedelta(minutes=5)
expiration = FIVE_MINUTES_INTO_PAST
key = repo.get_new_device_key()
assert key is not None

View file

@ -11,6 +11,8 @@ from tests.common import (
NearFuture,
assert_recovery_recent,
)
from tests.common import FIVE_MINUTES_INTO_FUTURE_NAIVE as FIVE_MINUTES_INTO_FUTURE
from tests.common import FIVE_MINUTES_INTO_PAST_NAIVE as FIVE_MINUTES_INTO_PAST
DATE_FORMATS = [
"%Y-%m-%dT%H:%M:%S.%fZ",
@ -110,7 +112,7 @@ def rest_recover_with_mnemonic(client, mnemonic_token, device_name):
def test_get_tokens_info(authorized_client, tokens_file):
assert rest_get_tokens_info(authorized_client) == [
assert sorted(rest_get_tokens_info(authorized_client), key=lambda x: x["name"]) == [
{"name": "test_token", "date": "2022-01-14T08:31:10.789314", "is_caller": True},
{
"name": "test_token2",
@ -321,7 +323,7 @@ def test_generate_recovery_token_with_expiration_date(
):
# Generate token with expiration date
# Generate expiration date in the future
expiration_date = datetime.datetime.now() + datetime.timedelta(minutes=5)
expiration_date = FIVE_MINUTES_INTO_FUTURE
mnemonic_token = rest_make_recovery_token(
authorized_client, expires_at=expiration_date, timeformat=timeformat
)
@ -333,7 +335,7 @@ def test_generate_recovery_token_with_expiration_date(
"exists": True,
"valid": True,
"date": time_generated,
"expiration": expiration_date.strftime("%Y-%m-%dT%H:%M:%S.%f"),
"expiration": expiration_date.isoformat(),
"uses_left": None,
}
@ -360,7 +362,7 @@ def test_generate_recovery_token_with_expiration_in_the_past(
authorized_client, tokens_file, timeformat
):
# Server must return 400 if expiration date is in the past
expiration_date = datetime.datetime.utcnow() - datetime.timedelta(minutes=5)
expiration_date = FIVE_MINUTES_INTO_PAST
expiration_date_str = expiration_date.strftime(timeformat)
response = authorized_client.post(
"/auth/recovery_token",