Merge pull request 'Migrate to AbstractTokenRepository API' (#28) from redis/token-repo into redis/connection-pool

Reviewed-on: https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api/pulls/28
This commit is contained in:
Inex Code 2022-12-30 20:06:43 +02:00
commit 45c6133881
16 changed files with 394 additions and 559 deletions

View file

@ -5,12 +5,16 @@ name: default
steps: steps:
- name: Run Tests and Generate Coverage Report - name: Run Tests and Generate Coverage Report
commands: commands:
- kill $(ps aux | grep '[r]edis-server 127.0.0.1:6389' | awk '{print $2}')
- redis-server --bind 127.0.0.1 --port 6389 >/dev/null &
- coverage run -m pytest -q - coverage run -m pytest -q
- coverage xml - coverage xml
- sonar-scanner -Dsonar.projectKey=SelfPrivacy-REST-API -Dsonar.sources=. -Dsonar.host.url=http://analyzer.lan:9000 -Dsonar.login="$SONARQUBE_TOKEN" - sonar-scanner -Dsonar.projectKey=SelfPrivacy-REST-API -Dsonar.sources=. -Dsonar.host.url=http://analyzer.lan:9000 -Dsonar.login="$SONARQUBE_TOKEN"
environment: environment:
SONARQUBE_TOKEN: SONARQUBE_TOKEN:
from_secret: SONARQUBE_TOKEN from_secret: SONARQUBE_TOKEN
USE_REDIS_PORT: 6389
- name: Run Bandit Checks - name: Run Bandit Checks
commands: commands:

View file

@ -2,20 +2,19 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from mnemonic import Mnemonic
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
from selfprivacy_api.utils.auth import ( JsonTokensRepository,
delete_token,
generate_recovery_token,
get_recovery_token_status,
get_tokens_info,
is_recovery_token_exists,
is_recovery_token_valid,
is_token_name_exists,
is_token_name_pair_valid,
refresh_token,
get_token_name,
) )
from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound,
RecoveryKeyNotFound,
InvalidMnemonic,
NewDeviceKeyNotFound,
)
TOKEN_REPO = JsonTokensRepository()
class TokenInfoWithIsCaller(BaseModel): class TokenInfoWithIsCaller(BaseModel):
@ -28,18 +27,23 @@ class TokenInfoWithIsCaller(BaseModel):
def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]: def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]:
"""Get the tokens info""" """Get the tokens info"""
caller_name = get_token_name(caller_token) caller_name = TOKEN_REPO.get_token_by_token_string(caller_token).device_name
tokens = get_tokens_info() tokens = TOKEN_REPO.get_tokens()
return [ return [
TokenInfoWithIsCaller( TokenInfoWithIsCaller(
name=token.name, name=token.device_name,
date=token.date, date=token.created_at,
is_caller=token.name == caller_name, is_caller=token.device_name == caller_name,
) )
for token in tokens for token in tokens
] ]
def is_token_valid(token) -> bool:
"""Check if token is valid"""
return TOKEN_REPO.is_token_valid(token)
class NotFoundException(Exception): class NotFoundException(Exception):
"""Not found exception""" """Not found exception"""
@ -50,19 +54,22 @@ class CannotDeleteCallerException(Exception):
def delete_api_token(caller_token: str, token_name: str) -> None: def delete_api_token(caller_token: str, token_name: str) -> None:
"""Delete the token""" """Delete the token"""
if is_token_name_pair_valid(token_name, caller_token): if TOKEN_REPO.is_token_name_pair_valid(token_name, caller_token):
raise CannotDeleteCallerException("Cannot delete caller's token") raise CannotDeleteCallerException("Cannot delete caller's token")
if not is_token_name_exists(token_name): if not TOKEN_REPO.is_token_name_exists(token_name):
raise NotFoundException("Token not found") raise NotFoundException("Token not found")
delete_token(token_name) token = TOKEN_REPO.get_token_by_name(token_name)
TOKEN_REPO.delete_token(token)
def refresh_api_token(caller_token: str) -> str: def refresh_api_token(caller_token: str) -> str:
"""Refresh the token""" """Refresh the token"""
new_token = refresh_token(caller_token) try:
if new_token is None: old_token = TOKEN_REPO.get_token_by_token_string(caller_token)
new_token = TOKEN_REPO.refresh_token(old_token)
except TokenNotFound:
raise NotFoundException("Token not found") raise NotFoundException("Token not found")
return new_token return new_token.token
class RecoveryTokenStatus(BaseModel): class RecoveryTokenStatus(BaseModel):
@ -77,18 +84,16 @@ class RecoveryTokenStatus(BaseModel):
def get_api_recovery_token_status() -> RecoveryTokenStatus: def get_api_recovery_token_status() -> RecoveryTokenStatus:
"""Get the recovery token status""" """Get the recovery token status"""
if not is_recovery_token_exists(): token = TOKEN_REPO.get_recovery_key()
if token is None:
return RecoveryTokenStatus(exists=False, valid=False) return RecoveryTokenStatus(exists=False, valid=False)
status = get_recovery_token_status() is_valid = TOKEN_REPO.is_recovery_key_valid()
if status is None:
return RecoveryTokenStatus(exists=False, valid=False)
is_valid = is_recovery_token_valid()
return RecoveryTokenStatus( return RecoveryTokenStatus(
exists=True, exists=True,
valid=is_valid, valid=is_valid,
date=status["date"], date=token.created_at,
expiration=status["expiration"], expiration=token.expires_at,
uses_left=status["uses_left"], uses_left=token.uses_left,
) )
@ -112,5 +117,46 @@ def get_new_api_recovery_key(
if uses_left <= 0: if uses_left <= 0:
raise InvalidUsesLeft("Uses must be greater than 0") raise InvalidUsesLeft("Uses must be greater than 0")
key = generate_recovery_token(expiration_date, uses_left) key = TOKEN_REPO.create_recovery_key(expiration_date, uses_left)
return key mnemonic_phrase = Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key))
return mnemonic_phrase
def use_mnemonic_recovery_token(mnemonic_phrase, name):
"""Use the recovery token by converting the mnemonic word list to a byte array.
If the recovery token if invalid itself, return None
If the binary representation of phrase not matches
the byte array of the recovery token, return None.
If the mnemonic phrase is valid then generate a device token and return it.
Substract 1 from uses_left if it exists.
mnemonic_phrase is a string representation of the mnemonic word list.
"""
try:
token = TOKEN_REPO.use_mnemonic_recovery_key(mnemonic_phrase, name)
return token.token
except (RecoveryKeyNotFound, InvalidMnemonic):
return None
def delete_new_device_auth_token() -> None:
TOKEN_REPO.delete_new_device_key()
def get_new_device_auth_token() -> str:
"""Generate and store a new device auth token which is valid for 10 minutes
and return a mnemonic phrase representation
"""
key = TOKEN_REPO.get_new_device_key()
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key))
def use_new_device_auth_token(mnemonic_phrase, name) -> Optional[str]:
"""Use the new device auth token by converting the mnemonic string to a byte array.
If the mnemonic phrase is valid then generate a device token and return it.
New device auth token must be deleted.
"""
try:
token = TOKEN_REPO.use_mnemonic_new_device_key(mnemonic_phrase, name)
return token.token
except (NewDeviceKeyNotFound, InvalidMnemonic):
return None

View file

@ -2,7 +2,7 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from pydantic import BaseModel from pydantic import BaseModel
from selfprivacy_api.utils.auth import is_token_valid from selfprivacy_api.actions.api_tokens import is_token_valid
class TokenHeader(BaseModel): class TokenHeader(BaseModel):

View file

@ -4,7 +4,7 @@ import typing
from strawberry.permission import BasePermission from strawberry.permission import BasePermission
from strawberry.types import Info from strawberry.types import Info
from selfprivacy_api.utils.auth import is_token_valid from selfprivacy_api.actions.api_tokens import is_token_valid
class IsAuthenticated(BasePermission): class IsAuthenticated(BasePermission):

View file

@ -11,6 +11,11 @@ from selfprivacy_api.actions.api_tokens import (
NotFoundException, NotFoundException,
delete_api_token, delete_api_token,
get_new_api_recovery_key, get_new_api_recovery_key,
use_mnemonic_recovery_token,
refresh_api_token,
delete_new_device_auth_token,
get_new_device_auth_token,
use_new_device_auth_token,
) )
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.mutation_interface import ( from selfprivacy_api.graphql.mutations.mutation_interface import (
@ -18,14 +23,6 @@ from selfprivacy_api.graphql.mutations.mutation_interface import (
MutationReturnInterface, MutationReturnInterface,
) )
from selfprivacy_api.utils.auth import (
delete_new_device_auth_token,
get_new_device_auth_token,
refresh_token,
use_mnemonic_recoverery_token,
use_new_device_auth_token,
)
@strawberry.type @strawberry.type
class ApiKeyMutationReturn(MutationReturnInterface): class ApiKeyMutationReturn(MutationReturnInterface):
@ -98,50 +95,53 @@ class ApiMutations:
self, input: UseRecoveryKeyInput self, input: UseRecoveryKeyInput
) -> DeviceApiTokenMutationReturn: ) -> DeviceApiTokenMutationReturn:
"""Use recovery key""" """Use recovery key"""
token = use_mnemonic_recoverery_token(input.key, input.deviceName) token = use_mnemonic_recovery_token(input.key, input.deviceName)
if token is None: if token is not None:
return DeviceApiTokenMutationReturn(
success=False,
message="Recovery key not found",
code=404,
token=None,
)
return DeviceApiTokenMutationReturn( return DeviceApiTokenMutationReturn(
success=True, success=True,
message="Recovery key used", message="Recovery key used",
code=200, code=200,
token=token, token=token,
) )
else:
return DeviceApiTokenMutationReturn(
success=False,
message="Recovery key not found",
code=404,
token=None,
)
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn: def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn:
"""Refresh device api token""" """Refresh device api token"""
token = ( token_string = (
info.context["request"] info.context["request"]
.headers.get("Authorization", "") .headers.get("Authorization", "")
.replace("Bearer ", "") .replace("Bearer ", "")
) )
if token is None: if token_string is None:
return DeviceApiTokenMutationReturn(
success=False,
message="Token not found",
code=404,
token=None,
)
new_token = refresh_token(token)
if new_token is None:
return DeviceApiTokenMutationReturn( return DeviceApiTokenMutationReturn(
success=False, success=False,
message="Token not found", message="Token not found",
code=404, code=404,
token=None, token=None,
) )
try:
new_token = refresh_api_token(token_string)
return DeviceApiTokenMutationReturn( return DeviceApiTokenMutationReturn(
success=True, success=True,
message="Token refreshed", message="Token refreshed",
code=200, code=200,
token=new_token, token=new_token,
) )
except NotFoundException:
return DeviceApiTokenMutationReturn(
success=False,
message="Token not found",
code=404,
token=None,
)
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn: def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn:

View file

@ -4,16 +4,12 @@ import datetime
import typing import typing
import strawberry import strawberry
from strawberry.types import Info from strawberry.types import Info
from selfprivacy_api.actions.api_tokens import get_api_tokens_with_caller_flag from selfprivacy_api.actions.api_tokens import (
from selfprivacy_api.graphql import IsAuthenticated get_api_tokens_with_caller_flag,
from selfprivacy_api.utils import parse_date get_api_recovery_token_status,
from selfprivacy_api.dependencies import get_api_version as get_api_version_dependency
from selfprivacy_api.utils.auth import (
get_recovery_token_status,
is_recovery_token_exists,
is_recovery_token_valid,
) )
from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.dependencies import get_api_version as get_api_version_dependency
def get_api_version() -> str: def get_api_version() -> str:
@ -43,16 +39,8 @@ class ApiRecoveryKeyStatus:
def get_recovery_key_status() -> ApiRecoveryKeyStatus: def get_recovery_key_status() -> ApiRecoveryKeyStatus:
"""Get recovery key status""" """Get recovery key status"""
if not is_recovery_token_exists(): status = get_api_recovery_token_status()
return ApiRecoveryKeyStatus( if status is None or not status.exists:
exists=False,
valid=False,
creation_date=None,
expiration_date=None,
uses_left=None,
)
status = get_recovery_token_status()
if status is None:
return ApiRecoveryKeyStatus( return ApiRecoveryKeyStatus(
exists=False, exists=False,
valid=False, valid=False,
@ -62,12 +50,10 @@ def get_recovery_key_status() -> ApiRecoveryKeyStatus:
) )
return ApiRecoveryKeyStatus( return ApiRecoveryKeyStatus(
exists=True, exists=True,
valid=is_recovery_token_valid(), valid=status.valid,
creation_date=parse_date(status["date"]), creation_date=status.date,
expiration_date=parse_date(status["expiration"]) expiration_date=status.expiration,
if status["expiration"] is not None uses_left=status.uses_left,
else None,
uses_left=status["uses_left"] if status["uses_left"] is not None else None,
) )

View file

@ -97,8 +97,8 @@ class Jobs:
error=None, error=None,
result=None, result=None,
) )
r = RedisPool().get_connection() redis = RedisPool().get_connection()
_store_job_as_hash(r, _redis_key_from_uuid(job.uid), job) _store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job)
return job return job
@staticmethod @staticmethod
@ -113,10 +113,10 @@ class Jobs:
""" """
Remove a job from the jobs list. Remove a job from the jobs list.
""" """
r = RedisPool().get_connection() redis = RedisPool().get_connection()
key = _redis_key_from_uuid(job_uuid) key = _redis_key_from_uuid(job_uuid)
if (r.exists(key)): if redis.exists(key):
r.delete(key) redis.delete(key)
return True return True
return False return False
@ -149,12 +149,12 @@ class Jobs:
if status in (JobStatus.FINISHED, JobStatus.ERROR): if status in (JobStatus.FINISHED, JobStatus.ERROR):
job.finished_at = datetime.datetime.now() job.finished_at = datetime.datetime.now()
r = RedisPool().get_connection() redis = RedisPool().get_connection()
key = _redis_key_from_uuid(job.uid) key = _redis_key_from_uuid(job.uid)
if r.exists(key): if redis.exists(key):
_store_job_as_hash(r, key, job) _store_job_as_hash(redis, key, job)
if status in (JobStatus.FINISHED, JobStatus.ERROR): if status in (JobStatus.FINISHED, JobStatus.ERROR):
r.expire(key, JOB_EXPIRATION_SECONDS) redis.expire(key, JOB_EXPIRATION_SECONDS)
return job return job
@ -163,10 +163,10 @@ class Jobs:
""" """
Get a job from the jobs list. Get a job from the jobs list.
""" """
r = RedisPool().get_connection() redis = RedisPool().get_connection()
key = _redis_key_from_uuid(uid) key = _redis_key_from_uuid(uid)
if r.exists(key): if redis.exists(key):
return _job_from_hash(r, key) return _job_from_hash(redis, key)
return None return None
@staticmethod @staticmethod
@ -174,9 +174,14 @@ class Jobs:
""" """
Get the jobs list. Get the jobs list.
""" """
r = RedisPool().get_connection() redis = RedisPool().get_connection()
jobs = r.keys("jobs:*") job_keys = redis.keys("jobs:*")
return [_job_from_hash(r, job_key) for job_key in jobs] jobs = []
for job_key in job_keys:
job = _job_from_hash(redis, job_key)
if job is not None:
jobs.append(job)
return jobs
@staticmethod @staticmethod
def is_busy() -> bool: def is_busy() -> bool:
@ -189,11 +194,11 @@ class Jobs:
return False return False
def _redis_key_from_uuid(uuid): def _redis_key_from_uuid(uuid_string):
return "jobs:" + str(uuid) return "jobs:" + str(uuid_string)
def _store_job_as_hash(r, redis_key, model): def _store_job_as_hash(redis, redis_key, model):
for key, value in model.dict().items(): for key, value in model.dict().items():
if isinstance(value, uuid.UUID): if isinstance(value, uuid.UUID):
value = str(value) value = str(value)
@ -201,12 +206,12 @@ def _store_job_as_hash(r, redis_key, model):
value = value.isoformat() value = value.isoformat()
if isinstance(value, JobStatus): if isinstance(value, JobStatus):
value = value.value value = value.value
r.hset(redis_key, key, str(value)) redis.hset(redis_key, key, str(value))
def _job_from_hash(r, redis_key): def _job_from_hash(redis, redis_key):
if r.exists(redis_key): if redis.exists(redis_key):
job_dict = r.hgetall(redis_key) job_dict = redis.hgetall(redis_key)
for date in [ for date in [
"created_at", "created_at",
"updated_at", "updated_at",

View file

@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from mnemonic import Mnemonic from mnemonic import Mnemonic
from secrets import randbelow
import re
from selfprivacy_api.models.tokens.token import Token from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.repositories.tokens.exceptions import ( from selfprivacy_api.repositories.tokens.exceptions import (
@ -15,7 +17,7 @@ from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
class AbstractTokensRepository(ABC): class AbstractTokensRepository(ABC):
def get_token_by_token_string(self, token_string: str) -> Optional[Token]: def get_token_by_token_string(self, token_string: str) -> Token:
"""Get the token by token""" """Get the token by token"""
tokens = self.get_tokens() tokens = self.get_tokens()
for token in tokens: for token in tokens:
@ -24,7 +26,7 @@ class AbstractTokensRepository(ABC):
raise TokenNotFound("Token not found!") raise TokenNotFound("Token not found!")
def get_token_by_name(self, token_name: str) -> Optional[Token]: def get_token_by_name(self, token_name: str) -> Token:
"""Get the token by name""" """Get the token by name"""
tokens = self.get_tokens() tokens = self.get_tokens()
for token in tokens: for token in tokens:
@ -39,7 +41,8 @@ class AbstractTokensRepository(ABC):
def create_token(self, device_name: str) -> Token: def create_token(self, device_name: str) -> Token:
"""Create new token""" """Create new token"""
new_token = Token.generate(device_name) unique_name = self._make_unique_device_name(device_name)
new_token = Token.generate(unique_name)
self._store_token(new_token) self._store_token(new_token)
@ -52,6 +55,7 @@ class AbstractTokensRepository(ABC):
def refresh_token(self, input_token: Token) -> Token: def refresh_token(self, input_token: Token) -> Token:
"""Change the token field of the existing token""" """Change the token field of the existing token"""
new_token = Token.generate(device_name=input_token.device_name) new_token = Token.generate(device_name=input_token.device_name)
new_token.created_at = input_token.created_at
if input_token in self.get_tokens(): if input_token in self.get_tokens():
self.delete_token(input_token) self.delete_token(input_token)
@ -62,23 +66,20 @@ class AbstractTokensRepository(ABC):
def is_token_valid(self, token_string: str) -> bool: def is_token_valid(self, token_string: str) -> bool:
"""Check if the token is valid""" """Check if the token is valid"""
token = self.get_token_by_token_string(token_string) return token_string in [token.token for token in self.get_tokens()]
if token is None:
return False
return True
def is_token_name_exists(self, token_name: str) -> bool: def is_token_name_exists(self, token_name: str) -> bool:
"""Check if the token name exists""" """Check if the token name exists"""
token = self.get_token_by_name(token_name) return token_name in [token.device_name for token in self.get_tokens()]
if token is None:
return False
return True
def is_token_name_pair_valid(self, token_name: str, token_string: str) -> bool: def is_token_name_pair_valid(self, token_name: str, token_string: str) -> bool:
"""Check if the token name and token are valid""" """Check if the token name and token are valid"""
try:
token = self.get_token_by_name(token_name) token = self.get_token_by_name(token_name)
if token is None: if token is None:
return False return False
except TokenNotFound:
return False
return token.token == token_string return token.token == token_string
@abstractmethod @abstractmethod
@ -100,7 +101,12 @@ class AbstractTokensRepository(ABC):
if not self.is_recovery_key_valid(): if not self.is_recovery_key_valid():
raise RecoveryKeyNotFound("Recovery key not found") raise RecoveryKeyNotFound("Recovery key not found")
recovery_hex_key = self.get_recovery_key().key recovery_key = self.get_recovery_key()
if recovery_key is None:
raise RecoveryKeyNotFound("Recovery key not found")
recovery_hex_key = recovery_key.key
if not self._assert_mnemonic(recovery_hex_key, mnemonic_phrase): if not self._assert_mnemonic(recovery_hex_key, mnemonic_phrase):
raise RecoveryKeyNotFound("Recovery key not found") raise RecoveryKeyNotFound("Recovery key not found")
@ -117,9 +123,15 @@ class AbstractTokensRepository(ABC):
return False return False
return recovery_key.is_valid() return recovery_key.is_valid()
@abstractmethod
def get_new_device_key(self) -> NewDeviceKey: def get_new_device_key(self) -> NewDeviceKey:
"""Creates and returns the new device key""" """Creates and returns the new device key"""
new_device_key = NewDeviceKey.generate()
self._store_new_device_key(new_device_key)
return new_device_key
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
"""Store new device key directly"""
@abstractmethod @abstractmethod
def delete_new_device_key(self) -> None: def delete_new_device_key(self) -> None:
@ -133,6 +145,9 @@ class AbstractTokensRepository(ABC):
if not new_device_key: if not new_device_key:
raise NewDeviceKeyNotFound raise NewDeviceKeyNotFound
if not new_device_key.is_valid():
raise NewDeviceKeyNotFound
if not self._assert_mnemonic(new_device_key.key, mnemonic_phrase): if not self._assert_mnemonic(new_device_key.key, mnemonic_phrase):
raise NewDeviceKeyNotFound("Phrase is not token!") raise NewDeviceKeyNotFound("Phrase is not token!")
@ -153,6 +168,19 @@ class AbstractTokensRepository(ABC):
def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]: def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
"""Retrieves new device key that is already stored.""" """Retrieves new device key that is already stored."""
def _make_unique_device_name(self, name: str) -> str:
"""Token name must be an alphanumeric string and not empty.
Replace invalid characters with '_'
If name exists, add a random number to the end of the name until it is unique.
"""
if not re.match("^[a-zA-Z0-9]*$", name):
name = re.sub("[^a-zA-Z0-9]", "_", name)
if name == "":
name = "Unknown device"
while self.is_token_name_exists(name):
name += str(randbelow(10))
return name
# TODO: find a proper place for it # TODO: find a proper place for it
def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str): def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str):
"""Return true if hex string matches the phrase, false otherwise """Return true if hex string matches the phrase, false otherwise

View file

@ -69,7 +69,7 @@ class JsonTokensRepository(AbstractTokensRepository):
recovery_key = RecoveryKey( recovery_key = RecoveryKey(
key=tokens_file["recovery_token"].get("token"), key=tokens_file["recovery_token"].get("token"),
created_at=tokens_file["recovery_token"].get("date"), created_at=tokens_file["recovery_token"].get("date"),
expires_at=tokens_file["recovery_token"].get("expitation"), expires_at=tokens_file["recovery_token"].get("expiration"),
uses_left=tokens_file["recovery_token"].get("uses_left"), uses_left=tokens_file["recovery_token"].get("uses_left"),
) )
@ -85,10 +85,13 @@ class JsonTokensRepository(AbstractTokensRepository):
recovery_key = RecoveryKey.generate(expiration, uses_left) recovery_key = RecoveryKey.generate(expiration, uses_left)
with WriteUserData(UserDataFiles.TOKENS) as tokens_file: with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
key_expiration: Optional[str] = None
if recovery_key.expires_at is not None:
key_expiration = recovery_key.expires_at.strftime(DATETIME_FORMAT)
tokens_file["recovery_token"] = { tokens_file["recovery_token"] = {
"token": recovery_key.key, "token": recovery_key.key,
"date": recovery_key.created_at.strftime(DATETIME_FORMAT), "date": recovery_key.created_at.strftime(DATETIME_FORMAT),
"expiration": recovery_key.expires_at, "expiration": key_expiration,
"uses_left": recovery_key.uses_left, "uses_left": recovery_key.uses_left,
} }
@ -98,12 +101,10 @@ class JsonTokensRepository(AbstractTokensRepository):
"""Decrement recovery key use count by one""" """Decrement recovery key use count by one"""
if self.is_recovery_key_valid(): if self.is_recovery_key_valid():
with WriteUserData(UserDataFiles.TOKENS) as tokens: with WriteUserData(UserDataFiles.TOKENS) as tokens:
if tokens["recovery_token"]["uses_left"] is not None:
tokens["recovery_token"]["uses_left"] -= 1 tokens["recovery_token"]["uses_left"] -= 1
def get_new_device_key(self) -> NewDeviceKey: def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
"""Creates and returns the new device key"""
new_device_key = NewDeviceKey.generate()
with WriteUserData(UserDataFiles.TOKENS) as tokens_file: with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
tokens_file["new_device"] = { tokens_file["new_device"] = {
"token": new_device_key.key, "token": new_device_key.key,
@ -111,8 +112,6 @@ class JsonTokensRepository(AbstractTokensRepository):
"expiration": new_device_key.expires_at.strftime(DATETIME_FORMAT), "expiration": new_device_key.expires_at.strftime(DATETIME_FORMAT),
} }
return new_device_key
def delete_new_device_key(self) -> None: def delete_new_device_key(self) -> None:
"""Delete the new device key""" """Delete the new device key"""
with WriteUserData(UserDataFiles.TOKENS) as tokens_file: with WriteUserData(UserDataFiles.TOKENS) as tokens_file:

View file

@ -32,29 +32,34 @@ class RedisTokensRepository(AbstractTokensRepository):
def get_tokens(self) -> list[Token]: def get_tokens(self) -> list[Token]:
"""Get the tokens""" """Get the tokens"""
r = self.connection redis = self.connection
token_keys = r.keys(TOKENS_PREFIX + "*") token_keys = redis.keys(TOKENS_PREFIX + "*")
return [self._token_from_hash(key) for key in token_keys] tokens = []
for key in token_keys:
token = self._token_from_hash(key)
if token is not None:
tokens.append(token)
return tokens
def delete_token(self, input_token: Token) -> None: def delete_token(self, input_token: Token) -> None:
"""Delete the token""" """Delete the token"""
r = self.connection redis = self.connection
key = RedisTokensRepository._token_redis_key(input_token) key = RedisTokensRepository._token_redis_key(input_token)
if input_token not in self.get_tokens(): if input_token not in self.get_tokens():
raise TokenNotFound raise TokenNotFound
r.delete(key) redis.delete(key)
def reset(self): def reset(self):
for token in self.get_tokens(): for token in self.get_tokens():
self.delete_token(token) self.delete_token(token)
self.delete_new_device_key() self.delete_new_device_key()
r = self.connection redis = self.connection
r.delete(RECOVERY_KEY_REDIS_KEY) redis.delete(RECOVERY_KEY_REDIS_KEY)
def get_recovery_key(self) -> Optional[RecoveryKey]: def get_recovery_key(self) -> Optional[RecoveryKey]:
"""Get the recovery key""" """Get the recovery key"""
r = self.connection redis = self.connection
if r.exists(RECOVERY_KEY_REDIS_KEY): if redis.exists(RECOVERY_KEY_REDIS_KEY):
return self._recovery_key_from_hash(RECOVERY_KEY_REDIS_KEY) return self._recovery_key_from_hash(RECOVERY_KEY_REDIS_KEY)
return None return None
@ -68,16 +73,14 @@ class RedisTokensRepository(AbstractTokensRepository):
self._store_model_as_hash(RECOVERY_KEY_REDIS_KEY, recovery_key) self._store_model_as_hash(RECOVERY_KEY_REDIS_KEY, recovery_key)
return recovery_key return recovery_key
def get_new_device_key(self) -> NewDeviceKey: def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
"""Creates and returns the new device key""" """Store new device key directly"""
new_device_key = NewDeviceKey.generate()
self._store_model_as_hash(NEW_DEVICE_KEY_REDIS_KEY, new_device_key) self._store_model_as_hash(NEW_DEVICE_KEY_REDIS_KEY, new_device_key)
return new_device_key
def delete_new_device_key(self) -> None: def delete_new_device_key(self) -> None:
"""Delete the new device key""" """Delete the new device key"""
r = self.connection redis = self.connection
r.delete(NEW_DEVICE_KEY_REDIS_KEY) redis.delete(NEW_DEVICE_KEY_REDIS_KEY)
@staticmethod @staticmethod
def _token_redis_key(token: Token) -> str: def _token_redis_key(token: Token) -> str:
@ -91,9 +94,13 @@ class RedisTokensRepository(AbstractTokensRepository):
def _decrement_recovery_token(self): def _decrement_recovery_token(self):
"""Decrement recovery key use count by one""" """Decrement recovery key use count by one"""
if self.is_recovery_key_valid(): if self.is_recovery_key_valid():
uses_left = self.get_recovery_key().uses_left recovery_key = self.get_recovery_key()
r = self.connection if recovery_key is None:
r.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1) return
uses_left = recovery_key.uses_left
if uses_left is not None:
redis = self.connection
redis.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1)
def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]: def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
"""Retrieves new device key that is already stored.""" """Retrieves new device key that is already stored."""
@ -117,9 +124,9 @@ class RedisTokensRepository(AbstractTokensRepository):
d[key] = None d[key] = None
def _model_dict_from_hash(self, redis_key: str) -> Optional[dict]: def _model_dict_from_hash(self, redis_key: str) -> Optional[dict]:
r = self.connection redis = self.connection
if r.exists(redis_key): if redis.exists(redis_key):
token_dict = r.hgetall(redis_key) token_dict = redis.hgetall(redis_key)
RedisTokensRepository._prepare_model_dict(token_dict) RedisTokensRepository._prepare_model_dict(token_dict)
return token_dict return token_dict
return None return None
@ -140,8 +147,8 @@ class RedisTokensRepository(AbstractTokensRepository):
return self._hash_as_model(redis_key, NewDeviceKey) return self._hash_as_model(redis_key, NewDeviceKey)
def _store_model_as_hash(self, redis_key, model): def _store_model_as_hash(self, redis_key, model):
r = self.connection redis = self.connection
for key, value in model.dict().items(): for key, value in model.dict().items():
if isinstance(value, datetime): if isinstance(value, datetime):
value = value.isoformat() value = value.isoformat()
r.hset(redis_key, key, str(value)) redis.hset(redis_key, key, str(value))

View file

@ -8,20 +8,18 @@ from selfprivacy_api.actions.api_tokens import (
InvalidUsesLeft, InvalidUsesLeft,
NotFoundException, NotFoundException,
delete_api_token, delete_api_token,
refresh_api_token,
get_api_recovery_token_status, get_api_recovery_token_status,
get_api_tokens_with_caller_flag, get_api_tokens_with_caller_flag,
get_new_api_recovery_key, get_new_api_recovery_key,
refresh_api_token, use_mnemonic_recovery_token,
delete_new_device_auth_token,
get_new_device_auth_token,
use_new_device_auth_token,
) )
from selfprivacy_api.dependencies import TokenHeader, get_token_header from selfprivacy_api.dependencies import TokenHeader, get_token_header
from selfprivacy_api.utils.auth import (
delete_new_device_auth_token,
get_new_device_auth_token,
use_mnemonic_recoverery_token,
use_new_device_auth_token,
)
router = APIRouter( router = APIRouter(
prefix="/auth", prefix="/auth",
@ -99,7 +97,7 @@ class UseTokenInput(BaseModel):
@router.post("/recovery_token/use") @router.post("/recovery_token/use")
async def rest_use_recovery_token(input: UseTokenInput): async def rest_use_recovery_token(input: UseTokenInput):
token = use_mnemonic_recoverery_token(input.token, input.device) token = use_mnemonic_recovery_token(input.token, input.device)
if token is None: if token is None:
raise HTTPException(status_code=404, detail="Token not found") raise HTTPException(status_code=404, detail="Token not found")
return {"token": token} return {"token": token}

View file

@ -1,329 +0,0 @@
#!/usr/bin/env python3
"""Token management utils"""
import secrets
from datetime import datetime, timedelta
import re
import typing
from pydantic import BaseModel
from mnemonic import Mnemonic
from . import ReadUserData, UserDataFiles, WriteUserData, parse_date
"""
Token are stored in the tokens.json file.
File contains device tokens, recovery token and new device auth token.
File structure:
{
"tokens": [
{
"token": "device token",
"name": "device name",
"date": "date of creation",
}
],
"recovery_token": {
"token": "recovery token",
"date": "date of creation",
"expiration": "date of expiration",
"uses_left": "number of uses left"
},
"new_device": {
"token": "new device auth token",
"date": "date of creation",
"expiration": "date of expiration",
}
}
Recovery token may or may not have expiration date and uses_left.
There may be no recovery token at all.
Device tokens must be unique.
"""
def _get_tokens():
"""Get all tokens as list of tokens of every device"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return [token["token"] for token in tokens["tokens"]]
def _get_token_names():
"""Get all token names"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return [t["name"] for t in tokens["tokens"]]
def _validate_token_name(name):
"""Token name must be an alphanumeric string and not empty.
Replace invalid characters with '_'
If token name exists, add a random number to the end of the name until it is unique.
"""
if not re.match("^[a-zA-Z0-9]*$", name):
name = re.sub("[^a-zA-Z0-9]", "_", name)
if name == "":
name = "Unknown device"
while name in _get_token_names():
name += str(secrets.randbelow(10))
return name
def is_token_valid(token):
"""Check if token is valid"""
if token in _get_tokens():
return True
return False
def is_token_name_exists(token_name):
"""Check if token name exists"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return token_name in [t["name"] for t in tokens["tokens"]]
def is_token_name_pair_valid(token_name, token):
"""Check if token name and token pair exists"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
for t in tokens["tokens"]:
if t["name"] == token_name and t["token"] == token:
return True
return False
def get_token_name(token: str) -> typing.Optional[str]:
"""Return the name of the token provided"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
for t in tokens["tokens"]:
if t["token"] == token:
return t["name"]
return None
class BasicTokenInfo(BaseModel):
"""Token info"""
name: str
date: datetime
def get_tokens_info():
"""Get all tokens info without tokens themselves"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return [
BasicTokenInfo(
name=t["name"],
date=parse_date(t["date"]),
)
for t in tokens["tokens"]
]
def _generate_token():
"""Generates new token and makes sure it is unique"""
token = secrets.token_urlsafe(32)
while token in _get_tokens():
token = secrets.token_urlsafe(32)
return token
def create_token(name):
"""Create new token"""
token = _generate_token()
name = _validate_token_name(name)
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"].append(
{
"token": token,
"name": name,
"date": str(datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")),
}
)
return token
def delete_token(token_name):
"""Delete token"""
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"] = [t for t in tokens["tokens"] if t["name"] != token_name]
def refresh_token(token: str) -> typing.Optional[str]:
"""Change the token field of the existing token"""
new_token = _generate_token()
with WriteUserData(UserDataFiles.TOKENS) as tokens:
for t in tokens["tokens"]:
if t["token"] == token:
t["token"] = new_token
return new_token
return None
def is_recovery_token_exists():
"""Check if recovery token exists"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return "recovery_token" in tokens
def is_recovery_token_valid():
"""Check if recovery token is valid"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
if "recovery_token" not in tokens:
return False
recovery_token = tokens["recovery_token"]
if "uses_left" in recovery_token and recovery_token["uses_left"] is not None:
if recovery_token["uses_left"] <= 0:
return False
if "expiration" not in recovery_token or recovery_token["expiration"] is None:
return True
return datetime.now() < parse_date(recovery_token["expiration"])
def get_recovery_token_status():
"""Get recovery token date of creation, expiration and uses left"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
if "recovery_token" not in tokens:
return None
recovery_token = tokens["recovery_token"]
return {
"date": recovery_token["date"],
"expiration": recovery_token["expiration"]
if "expiration" in recovery_token
else None,
"uses_left": recovery_token["uses_left"]
if "uses_left" in recovery_token
else None,
}
def _get_recovery_token():
"""Get recovery token"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
if "recovery_token" not in tokens:
return None
return tokens["recovery_token"]["token"]
def generate_recovery_token(
expiration: typing.Optional[datetime], uses_left: typing.Optional[int]
) -> str:
"""Generate a 24 bytes recovery token and return a mneomnic word list.
Write a string representation of the recovery token to the tokens.json file.
"""
# expires must be a date or None
# uses_left must be an integer or None
if expiration is not None:
if not isinstance(expiration, datetime):
raise TypeError("expires must be a datetime object")
if uses_left is not None:
if not isinstance(uses_left, int):
raise TypeError("uses_left must be an integer")
if uses_left <= 0:
raise ValueError("uses_left must be greater than 0")
recovery_token = secrets.token_bytes(24)
recovery_token_str = recovery_token.hex()
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["recovery_token"] = {
"token": recovery_token_str,
"date": str(datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")),
"expiration": expiration.strftime("%Y-%m-%dT%H:%M:%S.%f")
if expiration is not None
else None,
"uses_left": uses_left if uses_left is not None else None,
}
return Mnemonic(language="english").to_mnemonic(recovery_token)
def use_mnemonic_recoverery_token(mnemonic_phrase, name):
"""Use the recovery token by converting the mnemonic word list to a byte array.
If the recovery token if invalid itself, return None
If the binary representation of phrase not matches
the byte array of the recovery token, return None.
If the mnemonic phrase is valid then generate a device token and return it.
Substract 1 from uses_left if it exists.
mnemonic_phrase is a string representation of the mnemonic word list.
"""
if not is_recovery_token_valid():
return None
recovery_token_str = _get_recovery_token()
if recovery_token_str is None:
return None
recovery_token = bytes.fromhex(recovery_token_str)
if not Mnemonic(language="english").check(mnemonic_phrase):
return None
phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
if phrase_bytes != recovery_token:
return None
token = _generate_token()
name = _validate_token_name(name)
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"].append(
{
"token": token,
"name": name,
"date": str(datetime.now()),
}
)
if "recovery_token" in tokens:
if (
"uses_left" in tokens["recovery_token"]
and tokens["recovery_token"]["uses_left"] is not None
):
tokens["recovery_token"]["uses_left"] -= 1
return token
def get_new_device_auth_token() -> str:
"""Generate a new device auth token which is valid for 10 minutes
and return a mnemonic phrase representation
Write token to the new_device of the tokens.json file.
"""
token = secrets.token_bytes(16)
token_str = token.hex()
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["new_device"] = {
"token": token_str,
"date": str(datetime.now()),
"expiration": str(datetime.now() + timedelta(minutes=10)),
}
return Mnemonic(language="english").to_mnemonic(token)
def _get_new_device_auth_token():
"""Get new device auth token. If it is expired, return None"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
if "new_device" not in tokens:
return None
new_device = tokens["new_device"]
if "expiration" not in new_device:
return None
expiration = parse_date(new_device["expiration"])
if datetime.now() > expiration:
return None
return new_device["token"]
def delete_new_device_auth_token():
"""Delete new device auth token"""
with WriteUserData(UserDataFiles.TOKENS) as tokens:
if "new_device" in tokens:
del tokens["new_device"]
def use_new_device_auth_token(mnemonic_phrase, name):
"""Use the new device auth token by converting the mnemonic string to a byte array.
If the mnemonic phrase is valid then generate a device token and return it.
New device auth token must be deleted.
"""
token_str = _get_new_device_auth_token()
if token_str is None:
return None
token = bytes.fromhex(token_str)
if not Mnemonic(language="english").check(mnemonic_phrase):
return None
phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
if phrase_bytes != token:
return None
token = create_token(name)
with WriteUserData(UserDataFiles.TOKENS) as tokens:
if "new_device" in tokens:
del tokens["new_device"]
return token

View file

@ -2,8 +2,14 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
import datetime import datetime
import pytest
from mnemonic import Mnemonic from mnemonic import Mnemonic
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
)
from selfprivacy_api.models.tokens.token import Token
from tests.common import generate_api_query, read_json, write_json from tests.common import generate_api_query, read_json, write_json
TOKENS_FILE_CONTETS = { TOKENS_FILE_CONTETS = {
@ -30,6 +36,11 @@ devices {
""" """
@pytest.fixture
def token_repo():
return JsonTokensRepository()
def test_graphql_tokens_info(authorized_client, tokens_file): def test_graphql_tokens_info(authorized_client, tokens_file):
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
@ -170,7 +181,7 @@ def test_graphql_refresh_token_unauthorized(client, tokens_file):
assert response.json()["data"] is None assert response.json()["data"] is None
def test_graphql_refresh_token(authorized_client, tokens_file): def test_graphql_refresh_token(authorized_client, tokens_file, token_repo):
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
json={"query": REFRESH_TOKEN_MUTATION}, json={"query": REFRESH_TOKEN_MUTATION},
@ -180,11 +191,12 @@ def test_graphql_refresh_token(authorized_client, tokens_file):
assert response.json()["data"]["refreshDeviceApiToken"]["success"] is True assert response.json()["data"]["refreshDeviceApiToken"]["success"] is True
assert response.json()["data"]["refreshDeviceApiToken"]["message"] is not None assert response.json()["data"]["refreshDeviceApiToken"]["message"] is not None
assert response.json()["data"]["refreshDeviceApiToken"]["code"] == 200 assert response.json()["data"]["refreshDeviceApiToken"]["code"] == 200
assert read_json(tokens_file)["tokens"][0] == { token = token_repo.get_token_by_name("test_token")
"token": response.json()["data"]["refreshDeviceApiToken"]["token"], assert token == Token(
"name": "test_token", token=response.json()["data"]["refreshDeviceApiToken"]["token"],
"date": "2022-01-14 08:31:10.789314", device_name="test_token",
} created_at=datetime.datetime(2022, 1, 14, 8, 31, 10, 789314),
)
NEW_DEVICE_KEY_MUTATION = """ NEW_DEVICE_KEY_MUTATION = """

View file

@ -2,7 +2,8 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
from datetime import datetime from datetime import datetime, timedelta
from mnemonic import Mnemonic
import pytest import pytest
@ -32,6 +33,10 @@ ORIGINAL_DEVICE_NAMES = [
] ]
def mnemonic_from_hex(hexkey):
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(hexkey))
@pytest.fixture @pytest.fixture
def empty_keys(mocker, datadir): def empty_keys(mocker, datadir):
mocker.patch("selfprivacy_api.utils.TOKENS_FILE", new=datadir / "empty_keys.json") mocker.patch("selfprivacy_api.utils.TOKENS_FILE", new=datadir / "empty_keys.json")
@ -132,21 +137,6 @@ def mock_recovery_key_generate(mocker):
return mock return mock
@pytest.fixture
def mock_recovery_key_generate_for_mnemonic(mocker):
mock = mocker.patch(
"selfprivacy_api.models.tokens.recovery_key.RecoveryKey.generate",
autospec=True,
return_value=RecoveryKey(
key="ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
expires_at=None,
uses_left=1,
),
)
return mock
@pytest.fixture @pytest.fixture
def empty_json_repo(empty_keys): def empty_json_repo(empty_keys):
repo = JsonTokensRepository() repo = JsonTokensRepository()
@ -221,6 +211,28 @@ def test_get_token_by_non_existent_name(some_tokens_repo):
assert repo.get_token_by_name(token_name="badname") is None assert repo.get_token_by_name(token_name="badname") is None
def test_is_token_valid(some_tokens_repo):
repo = some_tokens_repo
token = repo.get_tokens()[0]
assert repo.is_token_valid(token.token)
assert not repo.is_token_valid("gibberish")
def test_is_token_name_pair_valid(some_tokens_repo):
repo = some_tokens_repo
token = repo.get_tokens()[0]
assert repo.is_token_name_pair_valid(token.device_name, token.token)
assert not repo.is_token_name_pair_valid(token.device_name, "gibberish")
assert not repo.is_token_name_pair_valid("gibberish", token.token)
def test_is_token_name_exists(some_tokens_repo):
repo = some_tokens_repo
token = repo.get_tokens()[0]
assert repo.is_token_name_exists(token.device_name)
assert not repo.is_token_name_exists("gibberish")
def test_get_tokens(some_tokens_repo): def test_get_tokens(some_tokens_repo):
repo = some_tokens_repo repo = some_tokens_repo
tokenstrings = [] tokenstrings = []
@ -249,6 +261,17 @@ def test_create_token(empty_repo, mock_token_generate):
] ]
def test_create_token_existing(some_tokens_repo):
repo = some_tokens_repo
old_token = repo.get_tokens()[0]
new_token = repo.create_token(device_name=old_token.device_name)
assert new_token.device_name != old_token.device_name
assert old_token in repo.get_tokens()
assert new_token in repo.get_tokens()
def test_delete_token(some_tokens_repo): def test_delete_token(some_tokens_repo):
repo = some_tokens_repo repo = some_tokens_repo
original_tokens = repo.get_tokens() original_tokens = repo.get_tokens()
@ -280,15 +303,17 @@ def test_delete_not_found_token(some_tokens_repo):
assert token in new_tokens assert token in new_tokens
def test_refresh_token(some_tokens_repo, mock_token_generate): def test_refresh_token(some_tokens_repo):
repo = some_tokens_repo repo = some_tokens_repo
input_token = some_tokens_repo.get_tokens()[0] input_token = some_tokens_repo.get_tokens()[0]
assert repo.refresh_token(input_token) == Token( output_token = repo.refresh_token(input_token)
token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM",
device_name="IamNewDevice", assert output_token.token != input_token.token
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), assert output_token.device_name == input_token.device_name
) assert output_token.created_at == input_token.created_at
assert output_token in repo.get_tokens()
def test_refresh_not_found_token(some_tokens_repo, mock_token_generate): def test_refresh_not_found_token(some_tokens_repo, mock_token_generate):
@ -355,6 +380,23 @@ def test_use_mnemonic_not_valid_recovery_key(
) )
def test_use_mnemonic_expired_recovery_key(
some_tokens_repo,
):
repo = some_tokens_repo
expiration = datetime.now() - timedelta(minutes=5)
assert repo.create_recovery_key(uses_left=2, expiration=expiration) is not None
recovery_key = repo.get_recovery_key()
assert recovery_key.expires_at == expiration
assert not repo.is_recovery_key_valid()
with pytest.raises(RecoveryKeyNotFound):
token = repo.use_mnemonic_recovery_key(
mnemonic_phrase=mnemonic_from_hex(recovery_key.key),
device_name="newdevice",
)
def test_use_mnemonic_not_mnemonic_recovery_key(some_tokens_repo): def test_use_mnemonic_not_mnemonic_recovery_key(some_tokens_repo):
repo = some_tokens_repo repo = some_tokens_repo
assert repo.create_recovery_key(uses_left=1, expiration=None) is not None assert repo.create_recovery_key(uses_left=1, expiration=None) is not None
@ -397,46 +439,38 @@ def test_use_not_found_mnemonic_recovery_key(some_tokens_repo):
) )
def test_use_mnemonic_recovery_key_when_empty(empty_repo): @pytest.fixture(params=["recovery_uses_1", "recovery_eternal"])
repo = empty_repo def recovery_key_uses_left(request):
if request.param == "recovery_uses_1":
with pytest.raises(RecoveryKeyNotFound): return 1
assert ( if request.param == "recovery_eternal":
repo.use_mnemonic_recovery_key( return None
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
device_name="primary_token",
)
is None
)
# agnostic test mixed with an implementation test def test_use_mnemonic_recovery_key(some_tokens_repo, recovery_key_uses_left):
def test_use_mnemonic_recovery_key(
some_tokens_repo, mock_recovery_key_generate_for_mnemonic, mock_generate_token
):
repo = some_tokens_repo repo = some_tokens_repo
assert repo.create_recovery_key(uses_left=1, expiration=None) is not None
test_token = Token(
token="ur71mC4aiI6FIYAN--cTL-38rPHS5D6NuB1bgN_qKF4",
device_name="newdevice",
created_at=datetime(2022, 11, 14, 6, 6, 32, 777123),
)
assert ( assert (
repo.use_mnemonic_recovery_key( repo.create_recovery_key(uses_left=recovery_key_uses_left, expiration=None)
mnemonic_phrase="uniform clarify napkin bid dress search input armor police cross salon because myself uphold slice bamboo hungry park", is not None
)
assert repo.is_recovery_key_valid()
recovery_key = repo.get_recovery_key()
token = repo.use_mnemonic_recovery_key(
mnemonic_phrase=mnemonic_from_hex(recovery_key.key),
device_name="newdevice", device_name="newdevice",
) )
== test_token
)
assert test_token in repo.get_tokens() assert token.device_name == "newdevice"
assert token in repo.get_tokens()
new_uses = None
if recovery_key_uses_left is not None:
new_uses = recovery_key_uses_left - 1
assert repo.get_recovery_key() == RecoveryKey( assert repo.get_recovery_key() == RecoveryKey(
key="ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54", key=recovery_key.key,
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), created_at=recovery_key.created_at,
expires_at=None, expires_at=None,
uses_left=0, uses_left=new_uses,
) )
@ -497,15 +531,16 @@ def test_use_not_exists_mnemonic_new_device_key(
) )
def test_use_mnemonic_new_device_key( def test_use_mnemonic_new_device_key(empty_repo):
empty_repo, mock_new_device_key_generate_for_mnemonic
):
repo = empty_repo repo = empty_repo
assert repo.get_new_device_key() is not None key = repo.get_new_device_key()
assert key is not None
mnemonic_phrase = mnemonic_from_hex(key.key)
new_token = repo.use_mnemonic_new_device_key( new_token = repo.use_mnemonic_new_device_key(
device_name="imnew", device_name="imnew",
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb", mnemonic_phrase=mnemonic_phrase,
) )
assert new_token.device_name == "imnew" assert new_token.device_name == "imnew"
@ -516,12 +551,32 @@ def test_use_mnemonic_new_device_key(
assert ( assert (
repo.use_mnemonic_new_device_key( repo.use_mnemonic_new_device_key(
device_name="imnew", device_name="imnew",
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb", mnemonic_phrase=mnemonic_phrase,
) )
is None is None
) )
def test_use_mnemonic_expired_new_device_key(
some_tokens_repo,
):
repo = some_tokens_repo
expiration = datetime.now() - timedelta(minutes=5)
key = repo.get_new_device_key()
assert key is not None
assert key.expires_at is not None
key.expires_at = expiration
assert not key.is_valid()
repo._store_new_device_key(key)
with pytest.raises(NewDeviceKeyNotFound):
token = repo.use_mnemonic_new_device_key(
mnemonic_phrase=mnemonic_from_hex(key.key),
device_name="imnew",
)
def test_use_mnemonic_new_device_key_when_empty(empty_repo): def test_use_mnemonic_new_device_key_when_empty(empty_repo):
repo = empty_repo repo = empty_repo

18
tests/test_models.py Normal file
View file

@ -0,0 +1,18 @@
import pytest
from datetime import datetime, timedelta
from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
def test_recovery_key_expired():
expiration = datetime.now() - timedelta(minutes=5)
key = RecoveryKey.generate(expiration=expiration, uses_left=2)
assert not key.is_valid()
def test_new_device_key_expired():
expiration = datetime.now() - timedelta(minutes=5)
key = NewDeviceKey.generate()
key.expires_at = expiration
assert not key.is_valid()

View file

@ -5,6 +5,12 @@ import datetime
import pytest import pytest
from mnemonic import Mnemonic from mnemonic import Mnemonic
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
)
TOKEN_REPO = JsonTokensRepository()
from tests.common import read_json, write_json from tests.common import read_json, write_json
@ -97,7 +103,7 @@ def test_refresh_token(authorized_client, tokens_file):
response = authorized_client.post("/auth/tokens") response = authorized_client.post("/auth/tokens")
assert response.status_code == 200 assert response.status_code == 200
new_token = response.json()["token"] new_token = response.json()["token"]
assert read_json(tokens_file)["tokens"][0]["token"] == new_token assert TOKEN_REPO.get_token_by_token_string(new_token) is not None
# new device # new device