From a742e66cc32fcf8420f5580c8fc78ea4a82c6468 Mon Sep 17 00:00:00 2001 From: Inex Code Date: Thu, 28 Mar 2024 14:28:13 +0300 Subject: [PATCH 01/32] feat: Add "OTHER" as a server provider We should allow manual SelfPrivacy installations on unsupported server providers. The ServerProvider enum is one of the gatekeepers that prevent this and we can change it easily as not much server-side logic rely on this. The next step would be manual DNS management, but it would be much more involved than just adding the enum value. --- selfprivacy_api/graphql/queries/providers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/selfprivacy_api/graphql/queries/providers.py b/selfprivacy_api/graphql/queries/providers.py index b9ca7ef..2995fe8 100644 --- a/selfprivacy_api/graphql/queries/providers.py +++ b/selfprivacy_api/graphql/queries/providers.py @@ -14,6 +14,7 @@ class DnsProvider(Enum): class ServerProvider(Enum): HETZNER = "HETZNER" DIGITALOCEAN = "DIGITALOCEAN" + OTHER = "OTHER" @strawberry.enum From f90eb3fb4c917c1c3ac1faf9905b64b0b5ad4b8a Mon Sep 17 00:00:00 2001 From: dettlaff Date: Fri, 21 Jun 2024 23:35:04 +0300 Subject: [PATCH 02/32] feat: add flake services manager (#113) Reviewed-on: https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api/pulls/113 Reviewed-by: Inex Code Reviewed-by: houkime Co-authored-by: dettlaff Co-committed-by: dettlaff --- .../services/flake_service_manager.py | 53 ++++++++ tests/test_flake_services_manager.py | 127 ++++++++++++++++++ .../no_services.nix | 4 + .../some_services.nix | 12 ++ 4 files changed, 196 insertions(+) create mode 100644 selfprivacy_api/services/flake_service_manager.py create mode 100644 tests/test_flake_services_manager.py create mode 100644 tests/test_flake_services_manager/no_services.nix create mode 100644 tests/test_flake_services_manager/some_services.nix diff --git a/selfprivacy_api/services/flake_service_manager.py b/selfprivacy_api/services/flake_service_manager.py new file mode 100644 index 0000000..63c2279 --- /dev/null +++ b/selfprivacy_api/services/flake_service_manager.py @@ -0,0 +1,53 @@ +import re +from typing import Tuple, Optional + +FLAKE_CONFIG_PATH = "/etc/nixos/sp-modules/flake.nix" + + +class FlakeServiceManager: + def __enter__(self) -> "FlakeServiceManager": + self.services = {} + + with open(FLAKE_CONFIG_PATH, "r") as file: + for line in file: + service_name, url = self._extract_services(input_string=line) + if service_name and url: + self.services[service_name] = url + + return self + + def _extract_services( + self, input_string: str + ) -> Tuple[Optional[str], Optional[str]]: + pattern = r"inputs\.([\w-]+)\.url\s*=\s*([\S]+);" + match = re.search(pattern, input_string) + + if match: + variable_name = match.group(1) + url = match.group(2) + return variable_name, url + else: + return None, None + + def __exit__(self, exc_type, exc_value, traceback) -> None: + with open(FLAKE_CONFIG_PATH, "w") as file: + file.write( + """ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";\n +""" + ) + + for key, value in self.services.items(): + file.write( + f""" + inputs.{key}.url = {value}; +""" + ) + + file.write( + """ + outputs = _: { }; +} +""" + ) diff --git a/tests/test_flake_services_manager.py b/tests/test_flake_services_manager.py new file mode 100644 index 0000000..4650b6d --- /dev/null +++ b/tests/test_flake_services_manager.py @@ -0,0 +1,127 @@ +import pytest + +from selfprivacy_api.services.flake_service_manager import FlakeServiceManager + +all_services_file = """ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc"; + + + inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden; + + inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea; + + inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet; + + inputs.nextcloud.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud; + + inputs.ocserv.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv; + + inputs.pleroma.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma; + + inputs.simple-nixos-mailserver.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver; + + outputs = _: { }; +} +""" + + +some_services_file = """ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc"; + + + inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden; + + inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea; + + inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet; + + outputs = _: { }; +} +""" + + +@pytest.fixture +def some_services_flake_mock(mocker, datadir): + flake_config_path = datadir / "some_services.nix" + mocker.patch( + "selfprivacy_api.services.flake_service_manager.FLAKE_CONFIG_PATH", + new=flake_config_path, + ) + return flake_config_path + + +@pytest.fixture +def no_services_flake_mock(mocker, datadir): + flake_config_path = datadir / "no_services.nix" + mocker.patch( + "selfprivacy_api.services.flake_service_manager.FLAKE_CONFIG_PATH", + new=flake_config_path, + ) + return flake_config_path + + +# --- + + +def test_read_services_list(some_services_flake_mock): + with FlakeServiceManager() as manager: + services = { + "bitwarden": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden", + "gitea": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea", + "jitsi-meet": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet", + } + assert manager.services == services + + +def test_change_services_list(some_services_flake_mock): + services = { + "bitwarden": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden", + "gitea": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea", + "jitsi-meet": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet", + "nextcloud": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud", + "ocserv": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv", + "pleroma": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma", + "simple-nixos-mailserver": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver", + } + + with FlakeServiceManager() as manager: + manager.services = services + + with FlakeServiceManager() as manager: + assert manager.services == services + + with open(some_services_flake_mock, "r", encoding="utf-8") as file: + file_content = file.read().strip() + + assert all_services_file.strip() == file_content + + +def test_read_empty_services_list(no_services_flake_mock): + with FlakeServiceManager() as manager: + services = {} + assert manager.services == services + + +def test_change_empty_services_list(no_services_flake_mock): + services = { + "bitwarden": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden", + "gitea": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea", + "jitsi-meet": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet", + "nextcloud": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud", + "ocserv": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv", + "pleroma": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma", + "simple-nixos-mailserver": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver", + } + + with FlakeServiceManager() as manager: + manager.services = services + + with FlakeServiceManager() as manager: + assert manager.services == services + + with open(no_services_flake_mock, "r", encoding="utf-8") as file: + file_content = file.read().strip() + + assert all_services_file.strip() == file_content diff --git a/tests/test_flake_services_manager/no_services.nix b/tests/test_flake_services_manager/no_services.nix new file mode 100644 index 0000000..5967016 --- /dev/null +++ b/tests/test_flake_services_manager/no_services.nix @@ -0,0 +1,4 @@ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc"; + outputs = _: { }; +} diff --git a/tests/test_flake_services_manager/some_services.nix b/tests/test_flake_services_manager/some_services.nix new file mode 100644 index 0000000..4bbb919 --- /dev/null +++ b/tests/test_flake_services_manager/some_services.nix @@ -0,0 +1,12 @@ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc"; + + + inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden; + + inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea; + + inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet; + + outputs = _: { }; +} From 5602c960565ffe23a771f8f9fa221fb8ed04303d Mon Sep 17 00:00:00 2001 From: Maxim Leshchenko Date: Thu, 27 Jun 2024 17:41:46 +0300 Subject: [PATCH 03/32] feat(services): rename "sda1" to "system disk" and etc (#122) Closes #51 Reviewed-on: https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api/pulls/122 Reviewed-by: Inex Code Co-authored-by: Maxim Leshchenko Co-committed-by: Maxim Leshchenko --- selfprivacy_api/actions/services.py | 2 +- selfprivacy_api/utils/block_devices.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/selfprivacy_api/actions/services.py b/selfprivacy_api/actions/services.py index ebb0917..f9486d1 100644 --- a/selfprivacy_api/actions/services.py +++ b/selfprivacy_api/actions/services.py @@ -27,7 +27,7 @@ def move_service(service_id: str, volume_name: str) -> Job: job = Jobs.add( type_id=f"services.{service.get_id()}.move", name=f"Move {service.get_display_name()}", - description=f"Moving {service.get_display_name()} data to {volume.name}", + description=f"Moving {service.get_display_name()} data to {volume.get_display_name().lower()}", ) move_service_task(service, volume, job) diff --git a/selfprivacy_api/utils/block_devices.py b/selfprivacy_api/utils/block_devices.py index 4de5b75..0db8fe0 100644 --- a/selfprivacy_api/utils/block_devices.py +++ b/selfprivacy_api/utils/block_devices.py @@ -90,6 +90,14 @@ class BlockDevice: def __hash__(self): return hash(self.name) + def get_display_name(self) -> str: + if self.is_root(): + return "System disk" + elif self.model == "Volume": + return "Expandable volume" + else: + return self.name + def is_root(self) -> bool: """ Return True if the block device is the root device. From 7522c2d796ac8060843e2824851dc67ca19d7a8c Mon Sep 17 00:00:00 2001 From: Inex Code Date: Sun, 30 Jun 2024 23:02:07 +0400 Subject: [PATCH 04/32] refactor: Change gitea to Forgejo --- selfprivacy_api/jobs/migrate_to_binds.py | 6 ++--- selfprivacy_api/services/__init__.py | 4 +-- .../services/{gitea => forgejo}/__init__.py | 26 +++++++++++-------- .../services/{gitea => forgejo}/gitea.svg | 0 .../services/{gitea => forgejo}/icon.py | 2 +- tests/test_services_systemctl.py | 8 +++--- 6 files changed, 25 insertions(+), 21 deletions(-) rename selfprivacy_api/services/{gitea => forgejo}/__init__.py (72%) rename selfprivacy_api/services/{gitea => forgejo}/gitea.svg (100%) rename selfprivacy_api/services/{gitea => forgejo}/icon.py (98%) diff --git a/selfprivacy_api/jobs/migrate_to_binds.py b/selfprivacy_api/jobs/migrate_to_binds.py index 3250c9a..782b361 100644 --- a/selfprivacy_api/jobs/migrate_to_binds.py +++ b/selfprivacy_api/jobs/migrate_to_binds.py @@ -6,7 +6,7 @@ import shutil from pydantic import BaseModel from selfprivacy_api.jobs import Job, JobStatus, Jobs from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.nextcloud import Nextcloud from selfprivacy_api.services.pleroma import Pleroma @@ -230,7 +230,7 @@ def migrate_to_binds(config: BindMigrationConfig, job: Job): status_text="Migrating Gitea.", ) - Gitea().stop() + Forgejo().stop() if not pathlib.Path("/volumes/sda1/gitea").exists(): if not pathlib.Path("/volumes/sdb/gitea").exists(): @@ -241,7 +241,7 @@ def migrate_to_binds(config: BindMigrationConfig, job: Job): group="gitea", ) - Gitea().start() + Forgejo().start() # Perform migration of Mail server diff --git a/selfprivacy_api/services/__init__.py b/selfprivacy_api/services/__init__.py index f9dfac2..da02eba 100644 --- a/selfprivacy_api/services/__init__.py +++ b/selfprivacy_api/services/__init__.py @@ -2,7 +2,7 @@ import typing from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.jitsimeet import JitsiMeet from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.nextcloud import Nextcloud @@ -13,7 +13,7 @@ import selfprivacy_api.utils.network as network_utils services: list[Service] = [ Bitwarden(), - Gitea(), + Forgejo(), MailServer(), Nextcloud(), Pleroma(), diff --git a/selfprivacy_api/services/gitea/__init__.py b/selfprivacy_api/services/forgejo/__init__.py similarity index 72% rename from selfprivacy_api/services/gitea/__init__.py rename to selfprivacy_api/services/forgejo/__init__.py index 311d59e..d035736 100644 --- a/selfprivacy_api/services/gitea/__init__.py +++ b/selfprivacy_api/services/forgejo/__init__.py @@ -7,31 +7,34 @@ from selfprivacy_api.utils import get_domain from selfprivacy_api.utils.systemd import get_service_status from selfprivacy_api.services.service import Service, ServiceStatus -from selfprivacy_api.services.gitea.icon import GITEA_ICON +from selfprivacy_api.services.forgejo.icon import FORGEJO_ICON -class Gitea(Service): - """Class representing Gitea service""" +class Forgejo(Service): + """Class representing Forgejo service. + + Previously was Gitea, so some IDs are still called gitea for compatibility. + """ @staticmethod def get_id() -> str: - """Return service id.""" + """Return service id. For compatibility keep in gitea.""" return "gitea" @staticmethod def get_display_name() -> str: """Return service display name.""" - return "Gitea" + return "Forgejo" @staticmethod def get_description() -> str: """Return service description.""" - return "Gitea is a Git forge." + return "Forgejo is a Git forge." @staticmethod def get_svg_icon() -> str: """Read SVG icon from file and return it as base64 encoded string.""" - return base64.b64encode(GITEA_ICON.encode("utf-8")).decode("utf-8") + return base64.b64encode(FORGEJO_ICON.encode("utf-8")).decode("utf-8") @staticmethod def get_url() -> Optional[str]: @@ -65,19 +68,19 @@ class Gitea(Service): Return code 3 means service is stopped. Return code 4 means service is off. """ - return get_service_status("gitea.service") + return get_service_status("forgejo.service") @staticmethod def stop(): - subprocess.run(["systemctl", "stop", "gitea.service"]) + subprocess.run(["systemctl", "stop", "forgejo.service"]) @staticmethod def start(): - subprocess.run(["systemctl", "start", "gitea.service"]) + subprocess.run(["systemctl", "start", "forgejo.service"]) @staticmethod def restart(): - subprocess.run(["systemctl", "restart", "gitea.service"]) + subprocess.run(["systemctl", "restart", "forgejo.service"]) @staticmethod def get_configuration(): @@ -93,4 +96,5 @@ class Gitea(Service): @staticmethod def get_folders() -> List[str]: + """The data folder is still called gitea for compatibility.""" return ["/var/lib/gitea"] diff --git a/selfprivacy_api/services/gitea/gitea.svg b/selfprivacy_api/services/forgejo/gitea.svg similarity index 100% rename from selfprivacy_api/services/gitea/gitea.svg rename to selfprivacy_api/services/forgejo/gitea.svg diff --git a/selfprivacy_api/services/gitea/icon.py b/selfprivacy_api/services/forgejo/icon.py similarity index 98% rename from selfprivacy_api/services/gitea/icon.py rename to selfprivacy_api/services/forgejo/icon.py index 569f96a..5e600cf 100644 --- a/selfprivacy_api/services/gitea/icon.py +++ b/selfprivacy_api/services/forgejo/icon.py @@ -1,4 +1,4 @@ -GITEA_ICON = """ +FORGEJO_ICON = """ diff --git a/tests/test_services_systemctl.py b/tests/test_services_systemctl.py index 8b247e0..43805e8 100644 --- a/tests/test_services_systemctl.py +++ b/tests/test_services_systemctl.py @@ -2,7 +2,7 @@ import pytest from selfprivacy_api.services.service import ServiceStatus from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.nextcloud import Nextcloud from selfprivacy_api.services.ocserv import Ocserv @@ -22,7 +22,7 @@ def call_args_asserts(mocked_object): "dovecot2.service", "postfix.service", "vaultwarden.service", - "gitea.service", + "forgejo.service", "phpfpm-nextcloud.service", "ocserv.service", "pleroma.service", @@ -77,7 +77,7 @@ def mock_popen_systemctl_service_not_ok(mocker): def test_systemctl_ok(mock_popen_systemctl_service_ok): assert MailServer.get_status() == ServiceStatus.ACTIVE assert Bitwarden.get_status() == ServiceStatus.ACTIVE - assert Gitea.get_status() == ServiceStatus.ACTIVE + assert Forgejo.get_status() == ServiceStatus.ACTIVE assert Nextcloud.get_status() == ServiceStatus.ACTIVE assert Ocserv.get_status() == ServiceStatus.ACTIVE assert Pleroma.get_status() == ServiceStatus.ACTIVE @@ -87,7 +87,7 @@ def test_systemctl_ok(mock_popen_systemctl_service_ok): def test_systemctl_failed_service(mock_popen_systemctl_service_not_ok): assert MailServer.get_status() == ServiceStatus.FAILED assert Bitwarden.get_status() == ServiceStatus.FAILED - assert Gitea.get_status() == ServiceStatus.FAILED + assert Forgejo.get_status() == ServiceStatus.FAILED assert Nextcloud.get_status() == ServiceStatus.FAILED assert Ocserv.get_status() == ServiceStatus.FAILED assert Pleroma.get_status() == ServiceStatus.FAILED From 4066be38ec11aabf47b03afd35778a53c6d28942 Mon Sep 17 00:00:00 2001 From: Inex Code Date: Mon, 1 Jul 2024 19:25:54 +0400 Subject: [PATCH 05/32] chore: Bump version to 3.2.2 --- selfprivacy_api/dependencies.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/selfprivacy_api/dependencies.py b/selfprivacy_api/dependencies.py index b9d0904..69ce319 100644 --- a/selfprivacy_api/dependencies.py +++ b/selfprivacy_api/dependencies.py @@ -27,4 +27,4 @@ async def get_token_header( def get_api_version() -> str: """Get API version""" - return "3.2.1" + return "3.2.2" diff --git a/setup.py b/setup.py index 473ece8..23c544e 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="selfprivacy_api", - version="3.2.1", + version="3.2.2", packages=find_packages(), scripts=[ "selfprivacy_api/app.py", From b6118465a071a88577a1e6b2bfa59524b4094ecb Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 1 Apr 2024 20:12:02 +0000 Subject: [PATCH 06/32] feature(redis): async connections --- selfprivacy_api/utils/redis_pool.py | 19 ++++++++++++----- tests/test_redis.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) create mode 100644 tests/test_redis.py diff --git a/selfprivacy_api/utils/redis_pool.py b/selfprivacy_api/utils/redis_pool.py index 3d35f01..04ccb51 100644 --- a/selfprivacy_api/utils/redis_pool.py +++ b/selfprivacy_api/utils/redis_pool.py @@ -2,6 +2,7 @@ Redis pool module for selfprivacy_api """ import redis +import redis.asyncio as redis_async from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass @@ -14,11 +15,18 @@ class RedisPool(metaclass=SingletonMetaclass): """ def __init__(self): + url = RedisPool.connection_url(dbnumber=0) + # We need a normal sync pool because otherwise + # our whole API will need to be async self._pool = redis.ConnectionPool.from_url( - RedisPool.connection_url(dbnumber=0), + url, + decode_responses=True, + ) + # We need an async pool for pubsub + self._async_pool = redis_async.ConnectionPool.from_url( + url, decode_responses=True, ) - self._pubsub_connection = self.get_connection() @staticmethod def connection_url(dbnumber: int) -> str: @@ -34,8 +42,9 @@ class RedisPool(metaclass=SingletonMetaclass): """ return redis.Redis(connection_pool=self._pool) - def get_pubsub(self): + def get_connection_async(self) -> redis_async.Redis: """ - Get a pubsub connection from the pool. + Get an async connection from the pool. + Async connections allow pubsub. """ - return self._pubsub_connection.pubsub() + return redis_async.Redis(connection_pool=self._async_pool) diff --git a/tests/test_redis.py b/tests/test_redis.py new file mode 100644 index 0000000..48ec56e --- /dev/null +++ b/tests/test_redis.py @@ -0,0 +1,33 @@ +import asyncio +import pytest + +from selfprivacy_api.utils.redis_pool import RedisPool + +TEST_KEY = "test:test" + + +@pytest.fixture() +def empty_redis(): + r = RedisPool().get_connection() + r.flushdb() + yield r + r.flushdb() + + +async def write_to_test_key(): + r = RedisPool().get_connection_async() + async with r.pipeline(transaction=True) as pipe: + ok1, ok2 = await pipe.set(TEST_KEY, "value1").set(TEST_KEY, "value2").execute() + assert ok1 + assert ok2 + assert await r.get(TEST_KEY) == "value2" + await r.close() + + +def test_async_connection(empty_redis): + r = RedisPool().get_connection() + assert not r.exists(TEST_KEY) + # It _will_ report an error if it arises + asyncio.run(write_to_test_key()) + # Confirming that we can read result from sync connection too + assert r.get(TEST_KEY) == "value2" From 94386fc53d33656d40101a9977ea06fa51c07b8f Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 15 Apr 2024 13:35:44 +0000 Subject: [PATCH 07/32] chore(nixos): add pytest-asyncio --- flake.nix | 1 + 1 file changed, 1 insertion(+) diff --git a/flake.nix b/flake.nix index f8b81aa..ab969a4 100644 --- a/flake.nix +++ b/flake.nix @@ -20,6 +20,7 @@ pytest-datadir pytest-mock pytest-subprocess + pytest-asyncio black mypy pylsp-mypy From f08dc3ad232a16cc70f8b22fd1db08dcb58e37a6 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 15 Apr 2024 13:37:04 +0000 Subject: [PATCH 08/32] test(async): pubsub --- selfprivacy_api/utils/redis_pool.py | 12 +++++- tests/test_redis.py | 64 ++++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/selfprivacy_api/utils/redis_pool.py b/selfprivacy_api/utils/redis_pool.py index 04ccb51..ea827d1 100644 --- a/selfprivacy_api/utils/redis_pool.py +++ b/selfprivacy_api/utils/redis_pool.py @@ -9,13 +9,15 @@ from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass REDIS_SOCKET = "/run/redis-sp-api/redis.sock" -class RedisPool(metaclass=SingletonMetaclass): +# class RedisPool(metaclass=SingletonMetaclass): +class RedisPool: """ Redis connection pool singleton. """ def __init__(self): - url = RedisPool.connection_url(dbnumber=0) + self._dbnumber = 0 + url = RedisPool.connection_url(dbnumber=self._dbnumber) # We need a normal sync pool because otherwise # our whole API will need to be async self._pool = redis.ConnectionPool.from_url( @@ -48,3 +50,9 @@ class RedisPool(metaclass=SingletonMetaclass): Async connections allow pubsub. """ return redis_async.Redis(connection_pool=self._async_pool) + + async def subscribe_to_keys(self, pattern: str) -> redis_async.client.PubSub: + async_redis = self.get_connection_async() + pubsub = async_redis.pubsub() + await pubsub.psubscribe(f"__keyspace@{self._dbnumber}__:" + pattern) + return pubsub diff --git a/tests/test_redis.py b/tests/test_redis.py index 48ec56e..2def280 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -1,13 +1,18 @@ import asyncio import pytest +import pytest_asyncio +from asyncio import streams +import redis +from typing import List from selfprivacy_api.utils.redis_pool import RedisPool TEST_KEY = "test:test" +STOPWORD = "STOP" @pytest.fixture() -def empty_redis(): +def empty_redis(event_loop): r = RedisPool().get_connection() r.flushdb() yield r @@ -31,3 +36,60 @@ def test_async_connection(empty_redis): asyncio.run(write_to_test_key()) # Confirming that we can read result from sync connection too assert r.get(TEST_KEY) == "value2" + + +async def channel_reader(channel: redis.client.PubSub) -> List[dict]: + result: List[dict] = [] + while True: + # Mypy cannot correctly detect that it is a coroutine + # But it is + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + result.append(message) + if message["data"] == STOPWORD: + break + return result + + +@pytest.mark.asyncio +async def test_pubsub(empty_redis, event_loop): + # Adapted from : + # https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html + # Sanity checking because of previous event loop bugs + assert event_loop == asyncio.get_event_loop() + assert event_loop == asyncio.events.get_event_loop() + assert event_loop == asyncio.events._get_event_loop() + assert event_loop == asyncio.events.get_running_loop() + + reader = streams.StreamReader(34) + assert event_loop == reader._loop + f = reader._loop.create_future() + f.set_result(3) + await f + + r = RedisPool().get_connection_async() + async with r.pubsub() as pubsub: + await pubsub.subscribe("channel:1") + future = asyncio.create_task(channel_reader(pubsub)) + + await r.publish("channel:1", "Hello") + # message: dict = await pubsub.get_message(ignore_subscribe_messages=True, timeout=5.0) # type: ignore + # raise ValueError(message) + await r.publish("channel:1", "World") + await r.publish("channel:1", STOPWORD) + + messages = await future + + assert len(messages) == 3 + + message = messages[0] + assert "data" in message.keys() + assert message["data"] == "Hello" + message = messages[1] + assert "data" in message.keys() + assert message["data"] == "World" + message = messages[2] + assert "data" in message.keys() + assert message["data"] == STOPWORD + + await r.close() From 5558577927a4711b06890349117d86f694bd2704 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 22 Apr 2024 14:40:55 +0000 Subject: [PATCH 09/32] test(redis): test key event notifications --- tests/test_redis.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_redis.py b/tests/test_redis.py index 2def280..181d325 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -15,6 +15,8 @@ STOPWORD = "STOP" def empty_redis(event_loop): r = RedisPool().get_connection() r.flushdb() + r.config_set("notify-keyspace-events", "KEA") + assert r.config_get("notify-keyspace-events")["notify-keyspace-events"] == "AKE" yield r r.flushdb() @@ -51,6 +53,15 @@ async def channel_reader(channel: redis.client.PubSub) -> List[dict]: return result +async def channel_reader_onemessage(channel: redis.client.PubSub) -> dict: + while True: + # Mypy cannot correctly detect that it is a coroutine + # But it is + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + return message + + @pytest.mark.asyncio async def test_pubsub(empty_redis, event_loop): # Adapted from : @@ -93,3 +104,40 @@ async def test_pubsub(empty_redis, event_loop): assert message["data"] == STOPWORD await r.close() + + +@pytest.mark.asyncio +async def test_keyspace_notifications_simple(empty_redis, event_loop): + r = RedisPool().get_connection_async() + await r.set(TEST_KEY, "I am not empty") + async with r.pubsub() as pubsub: + await pubsub.subscribe("__keyspace@0__:" + TEST_KEY) + + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": None, + "type": "message", + } + + +@pytest.mark.asyncio +async def test_keyspace_notifications(empty_redis, event_loop): + pubsub = await RedisPool().subscribe_to_keys(TEST_KEY) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": f"__keyspace@0__:{TEST_KEY}", + "type": "pmessage", + } From fff8a49992c3af9ba76931e72aa0502c3450c903 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 22 Apr 2024 14:41:56 +0000 Subject: [PATCH 10/32] refactoring(jobs): break out a function returning all jobs --- selfprivacy_api/graphql/queries/jobs.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/selfprivacy_api/graphql/queries/jobs.py b/selfprivacy_api/graphql/queries/jobs.py index e7b99e6..337382a 100644 --- a/selfprivacy_api/graphql/queries/jobs.py +++ b/selfprivacy_api/graphql/queries/jobs.py @@ -11,13 +11,17 @@ from selfprivacy_api.graphql.common_types.jobs import ( from selfprivacy_api.jobs import Jobs +def get_all_jobs() -> typing.List[ApiJob]: + Jobs.get_jobs() + + return [job_to_api_job(job) for job in Jobs.get_jobs()] + + @strawberry.type class Job: @strawberry.field def get_jobs(self) -> typing.List[ApiJob]: - Jobs.get_jobs() - - return [job_to_api_job(job) for job in Jobs.get_jobs()] + return get_all_jobs() @strawberry.field def get_job(self, job_id: str) -> typing.Optional[ApiJob]: From 6510d4cac6d336139f319d879843151a8fe92335 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 22 Apr 2024 14:50:08 +0000 Subject: [PATCH 11/32] feature(redis): enable key space notifications by default --- selfprivacy_api/utils/redis_pool.py | 2 ++ tests/test_redis.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/selfprivacy_api/utils/redis_pool.py b/selfprivacy_api/utils/redis_pool.py index ea827d1..39c536f 100644 --- a/selfprivacy_api/utils/redis_pool.py +++ b/selfprivacy_api/utils/redis_pool.py @@ -29,6 +29,8 @@ class RedisPool: url, decode_responses=True, ) + # TODO: inefficient, this is probably done each time we connect + self.get_connection().config_set("notify-keyspace-events", "KEA") @staticmethod def connection_url(dbnumber: int) -> str: diff --git a/tests/test_redis.py b/tests/test_redis.py index 181d325..70ef43a 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -15,7 +15,6 @@ STOPWORD = "STOP" def empty_redis(event_loop): r = RedisPool().get_connection() r.flushdb() - r.config_set("notify-keyspace-events", "KEA") assert r.config_get("notify-keyspace-events")["notify-keyspace-events"] == "AKE" yield r r.flushdb() From 9bfffcd820d0e43d68258a98b0e7c920cdd478f7 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 6 May 2024 14:54:13 +0000 Subject: [PATCH 12/32] feature(jobs): job update generator --- selfprivacy_api/jobs/__init__.py | 29 ++++---- selfprivacy_api/utils/redis_model_storage.py | 12 +++- tests/test_redis.py | 76 ++++++++++++++++++++ 3 files changed, 102 insertions(+), 15 deletions(-) diff --git a/selfprivacy_api/jobs/__init__.py b/selfprivacy_api/jobs/__init__.py index 4649bb0..3dd48c4 100644 --- a/selfprivacy_api/jobs/__init__.py +++ b/selfprivacy_api/jobs/__init__.py @@ -15,6 +15,7 @@ A job is a dictionary with the following keys: - result: result of the job """ import typing +import asyncio import datetime from uuid import UUID import uuid @@ -23,6 +24,7 @@ from enum import Enum from pydantic import BaseModel from selfprivacy_api.utils.redis_pool import RedisPool +from selfprivacy_api.utils.redis_model_storage import store_model_as_hash JOB_EXPIRATION_SECONDS = 10 * 24 * 60 * 60 # ten days @@ -102,7 +104,7 @@ class Jobs: result=None, ) redis = RedisPool().get_connection() - _store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job) + store_model_as_hash(redis, _redis_key_from_uuid(job.uid), job) return job @staticmethod @@ -218,7 +220,7 @@ class Jobs: redis = RedisPool().get_connection() key = _redis_key_from_uuid(job.uid) if redis.exists(key): - _store_job_as_hash(redis, key, job) + store_model_as_hash(redis, key, job) if status in (JobStatus.FINISHED, JobStatus.ERROR): redis.expire(key, JOB_EXPIRATION_SECONDS) @@ -294,17 +296,6 @@ def _progress_log_key_from_uuid(uuid_string) -> str: return PROGRESS_LOGS_PREFIX + str(uuid_string) -def _store_job_as_hash(redis, redis_key, model) -> None: - for key, value in model.dict().items(): - if isinstance(value, uuid.UUID): - value = str(value) - if isinstance(value, datetime.datetime): - value = value.isoformat() - if isinstance(value, JobStatus): - value = value.value - redis.hset(redis_key, key, str(value)) - - def _job_from_hash(redis, redis_key) -> typing.Optional[Job]: if redis.exists(redis_key): job_dict = redis.hgetall(redis_key) @@ -321,3 +312,15 @@ def _job_from_hash(redis, redis_key) -> typing.Optional[Job]: return Job(**job_dict) return None + + +async def job_notifications() -> typing.AsyncGenerator[dict, None]: + channel = await RedisPool().subscribe_to_keys("jobs:*") + while True: + try: + # we cannot timeout here because we do not know when the next message is supposed to arrive + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + yield message + except GeneratorExit: + break diff --git a/selfprivacy_api/utils/redis_model_storage.py b/selfprivacy_api/utils/redis_model_storage.py index 06dfe8c..7d84210 100644 --- a/selfprivacy_api/utils/redis_model_storage.py +++ b/selfprivacy_api/utils/redis_model_storage.py @@ -1,15 +1,23 @@ +import uuid + from datetime import datetime from typing import Optional from enum import Enum def store_model_as_hash(redis, redis_key, model): - for key, value in model.dict().items(): + model_dict = model.dict() + for key, value in model_dict.items(): + if isinstance(value, uuid.UUID): + value = str(value) if isinstance(value, datetime): value = value.isoformat() if isinstance(value, Enum): value = value.value - redis.hset(redis_key, key, str(value)) + value = str(value) + model_dict[key] = value + + redis.hset(redis_key, mapping=model_dict) def hash_as_model(redis, redis_key: str, model_class): diff --git a/tests/test_redis.py b/tests/test_redis.py index 70ef43a..02dfb21 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -7,6 +7,8 @@ from typing import List from selfprivacy_api.utils.redis_pool import RedisPool +from selfprivacy_api.jobs import Jobs, job_notifications + TEST_KEY = "test:test" STOPWORD = "STOP" @@ -140,3 +142,77 @@ async def test_keyspace_notifications(empty_redis, event_loop): "pattern": f"__keyspace@0__:{TEST_KEY}", "type": "pmessage", } + + +@pytest.mark.asyncio +async def test_keyspace_notifications_patterns(empty_redis, event_loop): + pattern = "test*" + pubsub = await RedisPool().subscribe_to_keys(pattern) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": f"__keyspace@0__:{pattern}", + "type": "pmessage", + } + + +@pytest.mark.asyncio +async def test_keyspace_notifications_jobs(empty_redis, event_loop): + pattern = "jobs:*" + pubsub = await RedisPool().subscribe_to_keys(pattern) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + Jobs.add("testjob1", "test.test", "Testing aaaalll day") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message["data"] == "hset" + + +async def reader_of_jobs() -> List[dict]: + """ + Reads 3 job updates and exits + """ + result: List[dict] = [] + async for message in job_notifications(): + result.append(message) + if len(result) >= 3: + break + return result + + +@pytest.mark.asyncio +async def test_jobs_generator(empty_redis, event_loop): + # Will read exactly 3 job messages + future_messages = asyncio.create_task(reader_of_jobs()) + await asyncio.sleep(1) + + Jobs.add("testjob1", "test.test", "Testing aaaalll day") + Jobs.add("testjob2", "test.test", "Testing aaaalll day") + Jobs.add("testjob3", "test.test", "Testing aaaalll day") + Jobs.add("testjob4", "test.test", "Testing aaaalll day") + + assert len(Jobs.get_jobs()) == 4 + r = RedisPool().get_connection() + assert len(r.keys("jobs:*")) == 4 + + messages = await future_messages + assert len(messages) == 3 + channels = [message["channel"] for message in messages] + operations = [message["data"] for message in messages] + assert set(operations) == set(["hset"]) # all of them are hsets + + # Asserting that all of jobs emitted exactly one message + jobs = Jobs.get_jobs() + names = ["testjob1", "testjob2", "testjob3"] + ids = [str(job.uid) for job in jobs if job.name in names] + for id in ids: + assert id in " ".join(channels) + # Asserting that they came in order + assert "testjob4" not in " ".join(channels) From 63d2e48a98c65d16ee2f3b2c1b38b71a29888689 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 11:29:20 +0000 Subject: [PATCH 13/32] feature(jobs): websocket connection --- selfprivacy_api/app.py | 7 ++++++- tests/test_graphql/test_websocket.py | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 tests/test_graphql/test_websocket.py diff --git a/selfprivacy_api/app.py b/selfprivacy_api/app.py index 64ca85a..2f7e2f7 100644 --- a/selfprivacy_api/app.py +++ b/selfprivacy_api/app.py @@ -3,6 +3,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from strawberry.fastapi import GraphQLRouter +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL import uvicorn @@ -13,8 +14,12 @@ from selfprivacy_api.migrations import run_migrations app = FastAPI() -graphql_app = GraphQLRouter( +graphql_app: GraphQLRouter = GraphQLRouter( schema, + subscription_protocols=[ + GRAPHQL_TRANSPORT_WS_PROTOCOL, + GRAPHQL_WS_PROTOCOL, + ], ) app.add_middleware( diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py new file mode 100644 index 0000000..fb2ac33 --- /dev/null +++ b/tests/test_graphql/test_websocket.py @@ -0,0 +1,6 @@ + +def test_websocket_connection_bare(authorized_client): + client =authorized_client + with client.websocket_connect('/graphql', subprotocols=[ "graphql-transport-ws","graphql-ws"] ) as websocket: + assert websocket is not None + assert websocket.scope is not None From c4aa757ca4f23b1fd1d7a47825a7f5bcd31e8efa Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 13:01:07 +0000 Subject: [PATCH 14/32] test(jobs): test Graphql job getting --- tests/common.py | 4 +++ tests/test_graphql/test_jobs.py | 48 +++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 tests/test_graphql/test_jobs.py diff --git a/tests/common.py b/tests/common.py index 5f69f3f..8c81f48 100644 --- a/tests/common.py +++ b/tests/common.py @@ -69,6 +69,10 @@ def generate_backup_query(query_array): return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}" +def generate_jobs_query(query_array): + return "query TestJobs {\n jobs {" + "\n".join(query_array) + "}\n}" + + def generate_service_query(query_array): return "query TestService {\n services {" + "\n".join(query_array) + "}\n}" diff --git a/tests/test_graphql/test_jobs.py b/tests/test_graphql/test_jobs.py new file mode 100644 index 0000000..8dfb102 --- /dev/null +++ b/tests/test_graphql/test_jobs.py @@ -0,0 +1,48 @@ +from tests.common import generate_jobs_query +from tests.test_graphql.common import ( + assert_ok, + assert_empty, + assert_errorcode, + get_data, +) + +API_JOBS_QUERY = """ +getJobs { + uid + typeId + name + description + status + statusText + progress + createdAt + updatedAt + finishedAt + error + result +} +""" + + +def graphql_send_query(client, query: str, variables: dict = {}): + return client.post("/graphql", json={"query": query, "variables": variables}) + + +def api_jobs(authorized_client): + response = graphql_send_query( + authorized_client, generate_jobs_query([API_JOBS_QUERY]) + ) + data = get_data(response) + result = data["jobs"]["getJobs"] + assert result is not None + return result + + +def test_all_jobs_unauthorized(client): + response = graphql_send_query(client, generate_jobs_query([API_JOBS_QUERY])) + assert_empty(response) + + +def test_all_jobs_when_none(authorized_client): + output = api_jobs(authorized_client) + assert output == [] From 2d9f48650e398e72c6b22d13b0327da28e13fb26 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 13:42:17 +0000 Subject: [PATCH 15/32] test(jobs) test API job format --- tests/test_graphql/test_jobs.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_graphql/test_jobs.py b/tests/test_graphql/test_jobs.py index 8dfb102..68a6d20 100644 --- a/tests/test_graphql/test_jobs.py +++ b/tests/test_graphql/test_jobs.py @@ -1,4 +1,6 @@ from tests.common import generate_jobs_query +import tests.test_graphql.test_api_backup + from tests.test_graphql.common import ( assert_ok, assert_empty, @@ -6,6 +8,8 @@ from tests.test_graphql.common import ( get_data, ) +from selfprivacy_api.jobs import Jobs + API_JOBS_QUERY = """ getJobs { uid @@ -46,3 +50,25 @@ def test_all_jobs_unauthorized(client): def test_all_jobs_when_none(authorized_client): output = api_jobs(authorized_client) assert output == [] + + +def test_all_jobs_when_some(authorized_client): + # We cannot make new jobs via API, at least directly + job = Jobs.add("bogus", "bogus.bogus", "fungus") + output = api_jobs(authorized_client) + + len(output) == 1 + api_job = output[0] + + assert api_job["uid"] == str(job.uid) + assert api_job["typeId"] == job.type_id + assert api_job["name"] == job.name + assert api_job["description"] == job.description + assert api_job["status"] == job.status + assert api_job["statusText"] == job.status_text + assert api_job["progress"] == job.progress + assert api_job["createdAt"] == job.created_at.isoformat() + assert api_job["updatedAt"] == job.updated_at.isoformat() + assert api_job["finishedAt"] == None + assert api_job["error"] == None + assert api_job["result"] == None From 00c42d966099580d841ea32b886fec8cfe5ff10e Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 18:14:14 +0000 Subject: [PATCH 16/32] test(jobs): subscription query generating function --- tests/common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/common.py b/tests/common.py index 8c81f48..3c05033 100644 --- a/tests/common.py +++ b/tests/common.py @@ -73,6 +73,10 @@ def generate_jobs_query(query_array): return "query TestJobs {\n jobs {" + "\n".join(query_array) + "}\n}" +def generate_jobs_subscription(query_array): + return "subscription TestSubscription {\n jobs {" + "\n".join(query_array) + "}\n}" + + def generate_service_query(query_array): return "query TestService {\n services {" + "\n".join(query_array) + "}\n}" From 9add0b1dc1db9725de183d8ae8840994558ce6da Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 18:15:16 +0000 Subject: [PATCH 17/32] test(websocket) test connection init --- tests/test_graphql/test_websocket.py | 48 ++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index fb2ac33..2431285 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -1,6 +1,50 @@ +from tests.common import generate_jobs_subscription +from selfprivacy_api.graphql.queries.jobs import Job as _Job +from selfprivacy_api.jobs import Jobs + +# JOBS_SUBSCRIPTION = """ +# jobUpdates { +# uid +# typeId +# name +# description +# status +# statusText +# progress +# createdAt +# updatedAt +# finishedAt +# error +# result +# } +# """ + def test_websocket_connection_bare(authorized_client): - client =authorized_client - with client.websocket_connect('/graphql', subprotocols=[ "graphql-transport-ws","graphql-ws"] ) as websocket: + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws", "graphql-ws"] + ) as websocket: assert websocket is not None assert websocket.scope is not None + + +def test_websocket_graphql_init(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + websocket.send_json({"type": "connection_init", "payload": {}}) + ack = websocket.receive_json() + assert ack == {"type": "connection_ack"} + + +# def test_websocket_subscription(authorized_client): +# client = authorized_client +# with client.websocket_connect( +# "/graphql", subprotocols=["graphql-transport-ws", "graphql-ws"] +# ) as websocket: +# websocket.send(generate_jobs_subscription([JOBS_SUBSCRIPTION])) +# Jobs.add("bogus","bogus.bogus", "yyyaaaaayy") +# joblist = websocket.receive_json() +# raise NotImplementedError(joblist) From a2a4b461e7054f712ef19b15caf95e2b56b83d52 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 18:31:16 +0000 Subject: [PATCH 18/32] test(websocket): ping pong test --- tests/test_graphql/test_websocket.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index 2431285..ef71312 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -29,7 +29,7 @@ def test_websocket_connection_bare(authorized_client): assert websocket.scope is not None -def test_websocket_graphql_init(authorized_client): +def test_websocket_graphql_ping(authorized_client): client = authorized_client with client.websocket_connect( "/graphql", subprotocols=["graphql-transport-ws"] @@ -38,6 +38,11 @@ def test_websocket_graphql_init(authorized_client): ack = websocket.receive_json() assert ack == {"type": "connection_ack"} + # https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping + websocket.send_json({"type": "ping", "payload": {}}) + pong = websocket.receive_json() + assert pong == {"type": "pong"} + # def test_websocket_subscription(authorized_client): # client = authorized_client From f14866bdbc2ba3d54dbfbc59605c7b30452f44f9 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 18:36:17 +0000 Subject: [PATCH 19/32] test(websocket): separate ping and init --- tests/test_graphql/test_websocket.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index ef71312..d534269 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -29,7 +29,7 @@ def test_websocket_connection_bare(authorized_client): assert websocket.scope is not None -def test_websocket_graphql_ping(authorized_client): +def test_websocket_graphql_init(authorized_client): client = authorized_client with client.websocket_connect( "/graphql", subprotocols=["graphql-transport-ws"] @@ -38,6 +38,12 @@ def test_websocket_graphql_ping(authorized_client): ack = websocket.receive_json() assert ack == {"type": "connection_ack"} + +def test_websocket_graphql_ping(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: # https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping websocket.send_json({"type": "ping", "payload": {}}) pong = websocket.receive_json() From ed777e3ebf5da70ff91c2fd5e99a8312ae7250df Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 20:41:36 +0000 Subject: [PATCH 20/32] feature(jobs): add subscription endpoint --- .../graphql/subscriptions/__init__.py | 0 selfprivacy_api/graphql/subscriptions/jobs.py | 20 +++++++++++++++++++ 2 files changed, 20 insertions(+) create mode 100644 selfprivacy_api/graphql/subscriptions/__init__.py create mode 100644 selfprivacy_api/graphql/subscriptions/jobs.py diff --git a/selfprivacy_api/graphql/subscriptions/__init__.py b/selfprivacy_api/graphql/subscriptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/selfprivacy_api/graphql/subscriptions/jobs.py b/selfprivacy_api/graphql/subscriptions/jobs.py new file mode 100644 index 0000000..380badb --- /dev/null +++ b/selfprivacy_api/graphql/subscriptions/jobs.py @@ -0,0 +1,20 @@ +# pylint: disable=too-few-public-methods +import strawberry + +from typing import AsyncGenerator, List + +from selfprivacy_api.jobs import job_notifications + +from selfprivacy_api.graphql.common_types.jobs import ApiJob +from selfprivacy_api.graphql.queries.jobs import get_all_jobs + + +@strawberry.type +class JobSubscriptions: + """Subscriptions related to jobs""" + + @strawberry.subscription + async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: + # Send the complete list of jobs every time anything gets updated + async for notification in job_notifications(): + yield get_all_jobs() From cbe5c56270609cb4fb7c42c19da520a1cdfbf9d3 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 20:41:48 +0000 Subject: [PATCH 21/32] chore(jobs): shorter typehints and import sorting --- selfprivacy_api/graphql/queries/jobs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/selfprivacy_api/graphql/queries/jobs.py b/selfprivacy_api/graphql/queries/jobs.py index 337382a..3cc3bf7 100644 --- a/selfprivacy_api/graphql/queries/jobs.py +++ b/selfprivacy_api/graphql/queries/jobs.py @@ -1,17 +1,17 @@ """Jobs status""" # pylint: disable=too-few-public-methods -import typing import strawberry +from typing import List, Optional + +from selfprivacy_api.jobs import Jobs from selfprivacy_api.graphql.common_types.jobs import ( ApiJob, get_api_job_by_id, job_to_api_job, ) -from selfprivacy_api.jobs import Jobs - -def get_all_jobs() -> typing.List[ApiJob]: +def get_all_jobs() -> List[ApiJob]: Jobs.get_jobs() return [job_to_api_job(job) for job in Jobs.get_jobs()] @@ -20,9 +20,9 @@ def get_all_jobs() -> typing.List[ApiJob]: @strawberry.type class Job: @strawberry.field - def get_jobs(self) -> typing.List[ApiJob]: + def get_jobs(self) -> List[ApiJob]: return get_all_jobs() @strawberry.field - def get_job(self, job_id: str) -> typing.Optional[ApiJob]: + def get_job(self, job_id: str) -> Optional[ApiJob]: return get_api_job_by_id(job_id) From 51ccde8b0750f88dd5cb3c19b213d30c274cb5f2 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 15 May 2024 20:43:17 +0000 Subject: [PATCH 22/32] test(jobs): test simple counting --- selfprivacy_api/graphql/schema.py | 15 +++-- tests/test_graphql/test_websocket.py | 92 +++++++++++++++++++++------- 2 files changed, 81 insertions(+), 26 deletions(-) diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index e4e7264..078ee3d 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -28,6 +28,8 @@ from selfprivacy_api.graphql.queries.services import Services from selfprivacy_api.graphql.queries.storage import Storage from selfprivacy_api.graphql.queries.system import System +from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions + from selfprivacy_api.graphql.mutations.users_mutations import UsersMutations from selfprivacy_api.graphql.queries.users import Users from selfprivacy_api.jobs.test import test_job @@ -129,16 +131,19 @@ class Mutation( code=200, ) - pass - @strawberry.type class Subscription: """Root schema for subscriptions""" - @strawberry.subscription(permission_classes=[IsAuthenticated]) - async def count(self, target: int = 100) -> AsyncGenerator[int, None]: - for i in range(target): + @strawberry.field(permission_classes=[IsAuthenticated]) + def jobs(self) -> JobSubscriptions: + """Jobs subscriptions""" + return JobSubscriptions() + + @strawberry.subscription + async def count(self) -> AsyncGenerator[int, None]: + for i in range(10): yield i await asyncio.sleep(0.5) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index d534269..58681e0 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -2,22 +2,22 @@ from tests.common import generate_jobs_subscription from selfprivacy_api.graphql.queries.jobs import Job as _Job from selfprivacy_api.jobs import Jobs -# JOBS_SUBSCRIPTION = """ -# jobUpdates { -# uid -# typeId -# name -# description -# status -# statusText -# progress -# createdAt -# updatedAt -# finishedAt -# error -# result -# } -# """ +JOBS_SUBSCRIPTION = """ +jobUpdates { + uid + typeId + name + description + status + statusText + progress + createdAt + updatedAt + finishedAt + error + result +} +""" def test_websocket_connection_bare(authorized_client): @@ -50,12 +50,62 @@ def test_websocket_graphql_ping(authorized_client): assert pong == {"type": "pong"} +def init_graphql(websocket): + websocket.send_json({"type": "connection_init", "payload": {}}) + ack = websocket.receive_json() + assert ack == {"type": "connection_ack"} + + +def test_websocket_subscription_minimal(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + init_graphql(websocket) + websocket.send_json( + { + "id": "3aaa2445", + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {count}", + }, + } + ) + response = websocket.receive_json() + assert response == { + "id": "3aaa2445", + "payload": {"data": {"count": 0}}, + "type": "next", + } + response = websocket.receive_json() + assert response == { + "id": "3aaa2445", + "payload": {"data": {"count": 1}}, + "type": "next", + } + response = websocket.receive_json() + assert response == { + "id": "3aaa2445", + "payload": {"data": {"count": 2}}, + "type": "next", + } + + # def test_websocket_subscription(authorized_client): # client = authorized_client # with client.websocket_connect( -# "/graphql", subprotocols=["graphql-transport-ws", "graphql-ws"] +# "/graphql", subprotocols=["graphql-transport-ws"] # ) as websocket: -# websocket.send(generate_jobs_subscription([JOBS_SUBSCRIPTION])) -# Jobs.add("bogus","bogus.bogus", "yyyaaaaayy") -# joblist = websocket.receive_json() -# raise NotImplementedError(joblist) +# init_graphql(websocket) +# websocket.send_json( +# { +# "id": "3aaa2445", +# "type": "subscribe", +# "payload": { +# "query": generate_jobs_subscription([JOBS_SUBSCRIPTION]), +# }, +# } +# ) +# Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy") +# response = websocket.receive_json() +# raise NotImplementedError(response) From 442538ee4361b4cf045e0b0b9673f05d1985f35d Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 22 May 2024 11:04:37 +0000 Subject: [PATCH 23/32] feature(jobs): UNSAFE endpoint to get job updates --- selfprivacy_api/graphql/queries/jobs.py | 7 +- selfprivacy_api/graphql/schema.py | 19 +++-- selfprivacy_api/graphql/subscriptions/jobs.py | 14 +--- tests/test_graphql/test_websocket.py | 81 ++++++++++++++----- 4 files changed, 83 insertions(+), 38 deletions(-) diff --git a/selfprivacy_api/graphql/queries/jobs.py b/selfprivacy_api/graphql/queries/jobs.py index 3cc3bf7..6a12838 100644 --- a/selfprivacy_api/graphql/queries/jobs.py +++ b/selfprivacy_api/graphql/queries/jobs.py @@ -12,9 +12,10 @@ from selfprivacy_api.graphql.common_types.jobs import ( def get_all_jobs() -> List[ApiJob]: - Jobs.get_jobs() - - return [job_to_api_job(job) for job in Jobs.get_jobs()] + jobs = Jobs.get_jobs() + api_jobs = [job_to_api_job(job) for job in jobs] + assert api_jobs is not None + return api_jobs @strawberry.type diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index 078ee3d..b8ed4e2 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -2,7 +2,7 @@ # pylint: disable=too-few-public-methods import asyncio -from typing import AsyncGenerator +from typing import AsyncGenerator, List import strawberry from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql.mutations.deprecated_mutations import ( @@ -28,7 +28,9 @@ from selfprivacy_api.graphql.queries.services import Services from selfprivacy_api.graphql.queries.storage import Storage from selfprivacy_api.graphql.queries.system import System -from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions +from selfprivacy_api.graphql.subscriptions.jobs import ApiJob +from selfprivacy_api.jobs import job_notifications +from selfprivacy_api.graphql.queries.jobs import get_all_jobs from selfprivacy_api.graphql.mutations.users_mutations import UsersMutations from selfprivacy_api.graphql.queries.users import Users @@ -136,10 +138,15 @@ class Mutation( class Subscription: """Root schema for subscriptions""" - @strawberry.field(permission_classes=[IsAuthenticated]) - def jobs(self) -> JobSubscriptions: - """Jobs subscriptions""" - return JobSubscriptions() + @strawberry.subscription + async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: + # Send the complete list of jobs every time anything gets updated + async for notification in job_notifications(): + yield get_all_jobs() + + # @strawberry.subscription + # async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: + # return job_updates() @strawberry.subscription async def count(self) -> AsyncGenerator[int, None]: diff --git a/selfprivacy_api/graphql/subscriptions/jobs.py b/selfprivacy_api/graphql/subscriptions/jobs.py index 380badb..11d6263 100644 --- a/selfprivacy_api/graphql/subscriptions/jobs.py +++ b/selfprivacy_api/graphql/subscriptions/jobs.py @@ -1,5 +1,4 @@ # pylint: disable=too-few-public-methods -import strawberry from typing import AsyncGenerator, List @@ -9,12 +8,7 @@ from selfprivacy_api.graphql.common_types.jobs import ApiJob from selfprivacy_api.graphql.queries.jobs import get_all_jobs -@strawberry.type -class JobSubscriptions: - """Subscriptions related to jobs""" - - @strawberry.subscription - async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: - # Send the complete list of jobs every time anything gets updated - async for notification in job_notifications(): - yield get_all_jobs() +async def job_updates() -> AsyncGenerator[List[ApiJob], None]: + # Send the complete list of jobs every time anything gets updated + async for notification in job_notifications(): + yield get_all_jobs() diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index 58681e0..ee33262 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -1,6 +1,13 @@ from tests.common import generate_jobs_subscription -from selfprivacy_api.graphql.queries.jobs import Job as _Job + +# from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions +import pytest +import asyncio + from selfprivacy_api.jobs import Jobs +from time import sleep + +from tests.test_redis import empty_redis JOBS_SUBSCRIPTION = """ jobUpdates { @@ -91,21 +98,57 @@ def test_websocket_subscription_minimal(authorized_client): } -# def test_websocket_subscription(authorized_client): -# client = authorized_client -# with client.websocket_connect( -# "/graphql", subprotocols=["graphql-transport-ws"] -# ) as websocket: -# init_graphql(websocket) -# websocket.send_json( -# { -# "id": "3aaa2445", -# "type": "subscribe", -# "payload": { -# "query": generate_jobs_subscription([JOBS_SUBSCRIPTION]), -# }, -# } -# ) -# Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy") -# response = websocket.receive_json() -# raise NotImplementedError(response) +async def read_one_job(websocket): + # bug? We only get them starting from the second job update + # that's why we receive two jobs in the list them + # the first update gets lost somewhere + response = websocket.receive_json() + return response + + +@pytest.mark.asyncio +async def test_websocket_subscription(authorized_client, empty_redis, event_loop): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + init_graphql(websocket) + websocket.send_json( + { + "id": "3aaa2445", + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + + JOBS_SUBSCRIPTION + + "}", + }, + } + ) + future = asyncio.create_task(read_one_job(websocket)) + jobs = [] + jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) + sleep(0.5) + jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works")) + + response = await future + data = response["payload"]["data"] + jobs_received = data["jobUpdates"] + received_names = [job["name"] for job in jobs_received] + for job in jobs: + assert job.name in received_names + + for job in jobs: + for api_job in jobs_received: + if (job.name) == api_job["name"]: + assert api_job["uid"] == str(job.uid) + assert api_job["typeId"] == job.type_id + assert api_job["name"] == job.name + assert api_job["description"] == job.description + assert api_job["status"] == job.status + assert api_job["statusText"] == job.status_text + assert api_job["progress"] == job.progress + assert api_job["createdAt"] == job.created_at.isoformat() + assert api_job["updatedAt"] == job.updated_at.isoformat() + assert api_job["finishedAt"] == None + assert api_job["error"] == None + assert api_job["result"] == None From 0fda29cdd7489608563d1250bf0a54fb7e41599c Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 27 May 2024 18:22:20 +0000 Subject: [PATCH 24/32] test(devices): provide devices for a service test to fix conditional test fail. --- tests/test_graphql/test_services.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_graphql/test_services.py b/tests/test_graphql/test_services.py index 6e8dcf6..b7faf3d 100644 --- a/tests/test_graphql/test_services.py +++ b/tests/test_graphql/test_services.py @@ -543,8 +543,8 @@ def test_disable_enable(authorized_client, only_dummy_service): assert api_dummy_service["status"] == ServiceStatus.ACTIVE.value -def test_move_immovable(authorized_client, only_dummy_service): - dummy_service = only_dummy_service +def test_move_immovable(authorized_client, dummy_service_with_binds): + dummy_service = dummy_service_with_binds dummy_service.set_movable(False) root = BlockDevices().get_root_block_device() mutation_response = api_move(authorized_client, dummy_service, root.name) From cb641e4f37d1ec73595b058d49b93721adc2555d Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 27 May 2024 20:21:11 +0000 Subject: [PATCH 25/32] feature(websocket): add auth --- selfprivacy_api/graphql/schema.py | 18 ++- tests/test_graphql/test_websocket.py | 169 ++++++++++++++++++--------- 2 files changed, 131 insertions(+), 56 deletions(-) diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index b8ed4e2..c6cf46b 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -4,6 +4,7 @@ import asyncio from typing import AsyncGenerator, List import strawberry + from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql.mutations.deprecated_mutations import ( DeprecatedApiMutations, @@ -134,12 +135,25 @@ class Mutation( ) +# A cruft for Websockets +def authenticated(info) -> bool: + return IsAuthenticated().has_permission(source=None, info=info) + + @strawberry.type class Subscription: - """Root schema for subscriptions""" + """Root schema for subscriptions. + Every field here should be an AsyncIterator or AsyncGenerator + It is not a part of the spec but graphql-core (dep of strawberryql) + demands it while the spec is vague in this area.""" @strawberry.subscription - async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: + async def job_updates( + self, info: strawberry.types.Info + ) -> AsyncGenerator[List[ApiJob], None]: + if not authenticated(info): + raise Exception(IsAuthenticated().message) + # Send the complete list of jobs every time anything gets updated async for notification in job_notifications(): yield get_all_jobs() diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index ee33262..5a92416 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -1,13 +1,20 @@ -from tests.common import generate_jobs_subscription - # from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions import pytest import asyncio - -from selfprivacy_api.jobs import Jobs +from typing import Generator from time import sleep -from tests.test_redis import empty_redis +from starlette.testclient import WebSocketTestSession + +from selfprivacy_api.jobs import Jobs +from selfprivacy_api.actions.api_tokens import TOKEN_REPO +from selfprivacy_api.graphql import IsAuthenticated + +from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH +from tests.test_jobs import jobs as empty_jobs + +# We do not iterate through them yet +TESTED_SUBPROTOCOLS = ["graphql-transport-ws"] JOBS_SUBSCRIPTION = """ jobUpdates { @@ -27,6 +34,48 @@ jobUpdates { """ +def connect_ws_authenticated(authorized_client) -> WebSocketTestSession: + token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"]) + return authorized_client.websocket_connect( + "/graphql", + subprotocols=TESTED_SUBPROTOCOLS, + params={"token": token}, + ) + + +def connect_ws_not_authenticated(client) -> WebSocketTestSession: + return client.websocket_connect( + "/graphql", + subprotocols=TESTED_SUBPROTOCOLS, + params={"token": "I like vegan icecream but it is not a valid token"}, + ) + + +def init_graphql(websocket): + websocket.send_json({"type": "connection_init", "payload": {}}) + ack = websocket.receive_json() + assert ack == {"type": "connection_ack"} + + +@pytest.fixture +def authenticated_websocket( + authorized_client, +) -> Generator[WebSocketTestSession, None, None]: + # We use authorized_client only tohave token in the repo, this client by itself is not enough to authorize websocket + + ValueError(TOKEN_REPO.get_tokens()) + with connect_ws_authenticated(authorized_client) as websocket: + yield websocket + sleep(1) + + +@pytest.fixture +def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]: + with connect_ws_not_authenticated(client) as websocket: + yield websocket + sleep(1) + + def test_websocket_connection_bare(authorized_client): client = authorized_client with client.websocket_connect( @@ -57,12 +106,6 @@ def test_websocket_graphql_ping(authorized_client): assert pong == {"type": "pong"} -def init_graphql(websocket): - websocket.send_json({"type": "connection_init", "payload": {}}) - ack = websocket.receive_json() - assert ack == {"type": "connection_ack"} - - def test_websocket_subscription_minimal(authorized_client): client = authorized_client with client.websocket_connect( @@ -107,48 +150,66 @@ async def read_one_job(websocket): @pytest.mark.asyncio -async def test_websocket_subscription(authorized_client, empty_redis, event_loop): - client = authorized_client - with client.websocket_connect( - "/graphql", subprotocols=["graphql-transport-ws"] - ) as websocket: - init_graphql(websocket) - websocket.send_json( - { - "id": "3aaa2445", - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {" - + JOBS_SUBSCRIPTION - + "}", - }, - } - ) - future = asyncio.create_task(read_one_job(websocket)) - jobs = [] - jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) - sleep(0.5) - jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works")) +async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs): + websocket = authenticated_websocket + init_graphql(websocket) + websocket.send_json( + { + "id": "3aaa2445", + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}", + }, + } + ) + future = asyncio.create_task(read_one_job(websocket)) + jobs = [] + jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) + sleep(0.5) + jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works")) - response = await future - data = response["payload"]["data"] - jobs_received = data["jobUpdates"] - received_names = [job["name"] for job in jobs_received] - for job in jobs: - assert job.name in received_names + response = await future + data = response["payload"]["data"] + jobs_received = data["jobUpdates"] + received_names = [job["name"] for job in jobs_received] + for job in jobs: + assert job.name in received_names - for job in jobs: - for api_job in jobs_received: - if (job.name) == api_job["name"]: - assert api_job["uid"] == str(job.uid) - assert api_job["typeId"] == job.type_id - assert api_job["name"] == job.name - assert api_job["description"] == job.description - assert api_job["status"] == job.status - assert api_job["statusText"] == job.status_text - assert api_job["progress"] == job.progress - assert api_job["createdAt"] == job.created_at.isoformat() - assert api_job["updatedAt"] == job.updated_at.isoformat() - assert api_job["finishedAt"] == None - assert api_job["error"] == None - assert api_job["result"] == None + assert len(jobs_received) == 2 + + for job in jobs: + for api_job in jobs_received: + if (job.name) == api_job["name"]: + assert api_job["uid"] == str(job.uid) + assert api_job["typeId"] == job.type_id + assert api_job["name"] == job.name + assert api_job["description"] == job.description + assert api_job["status"] == job.status + assert api_job["statusText"] == job.status_text + assert api_job["progress"] == job.progress + assert api_job["createdAt"] == job.created_at.isoformat() + assert api_job["updatedAt"] == job.updated_at.isoformat() + assert api_job["finishedAt"] == None + assert api_job["error"] == None + assert api_job["result"] == None + + +def test_websocket_subscription_unauthorized(unauthenticated_websocket): + websocket = unauthenticated_websocket + init_graphql(websocket) + websocket.send_json( + { + "id": "3aaa2445", + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}", + }, + } + ) + + response = websocket.receive_json() + assert response == { + "id": "3aaa2445", + "payload": [{"message": IsAuthenticated.message}], + "type": "error", + } From ccf71078b85c0e036d51da5c816ea00665edc2e8 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 27 May 2024 20:38:51 +0000 Subject: [PATCH 26/32] feature(websocket): add auth to counter too --- selfprivacy_api/graphql/schema.py | 17 +++---- tests/test_graphql/test_websocket.py | 69 +++++++++++++++------------- 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index c6cf46b..3280396 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -136,10 +136,15 @@ class Mutation( # A cruft for Websockets -def authenticated(info) -> bool: +def authenticated(info: strawberry.types.Info) -> bool: return IsAuthenticated().has_permission(source=None, info=info) +def reject_if_unauthenticated(info: strawberry.types.Info): + if not authenticated(info): + raise Exception(IsAuthenticated().message) + + @strawberry.type class Subscription: """Root schema for subscriptions. @@ -151,19 +156,15 @@ class Subscription: async def job_updates( self, info: strawberry.types.Info ) -> AsyncGenerator[List[ApiJob], None]: - if not authenticated(info): - raise Exception(IsAuthenticated().message) + reject_if_unauthenticated(info) # Send the complete list of jobs every time anything gets updated async for notification in job_notifications(): yield get_all_jobs() - # @strawberry.subscription - # async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: - # return job_updates() - @strawberry.subscription - async def count(self) -> AsyncGenerator[int, None]: + async def count(self, info: strawberry.types.Info) -> AsyncGenerator[int, None]: + reject_if_unauthenticated(info) for i in range(10): yield i await asyncio.sleep(0.5) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index 5a92416..49cc944 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -106,41 +106,61 @@ def test_websocket_graphql_ping(authorized_client): assert pong == {"type": "pong"} +def api_subscribe(websocket, id, subscription): + websocket.send_json( + { + "id": id, + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + subscription + "}", + }, + } + ) + + def test_websocket_subscription_minimal(authorized_client): + # Test a small endpoint that exists specifically for tests client = authorized_client with client.websocket_connect( "/graphql", subprotocols=["graphql-transport-ws"] ) as websocket: init_graphql(websocket) - websocket.send_json( - { - "id": "3aaa2445", - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {count}", - }, - } - ) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") response = websocket.receive_json() assert response == { - "id": "3aaa2445", + "id": arbitrary_id, "payload": {"data": {"count": 0}}, "type": "next", } response = websocket.receive_json() assert response == { - "id": "3aaa2445", + "id": arbitrary_id, "payload": {"data": {"count": 1}}, "type": "next", } response = websocket.receive_json() assert response == { - "id": "3aaa2445", + "id": arbitrary_id, "payload": {"data": {"count": 2}}, "type": "next", } +def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket): + websocket = unauthenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") + + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": [{"message": IsAuthenticated.message}], + "type": "error", + } + + async def read_one_job(websocket): # bug? We only get them starting from the second job update # that's why we receive two jobs in the list them @@ -153,15 +173,9 @@ async def read_one_job(websocket): async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs): websocket = authenticated_websocket init_graphql(websocket) - websocket.send_json( - { - "id": "3aaa2445", - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}", - }, - } - ) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION) + future = asyncio.create_task(read_one_job(websocket)) jobs = [] jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) @@ -197,19 +211,12 @@ async def test_websocket_subscription(authenticated_websocket, event_loop, empty def test_websocket_subscription_unauthorized(unauthenticated_websocket): websocket = unauthenticated_websocket init_graphql(websocket) - websocket.send_json( - { - "id": "3aaa2445", - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}", - }, - } - ) + id = "3aaa2445" + api_subscribe(websocket, id, JOBS_SUBSCRIPTION) response = websocket.receive_json() assert response == { - "id": "3aaa2445", + "id": id, "payload": [{"message": IsAuthenticated.message}], "type": "error", } From 05ffa036b3d916720ed3276603f5db4bd890309e Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 27 May 2024 21:13:57 +0000 Subject: [PATCH 27/32] refactor(jobs): offload job subscription logic to a separate file --- selfprivacy_api/graphql/schema.py | 11 +++++------ tests/test_graphql/test_websocket.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index 3280396..05e6bf9 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -30,8 +30,9 @@ from selfprivacy_api.graphql.queries.storage import Storage from selfprivacy_api.graphql.queries.system import System from selfprivacy_api.graphql.subscriptions.jobs import ApiJob -from selfprivacy_api.jobs import job_notifications -from selfprivacy_api.graphql.queries.jobs import get_all_jobs +from selfprivacy_api.graphql.subscriptions.jobs import ( + job_updates as job_update_generator, +) from selfprivacy_api.graphql.mutations.users_mutations import UsersMutations from selfprivacy_api.graphql.queries.users import Users @@ -157,12 +158,10 @@ class Subscription: self, info: strawberry.types.Info ) -> AsyncGenerator[List[ApiJob], None]: reject_if_unauthenticated(info) - - # Send the complete list of jobs every time anything gets updated - async for notification in job_notifications(): - yield get_all_jobs() + return job_update_generator() @strawberry.subscription + # Used for testing, consider deletion to shrink attack surface async def count(self, info: strawberry.types.Info) -> AsyncGenerator[int, None]: reject_if_unauthenticated(info) for i in range(10): diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index 49cc944..d538ca1 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -162,9 +162,9 @@ def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket): async def read_one_job(websocket): - # bug? We only get them starting from the second job update - # that's why we receive two jobs in the list them - # the first update gets lost somewhere + # Bug? We only get them starting from the second job update + # That's why we receive two jobs in the list them + # The first update gets lost somewhere response = websocket.receive_json() return response @@ -215,8 +215,16 @@ def test_websocket_subscription_unauthorized(unauthenticated_websocket): api_subscribe(websocket, id, JOBS_SUBSCRIPTION) response = websocket.receive_json() + # I do not really know why strawberry gives more info on this + # One versus the counter + payload = response["payload"][0] + assert isinstance(payload, dict) + assert "locations" in payload.keys() + # It looks like this 'locations': [{'column': 32, 'line': 1}] + # We cannot test locations feasibly + del payload["locations"] assert response == { "id": id, - "payload": [{"message": IsAuthenticated.message}], + "payload": [{"message": IsAuthenticated.message, "path": ["jobUpdates"]}], "type": "error", } From 57378a794089eb39bea3ec4da0ce92066859825e Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 27 May 2024 21:15:47 +0000 Subject: [PATCH 28/32] test(websocket): remove excessive sleeping --- tests/test_graphql/test_websocket.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index d538ca1..27cfd55 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -66,14 +66,12 @@ def authenticated_websocket( ValueError(TOKEN_REPO.get_tokens()) with connect_ws_authenticated(authorized_client) as websocket: yield websocket - sleep(1) @pytest.fixture def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]: with connect_ws_not_authenticated(client) as websocket: yield websocket - sleep(1) def test_websocket_connection_bare(authorized_client): From 41f6d8b6d2078a0c62f4567ffa7279a5d9c9a198 Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 27 May 2024 21:28:29 +0000 Subject: [PATCH 29/32] test(websocket): remove some duplication --- tests/test_graphql/test_websocket.py | 75 +++++++++++++--------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index 27cfd55..754fbbf 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -34,6 +34,18 @@ jobUpdates { """ +def api_subscribe(websocket, id, subscription): + websocket.send_json( + { + "id": id, + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + subscription + "}", + }, + } + ) + + def connect_ws_authenticated(authorized_client) -> WebSocketTestSession: token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"]) return authorized_client.websocket_connect( @@ -61,7 +73,7 @@ def init_graphql(websocket): def authenticated_websocket( authorized_client, ) -> Generator[WebSocketTestSession, None, None]: - # We use authorized_client only tohave token in the repo, this client by itself is not enough to authorize websocket + # We use authorized_client only to have token in the repo, this client by itself is not enough to authorize websocket ValueError(TOKEN_REPO.get_tokens()) with connect_ws_authenticated(authorized_client) as websocket: @@ -104,45 +116,30 @@ def test_websocket_graphql_ping(authorized_client): assert pong == {"type": "pong"} -def api_subscribe(websocket, id, subscription): - websocket.send_json( - { - "id": id, - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {" + subscription + "}", - }, - } - ) - - -def test_websocket_subscription_minimal(authorized_client): +def test_websocket_subscription_minimal(authorized_client, authenticated_websocket): # Test a small endpoint that exists specifically for tests - client = authorized_client - with client.websocket_connect( - "/graphql", subprotocols=["graphql-transport-ws"] - ) as websocket: - init_graphql(websocket) - arbitrary_id = "3aaa2445" - api_subscribe(websocket, arbitrary_id, "count") - response = websocket.receive_json() - assert response == { - "id": arbitrary_id, - "payload": {"data": {"count": 0}}, - "type": "next", - } - response = websocket.receive_json() - assert response == { - "id": arbitrary_id, - "payload": {"data": {"count": 1}}, - "type": "next", - } - response = websocket.receive_json() - assert response == { - "id": arbitrary_id, - "payload": {"data": {"count": 2}}, - "type": "next", - } + websocket = authenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 0}}, + "type": "next", + } + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 1}}, + "type": "next", + } + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 2}}, + "type": "next", + } def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket): From 9accf861c51a4f96f00eeb06a4da839d4ba92cfa Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 17 Jun 2024 11:34:23 +0000 Subject: [PATCH 30/32] fix(websockets): add websockets dep so that uvicorn works --- default.nix | 1 + tests/test_websocket_uvicorn_standalone.py | 39 ++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 tests/test_websocket_uvicorn_standalone.py diff --git a/default.nix b/default.nix index e7e6fcf..f6d85d6 100644 --- a/default.nix +++ b/default.nix @@ -18,6 +18,7 @@ pythonPackages.buildPythonPackage rec { strawberry-graphql typing-extensions uvicorn + websockets ]; pythonImportsCheck = [ "selfprivacy_api" ]; doCheck = false; diff --git a/tests/test_websocket_uvicorn_standalone.py b/tests/test_websocket_uvicorn_standalone.py new file mode 100644 index 0000000..43a53ef --- /dev/null +++ b/tests/test_websocket_uvicorn_standalone.py @@ -0,0 +1,39 @@ +import pytest +from fastapi import FastAPI, WebSocket +import uvicorn + +# import subprocess +from multiprocessing import Process +import asyncio +from time import sleep +from websockets import client + +app = FastAPI() + + +@app.websocket("/") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + while True: + data = await websocket.receive_text() + await websocket.send_text(f"You sent: {data}") + + +def run_uvicorn(): + uvicorn.run(app, port=5000) + return True + + +@pytest.mark.asyncio +async def test_uvcorn_ws_works_in_prod(): + proc = Process(target=run_uvicorn) + proc.start() + sleep(2) + + ws = await client.connect("ws://127.0.0.1:5000") + + await ws.send("hohoho") + message = await ws.read_message() + assert message == "You sent: hohoho" + await ws.close() + proc.kill() From a7be03a6d31d3017bf9ffe87b02680c62e6aeb5a Mon Sep 17 00:00:00 2001 From: Inex Code Date: Thu, 4 Jul 2024 18:49:17 +0400 Subject: [PATCH 31/32] refactor: Remove setting KEA This is already done via NixOS config --- selfprivacy_api/utils/redis_pool.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/selfprivacy_api/utils/redis_pool.py b/selfprivacy_api/utils/redis_pool.py index 39c536f..ea827d1 100644 --- a/selfprivacy_api/utils/redis_pool.py +++ b/selfprivacy_api/utils/redis_pool.py @@ -29,8 +29,6 @@ class RedisPool: url, decode_responses=True, ) - # TODO: inefficient, this is probably done each time we connect - self.get_connection().config_set("notify-keyspace-events", "KEA") @staticmethod def connection_url(dbnumber: int) -> str: From ceee6e4db9a7def34d8e2193a6088b2076e39fb8 Mon Sep 17 00:00:00 2001 From: Inex Code Date: Thu, 4 Jul 2024 21:08:40 +0400 Subject: [PATCH 32/32] fix: Read auth token from the connection initialization payload Websockets do not provide headers, and sending a token as a query param is also not good (it gets into server's logs), As an alternative, we can provide a token in the first ws payload. Read more: https://strawberry.rocks/docs/general/subscriptions#authenticating-subscriptions --- selfprivacy_api/graphql/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/selfprivacy_api/graphql/__init__.py b/selfprivacy_api/graphql/__init__.py index 6124a1a..edd8a78 100644 --- a/selfprivacy_api/graphql/__init__.py +++ b/selfprivacy_api/graphql/__init__.py @@ -16,6 +16,10 @@ class IsAuthenticated(BasePermission): token = info.context["request"].headers.get("Authorization") if token is None: token = info.context["request"].query_params.get("token") + if token is None: + connection_params = info.context.get("connection_params") + if connection_params is not None: + token = connection_params.get("Authorization") if token is None: return False return is_token_valid(token.replace("Bearer ", ""))