mirror of
https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api.git
synced 2024-11-22 04:01:27 +00:00
Merge remote-tracking branch 'origin/master' into inex/service-settings
This commit is contained in:
commit
c8d00e6c87
|
@ -13,9 +13,9 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
|
||||||
|
|
||||||
For detailed installation information, please review and follow: [link](https://nixos.org/manual/nix/stable/installation/installing-binary.html#installing-a-binary-distribution).
|
For detailed installation information, please review and follow: [link](https://nixos.org/manual/nix/stable/installation/installing-binary.html#installing-a-binary-distribution).
|
||||||
|
|
||||||
3. **Change directory to the cloned repository and start a nix shell:**
|
3. **Change directory to the cloned repository and start a nix development shell:**
|
||||||
|
|
||||||
```cd selfprivacy-rest-api && nix-shell```
|
```cd selfprivacy-rest-api && nix develop```
|
||||||
|
|
||||||
Nix will install all of the necessary packages for development work, all further actions will take place only within nix-shell.
|
Nix will install all of the necessary packages for development work, all further actions will take place only within nix-shell.
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
|
||||||
|
|
||||||
Copy the path that starts with ```/nix/store/``` and ends with ```env/bin/python```
|
Copy the path that starts with ```/nix/store/``` and ends with ```env/bin/python```
|
||||||
|
|
||||||
```/nix/store/???-python3-3.9.??-env/bin/python```
|
```/nix/store/???-python3-3.10.??-env/bin/python```
|
||||||
|
|
||||||
Click on the python version selection in the lower right corner, and replace the path to the interpreter in the project with the one you copied from the terminal.
|
Click on the python version selection in the lower right corner, and replace the path to the interpreter in the project with the one you copied from the terminal.
|
||||||
|
|
||||||
|
@ -43,12 +43,13 @@ the [repository](https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api),
|
||||||
|
|
||||||
## What to do after making changes to the repository?
|
## What to do after making changes to the repository?
|
||||||
|
|
||||||
**Run unit tests** using ```pytest .```
|
**Run unit tests** using ```pytest-vm``` inside of the development shell. This will run all the test inside a virtual machine, which is necessary for the tests to pass successfully.
|
||||||
Make sure that all tests pass successfully and the API works correctly. For convenience, you can use the built-in VScode interface.
|
Make sure that all tests pass successfully and the API works correctly.
|
||||||
|
|
||||||
How to review the percentage of code coverage? Execute the command:
|
The ```pytest-vm``` command will also print out the coverage of the tests. To export the report to an XML file, use the following command:
|
||||||
|
|
||||||
|
```coverage xml```
|
||||||
|
|
||||||
```coverage run -m pytest && coverage xml && coverage report```
|
|
||||||
|
|
||||||
Next, use the recommended extension ```ryanluker.vscode-coverage-gutters```, navigate to one of the test files, and click the "watch" button on the bottom panel of VScode.
|
Next, use the recommended extension ```ryanluker.vscode-coverage-gutters```, navigate to one of the test files, and click the "watch" button on the bottom panel of VScode.
|
||||||
|
|
||||||
|
|
|
@ -14,10 +14,12 @@ pythonPackages.buildPythonPackage rec {
|
||||||
pydantic
|
pydantic
|
||||||
pytz
|
pytz
|
||||||
redis
|
redis
|
||||||
|
systemd
|
||||||
setuptools
|
setuptools
|
||||||
strawberry-graphql
|
strawberry-graphql
|
||||||
typing-extensions
|
typing-extensions
|
||||||
uvicorn
|
uvicorn
|
||||||
|
websockets
|
||||||
];
|
];
|
||||||
pythonImportsCheck = [ "selfprivacy_api" ];
|
pythonImportsCheck = [ "selfprivacy_api" ];
|
||||||
doCheck = false;
|
doCheck = false;
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1709677081,
|
"lastModified": 1719957072,
|
||||||
"narHash": "sha256-tix36Y7u0rkn6mTm0lA45b45oab2cFLqAzDbJxeXS+c=",
|
"narHash": "sha256-gvFhEf5nszouwLAkT9nWsDzocUTqLWHuL++dvNjMp9I=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "880992dcc006a5e00dd0591446fdf723e6a51a64",
|
"rev": "7144d6241f02d171d25fba3edeaf15e0f2592105",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
pytest-datadir
|
pytest-datadir
|
||||||
pytest-mock
|
pytest-mock
|
||||||
pytest-subprocess
|
pytest-subprocess
|
||||||
|
pytest-asyncio
|
||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
pylsp-mypy
|
pylsp-mypy
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from strawberry.fastapi import GraphQLRouter
|
from strawberry.fastapi import GraphQLRouter
|
||||||
|
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
@ -13,8 +14,12 @@ from selfprivacy_api.migrations import run_migrations
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
graphql_app = GraphQLRouter(
|
graphql_app: GraphQLRouter = GraphQLRouter(
|
||||||
schema,
|
schema,
|
||||||
|
subscription_protocols=[
|
||||||
|
GRAPHQL_TRANSPORT_WS_PROTOCOL,
|
||||||
|
GRAPHQL_WS_PROTOCOL,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
|
|
@ -27,4 +27,4 @@ async def get_token_header(
|
||||||
|
|
||||||
def get_api_version() -> str:
|
def get_api_version() -> str:
|
||||||
"""Get API version"""
|
"""Get API version"""
|
||||||
return "3.2.2+configs"
|
return "3.3.0+configs"
|
||||||
|
|
|
@ -16,6 +16,10 @@ class IsAuthenticated(BasePermission):
|
||||||
token = info.context["request"].headers.get("Authorization")
|
token = info.context["request"].headers.get("Authorization")
|
||||||
if token is None:
|
if token is None:
|
||||||
token = info.context["request"].query_params.get("token")
|
token = info.context["request"].query_params.get("token")
|
||||||
|
if token is None:
|
||||||
|
connection_params = info.context.get("connection_params")
|
||||||
|
if connection_params is not None:
|
||||||
|
token = connection_params.get("Authorization")
|
||||||
if token is None:
|
if token is None:
|
||||||
return False
|
return False
|
||||||
return is_token_valid(token.replace("Bearer ", ""))
|
return is_token_valid(token.replace("Bearer ", ""))
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
"""API access status"""
|
"""API access status"""
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
import datetime
|
import datetime
|
||||||
import typing
|
import typing
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
||||||
from strawberry.types import Info
|
from strawberry.types import Info
|
||||||
from selfprivacy_api.actions.api_tokens import (
|
from selfprivacy_api.actions.api_tokens import (
|
||||||
get_api_tokens_with_caller_flag,
|
get_api_tokens_with_caller_flag,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Backup"""
|
"""Backup"""
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
import typing
|
import typing
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Common types and enums used by different types of queries."""
|
"""Common types and enums used by different types of queries."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import datetime
|
import datetime
|
||||||
import typing
|
import typing
|
||||||
|
|
|
@ -1,24 +1,30 @@
|
||||||
"""Jobs status"""
|
"""Jobs status"""
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
import typing
|
|
||||||
import strawberry
|
import strawberry
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from selfprivacy_api.jobs import Jobs
|
||||||
from selfprivacy_api.graphql.common_types.jobs import (
|
from selfprivacy_api.graphql.common_types.jobs import (
|
||||||
ApiJob,
|
ApiJob,
|
||||||
get_api_job_by_id,
|
get_api_job_by_id,
|
||||||
job_to_api_job,
|
job_to_api_job,
|
||||||
)
|
)
|
||||||
|
|
||||||
from selfprivacy_api.jobs import Jobs
|
|
||||||
|
def get_all_jobs() -> List[ApiJob]:
|
||||||
|
jobs = Jobs.get_jobs()
|
||||||
|
api_jobs = [job_to_api_job(job) for job in jobs]
|
||||||
|
assert api_jobs is not None
|
||||||
|
return api_jobs
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
@strawberry.type
|
||||||
class Job:
|
class Job:
|
||||||
@strawberry.field
|
@strawberry.field
|
||||||
def get_jobs(self) -> typing.List[ApiJob]:
|
def get_jobs(self) -> List[ApiJob]:
|
||||||
Jobs.get_jobs()
|
return get_all_jobs()
|
||||||
|
|
||||||
return [job_to_api_job(job) for job in Jobs.get_jobs()]
|
|
||||||
|
|
||||||
@strawberry.field
|
@strawberry.field
|
||||||
def get_job(self, job_id: str) -> typing.Optional[ApiJob]:
|
def get_job(self, job_id: str) -> Optional[ApiJob]:
|
||||||
return get_api_job_by_id(job_id)
|
return get_api_job_by_id(job_id)
|
||||||
|
|
88
selfprivacy_api/graphql/queries/logs.py
Normal file
88
selfprivacy_api/graphql/queries/logs.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
"""System logs"""
|
||||||
|
from datetime import datetime
|
||||||
|
import typing
|
||||||
|
import strawberry
|
||||||
|
from selfprivacy_api.utils.systemd_journal import get_paginated_logs
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.type
|
||||||
|
class LogEntry:
|
||||||
|
message: str = strawberry.field()
|
||||||
|
timestamp: datetime = strawberry.field()
|
||||||
|
priority: typing.Optional[int] = strawberry.field()
|
||||||
|
systemd_unit: typing.Optional[str] = strawberry.field()
|
||||||
|
systemd_slice: typing.Optional[str] = strawberry.field()
|
||||||
|
|
||||||
|
def __init__(self, journal_entry: typing.Dict):
|
||||||
|
self.entry = journal_entry
|
||||||
|
self.message = journal_entry["MESSAGE"]
|
||||||
|
self.timestamp = journal_entry["__REALTIME_TIMESTAMP"]
|
||||||
|
self.priority = journal_entry.get("PRIORITY")
|
||||||
|
self.systemd_unit = journal_entry.get("_SYSTEMD_UNIT")
|
||||||
|
self.systemd_slice = journal_entry.get("_SYSTEMD_SLICE")
|
||||||
|
|
||||||
|
@strawberry.field()
|
||||||
|
def cursor(self) -> str:
|
||||||
|
return self.entry["__CURSOR"]
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.type
|
||||||
|
class LogsPageMeta:
|
||||||
|
up_cursor: typing.Optional[str] = strawberry.field()
|
||||||
|
down_cursor: typing.Optional[str] = strawberry.field()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, up_cursor: typing.Optional[str], down_cursor: typing.Optional[str]
|
||||||
|
):
|
||||||
|
self.up_cursor = up_cursor
|
||||||
|
self.down_cursor = down_cursor
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.type
|
||||||
|
class PaginatedEntries:
|
||||||
|
page_meta: LogsPageMeta = strawberry.field(
|
||||||
|
description="Metadata to aid in pagination."
|
||||||
|
)
|
||||||
|
entries: typing.List[LogEntry] = strawberry.field(
|
||||||
|
description="The list of log entries."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, meta: LogsPageMeta, entries: typing.List[LogEntry]):
|
||||||
|
self.page_meta = meta
|
||||||
|
self.entries = entries
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_entries(entries: typing.List[LogEntry]):
|
||||||
|
if entries == []:
|
||||||
|
return PaginatedEntries(LogsPageMeta(None, None), [])
|
||||||
|
|
||||||
|
return PaginatedEntries(
|
||||||
|
LogsPageMeta(
|
||||||
|
entries[0].cursor(),
|
||||||
|
entries[-1].cursor(),
|
||||||
|
),
|
||||||
|
entries,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.type
|
||||||
|
class Logs:
|
||||||
|
@strawberry.field()
|
||||||
|
def paginated(
|
||||||
|
self,
|
||||||
|
limit: int = 20,
|
||||||
|
# All entries returned will be lesser than this cursor. Sets upper bound on results.
|
||||||
|
up_cursor: str | None = None,
|
||||||
|
# All entries returned will be greater than this cursor. Sets lower bound on results.
|
||||||
|
down_cursor: str | None = None,
|
||||||
|
) -> PaginatedEntries:
|
||||||
|
if limit > 50:
|
||||||
|
raise Exception("You can't fetch more than 50 entries via single request.")
|
||||||
|
return PaginatedEntries.from_entries(
|
||||||
|
list(
|
||||||
|
map(
|
||||||
|
lambda x: LogEntry(x),
|
||||||
|
get_paginated_logs(limit, up_cursor, down_cursor),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
|
@ -1,4 +1,5 @@
|
||||||
"""Enums representing different service providers."""
|
"""Enums representing different service providers."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Services status"""
|
"""Services status"""
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
import typing
|
import typing
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Storage queries."""
|
"""Storage queries."""
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
import typing
|
import typing
|
||||||
import strawberry
|
import strawberry
|
||||||
|
@ -18,9 +19,11 @@ class Storage:
|
||||||
"""Get list of volumes"""
|
"""Get list of volumes"""
|
||||||
return [
|
return [
|
||||||
StorageVolume(
|
StorageVolume(
|
||||||
total_space=str(volume.fssize)
|
total_space=(
|
||||||
if volume.fssize is not None
|
str(volume.fssize)
|
||||||
else str(volume.size),
|
if volume.fssize is not None
|
||||||
|
else str(volume.size)
|
||||||
|
),
|
||||||
free_space=str(volume.fsavail),
|
free_space=str(volume.fsavail),
|
||||||
used_space=str(volume.fsused),
|
used_space=str(volume.fsused),
|
||||||
root=volume.is_root(),
|
root=volume.is_root(),
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
"""Common system information and settings"""
|
"""Common system information and settings"""
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
import os
|
import os
|
||||||
import typing
|
import typing
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
||||||
from selfprivacy_api.graphql.common_types.dns import DnsRecord
|
from selfprivacy_api.graphql.common_types.dns import DnsRecord
|
||||||
|
|
||||||
from selfprivacy_api.graphql.queries.common import Alert, Severity
|
from selfprivacy_api.graphql.queries.common import Alert, Severity
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Users"""
|
"""Users"""
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
import typing
|
import typing
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
|
@ -2,8 +2,10 @@
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List
|
||||||
import strawberry
|
import strawberry
|
||||||
|
from strawberry.types import Info
|
||||||
|
|
||||||
from selfprivacy_api.graphql import IsAuthenticated
|
from selfprivacy_api.graphql import IsAuthenticated
|
||||||
from selfprivacy_api.graphql.mutations.deprecated_mutations import (
|
from selfprivacy_api.graphql.mutations.deprecated_mutations import (
|
||||||
DeprecatedApiMutations,
|
DeprecatedApiMutations,
|
||||||
|
@ -24,10 +26,17 @@ from selfprivacy_api.graphql.mutations.backup_mutations import BackupMutations
|
||||||
from selfprivacy_api.graphql.queries.api_queries import Api
|
from selfprivacy_api.graphql.queries.api_queries import Api
|
||||||
from selfprivacy_api.graphql.queries.backup import Backup
|
from selfprivacy_api.graphql.queries.backup import Backup
|
||||||
from selfprivacy_api.graphql.queries.jobs import Job
|
from selfprivacy_api.graphql.queries.jobs import Job
|
||||||
|
from selfprivacy_api.graphql.queries.logs import LogEntry, Logs
|
||||||
from selfprivacy_api.graphql.queries.services import Services
|
from selfprivacy_api.graphql.queries.services import Services
|
||||||
from selfprivacy_api.graphql.queries.storage import Storage
|
from selfprivacy_api.graphql.queries.storage import Storage
|
||||||
from selfprivacy_api.graphql.queries.system import System
|
from selfprivacy_api.graphql.queries.system import System
|
||||||
|
|
||||||
|
from selfprivacy_api.graphql.subscriptions.jobs import ApiJob
|
||||||
|
from selfprivacy_api.graphql.subscriptions.jobs import (
|
||||||
|
job_updates as job_update_generator,
|
||||||
|
)
|
||||||
|
from selfprivacy_api.graphql.subscriptions.logs import log_stream
|
||||||
|
|
||||||
from selfprivacy_api.graphql.common_types.service import (
|
from selfprivacy_api.graphql.common_types.service import (
|
||||||
StringConfigItem,
|
StringConfigItem,
|
||||||
BoolConfigItem,
|
BoolConfigItem,
|
||||||
|
@ -53,6 +62,11 @@ class Query:
|
||||||
"""System queries"""
|
"""System queries"""
|
||||||
return System()
|
return System()
|
||||||
|
|
||||||
|
@strawberry.field(permission_classes=[IsAuthenticated])
|
||||||
|
def logs(self) -> Logs:
|
||||||
|
"""Log queries"""
|
||||||
|
return Logs()
|
||||||
|
|
||||||
@strawberry.field(permission_classes=[IsAuthenticated])
|
@strawberry.field(permission_classes=[IsAuthenticated])
|
||||||
def users(self) -> Users:
|
def users(self) -> Users:
|
||||||
"""Users queries"""
|
"""Users queries"""
|
||||||
|
@ -135,19 +149,42 @@ class Mutation(
|
||||||
code=200,
|
code=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
pass
|
|
||||||
|
# A cruft for Websockets
|
||||||
|
def authenticated(info: Info) -> bool:
|
||||||
|
return IsAuthenticated().has_permission(source=None, info=info)
|
||||||
|
|
||||||
|
|
||||||
|
def reject_if_unauthenticated(info: Info):
|
||||||
|
if not authenticated(info):
|
||||||
|
raise Exception(IsAuthenticated().message)
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
@strawberry.type
|
||||||
class Subscription:
|
class Subscription:
|
||||||
"""Root schema for subscriptions"""
|
"""Root schema for subscriptions.
|
||||||
|
Every field here should be an AsyncIterator or AsyncGenerator
|
||||||
|
It is not a part of the spec but graphql-core (dep of strawberryql)
|
||||||
|
demands it while the spec is vague in this area."""
|
||||||
|
|
||||||
@strawberry.subscription(permission_classes=[IsAuthenticated])
|
@strawberry.subscription
|
||||||
async def count(self, target: int = 100) -> AsyncGenerator[int, None]:
|
async def job_updates(self, info: Info) -> AsyncGenerator[List[ApiJob], None]:
|
||||||
for i in range(target):
|
reject_if_unauthenticated(info)
|
||||||
|
return job_update_generator()
|
||||||
|
|
||||||
|
@strawberry.subscription
|
||||||
|
# Used for testing, consider deletion to shrink attack surface
|
||||||
|
async def count(self, info: Info) -> AsyncGenerator[int, None]:
|
||||||
|
reject_if_unauthenticated(info)
|
||||||
|
for i in range(10):
|
||||||
yield i
|
yield i
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
@strawberry.subscription
|
||||||
|
async def log_entries(self, info: Info) -> AsyncGenerator[LogEntry, None]:
|
||||||
|
reject_if_unauthenticated(info)
|
||||||
|
return log_stream()
|
||||||
|
|
||||||
|
|
||||||
schema = strawberry.Schema(
|
schema = strawberry.Schema(
|
||||||
query=Query,
|
query=Query,
|
||||||
|
|
0
selfprivacy_api/graphql/subscriptions/__init__.py
Normal file
0
selfprivacy_api/graphql/subscriptions/__init__.py
Normal file
14
selfprivacy_api/graphql/subscriptions/jobs.py
Normal file
14
selfprivacy_api/graphql/subscriptions/jobs.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
|
from typing import AsyncGenerator, List
|
||||||
|
|
||||||
|
from selfprivacy_api.jobs import job_notifications
|
||||||
|
|
||||||
|
from selfprivacy_api.graphql.common_types.jobs import ApiJob
|
||||||
|
from selfprivacy_api.graphql.queries.jobs import get_all_jobs
|
||||||
|
|
||||||
|
|
||||||
|
async def job_updates() -> AsyncGenerator[List[ApiJob], None]:
|
||||||
|
# Send the complete list of jobs every time anything gets updated
|
||||||
|
async for notification in job_notifications():
|
||||||
|
yield get_all_jobs()
|
31
selfprivacy_api/graphql/subscriptions/logs.py
Normal file
31
selfprivacy_api/graphql/subscriptions/logs.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from systemd import journal
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from selfprivacy_api.graphql.queries.logs import LogEntry
|
||||||
|
|
||||||
|
|
||||||
|
async def log_stream() -> AsyncGenerator[LogEntry, None]:
|
||||||
|
j = journal.Reader()
|
||||||
|
|
||||||
|
j.seek_tail()
|
||||||
|
j.get_previous()
|
||||||
|
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def callback():
|
||||||
|
if j.process() != journal.APPEND:
|
||||||
|
return
|
||||||
|
for entry in j:
|
||||||
|
await queue.put(entry)
|
||||||
|
|
||||||
|
asyncio.get_event_loop().add_reader(j, lambda: asyncio.ensure_future(callback()))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
entry = await queue.get()
|
||||||
|
try:
|
||||||
|
yield LogEntry(entry)
|
||||||
|
except Exception:
|
||||||
|
asyncio.get_event_loop().remove_reader(j)
|
||||||
|
return
|
||||||
|
queue.task_done()
|
|
@ -15,6 +15,7 @@ A job is a dictionary with the following keys:
|
||||||
- result: result of the job
|
- result: result of the job
|
||||||
"""
|
"""
|
||||||
import typing
|
import typing
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -23,6 +24,7 @@ from enum import Enum
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from selfprivacy_api.utils.redis_pool import RedisPool
|
from selfprivacy_api.utils.redis_pool import RedisPool
|
||||||
|
from selfprivacy_api.utils.redis_model_storage import store_model_as_hash
|
||||||
|
|
||||||
JOB_EXPIRATION_SECONDS = 10 * 24 * 60 * 60 # ten days
|
JOB_EXPIRATION_SECONDS = 10 * 24 * 60 * 60 # ten days
|
||||||
|
|
||||||
|
@ -102,7 +104,7 @@ class Jobs:
|
||||||
result=None,
|
result=None,
|
||||||
)
|
)
|
||||||
redis = RedisPool().get_connection()
|
redis = RedisPool().get_connection()
|
||||||
_store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job)
|
store_model_as_hash(redis, _redis_key_from_uuid(job.uid), job)
|
||||||
return job
|
return job
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -218,7 +220,7 @@ class Jobs:
|
||||||
redis = RedisPool().get_connection()
|
redis = RedisPool().get_connection()
|
||||||
key = _redis_key_from_uuid(job.uid)
|
key = _redis_key_from_uuid(job.uid)
|
||||||
if redis.exists(key):
|
if redis.exists(key):
|
||||||
_store_job_as_hash(redis, key, job)
|
store_model_as_hash(redis, key, job)
|
||||||
if status in (JobStatus.FINISHED, JobStatus.ERROR):
|
if status in (JobStatus.FINISHED, JobStatus.ERROR):
|
||||||
redis.expire(key, JOB_EXPIRATION_SECONDS)
|
redis.expire(key, JOB_EXPIRATION_SECONDS)
|
||||||
|
|
||||||
|
@ -294,17 +296,6 @@ def _progress_log_key_from_uuid(uuid_string) -> str:
|
||||||
return PROGRESS_LOGS_PREFIX + str(uuid_string)
|
return PROGRESS_LOGS_PREFIX + str(uuid_string)
|
||||||
|
|
||||||
|
|
||||||
def _store_job_as_hash(redis, redis_key, model) -> None:
|
|
||||||
for key, value in model.dict().items():
|
|
||||||
if isinstance(value, uuid.UUID):
|
|
||||||
value = str(value)
|
|
||||||
if isinstance(value, datetime.datetime):
|
|
||||||
value = value.isoformat()
|
|
||||||
if isinstance(value, JobStatus):
|
|
||||||
value = value.value
|
|
||||||
redis.hset(redis_key, key, str(value))
|
|
||||||
|
|
||||||
|
|
||||||
def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
|
def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
|
||||||
if redis.exists(redis_key):
|
if redis.exists(redis_key):
|
||||||
job_dict = redis.hgetall(redis_key)
|
job_dict = redis.hgetall(redis_key)
|
||||||
|
@ -321,3 +312,15 @@ def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
|
||||||
|
|
||||||
return Job(**job_dict)
|
return Job(**job_dict)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def job_notifications() -> typing.AsyncGenerator[dict, None]:
|
||||||
|
channel = await RedisPool().subscribe_to_keys("jobs:*")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# we cannot timeout here because we do not know when the next message is supposed to arrive
|
||||||
|
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
|
||||||
|
if message is not None:
|
||||||
|
yield message
|
||||||
|
except GeneratorExit:
|
||||||
|
break
|
||||||
|
|
|
@ -14,10 +14,12 @@ from selfprivacy_api.migrations.write_token_to_redis import WriteTokenToRedis
|
||||||
from selfprivacy_api.migrations.check_for_system_rebuild_jobs import (
|
from selfprivacy_api.migrations.check_for_system_rebuild_jobs import (
|
||||||
CheckForSystemRebuildJobs,
|
CheckForSystemRebuildJobs,
|
||||||
)
|
)
|
||||||
|
from selfprivacy_api.migrations.add_roundcube import AddRoundcube
|
||||||
|
|
||||||
migrations = [
|
migrations = [
|
||||||
WriteTokenToRedis(),
|
WriteTokenToRedis(),
|
||||||
CheckForSystemRebuildJobs(),
|
CheckForSystemRebuildJobs(),
|
||||||
|
AddRoundcube(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
36
selfprivacy_api/migrations/add_roundcube.py
Normal file
36
selfprivacy_api/migrations/add_roundcube.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from selfprivacy_api.migrations.migration import Migration
|
||||||
|
|
||||||
|
from selfprivacy_api.services.flake_service_manager import FlakeServiceManager
|
||||||
|
from selfprivacy_api.utils import ReadUserData, WriteUserData
|
||||||
|
|
||||||
|
|
||||||
|
class AddRoundcube(Migration):
|
||||||
|
"""Adds the Roundcube if it is not present."""
|
||||||
|
|
||||||
|
def get_migration_name(self) -> str:
|
||||||
|
return "add_roundcube"
|
||||||
|
|
||||||
|
def get_migration_description(self) -> str:
|
||||||
|
return "Adds the Roundcube if it is not present."
|
||||||
|
|
||||||
|
def is_migration_needed(self) -> bool:
|
||||||
|
with FlakeServiceManager() as manager:
|
||||||
|
if "roundcube" not in manager.services:
|
||||||
|
return True
|
||||||
|
with ReadUserData() as data:
|
||||||
|
if "roundcube" not in data["modules"]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def migrate(self) -> None:
|
||||||
|
with FlakeServiceManager() as manager:
|
||||||
|
if "roundcube" not in manager.services:
|
||||||
|
manager.services[
|
||||||
|
"roundcube"
|
||||||
|
] = "git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/roundcube"
|
||||||
|
with WriteUserData() as data:
|
||||||
|
if "roundcube" not in data["modules"]:
|
||||||
|
data["modules"]["roundcube"] = {
|
||||||
|
"enable": False,
|
||||||
|
"subdomain": "roundcube",
|
||||||
|
}
|
|
@ -5,13 +5,13 @@ from selfprivacy_api.jobs import JobStatus, Jobs
|
||||||
class CheckForSystemRebuildJobs(Migration):
|
class CheckForSystemRebuildJobs(Migration):
|
||||||
"""Check if there are unfinished system rebuild jobs and finish them"""
|
"""Check if there are unfinished system rebuild jobs and finish them"""
|
||||||
|
|
||||||
def get_migration_name(self):
|
def get_migration_name(self) -> str:
|
||||||
return "check_for_system_rebuild_jobs"
|
return "check_for_system_rebuild_jobs"
|
||||||
|
|
||||||
def get_migration_description(self):
|
def get_migration_description(self) -> str:
|
||||||
return "Check if there are unfinished system rebuild jobs and finish them"
|
return "Check if there are unfinished system rebuild jobs and finish them"
|
||||||
|
|
||||||
def is_migration_needed(self):
|
def is_migration_needed(self) -> bool:
|
||||||
# Check if there are any unfinished system rebuild jobs
|
# Check if there are any unfinished system rebuild jobs
|
||||||
for job in Jobs.get_jobs():
|
for job in Jobs.get_jobs():
|
||||||
if (
|
if (
|
||||||
|
@ -25,8 +25,9 @@ class CheckForSystemRebuildJobs(Migration):
|
||||||
JobStatus.RUNNING,
|
JobStatus.RUNNING,
|
||||||
]:
|
]:
|
||||||
return True
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def migrate(self):
|
def migrate(self) -> None:
|
||||||
# As the API is restarted, we assume that the jobs are finished
|
# As the API is restarted, we assume that the jobs are finished
|
||||||
for job in Jobs.get_jobs():
|
for job in Jobs.get_jobs():
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -12,17 +12,17 @@ class Migration(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_migration_name(self):
|
def get_migration_name(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_migration_description(self):
|
def get_migration_description(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_migration_needed(self):
|
def is_migration_needed(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def migrate(self):
|
def migrate(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -15,10 +15,10 @@ from selfprivacy_api.utils import ReadUserData, UserDataFiles
|
||||||
class WriteTokenToRedis(Migration):
|
class WriteTokenToRedis(Migration):
|
||||||
"""Load Json tokens into Redis"""
|
"""Load Json tokens into Redis"""
|
||||||
|
|
||||||
def get_migration_name(self):
|
def get_migration_name(self) -> str:
|
||||||
return "write_token_to_redis"
|
return "write_token_to_redis"
|
||||||
|
|
||||||
def get_migration_description(self):
|
def get_migration_description(self) -> str:
|
||||||
return "Loads the initial token into redis token storage"
|
return "Loads the initial token into redis token storage"
|
||||||
|
|
||||||
def is_repo_empty(self, repo: AbstractTokensRepository) -> bool:
|
def is_repo_empty(self, repo: AbstractTokensRepository) -> bool:
|
||||||
|
@ -38,7 +38,7 @@ class WriteTokenToRedis(Migration):
|
||||||
print(e)
|
print(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def is_migration_needed(self):
|
def is_migration_needed(self) -> bool:
|
||||||
try:
|
try:
|
||||||
if self.get_token_from_json() is not None and self.is_repo_empty(
|
if self.get_token_from_json() is not None and self.is_repo_empty(
|
||||||
RedisTokensRepository()
|
RedisTokensRepository()
|
||||||
|
@ -47,8 +47,9 @@ class WriteTokenToRedis(Migration):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return False
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
def migrate(self):
|
def migrate(self) -> None:
|
||||||
# Write info about providers to userdata.json
|
# Write info about providers to userdata.json
|
||||||
try:
|
try:
|
||||||
token = self.get_token_from_json()
|
token = self.get_token_from_json()
|
||||||
|
|
|
@ -4,6 +4,7 @@ import typing
|
||||||
from selfprivacy_api.services.bitwarden import Bitwarden
|
from selfprivacy_api.services.bitwarden import Bitwarden
|
||||||
from selfprivacy_api.services.forgejo import Forgejo
|
from selfprivacy_api.services.forgejo import Forgejo
|
||||||
from selfprivacy_api.services.jitsimeet import JitsiMeet
|
from selfprivacy_api.services.jitsimeet import JitsiMeet
|
||||||
|
from selfprivacy_api.services.roundcube import Roundcube
|
||||||
from selfprivacy_api.services.mailserver import MailServer
|
from selfprivacy_api.services.mailserver import MailServer
|
||||||
from selfprivacy_api.services.nextcloud import Nextcloud
|
from selfprivacy_api.services.nextcloud import Nextcloud
|
||||||
from selfprivacy_api.services.pleroma import Pleroma
|
from selfprivacy_api.services.pleroma import Pleroma
|
||||||
|
@ -19,6 +20,7 @@ services: list[Service] = [
|
||||||
Pleroma(),
|
Pleroma(),
|
||||||
Ocserv(),
|
Ocserv(),
|
||||||
JitsiMeet(),
|
JitsiMeet(),
|
||||||
|
Roundcube(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,14 +37,14 @@ class Bitwarden(Service):
|
||||||
def get_user() -> str:
|
def get_user() -> str:
|
||||||
return "vaultwarden"
|
return "vaultwarden"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_url() -> Optional[str]:
|
def get_url(cls) -> Optional[str]:
|
||||||
"""Return service url."""
|
"""Return service url."""
|
||||||
domain = get_domain()
|
domain = get_domain()
|
||||||
return f"https://password.{domain}"
|
return f"https://password.{domain}"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_subdomain() -> Optional[str]:
|
def get_subdomain(cls) -> Optional[str]:
|
||||||
return "password"
|
return "password"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -34,20 +34,20 @@ class FlakeServiceManager:
|
||||||
file.write(
|
file.write(
|
||||||
"""
|
"""
|
||||||
{
|
{
|
||||||
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";\n
|
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";\n
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value in self.services.items():
|
for key, value in self.services.items():
|
||||||
file.write(
|
file.write(
|
||||||
f"""
|
f"""
|
||||||
inputs.{key}.url = {value};
|
inputs.{key}.url = {value};
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
file.write(
|
file.write(
|
||||||
"""
|
"""
|
||||||
outputs = _: { };
|
outputs = _: { };
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
|
@ -90,14 +90,14 @@ class Forgejo(Service):
|
||||||
"""Read SVG icon from file and return it as base64 encoded string."""
|
"""Read SVG icon from file and return it as base64 encoded string."""
|
||||||
return base64.b64encode(FORGEJO_ICON.encode("utf-8")).decode("utf-8")
|
return base64.b64encode(FORGEJO_ICON.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_url() -> Optional[str]:
|
def get_url(cls) -> Optional[str]:
|
||||||
"""Return service url."""
|
"""Return service url."""
|
||||||
domain = get_domain()
|
domain = get_domain()
|
||||||
return f"https://git.{domain}"
|
return f"https://git.{domain}"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_subdomain() -> Optional[str]:
|
def get_subdomain(cls) -> Optional[str]:
|
||||||
return "git"
|
return "git"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -36,14 +36,14 @@ class JitsiMeet(Service):
|
||||||
"""Read SVG icon from file and return it as base64 encoded string."""
|
"""Read SVG icon from file and return it as base64 encoded string."""
|
||||||
return base64.b64encode(JITSI_ICON.encode("utf-8")).decode("utf-8")
|
return base64.b64encode(JITSI_ICON.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_url() -> Optional[str]:
|
def get_url(cls) -> Optional[str]:
|
||||||
"""Return service url."""
|
"""Return service url."""
|
||||||
domain = get_domain()
|
domain = get_domain()
|
||||||
return f"https://meet.{domain}"
|
return f"https://meet.{domain}"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_subdomain() -> Optional[str]:
|
def get_subdomain(cls) -> Optional[str]:
|
||||||
return "meet"
|
return "meet"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -35,13 +35,13 @@ class MailServer(Service):
|
||||||
def get_user() -> str:
|
def get_user() -> str:
|
||||||
return "virtualMail"
|
return "virtualMail"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_url() -> Optional[str]:
|
def get_url(cls) -> Optional[str]:
|
||||||
"""Return service url."""
|
"""Return service url."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_subdomain() -> Optional[str]:
|
def get_subdomain(cls) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -4,7 +4,6 @@ import subprocess
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from selfprivacy_api.utils import get_domain
|
from selfprivacy_api.utils import get_domain
|
||||||
from selfprivacy_api.jobs import Job, Jobs
|
|
||||||
|
|
||||||
from selfprivacy_api.utils.systemd import get_service_status
|
from selfprivacy_api.utils.systemd import get_service_status
|
||||||
from selfprivacy_api.services.service import Service, ServiceStatus
|
from selfprivacy_api.services.service import Service, ServiceStatus
|
||||||
|
@ -35,14 +34,14 @@ class Nextcloud(Service):
|
||||||
"""Read SVG icon from file and return it as base64 encoded string."""
|
"""Read SVG icon from file and return it as base64 encoded string."""
|
||||||
return base64.b64encode(NEXTCLOUD_ICON.encode("utf-8")).decode("utf-8")
|
return base64.b64encode(NEXTCLOUD_ICON.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_url() -> Optional[str]:
|
def get_url(cls) -> Optional[str]:
|
||||||
"""Return service url."""
|
"""Return service url."""
|
||||||
domain = get_domain()
|
domain = get_domain()
|
||||||
return f"https://cloud.{domain}"
|
return f"https://cloud.{domain}"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_subdomain() -> Optional[str]:
|
def get_subdomain(cls) -> Optional[str]:
|
||||||
return "cloud"
|
return "cloud"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -28,13 +28,13 @@ class Ocserv(Service):
|
||||||
def get_svg_icon() -> str:
|
def get_svg_icon() -> str:
|
||||||
return base64.b64encode(OCSERV_ICON.encode("utf-8")).decode("utf-8")
|
return base64.b64encode(OCSERV_ICON.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_url() -> typing.Optional[str]:
|
def get_url(cls) -> typing.Optional[str]:
|
||||||
"""Return service url."""
|
"""Return service url."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_subdomain() -> typing.Optional[str]:
|
def get_subdomain(cls) -> typing.Optional[str]:
|
||||||
return "vpn"
|
return "vpn"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -31,14 +31,14 @@ class Pleroma(Service):
|
||||||
def get_svg_icon() -> str:
|
def get_svg_icon() -> str:
|
||||||
return base64.b64encode(PLEROMA_ICON.encode("utf-8")).decode("utf-8")
|
return base64.b64encode(PLEROMA_ICON.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_url() -> Optional[str]:
|
def get_url(cls) -> Optional[str]:
|
||||||
"""Return service url."""
|
"""Return service url."""
|
||||||
domain = get_domain()
|
domain = get_domain()
|
||||||
return f"https://social.{domain}"
|
return f"https://social.{domain}"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_subdomain() -> Optional[str]:
|
def get_subdomain(cls) -> Optional[str]:
|
||||||
return "social"
|
return "social"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
113
selfprivacy_api/services/roundcube/__init__.py
Normal file
113
selfprivacy_api/services/roundcube/__init__.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
"""Class representing Roundcube service"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import subprocess
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from selfprivacy_api.jobs import Job
|
||||||
|
from selfprivacy_api.utils.systemd import (
|
||||||
|
get_service_status_from_several_units,
|
||||||
|
)
|
||||||
|
from selfprivacy_api.services.service import Service, ServiceStatus
|
||||||
|
from selfprivacy_api.utils import ReadUserData, get_domain
|
||||||
|
from selfprivacy_api.utils.block_devices import BlockDevice
|
||||||
|
from selfprivacy_api.services.roundcube.icon import ROUNDCUBE_ICON
|
||||||
|
|
||||||
|
|
||||||
|
class Roundcube(Service):
|
||||||
|
"""Class representing roundcube service"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_id() -> str:
|
||||||
|
"""Return service id."""
|
||||||
|
return "roundcube"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_display_name() -> str:
|
||||||
|
"""Return service display name."""
|
||||||
|
return "Roundcube"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_description() -> str:
|
||||||
|
"""Return service description."""
|
||||||
|
return "Roundcube is an open source webmail software."
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_svg_icon() -> str:
|
||||||
|
"""Read SVG icon from file and return it as base64 encoded string."""
|
||||||
|
return base64.b64encode(ROUNDCUBE_ICON.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_url(cls) -> Optional[str]:
|
||||||
|
"""Return service url."""
|
||||||
|
domain = get_domain()
|
||||||
|
subdomain = cls.get_subdomain()
|
||||||
|
return f"https://{subdomain}.{domain}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_subdomain(cls) -> Optional[str]:
|
||||||
|
with ReadUserData() as data:
|
||||||
|
if "roundcube" in data["modules"]:
|
||||||
|
return data["modules"]["roundcube"]["subdomain"]
|
||||||
|
|
||||||
|
return "roundcube"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_movable() -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_required() -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def can_be_backed_up() -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_backup_description() -> str:
|
||||||
|
return "Nothing to backup."
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_status() -> ServiceStatus:
|
||||||
|
return get_service_status_from_several_units(["phpfpm-roundcube.service"])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def stop():
|
||||||
|
subprocess.run(
|
||||||
|
["systemctl", "stop", "phpfpm-roundcube.service"],
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def start():
|
||||||
|
subprocess.run(
|
||||||
|
["systemctl", "start", "phpfpm-roundcube.service"],
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restart():
|
||||||
|
subprocess.run(
|
||||||
|
["systemctl", "restart", "phpfpm-roundcube.service"],
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_configuration():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_configuration(config_items):
|
||||||
|
return super().set_configuration(config_items)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_logs():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_folders() -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def move_to_volume(self, volume: BlockDevice) -> Job:
|
||||||
|
raise NotImplementedError("roundcube service is not movable")
|
7
selfprivacy_api/services/roundcube/icon.py
Normal file
7
selfprivacy_api/services/roundcube/icon.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
ROUNDCUBE_ICON = """
|
||||||
|
<svg fill="none" version="1.1" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<g transform="translate(29.07 -.3244)">
|
||||||
|
<path d="m-17.02 2.705c-4.01 2e-7 -7.283 3.273-7.283 7.283 0 0.00524-1.1e-5 0.01038 0 0.01562l-1.85 1.068v5.613l9.105 5.26 9.104-5.26v-5.613l-1.797-1.037c1.008e-4 -0.01573 0.00195-0.03112 0.00195-0.04688-1e-7 -4.01-3.271-7.283-7.281-7.283zm0 2.012c2.923 1e-7 5.27 2.349 5.27 5.271 0 2.923-2.347 5.27-5.27 5.27-2.923-1e-6 -5.271-2.347-5.271-5.27 0-2.923 2.349-5.271 5.271-5.271z" fill="#000" fill-rule="evenodd" stroke-linejoin="bevel"/>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
"""
|
|
@ -65,17 +65,17 @@ class Service(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_url() -> Optional[str]:
|
def get_url(cls) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
The url of the service if it is accessible from the internet browser.
|
The url of the service if it is accessible from the internet browser.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_subdomain() -> Optional[str]:
|
def get_subdomain(cls) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
The assigned primary subdomain for this service.
|
The assigned primary subdomain for this service.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -57,14 +57,14 @@ class DummyService(Service):
|
||||||
# return ""
|
# return ""
|
||||||
return base64.b64encode(BITWARDEN_ICON.encode("utf-8")).decode("utf-8")
|
return base64.b64encode(BITWARDEN_ICON.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_url() -> typing.Optional[str]:
|
def get_url(cls) -> typing.Optional[str]:
|
||||||
"""Return service url."""
|
"""Return service url."""
|
||||||
domain = "test.com"
|
domain = "test.com"
|
||||||
return f"https://password.{domain}"
|
return f"https://password.{domain}"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_subdomain() -> typing.Optional[str]:
|
def get_subdomain(cls) -> typing.Optional[str]:
|
||||||
return "password"
|
return "password"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -1,15 +1,23 @@
|
||||||
|
import uuid
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
def store_model_as_hash(redis, redis_key, model):
|
def store_model_as_hash(redis, redis_key, model):
|
||||||
for key, value in model.dict().items():
|
model_dict = model.dict()
|
||||||
|
for key, value in model_dict.items():
|
||||||
|
if isinstance(value, uuid.UUID):
|
||||||
|
value = str(value)
|
||||||
if isinstance(value, datetime):
|
if isinstance(value, datetime):
|
||||||
value = value.isoformat()
|
value = value.isoformat()
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
value = value.value
|
value = value.value
|
||||||
redis.hset(redis_key, key, str(value))
|
value = str(value)
|
||||||
|
model_dict[key] = value
|
||||||
|
|
||||||
|
redis.hset(redis_key, mapping=model_dict)
|
||||||
|
|
||||||
|
|
||||||
def hash_as_model(redis, redis_key: str, model_class):
|
def hash_as_model(redis, redis_key: str, model_class):
|
||||||
|
|
|
@ -2,23 +2,33 @@
|
||||||
Redis pool module for selfprivacy_api
|
Redis pool module for selfprivacy_api
|
||||||
"""
|
"""
|
||||||
import redis
|
import redis
|
||||||
|
import redis.asyncio as redis_async
|
||||||
|
|
||||||
from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass
|
from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass
|
||||||
|
|
||||||
REDIS_SOCKET = "/run/redis-sp-api/redis.sock"
|
REDIS_SOCKET = "/run/redis-sp-api/redis.sock"
|
||||||
|
|
||||||
|
|
||||||
class RedisPool(metaclass=SingletonMetaclass):
|
# class RedisPool(metaclass=SingletonMetaclass):
|
||||||
|
class RedisPool:
|
||||||
"""
|
"""
|
||||||
Redis connection pool singleton.
|
Redis connection pool singleton.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self._dbnumber = 0
|
||||||
|
url = RedisPool.connection_url(dbnumber=self._dbnumber)
|
||||||
|
# We need a normal sync pool because otherwise
|
||||||
|
# our whole API will need to be async
|
||||||
self._pool = redis.ConnectionPool.from_url(
|
self._pool = redis.ConnectionPool.from_url(
|
||||||
RedisPool.connection_url(dbnumber=0),
|
url,
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
# We need an async pool for pubsub
|
||||||
|
self._async_pool = redis_async.ConnectionPool.from_url(
|
||||||
|
url,
|
||||||
decode_responses=True,
|
decode_responses=True,
|
||||||
)
|
)
|
||||||
self._pubsub_connection = self.get_connection()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def connection_url(dbnumber: int) -> str:
|
def connection_url(dbnumber: int) -> str:
|
||||||
|
@ -34,8 +44,15 @@ class RedisPool(metaclass=SingletonMetaclass):
|
||||||
"""
|
"""
|
||||||
return redis.Redis(connection_pool=self._pool)
|
return redis.Redis(connection_pool=self._pool)
|
||||||
|
|
||||||
def get_pubsub(self):
|
def get_connection_async(self) -> redis_async.Redis:
|
||||||
"""
|
"""
|
||||||
Get a pubsub connection from the pool.
|
Get an async connection from the pool.
|
||||||
|
Async connections allow pubsub.
|
||||||
"""
|
"""
|
||||||
return self._pubsub_connection.pubsub()
|
return redis_async.Redis(connection_pool=self._async_pool)
|
||||||
|
|
||||||
|
async def subscribe_to_keys(self, pattern: str) -> redis_async.client.PubSub:
|
||||||
|
async_redis = self.get_connection_async()
|
||||||
|
pubsub = async_redis.pubsub()
|
||||||
|
await pubsub.psubscribe(f"__keyspace@{self._dbnumber}__:" + pattern)
|
||||||
|
return pubsub
|
||||||
|
|
55
selfprivacy_api/utils/systemd_journal.py
Normal file
55
selfprivacy_api/utils/systemd_journal.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
import typing
|
||||||
|
from systemd import journal
|
||||||
|
|
||||||
|
|
||||||
|
def get_events_from_journal(
|
||||||
|
j: journal.Reader, limit: int, next: typing.Callable[[journal.Reader], typing.Dict]
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
i = 0
|
||||||
|
while i < limit:
|
||||||
|
entry = next(j)
|
||||||
|
if entry is None or entry == dict():
|
||||||
|
break
|
||||||
|
if entry["MESSAGE"] != "":
|
||||||
|
events.append(entry)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def get_paginated_logs(
|
||||||
|
limit: int = 20,
|
||||||
|
# All entries returned will be lesser than this cursor. Sets upper bound on results.
|
||||||
|
up_cursor: str | None = None,
|
||||||
|
# All entries returned will be greater than this cursor. Sets lower bound on results.
|
||||||
|
down_cursor: str | None = None,
|
||||||
|
):
|
||||||
|
j = journal.Reader()
|
||||||
|
|
||||||
|
if up_cursor is None and down_cursor is None:
|
||||||
|
j.seek_tail()
|
||||||
|
|
||||||
|
events = get_events_from_journal(j, limit, lambda j: j.get_previous())
|
||||||
|
events.reverse()
|
||||||
|
|
||||||
|
return events
|
||||||
|
elif up_cursor is None and down_cursor is not None:
|
||||||
|
j.seek_cursor(down_cursor)
|
||||||
|
j.get_previous() # pagination is exclusive
|
||||||
|
|
||||||
|
events = get_events_from_journal(j, limit, lambda j: j.get_previous())
|
||||||
|
events.reverse()
|
||||||
|
|
||||||
|
return events
|
||||||
|
elif up_cursor is not None and down_cursor is None:
|
||||||
|
j.seek_cursor(up_cursor)
|
||||||
|
j.get_next() # pagination is exclusive
|
||||||
|
|
||||||
|
events = get_events_from_journal(j, limit, lambda j: j.get_next())
|
||||||
|
|
||||||
|
return events
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Pagination by both up_cursor and down_cursor is not implemented"
|
||||||
|
)
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="selfprivacy_api",
|
name="selfprivacy_api",
|
||||||
version="3.2.2",
|
version="3.3.0",
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
scripts=[
|
scripts=[
|
||||||
"selfprivacy_api/app.py",
|
"selfprivacy_api/app.py",
|
||||||
|
|
|
@ -69,10 +69,22 @@ def generate_backup_query(query_array):
|
||||||
return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}"
|
return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_jobs_query(query_array):
|
||||||
|
return "query TestJobs {\n jobs {" + "\n".join(query_array) + "}\n}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_jobs_subscription(query_array):
|
||||||
|
return "subscription TestSubscription {\n jobs {" + "\n".join(query_array) + "}\n}"
|
||||||
|
|
||||||
|
|
||||||
def generate_service_query(query_array):
|
def generate_service_query(query_array):
|
||||||
return "query TestService {\n services {" + "\n".join(query_array) + "}\n}"
|
return "query TestService {\n services {" + "\n".join(query_array) + "}\n}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_logs_query(query_array):
|
||||||
|
return "query TestService {\n logs {" + "\n".join(query_array) + "}\n}"
|
||||||
|
|
||||||
|
|
||||||
def mnemonic_to_hex(mnemonic):
|
def mnemonic_to_hex(mnemonic):
|
||||||
return Mnemonic(language="english").to_entropy(mnemonic).hex()
|
return Mnemonic(language="english").to_entropy(mnemonic).hex()
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,9 @@
|
||||||
"simple-nixos-mailserver": {
|
"simple-nixos-mailserver": {
|
||||||
"enable": true,
|
"enable": true,
|
||||||
"location": "sdb"
|
"location": "sdb"
|
||||||
|
},
|
||||||
|
"roundcube": {
|
||||||
|
"enable": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"volumes": [
|
"volumes": [
|
||||||
|
|
|
@ -4,40 +4,40 @@ from selfprivacy_api.services.flake_service_manager import FlakeServiceManager
|
||||||
|
|
||||||
all_services_file = """
|
all_services_file = """
|
||||||
{
|
{
|
||||||
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
|
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
|
||||||
|
|
||||||
|
|
||||||
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
|
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
|
||||||
|
|
||||||
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
|
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
|
||||||
|
|
||||||
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
|
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
|
||||||
|
|
||||||
inputs.nextcloud.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud;
|
inputs.nextcloud.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/nextcloud;
|
||||||
|
|
||||||
inputs.ocserv.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv;
|
inputs.ocserv.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/ocserv;
|
||||||
|
|
||||||
inputs.pleroma.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma;
|
inputs.pleroma.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/pleroma;
|
||||||
|
|
||||||
inputs.simple-nixos-mailserver.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver;
|
inputs.simple-nixos-mailserver.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/simple-nixos-mailserver;
|
||||||
|
|
||||||
outputs = _: { };
|
outputs = _: { };
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
some_services_file = """
|
some_services_file = """
|
||||||
{
|
{
|
||||||
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
|
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
|
||||||
|
|
||||||
|
|
||||||
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
|
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
|
||||||
|
|
||||||
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
|
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
|
||||||
|
|
||||||
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
|
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
|
||||||
|
|
||||||
outputs = _: { };
|
outputs = _: { };
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
|
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
|
||||||
outputs = _: { };
|
outputs = _: { };
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
{
|
{
|
||||||
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
|
description = "SelfPrivacy NixOS PoC modules/extensions/bundles/packages/etc";
|
||||||
|
|
||||||
|
|
||||||
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
|
inputs.bitwarden.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/bitwarden;
|
||||||
|
|
||||||
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
|
inputs.gitea.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/gitea;
|
||||||
|
|
||||||
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
|
inputs.jitsi-meet.url = git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes&dir=sp-modules/jitsi-meet;
|
||||||
|
|
||||||
outputs = _: { };
|
outputs = _: { };
|
||||||
}
|
}
|
||||||
|
|
172
tests/test_graphql/test_api_logs.py
Normal file
172
tests/test_graphql/test_api_logs.py
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from systemd import journal
|
||||||
|
|
||||||
|
from tests.test_graphql.test_websocket import init_graphql
|
||||||
|
|
||||||
|
|
||||||
|
def assert_log_entry_equals_to_journal_entry(api_entry, journal_entry):
|
||||||
|
assert api_entry["message"] == journal_entry["MESSAGE"]
|
||||||
|
assert (
|
||||||
|
datetime.fromisoformat(api_entry["timestamp"])
|
||||||
|
== journal_entry["__REALTIME_TIMESTAMP"]
|
||||||
|
)
|
||||||
|
assert api_entry.get("priority") == journal_entry.get("PRIORITY")
|
||||||
|
assert api_entry.get("systemdUnit") == journal_entry.get("_SYSTEMD_UNIT")
|
||||||
|
assert api_entry.get("systemdSlice") == journal_entry.get("_SYSTEMD_SLICE")
|
||||||
|
|
||||||
|
|
||||||
|
def take_from_journal(j, limit, next):
|
||||||
|
entries = []
|
||||||
|
for _ in range(0, limit):
|
||||||
|
entry = next(j)
|
||||||
|
if entry["MESSAGE"] != "":
|
||||||
|
entries.append(entry)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
API_GET_LOGS_WITH_UP_BORDER = """
|
||||||
|
query TestQuery($upCursor: String) {
|
||||||
|
logs {
|
||||||
|
paginated(limit: 4, upCursor: $upCursor) {
|
||||||
|
pageMeta {
|
||||||
|
upCursor
|
||||||
|
downCursor
|
||||||
|
}
|
||||||
|
entries {
|
||||||
|
message
|
||||||
|
timestamp
|
||||||
|
priority
|
||||||
|
systemdUnit
|
||||||
|
systemdSlice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
API_GET_LOGS_WITH_DOWN_BORDER = """
|
||||||
|
query TestQuery($downCursor: String) {
|
||||||
|
logs {
|
||||||
|
paginated(limit: 4, downCursor: $downCursor) {
|
||||||
|
pageMeta {
|
||||||
|
upCursor
|
||||||
|
downCursor
|
||||||
|
}
|
||||||
|
entries {
|
||||||
|
message
|
||||||
|
timestamp
|
||||||
|
priority
|
||||||
|
systemdUnit
|
||||||
|
systemdSlice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_graphql_get_logs_with_up_border(authorized_client):
|
||||||
|
j = journal.Reader()
|
||||||
|
j.seek_tail()
|
||||||
|
|
||||||
|
# < - cursor
|
||||||
|
# <- - log entry will be returned by API call.
|
||||||
|
# ...
|
||||||
|
# log <
|
||||||
|
# log <-
|
||||||
|
# log <-
|
||||||
|
# log <-
|
||||||
|
# log <-
|
||||||
|
# log
|
||||||
|
|
||||||
|
expected_entries = take_from_journal(j, 6, lambda j: j.get_previous())
|
||||||
|
expected_entries.reverse()
|
||||||
|
|
||||||
|
response = authorized_client.post(
|
||||||
|
"/graphql",
|
||||||
|
json={
|
||||||
|
"query": API_GET_LOGS_WITH_UP_BORDER,
|
||||||
|
"variables": {"upCursor": expected_entries[0]["__CURSOR"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
expected_entries = expected_entries[1:-1]
|
||||||
|
returned_entries = response.json()["data"]["logs"]["paginated"]["entries"]
|
||||||
|
|
||||||
|
assert len(returned_entries) == len(expected_entries)
|
||||||
|
|
||||||
|
for api_entry, journal_entry in zip(returned_entries, expected_entries):
|
||||||
|
assert_log_entry_equals_to_journal_entry(api_entry, journal_entry)
|
||||||
|
|
||||||
|
|
||||||
|
def test_graphql_get_logs_with_down_border(authorized_client):
|
||||||
|
j = journal.Reader()
|
||||||
|
j.seek_head()
|
||||||
|
j.get_next()
|
||||||
|
|
||||||
|
# < - cursor
|
||||||
|
# <- - log entry will be returned by API call.
|
||||||
|
# log
|
||||||
|
# log <-
|
||||||
|
# log <-
|
||||||
|
# log <-
|
||||||
|
# log <-
|
||||||
|
# log <
|
||||||
|
# ...
|
||||||
|
|
||||||
|
expected_entries = take_from_journal(j, 5, lambda j: j.get_next())
|
||||||
|
|
||||||
|
response = authorized_client.post(
|
||||||
|
"/graphql",
|
||||||
|
json={
|
||||||
|
"query": API_GET_LOGS_WITH_DOWN_BORDER,
|
||||||
|
"variables": {"downCursor": expected_entries[-1]["__CURSOR"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
expected_entries = expected_entries[:-1]
|
||||||
|
returned_entries = response.json()["data"]["logs"]["paginated"]["entries"]
|
||||||
|
|
||||||
|
assert len(returned_entries) == len(expected_entries)
|
||||||
|
|
||||||
|
for api_entry, journal_entry in zip(returned_entries, expected_entries):
|
||||||
|
assert_log_entry_equals_to_journal_entry(api_entry, journal_entry)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_subscription_for_logs(authorized_client):
|
||||||
|
with authorized_client.websocket_connect(
|
||||||
|
"/graphql", subprotocols=["graphql-transport-ws"]
|
||||||
|
) as websocket:
|
||||||
|
init_graphql(websocket)
|
||||||
|
websocket.send_json(
|
||||||
|
{
|
||||||
|
"id": "3aaa2445",
|
||||||
|
"type": "subscribe",
|
||||||
|
"payload": {
|
||||||
|
"query": "subscription TestSubscription { logEntries { message } }",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
def read_until(message, limit=5):
|
||||||
|
i = 0
|
||||||
|
while i < limit:
|
||||||
|
msg = websocket.receive_json()["payload"]["data"]["logEntries"][
|
||||||
|
"message"
|
||||||
|
]
|
||||||
|
if msg == message:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
raise Exception("Failed to read websocket data, timeout")
|
||||||
|
|
||||||
|
for i in range(0, 10):
|
||||||
|
journal.send(f"Lorem ipsum number {i}")
|
||||||
|
read_until(f"Lorem ipsum number {i}")
|
74
tests/test_graphql/test_jobs.py
Normal file
74
tests/test_graphql/test_jobs.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
from tests.common import generate_jobs_query
|
||||||
|
import tests.test_graphql.test_api_backup
|
||||||
|
|
||||||
|
from tests.test_graphql.common import (
|
||||||
|
assert_ok,
|
||||||
|
assert_empty,
|
||||||
|
assert_errorcode,
|
||||||
|
get_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
from selfprivacy_api.jobs import Jobs
|
||||||
|
|
||||||
|
API_JOBS_QUERY = """
|
||||||
|
getJobs {
|
||||||
|
uid
|
||||||
|
typeId
|
||||||
|
name
|
||||||
|
description
|
||||||
|
status
|
||||||
|
statusText
|
||||||
|
progress
|
||||||
|
createdAt
|
||||||
|
updatedAt
|
||||||
|
finishedAt
|
||||||
|
error
|
||||||
|
result
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def graphql_send_query(client, query: str, variables: dict = {}):
|
||||||
|
return client.post("/graphql", json={"query": query, "variables": variables})
|
||||||
|
|
||||||
|
|
||||||
|
def api_jobs(authorized_client):
|
||||||
|
response = graphql_send_query(
|
||||||
|
authorized_client, generate_jobs_query([API_JOBS_QUERY])
|
||||||
|
)
|
||||||
|
data = get_data(response)
|
||||||
|
result = data["jobs"]["getJobs"]
|
||||||
|
assert result is not None
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_jobs_unauthorized(client):
|
||||||
|
response = graphql_send_query(client, generate_jobs_query([API_JOBS_QUERY]))
|
||||||
|
assert_empty(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_jobs_when_none(authorized_client):
|
||||||
|
output = api_jobs(authorized_client)
|
||||||
|
assert output == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_jobs_when_some(authorized_client):
|
||||||
|
# We cannot make new jobs via API, at least directly
|
||||||
|
job = Jobs.add("bogus", "bogus.bogus", "fungus")
|
||||||
|
output = api_jobs(authorized_client)
|
||||||
|
|
||||||
|
len(output) == 1
|
||||||
|
api_job = output[0]
|
||||||
|
|
||||||
|
assert api_job["uid"] == str(job.uid)
|
||||||
|
assert api_job["typeId"] == job.type_id
|
||||||
|
assert api_job["name"] == job.name
|
||||||
|
assert api_job["description"] == job.description
|
||||||
|
assert api_job["status"] == job.status
|
||||||
|
assert api_job["statusText"] == job.status_text
|
||||||
|
assert api_job["progress"] == job.progress
|
||||||
|
assert api_job["createdAt"] == job.created_at.isoformat()
|
||||||
|
assert api_job["updatedAt"] == job.updated_at.isoformat()
|
||||||
|
assert api_job["finishedAt"] == None
|
||||||
|
assert api_job["error"] == None
|
||||||
|
assert api_job["result"] == None
|
|
@ -543,8 +543,8 @@ def test_disable_enable(authorized_client, only_dummy_service):
|
||||||
assert api_dummy_service["status"] == ServiceStatus.ACTIVE.value
|
assert api_dummy_service["status"] == ServiceStatus.ACTIVE.value
|
||||||
|
|
||||||
|
|
||||||
def test_move_immovable(authorized_client, only_dummy_service):
|
def test_move_immovable(authorized_client, dummy_service_with_binds):
|
||||||
dummy_service = only_dummy_service
|
dummy_service = dummy_service_with_binds
|
||||||
dummy_service.set_movable(False)
|
dummy_service.set_movable(False)
|
||||||
root = BlockDevices().get_root_block_device()
|
root = BlockDevices().get_root_block_device()
|
||||||
mutation_response = api_move(authorized_client, dummy_service, root.name)
|
mutation_response = api_move(authorized_client, dummy_service, root.name)
|
||||||
|
|
225
tests/test_graphql/test_websocket.py
Normal file
225
tests/test_graphql/test_websocket.py
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
# from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from typing import Generator
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
from starlette.testclient import WebSocketTestSession
|
||||||
|
|
||||||
|
from selfprivacy_api.jobs import Jobs
|
||||||
|
from selfprivacy_api.actions.api_tokens import TOKEN_REPO
|
||||||
|
from selfprivacy_api.graphql import IsAuthenticated
|
||||||
|
|
||||||
|
from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH
|
||||||
|
from tests.test_jobs import jobs as empty_jobs
|
||||||
|
|
||||||
|
# We do not iterate through them yet
|
||||||
|
TESTED_SUBPROTOCOLS = ["graphql-transport-ws"]
|
||||||
|
|
||||||
|
JOBS_SUBSCRIPTION = """
|
||||||
|
jobUpdates {
|
||||||
|
uid
|
||||||
|
typeId
|
||||||
|
name
|
||||||
|
description
|
||||||
|
status
|
||||||
|
statusText
|
||||||
|
progress
|
||||||
|
createdAt
|
||||||
|
updatedAt
|
||||||
|
finishedAt
|
||||||
|
error
|
||||||
|
result
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def api_subscribe(websocket, id, subscription):
|
||||||
|
websocket.send_json(
|
||||||
|
{
|
||||||
|
"id": id,
|
||||||
|
"type": "subscribe",
|
||||||
|
"payload": {
|
||||||
|
"query": "subscription TestSubscription {" + subscription + "}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def connect_ws_authenticated(authorized_client) -> WebSocketTestSession:
|
||||||
|
token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"])
|
||||||
|
return authorized_client.websocket_connect(
|
||||||
|
"/graphql",
|
||||||
|
subprotocols=TESTED_SUBPROTOCOLS,
|
||||||
|
params={"token": token},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def connect_ws_not_authenticated(client) -> WebSocketTestSession:
|
||||||
|
return client.websocket_connect(
|
||||||
|
"/graphql",
|
||||||
|
subprotocols=TESTED_SUBPROTOCOLS,
|
||||||
|
params={"token": "I like vegan icecream but it is not a valid token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_graphql(websocket):
|
||||||
|
websocket.send_json({"type": "connection_init", "payload": {}})
|
||||||
|
ack = websocket.receive_json()
|
||||||
|
assert ack == {"type": "connection_ack"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def authenticated_websocket(
|
||||||
|
authorized_client,
|
||||||
|
) -> Generator[WebSocketTestSession, None, None]:
|
||||||
|
# We use authorized_client only to have token in the repo, this client by itself is not enough to authorize websocket
|
||||||
|
|
||||||
|
ValueError(TOKEN_REPO.get_tokens())
|
||||||
|
with connect_ws_authenticated(authorized_client) as websocket:
|
||||||
|
yield websocket
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]:
|
||||||
|
with connect_ws_not_authenticated(client) as websocket:
|
||||||
|
yield websocket
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_connection_bare(authorized_client):
|
||||||
|
client = authorized_client
|
||||||
|
with client.websocket_connect(
|
||||||
|
"/graphql", subprotocols=["graphql-transport-ws", "graphql-ws"]
|
||||||
|
) as websocket:
|
||||||
|
assert websocket is not None
|
||||||
|
assert websocket.scope is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_graphql_init(authorized_client):
|
||||||
|
client = authorized_client
|
||||||
|
with client.websocket_connect(
|
||||||
|
"/graphql", subprotocols=["graphql-transport-ws"]
|
||||||
|
) as websocket:
|
||||||
|
websocket.send_json({"type": "connection_init", "payload": {}})
|
||||||
|
ack = websocket.receive_json()
|
||||||
|
assert ack == {"type": "connection_ack"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_graphql_ping(authorized_client):
|
||||||
|
client = authorized_client
|
||||||
|
with client.websocket_connect(
|
||||||
|
"/graphql", subprotocols=["graphql-transport-ws"]
|
||||||
|
) as websocket:
|
||||||
|
# https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping
|
||||||
|
websocket.send_json({"type": "ping", "payload": {}})
|
||||||
|
pong = websocket.receive_json()
|
||||||
|
assert pong == {"type": "pong"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_subscription_minimal(authorized_client, authenticated_websocket):
|
||||||
|
# Test a small endpoint that exists specifically for tests
|
||||||
|
websocket = authenticated_websocket
|
||||||
|
init_graphql(websocket)
|
||||||
|
arbitrary_id = "3aaa2445"
|
||||||
|
api_subscribe(websocket, arbitrary_id, "count")
|
||||||
|
response = websocket.receive_json()
|
||||||
|
assert response == {
|
||||||
|
"id": arbitrary_id,
|
||||||
|
"payload": {"data": {"count": 0}},
|
||||||
|
"type": "next",
|
||||||
|
}
|
||||||
|
response = websocket.receive_json()
|
||||||
|
assert response == {
|
||||||
|
"id": arbitrary_id,
|
||||||
|
"payload": {"data": {"count": 1}},
|
||||||
|
"type": "next",
|
||||||
|
}
|
||||||
|
response = websocket.receive_json()
|
||||||
|
assert response == {
|
||||||
|
"id": arbitrary_id,
|
||||||
|
"payload": {"data": {"count": 2}},
|
||||||
|
"type": "next",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket):
|
||||||
|
websocket = unauthenticated_websocket
|
||||||
|
init_graphql(websocket)
|
||||||
|
arbitrary_id = "3aaa2445"
|
||||||
|
api_subscribe(websocket, arbitrary_id, "count")
|
||||||
|
|
||||||
|
response = websocket.receive_json()
|
||||||
|
assert response == {
|
||||||
|
"id": arbitrary_id,
|
||||||
|
"payload": [{"message": IsAuthenticated.message}],
|
||||||
|
"type": "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def read_one_job(websocket):
|
||||||
|
# Bug? We only get them starting from the second job update
|
||||||
|
# That's why we receive two jobs in the list them
|
||||||
|
# The first update gets lost somewhere
|
||||||
|
response = websocket.receive_json()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs):
|
||||||
|
websocket = authenticated_websocket
|
||||||
|
init_graphql(websocket)
|
||||||
|
arbitrary_id = "3aaa2445"
|
||||||
|
api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION)
|
||||||
|
|
||||||
|
future = asyncio.create_task(read_one_job(websocket))
|
||||||
|
jobs = []
|
||||||
|
jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works"))
|
||||||
|
sleep(0.5)
|
||||||
|
jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works"))
|
||||||
|
|
||||||
|
response = await future
|
||||||
|
data = response["payload"]["data"]
|
||||||
|
jobs_received = data["jobUpdates"]
|
||||||
|
received_names = [job["name"] for job in jobs_received]
|
||||||
|
for job in jobs:
|
||||||
|
assert job.name in received_names
|
||||||
|
|
||||||
|
assert len(jobs_received) == 2
|
||||||
|
|
||||||
|
for job in jobs:
|
||||||
|
for api_job in jobs_received:
|
||||||
|
if (job.name) == api_job["name"]:
|
||||||
|
assert api_job["uid"] == str(job.uid)
|
||||||
|
assert api_job["typeId"] == job.type_id
|
||||||
|
assert api_job["name"] == job.name
|
||||||
|
assert api_job["description"] == job.description
|
||||||
|
assert api_job["status"] == job.status
|
||||||
|
assert api_job["statusText"] == job.status_text
|
||||||
|
assert api_job["progress"] == job.progress
|
||||||
|
assert api_job["createdAt"] == job.created_at.isoformat()
|
||||||
|
assert api_job["updatedAt"] == job.updated_at.isoformat()
|
||||||
|
assert api_job["finishedAt"] == None
|
||||||
|
assert api_job["error"] == None
|
||||||
|
assert api_job["result"] == None
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_subscription_unauthorized(unauthenticated_websocket):
|
||||||
|
websocket = unauthenticated_websocket
|
||||||
|
init_graphql(websocket)
|
||||||
|
id = "3aaa2445"
|
||||||
|
api_subscribe(websocket, id, JOBS_SUBSCRIPTION)
|
||||||
|
|
||||||
|
response = websocket.receive_json()
|
||||||
|
# I do not really know why strawberry gives more info on this
|
||||||
|
# One versus the counter
|
||||||
|
payload = response["payload"][0]
|
||||||
|
assert isinstance(payload, dict)
|
||||||
|
assert "locations" in payload.keys()
|
||||||
|
# It looks like this 'locations': [{'column': 32, 'line': 1}]
|
||||||
|
# We cannot test locations feasibly
|
||||||
|
del payload["locations"]
|
||||||
|
assert response == {
|
||||||
|
"id": id,
|
||||||
|
"payload": [{"message": IsAuthenticated.message, "path": ["jobUpdates"]}],
|
||||||
|
"type": "error",
|
||||||
|
}
|
218
tests/test_redis.py
Normal file
218
tests/test_redis.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from asyncio import streams
|
||||||
|
import redis
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from selfprivacy_api.utils.redis_pool import RedisPool
|
||||||
|
|
||||||
|
from selfprivacy_api.jobs import Jobs, job_notifications
|
||||||
|
|
||||||
|
TEST_KEY = "test:test"
|
||||||
|
STOPWORD = "STOP"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def empty_redis(event_loop):
|
||||||
|
r = RedisPool().get_connection()
|
||||||
|
r.flushdb()
|
||||||
|
assert r.config_get("notify-keyspace-events")["notify-keyspace-events"] == "AKE"
|
||||||
|
yield r
|
||||||
|
r.flushdb()
|
||||||
|
|
||||||
|
|
||||||
|
async def write_to_test_key():
|
||||||
|
r = RedisPool().get_connection_async()
|
||||||
|
async with r.pipeline(transaction=True) as pipe:
|
||||||
|
ok1, ok2 = await pipe.set(TEST_KEY, "value1").set(TEST_KEY, "value2").execute()
|
||||||
|
assert ok1
|
||||||
|
assert ok2
|
||||||
|
assert await r.get(TEST_KEY) == "value2"
|
||||||
|
await r.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_connection(empty_redis):
|
||||||
|
r = RedisPool().get_connection()
|
||||||
|
assert not r.exists(TEST_KEY)
|
||||||
|
# It _will_ report an error if it arises
|
||||||
|
asyncio.run(write_to_test_key())
|
||||||
|
# Confirming that we can read result from sync connection too
|
||||||
|
assert r.get(TEST_KEY) == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
async def channel_reader(channel: redis.client.PubSub) -> List[dict]:
|
||||||
|
result: List[dict] = []
|
||||||
|
while True:
|
||||||
|
# Mypy cannot correctly detect that it is a coroutine
|
||||||
|
# But it is
|
||||||
|
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
|
||||||
|
if message is not None:
|
||||||
|
result.append(message)
|
||||||
|
if message["data"] == STOPWORD:
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def channel_reader_onemessage(channel: redis.client.PubSub) -> dict:
|
||||||
|
while True:
|
||||||
|
# Mypy cannot correctly detect that it is a coroutine
|
||||||
|
# But it is
|
||||||
|
message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore
|
||||||
|
if message is not None:
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pubsub(empty_redis, event_loop):
|
||||||
|
# Adapted from :
|
||||||
|
# https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html
|
||||||
|
# Sanity checking because of previous event loop bugs
|
||||||
|
assert event_loop == asyncio.get_event_loop()
|
||||||
|
assert event_loop == asyncio.events.get_event_loop()
|
||||||
|
assert event_loop == asyncio.events._get_event_loop()
|
||||||
|
assert event_loop == asyncio.events.get_running_loop()
|
||||||
|
|
||||||
|
reader = streams.StreamReader(34)
|
||||||
|
assert event_loop == reader._loop
|
||||||
|
f = reader._loop.create_future()
|
||||||
|
f.set_result(3)
|
||||||
|
await f
|
||||||
|
|
||||||
|
r = RedisPool().get_connection_async()
|
||||||
|
async with r.pubsub() as pubsub:
|
||||||
|
await pubsub.subscribe("channel:1")
|
||||||
|
future = asyncio.create_task(channel_reader(pubsub))
|
||||||
|
|
||||||
|
await r.publish("channel:1", "Hello")
|
||||||
|
# message: dict = await pubsub.get_message(ignore_subscribe_messages=True, timeout=5.0) # type: ignore
|
||||||
|
# raise ValueError(message)
|
||||||
|
await r.publish("channel:1", "World")
|
||||||
|
await r.publish("channel:1", STOPWORD)
|
||||||
|
|
||||||
|
messages = await future
|
||||||
|
|
||||||
|
assert len(messages) == 3
|
||||||
|
|
||||||
|
message = messages[0]
|
||||||
|
assert "data" in message.keys()
|
||||||
|
assert message["data"] == "Hello"
|
||||||
|
message = messages[1]
|
||||||
|
assert "data" in message.keys()
|
||||||
|
assert message["data"] == "World"
|
||||||
|
message = messages[2]
|
||||||
|
assert "data" in message.keys()
|
||||||
|
assert message["data"] == STOPWORD
|
||||||
|
|
||||||
|
await r.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyspace_notifications_simple(empty_redis, event_loop):
|
||||||
|
r = RedisPool().get_connection_async()
|
||||||
|
await r.set(TEST_KEY, "I am not empty")
|
||||||
|
async with r.pubsub() as pubsub:
|
||||||
|
await pubsub.subscribe("__keyspace@0__:" + TEST_KEY)
|
||||||
|
|
||||||
|
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
|
||||||
|
empty_redis.set(TEST_KEY, "I am set!")
|
||||||
|
message = await future_message
|
||||||
|
assert message is not None
|
||||||
|
assert message["data"] is not None
|
||||||
|
assert message == {
|
||||||
|
"channel": f"__keyspace@0__:{TEST_KEY}",
|
||||||
|
"data": "set",
|
||||||
|
"pattern": None,
|
||||||
|
"type": "message",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyspace_notifications(empty_redis, event_loop):
|
||||||
|
pubsub = await RedisPool().subscribe_to_keys(TEST_KEY)
|
||||||
|
async with pubsub:
|
||||||
|
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
|
||||||
|
empty_redis.set(TEST_KEY, "I am set!")
|
||||||
|
message = await future_message
|
||||||
|
assert message is not None
|
||||||
|
assert message["data"] is not None
|
||||||
|
assert message == {
|
||||||
|
"channel": f"__keyspace@0__:{TEST_KEY}",
|
||||||
|
"data": "set",
|
||||||
|
"pattern": f"__keyspace@0__:{TEST_KEY}",
|
||||||
|
"type": "pmessage",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyspace_notifications_patterns(empty_redis, event_loop):
|
||||||
|
pattern = "test*"
|
||||||
|
pubsub = await RedisPool().subscribe_to_keys(pattern)
|
||||||
|
async with pubsub:
|
||||||
|
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
|
||||||
|
empty_redis.set(TEST_KEY, "I am set!")
|
||||||
|
message = await future_message
|
||||||
|
assert message is not None
|
||||||
|
assert message["data"] is not None
|
||||||
|
assert message == {
|
||||||
|
"channel": f"__keyspace@0__:{TEST_KEY}",
|
||||||
|
"data": "set",
|
||||||
|
"pattern": f"__keyspace@0__:{pattern}",
|
||||||
|
"type": "pmessage",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyspace_notifications_jobs(empty_redis, event_loop):
|
||||||
|
pattern = "jobs:*"
|
||||||
|
pubsub = await RedisPool().subscribe_to_keys(pattern)
|
||||||
|
async with pubsub:
|
||||||
|
future_message = asyncio.create_task(channel_reader_onemessage(pubsub))
|
||||||
|
Jobs.add("testjob1", "test.test", "Testing aaaalll day")
|
||||||
|
message = await future_message
|
||||||
|
assert message is not None
|
||||||
|
assert message["data"] is not None
|
||||||
|
assert message["data"] == "hset"
|
||||||
|
|
||||||
|
|
||||||
|
async def reader_of_jobs() -> List[dict]:
|
||||||
|
"""
|
||||||
|
Reads 3 job updates and exits
|
||||||
|
"""
|
||||||
|
result: List[dict] = []
|
||||||
|
async for message in job_notifications():
|
||||||
|
result.append(message)
|
||||||
|
if len(result) >= 3:
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_jobs_generator(empty_redis, event_loop):
|
||||||
|
# Will read exactly 3 job messages
|
||||||
|
future_messages = asyncio.create_task(reader_of_jobs())
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
Jobs.add("testjob1", "test.test", "Testing aaaalll day")
|
||||||
|
Jobs.add("testjob2", "test.test", "Testing aaaalll day")
|
||||||
|
Jobs.add("testjob3", "test.test", "Testing aaaalll day")
|
||||||
|
Jobs.add("testjob4", "test.test", "Testing aaaalll day")
|
||||||
|
|
||||||
|
assert len(Jobs.get_jobs()) == 4
|
||||||
|
r = RedisPool().get_connection()
|
||||||
|
assert len(r.keys("jobs:*")) == 4
|
||||||
|
|
||||||
|
messages = await future_messages
|
||||||
|
assert len(messages) == 3
|
||||||
|
channels = [message["channel"] for message in messages]
|
||||||
|
operations = [message["data"] for message in messages]
|
||||||
|
assert set(operations) == set(["hset"]) # all of them are hsets
|
||||||
|
|
||||||
|
# Asserting that all of jobs emitted exactly one message
|
||||||
|
jobs = Jobs.get_jobs()
|
||||||
|
names = ["testjob1", "testjob2", "testjob3"]
|
||||||
|
ids = [str(job.uid) for job in jobs if job.name in names]
|
||||||
|
for id in ids:
|
||||||
|
assert id in " ".join(channels)
|
||||||
|
# Asserting that they came in order
|
||||||
|
assert "testjob4" not in " ".join(channels)
|
39
tests/test_websocket_uvicorn_standalone.py
Normal file
39
tests/test_websocket_uvicorn_standalone.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, WebSocket
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# import subprocess
|
||||||
|
from multiprocessing import Process
|
||||||
|
import asyncio
|
||||||
|
from time import sleep
|
||||||
|
from websockets import client
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
await websocket.send_text(f"You sent: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_uvicorn():
|
||||||
|
uvicorn.run(app, port=5000)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_uvcorn_ws_works_in_prod():
|
||||||
|
proc = Process(target=run_uvicorn)
|
||||||
|
proc.start()
|
||||||
|
sleep(2)
|
||||||
|
|
||||||
|
ws = await client.connect("ws://127.0.0.1:5000")
|
||||||
|
|
||||||
|
await ws.send("hohoho")
|
||||||
|
message = await ws.read_message()
|
||||||
|
assert message == "You sent: hohoho"
|
||||||
|
await ws.close()
|
||||||
|
proc.kill()
|
Loading…
Reference in a new issue