diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 45ebd2a..a188a00 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,9 +13,9 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api), For detailed installation information, please review and follow: [link](https://nixos.org/manual/nix/stable/installation/installing-binary.html#installing-a-binary-distribution). -3. **Change directory to the cloned repository and start a nix shell:** +3. **Change directory to the cloned repository and start a nix development shell:** - ```cd selfprivacy-rest-api && nix-shell``` + ```cd selfprivacy-rest-api && nix develop``` Nix will install all of the necessary packages for development work, all further actions will take place only within nix-shell. @@ -31,7 +31,7 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api), Copy the path that starts with ```/nix/store/``` and ends with ```env/bin/python``` - ```/nix/store/???-python3-3.9.??-env/bin/python``` + ```/nix/store/???-python3-3.10.??-env/bin/python``` Click on the python version selection in the lower right corner, and replace the path to the interpreter in the project with the one you copied from the terminal. @@ -43,12 +43,13 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api), ## What to do after making changes to the repository? -**Run unit tests** using ```pytest .``` -Make sure that all tests pass successfully and the API works correctly. For convenience, you can use the built-in VScode interface. +**Run unit tests** using ```pytest-vm``` inside of the development shell. This will run all the test inside a virtual machine, which is necessary for the tests to pass successfully. +Make sure that all tests pass successfully and the API works correctly. -How to review the percentage of code coverage? Execute the command: +The ```pytest-vm``` command will also print out the coverage of the tests. To export the report to an XML file, use the following command: + +```coverage xml``` -```coverage run -m pytest && coverage xml && coverage report``` Next, use the recommended extension ```ryanluker.vscode-coverage-gutters```, navigate to one of the test files, and click the "watch" button on the bottom panel of VScode. diff --git a/default.nix b/default.nix index 15320df..1af935e 100644 --- a/default.nix +++ b/default.nix @@ -14,11 +14,12 @@ pythonPackages.buildPythonPackage rec { pydantic pytz redis + systemd setuptools strawberry-graphql typing-extensions uvicorn - requests + websockets ]; pythonImportsCheck = [ "selfprivacy_api" ]; doCheck = false; diff --git a/flake.lock b/flake.lock index 1f52d36..ba47e51 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1709677081, - "narHash": "sha256-tix36Y7u0rkn6mTm0lA45b45oab2cFLqAzDbJxeXS+c=", + "lastModified": 1719957072, + "narHash": "sha256-gvFhEf5nszouwLAkT9nWsDzocUTqLWHuL++dvNjMp9I=", "owner": "nixos", "repo": "nixpkgs", - "rev": "880992dcc006a5e00dd0591446fdf723e6a51a64", + "rev": "7144d6241f02d171d25fba3edeaf15e0f2592105", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index f8b81aa..4c8880e 100644 --- a/flake.nix +++ b/flake.nix @@ -20,6 +20,7 @@ pytest-datadir pytest-mock pytest-subprocess + pytest-asyncio black mypy pylsp-mypy @@ -65,7 +66,7 @@ SCRIPT=$(cat <&2") - machine.succeed("cd ${vmtest-src-dir} && coverage run -m pytest -v $@ >&2") + machine.succeed("cd ${vmtest-src-dir} && coverage run -m pytest $@ >&2") machine.succeed("cd ${vmtest-src-dir} && coverage report >&2") EOF ) diff --git a/selfprivacy_api/actions/services.py b/selfprivacy_api/actions/services.py index ebb0917..f9486d1 100644 --- a/selfprivacy_api/actions/services.py +++ b/selfprivacy_api/actions/services.py @@ -27,7 +27,7 @@ def move_service(service_id: str, volume_name: str) -> Job: job = Jobs.add( type_id=f"services.{service.get_id()}.move", name=f"Move {service.get_display_name()}", - description=f"Moving {service.get_display_name()} data to {volume.name}", + description=f"Moving {service.get_display_name()} data to {volume.get_display_name().lower()}", ) move_service_task(service, volume, job) diff --git a/selfprivacy_api/app.py b/selfprivacy_api/app.py index 64ca85a..2f7e2f7 100644 --- a/selfprivacy_api/app.py +++ b/selfprivacy_api/app.py @@ -3,6 +3,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from strawberry.fastapi import GraphQLRouter +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL import uvicorn @@ -13,8 +14,12 @@ from selfprivacy_api.migrations import run_migrations app = FastAPI() -graphql_app = GraphQLRouter( +graphql_app: GraphQLRouter = GraphQLRouter( schema, + subscription_protocols=[ + GRAPHQL_TRANSPORT_WS_PROTOCOL, + GRAPHQL_WS_PROTOCOL, + ], ) app.add_middleware( diff --git a/selfprivacy_api/dependencies.py b/selfprivacy_api/dependencies.py index b9d0904..b2e2b19 100644 --- a/selfprivacy_api/dependencies.py +++ b/selfprivacy_api/dependencies.py @@ -27,4 +27,4 @@ async def get_token_header( def get_api_version() -> str: """Get API version""" - return "3.2.1" + return "3.3.0" diff --git a/selfprivacy_api/graphql/__init__.py b/selfprivacy_api/graphql/__init__.py index 6124a1a..edd8a78 100644 --- a/selfprivacy_api/graphql/__init__.py +++ b/selfprivacy_api/graphql/__init__.py @@ -16,6 +16,10 @@ class IsAuthenticated(BasePermission): token = info.context["request"].headers.get("Authorization") if token is None: token = info.context["request"].query_params.get("token") + if token is None: + connection_params = info.context.get("connection_params") + if connection_params is not None: + token = connection_params.get("Authorization") if token is None: return False return is_token_valid(token.replace("Bearer ", "")) diff --git a/selfprivacy_api/graphql/queries/api_queries.py b/selfprivacy_api/graphql/queries/api_queries.py index 7052ded..77b0387 100644 --- a/selfprivacy_api/graphql/queries/api_queries.py +++ b/selfprivacy_api/graphql/queries/api_queries.py @@ -1,8 +1,10 @@ """API access status""" + # pylint: disable=too-few-public-methods import datetime import typing import strawberry + from strawberry.types import Info from selfprivacy_api.actions.api_tokens import ( get_api_tokens_with_caller_flag, diff --git a/selfprivacy_api/graphql/queries/backup.py b/selfprivacy_api/graphql/queries/backup.py index afb24ae..7695f0d 100644 --- a/selfprivacy_api/graphql/queries/backup.py +++ b/selfprivacy_api/graphql/queries/backup.py @@ -1,4 +1,5 @@ """Backup""" + # pylint: disable=too-few-public-methods import typing import strawberry diff --git a/selfprivacy_api/graphql/queries/common.py b/selfprivacy_api/graphql/queries/common.py index a1abbdc..09dbaf4 100644 --- a/selfprivacy_api/graphql/queries/common.py +++ b/selfprivacy_api/graphql/queries/common.py @@ -1,4 +1,5 @@ """Common types and enums used by different types of queries.""" + from enum import Enum import datetime import typing diff --git a/selfprivacy_api/graphql/queries/jobs.py b/selfprivacy_api/graphql/queries/jobs.py index e7b99e6..35a2182 100644 --- a/selfprivacy_api/graphql/queries/jobs.py +++ b/selfprivacy_api/graphql/queries/jobs.py @@ -1,24 +1,30 @@ """Jobs status""" + # pylint: disable=too-few-public-methods -import typing import strawberry +from typing import List, Optional + +from selfprivacy_api.jobs import Jobs from selfprivacy_api.graphql.common_types.jobs import ( ApiJob, get_api_job_by_id, job_to_api_job, ) -from selfprivacy_api.jobs import Jobs + +def get_all_jobs() -> List[ApiJob]: + jobs = Jobs.get_jobs() + api_jobs = [job_to_api_job(job) for job in jobs] + assert api_jobs is not None + return api_jobs @strawberry.type class Job: @strawberry.field - def get_jobs(self) -> typing.List[ApiJob]: - Jobs.get_jobs() - - return [job_to_api_job(job) for job in Jobs.get_jobs()] + def get_jobs(self) -> List[ApiJob]: + return get_all_jobs() @strawberry.field - def get_job(self, job_id: str) -> typing.Optional[ApiJob]: + def get_job(self, job_id: str) -> Optional[ApiJob]: return get_api_job_by_id(job_id) diff --git a/selfprivacy_api/graphql/queries/logs.py b/selfprivacy_api/graphql/queries/logs.py new file mode 100644 index 0000000..cf8fe21 --- /dev/null +++ b/selfprivacy_api/graphql/queries/logs.py @@ -0,0 +1,88 @@ +"""System logs""" +from datetime import datetime +import typing +import strawberry +from selfprivacy_api.utils.systemd_journal import get_paginated_logs + + +@strawberry.type +class LogEntry: + message: str = strawberry.field() + timestamp: datetime = strawberry.field() + priority: typing.Optional[int] = strawberry.field() + systemd_unit: typing.Optional[str] = strawberry.field() + systemd_slice: typing.Optional[str] = strawberry.field() + + def __init__(self, journal_entry: typing.Dict): + self.entry = journal_entry + self.message = journal_entry["MESSAGE"] + self.timestamp = journal_entry["__REALTIME_TIMESTAMP"] + self.priority = journal_entry.get("PRIORITY") + self.systemd_unit = journal_entry.get("_SYSTEMD_UNIT") + self.systemd_slice = journal_entry.get("_SYSTEMD_SLICE") + + @strawberry.field() + def cursor(self) -> str: + return self.entry["__CURSOR"] + + +@strawberry.type +class LogsPageMeta: + up_cursor: typing.Optional[str] = strawberry.field() + down_cursor: typing.Optional[str] = strawberry.field() + + def __init__( + self, up_cursor: typing.Optional[str], down_cursor: typing.Optional[str] + ): + self.up_cursor = up_cursor + self.down_cursor = down_cursor + + +@strawberry.type +class PaginatedEntries: + page_meta: LogsPageMeta = strawberry.field( + description="Metadata to aid in pagination." + ) + entries: typing.List[LogEntry] = strawberry.field( + description="The list of log entries." + ) + + def __init__(self, meta: LogsPageMeta, entries: typing.List[LogEntry]): + self.page_meta = meta + self.entries = entries + + @staticmethod + def from_entries(entries: typing.List[LogEntry]): + if entries == []: + return PaginatedEntries(LogsPageMeta(None, None), []) + + return PaginatedEntries( + LogsPageMeta( + entries[0].cursor(), + entries[-1].cursor(), + ), + entries, + ) + + +@strawberry.type +class Logs: + @strawberry.field() + def paginated( + self, + limit: int = 20, + # All entries returned will be lesser than this cursor. Sets upper bound on results. + up_cursor: str | None = None, + # All entries returned will be greater than this cursor. Sets lower bound on results. + down_cursor: str | None = None, + ) -> PaginatedEntries: + if limit > 50: + raise Exception("You can't fetch more than 50 entries via single request.") + return PaginatedEntries.from_entries( + list( + map( + lambda x: LogEntry(x), + get_paginated_logs(limit, up_cursor, down_cursor), + ) + ) + ) diff --git a/selfprivacy_api/graphql/queries/providers.py b/selfprivacy_api/graphql/queries/providers.py index 2995fe8..c08ea6c 100644 --- a/selfprivacy_api/graphql/queries/providers.py +++ b/selfprivacy_api/graphql/queries/providers.py @@ -1,4 +1,5 @@ """Enums representing different service providers.""" + from enum import Enum import strawberry diff --git a/selfprivacy_api/graphql/queries/services.py b/selfprivacy_api/graphql/queries/services.py index 5398f81..3085d61 100644 --- a/selfprivacy_api/graphql/queries/services.py +++ b/selfprivacy_api/graphql/queries/services.py @@ -1,4 +1,5 @@ """Services status""" + # pylint: disable=too-few-public-methods import typing import strawberry diff --git a/selfprivacy_api/graphql/queries/storage.py b/selfprivacy_api/graphql/queries/storage.py index 4b9a291..c221d26 100644 --- a/selfprivacy_api/graphql/queries/storage.py +++ b/selfprivacy_api/graphql/queries/storage.py @@ -1,4 +1,5 @@ """Storage queries.""" + # pylint: disable=too-few-public-methods import typing import strawberry @@ -18,9 +19,11 @@ class Storage: """Get list of volumes""" return [ StorageVolume( - total_space=str(volume.fssize) - if volume.fssize is not None - else str(volume.size), + total_space=( + str(volume.fssize) + if volume.fssize is not None + else str(volume.size) + ), free_space=str(volume.fsavail), used_space=str(volume.fsused), root=volume.is_root(), diff --git a/selfprivacy_api/graphql/queries/system.py b/selfprivacy_api/graphql/queries/system.py index 82c9260..55537d7 100644 --- a/selfprivacy_api/graphql/queries/system.py +++ b/selfprivacy_api/graphql/queries/system.py @@ -1,8 +1,10 @@ """Common system information and settings""" + # pylint: disable=too-few-public-methods import os import typing import strawberry + from selfprivacy_api.graphql.common_types.dns import DnsRecord from selfprivacy_api.graphql.queries.common import Alert, Severity diff --git a/selfprivacy_api/graphql/queries/users.py b/selfprivacy_api/graphql/queries/users.py index d2c0555..992ce01 100644 --- a/selfprivacy_api/graphql/queries/users.py +++ b/selfprivacy_api/graphql/queries/users.py @@ -1,4 +1,5 @@ """Users""" + # pylint: disable=too-few-public-methods import typing import strawberry diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index bcebbac..b5e6765 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -2,8 +2,10 @@ # pylint: disable=too-few-public-methods import asyncio -from typing import AsyncGenerator +from typing import AsyncGenerator, List import strawberry +from strawberry.types import Info + from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql.mutations.deprecated_mutations import ( DeprecatedApiMutations, @@ -24,11 +26,18 @@ from selfprivacy_api.graphql.mutations.backup_mutations import BackupMutations from selfprivacy_api.graphql.queries.api_queries import Api from selfprivacy_api.graphql.queries.backup import Backup from selfprivacy_api.graphql.queries.jobs import Job +from selfprivacy_api.graphql.queries.logs import LogEntry, Logs from selfprivacy_api.graphql.queries.services import Services from selfprivacy_api.graphql.queries.storage import Storage from selfprivacy_api.graphql.queries.system import System from selfprivacy_api.graphql.queries.monitoring import Monitoring +from selfprivacy_api.graphql.subscriptions.jobs import ApiJob +from selfprivacy_api.graphql.subscriptions.jobs import ( + job_updates as job_update_generator, +) +from selfprivacy_api.graphql.subscriptions.logs import log_stream + from selfprivacy_api.graphql.mutations.users_mutations import UsersMutations from selfprivacy_api.graphql.queries.users import Users from selfprivacy_api.jobs.test import test_job @@ -48,6 +57,11 @@ class Query: """System queries""" return System() + @strawberry.field(permission_classes=[IsAuthenticated]) + def logs(self) -> Logs: + """Log queries""" + return Logs() + @strawberry.field(permission_classes=[IsAuthenticated]) def users(self) -> Users: """Users queries""" @@ -135,19 +149,42 @@ class Mutation( code=200, ) - pass + +# A cruft for Websockets +def authenticated(info: Info) -> bool: + return IsAuthenticated().has_permission(source=None, info=info) + + +def reject_if_unauthenticated(info: Info): + if not authenticated(info): + raise Exception(IsAuthenticated().message) @strawberry.type class Subscription: - """Root schema for subscriptions""" + """Root schema for subscriptions. + Every field here should be an AsyncIterator or AsyncGenerator + It is not a part of the spec but graphql-core (dep of strawberryql) + demands it while the spec is vague in this area.""" - @strawberry.subscription(permission_classes=[IsAuthenticated]) - async def count(self, target: int = 100) -> AsyncGenerator[int, None]: - for i in range(target): + @strawberry.subscription + async def job_updates(self, info: Info) -> AsyncGenerator[List[ApiJob], None]: + reject_if_unauthenticated(info) + return job_update_generator() + + @strawberry.subscription + # Used for testing, consider deletion to shrink attack surface + async def count(self, info: Info) -> AsyncGenerator[int, None]: + reject_if_unauthenticated(info) + for i in range(10): yield i await asyncio.sleep(0.5) + @strawberry.subscription + async def log_entries(self, info: Info) -> AsyncGenerator[LogEntry, None]: + reject_if_unauthenticated(info) + return log_stream() + schema = strawberry.Schema( query=Query, diff --git a/selfprivacy_api/graphql/subscriptions/__init__.py b/selfprivacy_api/graphql/subscriptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/selfprivacy_api/graphql/subscriptions/jobs.py b/selfprivacy_api/graphql/subscriptions/jobs.py new file mode 100644 index 0000000..11d6263 --- /dev/null +++ b/selfprivacy_api/graphql/subscriptions/jobs.py @@ -0,0 +1,14 @@ +# pylint: disable=too-few-public-methods + +from typing import AsyncGenerator, List + +from selfprivacy_api.jobs import job_notifications + +from selfprivacy_api.graphql.common_types.jobs import ApiJob +from selfprivacy_api.graphql.queries.jobs import get_all_jobs + + +async def job_updates() -> AsyncGenerator[List[ApiJob], None]: + # Send the complete list of jobs every time anything gets updated + async for notification in job_notifications(): + yield get_all_jobs() diff --git a/selfprivacy_api/graphql/subscriptions/logs.py b/selfprivacy_api/graphql/subscriptions/logs.py new file mode 100644 index 0000000..1e47dba --- /dev/null +++ b/selfprivacy_api/graphql/subscriptions/logs.py @@ -0,0 +1,31 @@ +from typing import AsyncGenerator +from systemd import journal +import asyncio + +from selfprivacy_api.graphql.queries.logs import LogEntry + + +async def log_stream() -> AsyncGenerator[LogEntry, None]: + j = journal.Reader() + + j.seek_tail() + j.get_previous() + + queue = asyncio.Queue() + + async def callback(): + if j.process() != journal.APPEND: + return + for entry in j: + await queue.put(entry) + + asyncio.get_event_loop().add_reader(j, lambda: asyncio.ensure_future(callback())) + + while True: + entry = await queue.get() + try: + yield LogEntry(entry) + except Exception: + asyncio.get_event_loop().remove_reader(j) + return + queue.task_done() diff --git a/selfprivacy_api/jobs/__init__.py b/selfprivacy_api/jobs/__init__.py index 4649bb0..3dd48c4 100644 --- a/selfprivacy_api/jobs/__init__.py +++ b/selfprivacy_api/jobs/__init__.py @@ -15,6 +15,7 @@ A job is a dictionary with the following keys: - result: result of the job """ import typing +import asyncio import datetime from uuid import UUID import uuid @@ -23,6 +24,7 @@ from enum import Enum from pydantic import BaseModel from selfprivacy_api.utils.redis_pool import RedisPool +from selfprivacy_api.utils.redis_model_storage import store_model_as_hash JOB_EXPIRATION_SECONDS = 10 * 24 * 60 * 60 # ten days @@ -102,7 +104,7 @@ class Jobs: result=None, ) redis = RedisPool().get_connection() - _store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job) + store_model_as_hash(redis, _redis_key_from_uuid(job.uid), job) return job @staticmethod @@ -218,7 +220,7 @@ class Jobs: redis = RedisPool().get_connection() key = _redis_key_from_uuid(job.uid) if redis.exists(key): - _store_job_as_hash(redis, key, job) + store_model_as_hash(redis, key, job) if status in (JobStatus.FINISHED, JobStatus.ERROR): redis.expire(key, JOB_EXPIRATION_SECONDS) @@ -294,17 +296,6 @@ def _progress_log_key_from_uuid(uuid_string) -> str: return PROGRESS_LOGS_PREFIX + str(uuid_string) -def _store_job_as_hash(redis, redis_key, model) -> None: - for key, value in model.dict().items(): - if isinstance(value, uuid.UUID): - value = str(value) - if isinstance(value, datetime.datetime): - value = value.isoformat() - if isinstance(value, JobStatus): - value = value.value - redis.hset(redis_key, key, str(value)) - - def _job_from_hash(redis, redis_key) -> typing.Optional[Job]: if redis.exists(redis_key): job_dict = redis.hgetall(redis_key) @@ -321,3 +312,15 @@ def _job_from_hash(redis, redis_key) -> typing.Optional[Job]: return Job(**job_dict) return None + + +async def job_notifications() -> typing.AsyncGenerator[dict, None]: + channel = await RedisPool().subscribe_to_keys("jobs:*") + while True: + try: + # we cannot timeout here because we do not know when the next message is supposed to arrive + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + yield message + except GeneratorExit: + break diff --git a/selfprivacy_api/jobs/migrate_to_binds.py b/selfprivacy_api/jobs/migrate_to_binds.py index 3250c9a..782b361 100644 --- a/selfprivacy_api/jobs/migrate_to_binds.py +++ b/selfprivacy_api/jobs/migrate_to_binds.py @@ -6,7 +6,7 @@ import shutil from pydantic import BaseModel from selfprivacy_api.jobs import Job, JobStatus, Jobs from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.nextcloud import Nextcloud from selfprivacy_api.services.pleroma import Pleroma @@ -230,7 +230,7 @@ def migrate_to_binds(config: BindMigrationConfig, job: Job): status_text="Migrating Gitea.", ) - Gitea().stop() + Forgejo().stop() if not pathlib.Path("/volumes/sda1/gitea").exists(): if not pathlib.Path("/volumes/sdb/gitea").exists(): @@ -241,7 +241,7 @@ def migrate_to_binds(config: BindMigrationConfig, job: Job): group="gitea", ) - Gitea().start() + Forgejo().start() # Perform migration of Mail server diff --git a/selfprivacy_api/migrations/__init__.py b/selfprivacy_api/migrations/__init__.py index de426d6..5eb3194 100644 --- a/selfprivacy_api/migrations/__init__.py +++ b/selfprivacy_api/migrations/__init__.py @@ -14,12 +14,14 @@ from selfprivacy_api.migrations.write_token_to_redis import WriteTokenToRedis from selfprivacy_api.migrations.check_for_system_rebuild_jobs import ( CheckForSystemRebuildJobs, ) +from selfprivacy_api.migrations.add_roundcube import AddRoundcube from selfprivacy_api.migrations.add_prometheus import AddPrometheus migrations = [ WriteTokenToRedis(), CheckForSystemRebuildJobs(), AddPrometheus(), + AddRoundcube(), ] diff --git a/selfprivacy_api/migrations/add_roundcube.py b/selfprivacy_api/migrations/add_roundcube.py new file mode 100644 index 0000000..3c422c2 --- /dev/null +++ b/selfprivacy_api/migrations/add_roundcube.py @@ -0,0 +1,36 @@ +from selfprivacy_api.migrations.migration import Migration + +from selfprivacy_api.services.flake_service_manager import FlakeServiceManager +from selfprivacy_api.utils import ReadUserData, WriteUserData + + +class AddRoundcube(Migration): + """Adds the Roundcube if it is not present.""" + + def get_migration_name(self) -> str: + return "add_roundcube" + + def get_migration_description(self) -> str: + return "Adds the Roundcube if it is not present." + + def is_migration_needed(self) -> bool: + with FlakeServiceManager() as manager: + if "roundcube" not in manager.services: + return True + with ReadUserData() as data: + if "roundcube" not in data["modules"]: + return True + return False + + def migrate(self) -> None: + with FlakeServiceManager() as manager: + if "roundcube" not in manager.services: + manager.services[ + "roundcube" + ] = "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/roundcube" + with WriteUserData() as data: + if "roundcube" not in data["modules"]: + data["modules"]["roundcube"] = { + "enable": False, + "subdomain": "roundcube", + } diff --git a/selfprivacy_api/migrations/check_for_system_rebuild_jobs.py b/selfprivacy_api/migrations/check_for_system_rebuild_jobs.py index 9bbac8a..bb8eb74 100644 --- a/selfprivacy_api/migrations/check_for_system_rebuild_jobs.py +++ b/selfprivacy_api/migrations/check_for_system_rebuild_jobs.py @@ -5,13 +5,13 @@ from selfprivacy_api.jobs import JobStatus, Jobs class CheckForSystemRebuildJobs(Migration): """Check if there are unfinished system rebuild jobs and finish them""" - def get_migration_name(self): + def get_migration_name(self) -> str: return "check_for_system_rebuild_jobs" - def get_migration_description(self): + def get_migration_description(self) -> str: return "Check if there are unfinished system rebuild jobs and finish them" - def is_migration_needed(self): + def is_migration_needed(self) -> bool: # Check if there are any unfinished system rebuild jobs for job in Jobs.get_jobs(): if ( @@ -25,8 +25,9 @@ class CheckForSystemRebuildJobs(Migration): JobStatus.RUNNING, ]: return True + return False - def migrate(self): + def migrate(self) -> None: # As the API is restarted, we assume that the jobs are finished for job in Jobs.get_jobs(): if ( diff --git a/selfprivacy_api/migrations/migration.py b/selfprivacy_api/migrations/migration.py index 1116672..8eb047d 100644 --- a/selfprivacy_api/migrations/migration.py +++ b/selfprivacy_api/migrations/migration.py @@ -12,17 +12,17 @@ class Migration(ABC): """ @abstractmethod - def get_migration_name(self): + def get_migration_name(self) -> str: pass @abstractmethod - def get_migration_description(self): + def get_migration_description(self) -> str: pass @abstractmethod - def is_migration_needed(self): + def is_migration_needed(self) -> bool: pass @abstractmethod - def migrate(self): + def migrate(self) -> None: pass diff --git a/selfprivacy_api/migrations/write_token_to_redis.py b/selfprivacy_api/migrations/write_token_to_redis.py index aab4f72..ccf1c04 100644 --- a/selfprivacy_api/migrations/write_token_to_redis.py +++ b/selfprivacy_api/migrations/write_token_to_redis.py @@ -15,10 +15,10 @@ from selfprivacy_api.utils import ReadUserData, UserDataFiles class WriteTokenToRedis(Migration): """Load Json tokens into Redis""" - def get_migration_name(self): + def get_migration_name(self) -> str: return "write_token_to_redis" - def get_migration_description(self): + def get_migration_description(self) -> str: return "Loads the initial token into redis token storage" def is_repo_empty(self, repo: AbstractTokensRepository) -> bool: @@ -38,7 +38,7 @@ class WriteTokenToRedis(Migration): print(e) return None - def is_migration_needed(self): + def is_migration_needed(self) -> bool: try: if self.get_token_from_json() is not None and self.is_repo_empty( RedisTokensRepository() @@ -47,8 +47,9 @@ class WriteTokenToRedis(Migration): except Exception as e: print(e) return False + return False - def migrate(self): + def migrate(self) -> None: # Write info about providers to userdata.json try: token = self.get_token_from_json() diff --git a/selfprivacy_api/services/__init__.py b/selfprivacy_api/services/__init__.py index f9dfac2..5a2414c 100644 --- a/selfprivacy_api/services/__init__.py +++ b/selfprivacy_api/services/__init__.py @@ -2,8 +2,9 @@ import typing from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.jitsimeet import JitsiMeet +from selfprivacy_api.services.roundcube import Roundcube from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.nextcloud import Nextcloud from selfprivacy_api.services.pleroma import Pleroma @@ -13,12 +14,13 @@ import selfprivacy_api.utils.network as network_utils services: list[Service] = [ Bitwarden(), - Gitea(), + Forgejo(), MailServer(), Nextcloud(), Pleroma(), Ocserv(), JitsiMeet(), + Roundcube(), ] diff --git a/selfprivacy_api/services/bitwarden/__init__.py b/selfprivacy_api/services/bitwarden/__init__.py index 52f1466..56ee6e5 100644 --- a/selfprivacy_api/services/bitwarden/__init__.py +++ b/selfprivacy_api/services/bitwarden/__init__.py @@ -37,14 +37,14 @@ class Bitwarden(Service): def get_user() -> str: return "vaultwarden" - @staticmethod - def get_url() -> Optional[str]: + @classmethod + def get_url(cls) -> Optional[str]: """Return service url.""" domain = get_domain() return f"https://password.{domain}" - @staticmethod - def get_subdomain() -> Optional[str]: + @classmethod + def get_subdomain(cls) -> Optional[str]: return "password" @staticmethod diff --git a/selfprivacy_api/services/flake_service_manager.py b/selfprivacy_api/services/flake_service_manager.py new file mode 100644 index 0000000..8b76e5d --- /dev/null +++ b/selfprivacy_api/services/flake_service_manager.py @@ -0,0 +1,53 @@ +import re +from typing import Tuple, Optional + +FLAKE_CONFIG_PATH = "/etc/nixos/sp-modules/flake.nix" + + +class FlakeServiceManager: + def __enter__(self) -> "FlakeServiceManager": + self.services = {} + + with open(FLAKE_CONFIG_PATH, "r") as file: + for line in file: + service_name, url = self._extract_services(input_string=line) + if service_name and url: + self.services[service_name] = url + + return self + + def _extract_services( + self, input_string: str + ) -> Tuple[Optional[str], Optional[str]]: + pattern = r"inputs\.([\w-]+)\.url\s*=\s*([\S]+);" + match = re.search(pattern, input_string) + + if match: + variable_name = match.group(1) + url = match.group(2) + return variable_name, url + else: + return None, None + + def __exit__(self, exc_type, exc_value, traceback) -> None: + with open(FLAKE_CONFIG_PATH, "w") as file: + file.write( + """ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";\n +""" + ) + + for key, value in self.services.items(): + file.write( + f""" + inputs.{key}.url = {value}; +""" + ) + + file.write( + """ + outputs = _: { }; +} +""" + ) diff --git a/selfprivacy_api/services/gitea/__init__.py b/selfprivacy_api/services/forgejo/__init__.py similarity index 68% rename from selfprivacy_api/services/gitea/__init__.py rename to selfprivacy_api/services/forgejo/__init__.py index 311d59e..06cf614 100644 --- a/selfprivacy_api/services/gitea/__init__.py +++ b/selfprivacy_api/services/forgejo/__init__.py @@ -7,40 +7,43 @@ from selfprivacy_api.utils import get_domain from selfprivacy_api.utils.systemd import get_service_status from selfprivacy_api.services.service import Service, ServiceStatus -from selfprivacy_api.services.gitea.icon import GITEA_ICON +from selfprivacy_api.services.forgejo.icon import FORGEJO_ICON -class Gitea(Service): - """Class representing Gitea service""" +class Forgejo(Service): + """Class representing Forgejo service. + + Previously was Gitea, so some IDs are still called gitea for compatibility. + """ @staticmethod def get_id() -> str: - """Return service id.""" + """Return service id. For compatibility keep in gitea.""" return "gitea" @staticmethod def get_display_name() -> str: """Return service display name.""" - return "Gitea" + return "Forgejo" @staticmethod def get_description() -> str: """Return service description.""" - return "Gitea is a Git forge." + return "Forgejo is a Git forge." @staticmethod def get_svg_icon() -> str: """Read SVG icon from file and return it as base64 encoded string.""" - return base64.b64encode(GITEA_ICON.encode("utf-8")).decode("utf-8") + return base64.b64encode(FORGEJO_ICON.encode("utf-8")).decode("utf-8") - @staticmethod - def get_url() -> Optional[str]: + @classmethod + def get_url(cls) -> Optional[str]: """Return service url.""" domain = get_domain() return f"https://git.{domain}" - @staticmethod - def get_subdomain() -> Optional[str]: + @classmethod + def get_subdomain(cls) -> Optional[str]: return "git" @staticmethod @@ -65,19 +68,19 @@ class Gitea(Service): Return code 3 means service is stopped. Return code 4 means service is off. """ - return get_service_status("gitea.service") + return get_service_status("forgejo.service") @staticmethod def stop(): - subprocess.run(["systemctl", "stop", "gitea.service"]) + subprocess.run(["systemctl", "stop", "forgejo.service"]) @staticmethod def start(): - subprocess.run(["systemctl", "start", "gitea.service"]) + subprocess.run(["systemctl", "start", "forgejo.service"]) @staticmethod def restart(): - subprocess.run(["systemctl", "restart", "gitea.service"]) + subprocess.run(["systemctl", "restart", "forgejo.service"]) @staticmethod def get_configuration(): @@ -93,4 +96,5 @@ class Gitea(Service): @staticmethod def get_folders() -> List[str]: + """The data folder is still called gitea for compatibility.""" return ["/var/lib/gitea"] diff --git a/selfprivacy_api/services/gitea/gitea.svg b/selfprivacy_api/services/forgejo/gitea.svg similarity index 100% rename from selfprivacy_api/services/gitea/gitea.svg rename to selfprivacy_api/services/forgejo/gitea.svg diff --git a/selfprivacy_api/services/gitea/icon.py b/selfprivacy_api/services/forgejo/icon.py similarity index 98% rename from selfprivacy_api/services/gitea/icon.py rename to selfprivacy_api/services/forgejo/icon.py index 569f96a..5e600cf 100644 --- a/selfprivacy_api/services/gitea/icon.py +++ b/selfprivacy_api/services/forgejo/icon.py @@ -1,4 +1,4 @@ -GITEA_ICON = """ +FORGEJO_ICON = """ diff --git a/selfprivacy_api/services/jitsimeet/__init__.py b/selfprivacy_api/services/jitsimeet/__init__.py index 53d572c..27a497a 100644 --- a/selfprivacy_api/services/jitsimeet/__init__.py +++ b/selfprivacy_api/services/jitsimeet/__init__.py @@ -36,14 +36,14 @@ class JitsiMeet(Service): """Read SVG icon from file and return it as base64 encoded string.""" return base64.b64encode(JITSI_ICON.encode("utf-8")).decode("utf-8") - @staticmethod - def get_url() -> Optional[str]: + @classmethod + def get_url(cls) -> Optional[str]: """Return service url.""" domain = get_domain() return f"https://meet.{domain}" - @staticmethod - def get_subdomain() -> Optional[str]: + @classmethod + def get_subdomain(cls) -> Optional[str]: return "meet" @staticmethod diff --git a/selfprivacy_api/services/mailserver/__init__.py b/selfprivacy_api/services/mailserver/__init__.py index d2e9b5d..aba302d 100644 --- a/selfprivacy_api/services/mailserver/__init__.py +++ b/selfprivacy_api/services/mailserver/__init__.py @@ -35,13 +35,13 @@ class MailServer(Service): def get_user() -> str: return "virtualMail" - @staticmethod - def get_url() -> Optional[str]: + @classmethod + def get_url(cls) -> Optional[str]: """Return service url.""" return None - @staticmethod - def get_subdomain() -> Optional[str]: + @classmethod + def get_subdomain(cls) -> Optional[str]: return None @staticmethod diff --git a/selfprivacy_api/services/nextcloud/__init__.py b/selfprivacy_api/services/nextcloud/__init__.py index 3e5b8d3..275b11d 100644 --- a/selfprivacy_api/services/nextcloud/__init__.py +++ b/selfprivacy_api/services/nextcloud/__init__.py @@ -4,7 +4,6 @@ import subprocess from typing import Optional, List from selfprivacy_api.utils import get_domain -from selfprivacy_api.jobs import Job, Jobs from selfprivacy_api.utils.systemd import get_service_status from selfprivacy_api.services.service import Service, ServiceStatus @@ -35,14 +34,14 @@ class Nextcloud(Service): """Read SVG icon from file and return it as base64 encoded string.""" return base64.b64encode(NEXTCLOUD_ICON.encode("utf-8")).decode("utf-8") - @staticmethod - def get_url() -> Optional[str]: + @classmethod + def get_url(cls) -> Optional[str]: """Return service url.""" domain = get_domain() return f"https://cloud.{domain}" - @staticmethod - def get_subdomain() -> Optional[str]: + @classmethod + def get_subdomain(cls) -> Optional[str]: return "cloud" @staticmethod diff --git a/selfprivacy_api/services/ocserv/__init__.py b/selfprivacy_api/services/ocserv/__init__.py index 4dd802f..f600772 100644 --- a/selfprivacy_api/services/ocserv/__init__.py +++ b/selfprivacy_api/services/ocserv/__init__.py @@ -28,13 +28,13 @@ class Ocserv(Service): def get_svg_icon() -> str: return base64.b64encode(OCSERV_ICON.encode("utf-8")).decode("utf-8") - @staticmethod - def get_url() -> typing.Optional[str]: + @classmethod + def get_url(cls) -> typing.Optional[str]: """Return service url.""" return None - @staticmethod - def get_subdomain() -> typing.Optional[str]: + @classmethod + def get_subdomain(cls) -> typing.Optional[str]: return "vpn" @staticmethod diff --git a/selfprivacy_api/services/pleroma/__init__.py b/selfprivacy_api/services/pleroma/__init__.py index 44a9be8..64edd96 100644 --- a/selfprivacy_api/services/pleroma/__init__.py +++ b/selfprivacy_api/services/pleroma/__init__.py @@ -31,14 +31,14 @@ class Pleroma(Service): def get_svg_icon() -> str: return base64.b64encode(PLEROMA_ICON.encode("utf-8")).decode("utf-8") - @staticmethod - def get_url() -> Optional[str]: + @classmethod + def get_url(cls) -> Optional[str]: """Return service url.""" domain = get_domain() return f"https://social.{domain}" - @staticmethod - def get_subdomain() -> Optional[str]: + @classmethod + def get_subdomain(cls) -> Optional[str]: return "social" @staticmethod diff --git a/selfprivacy_api/services/roundcube/__init__.py b/selfprivacy_api/services/roundcube/__init__.py new file mode 100644 index 0000000..22604f5 --- /dev/null +++ b/selfprivacy_api/services/roundcube/__init__.py @@ -0,0 +1,113 @@ +"""Class representing Roundcube service""" + +import base64 +import subprocess +from typing import List, Optional + +from selfprivacy_api.jobs import Job +from selfprivacy_api.utils.systemd import ( + get_service_status_from_several_units, +) +from selfprivacy_api.services.service import Service, ServiceStatus +from selfprivacy_api.utils import ReadUserData, get_domain +from selfprivacy_api.utils.block_devices import BlockDevice +from selfprivacy_api.services.roundcube.icon import ROUNDCUBE_ICON + + +class Roundcube(Service): + """Class representing roundcube service""" + + @staticmethod + def get_id() -> str: + """Return service id.""" + return "roundcube" + + @staticmethod + def get_display_name() -> str: + """Return service display name.""" + return "Roundcube" + + @staticmethod + def get_description() -> str: + """Return service description.""" + return "Roundcube is an open source webmail software." + + @staticmethod + def get_svg_icon() -> str: + """Read SVG icon from file and return it as base64 encoded string.""" + return base64.b64encode(ROUNDCUBE_ICON.encode("utf-8")).decode("utf-8") + + @classmethod + def get_url(cls) -> Optional[str]: + """Return service url.""" + domain = get_domain() + subdomain = cls.get_subdomain() + return f"https://{subdomain}.{domain}" + + @classmethod + def get_subdomain(cls) -> Optional[str]: + with ReadUserData() as data: + if "roundcube" in data["modules"]: + return data["modules"]["roundcube"]["subdomain"] + + return "roundcube" + + @staticmethod + def is_movable() -> bool: + return False + + @staticmethod + def is_required() -> bool: + return False + + @staticmethod + def can_be_backed_up() -> bool: + return False + + @staticmethod + def get_backup_description() -> str: + return "Nothing to backup." + + @staticmethod + def get_status() -> ServiceStatus: + return get_service_status_from_several_units(["phpfpm-roundcube.service"]) + + @staticmethod + def stop(): + subprocess.run( + ["systemctl", "stop", "phpfpm-roundcube.service"], + check=False, + ) + + @staticmethod + def start(): + subprocess.run( + ["systemctl", "start", "phpfpm-roundcube.service"], + check=False, + ) + + @staticmethod + def restart(): + subprocess.run( + ["systemctl", "restart", "phpfpm-roundcube.service"], + check=False, + ) + + @staticmethod + def get_configuration(): + return {} + + @staticmethod + def set_configuration(config_items): + return super().set_configuration(config_items) + + @staticmethod + def get_logs(): + return "" + + @staticmethod + def get_folders() -> List[str]: + return [] + + def move_to_volume(self, volume: BlockDevice) -> Job: + raise NotImplementedError("roundcube service is not movable") diff --git a/selfprivacy_api/services/roundcube/icon.py b/selfprivacy_api/services/roundcube/icon.py new file mode 100644 index 0000000..4a08207 --- /dev/null +++ b/selfprivacy_api/services/roundcube/icon.py @@ -0,0 +1,7 @@ +ROUNDCUBE_ICON = """ + + + + + +""" diff --git a/selfprivacy_api/services/service.py b/selfprivacy_api/services/service.py index 64a1e80..6e3decf 100644 --- a/selfprivacy_api/services/service.py +++ b/selfprivacy_api/services/service.py @@ -65,17 +65,17 @@ class Service(ABC): """ pass - @staticmethod + @classmethod @abstractmethod - def get_url() -> Optional[str]: + def get_url(cls) -> Optional[str]: """ The url of the service if it is accessible from the internet browser. """ pass - @staticmethod + @classmethod @abstractmethod - def get_subdomain() -> Optional[str]: + def get_subdomain(cls) -> Optional[str]: """ The assigned primary subdomain for this service. """ diff --git a/selfprivacy_api/services/test_service/__init__.py b/selfprivacy_api/services/test_service/__init__.py index caf4666..de3c493 100644 --- a/selfprivacy_api/services/test_service/__init__.py +++ b/selfprivacy_api/services/test_service/__init__.py @@ -57,14 +57,14 @@ class DummyService(Service): # return "" return base64.b64encode(BITWARDEN_ICON.encode("utf-8")).decode("utf-8") - @staticmethod - def get_url() -> typing.Optional[str]: + @classmethod + def get_url(cls) -> typing.Optional[str]: """Return service url.""" domain = "test.com" return f"https://password.{domain}" - @staticmethod - def get_subdomain() -> typing.Optional[str]: + @classmethod + def get_subdomain(cls) -> typing.Optional[str]: return "password" @classmethod diff --git a/selfprivacy_api/utils/block_devices.py b/selfprivacy_api/utils/block_devices.py index 4de5b75..0db8fe0 100644 --- a/selfprivacy_api/utils/block_devices.py +++ b/selfprivacy_api/utils/block_devices.py @@ -90,6 +90,14 @@ class BlockDevice: def __hash__(self): return hash(self.name) + def get_display_name(self) -> str: + if self.is_root(): + return "System disk" + elif self.model == "Volume": + return "Expandable volume" + else: + return self.name + def is_root(self) -> bool: """ Return True if the block device is the root device. diff --git a/selfprivacy_api/utils/redis_model_storage.py b/selfprivacy_api/utils/redis_model_storage.py index 06dfe8c..7d84210 100644 --- a/selfprivacy_api/utils/redis_model_storage.py +++ b/selfprivacy_api/utils/redis_model_storage.py @@ -1,15 +1,23 @@ +import uuid + from datetime import datetime from typing import Optional from enum import Enum def store_model_as_hash(redis, redis_key, model): - for key, value in model.dict().items(): + model_dict = model.dict() + for key, value in model_dict.items(): + if isinstance(value, uuid.UUID): + value = str(value) if isinstance(value, datetime): value = value.isoformat() if isinstance(value, Enum): value = value.value - redis.hset(redis_key, key, str(value)) + value = str(value) + model_dict[key] = value + + redis.hset(redis_key, mapping=model_dict) def hash_as_model(redis, redis_key: str, model_class): diff --git a/selfprivacy_api/utils/redis_pool.py b/selfprivacy_api/utils/redis_pool.py index 3d35f01..ea827d1 100644 --- a/selfprivacy_api/utils/redis_pool.py +++ b/selfprivacy_api/utils/redis_pool.py @@ -2,23 +2,33 @@ Redis pool module for selfprivacy_api """ import redis +import redis.asyncio as redis_async from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass REDIS_SOCKET = "/run/redis-sp-api/redis.sock" -class RedisPool(metaclass=SingletonMetaclass): +# class RedisPool(metaclass=SingletonMetaclass): +class RedisPool: """ Redis connection pool singleton. """ def __init__(self): + self._dbnumber = 0 + url = RedisPool.connection_url(dbnumber=self._dbnumber) + # We need a normal sync pool because otherwise + # our whole API will need to be async self._pool = redis.ConnectionPool.from_url( - RedisPool.connection_url(dbnumber=0), + url, + decode_responses=True, + ) + # We need an async pool for pubsub + self._async_pool = redis_async.ConnectionPool.from_url( + url, decode_responses=True, ) - self._pubsub_connection = self.get_connection() @staticmethod def connection_url(dbnumber: int) -> str: @@ -34,8 +44,15 @@ class RedisPool(metaclass=SingletonMetaclass): """ return redis.Redis(connection_pool=self._pool) - def get_pubsub(self): + def get_connection_async(self) -> redis_async.Redis: """ - Get a pubsub connection from the pool. + Get an async connection from the pool. + Async connections allow pubsub. """ - return self._pubsub_connection.pubsub() + return redis_async.Redis(connection_pool=self._async_pool) + + async def subscribe_to_keys(self, pattern: str) -> redis_async.client.PubSub: + async_redis = self.get_connection_async() + pubsub = async_redis.pubsub() + await pubsub.psubscribe(f"__keyspace@{self._dbnumber}__:" + pattern) + return pubsub diff --git a/selfprivacy_api/utils/systemd_journal.py b/selfprivacy_api/utils/systemd_journal.py new file mode 100644 index 0000000..48e97b8 --- /dev/null +++ b/selfprivacy_api/utils/systemd_journal.py @@ -0,0 +1,55 @@ +import typing +from systemd import journal + + +def get_events_from_journal( + j: journal.Reader, limit: int, next: typing.Callable[[journal.Reader], typing.Dict] +): + events = [] + i = 0 + while i < limit: + entry = next(j) + if entry is None or entry == dict(): + break + if entry["MESSAGE"] != "": + events.append(entry) + i += 1 + + return events + + +def get_paginated_logs( + limit: int = 20, + # All entries returned will be lesser than this cursor. Sets upper bound on results. + up_cursor: str | None = None, + # All entries returned will be greater than this cursor. Sets lower bound on results. + down_cursor: str | None = None, +): + j = journal.Reader() + + if up_cursor is None and down_cursor is None: + j.seek_tail() + + events = get_events_from_journal(j, limit, lambda j: j.get_previous()) + events.reverse() + + return events + elif up_cursor is None and down_cursor is not None: + j.seek_cursor(down_cursor) + j.get_previous() # pagination is exclusive + + events = get_events_from_journal(j, limit, lambda j: j.get_previous()) + events.reverse() + + return events + elif up_cursor is not None and down_cursor is None: + j.seek_cursor(up_cursor) + j.get_next() # pagination is exclusive + + events = get_events_from_journal(j, limit, lambda j: j.get_next()) + + return events + else: + raise NotImplementedError( + "Pagination by both up_cursor and down_cursor is not implemented" + ) diff --git a/setup.py b/setup.py index 473ece8..aaf333e 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="selfprivacy_api", - version="3.2.1", + version="3.3.0", packages=find_packages(), scripts=[ "selfprivacy_api/app.py", diff --git a/tests/common.py b/tests/common.py index 5f69f3f..1de3893 100644 --- a/tests/common.py +++ b/tests/common.py @@ -69,10 +69,22 @@ def generate_backup_query(query_array): return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}" +def generate_jobs_query(query_array): + return "query TestJobs {\n jobs {" + "\n".join(query_array) + "}\n}" + + +def generate_jobs_subscription(query_array): + return "subscription TestSubscription {\n jobs {" + "\n".join(query_array) + "}\n}" + + def generate_service_query(query_array): return "query TestService {\n services {" + "\n".join(query_array) + "}\n}" +def generate_logs_query(query_array): + return "query TestService {\n logs {" + "\n".join(query_array) + "}\n}" + + def mnemonic_to_hex(mnemonic): return Mnemonic(language="english").to_entropy(mnemonic).hex() diff --git a/tests/data/turned_on.json b/tests/data/turned_on.json index 0bcc2f0..9c285b1 100644 --- a/tests/data/turned_on.json +++ b/tests/data/turned_on.json @@ -62,6 +62,9 @@ "simple-nixos-mailserver": { "enable": true, "location": "sdb" + }, + "roundcube": { + "enable": true } }, "volumes": [ diff --git a/tests/test_flake_services_manager.py b/tests/test_flake_services_manager.py new file mode 100644 index 0000000..93c6e1d --- /dev/null +++ b/tests/test_flake_services_manager.py @@ -0,0 +1,127 @@ +import pytest + +from selfprivacy_api.services.flake_service_manager import FlakeServiceManager + +all_services_file = """ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc"; + + + inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden; + + inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea; + + inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet; + + inputs.nextcloud.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud; + + inputs.ocserv.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv; + + inputs.pleroma.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma; + + inputs.simple-nixos-mailserver.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver; + + outputs = _: { }; +} +""" + + +some_services_file = """ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc"; + + + inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden; + + inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea; + + inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet; + + outputs = _: { }; +} +""" + + +@pytest.fixture +def some_services_flake_mock(mocker, datadir): + flake_config_path = datadir / "some_services.nix" + mocker.patch( + "selfprivacy_api.services.flake_service_manager.FLAKE_CONFIG_PATH", + new=flake_config_path, + ) + return flake_config_path + + +@pytest.fixture +def no_services_flake_mock(mocker, datadir): + flake_config_path = datadir / "no_services.nix" + mocker.patch( + "selfprivacy_api.services.flake_service_manager.FLAKE_CONFIG_PATH", + new=flake_config_path, + ) + return flake_config_path + + +# --- + + +def test_read_services_list(some_services_flake_mock): + with FlakeServiceManager() as manager: + services = { + "bitwarden": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden", + "gitea": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea", + "jitsi-meet": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet", + } + assert manager.services == services + + +def test_change_services_list(some_services_flake_mock): + services = { + "bitwarden": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden", + "gitea": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea", + "jitsi-meet": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet", + "nextcloud": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud", + "ocserv": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv", + "pleroma": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma", + "simple-nixos-mailserver": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver", + } + + with FlakeServiceManager() as manager: + manager.services = services + + with FlakeServiceManager() as manager: + assert manager.services == services + + with open(some_services_flake_mock, "r", encoding="utf-8") as file: + file_content = file.read().strip() + + assert all_services_file.strip() == file_content + + +def test_read_empty_services_list(no_services_flake_mock): + with FlakeServiceManager() as manager: + services = {} + assert manager.services == services + + +def test_change_empty_services_list(no_services_flake_mock): + services = { + "bitwarden": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden", + "gitea": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea", + "jitsi-meet": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet", + "nextcloud": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud", + "ocserv": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv", + "pleroma": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma", + "simple-nixos-mailserver": "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver", + } + + with FlakeServiceManager() as manager: + manager.services = services + + with FlakeServiceManager() as manager: + assert manager.services == services + + with open(no_services_flake_mock, "r", encoding="utf-8") as file: + file_content = file.read().strip() + + assert all_services_file.strip() == file_content diff --git a/tests/test_flake_services_manager/no_services.nix b/tests/test_flake_services_manager/no_services.nix new file mode 100644 index 0000000..8588bc7 --- /dev/null +++ b/tests/test_flake_services_manager/no_services.nix @@ -0,0 +1,4 @@ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc"; + outputs = _: { }; +} diff --git a/tests/test_flake_services_manager/some_services.nix b/tests/test_flake_services_manager/some_services.nix new file mode 100644 index 0000000..8c2e6af --- /dev/null +++ b/tests/test_flake_services_manager/some_services.nix @@ -0,0 +1,12 @@ +{ + description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc"; + + + inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden; + + inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea; + + inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet; + + outputs = _: { }; +} diff --git a/tests/test_graphql/test_api_logs.py b/tests/test_graphql/test_api_logs.py new file mode 100644 index 0000000..6875531 --- /dev/null +++ b/tests/test_graphql/test_api_logs.py @@ -0,0 +1,172 @@ +import asyncio +import pytest +from datetime import datetime +from systemd import journal + +from tests.test_graphql.test_websocket import init_graphql + + +def assert_log_entry_equals_to_journal_entry(api_entry, journal_entry): + assert api_entry["message"] == journal_entry["MESSAGE"] + assert ( + datetime.fromisoformat(api_entry["timestamp"]) + == journal_entry["__REALTIME_TIMESTAMP"] + ) + assert api_entry.get("priority") == journal_entry.get("PRIORITY") + assert api_entry.get("systemdUnit") == journal_entry.get("_SYSTEMD_UNIT") + assert api_entry.get("systemdSlice") == journal_entry.get("_SYSTEMD_SLICE") + + +def take_from_journal(j, limit, next): + entries = [] + for _ in range(0, limit): + entry = next(j) + if entry["MESSAGE"] != "": + entries.append(entry) + return entries + + +API_GET_LOGS_WITH_UP_BORDER = """ +query TestQuery($upCursor: String) { + logs { + paginated(limit: 4, upCursor: $upCursor) { + pageMeta { + upCursor + downCursor + } + entries { + message + timestamp + priority + systemdUnit + systemdSlice + } + } + } +} +""" + +API_GET_LOGS_WITH_DOWN_BORDER = """ +query TestQuery($downCursor: String) { + logs { + paginated(limit: 4, downCursor: $downCursor) { + pageMeta { + upCursor + downCursor + } + entries { + message + timestamp + priority + systemdUnit + systemdSlice + } + } + } +} +""" + + +def test_graphql_get_logs_with_up_border(authorized_client): + j = journal.Reader() + j.seek_tail() + + # < - cursor + # <- - log entry will be returned by API call. + # ... + # log < + # log <- + # log <- + # log <- + # log <- + # log + + expected_entries = take_from_journal(j, 6, lambda j: j.get_previous()) + expected_entries.reverse() + + response = authorized_client.post( + "/graphql", + json={ + "query": API_GET_LOGS_WITH_UP_BORDER, + "variables": {"upCursor": expected_entries[0]["__CURSOR"]}, + }, + ) + assert response.status_code == 200 + + expected_entries = expected_entries[1:-1] + returned_entries = response.json()["data"]["logs"]["paginated"]["entries"] + + assert len(returned_entries) == len(expected_entries) + + for api_entry, journal_entry in zip(returned_entries, expected_entries): + assert_log_entry_equals_to_journal_entry(api_entry, journal_entry) + + +def test_graphql_get_logs_with_down_border(authorized_client): + j = journal.Reader() + j.seek_head() + j.get_next() + + # < - cursor + # <- - log entry will be returned by API call. + # log + # log <- + # log <- + # log <- + # log <- + # log < + # ... + + expected_entries = take_from_journal(j, 5, lambda j: j.get_next()) + + response = authorized_client.post( + "/graphql", + json={ + "query": API_GET_LOGS_WITH_DOWN_BORDER, + "variables": {"downCursor": expected_entries[-1]["__CURSOR"]}, + }, + ) + assert response.status_code == 200 + + expected_entries = expected_entries[:-1] + returned_entries = response.json()["data"]["logs"]["paginated"]["entries"] + + assert len(returned_entries) == len(expected_entries) + + for api_entry, journal_entry in zip(returned_entries, expected_entries): + assert_log_entry_equals_to_journal_entry(api_entry, journal_entry) + + +@pytest.mark.asyncio +async def test_websocket_subscription_for_logs(authorized_client): + with authorized_client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + init_graphql(websocket) + websocket.send_json( + { + "id": "3aaa2445", + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription { logEntries { message } }", + }, + } + ) + await asyncio.sleep(1) + + def read_until(message, limit=5): + i = 0 + while i < limit: + msg = websocket.receive_json()["payload"]["data"]["logEntries"][ + "message" + ] + if msg == message: + return + else: + i += 1 + continue + raise Exception("Failed to read websocket data, timeout") + + for i in range(0, 10): + journal.send(f"Lorem ipsum number {i}") + read_until(f"Lorem ipsum number {i}") diff --git a/tests/test_graphql/test_jobs.py b/tests/test_graphql/test_jobs.py new file mode 100644 index 0000000..68a6d20 --- /dev/null +++ b/tests/test_graphql/test_jobs.py @@ -0,0 +1,74 @@ +from tests.common import generate_jobs_query +import tests.test_graphql.test_api_backup + +from tests.test_graphql.common import ( + assert_ok, + assert_empty, + assert_errorcode, + get_data, +) + +from selfprivacy_api.jobs import Jobs + +API_JOBS_QUERY = """ +getJobs { + uid + typeId + name + description + status + statusText + progress + createdAt + updatedAt + finishedAt + error + result +} +""" + + +def graphql_send_query(client, query: str, variables: dict = {}): + return client.post("/graphql", json={"query": query, "variables": variables}) + + +def api_jobs(authorized_client): + response = graphql_send_query( + authorized_client, generate_jobs_query([API_JOBS_QUERY]) + ) + data = get_data(response) + result = data["jobs"]["getJobs"] + assert result is not None + return result + + +def test_all_jobs_unauthorized(client): + response = graphql_send_query(client, generate_jobs_query([API_JOBS_QUERY])) + assert_empty(response) + + +def test_all_jobs_when_none(authorized_client): + output = api_jobs(authorized_client) + assert output == [] + + +def test_all_jobs_when_some(authorized_client): + # We cannot make new jobs via API, at least directly + job = Jobs.add("bogus", "bogus.bogus", "fungus") + output = api_jobs(authorized_client) + + len(output) == 1 + api_job = output[0] + + assert api_job["uid"] == str(job.uid) + assert api_job["typeId"] == job.type_id + assert api_job["name"] == job.name + assert api_job["description"] == job.description + assert api_job["status"] == job.status + assert api_job["statusText"] == job.status_text + assert api_job["progress"] == job.progress + assert api_job["createdAt"] == job.created_at.isoformat() + assert api_job["updatedAt"] == job.updated_at.isoformat() + assert api_job["finishedAt"] == None + assert api_job["error"] == None + assert api_job["result"] == None diff --git a/tests/test_graphql/test_services.py b/tests/test_graphql/test_services.py index 6e8dcf6..b7faf3d 100644 --- a/tests/test_graphql/test_services.py +++ b/tests/test_graphql/test_services.py @@ -543,8 +543,8 @@ def test_disable_enable(authorized_client, only_dummy_service): assert api_dummy_service["status"] == ServiceStatus.ACTIVE.value -def test_move_immovable(authorized_client, only_dummy_service): - dummy_service = only_dummy_service +def test_move_immovable(authorized_client, dummy_service_with_binds): + dummy_service = dummy_service_with_binds dummy_service.set_movable(False) root = BlockDevices().get_root_block_device() mutation_response = api_move(authorized_client, dummy_service, root.name) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py new file mode 100644 index 0000000..754fbbf --- /dev/null +++ b/tests/test_graphql/test_websocket.py @@ -0,0 +1,225 @@ +# from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions +import pytest +import asyncio +from typing import Generator +from time import sleep + +from starlette.testclient import WebSocketTestSession + +from selfprivacy_api.jobs import Jobs +from selfprivacy_api.actions.api_tokens import TOKEN_REPO +from selfprivacy_api.graphql import IsAuthenticated + +from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH +from tests.test_jobs import jobs as empty_jobs + +# We do not iterate through them yet +TESTED_SUBPROTOCOLS = ["graphql-transport-ws"] + +JOBS_SUBSCRIPTION = """ +jobUpdates { + uid + typeId + name + description + status + statusText + progress + createdAt + updatedAt + finishedAt + error + result +} +""" + + +def api_subscribe(websocket, id, subscription): + websocket.send_json( + { + "id": id, + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + subscription + "}", + }, + } + ) + + +def connect_ws_authenticated(authorized_client) -> WebSocketTestSession: + token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"]) + return authorized_client.websocket_connect( + "/graphql", + subprotocols=TESTED_SUBPROTOCOLS, + params={"token": token}, + ) + + +def connect_ws_not_authenticated(client) -> WebSocketTestSession: + return client.websocket_connect( + "/graphql", + subprotocols=TESTED_SUBPROTOCOLS, + params={"token": "I like vegan icecream but it is not a valid token"}, + ) + + +def init_graphql(websocket): + websocket.send_json({"type": "connection_init", "payload": {}}) + ack = websocket.receive_json() + assert ack == {"type": "connection_ack"} + + +@pytest.fixture +def authenticated_websocket( + authorized_client, +) -> Generator[WebSocketTestSession, None, None]: + # We use authorized_client only to have token in the repo, this client by itself is not enough to authorize websocket + + ValueError(TOKEN_REPO.get_tokens()) + with connect_ws_authenticated(authorized_client) as websocket: + yield websocket + + +@pytest.fixture +def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]: + with connect_ws_not_authenticated(client) as websocket: + yield websocket + + +def test_websocket_connection_bare(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws", "graphql-ws"] + ) as websocket: + assert websocket is not None + assert websocket.scope is not None + + +def test_websocket_graphql_init(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + websocket.send_json({"type": "connection_init", "payload": {}}) + ack = websocket.receive_json() + assert ack == {"type": "connection_ack"} + + +def test_websocket_graphql_ping(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + # https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping + websocket.send_json({"type": "ping", "payload": {}}) + pong = websocket.receive_json() + assert pong == {"type": "pong"} + + +def test_websocket_subscription_minimal(authorized_client, authenticated_websocket): + # Test a small endpoint that exists specifically for tests + websocket = authenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 0}}, + "type": "next", + } + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 1}}, + "type": "next", + } + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 2}}, + "type": "next", + } + + +def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket): + websocket = unauthenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") + + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": [{"message": IsAuthenticated.message}], + "type": "error", + } + + +async def read_one_job(websocket): + # Bug? We only get them starting from the second job update + # That's why we receive two jobs in the list them + # The first update gets lost somewhere + response = websocket.receive_json() + return response + + +@pytest.mark.asyncio +async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs): + websocket = authenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION) + + future = asyncio.create_task(read_one_job(websocket)) + jobs = [] + jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) + sleep(0.5) + jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works")) + + response = await future + data = response["payload"]["data"] + jobs_received = data["jobUpdates"] + received_names = [job["name"] for job in jobs_received] + for job in jobs: + assert job.name in received_names + + assert len(jobs_received) == 2 + + for job in jobs: + for api_job in jobs_received: + if (job.name) == api_job["name"]: + assert api_job["uid"] == str(job.uid) + assert api_job["typeId"] == job.type_id + assert api_job["name"] == job.name + assert api_job["description"] == job.description + assert api_job["status"] == job.status + assert api_job["statusText"] == job.status_text + assert api_job["progress"] == job.progress + assert api_job["createdAt"] == job.created_at.isoformat() + assert api_job["updatedAt"] == job.updated_at.isoformat() + assert api_job["finishedAt"] == None + assert api_job["error"] == None + assert api_job["result"] == None + + +def test_websocket_subscription_unauthorized(unauthenticated_websocket): + websocket = unauthenticated_websocket + init_graphql(websocket) + id = "3aaa2445" + api_subscribe(websocket, id, JOBS_SUBSCRIPTION) + + response = websocket.receive_json() + # I do not really know why strawberry gives more info on this + # One versus the counter + payload = response["payload"][0] + assert isinstance(payload, dict) + assert "locations" in payload.keys() + # It looks like this 'locations': [{'column': 32, 'line': 1}] + # We cannot test locations feasibly + del payload["locations"] + assert response == { + "id": id, + "payload": [{"message": IsAuthenticated.message, "path": ["jobUpdates"]}], + "type": "error", + } diff --git a/tests/test_redis.py b/tests/test_redis.py new file mode 100644 index 0000000..02dfb21 --- /dev/null +++ b/tests/test_redis.py @@ -0,0 +1,218 @@ +import asyncio +import pytest +import pytest_asyncio +from asyncio import streams +import redis +from typing import List + +from selfprivacy_api.utils.redis_pool import RedisPool + +from selfprivacy_api.jobs import Jobs, job_notifications + +TEST_KEY = "test:test" +STOPWORD = "STOP" + + +@pytest.fixture() +def empty_redis(event_loop): + r = RedisPool().get_connection() + r.flushdb() + assert r.config_get("notify-keyspace-events")["notify-keyspace-events"] == "AKE" + yield r + r.flushdb() + + +async def write_to_test_key(): + r = RedisPool().get_connection_async() + async with r.pipeline(transaction=True) as pipe: + ok1, ok2 = await pipe.set(TEST_KEY, "value1").set(TEST_KEY, "value2").execute() + assert ok1 + assert ok2 + assert await r.get(TEST_KEY) == "value2" + await r.close() + + +def test_async_connection(empty_redis): + r = RedisPool().get_connection() + assert not r.exists(TEST_KEY) + # It _will_ report an error if it arises + asyncio.run(write_to_test_key()) + # Confirming that we can read result from sync connection too + assert r.get(TEST_KEY) == "value2" + + +async def channel_reader(channel: redis.client.PubSub) -> List[dict]: + result: List[dict] = [] + while True: + # Mypy cannot correctly detect that it is a coroutine + # But it is + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + result.append(message) + if message["data"] == STOPWORD: + break + return result + + +async def channel_reader_onemessage(channel: redis.client.PubSub) -> dict: + while True: + # Mypy cannot correctly detect that it is a coroutine + # But it is + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + return message + + +@pytest.mark.asyncio +async def test_pubsub(empty_redis, event_loop): + # Adapted from : + # https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html + # Sanity checking because of previous event loop bugs + assert event_loop == asyncio.get_event_loop() + assert event_loop == asyncio.events.get_event_loop() + assert event_loop == asyncio.events._get_event_loop() + assert event_loop == asyncio.events.get_running_loop() + + reader = streams.StreamReader(34) + assert event_loop == reader._loop + f = reader._loop.create_future() + f.set_result(3) + await f + + r = RedisPool().get_connection_async() + async with r.pubsub() as pubsub: + await pubsub.subscribe("channel:1") + future = asyncio.create_task(channel_reader(pubsub)) + + await r.publish("channel:1", "Hello") + # message: dict = await pubsub.get_message(ignore_subscribe_messages=True, timeout=5.0) # type: ignore + # raise ValueError(message) + await r.publish("channel:1", "World") + await r.publish("channel:1", STOPWORD) + + messages = await future + + assert len(messages) == 3 + + message = messages[0] + assert "data" in message.keys() + assert message["data"] == "Hello" + message = messages[1] + assert "data" in message.keys() + assert message["data"] == "World" + message = messages[2] + assert "data" in message.keys() + assert message["data"] == STOPWORD + + await r.close() + + +@pytest.mark.asyncio +async def test_keyspace_notifications_simple(empty_redis, event_loop): + r = RedisPool().get_connection_async() + await r.set(TEST_KEY, "I am not empty") + async with r.pubsub() as pubsub: + await pubsub.subscribe("__keyspace@0__:" + TEST_KEY) + + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": None, + "type": "message", + } + + +@pytest.mark.asyncio +async def test_keyspace_notifications(empty_redis, event_loop): + pubsub = await RedisPool().subscribe_to_keys(TEST_KEY) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": f"__keyspace@0__:{TEST_KEY}", + "type": "pmessage", + } + + +@pytest.mark.asyncio +async def test_keyspace_notifications_patterns(empty_redis, event_loop): + pattern = "test*" + pubsub = await RedisPool().subscribe_to_keys(pattern) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": f"__keyspace@0__:{pattern}", + "type": "pmessage", + } + + +@pytest.mark.asyncio +async def test_keyspace_notifications_jobs(empty_redis, event_loop): + pattern = "jobs:*" + pubsub = await RedisPool().subscribe_to_keys(pattern) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + Jobs.add("testjob1", "test.test", "Testing aaaalll day") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message["data"] == "hset" + + +async def reader_of_jobs() -> List[dict]: + """ + Reads 3 job updates and exits + """ + result: List[dict] = [] + async for message in job_notifications(): + result.append(message) + if len(result) >= 3: + break + return result + + +@pytest.mark.asyncio +async def test_jobs_generator(empty_redis, event_loop): + # Will read exactly 3 job messages + future_messages = asyncio.create_task(reader_of_jobs()) + await asyncio.sleep(1) + + Jobs.add("testjob1", "test.test", "Testing aaaalll day") + Jobs.add("testjob2", "test.test", "Testing aaaalll day") + Jobs.add("testjob3", "test.test", "Testing aaaalll day") + Jobs.add("testjob4", "test.test", "Testing aaaalll day") + + assert len(Jobs.get_jobs()) == 4 + r = RedisPool().get_connection() + assert len(r.keys("jobs:*")) == 4 + + messages = await future_messages + assert len(messages) == 3 + channels = [message["channel"] for message in messages] + operations = [message["data"] for message in messages] + assert set(operations) == set(["hset"]) # all of them are hsets + + # Asserting that all of jobs emitted exactly one message + jobs = Jobs.get_jobs() + names = ["testjob1", "testjob2", "testjob3"] + ids = [str(job.uid) for job in jobs if job.name in names] + for id in ids: + assert id in " ".join(channels) + # Asserting that they came in order + assert "testjob4" not in " ".join(channels) diff --git a/tests/test_services_systemctl.py b/tests/test_services_systemctl.py index 8b247e0..43805e8 100644 --- a/tests/test_services_systemctl.py +++ b/tests/test_services_systemctl.py @@ -2,7 +2,7 @@ import pytest from selfprivacy_api.services.service import ServiceStatus from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.nextcloud import Nextcloud from selfprivacy_api.services.ocserv import Ocserv @@ -22,7 +22,7 @@ def call_args_asserts(mocked_object): "dovecot2.service", "postfix.service", "vaultwarden.service", - "gitea.service", + "forgejo.service", "phpfpm-nextcloud.service", "ocserv.service", "pleroma.service", @@ -77,7 +77,7 @@ def mock_popen_systemctl_service_not_ok(mocker): def test_systemctl_ok(mock_popen_systemctl_service_ok): assert MailServer.get_status() == ServiceStatus.ACTIVE assert Bitwarden.get_status() == ServiceStatus.ACTIVE - assert Gitea.get_status() == ServiceStatus.ACTIVE + assert Forgejo.get_status() == ServiceStatus.ACTIVE assert Nextcloud.get_status() == ServiceStatus.ACTIVE assert Ocserv.get_status() == ServiceStatus.ACTIVE assert Pleroma.get_status() == ServiceStatus.ACTIVE @@ -87,7 +87,7 @@ def test_systemctl_ok(mock_popen_systemctl_service_ok): def test_systemctl_failed_service(mock_popen_systemctl_service_not_ok): assert MailServer.get_status() == ServiceStatus.FAILED assert Bitwarden.get_status() == ServiceStatus.FAILED - assert Gitea.get_status() == ServiceStatus.FAILED + assert Forgejo.get_status() == ServiceStatus.FAILED assert Nextcloud.get_status() == ServiceStatus.FAILED assert Ocserv.get_status() == ServiceStatus.FAILED assert Pleroma.get_status() == ServiceStatus.FAILED diff --git a/tests/test_websocket_uvicorn_standalone.py b/tests/test_websocket_uvicorn_standalone.py new file mode 100644 index 0000000..43a53ef --- /dev/null +++ b/tests/test_websocket_uvicorn_standalone.py @@ -0,0 +1,39 @@ +import pytest +from fastapi import FastAPI, WebSocket +import uvicorn + +# import subprocess +from multiprocessing import Process +import asyncio +from time import sleep +from websockets import client + +app = FastAPI() + + +@app.websocket("/") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + while True: + data = await websocket.receive_text() + await websocket.send_text(f"You sent: {data}") + + +def run_uvicorn(): + uvicorn.run(app, port=5000) + return True + + +@pytest.mark.asyncio +async def test_uvcorn_ws_works_in_prod(): + proc = Process(target=run_uvicorn) + proc.start() + sleep(2) + + ws = await client.connect("ws://127.0.0.1:5000") + + await ws.send("hohoho") + message = await ws.read_message() + assert message == "You sent: hohoho" + await ws.close() + proc.kill()