Merge remote-tracking branch 'origin/master' into inex/service-settings

This commit is contained in:
Inex Code 2024-07-15 18:15:14 +04:00
commit c8d00e6c87
55 changed files with 1313 additions and 125 deletions

View file

@ -13,9 +13,9 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
For detailed installation information, please review and follow: [link](https://nixos.org/manual/nix/stable/installation/installing-binary.html#installing-a-binary-distribution).
3. **Change directory to the cloned repository and start a nix shell:**
3. **Change directory to the cloned repository and start a nix development shell:**
```cd selfprivacy-rest-api && nix-shell```
```cd selfprivacy-rest-api && nix develop```
Nix will install all of the necessary packages for development work, all further actions will take place only within nix-shell.
@ -31,7 +31,7 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
Copy the path that starts with ```/nix/store/``` and ends with ```env/bin/python```
```/nix/store/???-python3-3.9.??-env/bin/python```
```/nix/store/???-python3-3.10.??-env/bin/python```
Click on the python version selection in the lower right corner, and replace the path to the interpreter in the project with the one you copied from the terminal.
@ -43,12 +43,13 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
## What to do after making changes to the repository?
**Run unit tests** using ```pytest .```
Make sure that all tests pass successfully and the API works correctly. For convenience, you can use the built-in VScode interface.
**Run unit tests** using ```pytest-vm``` inside of the development shell. This will run all the test inside a virtual machine, which is necessary for the tests to pass successfully.
Make sure that all tests pass successfully and the API works correctly.
How to review the percentage of code coverage? Execute the command:
The ```pytest-vm``` command will also print out the coverage of the tests. To export the report to an XML file, use the following command:
```coverage xml```
```coverage run -m pytest && coverage xml && coverage report```
Next, use the recommended extension ```ryanluker.vscode-coverage-gutters```, navigate to one of the test files, and click the "watch" button on the bottom panel of VScode.

View file

@ -14,10 +14,12 @@ pythonPackages.buildPythonPackage rec {
pydantic
pytz
redis
systemd
setuptools
strawberry-graphql
typing-extensions
uvicorn
websockets
];
pythonImportsCheck = [ "selfprivacy_api" ];
doCheck = false;

View file

@ -2,11 +2,11 @@
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1709677081,
"narHash": "sha256-tix36Y7u0rkn6mTm0lA45b45oab2cFLqAzDbJxeXS+c=",
"lastModified": 1719957072,
"narHash": "sha256-gvFhEf5nszouwLAkT9nWsDzocUTqLWHuL++dvNjMp9I=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "880992dcc006a5e00dd0591446fdf723e6a51a64",
"rev": "7144d6241f02d171d25fba3edeaf15e0f2592105",
"type": "github"
},
"original": {

View file

@ -20,6 +20,7 @@
pytest-datadir
pytest-mock
pytest-subprocess
pytest-asyncio
black
mypy
pylsp-mypy

View file

@ -3,6 +3,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import GraphQLRouter
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
import uvicorn
@ -13,8 +14,12 @@ from selfprivacy_api.migrations import run_migrations
app = FastAPI()
graphql_app = GraphQLRouter(
graphql_app: GraphQLRouter = GraphQLRouter(
schema,
subscription_protocols=[
GRAPHQL_TRANSPORT_WS_PROTOCOL,
GRAPHQL_WS_PROTOCOL,
],
)
app.add_middleware(

View file

@ -27,4 +27,4 @@ async def get_token_header(
def get_api_version() -> str:
"""Get API version"""
return "3.2.2+configs"
return "3.3.0+configs"

View file

@ -16,6 +16,10 @@ class IsAuthenticated(BasePermission):
token = info.context["request"].headers.get("Authorization")
if token is None:
token = info.context["request"].query_params.get("token")
if token is None:
connection_params = info.context.get("connection_params")
if connection_params is not None:
token = connection_params.get("Authorization")
if token is None:
return False
return is_token_valid(token.replace("Bearer ", ""))

View file

@ -1,8 +1,10 @@
"""API access status"""
# pylint: disable=too-few-public-methods
import datetime
import typing
import strawberry
from strawberry.types import Info
from selfprivacy_api.actions.api_tokens import (
get_api_tokens_with_caller_flag,

View file

@ -1,4 +1,5 @@
"""Backup"""
# pylint: disable=too-few-public-methods
import typing
import strawberry

View file

@ -1,4 +1,5 @@
"""Common types and enums used by different types of queries."""
from enum import Enum
import datetime
import typing

View file

@ -1,24 +1,30 @@
"""Jobs status"""
# pylint: disable=too-few-public-methods
import typing
import strawberry
from typing import List, Optional
from selfprivacy_api.jobs import Jobs
from selfprivacy_api.graphql.common_types.jobs import (
ApiJob,
get_api_job_by_id,
job_to_api_job,
)
from selfprivacy_api.jobs import Jobs
def get_all_jobs() -> List[ApiJob]:
jobs = Jobs.get_jobs()
api_jobs = [job_to_api_job(job) for job in jobs]
assert api_jobs is not None
return api_jobs
@strawberry.type
class Job:
@strawberry.field
def get_jobs(self) -> typing.List[ApiJob]:
Jobs.get_jobs()
return [job_to_api_job(job) for job in Jobs.get_jobs()]
def get_jobs(self) -> List[ApiJob]:
return get_all_jobs()
@strawberry.field
def get_job(self, job_id: str) -> typing.Optional[ApiJob]:
def get_job(self, job_id: str) -> Optional[ApiJob]:
return get_api_job_by_id(job_id)

View file

@ -0,0 +1,88 @@
"""System logs"""
from datetime import datetime
import typing
import strawberry
from selfprivacy_api.utils.systemd_journal import get_paginated_logs
@strawberry.type
class LogEntry:
message: str = strawberry.field()
timestamp: datetime = strawberry.field()
priority: typing.Optional[int] = strawberry.field()
systemd_unit: typing.Optional[str] = strawberry.field()
systemd_slice: typing.Optional[str] = strawberry.field()
def __init__(self, journal_entry: typing.Dict):
self.entry = journal_entry
self.message = journal_entry["MESSAGE"]
self.timestamp = journal_entry["__REALTIME_TIMESTAMP"]
self.priority = journal_entry.get("PRIORITY")
self.systemd_unit = journal_entry.get("_SYSTEMD_UNIT")
self.systemd_slice = journal_entry.get("_SYSTEMD_SLICE")
@strawberry.field()
def cursor(self) -> str:
return self.entry["__CURSOR"]
@strawberry.type
class LogsPageMeta:
up_cursor: typing.Optional[str] = strawberry.field()
down_cursor: typing.Optional[str] = strawberry.field()
def __init__(
self, up_cursor: typing.Optional[str], down_cursor: typing.Optional[str]
):
self.up_cursor = up_cursor
self.down_cursor = down_cursor
@strawberry.type
class PaginatedEntries:
page_meta: LogsPageMeta = strawberry.field(
description="Metadata to aid in pagination."
)
entries: typing.List[LogEntry] = strawberry.field(
description="The list of log entries."
)
def __init__(self, meta: LogsPageMeta, entries: typing.List[LogEntry]):
self.page_meta = meta
self.entries = entries
@staticmethod
def from_entries(entries: typing.List[LogEntry]):
if entries == []:
return PaginatedEntries(LogsPageMeta(None, None), [])
return PaginatedEntries(
LogsPageMeta(
entries[0].cursor(),
entries[-1].cursor(),
),
entries,
)
@strawberry.type
class Logs:
@strawberry.field()
def paginated(
self,
limit: int = 20,
# All entries returned will be lesser than this cursor. Sets upper bound on results.
up_cursor: str | None = None,
# All entries returned will be greater than this cursor. Sets lower bound on results.
down_cursor: str | None = None,
) -> PaginatedEntries:
if limit > 50:
raise Exception("You can't fetch more than 50 entries via single request.")
return PaginatedEntries.from_entries(
list(
map(
lambda x: LogEntry(x),
get_paginated_logs(limit, up_cursor, down_cursor),
)
)
)

View file

@ -1,4 +1,5 @@
"""Enums representing different service providers."""
from enum import Enum
import strawberry

View file

@ -1,4 +1,5 @@
"""Services status"""
# pylint: disable=too-few-public-methods
import typing
import strawberry

View file

@ -1,4 +1,5 @@
"""Storage queries."""
# pylint: disable=too-few-public-methods
import typing
import strawberry
@ -18,9 +19,11 @@ class Storage:
"""Get list of volumes"""
return [
StorageVolume(
total_space=str(volume.fssize)
if volume.fssize is not None
else str(volume.size),
total_space=(
str(volume.fssize)
if volume.fssize is not None
else str(volume.size)
),
free_space=str(volume.fsavail),
used_space=str(volume.fsused),
root=volume.is_root(),

View file

@ -1,8 +1,10 @@
"""Common system information and settings"""
# pylint: disable=too-few-public-methods
import os
import typing
import strawberry
from selfprivacy_api.graphql.common_types.dns import DnsRecord
from selfprivacy_api.graphql.queries.common import Alert, Severity

View file

@ -1,4 +1,5 @@
"""Users"""
# pylint: disable=too-few-public-methods
import typing
import strawberry

View file

@ -2,8 +2,10 @@
# pylint: disable=too-few-public-methods
import asyncio
from typing import AsyncGenerator
from typing import AsyncGenerator, List
import strawberry
from strawberry.types import Info
from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.deprecated_mutations import (
DeprecatedApiMutations,
@ -24,10 +26,17 @@ from selfprivacy_api.graphql.mutations.backup_mutations import BackupMutations
from selfprivacy_api.graphql.queries.api_queries import Api
from selfprivacy_api.graphql.queries.backup import Backup
from selfprivacy_api.graphql.queries.jobs import Job
from selfprivacy_api.graphql.queries.logs import LogEntry, Logs
from selfprivacy_api.graphql.queries.services import Services
from selfprivacy_api.graphql.queries.storage import Storage
from selfprivacy_api.graphql.queries.system import System
from selfprivacy_api.graphql.subscriptions.jobs import ApiJob
from selfprivacy_api.graphql.subscriptions.jobs import (
job_updates as job_update_generator,
)
from selfprivacy_api.graphql.subscriptions.logs import log_stream
from selfprivacy_api.graphql.common_types.service import (
StringConfigItem,
BoolConfigItem,
@ -53,6 +62,11 @@ class Query:
"""System queries"""
return System()
@strawberry.field(permission_classes=[IsAuthenticated])
def logs(self) -> Logs:
"""Log queries"""
return Logs()
@strawberry.field(permission_classes=[IsAuthenticated])
def users(self) -> Users:
"""Users queries"""
@ -135,19 +149,42 @@ class Mutation(
code=200,
)
pass
# A cruft for Websockets
def authenticated(info: Info) -> bool:
return IsAuthenticated().has_permission(source=None, info=info)
def reject_if_unauthenticated(info: Info):
if not authenticated(info):
raise Exception(IsAuthenticated().message)
@strawberry.type
class Subscription:
"""Root schema for subscriptions"""
"""Root schema for subscriptions.
Every field here should be an AsyncIterator or AsyncGenerator
It is not a part of the spec but graphql-core (dep of strawberryql)
demands it while the spec is vague in this area."""
@strawberry.subscription(permission_classes=[IsAuthenticated])
async def count(self, target: int = 100) -> AsyncGenerator[int, None]:
for i in range(target):
@strawberry.subscription
async def job_updates(self, info: Info) -> AsyncGenerator[List[ApiJob], None]:
reject_if_unauthenticated(info)
return job_update_generator()
@strawberry.subscription
# Used for testing, consider deletion to shrink attack surface
async def count(self, info: Info) -> AsyncGenerator[int, None]:
reject_if_unauthenticated(info)
for i in range(10):
yield i
await asyncio.sleep(0.5)
@strawberry.subscription
async def log_entries(self, info: Info) -> AsyncGenerator[LogEntry, None]:
reject_if_unauthenticated(info)
return log_stream()
schema = strawberry.Schema(
query=Query,

View file

@ -0,0 +1,14 @@
# pylint: disable=too-few-public-methods
from typing import AsyncGenerator, List
from selfprivacy_api.jobs import job_notifications
from selfprivacy_api.graphql.common_types.jobs import ApiJob
from selfprivacy_api.graphql.queries.jobs import get_all_jobs
async def job_updates() -> AsyncGenerator[List[ApiJob], None]:
# Send the complete list of jobs every time anything gets updated
async for notification in job_notifications():
yield get_all_jobs()

View file

@ -0,0 +1,31 @@
from typing import AsyncGenerator
from systemd import journal
import asyncio
from selfprivacy_api.graphql.queries.logs import LogEntry
async def log_stream() -> AsyncGenerator[LogEntry, None]:
j = journal.Reader()
j.seek_tail()
j.get_previous()
queue = asyncio.Queue()
async def callback():
if j.process() != journal.APPEND:
return
for entry in j:
await queue.put(entry)
asyncio.get_event_loop().add_reader(j, lambda: asyncio.ensure_future(callback()))
while True:
entry = await queue.get()
try:
yield LogEntry(entry)
except Exception:
asyncio.get_event_loop().remove_reader(j)
return
queue.task_done()

View file

@ -15,6 +15,7 @@ A job is a dictionary with the following keys:
- result: result of the job
"""
import typing
import asyncio
import datetime
from uuid import UUID
import uuid
@ -23,6 +24,7 @@ from enum import Enum
from pydantic import BaseModel
from selfprivacy_api.utils.redis_pool import RedisPool
from selfprivacy_api.utils.redis_model_storage import store_model_as_hash
JOB_EXPIRATION_SECONDS = 10 * 24 * 60 * 60 # ten days
@ -102,7 +104,7 @@ class Jobs:
result=None,
)
redis = RedisPool().get_connection()
_store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job)
store_model_as_hash(redis, _redis_key_from_uuid(job.uid), job)
return job
@staticmethod
@ -218,7 +220,7 @@ class Jobs:
redis = RedisPool().get_connection()
key = _redis_key_from_uuid(job.uid)
if redis.exists(key):
_store_job_as_hash(redis, key, job)
store_model_as_hash(redis, key, job)
if status in (JobStatus.FINISHED, JobStatus.ERROR):
redis.expire(key, JOB_EXPIRATION_SECONDS)
@ -294,17 +296,6 @@ def _progress_log_key_from_uuid(uuid_string) -> str:
return PROGRESS_LOGS_PREFIX + str(uuid_string)
def _store_job_as_hash(redis, redis_key, model) -> None:
for key, value in model.dict().items():
if isinstance(value, uuid.UUID):
value = str(value)
if isinstance(value, datetime.datetime):
value = value.isoformat()
if isinstance(value, JobStatus):
value = value.value
redis.hset(redis_key, key, str(value))
def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
if redis.exists(redis_key):
job_dict = redis.hgetall(redis_key)
@ -321,3 +312,15 @@ def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
return Job(**job_dict)
return None
async def job_notifications() -> typing.AsyncGenerator[dict, None]:
channel = await RedisPool().subscribe_to_keys("jobs:*")
while True:
try:
# we cannot timeout here because we do not know when the next message is supposed to arrive
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
if message is not None:
yield message
except GeneratorExit:
break

View file

@ -14,10 +14,12 @@ from selfprivacy_api.migrations.write_token_to_redis import WriteTokenToRedis
from selfprivacy_api.migrations.check_for_system_rebuild_jobs import (
CheckForSystemRebuildJobs,
)
from selfprivacy_api.migrations.add_roundcube import AddRoundcube
migrations = [
WriteTokenToRedis(),
CheckForSystemRebuildJobs(),
AddRoundcube(),
]

View file

@ -0,0 +1,36 @@
from selfprivacy_api.migrations.migration import Migration
from selfprivacy_api.services.flake_service_manager import FlakeServiceManager
from selfprivacy_api.utils import ReadUserData, WriteUserData
class AddRoundcube(Migration):
"""Adds the Roundcube if it is not present."""
def get_migration_name(self) -> str:
return "add_roundcube"
def get_migration_description(self) -> str:
return "Adds the Roundcube if it is not present."
def is_migration_needed(self) -> bool:
with FlakeServiceManager() as manager:
if "roundcube" not in manager.services:
return True
with ReadUserData() as data:
if "roundcube" not in data["modules"]:
return True
return False
def migrate(self) -> None:
with FlakeServiceManager() as manager:
if "roundcube" not in manager.services:
manager.services[
"roundcube"
] = "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/roundcube"
with WriteUserData() as data:
if "roundcube" not in data["modules"]:
data["modules"]["roundcube"] = {
"enable": False,
"subdomain": "roundcube",
}

View file

@ -5,13 +5,13 @@ from selfprivacy_api.jobs import JobStatus, Jobs
class CheckForSystemRebuildJobs(Migration):
"""Check if there are unfinished system rebuild jobs and finish them"""
def get_migration_name(self):
def get_migration_name(self) -> str:
return "check_for_system_rebuild_jobs"
def get_migration_description(self):
def get_migration_description(self) -> str:
return "Check if there are unfinished system rebuild jobs and finish them"
def is_migration_needed(self):
def is_migration_needed(self) -> bool:
# Check if there are any unfinished system rebuild jobs
for job in Jobs.get_jobs():
if (
@ -25,8 +25,9 @@ class CheckForSystemRebuildJobs(Migration):
JobStatus.RUNNING,
]:
return True
return False
def migrate(self):
def migrate(self) -> None:
# As the API is restarted, we assume that the jobs are finished
for job in Jobs.get_jobs():
if (

View file

@ -12,17 +12,17 @@ class Migration(ABC):
"""
@abstractmethod
def get_migration_name(self):
def get_migration_name(self) -> str:
pass
@abstractmethod
def get_migration_description(self):
def get_migration_description(self) -> str:
pass
@abstractmethod
def is_migration_needed(self):
def is_migration_needed(self) -> bool:
pass
@abstractmethod
def migrate(self):
def migrate(self) -> None:
pass

View file

@ -15,10 +15,10 @@ from selfprivacy_api.utils import ReadUserData, UserDataFiles
class WriteTokenToRedis(Migration):
"""Load Json tokens into Redis"""
def get_migration_name(self):
def get_migration_name(self) -> str:
return "write_token_to_redis"
def get_migration_description(self):
def get_migration_description(self) -> str:
return "Loads the initial token into redis token storage"
def is_repo_empty(self, repo: AbstractTokensRepository) -> bool:
@ -38,7 +38,7 @@ class WriteTokenToRedis(Migration):
print(e)
return None
def is_migration_needed(self):
def is_migration_needed(self) -> bool:
try:
if self.get_token_from_json() is not None and self.is_repo_empty(
RedisTokensRepository()
@ -47,8 +47,9 @@ class WriteTokenToRedis(Migration):
except Exception as e:
print(e)
return False
return False
def migrate(self):
def migrate(self) -> None:
# Write info about providers to userdata.json
try:
token = self.get_token_from_json()

View file

@ -4,6 +4,7 @@ import typing
from selfprivacy_api.services.bitwarden import Bitwarden
from selfprivacy_api.services.forgejo import Forgejo
from selfprivacy_api.services.jitsimeet import JitsiMeet
from selfprivacy_api.services.roundcube import Roundcube
from selfprivacy_api.services.mailserver import MailServer
from selfprivacy_api.services.nextcloud import Nextcloud
from selfprivacy_api.services.pleroma import Pleroma
@ -19,6 +20,7 @@ services: list[Service] = [
Pleroma(),
Ocserv(),
JitsiMeet(),
Roundcube(),
]

View file

@ -37,14 +37,14 @@ class Bitwarden(Service):
def get_user() -> str:
return "vaultwarden"
@staticmethod
def get_url() -> Optional[str]:
@classmethod
def get_url(cls) -> Optional[str]:
"""Return service url."""
domain = get_domain()
return f"https://password.{domain}"
@staticmethod
def get_subdomain() -> Optional[str]:
@classmethod
def get_subdomain(cls) -> Optional[str]:
return "password"
@staticmethod

View file

@ -34,20 +34,20 @@ class FlakeServiceManager:
file.write(
"""
{
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";\n
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";\n
"""
)
for key, value in self.services.items():
file.write(
f"""
inputs.{key}.url = {value};
inputs.{key}.url = {value};
"""
)
file.write(
"""
outputs = _: { };
outputs = _: { };
}
"""
)

View file

@ -90,14 +90,14 @@ class Forgejo(Service):
"""Read SVG icon from file and return it as base64 encoded string."""
return base64.b64encode(FORGEJO_ICON.encode("utf-8")).decode("utf-8")
@staticmethod
def get_url() -> Optional[str]:
@classmethod
def get_url(cls) -> Optional[str]:
"""Return service url."""
domain = get_domain()
return f"https://git.{domain}"
@staticmethod
def get_subdomain() -> Optional[str]:
@classmethod
def get_subdomain(cls) -> Optional[str]:
return "git"
@staticmethod

View file

@ -36,14 +36,14 @@ class JitsiMeet(Service):
"""Read SVG icon from file and return it as base64 encoded string."""
return base64.b64encode(JITSI_ICON.encode("utf-8")).decode("utf-8")
@staticmethod
def get_url() -> Optional[str]:
@classmethod
def get_url(cls) -> Optional[str]:
"""Return service url."""
domain = get_domain()
return f"https://meet.{domain}"
@staticmethod
def get_subdomain() -> Optional[str]:
@classmethod
def get_subdomain(cls) -> Optional[str]:
return "meet"
@staticmethod

View file

@ -35,13 +35,13 @@ class MailServer(Service):
def get_user() -> str:
return "virtualMail"
@staticmethod
def get_url() -> Optional[str]:
@classmethod
def get_url(cls) -> Optional[str]:
"""Return service url."""
return None
@staticmethod
def get_subdomain() -> Optional[str]:
@classmethod
def get_subdomain(cls) -> Optional[str]:
return None
@staticmethod

View file

@ -4,7 +4,6 @@ import subprocess
from typing import Optional, List
from selfprivacy_api.utils import get_domain
from selfprivacy_api.jobs import Job, Jobs
from selfprivacy_api.utils.systemd import get_service_status
from selfprivacy_api.services.service import Service, ServiceStatus
@ -35,14 +34,14 @@ class Nextcloud(Service):
"""Read SVG icon from file and return it as base64 encoded string."""
return base64.b64encode(NEXTCLOUD_ICON.encode("utf-8")).decode("utf-8")
@staticmethod
def get_url() -> Optional[str]:
@classmethod
def get_url(cls) -> Optional[str]:
"""Return service url."""
domain = get_domain()
return f"https://cloud.{domain}"
@staticmethod
def get_subdomain() -> Optional[str]:
@classmethod
def get_subdomain(cls) -> Optional[str]:
return "cloud"
@staticmethod

View file

@ -28,13 +28,13 @@ class Ocserv(Service):
def get_svg_icon() -> str:
return base64.b64encode(OCSERV_ICON.encode("utf-8")).decode("utf-8")
@staticmethod
def get_url() -> typing.Optional[str]:
@classmethod
def get_url(cls) -> typing.Optional[str]:
"""Return service url."""
return None
@staticmethod
def get_subdomain() -> typing.Optional[str]:
@classmethod
def get_subdomain(cls) -> typing.Optional[str]:
return "vpn"
@staticmethod

View file

@ -31,14 +31,14 @@ class Pleroma(Service):
def get_svg_icon() -> str:
return base64.b64encode(PLEROMA_ICON.encode("utf-8")).decode("utf-8")
@staticmethod
def get_url() -> Optional[str]:
@classmethod
def get_url(cls) -> Optional[str]:
"""Return service url."""
domain = get_domain()
return f"https://social.{domain}"
@staticmethod
def get_subdomain() -> Optional[str]:
@classmethod
def get_subdomain(cls) -> Optional[str]:
return "social"
@staticmethod

View file

@ -0,0 +1,113 @@
"""Class representing Roundcube service"""
import base64
import subprocess
from typing import List, Optional
from selfprivacy_api.jobs import Job
from selfprivacy_api.utils.systemd import (
get_service_status_from_several_units,
)
from selfprivacy_api.services.service import Service, ServiceStatus
from selfprivacy_api.utils import ReadUserData, get_domain
from selfprivacy_api.utils.block_devices import BlockDevice
from selfprivacy_api.services.roundcube.icon import ROUNDCUBE_ICON
class Roundcube(Service):
"""Class representing roundcube service"""
@staticmethod
def get_id() -> str:
"""Return service id."""
return "roundcube"
@staticmethod
def get_display_name() -> str:
"""Return service display name."""
return "Roundcube"
@staticmethod
def get_description() -> str:
"""Return service description."""
return "Roundcube is an open source webmail software."
@staticmethod
def get_svg_icon() -> str:
"""Read SVG icon from file and return it as base64 encoded string."""
return base64.b64encode(ROUNDCUBE_ICON.encode("utf-8")).decode("utf-8")
@classmethod
def get_url(cls) -> Optional[str]:
"""Return service url."""
domain = get_domain()
subdomain = cls.get_subdomain()
return f"https://{subdomain}.{domain}"
@classmethod
def get_subdomain(cls) -> Optional[str]:
with ReadUserData() as data:
if "roundcube" in data["modules"]:
return data["modules"]["roundcube"]["subdomain"]
return "roundcube"
@staticmethod
def is_movable() -> bool:
return False
@staticmethod
def is_required() -> bool:
return False
@staticmethod
def can_be_backed_up() -> bool:
return False
@staticmethod
def get_backup_description() -> str:
return "Nothing to backup."
@staticmethod
def get_status() -> ServiceStatus:
return get_service_status_from_several_units(["phpfpm-roundcube.service"])
@staticmethod
def stop():
subprocess.run(
["systemctl", "stop", "phpfpm-roundcube.service"],
check=False,
)
@staticmethod
def start():
subprocess.run(
["systemctl", "start", "phpfpm-roundcube.service"],
check=False,
)
@staticmethod
def restart():
subprocess.run(
["systemctl", "restart", "phpfpm-roundcube.service"],
check=False,
)
@staticmethod
def get_configuration():
return {}
@staticmethod
def set_configuration(config_items):
return super().set_configuration(config_items)
@staticmethod
def get_logs():
return ""
@staticmethod
def get_folders() -> List[str]:
return []
def move_to_volume(self, volume: BlockDevice) -> Job:
raise NotImplementedError("roundcube service is not movable")

View file

@ -0,0 +1,7 @@
ROUNDCUBE_ICON = """
<svg fill="none" version="1.1" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
<g transform="translate(29.07 -.3244)">
<path d="m-17.02 2.705c-4.01 2e-7 -7.283 3.273-7.283 7.283 0 0.00524-1.1e-5 0.01038 0 0.01562l-1.85 1.068v5.613l9.105 5.26 9.104-5.26v-5.613l-1.797-1.037c1.008e-4 -0.01573 0.00195-0.03112 0.00195-0.04688-1e-7 -4.01-3.271-7.283-7.281-7.283zm0 2.012c2.923 1e-7 5.27 2.349 5.27 5.271 0 2.923-2.347 5.27-5.27 5.27-2.923-1e-6 -5.271-2.347-5.271-5.27 0-2.923 2.349-5.271 5.271-5.271z" fill="#000" fill-rule="evenodd" stroke-linejoin="bevel"/>
</g>
</svg>
"""

View file

@ -65,17 +65,17 @@ class Service(ABC):
"""
pass
@staticmethod
@classmethod
@abstractmethod
def get_url() -> Optional[str]:
def get_url(cls) -> Optional[str]:
"""
The url of the service if it is accessible from the internet browser.
"""
pass
@staticmethod
@classmethod
@abstractmethod
def get_subdomain() -> Optional[str]:
def get_subdomain(cls) -> Optional[str]:
"""
The assigned primary subdomain for this service.
"""

View file

@ -57,14 +57,14 @@ class DummyService(Service):
# return ""
return base64.b64encode(BITWARDEN_ICON.encode("utf-8")).decode("utf-8")
@staticmethod
def get_url() -> typing.Optional[str]:
@classmethod
def get_url(cls) -> typing.Optional[str]:
"""Return service url."""
domain = "test.com"
return f"https://password.{domain}"
@staticmethod
def get_subdomain() -> typing.Optional[str]:
@classmethod
def get_subdomain(cls) -> typing.Optional[str]:
return "password"
@classmethod

View file

@ -1,15 +1,23 @@
import uuid
from datetime import datetime
from typing import Optional
from enum import Enum
def store_model_as_hash(redis, redis_key, model):
for key, value in model.dict().items():
model_dict = model.dict()
for key, value in model_dict.items():
if isinstance(value, uuid.UUID):
value = str(value)
if isinstance(value, datetime):
value = value.isoformat()
if isinstance(value, Enum):
value = value.value
redis.hset(redis_key, key, str(value))
value = str(value)
model_dict[key] = value
redis.hset(redis_key, mapping=model_dict)
def hash_as_model(redis, redis_key: str, model_class):

View file

@ -2,23 +2,33 @@
Redis pool module for selfprivacy_api
"""
import redis
import redis.asyncio as redis_async
from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass
REDIS_SOCKET = "/run/redis-sp-api/redis.sock"
class RedisPool(metaclass=SingletonMetaclass):
# class RedisPool(metaclass=SingletonMetaclass):
class RedisPool:
"""
Redis connection pool singleton.
"""
def __init__(self):
self._dbnumber = 0
url = RedisPool.connection_url(dbnumber=self._dbnumber)
# We need a normal sync pool because otherwise
# our whole API will need to be async
self._pool = redis.ConnectionPool.from_url(
RedisPool.connection_url(dbnumber=0),
url,
decode_responses=True,
)
# We need an async pool for pubsub
self._async_pool = redis_async.ConnectionPool.from_url(
url,
decode_responses=True,
)
self._pubsub_connection = self.get_connection()
@staticmethod
def connection_url(dbnumber: int) -> str:
@ -34,8 +44,15 @@ class RedisPool(metaclass=SingletonMetaclass):
"""
return redis.Redis(connection_pool=self._pool)
def get_pubsub(self):
def get_connection_async(self) -> redis_async.Redis:
"""
Get a pubsub connection from the pool.
Get an async connection from the pool.
Async connections allow pubsub.
"""
return self._pubsub_connection.pubsub()
return redis_async.Redis(connection_pool=self._async_pool)
async def subscribe_to_keys(self, pattern: str) -> redis_async.client.PubSub:
async_redis = self.get_connection_async()
pubsub = async_redis.pubsub()
await pubsub.psubscribe(f"__keyspace@{self._dbnumber}__:" + pattern)
return pubsub

View file

@ -0,0 +1,55 @@
import typing
from systemd import journal
def get_events_from_journal(
j: journal.Reader, limit: int, next: typing.Callable[[journal.Reader], typing.Dict]
):
events = []
i = 0
while i < limit:
entry = next(j)
if entry is None or entry == dict():
break
if entry["MESSAGE"] != "":
events.append(entry)
i += 1
return events
def get_paginated_logs(
limit: int = 20,
# All entries returned will be lesser than this cursor. Sets upper bound on results.
up_cursor: str | None = None,
# All entries returned will be greater than this cursor. Sets lower bound on results.
down_cursor: str | None = None,
):
j = journal.Reader()
if up_cursor is None and down_cursor is None:
j.seek_tail()
events = get_events_from_journal(j, limit, lambda j: j.get_previous())
events.reverse()
return events
elif up_cursor is None and down_cursor is not None:
j.seek_cursor(down_cursor)
j.get_previous() # pagination is exclusive
events = get_events_from_journal(j, limit, lambda j: j.get_previous())
events.reverse()
return events
elif up_cursor is not None and down_cursor is None:
j.seek_cursor(up_cursor)
j.get_next() # pagination is exclusive
events = get_events_from_journal(j, limit, lambda j: j.get_next())
return events
else:
raise NotImplementedError(
"Pagination by both up_cursor and down_cursor is not implemented"
)

View file

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name="selfprivacy_api",
version="3.2.2",
version="3.3.0",
packages=find_packages(),
scripts=[
"selfprivacy_api/app.py",

View file

@ -69,10 +69,22 @@ def generate_backup_query(query_array):
return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}"
def generate_jobs_query(query_array):
return "query TestJobs {\n jobs {" + "\n".join(query_array) + "}\n}"
def generate_jobs_subscription(query_array):
return "subscription TestSubscription {\n jobs {" + "\n".join(query_array) + "}\n}"
def generate_service_query(query_array):
return "query TestService {\n services {" + "\n".join(query_array) + "}\n}"
def generate_logs_query(query_array):
return "query TestService {\n logs {" + "\n".join(query_array) + "}\n}"
def mnemonic_to_hex(mnemonic):
return Mnemonic(language="english").to_entropy(mnemonic).hex()

View file

@ -62,6 +62,9 @@
"simple-nixos-mailserver": {
"enable": true,
"location": "sdb"
},
"roundcube": {
"enable": true
}
},
"volumes": [

View file

@ -4,40 +4,40 @@ from selfprivacy_api.services.flake_service_manager import FlakeServiceManager
all_services_file = """
{
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
inputs.nextcloud.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud;
inputs.nextcloud.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud;
inputs.ocserv.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv;
inputs.ocserv.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv;
inputs.pleroma.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma;
inputs.pleroma.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma;
inputs.simple-nixos-mailserver.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver;
inputs.simple-nixos-mailserver.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver;
outputs = _: { };
outputs = _: { };
}
"""
some_services_file = """
{
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
outputs = _: { };
outputs = _: { };
}
"""

View file

@ -1,4 +1,4 @@
{
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
outputs = _: { };
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
outputs = _: { };
}

View file

@ -1,12 +1,12 @@
{
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
outputs = _: { };
outputs = _: { };
}

View file

@ -0,0 +1,172 @@
import asyncio
import pytest
from datetime import datetime
from systemd import journal
from tests.test_graphql.test_websocket import init_graphql
def assert_log_entry_equals_to_journal_entry(api_entry, journal_entry):
assert api_entry["message"] == journal_entry["MESSAGE"]
assert (
datetime.fromisoformat(api_entry["timestamp"])
== journal_entry["__REALTIME_TIMESTAMP"]
)
assert api_entry.get("priority") == journal_entry.get("PRIORITY")
assert api_entry.get("systemdUnit") == journal_entry.get("_SYSTEMD_UNIT")
assert api_entry.get("systemdSlice") == journal_entry.get("_SYSTEMD_SLICE")
def take_from_journal(j, limit, next):
entries = []
for _ in range(0, limit):
entry = next(j)
if entry["MESSAGE"] != "":
entries.append(entry)
return entries
API_GET_LOGS_WITH_UP_BORDER = """
query TestQuery($upCursor: String) {
logs {
paginated(limit: 4, upCursor: $upCursor) {
pageMeta {
upCursor
downCursor
}
entries {
message
timestamp
priority
systemdUnit
systemdSlice
}
}
}
}
"""
API_GET_LOGS_WITH_DOWN_BORDER = """
query TestQuery($downCursor: String) {
logs {
paginated(limit: 4, downCursor: $downCursor) {
pageMeta {
upCursor
downCursor
}
entries {
message
timestamp
priority
systemdUnit
systemdSlice
}
}
}
}
"""
def test_graphql_get_logs_with_up_border(authorized_client):
j = journal.Reader()
j.seek_tail()
# < - cursor
# <- - log entry will be returned by API call.
# ...
# log <
# log <-
# log <-
# log <-
# log <-
# log
expected_entries = take_from_journal(j, 6, lambda j: j.get_previous())
expected_entries.reverse()
response = authorized_client.post(
"/graphql",
json={
"query": API_GET_LOGS_WITH_UP_BORDER,
"variables": {"upCursor": expected_entries[0]["__CURSOR"]},
},
)
assert response.status_code == 200
expected_entries = expected_entries[1:-1]
returned_entries = response.json()["data"]["logs"]["paginated"]["entries"]
assert len(returned_entries) == len(expected_entries)
for api_entry, journal_entry in zip(returned_entries, expected_entries):
assert_log_entry_equals_to_journal_entry(api_entry, journal_entry)
def test_graphql_get_logs_with_down_border(authorized_client):
j = journal.Reader()
j.seek_head()
j.get_next()
# < - cursor
# <- - log entry will be returned by API call.
# log
# log <-
# log <-
# log <-
# log <-
# log <
# ...
expected_entries = take_from_journal(j, 5, lambda j: j.get_next())
response = authorized_client.post(
"/graphql",
json={
"query": API_GET_LOGS_WITH_DOWN_BORDER,
"variables": {"downCursor": expected_entries[-1]["__CURSOR"]},
},
)
assert response.status_code == 200
expected_entries = expected_entries[:-1]
returned_entries = response.json()["data"]["logs"]["paginated"]["entries"]
assert len(returned_entries) == len(expected_entries)
for api_entry, journal_entry in zip(returned_entries, expected_entries):
assert_log_entry_equals_to_journal_entry(api_entry, journal_entry)
@pytest.mark.asyncio
async def test_websocket_subscription_for_logs(authorized_client):
with authorized_client.websocket_connect(
"/graphql", subprotocols=["graphql-transport-ws"]
) as websocket:
init_graphql(websocket)
websocket.send_json(
{
"id": "3aaa2445",
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription { logEntries { message } }",
},
}
)
await asyncio.sleep(1)
def read_until(message, limit=5):
i = 0
while i < limit:
msg = websocket.receive_json()["payload"]["data"]["logEntries"][
"message"
]
if msg == message:
return
else:
i += 1
continue
raise Exception("Failed to read websocket data, timeout")
for i in range(0, 10):
journal.send(f"Lorem ipsum number {i}")
read_until(f"Lorem ipsum number {i}")

View file

@ -0,0 +1,74 @@
from tests.common import generate_jobs_query
import tests.test_graphql.test_api_backup
from tests.test_graphql.common import (
assert_ok,
assert_empty,
assert_errorcode,
get_data,
)
from selfprivacy_api.jobs import Jobs
API_JOBS_QUERY = """
getJobs {
uid
typeId
name
description
status
statusText
progress
createdAt
updatedAt
finishedAt
error
result
}
"""
def graphql_send_query(client, query: str, variables: dict = {}):
return client.post("/graphql", json={"query": query, "variables": variables})
def api_jobs(authorized_client):
response = graphql_send_query(
authorized_client, generate_jobs_query([API_JOBS_QUERY])
)
data = get_data(response)
result = data["jobs"]["getJobs"]
assert result is not None
return result
def test_all_jobs_unauthorized(client):
response = graphql_send_query(client, generate_jobs_query([API_JOBS_QUERY]))
assert_empty(response)
def test_all_jobs_when_none(authorized_client):
output = api_jobs(authorized_client)
assert output == []
def test_all_jobs_when_some(authorized_client):
# We cannot make new jobs via API, at least directly
job = Jobs.add("bogus", "bogus.bogus", "fungus")
output = api_jobs(authorized_client)
len(output) == 1
api_job = output[0]
assert api_job["uid"] == str(job.uid)
assert api_job["typeId"] == job.type_id
assert api_job["name"] == job.name
assert api_job["description"] == job.description
assert api_job["status"] == job.status
assert api_job["statusText"] == job.status_text
assert api_job["progress"] == job.progress
assert api_job["createdAt"] == job.created_at.isoformat()
assert api_job["updatedAt"] == job.updated_at.isoformat()
assert api_job["finishedAt"] == None
assert api_job["error"] == None
assert api_job["result"] == None

View file

@ -543,8 +543,8 @@ def test_disable_enable(authorized_client, only_dummy_service):
assert api_dummy_service["status"] == ServiceStatus.ACTIVE.value
def test_move_immovable(authorized_client, only_dummy_service):
dummy_service = only_dummy_service
def test_move_immovable(authorized_client, dummy_service_with_binds):
dummy_service = dummy_service_with_binds
dummy_service.set_movable(False)
root = BlockDevices().get_root_block_device()
mutation_response = api_move(authorized_client, dummy_service, root.name)

View file

@ -0,0 +1,225 @@
# from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions
import pytest
import asyncio
from typing import Generator
from time import sleep
from starlette.testclient import WebSocketTestSession
from selfprivacy_api.jobs import Jobs
from selfprivacy_api.actions.api_tokens import TOKEN_REPO
from selfprivacy_api.graphql import IsAuthenticated
from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH
from tests.test_jobs import jobs as empty_jobs
# We do not iterate through them yet
TESTED_SUBPROTOCOLS = ["graphql-transport-ws"]
JOBS_SUBSCRIPTION = """
jobUpdates {
uid
typeId
name
description
status
statusText
progress
createdAt
updatedAt
finishedAt
error
result
}
"""
def api_subscribe(websocket, id, subscription):
websocket.send_json(
{
"id": id,
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {" + subscription + "}",
},
}
)
def connect_ws_authenticated(authorized_client) -> WebSocketTestSession:
token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"])
return authorized_client.websocket_connect(
"/graphql",
subprotocols=TESTED_SUBPROTOCOLS,
params={"token": token},
)
def connect_ws_not_authenticated(client) -> WebSocketTestSession:
return client.websocket_connect(
"/graphql",
subprotocols=TESTED_SUBPROTOCOLS,
params={"token": "I like vegan icecream but it is not a valid token"},
)
def init_graphql(websocket):
websocket.send_json({"type": "connection_init", "payload": {}})
ack = websocket.receive_json()
assert ack == {"type": "connection_ack"}
@pytest.fixture
def authenticated_websocket(
authorized_client,
) -> Generator[WebSocketTestSession, None, None]:
# We use authorized_client only to have token in the repo, this client by itself is not enough to authorize websocket
ValueError(TOKEN_REPO.get_tokens())
with connect_ws_authenticated(authorized_client) as websocket:
yield websocket
@pytest.fixture
def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]:
with connect_ws_not_authenticated(client) as websocket:
yield websocket
def test_websocket_connection_bare(authorized_client):
client = authorized_client
with client.websocket_connect(
"/graphql", subprotocols=["graphql-transport-ws", "graphql-ws"]
) as websocket:
assert websocket is not None
assert websocket.scope is not None
def test_websocket_graphql_init(authorized_client):
client = authorized_client
with client.websocket_connect(
"/graphql", subprotocols=["graphql-transport-ws"]
) as websocket:
websocket.send_json({"type": "connection_init", "payload": {}})
ack = websocket.receive_json()
assert ack == {"type": "connection_ack"}
def test_websocket_graphql_ping(authorized_client):
client = authorized_client
with client.websocket_connect(
"/graphql", subprotocols=["graphql-transport-ws"]
) as websocket:
# https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping
websocket.send_json({"type": "ping", "payload": {}})
pong = websocket.receive_json()
assert pong == {"type": "pong"}
def test_websocket_subscription_minimal(authorized_client, authenticated_websocket):
# Test a small endpoint that exists specifically for tests
websocket = authenticated_websocket
init_graphql(websocket)
arbitrary_id = "3aaa2445"
api_subscribe(websocket, arbitrary_id, "count")
response = websocket.receive_json()
assert response == {
"id": arbitrary_id,
"payload": {"data": {"count": 0}},
"type": "next",
}
response = websocket.receive_json()
assert response == {
"id": arbitrary_id,
"payload": {"data": {"count": 1}},
"type": "next",
}
response = websocket.receive_json()
assert response == {
"id": arbitrary_id,
"payload": {"data": {"count": 2}},
"type": "next",
}
def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket):
websocket = unauthenticated_websocket
init_graphql(websocket)
arbitrary_id = "3aaa2445"
api_subscribe(websocket, arbitrary_id, "count")
response = websocket.receive_json()
assert response == {
"id": arbitrary_id,
"payload": [{"message": IsAuthenticated.message}],
"type": "error",
}
async def read_one_job(websocket):
# Bug? We only get them starting from the second job update
# That's why we receive two jobs in the list them
# The first update gets lost somewhere
response = websocket.receive_json()
return response
@pytest.mark.asyncio
async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs):
websocket = authenticated_websocket
init_graphql(websocket)
arbitrary_id = "3aaa2445"
api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION)
future = asyncio.create_task(read_one_job(websocket))
jobs = []
jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works"))
sleep(0.5)
jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works"))
response = await future
data = response["payload"]["data"]
jobs_received = data["jobUpdates"]
received_names = [job["name"] for job in jobs_received]
for job in jobs:
assert job.name in received_names
assert len(jobs_received) == 2
for job in jobs:
for api_job in jobs_received:
if (job.name) == api_job["name"]:
assert api_job["uid"] == str(job.uid)
assert api_job["typeId"] == job.type_id
assert api_job["name"] == job.name
assert api_job["description"] == job.description
assert api_job["status"] == job.status
assert api_job["statusText"] == job.status_text
assert api_job["progress"] == job.progress
assert api_job["createdAt"] == job.created_at.isoformat()
assert api_job["updatedAt"] == job.updated_at.isoformat()
assert api_job["finishedAt"] == None
assert api_job["error"] == None
assert api_job["result"] == None
def test_websocket_subscription_unauthorized(unauthenticated_websocket):
websocket = unauthenticated_websocket
init_graphql(websocket)
id = "3aaa2445"
api_subscribe(websocket, id, JOBS_SUBSCRIPTION)
response = websocket.receive_json()
# I do not really know why strawberry gives more info on this
# One versus the counter
payload = response["payload"][0]
assert isinstance(payload, dict)
assert "locations" in payload.keys()
# It looks like this 'locations': [{'column': 32, 'line': 1}]
# We cannot test locations feasibly
del payload["locations"]
assert response == {
"id": id,
"payload": [{"message": IsAuthenticated.message, "path": ["jobUpdates"]}],
"type": "error",
}

218
tests/test_redis.py Normal file
View file

@ -0,0 +1,218 @@
import asyncio
import pytest
import pytest_asyncio
from asyncio import streams
import redis
from typing import List
from selfprivacy_api.utils.redis_pool import RedisPool
from selfprivacy_api.jobs import Jobs, job_notifications
TEST_KEY = "test:test"
STOPWORD = "STOP"
@pytest.fixture()
def empty_redis(event_loop):
r = RedisPool().get_connection()
r.flushdb()
assert r.config_get("notify-keyspace-events")["notify-keyspace-events"] == "AKE"
yield r
r.flushdb()
async def write_to_test_key():
r = RedisPool().get_connection_async()
async with r.pipeline(transaction=True) as pipe:
ok1, ok2 = await pipe.set(TEST_KEY, "value1").set(TEST_KEY, "value2").execute()
assert ok1
assert ok2
assert await r.get(TEST_KEY) == "value2"
await r.close()
def test_async_connection(empty_redis):
r = RedisPool().get_connection()
assert not r.exists(TEST_KEY)
# It _will_ report an error if it arises
asyncio.run(write_to_test_key())
# Confirming that we can read result from sync connection too
assert r.get(TEST_KEY) == "value2"
async def channel_reader(channel: redis.client.PubSub) -> List[dict]:
result: List[dict] = []
while True:
# Mypy cannot correctly detect that it is a coroutine
# But it is
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
if message is not None:
result.append(message)
if message["data"] == STOPWORD:
break
return result
async def channel_reader_onemessage(channel: redis.client.PubSub) -> dict:
while True:
# Mypy cannot correctly detect that it is a coroutine
# But it is
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
if message is not None:
return message
@pytest.mark.asyncio
async def test_pubsub(empty_redis, event_loop):
# Adapted from :
# https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html
# Sanity checking because of previous event loop bugs
assert event_loop == asyncio.get_event_loop()
assert event_loop == asyncio.events.get_event_loop()
assert event_loop == asyncio.events._get_event_loop()
assert event_loop == asyncio.events.get_running_loop()
reader = streams.StreamReader(34)
assert event_loop == reader._loop
f = reader._loop.create_future()
f.set_result(3)
await f
r = RedisPool().get_connection_async()
async with r.pubsub() as pubsub:
await pubsub.subscribe("channel:1")
future = asyncio.create_task(channel_reader(pubsub))
await r.publish("channel:1", "Hello")
# message: dict = await pubsub.get_message(ignore_subscribe_messages=True, timeout=5.0) # type: ignore
# raise ValueError(message)
await r.publish("channel:1", "World")
await r.publish("channel:1", STOPWORD)
messages = await future
assert len(messages) == 3
message = messages[0]
assert "data" in message.keys()
assert message["data"] == "Hello"
message = messages[1]
assert "data" in message.keys()
assert message["data"] == "World"
message = messages[2]
assert "data" in message.keys()
assert message["data"] == STOPWORD
await r.close()
@pytest.mark.asyncio
async def test_keyspace_notifications_simple(empty_redis, event_loop):
r = RedisPool().get_connection_async()
await r.set(TEST_KEY, "I am not empty")
async with r.pubsub() as pubsub:
await pubsub.subscribe("__keyspace@0__:" + TEST_KEY)
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
empty_redis.set(TEST_KEY, "I am set!")
message = await future_message
assert message is not None
assert message["data"] is not None
assert message == {
"channel": f"__keyspace@0__:{TEST_KEY}",
"data": "set",
"pattern": None,
"type": "message",
}
@pytest.mark.asyncio
async def test_keyspace_notifications(empty_redis, event_loop):
pubsub = await RedisPool().subscribe_to_keys(TEST_KEY)
async with pubsub:
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
empty_redis.set(TEST_KEY, "I am set!")
message = await future_message
assert message is not None
assert message["data"] is not None
assert message == {
"channel": f"__keyspace@0__:{TEST_KEY}",
"data": "set",
"pattern": f"__keyspace@0__:{TEST_KEY}",
"type": "pmessage",
}
@pytest.mark.asyncio
async def test_keyspace_notifications_patterns(empty_redis, event_loop):
pattern = "test*"
pubsub = await RedisPool().subscribe_to_keys(pattern)
async with pubsub:
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
empty_redis.set(TEST_KEY, "I am set!")
message = await future_message
assert message is not None
assert message["data"] is not None
assert message == {
"channel": f"__keyspace@0__:{TEST_KEY}",
"data": "set",
"pattern": f"__keyspace@0__:{pattern}",
"type": "pmessage",
}
@pytest.mark.asyncio
async def test_keyspace_notifications_jobs(empty_redis, event_loop):
pattern = "jobs:*"
pubsub = await RedisPool().subscribe_to_keys(pattern)
async with pubsub:
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
Jobs.add("testjob1", "test.test", "Testing aaaalll day")
message = await future_message
assert message is not None
assert message["data"] is not None
assert message["data"] == "hset"
async def reader_of_jobs() -> List[dict]:
"""
Reads 3 job updates and exits
"""
result: List[dict] = []
async for message in job_notifications():
result.append(message)
if len(result) >= 3:
break
return result
@pytest.mark.asyncio
async def test_jobs_generator(empty_redis, event_loop):
# Will read exactly 3 job messages
future_messages = asyncio.create_task(reader_of_jobs())
await asyncio.sleep(1)
Jobs.add("testjob1", "test.test", "Testing aaaalll day")
Jobs.add("testjob2", "test.test", "Testing aaaalll day")
Jobs.add("testjob3", "test.test", "Testing aaaalll day")
Jobs.add("testjob4", "test.test", "Testing aaaalll day")
assert len(Jobs.get_jobs()) == 4
r = RedisPool().get_connection()
assert len(r.keys("jobs:*")) == 4
messages = await future_messages
assert len(messages) == 3
channels = [message["channel"] for message in messages]
operations = [message["data"] for message in messages]
assert set(operations) == set(["hset"]) # all of them are hsets
# Asserting that all of jobs emitted exactly one message
jobs = Jobs.get_jobs()
names = ["testjob1", "testjob2", "testjob3"]
ids = [str(job.uid) for job in jobs if job.name in names]
for id in ids:
assert id in " ".join(channels)
# Asserting that they came in order
assert "testjob4" not in " ".join(channels)

View file

@ -0,0 +1,39 @@
import pytest
from fastapi import FastAPI, WebSocket
import uvicorn
# import subprocess
from multiprocessing import Process
import asyncio
from time import sleep
from websockets import client
app = FastAPI()
@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"You sent: {data}")
def run_uvicorn():
uvicorn.run(app, port=5000)
return True
@pytest.mark.asyncio
async def test_uvcorn_ws_works_in_prod():
proc = Process(target=run_uvicorn)
proc.start()
sleep(2)
ws = await client.connect("ws://127.0.0.1:5000")
await ws.send("hohoho")
message = await ws.read_message()
assert message == "You sent: hohoho"
await ws.close()
proc.kill()