mirror of
https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api.git
synced 2024-11-26 05:51:29 +00:00
Merge pull request 'Migrate to AbstractTokenRepository API' (#28) from redis/token-repo into redis/connection-pool
Reviewed-on: https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api/pulls/28
This commit is contained in:
commit
45c6133881
|
@ -5,12 +5,16 @@ name: default
|
|||
steps:
|
||||
- name: Run Tests and Generate Coverage Report
|
||||
commands:
|
||||
- kill $(ps aux | grep '[r]edis-server 127.0.0.1:6389' | awk '{print $2}')
|
||||
- redis-server --bind 127.0.0.1 --port 6389 >/dev/null &
|
||||
- coverage run -m pytest -q
|
||||
- coverage xml
|
||||
- sonar-scanner -Dsonar.projectKey=SelfPrivacy-REST-API -Dsonar.sources=. -Dsonar.host.url=http://analyzer.lan:9000 -Dsonar.login="$SONARQUBE_TOKEN"
|
||||
environment:
|
||||
SONARQUBE_TOKEN:
|
||||
from_secret: SONARQUBE_TOKEN
|
||||
USE_REDIS_PORT: 6389
|
||||
|
||||
|
||||
- name: Run Bandit Checks
|
||||
commands:
|
||||
|
|
|
@ -2,20 +2,19 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
|
||||
from selfprivacy_api.utils.auth import (
|
||||
delete_token,
|
||||
generate_recovery_token,
|
||||
get_recovery_token_status,
|
||||
get_tokens_info,
|
||||
is_recovery_token_exists,
|
||||
is_recovery_token_valid,
|
||||
is_token_name_exists,
|
||||
is_token_name_pair_valid,
|
||||
refresh_token,
|
||||
get_token_name,
|
||||
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
|
||||
JsonTokensRepository,
|
||||
)
|
||||
from selfprivacy_api.repositories.tokens.exceptions import (
|
||||
TokenNotFound,
|
||||
RecoveryKeyNotFound,
|
||||
InvalidMnemonic,
|
||||
NewDeviceKeyNotFound,
|
||||
)
|
||||
|
||||
TOKEN_REPO = JsonTokensRepository()
|
||||
|
||||
|
||||
class TokenInfoWithIsCaller(BaseModel):
|
||||
|
@ -28,18 +27,23 @@ class TokenInfoWithIsCaller(BaseModel):
|
|||
|
||||
def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]:
|
||||
"""Get the tokens info"""
|
||||
caller_name = get_token_name(caller_token)
|
||||
tokens = get_tokens_info()
|
||||
caller_name = TOKEN_REPO.get_token_by_token_string(caller_token).device_name
|
||||
tokens = TOKEN_REPO.get_tokens()
|
||||
return [
|
||||
TokenInfoWithIsCaller(
|
||||
name=token.name,
|
||||
date=token.date,
|
||||
is_caller=token.name == caller_name,
|
||||
name=token.device_name,
|
||||
date=token.created_at,
|
||||
is_caller=token.device_name == caller_name,
|
||||
)
|
||||
for token in tokens
|
||||
]
|
||||
|
||||
|
||||
def is_token_valid(token) -> bool:
|
||||
"""Check if token is valid"""
|
||||
return TOKEN_REPO.is_token_valid(token)
|
||||
|
||||
|
||||
class NotFoundException(Exception):
|
||||
"""Not found exception"""
|
||||
|
||||
|
@ -50,19 +54,22 @@ class CannotDeleteCallerException(Exception):
|
|||
|
||||
def delete_api_token(caller_token: str, token_name: str) -> None:
|
||||
"""Delete the token"""
|
||||
if is_token_name_pair_valid(token_name, caller_token):
|
||||
if TOKEN_REPO.is_token_name_pair_valid(token_name, caller_token):
|
||||
raise CannotDeleteCallerException("Cannot delete caller's token")
|
||||
if not is_token_name_exists(token_name):
|
||||
if not TOKEN_REPO.is_token_name_exists(token_name):
|
||||
raise NotFoundException("Token not found")
|
||||
delete_token(token_name)
|
||||
token = TOKEN_REPO.get_token_by_name(token_name)
|
||||
TOKEN_REPO.delete_token(token)
|
||||
|
||||
|
||||
def refresh_api_token(caller_token: str) -> str:
|
||||
"""Refresh the token"""
|
||||
new_token = refresh_token(caller_token)
|
||||
if new_token is None:
|
||||
try:
|
||||
old_token = TOKEN_REPO.get_token_by_token_string(caller_token)
|
||||
new_token = TOKEN_REPO.refresh_token(old_token)
|
||||
except TokenNotFound:
|
||||
raise NotFoundException("Token not found")
|
||||
return new_token
|
||||
return new_token.token
|
||||
|
||||
|
||||
class RecoveryTokenStatus(BaseModel):
|
||||
|
@ -77,18 +84,16 @@ class RecoveryTokenStatus(BaseModel):
|
|||
|
||||
def get_api_recovery_token_status() -> RecoveryTokenStatus:
|
||||
"""Get the recovery token status"""
|
||||
if not is_recovery_token_exists():
|
||||
token = TOKEN_REPO.get_recovery_key()
|
||||
if token is None:
|
||||
return RecoveryTokenStatus(exists=False, valid=False)
|
||||
status = get_recovery_token_status()
|
||||
if status is None:
|
||||
return RecoveryTokenStatus(exists=False, valid=False)
|
||||
is_valid = is_recovery_token_valid()
|
||||
is_valid = TOKEN_REPO.is_recovery_key_valid()
|
||||
return RecoveryTokenStatus(
|
||||
exists=True,
|
||||
valid=is_valid,
|
||||
date=status["date"],
|
||||
expiration=status["expiration"],
|
||||
uses_left=status["uses_left"],
|
||||
date=token.created_at,
|
||||
expiration=token.expires_at,
|
||||
uses_left=token.uses_left,
|
||||
)
|
||||
|
||||
|
||||
|
@ -112,5 +117,46 @@ def get_new_api_recovery_key(
|
|||
if uses_left <= 0:
|
||||
raise InvalidUsesLeft("Uses must be greater than 0")
|
||||
|
||||
key = generate_recovery_token(expiration_date, uses_left)
|
||||
return key
|
||||
key = TOKEN_REPO.create_recovery_key(expiration_date, uses_left)
|
||||
mnemonic_phrase = Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key))
|
||||
return mnemonic_phrase
|
||||
|
||||
|
||||
def use_mnemonic_recovery_token(mnemonic_phrase, name):
|
||||
"""Use the recovery token by converting the mnemonic word list to a byte array.
|
||||
If the recovery token if invalid itself, return None
|
||||
If the binary representation of phrase not matches
|
||||
the byte array of the recovery token, return None.
|
||||
If the mnemonic phrase is valid then generate a device token and return it.
|
||||
Substract 1 from uses_left if it exists.
|
||||
mnemonic_phrase is a string representation of the mnemonic word list.
|
||||
"""
|
||||
try:
|
||||
token = TOKEN_REPO.use_mnemonic_recovery_key(mnemonic_phrase, name)
|
||||
return token.token
|
||||
except (RecoveryKeyNotFound, InvalidMnemonic):
|
||||
return None
|
||||
|
||||
|
||||
def delete_new_device_auth_token() -> None:
|
||||
TOKEN_REPO.delete_new_device_key()
|
||||
|
||||
|
||||
def get_new_device_auth_token() -> str:
|
||||
"""Generate and store a new device auth token which is valid for 10 minutes
|
||||
and return a mnemonic phrase representation
|
||||
"""
|
||||
key = TOKEN_REPO.get_new_device_key()
|
||||
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key))
|
||||
|
||||
|
||||
def use_new_device_auth_token(mnemonic_phrase, name) -> Optional[str]:
|
||||
"""Use the new device auth token by converting the mnemonic string to a byte array.
|
||||
If the mnemonic phrase is valid then generate a device token and return it.
|
||||
New device auth token must be deleted.
|
||||
"""
|
||||
try:
|
||||
token = TOKEN_REPO.use_mnemonic_new_device_key(mnemonic_phrase, name)
|
||||
return token.token
|
||||
except (NewDeviceKeyNotFound, InvalidMnemonic):
|
||||
return None
|
||||
|
|
|
@ -2,7 +2,7 @@ from fastapi import Depends, HTTPException, status
|
|||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel
|
||||
|
||||
from selfprivacy_api.utils.auth import is_token_valid
|
||||
from selfprivacy_api.actions.api_tokens import is_token_valid
|
||||
|
||||
|
||||
class TokenHeader(BaseModel):
|
||||
|
|
|
@ -4,7 +4,7 @@ import typing
|
|||
from strawberry.permission import BasePermission
|
||||
from strawberry.types import Info
|
||||
|
||||
from selfprivacy_api.utils.auth import is_token_valid
|
||||
from selfprivacy_api.actions.api_tokens import is_token_valid
|
||||
|
||||
|
||||
class IsAuthenticated(BasePermission):
|
||||
|
|
|
@ -11,6 +11,11 @@ from selfprivacy_api.actions.api_tokens import (
|
|||
NotFoundException,
|
||||
delete_api_token,
|
||||
get_new_api_recovery_key,
|
||||
use_mnemonic_recovery_token,
|
||||
refresh_api_token,
|
||||
delete_new_device_auth_token,
|
||||
get_new_device_auth_token,
|
||||
use_new_device_auth_token,
|
||||
)
|
||||
from selfprivacy_api.graphql import IsAuthenticated
|
||||
from selfprivacy_api.graphql.mutations.mutation_interface import (
|
||||
|
@ -18,14 +23,6 @@ from selfprivacy_api.graphql.mutations.mutation_interface import (
|
|||
MutationReturnInterface,
|
||||
)
|
||||
|
||||
from selfprivacy_api.utils.auth import (
|
||||
delete_new_device_auth_token,
|
||||
get_new_device_auth_token,
|
||||
refresh_token,
|
||||
use_mnemonic_recoverery_token,
|
||||
use_new_device_auth_token,
|
||||
)
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class ApiKeyMutationReturn(MutationReturnInterface):
|
||||
|
@ -98,50 +95,53 @@ class ApiMutations:
|
|||
self, input: UseRecoveryKeyInput
|
||||
) -> DeviceApiTokenMutationReturn:
|
||||
"""Use recovery key"""
|
||||
token = use_mnemonic_recoverery_token(input.key, input.deviceName)
|
||||
if token is None:
|
||||
token = use_mnemonic_recovery_token(input.key, input.deviceName)
|
||||
if token is not None:
|
||||
return DeviceApiTokenMutationReturn(
|
||||
success=True,
|
||||
message="Recovery key used",
|
||||
code=200,
|
||||
token=token,
|
||||
)
|
||||
else:
|
||||
return DeviceApiTokenMutationReturn(
|
||||
success=False,
|
||||
message="Recovery key not found",
|
||||
code=404,
|
||||
token=None,
|
||||
)
|
||||
return DeviceApiTokenMutationReturn(
|
||||
success=True,
|
||||
message="Recovery key used",
|
||||
code=200,
|
||||
token=token,
|
||||
)
|
||||
|
||||
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
||||
def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn:
|
||||
"""Refresh device api token"""
|
||||
token = (
|
||||
token_string = (
|
||||
info.context["request"]
|
||||
.headers.get("Authorization", "")
|
||||
.replace("Bearer ", "")
|
||||
)
|
||||
if token is None:
|
||||
if token_string is None:
|
||||
return DeviceApiTokenMutationReturn(
|
||||
success=False,
|
||||
message="Token not found",
|
||||
code=404,
|
||||
token=None,
|
||||
)
|
||||
new_token = refresh_token(token)
|
||||
if new_token is None:
|
||||
|
||||
try:
|
||||
new_token = refresh_api_token(token_string)
|
||||
return DeviceApiTokenMutationReturn(
|
||||
success=True,
|
||||
message="Token refreshed",
|
||||
code=200,
|
||||
token=new_token,
|
||||
)
|
||||
except NotFoundException:
|
||||
return DeviceApiTokenMutationReturn(
|
||||
success=False,
|
||||
message="Token not found",
|
||||
code=404,
|
||||
token=None,
|
||||
)
|
||||
return DeviceApiTokenMutationReturn(
|
||||
success=True,
|
||||
message="Token refreshed",
|
||||
code=200,
|
||||
token=new_token,
|
||||
)
|
||||
|
||||
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
||||
def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn:
|
||||
|
|
|
@ -4,16 +4,12 @@ import datetime
|
|||
import typing
|
||||
import strawberry
|
||||
from strawberry.types import Info
|
||||
from selfprivacy_api.actions.api_tokens import get_api_tokens_with_caller_flag
|
||||
from selfprivacy_api.graphql import IsAuthenticated
|
||||
from selfprivacy_api.utils import parse_date
|
||||
from selfprivacy_api.dependencies import get_api_version as get_api_version_dependency
|
||||
|
||||
from selfprivacy_api.utils.auth import (
|
||||
get_recovery_token_status,
|
||||
is_recovery_token_exists,
|
||||
is_recovery_token_valid,
|
||||
from selfprivacy_api.actions.api_tokens import (
|
||||
get_api_tokens_with_caller_flag,
|
||||
get_api_recovery_token_status,
|
||||
)
|
||||
from selfprivacy_api.graphql import IsAuthenticated
|
||||
from selfprivacy_api.dependencies import get_api_version as get_api_version_dependency
|
||||
|
||||
|
||||
def get_api_version() -> str:
|
||||
|
@ -43,16 +39,8 @@ class ApiRecoveryKeyStatus:
|
|||
|
||||
def get_recovery_key_status() -> ApiRecoveryKeyStatus:
|
||||
"""Get recovery key status"""
|
||||
if not is_recovery_token_exists():
|
||||
return ApiRecoveryKeyStatus(
|
||||
exists=False,
|
||||
valid=False,
|
||||
creation_date=None,
|
||||
expiration_date=None,
|
||||
uses_left=None,
|
||||
)
|
||||
status = get_recovery_token_status()
|
||||
if status is None:
|
||||
status = get_api_recovery_token_status()
|
||||
if status is None or not status.exists:
|
||||
return ApiRecoveryKeyStatus(
|
||||
exists=False,
|
||||
valid=False,
|
||||
|
@ -62,12 +50,10 @@ def get_recovery_key_status() -> ApiRecoveryKeyStatus:
|
|||
)
|
||||
return ApiRecoveryKeyStatus(
|
||||
exists=True,
|
||||
valid=is_recovery_token_valid(),
|
||||
creation_date=parse_date(status["date"]),
|
||||
expiration_date=parse_date(status["expiration"])
|
||||
if status["expiration"] is not None
|
||||
else None,
|
||||
uses_left=status["uses_left"] if status["uses_left"] is not None else None,
|
||||
valid=status.valid,
|
||||
creation_date=status.date,
|
||||
expiration_date=status.expiration,
|
||||
uses_left=status.uses_left,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -97,8 +97,8 @@ class Jobs:
|
|||
error=None,
|
||||
result=None,
|
||||
)
|
||||
r = RedisPool().get_connection()
|
||||
_store_job_as_hash(r, _redis_key_from_uuid(job.uid), job)
|
||||
redis = RedisPool().get_connection()
|
||||
_store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job)
|
||||
return job
|
||||
|
||||
@staticmethod
|
||||
|
@ -113,10 +113,10 @@ class Jobs:
|
|||
"""
|
||||
Remove a job from the jobs list.
|
||||
"""
|
||||
r = RedisPool().get_connection()
|
||||
redis = RedisPool().get_connection()
|
||||
key = _redis_key_from_uuid(job_uuid)
|
||||
if (r.exists(key)):
|
||||
r.delete(key)
|
||||
if redis.exists(key):
|
||||
redis.delete(key)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -149,12 +149,12 @@ class Jobs:
|
|||
if status in (JobStatus.FINISHED, JobStatus.ERROR):
|
||||
job.finished_at = datetime.datetime.now()
|
||||
|
||||
r = RedisPool().get_connection()
|
||||
redis = RedisPool().get_connection()
|
||||
key = _redis_key_from_uuid(job.uid)
|
||||
if r.exists(key):
|
||||
_store_job_as_hash(r, key, job)
|
||||
if redis.exists(key):
|
||||
_store_job_as_hash(redis, key, job)
|
||||
if status in (JobStatus.FINISHED, JobStatus.ERROR):
|
||||
r.expire(key, JOB_EXPIRATION_SECONDS)
|
||||
redis.expire(key, JOB_EXPIRATION_SECONDS)
|
||||
|
||||
return job
|
||||
|
||||
|
@ -163,10 +163,10 @@ class Jobs:
|
|||
"""
|
||||
Get a job from the jobs list.
|
||||
"""
|
||||
r = RedisPool().get_connection()
|
||||
redis = RedisPool().get_connection()
|
||||
key = _redis_key_from_uuid(uid)
|
||||
if r.exists(key):
|
||||
return _job_from_hash(r, key)
|
||||
if redis.exists(key):
|
||||
return _job_from_hash(redis, key)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
@ -174,9 +174,14 @@ class Jobs:
|
|||
"""
|
||||
Get the jobs list.
|
||||
"""
|
||||
r = RedisPool().get_connection()
|
||||
jobs = r.keys("jobs:*")
|
||||
return [_job_from_hash(r, job_key) for job_key in jobs]
|
||||
redis = RedisPool().get_connection()
|
||||
job_keys = redis.keys("jobs:*")
|
||||
jobs = []
|
||||
for job_key in job_keys:
|
||||
job = _job_from_hash(redis, job_key)
|
||||
if job is not None:
|
||||
jobs.append(job)
|
||||
return jobs
|
||||
|
||||
@staticmethod
|
||||
def is_busy() -> bool:
|
||||
|
@ -189,11 +194,11 @@ class Jobs:
|
|||
return False
|
||||
|
||||
|
||||
def _redis_key_from_uuid(uuid):
|
||||
return "jobs:" + str(uuid)
|
||||
def _redis_key_from_uuid(uuid_string):
|
||||
return "jobs:" + str(uuid_string)
|
||||
|
||||
|
||||
def _store_job_as_hash(r, redis_key, model):
|
||||
def _store_job_as_hash(redis, redis_key, model):
|
||||
for key, value in model.dict().items():
|
||||
if isinstance(value, uuid.UUID):
|
||||
value = str(value)
|
||||
|
@ -201,12 +206,12 @@ def _store_job_as_hash(r, redis_key, model):
|
|||
value = value.isoformat()
|
||||
if isinstance(value, JobStatus):
|
||||
value = value.value
|
||||
r.hset(redis_key, key, str(value))
|
||||
redis.hset(redis_key, key, str(value))
|
||||
|
||||
|
||||
def _job_from_hash(r, redis_key):
|
||||
if r.exists(redis_key):
|
||||
job_dict = r.hgetall(redis_key)
|
||||
def _job_from_hash(redis, redis_key):
|
||||
if redis.exists(redis_key):
|
||||
job_dict = redis.hgetall(redis_key)
|
||||
for date in [
|
||||
"created_at",
|
||||
"updated_at",
|
||||
|
|
|
@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from mnemonic import Mnemonic
|
||||
from secrets import randbelow
|
||||
import re
|
||||
|
||||
from selfprivacy_api.models.tokens.token import Token
|
||||
from selfprivacy_api.repositories.tokens.exceptions import (
|
||||
|
@ -15,7 +17,7 @@ from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
|
|||
|
||||
|
||||
class AbstractTokensRepository(ABC):
|
||||
def get_token_by_token_string(self, token_string: str) -> Optional[Token]:
|
||||
def get_token_by_token_string(self, token_string: str) -> Token:
|
||||
"""Get the token by token"""
|
||||
tokens = self.get_tokens()
|
||||
for token in tokens:
|
||||
|
@ -24,7 +26,7 @@ class AbstractTokensRepository(ABC):
|
|||
|
||||
raise TokenNotFound("Token not found!")
|
||||
|
||||
def get_token_by_name(self, token_name: str) -> Optional[Token]:
|
||||
def get_token_by_name(self, token_name: str) -> Token:
|
||||
"""Get the token by name"""
|
||||
tokens = self.get_tokens()
|
||||
for token in tokens:
|
||||
|
@ -39,7 +41,8 @@ class AbstractTokensRepository(ABC):
|
|||
|
||||
def create_token(self, device_name: str) -> Token:
|
||||
"""Create new token"""
|
||||
new_token = Token.generate(device_name)
|
||||
unique_name = self._make_unique_device_name(device_name)
|
||||
new_token = Token.generate(unique_name)
|
||||
|
||||
self._store_token(new_token)
|
||||
|
||||
|
@ -52,6 +55,7 @@ class AbstractTokensRepository(ABC):
|
|||
def refresh_token(self, input_token: Token) -> Token:
|
||||
"""Change the token field of the existing token"""
|
||||
new_token = Token.generate(device_name=input_token.device_name)
|
||||
new_token.created_at = input_token.created_at
|
||||
|
||||
if input_token in self.get_tokens():
|
||||
self.delete_token(input_token)
|
||||
|
@ -62,22 +66,19 @@ class AbstractTokensRepository(ABC):
|
|||
|
||||
def is_token_valid(self, token_string: str) -> bool:
|
||||
"""Check if the token is valid"""
|
||||
token = self.get_token_by_token_string(token_string)
|
||||
if token is None:
|
||||
return False
|
||||
return True
|
||||
return token_string in [token.token for token in self.get_tokens()]
|
||||
|
||||
def is_token_name_exists(self, token_name: str) -> bool:
|
||||
"""Check if the token name exists"""
|
||||
token = self.get_token_by_name(token_name)
|
||||
if token is None:
|
||||
return False
|
||||
return True
|
||||
return token_name in [token.device_name for token in self.get_tokens()]
|
||||
|
||||
def is_token_name_pair_valid(self, token_name: str, token_string: str) -> bool:
|
||||
"""Check if the token name and token are valid"""
|
||||
token = self.get_token_by_name(token_name)
|
||||
if token is None:
|
||||
try:
|
||||
token = self.get_token_by_name(token_name)
|
||||
if token is None:
|
||||
return False
|
||||
except TokenNotFound:
|
||||
return False
|
||||
return token.token == token_string
|
||||
|
||||
|
@ -100,7 +101,12 @@ class AbstractTokensRepository(ABC):
|
|||
if not self.is_recovery_key_valid():
|
||||
raise RecoveryKeyNotFound("Recovery key not found")
|
||||
|
||||
recovery_hex_key = self.get_recovery_key().key
|
||||
recovery_key = self.get_recovery_key()
|
||||
|
||||
if recovery_key is None:
|
||||
raise RecoveryKeyNotFound("Recovery key not found")
|
||||
|
||||
recovery_hex_key = recovery_key.key
|
||||
if not self._assert_mnemonic(recovery_hex_key, mnemonic_phrase):
|
||||
raise RecoveryKeyNotFound("Recovery key not found")
|
||||
|
||||
|
@ -117,9 +123,15 @@ class AbstractTokensRepository(ABC):
|
|||
return False
|
||||
return recovery_key.is_valid()
|
||||
|
||||
@abstractmethod
|
||||
def get_new_device_key(self) -> NewDeviceKey:
|
||||
"""Creates and returns the new device key"""
|
||||
new_device_key = NewDeviceKey.generate()
|
||||
self._store_new_device_key(new_device_key)
|
||||
|
||||
return new_device_key
|
||||
|
||||
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
|
||||
"""Store new device key directly"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_new_device_key(self) -> None:
|
||||
|
@ -133,6 +145,9 @@ class AbstractTokensRepository(ABC):
|
|||
if not new_device_key:
|
||||
raise NewDeviceKeyNotFound
|
||||
|
||||
if not new_device_key.is_valid():
|
||||
raise NewDeviceKeyNotFound
|
||||
|
||||
if not self._assert_mnemonic(new_device_key.key, mnemonic_phrase):
|
||||
raise NewDeviceKeyNotFound("Phrase is not token!")
|
||||
|
||||
|
@ -153,6 +168,19 @@ class AbstractTokensRepository(ABC):
|
|||
def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
|
||||
"""Retrieves new device key that is already stored."""
|
||||
|
||||
def _make_unique_device_name(self, name: str) -> str:
|
||||
"""Token name must be an alphanumeric string and not empty.
|
||||
Replace invalid characters with '_'
|
||||
If name exists, add a random number to the end of the name until it is unique.
|
||||
"""
|
||||
if not re.match("^[a-zA-Z0-9]*$", name):
|
||||
name = re.sub("[^a-zA-Z0-9]", "_", name)
|
||||
if name == "":
|
||||
name = "Unknown device"
|
||||
while self.is_token_name_exists(name):
|
||||
name += str(randbelow(10))
|
||||
return name
|
||||
|
||||
# TODO: find a proper place for it
|
||||
def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str):
|
||||
"""Return true if hex string matches the phrase, false otherwise
|
||||
|
|
|
@ -69,7 +69,7 @@ class JsonTokensRepository(AbstractTokensRepository):
|
|||
recovery_key = RecoveryKey(
|
||||
key=tokens_file["recovery_token"].get("token"),
|
||||
created_at=tokens_file["recovery_token"].get("date"),
|
||||
expires_at=tokens_file["recovery_token"].get("expitation"),
|
||||
expires_at=tokens_file["recovery_token"].get("expiration"),
|
||||
uses_left=tokens_file["recovery_token"].get("uses_left"),
|
||||
)
|
||||
|
||||
|
@ -85,10 +85,13 @@ class JsonTokensRepository(AbstractTokensRepository):
|
|||
recovery_key = RecoveryKey.generate(expiration, uses_left)
|
||||
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
|
||||
key_expiration: Optional[str] = None
|
||||
if recovery_key.expires_at is not None:
|
||||
key_expiration = recovery_key.expires_at.strftime(DATETIME_FORMAT)
|
||||
tokens_file["recovery_token"] = {
|
||||
"token": recovery_key.key,
|
||||
"date": recovery_key.created_at.strftime(DATETIME_FORMAT),
|
||||
"expiration": recovery_key.expires_at,
|
||||
"expiration": key_expiration,
|
||||
"uses_left": recovery_key.uses_left,
|
||||
}
|
||||
|
||||
|
@ -98,12 +101,10 @@ class JsonTokensRepository(AbstractTokensRepository):
|
|||
"""Decrement recovery key use count by one"""
|
||||
if self.is_recovery_key_valid():
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
tokens["recovery_token"]["uses_left"] -= 1
|
||||
|
||||
def get_new_device_key(self) -> NewDeviceKey:
|
||||
"""Creates and returns the new device key"""
|
||||
new_device_key = NewDeviceKey.generate()
|
||||
if tokens["recovery_token"]["uses_left"] is not None:
|
||||
tokens["recovery_token"]["uses_left"] -= 1
|
||||
|
||||
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
|
||||
tokens_file["new_device"] = {
|
||||
"token": new_device_key.key,
|
||||
|
@ -111,8 +112,6 @@ class JsonTokensRepository(AbstractTokensRepository):
|
|||
"expiration": new_device_key.expires_at.strftime(DATETIME_FORMAT),
|
||||
}
|
||||
|
||||
return new_device_key
|
||||
|
||||
def delete_new_device_key(self) -> None:
|
||||
"""Delete the new device key"""
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
|
||||
|
|
|
@ -32,29 +32,34 @@ class RedisTokensRepository(AbstractTokensRepository):
|
|||
|
||||
def get_tokens(self) -> list[Token]:
|
||||
"""Get the tokens"""
|
||||
r = self.connection
|
||||
token_keys = r.keys(TOKENS_PREFIX + "*")
|
||||
return [self._token_from_hash(key) for key in token_keys]
|
||||
redis = self.connection
|
||||
token_keys = redis.keys(TOKENS_PREFIX + "*")
|
||||
tokens = []
|
||||
for key in token_keys:
|
||||
token = self._token_from_hash(key)
|
||||
if token is not None:
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def delete_token(self, input_token: Token) -> None:
|
||||
"""Delete the token"""
|
||||
r = self.connection
|
||||
redis = self.connection
|
||||
key = RedisTokensRepository._token_redis_key(input_token)
|
||||
if input_token not in self.get_tokens():
|
||||
raise TokenNotFound
|
||||
r.delete(key)
|
||||
redis.delete(key)
|
||||
|
||||
def reset(self):
|
||||
for token in self.get_tokens():
|
||||
self.delete_token(token)
|
||||
self.delete_new_device_key()
|
||||
r = self.connection
|
||||
r.delete(RECOVERY_KEY_REDIS_KEY)
|
||||
redis = self.connection
|
||||
redis.delete(RECOVERY_KEY_REDIS_KEY)
|
||||
|
||||
def get_recovery_key(self) -> Optional[RecoveryKey]:
|
||||
"""Get the recovery key"""
|
||||
r = self.connection
|
||||
if r.exists(RECOVERY_KEY_REDIS_KEY):
|
||||
redis = self.connection
|
||||
if redis.exists(RECOVERY_KEY_REDIS_KEY):
|
||||
return self._recovery_key_from_hash(RECOVERY_KEY_REDIS_KEY)
|
||||
return None
|
||||
|
||||
|
@ -68,16 +73,14 @@ class RedisTokensRepository(AbstractTokensRepository):
|
|||
self._store_model_as_hash(RECOVERY_KEY_REDIS_KEY, recovery_key)
|
||||
return recovery_key
|
||||
|
||||
def get_new_device_key(self) -> NewDeviceKey:
|
||||
"""Creates and returns the new device key"""
|
||||
new_device_key = NewDeviceKey.generate()
|
||||
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
|
||||
"""Store new device key directly"""
|
||||
self._store_model_as_hash(NEW_DEVICE_KEY_REDIS_KEY, new_device_key)
|
||||
return new_device_key
|
||||
|
||||
def delete_new_device_key(self) -> None:
|
||||
"""Delete the new device key"""
|
||||
r = self.connection
|
||||
r.delete(NEW_DEVICE_KEY_REDIS_KEY)
|
||||
redis = self.connection
|
||||
redis.delete(NEW_DEVICE_KEY_REDIS_KEY)
|
||||
|
||||
@staticmethod
|
||||
def _token_redis_key(token: Token) -> str:
|
||||
|
@ -91,9 +94,13 @@ class RedisTokensRepository(AbstractTokensRepository):
|
|||
def _decrement_recovery_token(self):
|
||||
"""Decrement recovery key use count by one"""
|
||||
if self.is_recovery_key_valid():
|
||||
uses_left = self.get_recovery_key().uses_left
|
||||
r = self.connection
|
||||
r.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1)
|
||||
recovery_key = self.get_recovery_key()
|
||||
if recovery_key is None:
|
||||
return
|
||||
uses_left = recovery_key.uses_left
|
||||
if uses_left is not None:
|
||||
redis = self.connection
|
||||
redis.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1)
|
||||
|
||||
def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
|
||||
"""Retrieves new device key that is already stored."""
|
||||
|
@ -117,9 +124,9 @@ class RedisTokensRepository(AbstractTokensRepository):
|
|||
d[key] = None
|
||||
|
||||
def _model_dict_from_hash(self, redis_key: str) -> Optional[dict]:
|
||||
r = self.connection
|
||||
if r.exists(redis_key):
|
||||
token_dict = r.hgetall(redis_key)
|
||||
redis = self.connection
|
||||
if redis.exists(redis_key):
|
||||
token_dict = redis.hgetall(redis_key)
|
||||
RedisTokensRepository._prepare_model_dict(token_dict)
|
||||
return token_dict
|
||||
return None
|
||||
|
@ -140,8 +147,8 @@ class RedisTokensRepository(AbstractTokensRepository):
|
|||
return self._hash_as_model(redis_key, NewDeviceKey)
|
||||
|
||||
def _store_model_as_hash(self, redis_key, model):
|
||||
r = self.connection
|
||||
redis = self.connection
|
||||
for key, value in model.dict().items():
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
r.hset(redis_key, key, str(value))
|
||||
redis.hset(redis_key, key, str(value))
|
||||
|
|
|
@ -8,20 +8,18 @@ from selfprivacy_api.actions.api_tokens import (
|
|||
InvalidUsesLeft,
|
||||
NotFoundException,
|
||||
delete_api_token,
|
||||
refresh_api_token,
|
||||
get_api_recovery_token_status,
|
||||
get_api_tokens_with_caller_flag,
|
||||
get_new_api_recovery_key,
|
||||
refresh_api_token,
|
||||
use_mnemonic_recovery_token,
|
||||
delete_new_device_auth_token,
|
||||
get_new_device_auth_token,
|
||||
use_new_device_auth_token,
|
||||
)
|
||||
|
||||
from selfprivacy_api.dependencies import TokenHeader, get_token_header
|
||||
|
||||
from selfprivacy_api.utils.auth import (
|
||||
delete_new_device_auth_token,
|
||||
get_new_device_auth_token,
|
||||
use_mnemonic_recoverery_token,
|
||||
use_new_device_auth_token,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auth",
|
||||
|
@ -99,7 +97,7 @@ class UseTokenInput(BaseModel):
|
|||
|
||||
@router.post("/recovery_token/use")
|
||||
async def rest_use_recovery_token(input: UseTokenInput):
|
||||
token = use_mnemonic_recoverery_token(input.token, input.device)
|
||||
token = use_mnemonic_recovery_token(input.token, input.device)
|
||||
if token is None:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
return {"token": token}
|
||||
|
|
|
@ -1,329 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Token management utils"""
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
import re
|
||||
import typing
|
||||
|
||||
from pydantic import BaseModel
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import ReadUserData, UserDataFiles, WriteUserData, parse_date
|
||||
|
||||
"""
|
||||
Token are stored in the tokens.json file.
|
||||
File contains device tokens, recovery token and new device auth token.
|
||||
File structure:
|
||||
{
|
||||
"tokens": [
|
||||
{
|
||||
"token": "device token",
|
||||
"name": "device name",
|
||||
"date": "date of creation",
|
||||
}
|
||||
],
|
||||
"recovery_token": {
|
||||
"token": "recovery token",
|
||||
"date": "date of creation",
|
||||
"expiration": "date of expiration",
|
||||
"uses_left": "number of uses left"
|
||||
},
|
||||
"new_device": {
|
||||
"token": "new device auth token",
|
||||
"date": "date of creation",
|
||||
"expiration": "date of expiration",
|
||||
}
|
||||
}
|
||||
Recovery token may or may not have expiration date and uses_left.
|
||||
There may be no recovery token at all.
|
||||
Device tokens must be unique.
|
||||
"""
|
||||
|
||||
|
||||
def _get_tokens():
|
||||
"""Get all tokens as list of tokens of every device"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
return [token["token"] for token in tokens["tokens"]]
|
||||
|
||||
|
||||
def _get_token_names():
|
||||
"""Get all token names"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
return [t["name"] for t in tokens["tokens"]]
|
||||
|
||||
|
||||
def _validate_token_name(name):
|
||||
"""Token name must be an alphanumeric string and not empty.
|
||||
Replace invalid characters with '_'
|
||||
If token name exists, add a random number to the end of the name until it is unique.
|
||||
"""
|
||||
if not re.match("^[a-zA-Z0-9]*$", name):
|
||||
name = re.sub("[^a-zA-Z0-9]", "_", name)
|
||||
if name == "":
|
||||
name = "Unknown device"
|
||||
while name in _get_token_names():
|
||||
name += str(secrets.randbelow(10))
|
||||
return name
|
||||
|
||||
|
||||
def is_token_valid(token):
|
||||
"""Check if token is valid"""
|
||||
if token in _get_tokens():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_token_name_exists(token_name):
|
||||
"""Check if token name exists"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
return token_name in [t["name"] for t in tokens["tokens"]]
|
||||
|
||||
|
||||
def is_token_name_pair_valid(token_name, token):
|
||||
"""Check if token name and token pair exists"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
for t in tokens["tokens"]:
|
||||
if t["name"] == token_name and t["token"] == token:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_token_name(token: str) -> typing.Optional[str]:
|
||||
"""Return the name of the token provided"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
for t in tokens["tokens"]:
|
||||
if t["token"] == token:
|
||||
return t["name"]
|
||||
return None
|
||||
|
||||
|
||||
class BasicTokenInfo(BaseModel):
|
||||
"""Token info"""
|
||||
|
||||
name: str
|
||||
date: datetime
|
||||
|
||||
|
||||
def get_tokens_info():
|
||||
"""Get all tokens info without tokens themselves"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
return [
|
||||
BasicTokenInfo(
|
||||
name=t["name"],
|
||||
date=parse_date(t["date"]),
|
||||
)
|
||||
for t in tokens["tokens"]
|
||||
]
|
||||
|
||||
|
||||
def _generate_token():
|
||||
"""Generates new token and makes sure it is unique"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
while token in _get_tokens():
|
||||
token = secrets.token_urlsafe(32)
|
||||
return token
|
||||
|
||||
|
||||
def create_token(name):
|
||||
"""Create new token"""
|
||||
token = _generate_token()
|
||||
name = _validate_token_name(name)
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
tokens["tokens"].append(
|
||||
{
|
||||
"token": token,
|
||||
"name": name,
|
||||
"date": str(datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")),
|
||||
}
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
def delete_token(token_name):
|
||||
"""Delete token"""
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
tokens["tokens"] = [t for t in tokens["tokens"] if t["name"] != token_name]
|
||||
|
||||
|
||||
def refresh_token(token: str) -> typing.Optional[str]:
|
||||
"""Change the token field of the existing token"""
|
||||
new_token = _generate_token()
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
for t in tokens["tokens"]:
|
||||
if t["token"] == token:
|
||||
t["token"] = new_token
|
||||
return new_token
|
||||
return None
|
||||
|
||||
|
||||
def is_recovery_token_exists():
|
||||
"""Check if recovery token exists"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
return "recovery_token" in tokens
|
||||
|
||||
|
||||
def is_recovery_token_valid():
|
||||
"""Check if recovery token is valid"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
if "recovery_token" not in tokens:
|
||||
return False
|
||||
recovery_token = tokens["recovery_token"]
|
||||
if "uses_left" in recovery_token and recovery_token["uses_left"] is not None:
|
||||
if recovery_token["uses_left"] <= 0:
|
||||
return False
|
||||
if "expiration" not in recovery_token or recovery_token["expiration"] is None:
|
||||
return True
|
||||
return datetime.now() < parse_date(recovery_token["expiration"])
|
||||
|
||||
|
||||
def get_recovery_token_status():
|
||||
"""Get recovery token date of creation, expiration and uses left"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
if "recovery_token" not in tokens:
|
||||
return None
|
||||
recovery_token = tokens["recovery_token"]
|
||||
return {
|
||||
"date": recovery_token["date"],
|
||||
"expiration": recovery_token["expiration"]
|
||||
if "expiration" in recovery_token
|
||||
else None,
|
||||
"uses_left": recovery_token["uses_left"]
|
||||
if "uses_left" in recovery_token
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
def _get_recovery_token():
|
||||
"""Get recovery token"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
if "recovery_token" not in tokens:
|
||||
return None
|
||||
return tokens["recovery_token"]["token"]
|
||||
|
||||
|
||||
def generate_recovery_token(
|
||||
expiration: typing.Optional[datetime], uses_left: typing.Optional[int]
|
||||
) -> str:
|
||||
"""Generate a 24 bytes recovery token and return a mneomnic word list.
|
||||
Write a string representation of the recovery token to the tokens.json file.
|
||||
"""
|
||||
# expires must be a date or None
|
||||
# uses_left must be an integer or None
|
||||
if expiration is not None:
|
||||
if not isinstance(expiration, datetime):
|
||||
raise TypeError("expires must be a datetime object")
|
||||
if uses_left is not None:
|
||||
if not isinstance(uses_left, int):
|
||||
raise TypeError("uses_left must be an integer")
|
||||
if uses_left <= 0:
|
||||
raise ValueError("uses_left must be greater than 0")
|
||||
|
||||
recovery_token = secrets.token_bytes(24)
|
||||
recovery_token_str = recovery_token.hex()
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
tokens["recovery_token"] = {
|
||||
"token": recovery_token_str,
|
||||
"date": str(datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")),
|
||||
"expiration": expiration.strftime("%Y-%m-%dT%H:%M:%S.%f")
|
||||
if expiration is not None
|
||||
else None,
|
||||
"uses_left": uses_left if uses_left is not None else None,
|
||||
}
|
||||
return Mnemonic(language="english").to_mnemonic(recovery_token)
|
||||
|
||||
|
||||
def use_mnemonic_recoverery_token(mnemonic_phrase, name):
|
||||
"""Use the recovery token by converting the mnemonic word list to a byte array.
|
||||
If the recovery token if invalid itself, return None
|
||||
If the binary representation of phrase not matches
|
||||
the byte array of the recovery token, return None.
|
||||
If the mnemonic phrase is valid then generate a device token and return it.
|
||||
Substract 1 from uses_left if it exists.
|
||||
mnemonic_phrase is a string representation of the mnemonic word list.
|
||||
"""
|
||||
if not is_recovery_token_valid():
|
||||
return None
|
||||
recovery_token_str = _get_recovery_token()
|
||||
if recovery_token_str is None:
|
||||
return None
|
||||
recovery_token = bytes.fromhex(recovery_token_str)
|
||||
if not Mnemonic(language="english").check(mnemonic_phrase):
|
||||
return None
|
||||
phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
|
||||
if phrase_bytes != recovery_token:
|
||||
return None
|
||||
token = _generate_token()
|
||||
name = _validate_token_name(name)
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
tokens["tokens"].append(
|
||||
{
|
||||
"token": token,
|
||||
"name": name,
|
||||
"date": str(datetime.now()),
|
||||
}
|
||||
)
|
||||
if "recovery_token" in tokens:
|
||||
if (
|
||||
"uses_left" in tokens["recovery_token"]
|
||||
and tokens["recovery_token"]["uses_left"] is not None
|
||||
):
|
||||
tokens["recovery_token"]["uses_left"] -= 1
|
||||
return token
|
||||
|
||||
|
||||
def get_new_device_auth_token() -> str:
|
||||
"""Generate a new device auth token which is valid for 10 minutes
|
||||
and return a mnemonic phrase representation
|
||||
Write token to the new_device of the tokens.json file.
|
||||
"""
|
||||
token = secrets.token_bytes(16)
|
||||
token_str = token.hex()
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
tokens["new_device"] = {
|
||||
"token": token_str,
|
||||
"date": str(datetime.now()),
|
||||
"expiration": str(datetime.now() + timedelta(minutes=10)),
|
||||
}
|
||||
return Mnemonic(language="english").to_mnemonic(token)
|
||||
|
||||
|
||||
def _get_new_device_auth_token():
|
||||
"""Get new device auth token. If it is expired, return None"""
|
||||
with ReadUserData(UserDataFiles.TOKENS) as tokens:
|
||||
if "new_device" not in tokens:
|
||||
return None
|
||||
new_device = tokens["new_device"]
|
||||
if "expiration" not in new_device:
|
||||
return None
|
||||
expiration = parse_date(new_device["expiration"])
|
||||
if datetime.now() > expiration:
|
||||
return None
|
||||
return new_device["token"]
|
||||
|
||||
|
||||
def delete_new_device_auth_token():
|
||||
"""Delete new device auth token"""
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
if "new_device" in tokens:
|
||||
del tokens["new_device"]
|
||||
|
||||
|
||||
def use_new_device_auth_token(mnemonic_phrase, name):
|
||||
"""Use the new device auth token by converting the mnemonic string to a byte array.
|
||||
If the mnemonic phrase is valid then generate a device token and return it.
|
||||
New device auth token must be deleted.
|
||||
"""
|
||||
token_str = _get_new_device_auth_token()
|
||||
if token_str is None:
|
||||
return None
|
||||
token = bytes.fromhex(token_str)
|
||||
if not Mnemonic(language="english").check(mnemonic_phrase):
|
||||
return None
|
||||
phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
|
||||
if phrase_bytes != token:
|
||||
return None
|
||||
token = create_token(name)
|
||||
with WriteUserData(UserDataFiles.TOKENS) as tokens:
|
||||
if "new_device" in tokens:
|
||||
del tokens["new_device"]
|
||||
return token
|
|
@ -2,8 +2,14 @@
|
|||
# pylint: disable=unused-argument
|
||||
# pylint: disable=missing-function-docstring
|
||||
import datetime
|
||||
import pytest
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
|
||||
JsonTokensRepository,
|
||||
)
|
||||
from selfprivacy_api.models.tokens.token import Token
|
||||
|
||||
from tests.common import generate_api_query, read_json, write_json
|
||||
|
||||
TOKENS_FILE_CONTETS = {
|
||||
|
@ -30,6 +36,11 @@ devices {
|
|||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_repo():
|
||||
return JsonTokensRepository()
|
||||
|
||||
|
||||
def test_graphql_tokens_info(authorized_client, tokens_file):
|
||||
response = authorized_client.post(
|
||||
"/graphql",
|
||||
|
@ -170,7 +181,7 @@ def test_graphql_refresh_token_unauthorized(client, tokens_file):
|
|||
assert response.json()["data"] is None
|
||||
|
||||
|
||||
def test_graphql_refresh_token(authorized_client, tokens_file):
|
||||
def test_graphql_refresh_token(authorized_client, tokens_file, token_repo):
|
||||
response = authorized_client.post(
|
||||
"/graphql",
|
||||
json={"query": REFRESH_TOKEN_MUTATION},
|
||||
|
@ -180,11 +191,12 @@ def test_graphql_refresh_token(authorized_client, tokens_file):
|
|||
assert response.json()["data"]["refreshDeviceApiToken"]["success"] is True
|
||||
assert response.json()["data"]["refreshDeviceApiToken"]["message"] is not None
|
||||
assert response.json()["data"]["refreshDeviceApiToken"]["code"] == 200
|
||||
assert read_json(tokens_file)["tokens"][0] == {
|
||||
"token": response.json()["data"]["refreshDeviceApiToken"]["token"],
|
||||
"name": "test_token",
|
||||
"date": "2022-01-14 08:31:10.789314",
|
||||
}
|
||||
token = token_repo.get_token_by_name("test_token")
|
||||
assert token == Token(
|
||||
token=response.json()["data"]["refreshDeviceApiToken"]["token"],
|
||||
device_name="test_token",
|
||||
created_at=datetime.datetime(2022, 1, 14, 8, 31, 10, 789314),
|
||||
)
|
||||
|
||||
|
||||
NEW_DEVICE_KEY_MUTATION = """
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
# pylint: disable=unused-argument
|
||||
# pylint: disable=missing-function-docstring
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -32,6 +33,10 @@ ORIGINAL_DEVICE_NAMES = [
|
|||
]
|
||||
|
||||
|
||||
def mnemonic_from_hex(hexkey):
|
||||
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(hexkey))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_keys(mocker, datadir):
|
||||
mocker.patch("selfprivacy_api.utils.TOKENS_FILE", new=datadir / "empty_keys.json")
|
||||
|
@ -132,21 +137,6 @@ def mock_recovery_key_generate(mocker):
|
|||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_recovery_key_generate_for_mnemonic(mocker):
|
||||
mock = mocker.patch(
|
||||
"selfprivacy_api.models.tokens.recovery_key.RecoveryKey.generate",
|
||||
autospec=True,
|
||||
return_value=RecoveryKey(
|
||||
key="ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54",
|
||||
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
|
||||
expires_at=None,
|
||||
uses_left=1,
|
||||
),
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_json_repo(empty_keys):
|
||||
repo = JsonTokensRepository()
|
||||
|
@ -221,6 +211,28 @@ def test_get_token_by_non_existent_name(some_tokens_repo):
|
|||
assert repo.get_token_by_name(token_name="badname") is None
|
||||
|
||||
|
||||
def test_is_token_valid(some_tokens_repo):
|
||||
repo = some_tokens_repo
|
||||
token = repo.get_tokens()[0]
|
||||
assert repo.is_token_valid(token.token)
|
||||
assert not repo.is_token_valid("gibberish")
|
||||
|
||||
|
||||
def test_is_token_name_pair_valid(some_tokens_repo):
|
||||
repo = some_tokens_repo
|
||||
token = repo.get_tokens()[0]
|
||||
assert repo.is_token_name_pair_valid(token.device_name, token.token)
|
||||
assert not repo.is_token_name_pair_valid(token.device_name, "gibberish")
|
||||
assert not repo.is_token_name_pair_valid("gibberish", token.token)
|
||||
|
||||
|
||||
def test_is_token_name_exists(some_tokens_repo):
|
||||
repo = some_tokens_repo
|
||||
token = repo.get_tokens()[0]
|
||||
assert repo.is_token_name_exists(token.device_name)
|
||||
assert not repo.is_token_name_exists("gibberish")
|
||||
|
||||
|
||||
def test_get_tokens(some_tokens_repo):
|
||||
repo = some_tokens_repo
|
||||
tokenstrings = []
|
||||
|
@ -249,6 +261,17 @@ def test_create_token(empty_repo, mock_token_generate):
|
|||
]
|
||||
|
||||
|
||||
def test_create_token_existing(some_tokens_repo):
|
||||
repo = some_tokens_repo
|
||||
old_token = repo.get_tokens()[0]
|
||||
|
||||
new_token = repo.create_token(device_name=old_token.device_name)
|
||||
assert new_token.device_name != old_token.device_name
|
||||
|
||||
assert old_token in repo.get_tokens()
|
||||
assert new_token in repo.get_tokens()
|
||||
|
||||
|
||||
def test_delete_token(some_tokens_repo):
|
||||
repo = some_tokens_repo
|
||||
original_tokens = repo.get_tokens()
|
||||
|
@ -280,15 +303,17 @@ def test_delete_not_found_token(some_tokens_repo):
|
|||
assert token in new_tokens
|
||||
|
||||
|
||||
def test_refresh_token(some_tokens_repo, mock_token_generate):
|
||||
def test_refresh_token(some_tokens_repo):
|
||||
repo = some_tokens_repo
|
||||
input_token = some_tokens_repo.get_tokens()[0]
|
||||
|
||||
assert repo.refresh_token(input_token) == Token(
|
||||
token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM",
|
||||
device_name="IamNewDevice",
|
||||
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
|
||||
)
|
||||
output_token = repo.refresh_token(input_token)
|
||||
|
||||
assert output_token.token != input_token.token
|
||||
assert output_token.device_name == input_token.device_name
|
||||
assert output_token.created_at == input_token.created_at
|
||||
|
||||
assert output_token in repo.get_tokens()
|
||||
|
||||
|
||||
def test_refresh_not_found_token(some_tokens_repo, mock_token_generate):
|
||||
|
@ -355,6 +380,23 @@ def test_use_mnemonic_not_valid_recovery_key(
|
|||
)
|
||||
|
||||
|
||||
def test_use_mnemonic_expired_recovery_key(
|
||||
some_tokens_repo,
|
||||
):
|
||||
repo = some_tokens_repo
|
||||
expiration = datetime.now() - timedelta(minutes=5)
|
||||
assert repo.create_recovery_key(uses_left=2, expiration=expiration) is not None
|
||||
recovery_key = repo.get_recovery_key()
|
||||
assert recovery_key.expires_at == expiration
|
||||
assert not repo.is_recovery_key_valid()
|
||||
|
||||
with pytest.raises(RecoveryKeyNotFound):
|
||||
token = repo.use_mnemonic_recovery_key(
|
||||
mnemonic_phrase=mnemonic_from_hex(recovery_key.key),
|
||||
device_name="newdevice",
|
||||
)
|
||||
|
||||
|
||||
def test_use_mnemonic_not_mnemonic_recovery_key(some_tokens_repo):
|
||||
repo = some_tokens_repo
|
||||
assert repo.create_recovery_key(uses_left=1, expiration=None) is not None
|
||||
|
@ -397,46 +439,38 @@ def test_use_not_found_mnemonic_recovery_key(some_tokens_repo):
|
|||
)
|
||||
|
||||
|
||||
def test_use_mnemonic_recovery_key_when_empty(empty_repo):
|
||||
repo = empty_repo
|
||||
|
||||
with pytest.raises(RecoveryKeyNotFound):
|
||||
assert (
|
||||
repo.use_mnemonic_recovery_key(
|
||||
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
|
||||
device_name="primary_token",
|
||||
)
|
||||
is None
|
||||
)
|
||||
@pytest.fixture(params=["recovery_uses_1", "recovery_eternal"])
|
||||
def recovery_key_uses_left(request):
|
||||
if request.param == "recovery_uses_1":
|
||||
return 1
|
||||
if request.param == "recovery_eternal":
|
||||
return None
|
||||
|
||||
|
||||
# agnostic test mixed with an implementation test
|
||||
def test_use_mnemonic_recovery_key(
|
||||
some_tokens_repo, mock_recovery_key_generate_for_mnemonic, mock_generate_token
|
||||
):
|
||||
def test_use_mnemonic_recovery_key(some_tokens_repo, recovery_key_uses_left):
|
||||
repo = some_tokens_repo
|
||||
assert repo.create_recovery_key(uses_left=1, expiration=None) is not None
|
||||
|
||||
test_token = Token(
|
||||
token="ur71mC4aiI6FIYAN--cTL-38rPHS5D6NuB1bgN_qKF4",
|
||||
device_name="newdevice",
|
||||
created_at=datetime(2022, 11, 14, 6, 6, 32, 777123),
|
||||
)
|
||||
|
||||
assert (
|
||||
repo.use_mnemonic_recovery_key(
|
||||
mnemonic_phrase="uniform clarify napkin bid dress search input armor police cross salon because myself uphold slice bamboo hungry park",
|
||||
device_name="newdevice",
|
||||
)
|
||||
== test_token
|
||||
repo.create_recovery_key(uses_left=recovery_key_uses_left, expiration=None)
|
||||
is not None
|
||||
)
|
||||
assert repo.is_recovery_key_valid()
|
||||
recovery_key = repo.get_recovery_key()
|
||||
|
||||
token = repo.use_mnemonic_recovery_key(
|
||||
mnemonic_phrase=mnemonic_from_hex(recovery_key.key),
|
||||
device_name="newdevice",
|
||||
)
|
||||
|
||||
assert test_token in repo.get_tokens()
|
||||
assert token.device_name == "newdevice"
|
||||
assert token in repo.get_tokens()
|
||||
new_uses = None
|
||||
if recovery_key_uses_left is not None:
|
||||
new_uses = recovery_key_uses_left - 1
|
||||
assert repo.get_recovery_key() == RecoveryKey(
|
||||
key="ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54",
|
||||
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
|
||||
key=recovery_key.key,
|
||||
created_at=recovery_key.created_at,
|
||||
expires_at=None,
|
||||
uses_left=0,
|
||||
uses_left=new_uses,
|
||||
)
|
||||
|
||||
|
||||
|
@ -497,15 +531,16 @@ def test_use_not_exists_mnemonic_new_device_key(
|
|||
)
|
||||
|
||||
|
||||
def test_use_mnemonic_new_device_key(
|
||||
empty_repo, mock_new_device_key_generate_for_mnemonic
|
||||
):
|
||||
def test_use_mnemonic_new_device_key(empty_repo):
|
||||
repo = empty_repo
|
||||
assert repo.get_new_device_key() is not None
|
||||
key = repo.get_new_device_key()
|
||||
assert key is not None
|
||||
|
||||
mnemonic_phrase = mnemonic_from_hex(key.key)
|
||||
|
||||
new_token = repo.use_mnemonic_new_device_key(
|
||||
device_name="imnew",
|
||||
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
|
||||
mnemonic_phrase=mnemonic_phrase,
|
||||
)
|
||||
|
||||
assert new_token.device_name == "imnew"
|
||||
|
@ -516,12 +551,32 @@ def test_use_mnemonic_new_device_key(
|
|||
assert (
|
||||
repo.use_mnemonic_new_device_key(
|
||||
device_name="imnew",
|
||||
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
|
||||
mnemonic_phrase=mnemonic_phrase,
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_use_mnemonic_expired_new_device_key(
|
||||
some_tokens_repo,
|
||||
):
|
||||
repo = some_tokens_repo
|
||||
expiration = datetime.now() - timedelta(minutes=5)
|
||||
|
||||
key = repo.get_new_device_key()
|
||||
assert key is not None
|
||||
assert key.expires_at is not None
|
||||
key.expires_at = expiration
|
||||
assert not key.is_valid()
|
||||
repo._store_new_device_key(key)
|
||||
|
||||
with pytest.raises(NewDeviceKeyNotFound):
|
||||
token = repo.use_mnemonic_new_device_key(
|
||||
mnemonic_phrase=mnemonic_from_hex(key.key),
|
||||
device_name="imnew",
|
||||
)
|
||||
|
||||
|
||||
def test_use_mnemonic_new_device_key_when_empty(empty_repo):
|
||||
repo = empty_repo
|
||||
|
||||
|
|
18
tests/test_models.py
Normal file
18
tests/test_models.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
|
||||
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
|
||||
|
||||
|
||||
def test_recovery_key_expired():
|
||||
expiration = datetime.now() - timedelta(minutes=5)
|
||||
key = RecoveryKey.generate(expiration=expiration, uses_left=2)
|
||||
assert not key.is_valid()
|
||||
|
||||
|
||||
def test_new_device_key_expired():
|
||||
expiration = datetime.now() - timedelta(minutes=5)
|
||||
key = NewDeviceKey.generate()
|
||||
key.expires_at = expiration
|
||||
assert not key.is_valid()
|
|
@ -5,6 +5,12 @@ import datetime
|
|||
import pytest
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
|
||||
JsonTokensRepository,
|
||||
)
|
||||
|
||||
TOKEN_REPO = JsonTokensRepository()
|
||||
|
||||
from tests.common import read_json, write_json
|
||||
|
||||
|
||||
|
@ -97,7 +103,7 @@ def test_refresh_token(authorized_client, tokens_file):
|
|||
response = authorized_client.post("/auth/tokens")
|
||||
assert response.status_code == 200
|
||||
new_token = response.json()["token"]
|
||||
assert read_json(tokens_file)["tokens"][0]["token"] == new_token
|
||||
assert TOKEN_REPO.get_token_by_token_string(new_token) is not None
|
||||
|
||||
|
||||
# new device
|
||||
|
|
Loading…
Reference in a new issue