feat: removed unnecessary functionality from the repository

This commit is contained in:
dettlaff 2024-11-03 03:15:51 +04:00
parent 516cb781dc
commit fbc0ae61fb
15 changed files with 261 additions and 295 deletions

View file

@ -9,9 +9,7 @@ from pydantic import BaseModel
from mnemonic import Mnemonic from mnemonic import Mnemonic
from selfprivacy_api.utils.timeutils import ensure_tz_aware, ensure_tz_aware_strict from selfprivacy_api.utils.timeutils import ensure_tz_aware, ensure_tz_aware_strict
from selfprivacy_api.repositories.tokens.redis_tokens_repository import ( from selfprivacy_api.repositories.tokens import ACTIVE_TOKEN_PROVIDER
RedisTokensRepository,
)
from selfprivacy_api.repositories.tokens.exceptions import ( from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound, TokenNotFound,
RecoveryKeyNotFound, RecoveryKeyNotFound,
@ -19,8 +17,6 @@ from selfprivacy_api.repositories.tokens.exceptions import (
NewDeviceKeyNotFound, NewDeviceKeyNotFound,
) )
TOKEN_REPO = RedisTokensRepository()
class TokenInfoWithIsCaller(BaseModel): class TokenInfoWithIsCaller(BaseModel):
"""Token info""" """Token info"""
@ -40,8 +36,10 @@ def _naive(date_time: datetime) -> datetime:
def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]: def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]:
"""Get the tokens info""" """Get the tokens info"""
caller_name = TOKEN_REPO.get_token_by_token_string(caller_token).device_name caller_name = ACTIVE_TOKEN_PROVIDER.get_token_by_token_string(
tokens = TOKEN_REPO.get_tokens() caller_token
).device_name
tokens = ACTIVE_TOKEN_PROVIDER.get_tokens()
return [ return [
TokenInfoWithIsCaller( TokenInfoWithIsCaller(
name=token.device_name, name=token.device_name,
@ -54,7 +52,7 @@ def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCa
def is_token_valid(token) -> bool: def is_token_valid(token) -> bool:
"""Check if token is valid""" """Check if token is valid"""
return TOKEN_REPO.is_token_valid(token) return ACTIVE_TOKEN_PROVIDER.is_token_valid(token)
class NotFoundException(Exception): class NotFoundException(Exception):
@ -67,19 +65,19 @@ class CannotDeleteCallerException(Exception):
def delete_api_token(caller_token: str, token_name: str) -> None: def delete_api_token(caller_token: str, token_name: str) -> None:
"""Delete the token""" """Delete the token"""
if TOKEN_REPO.is_token_name_pair_valid(token_name, caller_token): if ACTIVE_TOKEN_PROVIDER.is_token_name_pair_valid(token_name, caller_token):
raise CannotDeleteCallerException("Cannot delete caller's token") raise CannotDeleteCallerException("Cannot delete caller's token")
if not TOKEN_REPO.is_token_name_exists(token_name): if not ACTIVE_TOKEN_PROVIDER.is_token_name_exists(token_name):
raise NotFoundException("Token not found") raise NotFoundException("Token not found")
token = TOKEN_REPO.get_token_by_name(token_name) token = ACTIVE_TOKEN_PROVIDER.get_token_by_name(token_name)
TOKEN_REPO.delete_token(token) ACTIVE_TOKEN_PROVIDER.delete_token(token)
def refresh_api_token(caller_token: str) -> str: def refresh_api_token(caller_token: str) -> str:
"""Refresh the token""" """Refresh the token"""
try: try:
old_token = TOKEN_REPO.get_token_by_token_string(caller_token) old_token = ACTIVE_TOKEN_PROVIDER.get_token_by_token_string(caller_token)
new_token = TOKEN_REPO.refresh_token(old_token) new_token = ACTIVE_TOKEN_PROVIDER.refresh_token(old_token)
except TokenNotFound: except TokenNotFound:
raise NotFoundException("Token not found") raise NotFoundException("Token not found")
return new_token.token return new_token.token
@ -97,10 +95,10 @@ class RecoveryTokenStatus(BaseModel):
def get_api_recovery_token_status() -> RecoveryTokenStatus: def get_api_recovery_token_status() -> RecoveryTokenStatus:
"""Get the recovery token status, timezone-aware""" """Get the recovery token status, timezone-aware"""
token = TOKEN_REPO.get_recovery_key() token = ACTIVE_TOKEN_PROVIDER.get_recovery_key()
if token is None: if token is None:
return RecoveryTokenStatus(exists=False, valid=False) return RecoveryTokenStatus(exists=False, valid=False)
is_valid = TOKEN_REPO.is_recovery_key_valid() is_valid = ACTIVE_TOKEN_PROVIDER.is_recovery_key_valid()
# New tokens are tz-aware, but older ones might not be # New tokens are tz-aware, but older ones might not be
expiry_date = token.expires_at expiry_date = token.expires_at
@ -137,7 +135,7 @@ def get_new_api_recovery_key(
if uses_left <= 0: if uses_left <= 0:
raise InvalidUsesLeft("Uses must be greater than 0") raise InvalidUsesLeft("Uses must be greater than 0")
key = TOKEN_REPO.create_recovery_key(expiration_date, uses_left) key = ACTIVE_TOKEN_PROVIDER.create_recovery_key(expiration_date, uses_left)
mnemonic_phrase = Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key)) mnemonic_phrase = Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key))
return mnemonic_phrase return mnemonic_phrase
@ -152,21 +150,21 @@ def use_mnemonic_recovery_token(mnemonic_phrase, name):
mnemonic_phrase is a string representation of the mnemonic word list. mnemonic_phrase is a string representation of the mnemonic word list.
""" """
try: try:
token = TOKEN_REPO.use_mnemonic_recovery_key(mnemonic_phrase, name) token = ACTIVE_TOKEN_PROVIDER.use_mnemonic_recovery_key(mnemonic_phrase, name)
return token.token return token.token
except (RecoveryKeyNotFound, InvalidMnemonic): except (RecoveryKeyNotFound, InvalidMnemonic):
return None return None
def delete_new_device_auth_token() -> None: def delete_new_device_auth_token() -> None:
TOKEN_REPO.delete_new_device_key() ACTIVE_TOKEN_PROVIDER.delete_new_device_key()
def get_new_device_auth_token() -> str: def get_new_device_auth_token() -> str:
"""Generate and store a new device auth token which is valid for 10 minutes """Generate and store a new device auth token which is valid for 10 minutes
and return a mnemonic phrase representation and return a mnemonic phrase representation
""" """
key = TOKEN_REPO.get_new_device_key() key = ACTIVE_TOKEN_PROVIDER.get_new_device_key()
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key)) return Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key))
@ -176,7 +174,7 @@ def use_new_device_auth_token(mnemonic_phrase, name) -> Optional[str]:
New device auth token must be deleted. New device auth token must be deleted.
""" """
try: try:
token = TOKEN_REPO.use_mnemonic_new_device_key(mnemonic_phrase, name) token = ACTIVE_TOKEN_PROVIDER.use_mnemonic_new_device_key(mnemonic_phrase, name)
return token.token return token.token
except (NewDeviceKeyNotFound, InvalidMnemonic): except (NewDeviceKeyNotFound, InvalidMnemonic):
return None return None

View file

@ -2,12 +2,10 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from selfprivacy_api.actions.users import (
UserNotFound,
ensure_ssh_and_users_fields_exist,
)
from selfprivacy_api.utils import WriteUserData, ReadUserData, validate_ssh_public_key from selfprivacy_api.utils import WriteUserData, ReadUserData, validate_ssh_public_key
from selfprivacy_api.repositories.users.exceptions import UserNotFound
from selfprivacy_api.utils import ensure_ssh_and_users_fields_exist
def enable_ssh(): def enable_ssh():

View file

@ -2,78 +2,27 @@
import re import re
from typing import Optional from typing import Optional
from selfprivacy_api.utils import (
ReadUserData,
WriteUserData,
hash_password,
is_username_forbidden,
)
from selfprivacy_api.repositories.users.abstract_user_repository import ( from selfprivacy_api.models.user import UserDataUser
UserDataUser,
UserDataUserOrigin,
)
from selfprivacy_api.utils import hash_password, is_username_forbidden
from selfprivacy_api.repositories.users import ACTIVE_USERS_PROVIDER
from selfprivacy_api.repositories.users.exceptions import ( from selfprivacy_api.repositories.users.exceptions import (
InvalidConfiguration,
PasswordIsEmpty, PasswordIsEmpty,
UserAlreadyExists,
UserIsProtected,
UsernameForbidden, UsernameForbidden,
UsernameNotAlphanumeric, UsernameNotAlphanumeric,
UsernameTooLong, UsernameTooLong,
UserNotFound,
) )
def ensure_ssh_and_users_fields_exist(data):
if "ssh" not in data:
data["ssh"] = {}
data["ssh"]["rootKeys"] = []
elif data["ssh"].get("rootKeys") is None:
data["ssh"]["rootKeys"] = []
if "sshKeys" not in data:
data["sshKeys"] = []
if "users" not in data:
data["users"] = []
def get_users( def get_users(
exclude_primary: bool = False, exclude_primary: bool = False,
exclude_root: bool = False, exclude_root: bool = False,
) -> list[UserDataUser]: ) -> list[UserDataUser]:
"""Get the list of users""" return ACTIVE_USERS_PROVIDER.get_users(
users = [] exclude_primary=exclude_primary, exclude_root=exclude_root
with ReadUserData() as user_data: )
ensure_ssh_and_users_fields_exist(user_data)
users = [
UserDataUser(
username=user["username"],
ssh_keys=user.get("sshKeys", []),
origin=UserDataUserOrigin.NORMAL,
)
for user in user_data["users"]
]
if not exclude_primary and "username" in user_data.keys():
users.append(
UserDataUser(
username=user_data["username"],
ssh_keys=user_data["sshKeys"],
origin=UserDataUserOrigin.PRIMARY,
)
)
if not exclude_root:
users.append(
UserDataUser(
username="root",
ssh_keys=user_data["ssh"]["rootKeys"],
origin=UserDataUserOrigin.ROOT,
)
)
return users
def create_user(username: str, password: str) -> None: def create_user(username: str, password: str) -> None:
@ -91,39 +40,15 @@ def create_user(username: str, password: str) -> None:
if len(username) >= 32: if len(username) >= 32:
raise UsernameTooLong("Username must be less than 32 characters") raise UsernameTooLong("Username must be less than 32 characters")
with ReadUserData() as user_data:
ensure_ssh_and_users_fields_exist(user_data)
if "username" not in user_data.keys():
raise InvalidConfiguration(
"Broken config: Admin name is not defined. Consider recovery or add it manually"
)
if username == user_data["username"]:
raise UserAlreadyExists("User already exists")
if username in [user["username"] for user in user_data["users"]]:
raise UserAlreadyExists("User already exists")
hashed_password = hash_password(password) hashed_password = hash_password(password)
with WriteUserData() as user_data: return ACTIVE_USERS_PROVIDER.create_user(
ensure_ssh_and_users_fields_exist(user_data) username=username, hashed_password=hashed_password
)
user_data["users"].append(
{"username": username, "sshKeys": [], "hashedPassword": hashed_password}
)
def delete_user(username: str) -> None: def delete_user(username: str) -> None:
with WriteUserData() as user_data: return ACTIVE_USERS_PROVIDER.delete_user(username=username)
ensure_ssh_and_users_fields_exist(user_data)
if username == user_data["username"] or username == "root":
raise UserIsProtected("Cannot delete main or root user")
for data_user in user_data["users"]:
if data_user["username"] == username:
user_data["users"].remove(data_user)
break
else:
raise UserNotFound("User did not exist")
def update_user(username: str, password: str) -> None: def update_user(username: str, password: str) -> None:
@ -132,49 +57,10 @@ def update_user(username: str, password: str) -> None:
hashed_password = hash_password(password) hashed_password = hash_password(password)
with WriteUserData() as data: return ACTIVE_USERS_PROVIDER.update_user(
ensure_ssh_and_users_fields_exist(data) username=username, hashed_password=hashed_password
)
if username == data["username"]:
data["hashedMasterPassword"] = hashed_password
# Return 404 if user does not exist
else:
for data_user in data["users"]:
if data_user["username"] == username:
data_user["hashedPassword"] = hashed_password
break
else:
raise UserNotFound("User does not exist")
def get_user_by_username(username: str) -> Optional[UserDataUser]: def get_user_by_username(username: str) -> Optional[UserDataUser]:
with ReadUserData() as data: return ACTIVE_USERS_PROVIDER.get_user_by_username(username=username)
ensure_ssh_and_users_fields_exist(data)
if username == "root":
return UserDataUser(
origin=UserDataUserOrigin.ROOT,
username="root",
ssh_keys=data["ssh"]["rootKeys"],
)
if username == data["username"]:
return UserDataUser(
origin=UserDataUserOrigin.PRIMARY,
username=username,
ssh_keys=data["sshKeys"],
)
for user in data["users"]:
if user["username"] == username:
if "sshKeys" not in user:
user["sshKeys"] = []
return UserDataUser(
origin=UserDataUserOrigin.NORMAL,
username=username,
ssh_keys=user["sshKeys"],
)
return None

View file

@ -1,7 +1,12 @@
import typing import typing
from enum import Enum from enum import Enum
import strawberry import strawberry
from selfprivacy_api.repositories.users import ACTIVE_USERS_PROVIDER as users_actions
from selfprivacy_api.actions.users import (
get_user_by_username as actions_get_user_by_username,
)
from selfprivacy_api.actions.users import get_users as actions_get_users
from selfprivacy_api.graphql.mutations.mutation_interface import ( from selfprivacy_api.graphql.mutations.mutation_interface import (
MutationReturnInterface, MutationReturnInterface,
@ -31,7 +36,7 @@ class UserMutationReturn(MutationReturnInterface):
def get_user_by_username(username: str) -> typing.Optional[User]: def get_user_by_username(username: str) -> typing.Optional[User]:
user = users_actions.get_user_by_username(username=username) user = actions_get_user_by_username(username=username)
if user is None: if user is None:
return None return None
@ -44,7 +49,7 @@ def get_user_by_username(username: str) -> typing.Optional[User]:
def get_users() -> typing.List[User]: def get_users() -> typing.List[User]:
"""Get users""" """Get users"""
users = users_actions.get_users(exclude_root=True) users = actions_get_users(exclude_root=True)
return [ return [
User( User(
user_type=UserType(user.origin.value), user_type=UserType(user.origin.value),

View file

@ -18,7 +18,7 @@ from selfprivacy_api.actions.ssh import (
from selfprivacy_api.graphql.mutations.mutation_interface import ( from selfprivacy_api.graphql.mutations.mutation_interface import (
GenericMutationReturn, GenericMutationReturn,
) )
from selfprivacy_api.repositories.users import ACTIVE_USERS_PROVIDER as users_actions from selfprivacy_api.actions.users import create_user, delete_user, update_user
from selfprivacy_api.repositories.users.exceptions import ( from selfprivacy_api.repositories.users.exceptions import (
PasswordIsEmpty, PasswordIsEmpty,
UsernameForbidden, UsernameForbidden,
@ -54,7 +54,7 @@ class UsersMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def create_user(self, user: UserMutationInput) -> UserMutationReturn: def create_user(self, user: UserMutationInput) -> UserMutationReturn:
try: try:
users_actions.create_user(user.username, user.password) create_user(user.username, user.password)
except PasswordIsEmpty as e: except PasswordIsEmpty as e:
return UserMutationReturn( return UserMutationReturn(
success=False, success=False,
@ -103,7 +103,7 @@ class UsersMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def delete_user(self, username: str) -> GenericMutationReturn: def delete_user(self, username: str) -> GenericMutationReturn:
try: try:
users_actions.delete_user(username) delete_user(username)
except UserNotFound as e: except UserNotFound as e:
return GenericMutationReturn( return GenericMutationReturn(
success=False, success=False,
@ -127,7 +127,7 @@ class UsersMutations:
def update_user(self, user: UserMutationInput) -> UserMutationReturn: def update_user(self, user: UserMutationInput) -> UserMutationReturn:
"""Update user mutation""" """Update user mutation"""
try: try:
users_actions.update_user(user.username, user.password) update_user(user.username, user.password)
except PasswordIsEmpty as e: except PasswordIsEmpty as e:
return UserMutationReturn( return UserMutationReturn(
success=False, success=False,

View file

@ -0,0 +1,18 @@
from enum import Enum
from pydantic import BaseModel
class UserDataUserOrigin(Enum):
"""Origin of the user in the user data"""
NORMAL = "NORMAL"
PRIMARY = "PRIMARY"
ROOT = "ROOT"
class UserDataUser(BaseModel):
"""The user model from the userdata file"""
username: str
ssh_keys: list[str]
origin: UserDataUserOrigin

View file

@ -0,0 +1,5 @@
from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
RedisTokensRepository,
)
ACTIVE_TOKEN_PROVIDER = RedisTokensRepository()

View file

@ -3,18 +3,19 @@ Token repository using Redis as backend.
""" """
from typing import Any, Optional from typing import Any, Optional
from datetime import datetime from datetime import datetime, timezone
from hashlib import md5 from hashlib import md5
from datetime import timezone
from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
AbstractTokensRepository,
)
from selfprivacy_api.utils.redis_pool import RedisPool from selfprivacy_api.utils.redis_pool import RedisPool
from selfprivacy_api.models.tokens.token import Token from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.models.tokens.recovery_key import RecoveryKey from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
from selfprivacy_api.repositories.tokens.exceptions import TokenNotFound from selfprivacy_api.repositories.tokens.exceptions import TokenNotFound
from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
AbstractTokensRepository,
)
TOKENS_PREFIX = "token_repo:tokens:" TOKENS_PREFIX = "token_repo:tokens:"
NEW_DEVICE_KEY_REDIS_KEY = "token_repo:new_device_key" NEW_DEVICE_KEY_REDIS_KEY = "token_repo:new_device_key"

View file

@ -1,28 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from pydantic import BaseModel from selfprivacy_api.models.user import UserDataUser
from enum import Enum
class UserDataUserOrigin(Enum):
"""Origin of the user in the user data"""
NORMAL = "NORMAL"
PRIMARY = "PRIMARY"
ROOT = "ROOT"
class UserDataUser(BaseModel):
"""The user model from the userdata file"""
username: str
ssh_keys: list[str]
origin: UserDataUserOrigin
class AbstractUserRepository(ABC): class AbstractUserRepository(ABC):
@staticmethod
@abstractmethod @abstractmethod
def get_users( def get_users(
exclude_primary: bool = False, exclude_primary: bool = False,
@ -30,18 +14,22 @@ class AbstractUserRepository(ABC):
) -> list[UserDataUser]: ) -> list[UserDataUser]:
"""Retrieves a list of users with options to exclude specific user groups""" """Retrieves a list of users with options to exclude specific user groups"""
@staticmethod
@abstractmethod @abstractmethod
def create_user(username: str, password: str): def create_user(username: str, hashed_password: str) -> None:
"""Creates a new user""" """Creates a new user"""
@staticmethod
@abstractmethod @abstractmethod
def delete_user(username: str) -> None: def delete_user(username: str) -> None:
"""Deletes an existing user""" """Deletes an existing user"""
@staticmethod
@abstractmethod @abstractmethod
def update_user(username: str, password: str) -> None: def update_user(username: str, hashed_password: str) -> None:
"""Updates the password of an existing user""" """Updates the password of an existing user"""
@staticmethod
@abstractmethod @abstractmethod
def get_user_by_username(username: str) -> Optional[UserDataUser]: def get_user_by_username(username: str) -> Optional[UserDataUser]:
"""Retrieves user data (UserDataUser) by username""" """Retrieves user data (UserDataUser) by username"""

View file

@ -1,38 +1,138 @@
from typing import Optional from typing import Optional
from selfprivacy_api.models.user import UserDataUser, UserDataUserOrigin
from selfprivacy_api.utils import (
ReadUserData,
WriteUserData,
ensure_ssh_and_users_fields_exist,
)
from selfprivacy_api.repositories.users.abstract_user_repository import ( from selfprivacy_api.repositories.users.abstract_user_repository import (
AbstractUserRepository, AbstractUserRepository,
UserDataUser,
) )
from selfprivacy_api.repositories.users.exceptions import (
from selfprivacy_api.actions.users import ( InvalidConfiguration,
create_user, UserAlreadyExists,
delete_user, UserIsProtected,
get_user_by_username, UserNotFound,
get_users,
update_user,
) )
class JsonUserRepository(AbstractUserRepository): class JsonUserRepository(AbstractUserRepository):
@staticmethod
def get_users( def get_users(
exclude_primary: bool = False, exclude_primary: bool = False,
exclude_root: bool = False, exclude_root: bool = False,
) -> list[UserDataUser]: ) -> list[UserDataUser]:
return get_users(exclude_primary=exclude_primary, exclude_root=exclude_root) """Get the list of users"""
users = []
with ReadUserData() as user_data:
ensure_ssh_and_users_fields_exist(user_data)
users = [
UserDataUser(
username=user["username"],
ssh_keys=user.get("sshKeys", []),
origin=UserDataUserOrigin.NORMAL,
)
for user in user_data["users"]
]
if not exclude_primary and "username" in user_data.keys():
users.append(
UserDataUser(
username=user_data["username"],
ssh_keys=user_data["sshKeys"],
origin=UserDataUserOrigin.PRIMARY,
)
)
if not exclude_root:
users.append(
UserDataUser(
username="root",
ssh_keys=user_data["ssh"]["rootKeys"],
origin=UserDataUserOrigin.ROOT,
)
)
return users
def create_user(username: str, password: str): @staticmethod
"""Creates a new user""" def create_user(username: str, hashed_password: str) -> None:
return create_user(username=username, password=password) with ReadUserData() as user_data:
ensure_ssh_and_users_fields_exist(user_data)
if "username" not in user_data.keys():
raise InvalidConfiguration(
"Broken config: Admin name is not defined. Consider recovery or add it manually"
)
if username == user_data["username"]:
raise UserAlreadyExists("User already exists")
if username in [user["username"] for user in user_data["users"]]:
raise UserAlreadyExists("User already exists")
with WriteUserData() as user_data:
ensure_ssh_and_users_fields_exist(user_data)
user_data["users"].append(
{"username": username, "sshKeys": [], "hashedPassword": hashed_password}
)
@staticmethod
def delete_user(username: str) -> None: def delete_user(username: str) -> None:
"""Deletes an existing user""" with WriteUserData() as user_data:
return delete_user(username=username) ensure_ssh_and_users_fields_exist(user_data)
if username == user_data["username"] or username == "root":
raise UserIsProtected("Cannot delete main or root user")
def update_user(username: str, password: str) -> None: for data_user in user_data["users"]:
"""Updates the password of an existing user""" if data_user["username"] == username:
return update_user(username=username, password=password) user_data["users"].remove(data_user)
break
else:
raise UserNotFound("User did not exist")
@staticmethod
def update_user(username: str, hashed_password: str) -> None:
with WriteUserData() as data:
ensure_ssh_and_users_fields_exist(data)
if username == data["username"]:
data["hashedMasterPassword"] = hashed_password
# Return 404 if user does not exist
else:
for data_user in data["users"]:
if data_user["username"] == username:
data_user["hashedPassword"] = hashed_password
break
else:
raise UserNotFound("User does not exist")
@staticmethod
def get_user_by_username(username: str) -> Optional[UserDataUser]: def get_user_by_username(username: str) -> Optional[UserDataUser]:
"""Retrieves user data (UserDataUser) by username""" with ReadUserData() as data:
return get_user_by_username(username=username) ensure_ssh_and_users_fields_exist(data)
if username == "root":
return UserDataUser(
origin=UserDataUserOrigin.ROOT,
username="root",
ssh_keys=data["ssh"]["rootKeys"],
)
if username == data["username"]:
return UserDataUser(
origin=UserDataUserOrigin.PRIMARY,
username=username,
ssh_keys=data["sshKeys"],
)
for user in data["users"]:
if user["username"] == username:
if "sshKeys" not in user:
user["sshKeys"] = []
return UserDataUser(
origin=UserDataUserOrigin.NORMAL,
username=username,
ssh_keys=user["sshKeys"],
)
return None

View file

@ -1,38 +1,61 @@
from typing import Optional from typing import Optional
import requests
from selfprivacy_api.models.user import UserDataUser
from selfprivacy_api.repositories.users.abstract_user_repository import ( from selfprivacy_api.repositories.users.abstract_user_repository import (
AbstractUserRepository, AbstractUserRepository,
UserDataUser,
) )
from selfprivacy_api.utils.kanidm_manager import ( KANIDM_URL = "http://localhost:9001"
create_user,
delete_user,
get_user_by_username, class KanidmQueryError(Exception):
get_users, """Error occurred during Kanidm query"""
update_user,
)
class KanidmUserRepository(AbstractUserRepository): class KanidmUserRepository(AbstractUserRepository):
@staticmethod
def _send_query(endpoint: str, method: str = "GET", **kwargs):
request_method = getattr(requests, method.lower(), None)
try:
response = request_method(
f"{KANIDM_URL}/api/v1/{endpoint}",
params=kwargs,
timeout=0.8, # TODO: change timeout
)
if response.status_code != 200:
raise KanidmQueryError(
error=f"Kanidm returned unexpected HTTP status code. Error: {response.text}."
)
json = response.json()
return json["data"]
except Exception as error:
raise KanidmQueryError(error=f"Kanidm request failed! Error: {str(error)}")
@staticmethod
def create_user(username: str, password: str):
return KanidmUserRepository._send_query(
endpoint="person", method="POST", name=username, displayname=username
)
def get_users( def get_users(
exclude_primary: bool = False, exclude_primary: bool = False,
exclude_root: bool = False, exclude_root: bool = False,
) -> list[UserDataUser]: ) -> list[UserDataUser]:
return get_users(exclude_primary=exclude_primary, exclude_root=exclude_root) return KanidmUserRepository._send_query()
def create_user(username: str, password: str):
"""Creates a new user"""
return create_user(username=username, password=password)
def delete_user(username: str) -> None: def delete_user(username: str) -> None:
"""Deletes an existing user""" """Deletes an existing user"""
return delete_user(username=username) return KanidmUserRepository._send_query()
def update_user(username: str, password: str) -> None: def update_user(username: str, password: str) -> None:
"""Updates the password of an existing user""" """Updates the password of an existing user"""
return update_user(username=username, password=password) return KanidmUserRepository._send_query()
def get_user_by_username(username: str) -> Optional[UserDataUser]: def get_user_by_username(username: str) -> Optional[UserDataUser]:
"""Retrieves user data (UserDataUser) by username""" """Retrieves user data (UserDataUser) by username"""
return get_user_by_username(username=username) return KanidmUserRepository._send_query()

View file

@ -92,6 +92,21 @@ class ReadUserData(object):
self.userdata_file.close() self.userdata_file.close()
def ensure_ssh_and_users_fields_exist(data):
if "ssh" not in data:
data["ssh"] = {}
data["ssh"]["rootKeys"] = []
elif data["ssh"].get("rootKeys") is None:
data["ssh"]["rootKeys"] = []
if "sshKeys" not in data:
data["sshKeys"] = []
if "users" not in data:
data["users"] = []
def validate_ssh_public_key(key): def validate_ssh_public_key(key):
"""Validate SSH public key. """Validate SSH public key.
It may be ssh-ed25519, ssh-rsa or ecdsa-sha2-nistp256.""" It may be ssh-ed25519, ssh-rsa or ecdsa-sha2-nistp256."""

View file

@ -1,71 +0,0 @@
"""Kanidm queries."""
# pylint: disable=too-few-public-methods
import requests
import strawberry
from typing import Annotated, Union
KANIDM_URL = "http://localhost:9001"
@strawberry.type
class KanidmQueryError:
error: str
KanidmValuesResult = Annotated[
Union[str, KanidmQueryError], # WIP. TODO: change str
strawberry.union("KanidmValuesResult"),
]
# WIP WIP WIP WIP WIP WIP
class KanidmQueries:
@staticmethod
def _send_query(query: str) -> Union[dict, KanidmQueryError]:
try:
response = requests.get(
f"{KANIDM_URL}/api/v1/query",
params={
"query": query,
},
timeout=0.8, # TODO: change timeout
)
if response.status_code != 200:
return KanidmQueryError(
error=f"Kanidm returned unexpected HTTP status code. Error: {response.text}. The query was {query}"
)
json = response.json()
return json["data"]
except Exception as error:
return KanidmQueryError(error=f"Kanidm request failed! Error: {str(error)}")
@staticmethod
def create_user(username: str, password: str) -> KanidmValuesResult:
query = """"""
data = KanidmQueries._send_query(query=query)
if isinstance(data, KanidmQueryError):
return data
return KanidmValuesResult(data)
# def get_users(
# exclude_primary: bool = False,
# exclude_root: bool = False,
# ) -> list[UserDataUser]:
# def create_user(username: str, password: str):
# def delete_user(username: str) -> None:
# def update_user(username: str, password: str) -> None:
# def get_user_by_username(username: str) -> Optional[UserDataUser]:

View file

@ -7,7 +7,7 @@ from time import sleep
from starlette.testclient import WebSocketTestSession from starlette.testclient import WebSocketTestSession
from selfprivacy_api.jobs import Jobs from selfprivacy_api.jobs import Jobs
from selfprivacy_api.actions.api_tokens import TOKEN_REPO from selfprivacy_api.actions.api_tokens import ACTIVE_TOKEN_PROVIDER
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH
@ -75,7 +75,7 @@ def authenticated_websocket(
) -> Generator[WebSocketTestSession, None, None]: ) -> Generator[WebSocketTestSession, None, None]:
# We use authorized_client only to have token in the repo, this client by itself is not enough to authorize websocket # We use authorized_client only to have token in the repo, this client by itself is not enough to authorize websocket
ValueError(TOKEN_REPO.get_tokens()) ValueError(ACTIVE_TOKEN_PROVIDER.get_tokens())
with connect_ws_authenticated(authorized_client) as websocket: with connect_ws_authenticated(authorized_client) as websocket:
yield websocket yield websocket

View file

@ -1,4 +1,4 @@
from selfprivacy_api.utils import ReadUserData, WriteUserData from selfprivacy_api.utils import ReadUserData
from selfprivacy_api.actions.users import delete_user from selfprivacy_api.actions.users import delete_user
""" """