diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 45ebd2a..a188a00 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,9 +13,9 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api), For detailed installation information, please review and follow: [link](https://nixos.org/manual/nix/stable/installation/installing-binary.html#installing-a-binary-distribution). -3. **Change directory to the cloned repository and start a nix shell:** +3. **Change directory to the cloned repository and start a nix development shell:** - ```cd selfprivacy-rest-api && nix-shell``` + ```cd selfprivacy-rest-api && nix develop``` Nix will install all of the necessary packages for development work, all further actions will take place only within nix-shell. @@ -31,7 +31,7 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api), Copy the path that starts with ```/nix/store/``` and ends with ```env/bin/python``` - ```/nix/store/???-python3-3.9.??-env/bin/python``` + ```/nix/store/???-python3-3.10.??-env/bin/python``` Click on the python version selection in the lower right corner, and replace the path to the interpreter in the project with the one you copied from the terminal. @@ -43,12 +43,13 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api), ## What to do after making changes to the repository? -**Run unit tests** using ```pytest .``` -Make sure that all tests pass successfully and the API works correctly. For convenience, you can use the built-in VScode interface. +**Run unit tests** using ```pytest-vm``` inside of the development shell. This will run all the test inside a virtual machine, which is necessary for the tests to pass successfully. +Make sure that all tests pass successfully and the API works correctly. -How to review the percentage of code coverage? Execute the command: +The ```pytest-vm``` command will also print out the coverage of the tests. To export the report to an XML file, use the following command: + +```coverage xml``` -```coverage run -m pytest && coverage xml && coverage report``` Next, use the recommended extension ```ryanluker.vscode-coverage-gutters```, navigate to one of the test files, and click the "watch" button on the bottom panel of VScode. diff --git a/default.nix b/default.nix index e7e6fcf..1af935e 100644 --- a/default.nix +++ b/default.nix @@ -14,10 +14,12 @@ pythonPackages.buildPythonPackage rec { pydantic pytz redis + systemd setuptools strawberry-graphql typing-extensions uvicorn + websockets ]; pythonImportsCheck = [ "selfprivacy_api" ]; doCheck = false; diff --git a/flake.lock b/flake.lock index 1f52d36..ba47e51 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1709677081, - "narHash": "sha256-tix36Y7u0rkn6mTm0lA45b45oab2cFLqAzDbJxeXS+c=", + "lastModified": 1719957072, + "narHash": "sha256-gvFhEf5nszouwLAkT9nWsDzocUTqLWHuL++dvNjMp9I=", "owner": "nixos", "repo": "nixpkgs", - "rev": "880992dcc006a5e00dd0591446fdf723e6a51a64", + "rev": "7144d6241f02d171d25fba3edeaf15e0f2592105", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index f8b81aa..ab969a4 100644 --- a/flake.nix +++ b/flake.nix @@ -20,6 +20,7 @@ pytest-datadir pytest-mock pytest-subprocess + pytest-asyncio black mypy pylsp-mypy diff --git a/selfprivacy_api/app.py b/selfprivacy_api/app.py index 64ca85a..2f7e2f7 100644 --- a/selfprivacy_api/app.py +++ b/selfprivacy_api/app.py @@ -3,6 +3,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from strawberry.fastapi import GraphQLRouter +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL import uvicorn @@ -13,8 +14,12 @@ from selfprivacy_api.migrations import run_migrations app = FastAPI() -graphql_app = GraphQLRouter( +graphql_app: GraphQLRouter = GraphQLRouter( schema, + subscription_protocols=[ + GRAPHQL_TRANSPORT_WS_PROTOCOL, + GRAPHQL_WS_PROTOCOL, + ], ) app.add_middleware( diff --git a/selfprivacy_api/dependencies.py b/selfprivacy_api/dependencies.py index b9d0904..b2e2b19 100644 --- a/selfprivacy_api/dependencies.py +++ b/selfprivacy_api/dependencies.py @@ -27,4 +27,4 @@ async def get_token_header( def get_api_version() -> str: """Get API version""" - return "3.2.1" + return "3.3.0" diff --git a/selfprivacy_api/graphql/__init__.py b/selfprivacy_api/graphql/__init__.py index 6124a1a..edd8a78 100644 --- a/selfprivacy_api/graphql/__init__.py +++ b/selfprivacy_api/graphql/__init__.py @@ -16,6 +16,10 @@ class IsAuthenticated(BasePermission): token = info.context["request"].headers.get("Authorization") if token is None: token = info.context["request"].query_params.get("token") + if token is None: + connection_params = info.context.get("connection_params") + if connection_params is not None: + token = connection_params.get("Authorization") if token is None: return False return is_token_valid(token.replace("Bearer ", "")) diff --git a/selfprivacy_api/graphql/queries/jobs.py b/selfprivacy_api/graphql/queries/jobs.py index eebba43..35a2182 100644 --- a/selfprivacy_api/graphql/queries/jobs.py +++ b/selfprivacy_api/graphql/queries/jobs.py @@ -1,26 +1,30 @@ """Jobs status""" # pylint: disable=too-few-public-methods -import typing import strawberry +from typing import List, Optional +from selfprivacy_api.jobs import Jobs from selfprivacy_api.graphql.common_types.jobs import ( ApiJob, get_api_job_by_id, job_to_api_job, ) -from selfprivacy_api.jobs import Jobs + +def get_all_jobs() -> List[ApiJob]: + jobs = Jobs.get_jobs() + api_jobs = [job_to_api_job(job) for job in jobs] + assert api_jobs is not None + return api_jobs @strawberry.type class Job: @strawberry.field - def get_jobs(self) -> typing.List[ApiJob]: - Jobs.get_jobs() - - return [job_to_api_job(job) for job in Jobs.get_jobs()] + def get_jobs(self) -> List[ApiJob]: + return get_all_jobs() @strawberry.field - def get_job(self, job_id: str) -> typing.Optional[ApiJob]: + def get_job(self, job_id: str) -> Optional[ApiJob]: return get_api_job_by_id(job_id) diff --git a/selfprivacy_api/graphql/queries/logs.py b/selfprivacy_api/graphql/queries/logs.py new file mode 100644 index 0000000..cf8fe21 --- /dev/null +++ b/selfprivacy_api/graphql/queries/logs.py @@ -0,0 +1,88 @@ +"""System logs""" +from datetime import datetime +import typing +import strawberry +from selfprivacy_api.utils.systemd_journal import get_paginated_logs + + +@strawberry.type +class LogEntry: + message: str = strawberry.field() + timestamp: datetime = strawberry.field() + priority: typing.Optional[int] = strawberry.field() + systemd_unit: typing.Optional[str] = strawberry.field() + systemd_slice: typing.Optional[str] = strawberry.field() + + def __init__(self, journal_entry: typing.Dict): + self.entry = journal_entry + self.message = journal_entry["MESSAGE"] + self.timestamp = journal_entry["__REALTIME_TIMESTAMP"] + self.priority = journal_entry.get("PRIORITY") + self.systemd_unit = journal_entry.get("_SYSTEMD_UNIT") + self.systemd_slice = journal_entry.get("_SYSTEMD_SLICE") + + @strawberry.field() + def cursor(self) -> str: + return self.entry["__CURSOR"] + + +@strawberry.type +class LogsPageMeta: + up_cursor: typing.Optional[str] = strawberry.field() + down_cursor: typing.Optional[str] = strawberry.field() + + def __init__( + self, up_cursor: typing.Optional[str], down_cursor: typing.Optional[str] + ): + self.up_cursor = up_cursor + self.down_cursor = down_cursor + + +@strawberry.type +class PaginatedEntries: + page_meta: LogsPageMeta = strawberry.field( + description="Metadata to aid in pagination." + ) + entries: typing.List[LogEntry] = strawberry.field( + description="The list of log entries." + ) + + def __init__(self, meta: LogsPageMeta, entries: typing.List[LogEntry]): + self.page_meta = meta + self.entries = entries + + @staticmethod + def from_entries(entries: typing.List[LogEntry]): + if entries == []: + return PaginatedEntries(LogsPageMeta(None, None), []) + + return PaginatedEntries( + LogsPageMeta( + entries[0].cursor(), + entries[-1].cursor(), + ), + entries, + ) + + +@strawberry.type +class Logs: + @strawberry.field() + def paginated( + self, + limit: int = 20, + # All entries returned will be lesser than this cursor. Sets upper bound on results. + up_cursor: str | None = None, + # All entries returned will be greater than this cursor. Sets lower bound on results. + down_cursor: str | None = None, + ) -> PaginatedEntries: + if limit > 50: + raise Exception("You can't fetch more than 50 entries via single request.") + return PaginatedEntries.from_entries( + list( + map( + lambda x: LogEntry(x), + get_paginated_logs(limit, up_cursor, down_cursor), + ) + ) + ) diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index e4e7264..b49a629 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -2,8 +2,10 @@ # pylint: disable=too-few-public-methods import asyncio -from typing import AsyncGenerator +from typing import AsyncGenerator, List import strawberry +from strawberry.types import Info + from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql.mutations.deprecated_mutations import ( DeprecatedApiMutations, @@ -24,10 +26,17 @@ from selfprivacy_api.graphql.mutations.backup_mutations import BackupMutations from selfprivacy_api.graphql.queries.api_queries import Api from selfprivacy_api.graphql.queries.backup import Backup from selfprivacy_api.graphql.queries.jobs import Job +from selfprivacy_api.graphql.queries.logs import LogEntry, Logs from selfprivacy_api.graphql.queries.services import Services from selfprivacy_api.graphql.queries.storage import Storage from selfprivacy_api.graphql.queries.system import System +from selfprivacy_api.graphql.subscriptions.jobs import ApiJob +from selfprivacy_api.graphql.subscriptions.jobs import ( + job_updates as job_update_generator, +) +from selfprivacy_api.graphql.subscriptions.logs import log_stream + from selfprivacy_api.graphql.mutations.users_mutations import UsersMutations from selfprivacy_api.graphql.queries.users import Users from selfprivacy_api.jobs.test import test_job @@ -47,6 +56,11 @@ class Query: """System queries""" return System() + @strawberry.field(permission_classes=[IsAuthenticated]) + def logs(self) -> Logs: + """Log queries""" + return Logs() + @strawberry.field(permission_classes=[IsAuthenticated]) def users(self) -> Users: """Users queries""" @@ -129,19 +143,42 @@ class Mutation( code=200, ) - pass + +# A cruft for Websockets +def authenticated(info: Info) -> bool: + return IsAuthenticated().has_permission(source=None, info=info) + + +def reject_if_unauthenticated(info: Info): + if not authenticated(info): + raise Exception(IsAuthenticated().message) @strawberry.type class Subscription: - """Root schema for subscriptions""" + """Root schema for subscriptions. + Every field here should be an AsyncIterator or AsyncGenerator + It is not a part of the spec but graphql-core (dep of strawberryql) + demands it while the spec is vague in this area.""" - @strawberry.subscription(permission_classes=[IsAuthenticated]) - async def count(self, target: int = 100) -> AsyncGenerator[int, None]: - for i in range(target): + @strawberry.subscription + async def job_updates(self, info: Info) -> AsyncGenerator[List[ApiJob], None]: + reject_if_unauthenticated(info) + return job_update_generator() + + @strawberry.subscription + # Used for testing, consider deletion to shrink attack surface + async def count(self, info: Info) -> AsyncGenerator[int, None]: + reject_if_unauthenticated(info) + for i in range(10): yield i await asyncio.sleep(0.5) + @strawberry.subscription + async def log_entries(self, info: Info) -> AsyncGenerator[LogEntry, None]: + reject_if_unauthenticated(info) + return log_stream() + schema = strawberry.Schema( query=Query, diff --git a/selfprivacy_api/graphql/subscriptions/__init__.py b/selfprivacy_api/graphql/subscriptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/selfprivacy_api/graphql/subscriptions/jobs.py b/selfprivacy_api/graphql/subscriptions/jobs.py new file mode 100644 index 0000000..11d6263 --- /dev/null +++ b/selfprivacy_api/graphql/subscriptions/jobs.py @@ -0,0 +1,14 @@ +# pylint: disable=too-few-public-methods + +from typing import AsyncGenerator, List + +from selfprivacy_api.jobs import job_notifications + +from selfprivacy_api.graphql.common_types.jobs import ApiJob +from selfprivacy_api.graphql.queries.jobs import get_all_jobs + + +async def job_updates() -> AsyncGenerator[List[ApiJob], None]: + # Send the complete list of jobs every time anything gets updated + async for notification in job_notifications(): + yield get_all_jobs() diff --git a/selfprivacy_api/graphql/subscriptions/logs.py b/selfprivacy_api/graphql/subscriptions/logs.py new file mode 100644 index 0000000..1e47dba --- /dev/null +++ b/selfprivacy_api/graphql/subscriptions/logs.py @@ -0,0 +1,31 @@ +from typing import AsyncGenerator +from systemd import journal +import asyncio + +from selfprivacy_api.graphql.queries.logs import LogEntry + + +async def log_stream() -> AsyncGenerator[LogEntry, None]: + j = journal.Reader() + + j.seek_tail() + j.get_previous() + + queue = asyncio.Queue() + + async def callback(): + if j.process() != journal.APPEND: + return + for entry in j: + await queue.put(entry) + + asyncio.get_event_loop().add_reader(j, lambda: asyncio.ensure_future(callback())) + + while True: + entry = await queue.get() + try: + yield LogEntry(entry) + except Exception: + asyncio.get_event_loop().remove_reader(j) + return + queue.task_done() diff --git a/selfprivacy_api/jobs/__init__.py b/selfprivacy_api/jobs/__init__.py index 4649bb0..3dd48c4 100644 --- a/selfprivacy_api/jobs/__init__.py +++ b/selfprivacy_api/jobs/__init__.py @@ -15,6 +15,7 @@ A job is a dictionary with the following keys: - result: result of the job """ import typing +import asyncio import datetime from uuid import UUID import uuid @@ -23,6 +24,7 @@ from enum import Enum from pydantic import BaseModel from selfprivacy_api.utils.redis_pool import RedisPool +from selfprivacy_api.utils.redis_model_storage import store_model_as_hash JOB_EXPIRATION_SECONDS = 10 * 24 * 60 * 60 # ten days @@ -102,7 +104,7 @@ class Jobs: result=None, ) redis = RedisPool().get_connection() - _store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job) + store_model_as_hash(redis, _redis_key_from_uuid(job.uid), job) return job @staticmethod @@ -218,7 +220,7 @@ class Jobs: redis = RedisPool().get_connection() key = _redis_key_from_uuid(job.uid) if redis.exists(key): - _store_job_as_hash(redis, key, job) + store_model_as_hash(redis, key, job) if status in (JobStatus.FINISHED, JobStatus.ERROR): redis.expire(key, JOB_EXPIRATION_SECONDS) @@ -294,17 +296,6 @@ def _progress_log_key_from_uuid(uuid_string) -> str: return PROGRESS_LOGS_PREFIX + str(uuid_string) -def _store_job_as_hash(redis, redis_key, model) -> None: - for key, value in model.dict().items(): - if isinstance(value, uuid.UUID): - value = str(value) - if isinstance(value, datetime.datetime): - value = value.isoformat() - if isinstance(value, JobStatus): - value = value.value - redis.hset(redis_key, key, str(value)) - - def _job_from_hash(redis, redis_key) -> typing.Optional[Job]: if redis.exists(redis_key): job_dict = redis.hgetall(redis_key) @@ -321,3 +312,15 @@ def _job_from_hash(redis, redis_key) -> typing.Optional[Job]: return Job(**job_dict) return None + + +async def job_notifications() -> typing.AsyncGenerator[dict, None]: + channel = await RedisPool().subscribe_to_keys("jobs:*") + while True: + try: + # we cannot timeout here because we do not know when the next message is supposed to arrive + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + yield message + except GeneratorExit: + break diff --git a/selfprivacy_api/jobs/migrate_to_binds.py b/selfprivacy_api/jobs/migrate_to_binds.py index 3250c9a..782b361 100644 --- a/selfprivacy_api/jobs/migrate_to_binds.py +++ b/selfprivacy_api/jobs/migrate_to_binds.py @@ -6,7 +6,7 @@ import shutil from pydantic import BaseModel from selfprivacy_api.jobs import Job, JobStatus, Jobs from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.nextcloud import Nextcloud from selfprivacy_api.services.pleroma import Pleroma @@ -230,7 +230,7 @@ def migrate_to_binds(config: BindMigrationConfig, job: Job): status_text="Migrating Gitea.", ) - Gitea().stop() + Forgejo().stop() if not pathlib.Path("/volumes/sda1/gitea").exists(): if not pathlib.Path("/volumes/sdb/gitea").exists(): @@ -241,7 +241,7 @@ def migrate_to_binds(config: BindMigrationConfig, job: Job): group="gitea", ) - Gitea().start() + Forgejo().start() # Perform migration of Mail server diff --git a/selfprivacy_api/services/__init__.py b/selfprivacy_api/services/__init__.py index 267cc31..5a2414c 100644 --- a/selfprivacy_api/services/__init__.py +++ b/selfprivacy_api/services/__init__.py @@ -2,7 +2,7 @@ import typing from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.jitsimeet import JitsiMeet from selfprivacy_api.services.roundcube import Roundcube from selfprivacy_api.services.mailserver import MailServer @@ -14,7 +14,7 @@ import selfprivacy_api.utils.network as network_utils services: list[Service] = [ Bitwarden(), - Gitea(), + Forgejo(), MailServer(), Nextcloud(), Pleroma(), diff --git a/selfprivacy_api/services/gitea/__init__.py b/selfprivacy_api/services/forgejo/__init__.py similarity index 72% rename from selfprivacy_api/services/gitea/__init__.py rename to selfprivacy_api/services/forgejo/__init__.py index 88df4ed..06cf614 100644 --- a/selfprivacy_api/services/gitea/__init__.py +++ b/selfprivacy_api/services/forgejo/__init__.py @@ -7,31 +7,34 @@ from selfprivacy_api.utils import get_domain from selfprivacy_api.utils.systemd import get_service_status from selfprivacy_api.services.service import Service, ServiceStatus -from selfprivacy_api.services.gitea.icon import GITEA_ICON +from selfprivacy_api.services.forgejo.icon import FORGEJO_ICON -class Gitea(Service): - """Class representing Gitea service""" +class Forgejo(Service): + """Class representing Forgejo service. + + Previously was Gitea, so some IDs are still called gitea for compatibility. + """ @staticmethod def get_id() -> str: - """Return service id.""" + """Return service id. For compatibility keep in gitea.""" return "gitea" @staticmethod def get_display_name() -> str: """Return service display name.""" - return "Gitea" + return "Forgejo" @staticmethod def get_description() -> str: """Return service description.""" - return "Gitea is a Git forge." + return "Forgejo is a Git forge." @staticmethod def get_svg_icon() -> str: """Read SVG icon from file and return it as base64 encoded string.""" - return base64.b64encode(GITEA_ICON.encode("utf-8")).decode("utf-8") + return base64.b64encode(FORGEJO_ICON.encode("utf-8")).decode("utf-8") @classmethod def get_url(cls) -> Optional[str]: @@ -65,19 +68,19 @@ class Gitea(Service): Return code 3 means service is stopped. Return code 4 means service is off. """ - return get_service_status("gitea.service") + return get_service_status("forgejo.service") @staticmethod def stop(): - subprocess.run(["systemctl", "stop", "gitea.service"]) + subprocess.run(["systemctl", "stop", "forgejo.service"]) @staticmethod def start(): - subprocess.run(["systemctl", "start", "gitea.service"]) + subprocess.run(["systemctl", "start", "forgejo.service"]) @staticmethod def restart(): - subprocess.run(["systemctl", "restart", "gitea.service"]) + subprocess.run(["systemctl", "restart", "forgejo.service"]) @staticmethod def get_configuration(): @@ -93,4 +96,5 @@ class Gitea(Service): @staticmethod def get_folders() -> List[str]: + """The data folder is still called gitea for compatibility.""" return ["/var/lib/gitea"] diff --git a/selfprivacy_api/services/gitea/gitea.svg b/selfprivacy_api/services/forgejo/gitea.svg similarity index 100% rename from selfprivacy_api/services/gitea/gitea.svg rename to selfprivacy_api/services/forgejo/gitea.svg diff --git a/selfprivacy_api/services/gitea/icon.py b/selfprivacy_api/services/forgejo/icon.py similarity index 98% rename from selfprivacy_api/services/gitea/icon.py rename to selfprivacy_api/services/forgejo/icon.py index 569f96a..5e600cf 100644 --- a/selfprivacy_api/services/gitea/icon.py +++ b/selfprivacy_api/services/forgejo/icon.py @@ -1,4 +1,4 @@ -GITEA_ICON = """ +FORGEJO_ICON = """ diff --git a/selfprivacy_api/utils/redis_model_storage.py b/selfprivacy_api/utils/redis_model_storage.py index 06dfe8c..7d84210 100644 --- a/selfprivacy_api/utils/redis_model_storage.py +++ b/selfprivacy_api/utils/redis_model_storage.py @@ -1,15 +1,23 @@ +import uuid + from datetime import datetime from typing import Optional from enum import Enum def store_model_as_hash(redis, redis_key, model): - for key, value in model.dict().items(): + model_dict = model.dict() + for key, value in model_dict.items(): + if isinstance(value, uuid.UUID): + value = str(value) if isinstance(value, datetime): value = value.isoformat() if isinstance(value, Enum): value = value.value - redis.hset(redis_key, key, str(value)) + value = str(value) + model_dict[key] = value + + redis.hset(redis_key, mapping=model_dict) def hash_as_model(redis, redis_key: str, model_class): diff --git a/selfprivacy_api/utils/redis_pool.py b/selfprivacy_api/utils/redis_pool.py index 3d35f01..ea827d1 100644 --- a/selfprivacy_api/utils/redis_pool.py +++ b/selfprivacy_api/utils/redis_pool.py @@ -2,23 +2,33 @@ Redis pool module for selfprivacy_api """ import redis +import redis.asyncio as redis_async from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass REDIS_SOCKET = "/run/redis-sp-api/redis.sock" -class RedisPool(metaclass=SingletonMetaclass): +# class RedisPool(metaclass=SingletonMetaclass): +class RedisPool: """ Redis connection pool singleton. """ def __init__(self): + self._dbnumber = 0 + url = RedisPool.connection_url(dbnumber=self._dbnumber) + # We need a normal sync pool because otherwise + # our whole API will need to be async self._pool = redis.ConnectionPool.from_url( - RedisPool.connection_url(dbnumber=0), + url, + decode_responses=True, + ) + # We need an async pool for pubsub + self._async_pool = redis_async.ConnectionPool.from_url( + url, decode_responses=True, ) - self._pubsub_connection = self.get_connection() @staticmethod def connection_url(dbnumber: int) -> str: @@ -34,8 +44,15 @@ class RedisPool(metaclass=SingletonMetaclass): """ return redis.Redis(connection_pool=self._pool) - def get_pubsub(self): + def get_connection_async(self) -> redis_async.Redis: """ - Get a pubsub connection from the pool. + Get an async connection from the pool. + Async connections allow pubsub. """ - return self._pubsub_connection.pubsub() + return redis_async.Redis(connection_pool=self._async_pool) + + async def subscribe_to_keys(self, pattern: str) -> redis_async.client.PubSub: + async_redis = self.get_connection_async() + pubsub = async_redis.pubsub() + await pubsub.psubscribe(f"__keyspace@{self._dbnumber}__:" + pattern) + return pubsub diff --git a/selfprivacy_api/utils/systemd_journal.py b/selfprivacy_api/utils/systemd_journal.py new file mode 100644 index 0000000..48e97b8 --- /dev/null +++ b/selfprivacy_api/utils/systemd_journal.py @@ -0,0 +1,55 @@ +import typing +from systemd import journal + + +def get_events_from_journal( + j: journal.Reader, limit: int, next: typing.Callable[[journal.Reader], typing.Dict] +): + events = [] + i = 0 + while i < limit: + entry = next(j) + if entry is None or entry == dict(): + break + if entry["MESSAGE"] != "": + events.append(entry) + i += 1 + + return events + + +def get_paginated_logs( + limit: int = 20, + # All entries returned will be lesser than this cursor. Sets upper bound on results. + up_cursor: str | None = None, + # All entries returned will be greater than this cursor. Sets lower bound on results. + down_cursor: str | None = None, +): + j = journal.Reader() + + if up_cursor is None and down_cursor is None: + j.seek_tail() + + events = get_events_from_journal(j, limit, lambda j: j.get_previous()) + events.reverse() + + return events + elif up_cursor is None and down_cursor is not None: + j.seek_cursor(down_cursor) + j.get_previous() # pagination is exclusive + + events = get_events_from_journal(j, limit, lambda j: j.get_previous()) + events.reverse() + + return events + elif up_cursor is not None and down_cursor is None: + j.seek_cursor(up_cursor) + j.get_next() # pagination is exclusive + + events = get_events_from_journal(j, limit, lambda j: j.get_next()) + + return events + else: + raise NotImplementedError( + "Pagination by both up_cursor and down_cursor is not implemented" + ) diff --git a/setup.py b/setup.py index 473ece8..aaf333e 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="selfprivacy_api", - version="3.2.1", + version="3.3.0", packages=find_packages(), scripts=[ "selfprivacy_api/app.py", diff --git a/tests/common.py b/tests/common.py index 5f69f3f..1de3893 100644 --- a/tests/common.py +++ b/tests/common.py @@ -69,10 +69,22 @@ def generate_backup_query(query_array): return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}" +def generate_jobs_query(query_array): + return "query TestJobs {\n jobs {" + "\n".join(query_array) + "}\n}" + + +def generate_jobs_subscription(query_array): + return "subscription TestSubscription {\n jobs {" + "\n".join(query_array) + "}\n}" + + def generate_service_query(query_array): return "query TestService {\n services {" + "\n".join(query_array) + "}\n}" +def generate_logs_query(query_array): + return "query TestService {\n logs {" + "\n".join(query_array) + "}\n}" + + def mnemonic_to_hex(mnemonic): return Mnemonic(language="english").to_entropy(mnemonic).hex() diff --git a/tests/test_graphql/test_api_logs.py b/tests/test_graphql/test_api_logs.py new file mode 100644 index 0000000..6875531 --- /dev/null +++ b/tests/test_graphql/test_api_logs.py @@ -0,0 +1,172 @@ +import asyncio +import pytest +from datetime import datetime +from systemd import journal + +from tests.test_graphql.test_websocket import init_graphql + + +def assert_log_entry_equals_to_journal_entry(api_entry, journal_entry): + assert api_entry["message"] == journal_entry["MESSAGE"] + assert ( + datetime.fromisoformat(api_entry["timestamp"]) + == journal_entry["__REALTIME_TIMESTAMP"] + ) + assert api_entry.get("priority") == journal_entry.get("PRIORITY") + assert api_entry.get("systemdUnit") == journal_entry.get("_SYSTEMD_UNIT") + assert api_entry.get("systemdSlice") == journal_entry.get("_SYSTEMD_SLICE") + + +def take_from_journal(j, limit, next): + entries = [] + for _ in range(0, limit): + entry = next(j) + if entry["MESSAGE"] != "": + entries.append(entry) + return entries + + +API_GET_LOGS_WITH_UP_BORDER = """ +query TestQuery($upCursor: String) { + logs { + paginated(limit: 4, upCursor: $upCursor) { + pageMeta { + upCursor + downCursor + } + entries { + message + timestamp + priority + systemdUnit + systemdSlice + } + } + } +} +""" + +API_GET_LOGS_WITH_DOWN_BORDER = """ +query TestQuery($downCursor: String) { + logs { + paginated(limit: 4, downCursor: $downCursor) { + pageMeta { + upCursor + downCursor + } + entries { + message + timestamp + priority + systemdUnit + systemdSlice + } + } + } +} +""" + + +def test_graphql_get_logs_with_up_border(authorized_client): + j = journal.Reader() + j.seek_tail() + + # < - cursor + # <- - log entry will be returned by API call. + # ... + # log < + # log <- + # log <- + # log <- + # log <- + # log + + expected_entries = take_from_journal(j, 6, lambda j: j.get_previous()) + expected_entries.reverse() + + response = authorized_client.post( + "/graphql", + json={ + "query": API_GET_LOGS_WITH_UP_BORDER, + "variables": {"upCursor": expected_entries[0]["__CURSOR"]}, + }, + ) + assert response.status_code == 200 + + expected_entries = expected_entries[1:-1] + returned_entries = response.json()["data"]["logs"]["paginated"]["entries"] + + assert len(returned_entries) == len(expected_entries) + + for api_entry, journal_entry in zip(returned_entries, expected_entries): + assert_log_entry_equals_to_journal_entry(api_entry, journal_entry) + + +def test_graphql_get_logs_with_down_border(authorized_client): + j = journal.Reader() + j.seek_head() + j.get_next() + + # < - cursor + # <- - log entry will be returned by API call. + # log + # log <- + # log <- + # log <- + # log <- + # log < + # ... + + expected_entries = take_from_journal(j, 5, lambda j: j.get_next()) + + response = authorized_client.post( + "/graphql", + json={ + "query": API_GET_LOGS_WITH_DOWN_BORDER, + "variables": {"downCursor": expected_entries[-1]["__CURSOR"]}, + }, + ) + assert response.status_code == 200 + + expected_entries = expected_entries[:-1] + returned_entries = response.json()["data"]["logs"]["paginated"]["entries"] + + assert len(returned_entries) == len(expected_entries) + + for api_entry, journal_entry in zip(returned_entries, expected_entries): + assert_log_entry_equals_to_journal_entry(api_entry, journal_entry) + + +@pytest.mark.asyncio +async def test_websocket_subscription_for_logs(authorized_client): + with authorized_client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + init_graphql(websocket) + websocket.send_json( + { + "id": "3aaa2445", + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription { logEntries { message } }", + }, + } + ) + await asyncio.sleep(1) + + def read_until(message, limit=5): + i = 0 + while i < limit: + msg = websocket.receive_json()["payload"]["data"]["logEntries"][ + "message" + ] + if msg == message: + return + else: + i += 1 + continue + raise Exception("Failed to read websocket data, timeout") + + for i in range(0, 10): + journal.send(f"Lorem ipsum number {i}") + read_until(f"Lorem ipsum number {i}") diff --git a/tests/test_graphql/test_jobs.py b/tests/test_graphql/test_jobs.py new file mode 100644 index 0000000..68a6d20 --- /dev/null +++ b/tests/test_graphql/test_jobs.py @@ -0,0 +1,74 @@ +from tests.common import generate_jobs_query +import tests.test_graphql.test_api_backup + +from tests.test_graphql.common import ( + assert_ok, + assert_empty, + assert_errorcode, + get_data, +) + +from selfprivacy_api.jobs import Jobs + +API_JOBS_QUERY = """ +getJobs { + uid + typeId + name + description + status + statusText + progress + createdAt + updatedAt + finishedAt + error + result +} +""" + + +def graphql_send_query(client, query: str, variables: dict = {}): + return client.post("/graphql", json={"query": query, "variables": variables}) + + +def api_jobs(authorized_client): + response = graphql_send_query( + authorized_client, generate_jobs_query([API_JOBS_QUERY]) + ) + data = get_data(response) + result = data["jobs"]["getJobs"] + assert result is not None + return result + + +def test_all_jobs_unauthorized(client): + response = graphql_send_query(client, generate_jobs_query([API_JOBS_QUERY])) + assert_empty(response) + + +def test_all_jobs_when_none(authorized_client): + output = api_jobs(authorized_client) + assert output == [] + + +def test_all_jobs_when_some(authorized_client): + # We cannot make new jobs via API, at least directly + job = Jobs.add("bogus", "bogus.bogus", "fungus") + output = api_jobs(authorized_client) + + len(output) == 1 + api_job = output[0] + + assert api_job["uid"] == str(job.uid) + assert api_job["typeId"] == job.type_id + assert api_job["name"] == job.name + assert api_job["description"] == job.description + assert api_job["status"] == job.status + assert api_job["statusText"] == job.status_text + assert api_job["progress"] == job.progress + assert api_job["createdAt"] == job.created_at.isoformat() + assert api_job["updatedAt"] == job.updated_at.isoformat() + assert api_job["finishedAt"] == None + assert api_job["error"] == None + assert api_job["result"] == None diff --git a/tests/test_graphql/test_services.py b/tests/test_graphql/test_services.py index 6e8dcf6..b7faf3d 100644 --- a/tests/test_graphql/test_services.py +++ b/tests/test_graphql/test_services.py @@ -543,8 +543,8 @@ def test_disable_enable(authorized_client, only_dummy_service): assert api_dummy_service["status"] == ServiceStatus.ACTIVE.value -def test_move_immovable(authorized_client, only_dummy_service): - dummy_service = only_dummy_service +def test_move_immovable(authorized_client, dummy_service_with_binds): + dummy_service = dummy_service_with_binds dummy_service.set_movable(False) root = BlockDevices().get_root_block_device() mutation_response = api_move(authorized_client, dummy_service, root.name) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py new file mode 100644 index 0000000..754fbbf --- /dev/null +++ b/tests/test_graphql/test_websocket.py @@ -0,0 +1,225 @@ +# from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions +import pytest +import asyncio +from typing import Generator +from time import sleep + +from starlette.testclient import WebSocketTestSession + +from selfprivacy_api.jobs import Jobs +from selfprivacy_api.actions.api_tokens import TOKEN_REPO +from selfprivacy_api.graphql import IsAuthenticated + +from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH +from tests.test_jobs import jobs as empty_jobs + +# We do not iterate through them yet +TESTED_SUBPROTOCOLS = ["graphql-transport-ws"] + +JOBS_SUBSCRIPTION = """ +jobUpdates { + uid + typeId + name + description + status + statusText + progress + createdAt + updatedAt + finishedAt + error + result +} +""" + + +def api_subscribe(websocket, id, subscription): + websocket.send_json( + { + "id": id, + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + subscription + "}", + }, + } + ) + + +def connect_ws_authenticated(authorized_client) -> WebSocketTestSession: + token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"]) + return authorized_client.websocket_connect( + "/graphql", + subprotocols=TESTED_SUBPROTOCOLS, + params={"token": token}, + ) + + +def connect_ws_not_authenticated(client) -> WebSocketTestSession: + return client.websocket_connect( + "/graphql", + subprotocols=TESTED_SUBPROTOCOLS, + params={"token": "I like vegan icecream but it is not a valid token"}, + ) + + +def init_graphql(websocket): + websocket.send_json({"type": "connection_init", "payload": {}}) + ack = websocket.receive_json() + assert ack == {"type": "connection_ack"} + + +@pytest.fixture +def authenticated_websocket( + authorized_client, +) -> Generator[WebSocketTestSession, None, None]: + # We use authorized_client only to have token in the repo, this client by itself is not enough to authorize websocket + + ValueError(TOKEN_REPO.get_tokens()) + with connect_ws_authenticated(authorized_client) as websocket: + yield websocket + + +@pytest.fixture +def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]: + with connect_ws_not_authenticated(client) as websocket: + yield websocket + + +def test_websocket_connection_bare(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws", "graphql-ws"] + ) as websocket: + assert websocket is not None + assert websocket.scope is not None + + +def test_websocket_graphql_init(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + websocket.send_json({"type": "connection_init", "payload": {}}) + ack = websocket.receive_json() + assert ack == {"type": "connection_ack"} + + +def test_websocket_graphql_ping(authorized_client): + client = authorized_client + with client.websocket_connect( + "/graphql", subprotocols=["graphql-transport-ws"] + ) as websocket: + # https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping + websocket.send_json({"type": "ping", "payload": {}}) + pong = websocket.receive_json() + assert pong == {"type": "pong"} + + +def test_websocket_subscription_minimal(authorized_client, authenticated_websocket): + # Test a small endpoint that exists specifically for tests + websocket = authenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 0}}, + "type": "next", + } + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 1}}, + "type": "next", + } + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": {"data": {"count": 2}}, + "type": "next", + } + + +def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket): + websocket = unauthenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") + + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": [{"message": IsAuthenticated.message}], + "type": "error", + } + + +async def read_one_job(websocket): + # Bug? We only get them starting from the second job update + # That's why we receive two jobs in the list them + # The first update gets lost somewhere + response = websocket.receive_json() + return response + + +@pytest.mark.asyncio +async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs): + websocket = authenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION) + + future = asyncio.create_task(read_one_job(websocket)) + jobs = [] + jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) + sleep(0.5) + jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works")) + + response = await future + data = response["payload"]["data"] + jobs_received = data["jobUpdates"] + received_names = [job["name"] for job in jobs_received] + for job in jobs: + assert job.name in received_names + + assert len(jobs_received) == 2 + + for job in jobs: + for api_job in jobs_received: + if (job.name) == api_job["name"]: + assert api_job["uid"] == str(job.uid) + assert api_job["typeId"] == job.type_id + assert api_job["name"] == job.name + assert api_job["description"] == job.description + assert api_job["status"] == job.status + assert api_job["statusText"] == job.status_text + assert api_job["progress"] == job.progress + assert api_job["createdAt"] == job.created_at.isoformat() + assert api_job["updatedAt"] == job.updated_at.isoformat() + assert api_job["finishedAt"] == None + assert api_job["error"] == None + assert api_job["result"] == None + + +def test_websocket_subscription_unauthorized(unauthenticated_websocket): + websocket = unauthenticated_websocket + init_graphql(websocket) + id = "3aaa2445" + api_subscribe(websocket, id, JOBS_SUBSCRIPTION) + + response = websocket.receive_json() + # I do not really know why strawberry gives more info on this + # One versus the counter + payload = response["payload"][0] + assert isinstance(payload, dict) + assert "locations" in payload.keys() + # It looks like this 'locations': [{'column': 32, 'line': 1}] + # We cannot test locations feasibly + del payload["locations"] + assert response == { + "id": id, + "payload": [{"message": IsAuthenticated.message, "path": ["jobUpdates"]}], + "type": "error", + } diff --git a/tests/test_redis.py b/tests/test_redis.py new file mode 100644 index 0000000..02dfb21 --- /dev/null +++ b/tests/test_redis.py @@ -0,0 +1,218 @@ +import asyncio +import pytest +import pytest_asyncio +from asyncio import streams +import redis +from typing import List + +from selfprivacy_api.utils.redis_pool import RedisPool + +from selfprivacy_api.jobs import Jobs, job_notifications + +TEST_KEY = "test:test" +STOPWORD = "STOP" + + +@pytest.fixture() +def empty_redis(event_loop): + r = RedisPool().get_connection() + r.flushdb() + assert r.config_get("notify-keyspace-events")["notify-keyspace-events"] == "AKE" + yield r + r.flushdb() + + +async def write_to_test_key(): + r = RedisPool().get_connection_async() + async with r.pipeline(transaction=True) as pipe: + ok1, ok2 = await pipe.set(TEST_KEY, "value1").set(TEST_KEY, "value2").execute() + assert ok1 + assert ok2 + assert await r.get(TEST_KEY) == "value2" + await r.close() + + +def test_async_connection(empty_redis): + r = RedisPool().get_connection() + assert not r.exists(TEST_KEY) + # It _will_ report an error if it arises + asyncio.run(write_to_test_key()) + # Confirming that we can read result from sync connection too + assert r.get(TEST_KEY) == "value2" + + +async def channel_reader(channel: redis.client.PubSub) -> List[dict]: + result: List[dict] = [] + while True: + # Mypy cannot correctly detect that it is a coroutine + # But it is + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + result.append(message) + if message["data"] == STOPWORD: + break + return result + + +async def channel_reader_onemessage(channel: redis.client.PubSub) -> dict: + while True: + # Mypy cannot correctly detect that it is a coroutine + # But it is + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + return message + + +@pytest.mark.asyncio +async def test_pubsub(empty_redis, event_loop): + # Adapted from : + # https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html + # Sanity checking because of previous event loop bugs + assert event_loop == asyncio.get_event_loop() + assert event_loop == asyncio.events.get_event_loop() + assert event_loop == asyncio.events._get_event_loop() + assert event_loop == asyncio.events.get_running_loop() + + reader = streams.StreamReader(34) + assert event_loop == reader._loop + f = reader._loop.create_future() + f.set_result(3) + await f + + r = RedisPool().get_connection_async() + async with r.pubsub() as pubsub: + await pubsub.subscribe("channel:1") + future = asyncio.create_task(channel_reader(pubsub)) + + await r.publish("channel:1", "Hello") + # message: dict = await pubsub.get_message(ignore_subscribe_messages=True, timeout=5.0) # type: ignore + # raise ValueError(message) + await r.publish("channel:1", "World") + await r.publish("channel:1", STOPWORD) + + messages = await future + + assert len(messages) == 3 + + message = messages[0] + assert "data" in message.keys() + assert message["data"] == "Hello" + message = messages[1] + assert "data" in message.keys() + assert message["data"] == "World" + message = messages[2] + assert "data" in message.keys() + assert message["data"] == STOPWORD + + await r.close() + + +@pytest.mark.asyncio +async def test_keyspace_notifications_simple(empty_redis, event_loop): + r = RedisPool().get_connection_async() + await r.set(TEST_KEY, "I am not empty") + async with r.pubsub() as pubsub: + await pubsub.subscribe("__keyspace@0__:" + TEST_KEY) + + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": None, + "type": "message", + } + + +@pytest.mark.asyncio +async def test_keyspace_notifications(empty_redis, event_loop): + pubsub = await RedisPool().subscribe_to_keys(TEST_KEY) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": f"__keyspace@0__:{TEST_KEY}", + "type": "pmessage", + } + + +@pytest.mark.asyncio +async def test_keyspace_notifications_patterns(empty_redis, event_loop): + pattern = "test*" + pubsub = await RedisPool().subscribe_to_keys(pattern) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + empty_redis.set(TEST_KEY, "I am set!") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message == { + "channel": f"__keyspace@0__:{TEST_KEY}", + "data": "set", + "pattern": f"__keyspace@0__:{pattern}", + "type": "pmessage", + } + + +@pytest.mark.asyncio +async def test_keyspace_notifications_jobs(empty_redis, event_loop): + pattern = "jobs:*" + pubsub = await RedisPool().subscribe_to_keys(pattern) + async with pubsub: + future_message = asyncio.create_task(channel_reader_onemessage(pubsub)) + Jobs.add("testjob1", "test.test", "Testing aaaalll day") + message = await future_message + assert message is not None + assert message["data"] is not None + assert message["data"] == "hset" + + +async def reader_of_jobs() -> List[dict]: + """ + Reads 3 job updates and exits + """ + result: List[dict] = [] + async for message in job_notifications(): + result.append(message) + if len(result) >= 3: + break + return result + + +@pytest.mark.asyncio +async def test_jobs_generator(empty_redis, event_loop): + # Will read exactly 3 job messages + future_messages = asyncio.create_task(reader_of_jobs()) + await asyncio.sleep(1) + + Jobs.add("testjob1", "test.test", "Testing aaaalll day") + Jobs.add("testjob2", "test.test", "Testing aaaalll day") + Jobs.add("testjob3", "test.test", "Testing aaaalll day") + Jobs.add("testjob4", "test.test", "Testing aaaalll day") + + assert len(Jobs.get_jobs()) == 4 + r = RedisPool().get_connection() + assert len(r.keys("jobs:*")) == 4 + + messages = await future_messages + assert len(messages) == 3 + channels = [message["channel"] for message in messages] + operations = [message["data"] for message in messages] + assert set(operations) == set(["hset"]) # all of them are hsets + + # Asserting that all of jobs emitted exactly one message + jobs = Jobs.get_jobs() + names = ["testjob1", "testjob2", "testjob3"] + ids = [str(job.uid) for job in jobs if job.name in names] + for id in ids: + assert id in " ".join(channels) + # Asserting that they came in order + assert "testjob4" not in " ".join(channels) diff --git a/tests/test_services_systemctl.py b/tests/test_services_systemctl.py index 8b247e0..43805e8 100644 --- a/tests/test_services_systemctl.py +++ b/tests/test_services_systemctl.py @@ -2,7 +2,7 @@ import pytest from selfprivacy_api.services.service import ServiceStatus from selfprivacy_api.services.bitwarden import Bitwarden -from selfprivacy_api.services.gitea import Gitea +from selfprivacy_api.services.forgejo import Forgejo from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.nextcloud import Nextcloud from selfprivacy_api.services.ocserv import Ocserv @@ -22,7 +22,7 @@ def call_args_asserts(mocked_object): "dovecot2.service", "postfix.service", "vaultwarden.service", - "gitea.service", + "forgejo.service", "phpfpm-nextcloud.service", "ocserv.service", "pleroma.service", @@ -77,7 +77,7 @@ def mock_popen_systemctl_service_not_ok(mocker): def test_systemctl_ok(mock_popen_systemctl_service_ok): assert MailServer.get_status() == ServiceStatus.ACTIVE assert Bitwarden.get_status() == ServiceStatus.ACTIVE - assert Gitea.get_status() == ServiceStatus.ACTIVE + assert Forgejo.get_status() == ServiceStatus.ACTIVE assert Nextcloud.get_status() == ServiceStatus.ACTIVE assert Ocserv.get_status() == ServiceStatus.ACTIVE assert Pleroma.get_status() == ServiceStatus.ACTIVE @@ -87,7 +87,7 @@ def test_systemctl_ok(mock_popen_systemctl_service_ok): def test_systemctl_failed_service(mock_popen_systemctl_service_not_ok): assert MailServer.get_status() == ServiceStatus.FAILED assert Bitwarden.get_status() == ServiceStatus.FAILED - assert Gitea.get_status() == ServiceStatus.FAILED + assert Forgejo.get_status() == ServiceStatus.FAILED assert Nextcloud.get_status() == ServiceStatus.FAILED assert Ocserv.get_status() == ServiceStatus.FAILED assert Pleroma.get_status() == ServiceStatus.FAILED diff --git a/tests/test_websocket_uvicorn_standalone.py b/tests/test_websocket_uvicorn_standalone.py new file mode 100644 index 0000000..43a53ef --- /dev/null +++ b/tests/test_websocket_uvicorn_standalone.py @@ -0,0 +1,39 @@ +import pytest +from fastapi import FastAPI, WebSocket +import uvicorn + +# import subprocess +from multiprocessing import Process +import asyncio +from time import sleep +from websockets import client + +app = FastAPI() + + +@app.websocket("/") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + while True: + data = await websocket.receive_text() + await websocket.send_text(f"You sent: {data}") + + +def run_uvicorn(): + uvicorn.run(app, port=5000) + return True + + +@pytest.mark.asyncio +async def test_uvcorn_ws_works_in_prod(): + proc = Process(target=run_uvicorn) + proc.start() + sleep(2) + + ws = await client.connect("ws://127.0.0.1:5000") + + await ws.send("hohoho") + message = await ws.read_message() + assert message == "You sent: hohoho" + await ws.close() + proc.kill()