mirror of
https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api.git
synced 2024-11-22 04:01:27 +00:00
Merge remote-tracking branch 'origin/master' into inex/service-settings
This commit is contained in:
commit
c8d00e6c87
|
@ -13,9 +13,9 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
|
|||
|
||||
For detailed installation information, please review and follow: [link](https://nixos.org/manual/nix/stable/installation/installing-binary.html#installing-a-binary-distribution).
|
||||
|
||||
3. **Change directory to the cloned repository and start a nix shell:**
|
||||
3. **Change directory to the cloned repository and start a nix development shell:**
|
||||
|
||||
```cd selfprivacy-rest-api && nix-shell```
|
||||
```cd selfprivacy-rest-api && nix develop```
|
||||
|
||||
Nix will install all of the necessary packages for development work, all further actions will take place only within nix-shell.
|
||||
|
||||
|
@ -31,7 +31,7 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
|
|||
|
||||
Copy the path that starts with ```/nix/store/``` and ends with ```env/bin/python```
|
||||
|
||||
```/nix/store/???-python3-3.9.??-env/bin/python```
|
||||
```/nix/store/???-python3-3.10.??-env/bin/python```
|
||||
|
||||
Click on the python version selection in the lower right corner, and replace the path to the interpreter in the project with the one you copied from the terminal.
|
||||
|
||||
|
@ -43,12 +43,13 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
|
|||
|
||||
## What to do after making changes to the repository?
|
||||
|
||||
**Run unit tests** using ```pytest .```
|
||||
Make sure that all tests pass successfully and the API works correctly. For convenience, you can use the built-in VScode interface.
|
||||
**Run unit tests** using ```pytest-vm``` inside of the development shell. This will run all the test inside a virtual machine, which is necessary for the tests to pass successfully.
|
||||
Make sure that all tests pass successfully and the API works correctly.
|
||||
|
||||
How to review the percentage of code coverage? Execute the command:
|
||||
The ```pytest-vm``` command will also print out the coverage of the tests. To export the report to an XML file, use the following command:
|
||||
|
||||
```coverage xml```
|
||||
|
||||
```coverage run -m pytest && coverage xml && coverage report```
|
||||
|
||||
Next, use the recommended extension ```ryanluker.vscode-coverage-gutters```, navigate to one of the test files, and click the "watch" button on the bottom panel of VScode.
|
||||
|
||||
|
|
|
@ -14,10 +14,12 @@ pythonPackages.buildPythonPackage rec {
|
|||
pydantic
|
||||
pytz
|
||||
redis
|
||||
systemd
|
||||
setuptools
|
||||
strawberry-graphql
|
||||
typing-extensions
|
||||
uvicorn
|
||||
websockets
|
||||
];
|
||||
pythonImportsCheck = [ "selfprivacy_api" ];
|
||||
doCheck = false;
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1709677081,
|
||||
"narHash": "sha256-tix36Y7u0rkn6mTm0lA45b45oab2cFLqAzDbJxeXS+c=",
|
||||
"lastModified": 1719957072,
|
||||
"narHash": "sha256-gvFhEf5nszouwLAkT9nWsDzocUTqLWHuL++dvNjMp9I=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "880992dcc006a5e00dd0591446fdf723e6a51a64",
|
||||
"rev": "7144d6241f02d171d25fba3edeaf15e0f2592105",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
pytest-datadir
|
||||
pytest-mock
|
||||
pytest-subprocess
|
||||
pytest-asyncio
|
||||
black
|
||||
mypy
|
||||
pylsp-mypy
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
@ -13,8 +14,12 @@ from selfprivacy_api.migrations import run_migrations
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
graphql_app = GraphQLRouter(
|
||||
graphql_app: GraphQLRouter = GraphQLRouter(
|
||||
schema,
|
||||
subscription_protocols=[
|
||||
GRAPHQL_TRANSPORT_WS_PROTOCOL,
|
||||
GRAPHQL_WS_PROTOCOL,
|
||||
],
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
|
|
|
@ -27,4 +27,4 @@ async def get_token_header(
|
|||
|
||||
def get_api_version() -> str:
|
||||
"""Get API version"""
|
||||
return "3.2.2+configs"
|
||||
return "3.3.0+configs"
|
||||
|
|
|
@ -16,6 +16,10 @@ class IsAuthenticated(BasePermission):
|
|||
token = info.context["request"].headers.get("Authorization")
|
||||
if token is None:
|
||||
token = info.context["request"].query_params.get("token")
|
||||
if token is None:
|
||||
connection_params = info.context.get("connection_params")
|
||||
if connection_params is not None:
|
||||
token = connection_params.get("Authorization")
|
||||
if token is None:
|
||||
return False
|
||||
return is_token_valid(token.replace("Bearer ", ""))
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
"""API access status"""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
import datetime
|
||||
import typing
|
||||
import strawberry
|
||||
|
||||
from strawberry.types import Info
|
||||
from selfprivacy_api.actions.api_tokens import (
|
||||
get_api_tokens_with_caller_flag,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Backup"""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
import typing
|
||||
import strawberry
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Common types and enums used by different types of queries."""
|
||||
|
||||
from enum import Enum
|
||||
import datetime
|
||||
import typing
|
||||
|
|
|
@ -1,24 +1,30 @@
|
|||
"""Jobs status"""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
import typing
|
||||
import strawberry
|
||||
from typing import List, Optional
|
||||
|
||||
from selfprivacy_api.jobs import Jobs
|
||||
from selfprivacy_api.graphql.common_types.jobs import (
|
||||
ApiJob,
|
||||
get_api_job_by_id,
|
||||
job_to_api_job,
|
||||
)
|
||||
|
||||
from selfprivacy_api.jobs import Jobs
|
||||
|
||||
def get_all_jobs() -> List[ApiJob]:
|
||||
jobs = Jobs.get_jobs()
|
||||
api_jobs = [job_to_api_job(job) for job in jobs]
|
||||
assert api_jobs is not None
|
||||
return api_jobs
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Job:
|
||||
@strawberry.field
|
||||
def get_jobs(self) -> typing.List[ApiJob]:
|
||||
Jobs.get_jobs()
|
||||
|
||||
return [job_to_api_job(job) for job in Jobs.get_jobs()]
|
||||
def get_jobs(self) -> List[ApiJob]:
|
||||
return get_all_jobs()
|
||||
|
||||
@strawberry.field
|
||||
def get_job(self, job_id: str) -> typing.Optional[ApiJob]:
|
||||
def get_job(self, job_id: str) -> Optional[ApiJob]:
|
||||
return get_api_job_by_id(job_id)
|
||||
|
|
88
selfprivacy_api/graphql/queries/logs.py
Normal file
88
selfprivacy_api/graphql/queries/logs.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
"""System logs"""
|
||||
from datetime import datetime
|
||||
import typing
|
||||
import strawberry
|
||||
from selfprivacy_api.utils.systemd_journal import get_paginated_logs
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class LogEntry:
|
||||
message: str = strawberry.field()
|
||||
timestamp: datetime = strawberry.field()
|
||||
priority: typing.Optional[int] = strawberry.field()
|
||||
systemd_unit: typing.Optional[str] = strawberry.field()
|
||||
systemd_slice: typing.Optional[str] = strawberry.field()
|
||||
|
||||
def __init__(self, journal_entry: typing.Dict):
|
||||
self.entry = journal_entry
|
||||
self.message = journal_entry["MESSAGE"]
|
||||
self.timestamp = journal_entry["__REALTIME_TIMESTAMP"]
|
||||
self.priority = journal_entry.get("PRIORITY")
|
||||
self.systemd_unit = journal_entry.get("_SYSTEMD_UNIT")
|
||||
self.systemd_slice = journal_entry.get("_SYSTEMD_SLICE")
|
||||
|
||||
@strawberry.field()
|
||||
def cursor(self) -> str:
|
||||
return self.entry["__CURSOR"]
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class LogsPageMeta:
|
||||
up_cursor: typing.Optional[str] = strawberry.field()
|
||||
down_cursor: typing.Optional[str] = strawberry.field()
|
||||
|
||||
def __init__(
|
||||
self, up_cursor: typing.Optional[str], down_cursor: typing.Optional[str]
|
||||
):
|
||||
self.up_cursor = up_cursor
|
||||
self.down_cursor = down_cursor
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class PaginatedEntries:
|
||||
page_meta: LogsPageMeta = strawberry.field(
|
||||
description="Metadata to aid in pagination."
|
||||
)
|
||||
entries: typing.List[LogEntry] = strawberry.field(
|
||||
description="The list of log entries."
|
||||
)
|
||||
|
||||
def __init__(self, meta: LogsPageMeta, entries: typing.List[LogEntry]):
|
||||
self.page_meta = meta
|
||||
self.entries = entries
|
||||
|
||||
@staticmethod
|
||||
def from_entries(entries: typing.List[LogEntry]):
|
||||
if entries == []:
|
||||
return PaginatedEntries(LogsPageMeta(None, None), [])
|
||||
|
||||
return PaginatedEntries(
|
||||
LogsPageMeta(
|
||||
entries[0].cursor(),
|
||||
entries[-1].cursor(),
|
||||
),
|
||||
entries,
|
||||
)
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Logs:
|
||||
@strawberry.field()
|
||||
def paginated(
|
||||
self,
|
||||
limit: int = 20,
|
||||
# All entries returned will be lesser than this cursor. Sets upper bound on results.
|
||||
up_cursor: str | None = None,
|
||||
# All entries returned will be greater than this cursor. Sets lower bound on results.
|
||||
down_cursor: str | None = None,
|
||||
) -> PaginatedEntries:
|
||||
if limit > 50:
|
||||
raise Exception("You can't fetch more than 50 entries via single request.")
|
||||
return PaginatedEntries.from_entries(
|
||||
list(
|
||||
map(
|
||||
lambda x: LogEntry(x),
|
||||
get_paginated_logs(limit, up_cursor, down_cursor),
|
||||
)
|
||||
)
|
||||
)
|
|
@ -1,4 +1,5 @@
|
|||
"""Enums representing different service providers."""
|
||||
|
||||
from enum import Enum
|
||||
import strawberry
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Services status"""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
import typing
|
||||
import strawberry
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Storage queries."""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
import typing
|
||||
import strawberry
|
||||
|
@ -18,9 +19,11 @@ class Storage:
|
|||
"""Get list of volumes"""
|
||||
return [
|
||||
StorageVolume(
|
||||
total_space=str(volume.fssize)
|
||||
total_space=(
|
||||
str(volume.fssize)
|
||||
if volume.fssize is not None
|
||||
else str(volume.size),
|
||||
else str(volume.size)
|
||||
),
|
||||
free_space=str(volume.fsavail),
|
||||
used_space=str(volume.fsused),
|
||||
root=volume.is_root(),
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
"""Common system information and settings"""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
import os
|
||||
import typing
|
||||
import strawberry
|
||||
|
||||
from selfprivacy_api.graphql.common_types.dns import DnsRecord
|
||||
|
||||
from selfprivacy_api.graphql.queries.common import Alert, Severity
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Users"""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
import typing
|
||||
import strawberry
|
||||
|
|
|
@ -2,8 +2,10 @@
|
|||
# pylint: disable=too-few-public-methods
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List
|
||||
import strawberry
|
||||
from strawberry.types import Info
|
||||
|
||||
from selfprivacy_api.graphql import IsAuthenticated
|
||||
from selfprivacy_api.graphql.mutations.deprecated_mutations import (
|
||||
DeprecatedApiMutations,
|
||||
|
@ -24,10 +26,17 @@ from selfprivacy_api.graphql.mutations.backup_mutations import BackupMutations
|
|||
from selfprivacy_api.graphql.queries.api_queries import Api
|
||||
from selfprivacy_api.graphql.queries.backup import Backup
|
||||
from selfprivacy_api.graphql.queries.jobs import Job
|
||||
from selfprivacy_api.graphql.queries.logs import LogEntry, Logs
|
||||
from selfprivacy_api.graphql.queries.services import Services
|
||||
from selfprivacy_api.graphql.queries.storage import Storage
|
||||
from selfprivacy_api.graphql.queries.system import System
|
||||
|
||||
from selfprivacy_api.graphql.subscriptions.jobs import ApiJob
|
||||
from selfprivacy_api.graphql.subscriptions.jobs import (
|
||||
job_updates as job_update_generator,
|
||||
)
|
||||
from selfprivacy_api.graphql.subscriptions.logs import log_stream
|
||||
|
||||
from selfprivacy_api.graphql.common_types.service import (
|
||||
StringConfigItem,
|
||||
BoolConfigItem,
|
||||
|
@ -53,6 +62,11 @@ class Query:
|
|||
"""System queries"""
|
||||
return System()
|
||||
|
||||
@strawberry.field(permission_classes=[IsAuthenticated])
|
||||
def logs(self) -> Logs:
|
||||
"""Log queries"""
|
||||
return Logs()
|
||||
|
||||
@strawberry.field(permission_classes=[IsAuthenticated])
|
||||
def users(self) -> Users:
|
||||
"""Users queries"""
|
||||
|
@ -135,19 +149,42 @@ class Mutation(
|
|||
code=200,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
# A cruft for Websockets
|
||||
def authenticated(info: Info) -> bool:
|
||||
return IsAuthenticated().has_permission(source=None, info=info)
|
||||
|
||||
|
||||
def reject_if_unauthenticated(info: Info):
|
||||
if not authenticated(info):
|
||||
raise Exception(IsAuthenticated().message)
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Subscription:
|
||||
"""Root schema for subscriptions"""
|
||||
"""Root schema for subscriptions.
|
||||
Every field here should be an AsyncIterator or AsyncGenerator
|
||||
It is not a part of the spec but graphql-core (dep of strawberryql)
|
||||
demands it while the spec is vague in this area."""
|
||||
|
||||
@strawberry.subscription(permission_classes=[IsAuthenticated])
|
||||
async def count(self, target: int = 100) -> AsyncGenerator[int, None]:
|
||||
for i in range(target):
|
||||
@strawberry.subscription
|
||||
async def job_updates(self, info: Info) -> AsyncGenerator[List[ApiJob], None]:
|
||||
reject_if_unauthenticated(info)
|
||||
return job_update_generator()
|
||||
|
||||
@strawberry.subscription
|
||||
# Used for testing, consider deletion to shrink attack surface
|
||||
async def count(self, info: Info) -> AsyncGenerator[int, None]:
|
||||
reject_if_unauthenticated(info)
|
||||
for i in range(10):
|
||||
yield i
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
@strawberry.subscription
|
||||
async def log_entries(self, info: Info) -> AsyncGenerator[LogEntry, None]:
|
||||
reject_if_unauthenticated(info)
|
||||
return log_stream()
|
||||
|
||||
|
||||
schema = strawberry.Schema(
|
||||
query=Query,
|
||||
|
|
0
selfprivacy_api/graphql/subscriptions/__init__.py
Normal file
0
selfprivacy_api/graphql/subscriptions/__init__.py
Normal file
14
selfprivacy_api/graphql/subscriptions/jobs.py
Normal file
14
selfprivacy_api/graphql/subscriptions/jobs.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# pylint: disable=too-few-public-methods
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from selfprivacy_api.jobs import job_notifications
|
||||
|
||||
from selfprivacy_api.graphql.common_types.jobs import ApiJob
|
||||
from selfprivacy_api.graphql.queries.jobs import get_all_jobs
|
||||
|
||||
|
||||
async def job_updates() -> AsyncGenerator[List[ApiJob], None]:
|
||||
# Send the complete list of jobs every time anything gets updated
|
||||
async for notification in job_notifications():
|
||||
yield get_all_jobs()
|
31
selfprivacy_api/graphql/subscriptions/logs.py
Normal file
31
selfprivacy_api/graphql/subscriptions/logs.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from typing import AsyncGenerator
|
||||
from systemd import journal
|
||||
import asyncio
|
||||
|
||||
from selfprivacy_api.graphql.queries.logs import LogEntry
|
||||
|
||||
|
||||
async def log_stream() -> AsyncGenerator[LogEntry, None]:
|
||||
j = journal.Reader()
|
||||
|
||||
j.seek_tail()
|
||||
j.get_previous()
|
||||
|
||||
queue = asyncio.Queue()
|
||||
|
||||
async def callback():
|
||||
if j.process() != journal.APPEND:
|
||||
return
|
||||
for entry in j:
|
||||
await queue.put(entry)
|
||||
|
||||
asyncio.get_event_loop().add_reader(j, lambda: asyncio.ensure_future(callback()))
|
||||
|
||||
while True:
|
||||
entry = await queue.get()
|
||||
try:
|
||||
yield LogEntry(entry)
|
||||
except Exception:
|
||||
asyncio.get_event_loop().remove_reader(j)
|
||||
return
|
||||
queue.task_done()
|
|
@ -15,6 +15,7 @@ A job is a dictionary with the following keys:
|
|||
- result: result of the job
|
||||
"""
|
||||
import typing
|
||||
import asyncio
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
import uuid
|
||||
|
@ -23,6 +24,7 @@ from enum import Enum
|
|||
from pydantic import BaseModel
|
||||
|
||||
from selfprivacy_api.utils.redis_pool import RedisPool
|
||||
from selfprivacy_api.utils.redis_model_storage import store_model_as_hash
|
||||
|
||||
JOB_EXPIRATION_SECONDS = 10 * 24 * 60 * 60 # ten days
|
||||
|
||||
|
@ -102,7 +104,7 @@ class Jobs:
|
|||
result=None,
|
||||
)
|
||||
redis = RedisPool().get_connection()
|
||||
_store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job)
|
||||
store_model_as_hash(redis, _redis_key_from_uuid(job.uid), job)
|
||||
return job
|
||||
|
||||
@staticmethod
|
||||
|
@ -218,7 +220,7 @@ class Jobs:
|
|||
redis = RedisPool().get_connection()
|
||||
key = _redis_key_from_uuid(job.uid)
|
||||
if redis.exists(key):
|
||||
_store_job_as_hash(redis, key, job)
|
||||
store_model_as_hash(redis, key, job)
|
||||
if status in (JobStatus.FINISHED, JobStatus.ERROR):
|
||||
redis.expire(key, JOB_EXPIRATION_SECONDS)
|
||||
|
||||
|
@ -294,17 +296,6 @@ def _progress_log_key_from_uuid(uuid_string) -> str:
|
|||
return PROGRESS_LOGS_PREFIX + str(uuid_string)
|
||||
|
||||
|
||||
def _store_job_as_hash(redis, redis_key, model) -> None:
|
||||
for key, value in model.dict().items():
|
||||
if isinstance(value, uuid.UUID):
|
||||
value = str(value)
|
||||
if isinstance(value, datetime.datetime):
|
||||
value = value.isoformat()
|
||||
if isinstance(value, JobStatus):
|
||||
value = value.value
|
||||
redis.hset(redis_key, key, str(value))
|
||||
|
||||
|
||||
def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
|
||||
if redis.exists(redis_key):
|
||||
job_dict = redis.hgetall(redis_key)
|
||||
|
@ -321,3 +312,15 @@ def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
|
|||
|
||||
return Job(**job_dict)
|
||||
return None
|
||||
|
||||
|
||||
async def job_notifications() -> typing.AsyncGenerator[dict, None]:
|
||||
channel = await RedisPool().subscribe_to_keys("jobs:*")
|
||||
while True:
|
||||
try:
|
||||
# we cannot timeout here because we do not know when the next message is supposed to arrive
|
||||
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
|
||||
if message is not None:
|
||||
yield message
|
||||
except GeneratorExit:
|
||||
break
|
||||
|
|
|
@ -14,10 +14,12 @@ from selfprivacy_api.migrations.write_token_to_redis import WriteTokenToRedis
|
|||
from selfprivacy_api.migrations.check_for_system_rebuild_jobs import (
|
||||
CheckForSystemRebuildJobs,
|
||||
)
|
||||
from selfprivacy_api.migrations.add_roundcube import AddRoundcube
|
||||
|
||||
migrations = [
|
||||
WriteTokenToRedis(),
|
||||
CheckForSystemRebuildJobs(),
|
||||
AddRoundcube(),
|
||||
]
|
||||
|
||||
|
||||
|
|
36
selfprivacy_api/migrations/add_roundcube.py
Normal file
36
selfprivacy_api/migrations/add_roundcube.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from selfprivacy_api.migrations.migration import Migration
|
||||
|
||||
from selfprivacy_api.services.flake_service_manager import FlakeServiceManager
|
||||
from selfprivacy_api.utils import ReadUserData, WriteUserData
|
||||
|
||||
|
||||
class AddRoundcube(Migration):
|
||||
"""Adds the Roundcube if it is not present."""
|
||||
|
||||
def get_migration_name(self) -> str:
|
||||
return "add_roundcube"
|
||||
|
||||
def get_migration_description(self) -> str:
|
||||
return "Adds the Roundcube if it is not present."
|
||||
|
||||
def is_migration_needed(self) -> bool:
|
||||
with FlakeServiceManager() as manager:
|
||||
if "roundcube" not in manager.services:
|
||||
return True
|
||||
with ReadUserData() as data:
|
||||
if "roundcube" not in data["modules"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def migrate(self) -> None:
|
||||
with FlakeServiceManager() as manager:
|
||||
if "roundcube" not in manager.services:
|
||||
manager.services[
|
||||
"roundcube"
|
||||
] = "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/roundcube"
|
||||
with WriteUserData() as data:
|
||||
if "roundcube" not in data["modules"]:
|
||||
data["modules"]["roundcube"] = {
|
||||
"enable": False,
|
||||
"subdomain": "roundcube",
|
||||
}
|
|
@ -5,13 +5,13 @@ from selfprivacy_api.jobs import JobStatus, Jobs
|
|||
class CheckForSystemRebuildJobs(Migration):
|
||||
"""Check if there are unfinished system rebuild jobs and finish them"""
|
||||
|
||||
def get_migration_name(self):
|
||||
def get_migration_name(self) -> str:
|
||||
return "check_for_system_rebuild_jobs"
|
||||
|
||||
def get_migration_description(self):
|
||||
def get_migration_description(self) -> str:
|
||||
return "Check if there are unfinished system rebuild jobs and finish them"
|
||||
|
||||
def is_migration_needed(self):
|
||||
def is_migration_needed(self) -> bool:
|
||||
# Check if there are any unfinished system rebuild jobs
|
||||
for job in Jobs.get_jobs():
|
||||
if (
|
||||
|
@ -25,8 +25,9 @@ class CheckForSystemRebuildJobs(Migration):
|
|||
JobStatus.RUNNING,
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def migrate(self):
|
||||
def migrate(self) -> None:
|
||||
# As the API is restarted, we assume that the jobs are finished
|
||||
for job in Jobs.get_jobs():
|
||||
if (
|
||||
|
|
|
@ -12,17 +12,17 @@ class Migration(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_migration_name(self):
|
||||
def get_migration_name(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_migration_description(self):
|
||||
def get_migration_description(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_migration_needed(self):
|
||||
def is_migration_needed(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def migrate(self):
|
||||
def migrate(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -15,10 +15,10 @@ from selfprivacy_api.utils import ReadUserData, UserDataFiles
|
|||
class WriteTokenToRedis(Migration):
|
||||
"""Load Json tokens into Redis"""
|
||||
|
||||
def get_migration_name(self):
|
||||
def get_migration_name(self) -> str:
|
||||
return "write_token_to_redis"
|
||||
|
||||
def get_migration_description(self):
|
||||
def get_migration_description(self) -> str:
|
||||
return "Loads the initial token into redis token storage"
|
||||
|
||||
def is_repo_empty(self, repo: AbstractTokensRepository) -> bool:
|
||||
|
@ -38,7 +38,7 @@ class WriteTokenToRedis(Migration):
|
|||
print(e)
|
||||
return None
|
||||
|
||||
def is_migration_needed(self):
|
||||
def is_migration_needed(self) -> bool:
|
||||
try:
|
||||
if self.get_token_from_json() is not None and self.is_repo_empty(
|
||||
RedisTokensRepository()
|
||||
|
@ -47,8 +47,9 @@ class WriteTokenToRedis(Migration):
|
|||
except Exception as e:
|
||||
print(e)
|
||||
return False
|
||||
return False
|
||||
|
||||
def migrate(self):
|
||||
def migrate(self) -> None:
|
||||
# Write info about providers to userdata.json
|
||||
try:
|
||||
token = self.get_token_from_json()
|
||||
|
|
|
@ -4,6 +4,7 @@ import typing
|
|||
from selfprivacy_api.services.bitwarden import Bitwarden
|
||||
from selfprivacy_api.services.forgejo import Forgejo
|
||||
from selfprivacy_api.services.jitsimeet import JitsiMeet
|
||||
from selfprivacy_api.services.roundcube import Roundcube
|
||||
from selfprivacy_api.services.mailserver import MailServer
|
||||
from selfprivacy_api.services.nextcloud import Nextcloud
|
||||
from selfprivacy_api.services.pleroma import Pleroma
|
||||
|
@ -19,6 +20,7 @@ services: list[Service] = [
|
|||
Pleroma(),
|
||||
Ocserv(),
|
||||
JitsiMeet(),
|
||||
Roundcube(),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -37,14 +37,14 @@ class Bitwarden(Service):
|
|||
def get_user() -> str:
|
||||
return "vaultwarden"
|
||||
|
||||
@staticmethod
|
||||
def get_url() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_url(cls) -> Optional[str]:
|
||||
"""Return service url."""
|
||||
domain = get_domain()
|
||||
return f"https://password.{domain}"
|
||||
|
||||
@staticmethod
|
||||
def get_subdomain() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> Optional[str]:
|
||||
return "password"
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -90,14 +90,14 @@ class Forgejo(Service):
|
|||
"""Read SVG icon from file and return it as base64 encoded string."""
|
||||
return base64.b64encode(FORGEJO_ICON.encode("utf-8")).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def get_url() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_url(cls) -> Optional[str]:
|
||||
"""Return service url."""
|
||||
domain = get_domain()
|
||||
return f"https://git.{domain}"
|
||||
|
||||
@staticmethod
|
||||
def get_subdomain() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> Optional[str]:
|
||||
return "git"
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -36,14 +36,14 @@ class JitsiMeet(Service):
|
|||
"""Read SVG icon from file and return it as base64 encoded string."""
|
||||
return base64.b64encode(JITSI_ICON.encode("utf-8")).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def get_url() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_url(cls) -> Optional[str]:
|
||||
"""Return service url."""
|
||||
domain = get_domain()
|
||||
return f"https://meet.{domain}"
|
||||
|
||||
@staticmethod
|
||||
def get_subdomain() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> Optional[str]:
|
||||
return "meet"
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -35,13 +35,13 @@ class MailServer(Service):
|
|||
def get_user() -> str:
|
||||
return "virtualMail"
|
||||
|
||||
@staticmethod
|
||||
def get_url() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_url(cls) -> Optional[str]:
|
||||
"""Return service url."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_subdomain() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -4,7 +4,6 @@ import subprocess
|
|||
from typing import Optional, List
|
||||
|
||||
from selfprivacy_api.utils import get_domain
|
||||
from selfprivacy_api.jobs import Job, Jobs
|
||||
|
||||
from selfprivacy_api.utils.systemd import get_service_status
|
||||
from selfprivacy_api.services.service import Service, ServiceStatus
|
||||
|
@ -35,14 +34,14 @@ class Nextcloud(Service):
|
|||
"""Read SVG icon from file and return it as base64 encoded string."""
|
||||
return base64.b64encode(NEXTCLOUD_ICON.encode("utf-8")).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def get_url() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_url(cls) -> Optional[str]:
|
||||
"""Return service url."""
|
||||
domain = get_domain()
|
||||
return f"https://cloud.{domain}"
|
||||
|
||||
@staticmethod
|
||||
def get_subdomain() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> Optional[str]:
|
||||
return "cloud"
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -28,13 +28,13 @@ class Ocserv(Service):
|
|||
def get_svg_icon() -> str:
|
||||
return base64.b64encode(OCSERV_ICON.encode("utf-8")).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def get_url() -> typing.Optional[str]:
|
||||
@classmethod
|
||||
def get_url(cls) -> typing.Optional[str]:
|
||||
"""Return service url."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_subdomain() -> typing.Optional[str]:
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> typing.Optional[str]:
|
||||
return "vpn"
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -31,14 +31,14 @@ class Pleroma(Service):
|
|||
def get_svg_icon() -> str:
|
||||
return base64.b64encode(PLEROMA_ICON.encode("utf-8")).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def get_url() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_url(cls) -> Optional[str]:
|
||||
"""Return service url."""
|
||||
domain = get_domain()
|
||||
return f"https://social.{domain}"
|
||||
|
||||
@staticmethod
|
||||
def get_subdomain() -> Optional[str]:
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> Optional[str]:
|
||||
return "social"
|
||||
|
||||
@staticmethod
|
||||
|
|
113
selfprivacy_api/services/roundcube/__init__.py
Normal file
113
selfprivacy_api/services/roundcube/__init__.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
"""Class representing Roundcube service"""
|
||||
|
||||
import base64
|
||||
import subprocess
|
||||
from typing import List, Optional
|
||||
|
||||
from selfprivacy_api.jobs import Job
|
||||
from selfprivacy_api.utils.systemd import (
|
||||
get_service_status_from_several_units,
|
||||
)
|
||||
from selfprivacy_api.services.service import Service, ServiceStatus
|
||||
from selfprivacy_api.utils import ReadUserData, get_domain
|
||||
from selfprivacy_api.utils.block_devices import BlockDevice
|
||||
from selfprivacy_api.services.roundcube.icon import ROUNDCUBE_ICON
|
||||
|
||||
|
||||
class Roundcube(Service):
|
||||
"""Class representing roundcube service"""
|
||||
|
||||
@staticmethod
|
||||
def get_id() -> str:
|
||||
"""Return service id."""
|
||||
return "roundcube"
|
||||
|
||||
@staticmethod
|
||||
def get_display_name() -> str:
|
||||
"""Return service display name."""
|
||||
return "Roundcube"
|
||||
|
||||
@staticmethod
|
||||
def get_description() -> str:
|
||||
"""Return service description."""
|
||||
return "Roundcube is an open source webmail software."
|
||||
|
||||
@staticmethod
|
||||
def get_svg_icon() -> str:
|
||||
"""Read SVG icon from file and return it as base64 encoded string."""
|
||||
return base64.b64encode(ROUNDCUBE_ICON.encode("utf-8")).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def get_url(cls) -> Optional[str]:
|
||||
"""Return service url."""
|
||||
domain = get_domain()
|
||||
subdomain = cls.get_subdomain()
|
||||
return f"https://{subdomain}.{domain}"
|
||||
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> Optional[str]:
|
||||
with ReadUserData() as data:
|
||||
if "roundcube" in data["modules"]:
|
||||
return data["modules"]["roundcube"]["subdomain"]
|
||||
|
||||
return "roundcube"
|
||||
|
||||
@staticmethod
|
||||
def is_movable() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_required() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def can_be_backed_up() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_backup_description() -> str:
|
||||
return "Nothing to backup."
|
||||
|
||||
@staticmethod
|
||||
def get_status() -> ServiceStatus:
|
||||
return get_service_status_from_several_units(["phpfpm-roundcube.service"])
|
||||
|
||||
@staticmethod
|
||||
def stop():
|
||||
subprocess.run(
|
||||
["systemctl", "stop", "phpfpm-roundcube.service"],
|
||||
check=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def start():
|
||||
subprocess.run(
|
||||
["systemctl", "start", "phpfpm-roundcube.service"],
|
||||
check=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def restart():
|
||||
subprocess.run(
|
||||
["systemctl", "restart", "phpfpm-roundcube.service"],
|
||||
check=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_configuration():
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def set_configuration(config_items):
|
||||
return super().set_configuration(config_items)
|
||||
|
||||
@staticmethod
|
||||
def get_logs():
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def get_folders() -> List[str]:
|
||||
return []
|
||||
|
||||
def move_to_volume(self, volume: BlockDevice) -> Job:
|
||||
raise NotImplementedError("roundcube service is not movable")
|
7
selfprivacy_api/services/roundcube/icon.py
Normal file
7
selfprivacy_api/services/roundcube/icon.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
ROUNDCUBE_ICON = """
|
||||
<svg fill="none" version="1.1" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
|
||||
<g transform="translate(29.07 -.3244)">
|
||||
<path d="m-17.02 2.705c-4.01 2e-7 -7.283 3.273-7.283 7.283 0 0.00524-1.1e-5 0.01038 0 0.01562l-1.85 1.068v5.613l9.105 5.26 9.104-5.26v-5.613l-1.797-1.037c1.008e-4 -0.01573 0.00195-0.03112 0.00195-0.04688-1e-7 -4.01-3.271-7.283-7.281-7.283zm0 2.012c2.923 1e-7 5.27 2.349 5.27 5.271 0 2.923-2.347 5.27-5.27 5.27-2.923-1e-6 -5.271-2.347-5.271-5.27 0-2.923 2.349-5.271 5.271-5.271z" fill="#000" fill-rule="evenodd" stroke-linejoin="bevel"/>
|
||||
</g>
|
||||
</svg>
|
||||
"""
|
|
@ -65,17 +65,17 @@ class Service(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_url() -> Optional[str]:
|
||||
def get_url(cls) -> Optional[str]:
|
||||
"""
|
||||
The url of the service if it is accessible from the internet browser.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_subdomain() -> Optional[str]:
|
||||
def get_subdomain(cls) -> Optional[str]:
|
||||
"""
|
||||
The assigned primary subdomain for this service.
|
||||
"""
|
||||
|
|
|
@ -57,14 +57,14 @@ class DummyService(Service):
|
|||
# return ""
|
||||
return base64.b64encode(BITWARDEN_ICON.encode("utf-8")).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def get_url() -> typing.Optional[str]:
|
||||
@classmethod
|
||||
def get_url(cls) -> typing.Optional[str]:
|
||||
"""Return service url."""
|
||||
domain = "test.com"
|
||||
return f"https://password.{domain}"
|
||||
|
||||
@staticmethod
|
||||
def get_subdomain() -> typing.Optional[str]:
|
||||
@classmethod
|
||||
def get_subdomain(cls) -> typing.Optional[str]:
|
||||
return "password"
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,15 +1,23 @@
|
|||
import uuid
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
def store_model_as_hash(redis, redis_key, model):
|
||||
for key, value in model.dict().items():
|
||||
model_dict = model.dict()
|
||||
for key, value in model_dict.items():
|
||||
if isinstance(value, uuid.UUID):
|
||||
value = str(value)
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
if isinstance(value, Enum):
|
||||
value = value.value
|
||||
redis.hset(redis_key, key, str(value))
|
||||
value = str(value)
|
||||
model_dict[key] = value
|
||||
|
||||
redis.hset(redis_key, mapping=model_dict)
|
||||
|
||||
|
||||
def hash_as_model(redis, redis_key: str, model_class):
|
||||
|
|
|
@ -2,23 +2,33 @@
|
|||
Redis pool module for selfprivacy_api
|
||||
"""
|
||||
import redis
|
||||
import redis.asyncio as redis_async
|
||||
|
||||
from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass
|
||||
|
||||
REDIS_SOCKET = "/run/redis-sp-api/redis.sock"
|
||||
|
||||
|
||||
class RedisPool(metaclass=SingletonMetaclass):
|
||||
# class RedisPool(metaclass=SingletonMetaclass):
|
||||
class RedisPool:
|
||||
"""
|
||||
Redis connection pool singleton.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._dbnumber = 0
|
||||
url = RedisPool.connection_url(dbnumber=self._dbnumber)
|
||||
# We need a normal sync pool because otherwise
|
||||
# our whole API will need to be async
|
||||
self._pool = redis.ConnectionPool.from_url(
|
||||
RedisPool.connection_url(dbnumber=0),
|
||||
url,
|
||||
decode_responses=True,
|
||||
)
|
||||
# We need an async pool for pubsub
|
||||
self._async_pool = redis_async.ConnectionPool.from_url(
|
||||
url,
|
||||
decode_responses=True,
|
||||
)
|
||||
self._pubsub_connection = self.get_connection()
|
||||
|
||||
@staticmethod
|
||||
def connection_url(dbnumber: int) -> str:
|
||||
|
@ -34,8 +44,15 @@ class RedisPool(metaclass=SingletonMetaclass):
|
|||
"""
|
||||
return redis.Redis(connection_pool=self._pool)
|
||||
|
||||
def get_pubsub(self):
|
||||
def get_connection_async(self) -> redis_async.Redis:
|
||||
"""
|
||||
Get a pubsub connection from the pool.
|
||||
Get an async connection from the pool.
|
||||
Async connections allow pubsub.
|
||||
"""
|
||||
return self._pubsub_connection.pubsub()
|
||||
return redis_async.Redis(connection_pool=self._async_pool)
|
||||
|
||||
async def subscribe_to_keys(self, pattern: str) -> redis_async.client.PubSub:
|
||||
async_redis = self.get_connection_async()
|
||||
pubsub = async_redis.pubsub()
|
||||
await pubsub.psubscribe(f"__keyspace@{self._dbnumber}__:" + pattern)
|
||||
return pubsub
|
||||
|
|
55
selfprivacy_api/utils/systemd_journal.py
Normal file
55
selfprivacy_api/utils/systemd_journal.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
import typing
|
||||
from systemd import journal
|
||||
|
||||
|
||||
def get_events_from_journal(
|
||||
j: journal.Reader, limit: int, next: typing.Callable[[journal.Reader], typing.Dict]
|
||||
):
|
||||
events = []
|
||||
i = 0
|
||||
while i < limit:
|
||||
entry = next(j)
|
||||
if entry is None or entry == dict():
|
||||
break
|
||||
if entry["MESSAGE"] != "":
|
||||
events.append(entry)
|
||||
i += 1
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def get_paginated_logs(
|
||||
limit: int = 20,
|
||||
# All entries returned will be lesser than this cursor. Sets upper bound on results.
|
||||
up_cursor: str | None = None,
|
||||
# All entries returned will be greater than this cursor. Sets lower bound on results.
|
||||
down_cursor: str | None = None,
|
||||
):
|
||||
j = journal.Reader()
|
||||
|
||||
if up_cursor is None and down_cursor is None:
|
||||
j.seek_tail()
|
||||
|
||||
events = get_events_from_journal(j, limit, lambda j: j.get_previous())
|
||||
events.reverse()
|
||||
|
||||
return events
|
||||
elif up_cursor is None and down_cursor is not None:
|
||||
j.seek_cursor(down_cursor)
|
||||
j.get_previous() # pagination is exclusive
|
||||
|
||||
events = get_events_from_journal(j, limit, lambda j: j.get_previous())
|
||||
events.reverse()
|
||||
|
||||
return events
|
||||
elif up_cursor is not None and down_cursor is None:
|
||||
j.seek_cursor(up_cursor)
|
||||
j.get_next() # pagination is exclusive
|
||||
|
||||
events = get_events_from_journal(j, limit, lambda j: j.get_next())
|
||||
|
||||
return events
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Pagination by both up_cursor and down_cursor is not implemented"
|
||||
)
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||
|
||||
setup(
|
||||
name="selfprivacy_api",
|
||||
version="3.2.2",
|
||||
version="3.3.0",
|
||||
packages=find_packages(),
|
||||
scripts=[
|
||||
"selfprivacy_api/app.py",
|
||||
|
|
|
@ -69,10 +69,22 @@ def generate_backup_query(query_array):
|
|||
return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}"
|
||||
|
||||
|
||||
def generate_jobs_query(query_array):
|
||||
return "query TestJobs {\n jobs {" + "\n".join(query_array) + "}\n}"
|
||||
|
||||
|
||||
def generate_jobs_subscription(query_array):
|
||||
return "subscription TestSubscription {\n jobs {" + "\n".join(query_array) + "}\n}"
|
||||
|
||||
|
||||
def generate_service_query(query_array):
|
||||
return "query TestService {\n services {" + "\n".join(query_array) + "}\n}"
|
||||
|
||||
|
||||
def generate_logs_query(query_array):
|
||||
return "query TestService {\n logs {" + "\n".join(query_array) + "}\n}"
|
||||
|
||||
|
||||
def mnemonic_to_hex(mnemonic):
|
||||
return Mnemonic(language="english").to_entropy(mnemonic).hex()
|
||||
|
||||
|
|
|
@ -62,6 +62,9 @@
|
|||
"simple-nixos-mailserver": {
|
||||
"enable": true,
|
||||
"location": "sdb"
|
||||
},
|
||||
"roundcube": {
|
||||
"enable": true
|
||||
}
|
||||
},
|
||||
"volumes": [
|
||||
|
|
172
tests/test_graphql/test_api_logs.py
Normal file
172
tests/test_graphql/test_api_logs.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
import asyncio
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from systemd import journal
|
||||
|
||||
from tests.test_graphql.test_websocket import init_graphql
|
||||
|
||||
|
||||
def assert_log_entry_equals_to_journal_entry(api_entry, journal_entry):
|
||||
assert api_entry["message"] == journal_entry["MESSAGE"]
|
||||
assert (
|
||||
datetime.fromisoformat(api_entry["timestamp"])
|
||||
== journal_entry["__REALTIME_TIMESTAMP"]
|
||||
)
|
||||
assert api_entry.get("priority") == journal_entry.get("PRIORITY")
|
||||
assert api_entry.get("systemdUnit") == journal_entry.get("_SYSTEMD_UNIT")
|
||||
assert api_entry.get("systemdSlice") == journal_entry.get("_SYSTEMD_SLICE")
|
||||
|
||||
|
||||
def take_from_journal(j, limit, next):
|
||||
entries = []
|
||||
for _ in range(0, limit):
|
||||
entry = next(j)
|
||||
if entry["MESSAGE"] != "":
|
||||
entries.append(entry)
|
||||
return entries
|
||||
|
||||
|
||||
API_GET_LOGS_WITH_UP_BORDER = """
|
||||
query TestQuery($upCursor: String) {
|
||||
logs {
|
||||
paginated(limit: 4, upCursor: $upCursor) {
|
||||
pageMeta {
|
||||
upCursor
|
||||
downCursor
|
||||
}
|
||||
entries {
|
||||
message
|
||||
timestamp
|
||||
priority
|
||||
systemdUnit
|
||||
systemdSlice
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
API_GET_LOGS_WITH_DOWN_BORDER = """
|
||||
query TestQuery($downCursor: String) {
|
||||
logs {
|
||||
paginated(limit: 4, downCursor: $downCursor) {
|
||||
pageMeta {
|
||||
upCursor
|
||||
downCursor
|
||||
}
|
||||
entries {
|
||||
message
|
||||
timestamp
|
||||
priority
|
||||
systemdUnit
|
||||
systemdSlice
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def test_graphql_get_logs_with_up_border(authorized_client):
|
||||
j = journal.Reader()
|
||||
j.seek_tail()
|
||||
|
||||
# < - cursor
|
||||
# <- - log entry will be returned by API call.
|
||||
# ...
|
||||
# log <
|
||||
# log <-
|
||||
# log <-
|
||||
# log <-
|
||||
# log <-
|
||||
# log
|
||||
|
||||
expected_entries = take_from_journal(j, 6, lambda j: j.get_previous())
|
||||
expected_entries.reverse()
|
||||
|
||||
response = authorized_client.post(
|
||||
"/graphql",
|
||||
json={
|
||||
"query": API_GET_LOGS_WITH_UP_BORDER,
|
||||
"variables": {"upCursor": expected_entries[0]["__CURSOR"]},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
expected_entries = expected_entries[1:-1]
|
||||
returned_entries = response.json()["data"]["logs"]["paginated"]["entries"]
|
||||
|
||||
assert len(returned_entries) == len(expected_entries)
|
||||
|
||||
for api_entry, journal_entry in zip(returned_entries, expected_entries):
|
||||
assert_log_entry_equals_to_journal_entry(api_entry, journal_entry)
|
||||
|
||||
|
||||
def test_graphql_get_logs_with_down_border(authorized_client):
|
||||
j = journal.Reader()
|
||||
j.seek_head()
|
||||
j.get_next()
|
||||
|
||||
# < - cursor
|
||||
# <- - log entry will be returned by API call.
|
||||
# log
|
||||
# log <-
|
||||
# log <-
|
||||
# log <-
|
||||
# log <-
|
||||
# log <
|
||||
# ...
|
||||
|
||||
expected_entries = take_from_journal(j, 5, lambda j: j.get_next())
|
||||
|
||||
response = authorized_client.post(
|
||||
"/graphql",
|
||||
json={
|
||||
"query": API_GET_LOGS_WITH_DOWN_BORDER,
|
||||
"variables": {"downCursor": expected_entries[-1]["__CURSOR"]},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
expected_entries = expected_entries[:-1]
|
||||
returned_entries = response.json()["data"]["logs"]["paginated"]["entries"]
|
||||
|
||||
assert len(returned_entries) == len(expected_entries)
|
||||
|
||||
for api_entry, journal_entry in zip(returned_entries, expected_entries):
|
||||
assert_log_entry_equals_to_journal_entry(api_entry, journal_entry)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_subscription_for_logs(authorized_client):
|
||||
with authorized_client.websocket_connect(
|
||||
"/graphql", subprotocols=["graphql-transport-ws"]
|
||||
) as websocket:
|
||||
init_graphql(websocket)
|
||||
websocket.send_json(
|
||||
{
|
||||
"id": "3aaa2445",
|
||||
"type": "subscribe",
|
||||
"payload": {
|
||||
"query": "subscription TestSubscription { logEntries { message } }",
|
||||
},
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def read_until(message, limit=5):
|
||||
i = 0
|
||||
while i < limit:
|
||||
msg = websocket.receive_json()["payload"]["data"]["logEntries"][
|
||||
"message"
|
||||
]
|
||||
if msg == message:
|
||||
return
|
||||
else:
|
||||
i += 1
|
||||
continue
|
||||
raise Exception("Failed to read websocket data, timeout")
|
||||
|
||||
for i in range(0, 10):
|
||||
journal.send(f"Lorem ipsum number {i}")
|
||||
read_until(f"Lorem ipsum number {i}")
|
74
tests/test_graphql/test_jobs.py
Normal file
74
tests/test_graphql/test_jobs.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
from tests.common import generate_jobs_query
|
||||
import tests.test_graphql.test_api_backup
|
||||
|
||||
from tests.test_graphql.common import (
|
||||
assert_ok,
|
||||
assert_empty,
|
||||
assert_errorcode,
|
||||
get_data,
|
||||
)
|
||||
|
||||
from selfprivacy_api.jobs import Jobs
|
||||
|
||||
API_JOBS_QUERY = """
|
||||
getJobs {
|
||||
uid
|
||||
typeId
|
||||
name
|
||||
description
|
||||
status
|
||||
statusText
|
||||
progress
|
||||
createdAt
|
||||
updatedAt
|
||||
finishedAt
|
||||
error
|
||||
result
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def graphql_send_query(client, query: str, variables: dict = {}):
|
||||
return client.post("/graphql", json={"query": query, "variables": variables})
|
||||
|
||||
|
||||
def api_jobs(authorized_client):
|
||||
response = graphql_send_query(
|
||||
authorized_client, generate_jobs_query([API_JOBS_QUERY])
|
||||
)
|
||||
data = get_data(response)
|
||||
result = data["jobs"]["getJobs"]
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
|
||||
def test_all_jobs_unauthorized(client):
|
||||
response = graphql_send_query(client, generate_jobs_query([API_JOBS_QUERY]))
|
||||
assert_empty(response)
|
||||
|
||||
|
||||
def test_all_jobs_when_none(authorized_client):
|
||||
output = api_jobs(authorized_client)
|
||||
assert output == []
|
||||
|
||||
|
||||
def test_all_jobs_when_some(authorized_client):
|
||||
# We cannot make new jobs via API, at least directly
|
||||
job = Jobs.add("bogus", "bogus.bogus", "fungus")
|
||||
output = api_jobs(authorized_client)
|
||||
|
||||
len(output) == 1
|
||||
api_job = output[0]
|
||||
|
||||
assert api_job["uid"] == str(job.uid)
|
||||
assert api_job["typeId"] == job.type_id
|
||||
assert api_job["name"] == job.name
|
||||
assert api_job["description"] == job.description
|
||||
assert api_job["status"] == job.status
|
||||
assert api_job["statusText"] == job.status_text
|
||||
assert api_job["progress"] == job.progress
|
||||
assert api_job["createdAt"] == job.created_at.isoformat()
|
||||
assert api_job["updatedAt"] == job.updated_at.isoformat()
|
||||
assert api_job["finishedAt"] == None
|
||||
assert api_job["error"] == None
|
||||
assert api_job["result"] == None
|
|
@ -543,8 +543,8 @@ def test_disable_enable(authorized_client, only_dummy_service):
|
|||
assert api_dummy_service["status"] == ServiceStatus.ACTIVE.value
|
||||
|
||||
|
||||
def test_move_immovable(authorized_client, only_dummy_service):
|
||||
dummy_service = only_dummy_service
|
||||
def test_move_immovable(authorized_client, dummy_service_with_binds):
|
||||
dummy_service = dummy_service_with_binds
|
||||
dummy_service.set_movable(False)
|
||||
root = BlockDevices().get_root_block_device()
|
||||
mutation_response = api_move(authorized_client, dummy_service, root.name)
|
||||
|
|
225
tests/test_graphql/test_websocket.py
Normal file
225
tests/test_graphql/test_websocket.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
# from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import Generator
|
||||
from time import sleep
|
||||
|
||||
from starlette.testclient import WebSocketTestSession
|
||||
|
||||
from selfprivacy_api.jobs import Jobs
|
||||
from selfprivacy_api.actions.api_tokens import TOKEN_REPO
|
||||
from selfprivacy_api.graphql import IsAuthenticated
|
||||
|
||||
from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH
|
||||
from tests.test_jobs import jobs as empty_jobs
|
||||
|
||||
# We do not iterate through them yet
|
||||
TESTED_SUBPROTOCOLS = ["graphql-transport-ws"]
|
||||
|
||||
JOBS_SUBSCRIPTION = """
|
||||
jobUpdates {
|
||||
uid
|
||||
typeId
|
||||
name
|
||||
description
|
||||
status
|
||||
statusText
|
||||
progress
|
||||
createdAt
|
||||
updatedAt
|
||||
finishedAt
|
||||
error
|
||||
result
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def api_subscribe(websocket, id, subscription):
|
||||
websocket.send_json(
|
||||
{
|
||||
"id": id,
|
||||
"type": "subscribe",
|
||||
"payload": {
|
||||
"query": "subscription TestSubscription {" + subscription + "}",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def connect_ws_authenticated(authorized_client) -> WebSocketTestSession:
|
||||
token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"])
|
||||
return authorized_client.websocket_connect(
|
||||
"/graphql",
|
||||
subprotocols=TESTED_SUBPROTOCOLS,
|
||||
params={"token": token},
|
||||
)
|
||||
|
||||
|
||||
def connect_ws_not_authenticated(client) -> WebSocketTestSession:
|
||||
return client.websocket_connect(
|
||||
"/graphql",
|
||||
subprotocols=TESTED_SUBPROTOCOLS,
|
||||
params={"token": "I like vegan icecream but it is not a valid token"},
|
||||
)
|
||||
|
||||
|
||||
def init_graphql(websocket):
|
||||
websocket.send_json({"type": "connection_init", "payload": {}})
|
||||
ack = websocket.receive_json()
|
||||
assert ack == {"type": "connection_ack"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_websocket(
|
||||
authorized_client,
|
||||
) -> Generator[WebSocketTestSession, None, None]:
|
||||
# We use authorized_client only to have token in the repo, this client by itself is not enough to authorize websocket
|
||||
|
||||
ValueError(TOKEN_REPO.get_tokens())
|
||||
with connect_ws_authenticated(authorized_client) as websocket:
|
||||
yield websocket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]:
|
||||
with connect_ws_not_authenticated(client) as websocket:
|
||||
yield websocket
|
||||
|
||||
|
||||
def test_websocket_connection_bare(authorized_client):
|
||||
client = authorized_client
|
||||
with client.websocket_connect(
|
||||
"/graphql", subprotocols=["graphql-transport-ws", "graphql-ws"]
|
||||
) as websocket:
|
||||
assert websocket is not None
|
||||
assert websocket.scope is not None
|
||||
|
||||
|
||||
def test_websocket_graphql_init(authorized_client):
|
||||
client = authorized_client
|
||||
with client.websocket_connect(
|
||||
"/graphql", subprotocols=["graphql-transport-ws"]
|
||||
) as websocket:
|
||||
websocket.send_json({"type": "connection_init", "payload": {}})
|
||||
ack = websocket.receive_json()
|
||||
assert ack == {"type": "connection_ack"}
|
||||
|
||||
|
||||
def test_websocket_graphql_ping(authorized_client):
|
||||
client = authorized_client
|
||||
with client.websocket_connect(
|
||||
"/graphql", subprotocols=["graphql-transport-ws"]
|
||||
) as websocket:
|
||||
# https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping
|
||||
websocket.send_json({"type": "ping", "payload": {}})
|
||||
pong = websocket.receive_json()
|
||||
assert pong == {"type": "pong"}
|
||||
|
||||
|
||||
def test_websocket_subscription_minimal(authorized_client, authenticated_websocket):
|
||||
# Test a small endpoint that exists specifically for tests
|
||||
websocket = authenticated_websocket
|
||||
init_graphql(websocket)
|
||||
arbitrary_id = "3aaa2445"
|
||||
api_subscribe(websocket, arbitrary_id, "count")
|
||||
response = websocket.receive_json()
|
||||
assert response == {
|
||||
"id": arbitrary_id,
|
||||
"payload": {"data": {"count": 0}},
|
||||
"type": "next",
|
||||
}
|
||||
response = websocket.receive_json()
|
||||
assert response == {
|
||||
"id": arbitrary_id,
|
||||
"payload": {"data": {"count": 1}},
|
||||
"type": "next",
|
||||
}
|
||||
response = websocket.receive_json()
|
||||
assert response == {
|
||||
"id": arbitrary_id,
|
||||
"payload": {"data": {"count": 2}},
|
||||
"type": "next",
|
||||
}
|
||||
|
||||
|
||||
def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket):
|
||||
websocket = unauthenticated_websocket
|
||||
init_graphql(websocket)
|
||||
arbitrary_id = "3aaa2445"
|
||||
api_subscribe(websocket, arbitrary_id, "count")
|
||||
|
||||
response = websocket.receive_json()
|
||||
assert response == {
|
||||
"id": arbitrary_id,
|
||||
"payload": [{"message": IsAuthenticated.message}],
|
||||
"type": "error",
|
||||
}
|
||||
|
||||
|
||||
async def read_one_job(websocket):
|
||||
# Bug? We only get them starting from the second job update
|
||||
# That's why we receive two jobs in the list them
|
||||
# The first update gets lost somewhere
|
||||
response = websocket.receive_json()
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs):
|
||||
websocket = authenticated_websocket
|
||||
init_graphql(websocket)
|
||||
arbitrary_id = "3aaa2445"
|
||||
api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION)
|
||||
|
||||
future = asyncio.create_task(read_one_job(websocket))
|
||||
jobs = []
|
||||
jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works"))
|
||||
sleep(0.5)
|
||||
jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works"))
|
||||
|
||||
response = await future
|
||||
data = response["payload"]["data"]
|
||||
jobs_received = data["jobUpdates"]
|
||||
received_names = [job["name"] for job in jobs_received]
|
||||
for job in jobs:
|
||||
assert job.name in received_names
|
||||
|
||||
assert len(jobs_received) == 2
|
||||
|
||||
for job in jobs:
|
||||
for api_job in jobs_received:
|
||||
if (job.name) == api_job["name"]:
|
||||
assert api_job["uid"] == str(job.uid)
|
||||
assert api_job["typeId"] == job.type_id
|
||||
assert api_job["name"] == job.name
|
||||
assert api_job["description"] == job.description
|
||||
assert api_job["status"] == job.status
|
||||
assert api_job["statusText"] == job.status_text
|
||||
assert api_job["progress"] == job.progress
|
||||
assert api_job["createdAt"] == job.created_at.isoformat()
|
||||
assert api_job["updatedAt"] == job.updated_at.isoformat()
|
||||
assert api_job["finishedAt"] == None
|
||||
assert api_job["error"] == None
|
||||
assert api_job["result"] == None
|
||||
|
||||
|
||||
def test_websocket_subscription_unauthorized(unauthenticated_websocket):
|
||||
websocket = unauthenticated_websocket
|
||||
init_graphql(websocket)
|
||||
id = "3aaa2445"
|
||||
api_subscribe(websocket, id, JOBS_SUBSCRIPTION)
|
||||
|
||||
response = websocket.receive_json()
|
||||
# I do not really know why strawberry gives more info on this
|
||||
# One versus the counter
|
||||
payload = response["payload"][0]
|
||||
assert isinstance(payload, dict)
|
||||
assert "locations" in payload.keys()
|
||||
# It looks like this 'locations': [{'column': 32, 'line': 1}]
|
||||
# We cannot test locations feasibly
|
||||
del payload["locations"]
|
||||
assert response == {
|
||||
"id": id,
|
||||
"payload": [{"message": IsAuthenticated.message, "path": ["jobUpdates"]}],
|
||||
"type": "error",
|
||||
}
|
218
tests/test_redis.py
Normal file
218
tests/test_redis.py
Normal file
|
@ -0,0 +1,218 @@
|
|||
import asyncio
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from asyncio import streams
|
||||
import redis
|
||||
from typing import List
|
||||
|
||||
from selfprivacy_api.utils.redis_pool import RedisPool
|
||||
|
||||
from selfprivacy_api.jobs import Jobs, job_notifications
|
||||
|
||||
TEST_KEY = "test:test"
|
||||
STOPWORD = "STOP"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def empty_redis(event_loop):
|
||||
r = RedisPool().get_connection()
|
||||
r.flushdb()
|
||||
assert r.config_get("notify-keyspace-events")["notify-keyspace-events"] == "AKE"
|
||||
yield r
|
||||
r.flushdb()
|
||||
|
||||
|
||||
async def write_to_test_key():
|
||||
r = RedisPool().get_connection_async()
|
||||
async with r.pipeline(transaction=True) as pipe:
|
||||
ok1, ok2 = await pipe.set(TEST_KEY, "value1").set(TEST_KEY, "value2").execute()
|
||||
assert ok1
|
||||
assert ok2
|
||||
assert await r.get(TEST_KEY) == "value2"
|
||||
await r.close()
|
||||
|
||||
|
||||
def test_async_connection(empty_redis):
|
||||
r = RedisPool().get_connection()
|
||||
assert not r.exists(TEST_KEY)
|
||||
# It _will_ report an error if it arises
|
||||
asyncio.run(write_to_test_key())
|
||||
# Confirming that we can read result from sync connection too
|
||||
assert r.get(TEST_KEY) == "value2"
|
||||
|
||||
|
||||
async def channel_reader(channel: redis.client.PubSub) -> List[dict]:
|
||||
result: List[dict] = []
|
||||
while True:
|
||||
# Mypy cannot correctly detect that it is a coroutine
|
||||
# But it is
|
||||
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
|
||||
if message is not None:
|
||||
result.append(message)
|
||||
if message["data"] == STOPWORD:
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
async def channel_reader_onemessage(channel: redis.client.PubSub) -> dict:
|
||||
while True:
|
||||
# Mypy cannot correctly detect that it is a coroutine
|
||||
# But it is
|
||||
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
|
||||
if message is not None:
|
||||
return message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pubsub(empty_redis, event_loop):
|
||||
# Adapted from :
|
||||
# https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html
|
||||
# Sanity checking because of previous event loop bugs
|
||||
assert event_loop == asyncio.get_event_loop()
|
||||
assert event_loop == asyncio.events.get_event_loop()
|
||||
assert event_loop == asyncio.events._get_event_loop()
|
||||
assert event_loop == asyncio.events.get_running_loop()
|
||||
|
||||
reader = streams.StreamReader(34)
|
||||
assert event_loop == reader._loop
|
||||
f = reader._loop.create_future()
|
||||
f.set_result(3)
|
||||
await f
|
||||
|
||||
r = RedisPool().get_connection_async()
|
||||
async with r.pubsub() as pubsub:
|
||||
await pubsub.subscribe("channel:1")
|
||||
future = asyncio.create_task(channel_reader(pubsub))
|
||||
|
||||
await r.publish("channel:1", "Hello")
|
||||
# message: dict = await pubsub.get_message(ignore_subscribe_messages=True, timeout=5.0) # type: ignore
|
||||
# raise ValueError(message)
|
||||
await r.publish("channel:1", "World")
|
||||
await r.publish("channel:1", STOPWORD)
|
||||
|
||||
messages = await future
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
message = messages[0]
|
||||
assert "data" in message.keys()
|
||||
assert message["data"] == "Hello"
|
||||
message = messages[1]
|
||||
assert "data" in message.keys()
|
||||
assert message["data"] == "World"
|
||||
message = messages[2]
|
||||
assert "data" in message.keys()
|
||||
assert message["data"] == STOPWORD
|
||||
|
||||
await r.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyspace_notifications_simple(empty_redis, event_loop):
|
||||
r = RedisPool().get_connection_async()
|
||||
await r.set(TEST_KEY, "I am not empty")
|
||||
async with r.pubsub() as pubsub:
|
||||
await pubsub.subscribe("__keyspace@0__:" + TEST_KEY)
|
||||
|
||||
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
|
||||
empty_redis.set(TEST_KEY, "I am set!")
|
||||
message = await future_message
|
||||
assert message is not None
|
||||
assert message["data"] is not None
|
||||
assert message == {
|
||||
"channel": f"__keyspace@0__:{TEST_KEY}",
|
||||
"data": "set",
|
||||
"pattern": None,
|
||||
"type": "message",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyspace_notifications(empty_redis, event_loop):
|
||||
pubsub = await RedisPool().subscribe_to_keys(TEST_KEY)
|
||||
async with pubsub:
|
||||
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
|
||||
empty_redis.set(TEST_KEY, "I am set!")
|
||||
message = await future_message
|
||||
assert message is not None
|
||||
assert message["data"] is not None
|
||||
assert message == {
|
||||
"channel": f"__keyspace@0__:{TEST_KEY}",
|
||||
"data": "set",
|
||||
"pattern": f"__keyspace@0__:{TEST_KEY}",
|
||||
"type": "pmessage",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyspace_notifications_patterns(empty_redis, event_loop):
|
||||
pattern = "test*"
|
||||
pubsub = await RedisPool().subscribe_to_keys(pattern)
|
||||
async with pubsub:
|
||||
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
|
||||
empty_redis.set(TEST_KEY, "I am set!")
|
||||
message = await future_message
|
||||
assert message is not None
|
||||
assert message["data"] is not None
|
||||
assert message == {
|
||||
"channel": f"__keyspace@0__:{TEST_KEY}",
|
||||
"data": "set",
|
||||
"pattern": f"__keyspace@0__:{pattern}",
|
||||
"type": "pmessage",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyspace_notifications_jobs(empty_redis, event_loop):
|
||||
pattern = "jobs:*"
|
||||
pubsub = await RedisPool().subscribe_to_keys(pattern)
|
||||
async with pubsub:
|
||||
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
|
||||
Jobs.add("testjob1", "test.test", "Testing aaaalll day")
|
||||
message = await future_message
|
||||
assert message is not None
|
||||
assert message["data"] is not None
|
||||
assert message["data"] == "hset"
|
||||
|
||||
|
||||
async def reader_of_jobs() -> List[dict]:
|
||||
"""
|
||||
Reads 3 job updates and exits
|
||||
"""
|
||||
result: List[dict] = []
|
||||
async for message in job_notifications():
|
||||
result.append(message)
|
||||
if len(result) >= 3:
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jobs_generator(empty_redis, event_loop):
|
||||
# Will read exactly 3 job messages
|
||||
future_messages = asyncio.create_task(reader_of_jobs())
|
||||
await asyncio.sleep(1)
|
||||
|
||||
Jobs.add("testjob1", "test.test", "Testing aaaalll day")
|
||||
Jobs.add("testjob2", "test.test", "Testing aaaalll day")
|
||||
Jobs.add("testjob3", "test.test", "Testing aaaalll day")
|
||||
Jobs.add("testjob4", "test.test", "Testing aaaalll day")
|
||||
|
||||
assert len(Jobs.get_jobs()) == 4
|
||||
r = RedisPool().get_connection()
|
||||
assert len(r.keys("jobs:*")) == 4
|
||||
|
||||
messages = await future_messages
|
||||
assert len(messages) == 3
|
||||
channels = [message["channel"] for message in messages]
|
||||
operations = [message["data"] for message in messages]
|
||||
assert set(operations) == set(["hset"]) # all of them are hsets
|
||||
|
||||
# Asserting that all of jobs emitted exactly one message
|
||||
jobs = Jobs.get_jobs()
|
||||
names = ["testjob1", "testjob2", "testjob3"]
|
||||
ids = [str(job.uid) for job in jobs if job.name in names]
|
||||
for id in ids:
|
||||
assert id in " ".join(channels)
|
||||
# Asserting that they came in order
|
||||
assert "testjob4" not in " ".join(channels)
|
39
tests/test_websocket_uvicorn_standalone.py
Normal file
39
tests/test_websocket_uvicorn_standalone.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import pytest
|
||||
from fastapi import FastAPI, WebSocket
|
||||
import uvicorn
|
||||
|
||||
# import subprocess
|
||||
from multiprocessing import Process
|
||||
import asyncio
|
||||
from time import sleep
|
||||
from websockets import client
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.websocket("/")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
await websocket.send_text(f"You sent: {data}")
|
||||
|
||||
|
||||
def run_uvicorn():
|
||||
uvicorn.run(app, port=5000)
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uvcorn_ws_works_in_prod():
|
||||
proc = Process(target=run_uvicorn)
|
||||
proc.start()
|
||||
sleep(2)
|
||||
|
||||
ws = await client.connect("ws://127.0.0.1:5000")
|
||||
|
||||
await ws.send("hohoho")
|
||||
message = await ws.read_message()
|
||||
assert message == "You sent: hohoho"
|
||||
await ws.close()
|
||||
proc.kill()
|
Loading…
Reference in a new issue